feat: authsession service

This commit is contained in:
Ilia Denisov
2026-04-08 16:23:07 +02:00
committed by GitHub
parent 28f04916af
commit 86a68ed9d0
174 changed files with 31732 additions and 112 deletions
@@ -0,0 +1,88 @@
package blockuser
import (
"context"
"errors"
"testing"
"time"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/devicesession"
"galaxy/authsession/internal/domain/userresolution"
"galaxy/authsession/internal/service/shared"
"galaxy/authsession/internal/testkit"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestExecuteRetriesProjectionPublishesForBlockFlow(t *testing.T) {
t.Parallel()
userDirectory := &testkit.InMemoryUserDirectory{}
store := &testkit.InMemorySessionStore{}
publisher := &testkit.RecordingProjectionPublisher{
Errors: []error{errors.New("publish failed"), nil},
}
require.NoError(t, userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
require.NoError(t, store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())))
service, err := New(userDirectory, store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
require.NoError(t, err)
result, err := service.Execute(context.Background(), Input{
UserID: "user-1",
ReasonCode: "policy_block",
ActorType: "admin",
})
require.NoError(t, err)
assert.Equal(t, "blocked", result.Outcome)
assert.EqualValues(t, 1, result.AffectedSessionCount)
require.Len(t, publisher.PublishedSnapshots(), 2)
}
func TestExecuteRepairsProjectionOnRepeatedAlreadyBlockedRequest(t *testing.T) {
t.Parallel()
userDirectory := &testkit.InMemoryUserDirectory{}
store := &testkit.InMemorySessionStore{}
publisher := &testkit.RecordingProjectionPublisher{Err: errors.New("publish failed")}
require.NoError(t, userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
require.NoError(t, store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())))
service, err := New(userDirectory, store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
require.NoError(t, err)
_, err = service.Execute(context.Background(), Input{
UserID: "user-1",
ReasonCode: "policy_block",
ActorType: "admin",
})
require.Error(t, err)
assert.Equal(t, shared.ErrorCodeServiceUnavailable, shared.CodeOf(err))
require.Len(t, publisher.PublishedSnapshots(), shared.MaxProjectionPublishAttempts)
sessionRecord, getErr := store.Get(context.Background(), common.DeviceSessionID("device-session-1"))
require.NoError(t, getErr)
require.NotNil(t, sessionRecord.Revocation)
assert.Equal(t, devicesession.StatusRevoked, sessionRecord.Status)
assert.Equal(t, devicesession.RevokeReasonUserBlocked, sessionRecord.Revocation.ReasonCode)
resolution, resolveErr := userDirectory.ResolveByEmail(context.Background(), common.Email("pilot@example.com"))
require.NoError(t, resolveErr)
assert.Equal(t, userresolution.KindBlocked, resolution.Kind)
publisher.Err = nil
result, err := service.Execute(context.Background(), Input{
UserID: "user-1",
ReasonCode: "policy_block",
ActorType: "admin",
})
require.NoError(t, err)
assert.Equal(t, "already_blocked", result.Outcome)
assert.EqualValues(t, 0, result.AffectedSessionCount)
require.NotNil(t, result.AffectedDeviceSessionIDs)
assert.Empty(t, result.AffectedDeviceSessionIDs)
require.Len(t, publisher.PublishedSnapshots(), shared.MaxProjectionPublishAttempts+1)
}
@@ -0,0 +1,91 @@
package blockuser
import (
"context"
"testing"
"time"
"galaxy/authsession/internal/domain/challenge"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/service/confirmemailcode"
"galaxy/authsession/internal/service/sendemailcode"
"galaxy/authsession/internal/service/shared"
"galaxy/authsession/internal/testkit"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const blockFlowPublicKey = "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8="
func TestBlockUserAffectsLaterSendAndConfirmFlows(t *testing.T) {
t.Parallel()
challengeStore := &testkit.InMemoryChallengeStore{}
sessionStore := &testkit.InMemorySessionStore{}
userDirectory := &testkit.InMemoryUserDirectory{}
publisher := &testkit.RecordingProjectionPublisher{}
idGenerator := &testkit.SequenceIDGenerator{
ChallengeIDs: []common.ChallengeID{"challenge-1"},
DeviceSessionIDs: []common.DeviceSessionID{"device-session-1"},
}
hasher := testkit.DeterministicCodeHasher{}
mailSender := &testkit.RecordingMailSender{}
now := time.Unix(20, 0).UTC()
clock := testkit.FixedClock{Time: now}
blockService, err := New(userDirectory, sessionStore, publisher, clock)
require.NoError(t, err)
_, err = blockService.Execute(context.Background(), Input{
Email: "pilot@example.com",
ReasonCode: "policy_block",
ActorType: "admin",
})
require.NoError(t, err)
sendService, err := sendemailcode.New(
challengeStore,
userDirectory,
idGenerator,
testkit.FixedCodeGenerator{Code: "654321"},
hasher,
mailSender,
clock,
)
require.NoError(t, err)
sendResult, err := sendService.Execute(context.Background(), sendemailcode.Input{Email: "pilot@example.com"})
require.NoError(t, err)
assert.Equal(t, "challenge-1", sendResult.ChallengeID)
assert.Empty(t, mailSender.RecordedInputs())
challengeRecord, err := challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
require.NoError(t, err)
assert.Equal(t, challenge.StatusDeliverySuppressed, challengeRecord.Status)
assert.Equal(t, challenge.DeliverySuppressed, challengeRecord.DeliveryState)
confirmService, err := confirmemailcode.New(
challengeStore,
sessionStore,
userDirectory,
testkit.StaticConfigProvider{},
publisher,
idGenerator,
hasher,
clock,
)
require.NoError(t, err)
_, err = confirmService.Execute(context.Background(), confirmemailcode.Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: blockFlowPublicKey,
})
require.Error(t, err)
assert.Equal(t, shared.ErrorCodeBlockedByPolicy, shared.CodeOf(err))
updatedChallenge, getErr := challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
require.NoError(t, getErr)
assert.Equal(t, challenge.StatusFailed, updatedChallenge.Status)
}
@@ -0,0 +1,64 @@
package blockuser
import (
"bytes"
"context"
"testing"
"time"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/testkit"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
func TestExecuteLogsSafeOutcomeFields(t *testing.T) {
t.Parallel()
userDirectory := &testkit.InMemoryUserDirectory{}
require.NoError(t, userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
sessionStore := &testkit.InMemorySessionStore{}
require.NoError(t, sessionStore.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())))
logger, buffer := newObservedServiceLogger()
service, err := NewWithObservability(
userDirectory,
sessionStore,
&testkit.RecordingProjectionPublisher{},
testkit.FixedClock{Time: time.Unix(20, 0).UTC()},
logger,
nil,
)
require.NoError(t, err)
_, err = service.Execute(context.Background(), Input{
UserID: "user-1",
ReasonCode: "policy_block",
ActorType: "admin",
})
require.NoError(t, err)
logOutput := buffer.String()
assert.Contains(t, logOutput, "block_user")
assert.Contains(t, logOutput, "\"user_id\":\"user-1\"")
assert.Contains(t, logOutput, "\"reason_code\":\"policy_block\"")
assert.NotContains(t, logOutput, "pilot@example.com")
}
func newObservedServiceLogger() (*zap.Logger, *bytes.Buffer) {
buffer := &bytes.Buffer{}
encoderConfig := zap.NewProductionEncoderConfig()
encoderConfig.TimeKey = ""
core := zapcore.NewCore(
zapcore.NewJSONEncoder(encoderConfig),
zapcore.AddSync(buffer),
zap.DebugLevel,
)
return zap.New(core), buffer
}
@@ -0,0 +1,294 @@
// Package blockuser implements the trusted internal block-user use case.
package blockuser
import (
"context"
"errors"
"fmt"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/devicesession"
"galaxy/authsession/internal/domain/userresolution"
"galaxy/authsession/internal/ports"
"galaxy/authsession/internal/service/shared"
"galaxy/authsession/internal/telemetry"
"go.uber.org/zap"
)
const (
// SubjectKindUserID identifies a block request addressed by stable user id.
SubjectKindUserID = "user_id"
// SubjectKindEmail identifies a block request addressed by normalized e-mail
// address.
SubjectKindEmail = "email"
)
// Input describes one trusted internal block-user request.
type Input struct {
// UserID identifies the subject to block when the request is user-id based.
UserID string
// Email identifies the subject to block when the request is e-mail based.
Email string
// ReasonCode stores the machine-readable block reason code applied to the
// user directory.
ReasonCode string
// ActorType stores the machine-readable actor type for any derived session
// revocation.
ActorType string
// ActorID stores the optional stable actor identifier for any derived
// session revocation.
ActorID string
}
// Result describes the frozen internal block-user acknowledgement.
type Result struct {
// Outcome reports whether the block state was newly applied or already
// existed.
Outcome string
// SubjectKind reports whether the request targeted `user_id` or `email`.
SubjectKind string
// SubjectValue stores the normalized subject value addressed by the
// operation.
SubjectValue string
// AffectedSessionCount reports how many sessions changed state during the
// current call.
AffectedSessionCount int64
// AffectedDeviceSessionIDs lists every session identifier affected during
// the current call.
AffectedDeviceSessionIDs []string
}
// Service executes the trusted internal block-user use case.
type Service struct {
userDirectory ports.UserDirectory
sessionStore ports.SessionStore
publisher ports.GatewaySessionProjectionPublisher
clock ports.Clock
logger *zap.Logger
telemetry *telemetry.Runtime
}
// New returns a block-user service wired to the required ports.
func New(userDirectory ports.UserDirectory, sessionStore ports.SessionStore, publisher ports.GatewaySessionProjectionPublisher, clock ports.Clock) (*Service, error) {
return NewWithObservability(userDirectory, sessionStore, publisher, clock, nil, nil)
}
// NewWithObservability returns a block-user service wired to the required
// ports plus optional structured logging and telemetry dependencies.
func NewWithObservability(
userDirectory ports.UserDirectory,
sessionStore ports.SessionStore,
publisher ports.GatewaySessionProjectionPublisher,
clock ports.Clock,
logger *zap.Logger,
telemetryRuntime *telemetry.Runtime,
) (*Service, error) {
switch {
case userDirectory == nil:
return nil, fmt.Errorf("blockuser: user directory must not be nil")
case sessionStore == nil:
return nil, fmt.Errorf("blockuser: session store must not be nil")
case publisher == nil:
return nil, fmt.Errorf("blockuser: projection publisher must not be nil")
case clock == nil:
return nil, fmt.Errorf("blockuser: clock must not be nil")
default:
return &Service{
userDirectory: userDirectory,
sessionStore: sessionStore,
publisher: publisher,
clock: clock,
logger: namedLogger(logger, "block_user"),
telemetry: telemetryRuntime,
}, nil
}
}
// Execute applies the requested block state and revokes any active sessions of
// the resolved user when one exists.
func (s *Service) Execute(ctx context.Context, input Input) (result Result, err error) {
logFields := []zap.Field{
zap.String("component", "service"),
zap.String("use_case", "block_user"),
}
defer func() {
if result.Outcome != "" {
logFields = append(logFields, zap.String("outcome", result.Outcome))
}
if result.SubjectKind != "" {
logFields = append(logFields, zap.String("subject_kind", result.SubjectKind))
}
if result.AffectedSessionCount > 0 {
logFields = append(logFields, zap.Int64("affected_session_count", result.AffectedSessionCount))
}
shared.LogServiceOutcome(s.logger, ctx, "block user completed", err, logFields...)
}()
subjectKind, subjectValue, storeResult, err := s.blockSubject(ctx, input)
if err != nil {
return Result{}, err
}
logFields = append(logFields, zap.String("reason_code", shared.NormalizeString(input.ReasonCode)))
if !storeResult.UserID.IsZero() {
logFields = append(logFields, zap.String("user_id", storeResult.UserID.String()))
}
affectedDeviceSessionIDs := []string{}
affectedSessionCount := int64(0)
if !storeResult.UserID.IsZero() {
revocation, err := shared.BuildRevocation(
devicesession.RevokeReasonUserBlocked.String(),
input.ActorType,
input.ActorID,
s.clock.Now(),
)
if err != nil {
return Result{}, err
}
revokeResult, err := s.sessionStore.RevokeAllByUserID(ctx, ports.RevokeUserSessionsInput{
UserID: storeResult.UserID,
Revocation: revocation,
})
if err != nil {
return Result{}, shared.ServiceUnavailable(err)
}
if err := revokeResult.Validate(); err != nil {
return Result{}, shared.InternalError(err)
}
for _, record := range revokeResult.Sessions {
if err := shared.PublishSessionProjectionWithTelemetry(ctx, s.publisher, record, s.telemetry, "block_user"); err != nil {
return Result{}, err
}
affectedDeviceSessionIDs = append(affectedDeviceSessionIDs, record.ID.String())
}
if revokeResult.Outcome == ports.RevokeUserSessionsOutcomeNoActiveSessions {
if err := s.republishCurrentRevokedSessions(ctx, storeResult.UserID); err != nil {
return Result{}, err
}
}
affectedSessionCount = int64(len(revokeResult.Sessions))
if affectedSessionCount > 0 {
s.telemetry.RecordSessionRevocations(ctx, "block_user", devicesession.RevokeReasonUserBlocked.String(), affectedSessionCount)
}
}
result = Result{
Outcome: string(storeResult.Outcome),
SubjectKind: subjectKind,
SubjectValue: subjectValue,
AffectedSessionCount: affectedSessionCount,
AffectedDeviceSessionIDs: affectedDeviceSessionIDs,
}
return result, nil
}
func (s *Service) blockSubject(ctx context.Context, input Input) (string, string, ports.BlockUserResult, error) {
userID := shared.NormalizeString(input.UserID)
email := shared.NormalizeString(input.Email)
switch {
case userID == "" && email == "":
return "", "", ports.BlockUserResult{}, shared.InvalidRequest("exactly one of user_id or email must be provided")
case userID != "" && email != "":
return "", "", ports.BlockUserResult{}, shared.InvalidRequest("exactly one of user_id or email must be provided")
case userID != "":
parsedUserID, err := shared.ParseUserID(userID)
if err != nil {
return "", "", ports.BlockUserResult{}, err
}
reasonCode, err := parseBlockReasonCode(input.ReasonCode)
if err != nil {
return "", "", ports.BlockUserResult{}, err
}
result, err := s.userDirectory.BlockByUserID(ctx, ports.BlockUserByIDInput{
UserID: parsedUserID,
ReasonCode: reasonCode,
})
if err != nil {
switch {
case errors.Is(err, ports.ErrNotFound):
return "", "", ports.BlockUserResult{}, shared.SubjectNotFound()
default:
return "", "", ports.BlockUserResult{}, shared.ServiceUnavailable(err)
}
}
if err := result.Validate(); err != nil {
return "", "", ports.BlockUserResult{}, shared.InternalError(err)
}
s.telemetry.RecordUserDirectoryOutcome(ctx, "block_by_user_id", string(result.Outcome))
return SubjectKindUserID, parsedUserID.String(), result, nil
default:
parsedEmail, err := shared.ParseEmail(email)
if err != nil {
return "", "", ports.BlockUserResult{}, err
}
reasonCode, err := parseBlockReasonCode(input.ReasonCode)
if err != nil {
return "", "", ports.BlockUserResult{}, err
}
result, err := s.userDirectory.BlockByEmail(ctx, ports.BlockUserByEmailInput{
Email: parsedEmail,
ReasonCode: reasonCode,
})
if err != nil {
return "", "", ports.BlockUserResult{}, shared.ServiceUnavailable(err)
}
if err := result.Validate(); err != nil {
return "", "", ports.BlockUserResult{}, shared.InternalError(err)
}
s.telemetry.RecordUserDirectoryOutcome(ctx, "block_by_email", string(result.Outcome))
return SubjectKindEmail, parsedEmail.String(), result, nil
}
}
func parseBlockReasonCode(value string) (userresolution.BlockReasonCode, error) {
reasonCode := userresolution.BlockReasonCode(shared.NormalizeString(value))
if err := reasonCode.Validate(); err != nil {
return "", shared.InvalidRequest(err.Error())
}
return reasonCode, nil
}
func (s *Service) republishCurrentRevokedSessions(ctx context.Context, userID common.UserID) error {
records, err := s.sessionStore.ListByUserID(ctx, userID)
if err != nil {
return shared.ServiceUnavailable(err)
}
for _, record := range records {
if record.Status != devicesession.StatusRevoked {
continue
}
if err := shared.PublishSessionProjectionWithTelemetry(ctx, s.publisher, record, s.telemetry, "block_user_repair"); err != nil {
return err
}
}
return nil
}
func namedLogger(logger *zap.Logger, name string) *zap.Logger {
if logger == nil {
logger = zap.NewNop()
}
return logger.Named(name)
}
@@ -0,0 +1,237 @@
package blockuser
import (
"context"
"errors"
"testing"
"time"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/devicesession"
"galaxy/authsession/internal/domain/gatewayprojection"
"galaxy/authsession/internal/domain/userresolution"
"galaxy/authsession/internal/service/shared"
"galaxy/authsession/internal/testkit"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestExecuteBlocksByUserIDAndRevokesSessions(t *testing.T) {
t.Parallel()
userDirectory := &testkit.InMemoryUserDirectory{}
store := &testkit.InMemorySessionStore{}
publisher := &testkit.RecordingProjectionPublisher{}
if err := userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil {
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
}
if err := store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service, err := New(userDirectory, store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
require.NoError(t, err)
result, err := service.Execute(context.Background(), Input{
UserID: "user-1",
ReasonCode: "policy_block",
ActorType: "admin",
})
require.NoError(t, err)
assert.Equal(t, "blocked", result.Outcome)
assert.EqualValues(t, 1, result.AffectedSessionCount)
assert.Equal(t, SubjectKindUserID, result.SubjectKind)
assert.Equal(t, "user-1", result.SubjectValue)
assert.Equal(t, []string{"device-session-1"}, result.AffectedDeviceSessionIDs)
stored, getErr := store.Get(context.Background(), common.DeviceSessionID("device-session-1"))
require.NoError(t, getErr)
require.NotNil(t, stored.Revocation)
assert.Equal(t, devicesession.StatusRevoked, stored.Status)
assert.Equal(t, devicesession.RevokeReasonUserBlocked, stored.Revocation.ReasonCode)
assert.Equal(t, common.RevokeActorType("admin"), stored.Revocation.ActorType)
resolution, resolveErr := userDirectory.ResolveByEmail(context.Background(), common.Email("pilot@example.com"))
require.NoError(t, resolveErr)
assert.Equal(t, userresolution.KindBlocked, resolution.Kind)
assert.Equal(t, userresolution.BlockReasonCode("policy_block"), resolution.BlockReasonCode)
published := publisher.PublishedSnapshots()
require.Len(t, published, 1)
assert.Equal(t, gatewayprojection.StatusRevoked, published[0].Status)
assert.Equal(t, devicesession.RevokeReasonUserBlocked, published[0].RevokeReasonCode)
assert.Equal(t, common.RevokeActorType("admin"), published[0].RevokeActorType)
}
func TestExecuteBlocksByEmailWithoutExistingUser(t *testing.T) {
t.Parallel()
userDirectory := &testkit.InMemoryUserDirectory{}
publisher := &testkit.RecordingProjectionPublisher{}
service, err := New(userDirectory, &testkit.InMemorySessionStore{}, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
require.NoError(t, err)
result, err := service.Execute(context.Background(), Input{
Email: "pilot@example.com",
ReasonCode: "policy_block",
ActorType: "admin",
})
require.NoError(t, err)
assert.Equal(t, "blocked", result.Outcome)
assert.EqualValues(t, 0, result.AffectedSessionCount)
assert.Equal(t, SubjectKindEmail, result.SubjectKind)
assert.Equal(t, "pilot@example.com", result.SubjectValue)
require.NotNil(t, result.AffectedDeviceSessionIDs)
assert.Empty(t, result.AffectedDeviceSessionIDs)
resolution, resolveErr := userDirectory.ResolveByEmail(context.Background(), common.Email("pilot@example.com"))
require.NoError(t, resolveErr)
assert.Equal(t, userresolution.KindBlocked, resolution.Kind)
assert.Equal(t, userresolution.BlockReasonCode("policy_block"), resolution.BlockReasonCode)
assert.Empty(t, publisher.PublishedSnapshots())
}
func TestExecuteBlocksByEmailWithExistingUserAndRevokesSessions(t *testing.T) {
t.Parallel()
userDirectory := &testkit.InMemoryUserDirectory{}
store := &testkit.InMemorySessionStore{}
publisher := &testkit.RecordingProjectionPublisher{}
if err := userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil {
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
}
if err := store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service, err := New(userDirectory, store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
require.NoError(t, err)
result, err := service.Execute(context.Background(), Input{
Email: "pilot@example.com",
ReasonCode: "policy_block",
ActorType: "admin",
})
require.NoError(t, err)
assert.Equal(t, "blocked", result.Outcome)
assert.EqualValues(t, 1, result.AffectedSessionCount)
assert.Equal(t, []string{"device-session-1"}, result.AffectedDeviceSessionIDs)
stored, getErr := store.Get(context.Background(), common.DeviceSessionID("device-session-1"))
require.NoError(t, getErr)
require.NotNil(t, stored.Revocation)
assert.Equal(t, devicesession.RevokeReasonUserBlocked, stored.Revocation.ReasonCode)
assert.Equal(t, common.RevokeActorType("admin"), stored.Revocation.ActorType)
resolution, resolveErr := userDirectory.ResolveByEmail(context.Background(), common.Email("pilot@example.com"))
require.NoError(t, resolveErr)
assert.Equal(t, userresolution.KindBlocked, resolution.Kind)
assert.Equal(t, userresolution.BlockReasonCode("policy_block"), resolution.BlockReasonCode)
published := publisher.PublishedSnapshots()
require.Len(t, published, 1)
assert.Equal(t, devicesession.RevokeReasonUserBlocked, published[0].RevokeReasonCode)
}
func TestExecuteReturnsSubjectNotFoundForUnknownUserID(t *testing.T) {
t.Parallel()
service, err := New(&testkit.InMemoryUserDirectory{}, &testkit.InMemorySessionStore{}, &testkit.RecordingProjectionPublisher{}, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
if err != nil {
require.Failf(t, "test failed", "New() returned error: %v", err)
}
_, err = service.Execute(context.Background(), Input{
UserID: "missing",
ReasonCode: "policy_block",
ActorType: "admin",
})
assert.Equal(t, shared.ErrorCodeSubjectNotFound, shared.CodeOf(err))
}
func TestExecuteAlreadyBlockedStillRevokesLingeringSessions(t *testing.T) {
t.Parallel()
userDirectory := &testkit.InMemoryUserDirectory{}
store := &testkit.InMemorySessionStore{}
publisher := &testkit.RecordingProjectionPublisher{}
if err := userDirectory.SeedBlockedUser(common.Email("pilot@example.com"), common.UserID("user-1"), userresolution.BlockReasonCode("policy_block")); err != nil {
require.Failf(t, "test failed", "SeedBlockedUser() returned error: %v", err)
}
if err := store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service, err := New(userDirectory, store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
require.NoError(t, err)
result, err := service.Execute(context.Background(), Input{
Email: "pilot@example.com",
ReasonCode: "policy_block",
ActorType: "admin",
})
require.NoError(t, err)
assert.Equal(t, "already_blocked", result.Outcome)
assert.EqualValues(t, 1, result.AffectedSessionCount)
assert.Equal(t, []string{"device-session-1"}, result.AffectedDeviceSessionIDs)
stored, getErr := store.Get(context.Background(), common.DeviceSessionID("device-session-1"))
require.NoError(t, getErr)
require.NotNil(t, stored.Revocation)
assert.Equal(t, devicesession.RevokeReasonUserBlocked, stored.Revocation.ReasonCode)
assert.Equal(t, common.RevokeActorType("admin"), stored.Revocation.ActorType)
published := publisher.PublishedSnapshots()
require.Len(t, published, 1)
assert.Equal(t, devicesession.RevokeReasonUserBlocked, published[0].RevokeReasonCode)
}
func TestExecuteReturnsServiceUnavailableWhenPublishFails(t *testing.T) {
t.Parallel()
userDirectory := &testkit.InMemoryUserDirectory{}
store := &testkit.InMemorySessionStore{}
publisher := &testkit.RecordingProjectionPublisher{Err: errors.New("publish failed")}
if err := userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil {
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
}
if err := store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service, err := New(userDirectory, store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
require.NoError(t, err)
_, err = service.Execute(context.Background(), Input{
UserID: "user-1",
ReasonCode: "policy_block",
ActorType: "admin",
})
assert.Equal(t, shared.ErrorCodeServiceUnavailable, shared.CodeOf(err))
stored, getErr := store.Get(context.Background(), common.DeviceSessionID("device-session-1"))
require.NoError(t, getErr)
require.NotNil(t, stored.Revocation)
assert.Equal(t, devicesession.RevokeReasonUserBlocked, stored.Revocation.ReasonCode)
resolution, resolveErr := userDirectory.ResolveByEmail(context.Background(), common.Email("pilot@example.com"))
require.NoError(t, resolveErr)
assert.Equal(t, userresolution.KindBlocked, resolution.Kind)
assert.Equal(t, userresolution.BlockReasonCode("policy_block"), resolution.BlockReasonCode)
}
func activeSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session {
key, err := common.NewClientPublicKey(make([]byte, 32))
if err != nil {
panic(err)
}
return devicesession.Session{
ID: common.DeviceSessionID(deviceSessionID),
UserID: common.UserID(userID),
ClientPublicKey: key,
Status: devicesession.StatusActive,
CreatedAt: createdAt,
}
}
@@ -0,0 +1,60 @@
package blockuser
import (
"context"
"testing"
"time"
stubuserservice "galaxy/authsession/internal/adapters/userservice"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/testkit"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestExecuteWithRuntimeStubUserDirectory(t *testing.T) {
t.Parallel()
t.Run("blocks by email through runtime stub", func(t *testing.T) {
t.Parallel()
userDirectory := &stubuserservice.StubDirectory{}
require.NoError(t, userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
store := &testkit.InMemorySessionStore{}
require.NoError(t, store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())))
service, err := New(userDirectory, store, &testkit.RecordingProjectionPublisher{}, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
require.NoError(t, err)
result, err := service.Execute(context.Background(), Input{
Email: "pilot@example.com",
ReasonCode: "policy_block",
ActorType: "admin",
})
require.NoError(t, err)
assert.Equal(t, SubjectKindEmail, result.SubjectKind)
assert.Equal(t, "blocked", result.Outcome)
assert.EqualValues(t, 1, result.AffectedSessionCount)
})
t.Run("blocks by user id through runtime stub", func(t *testing.T) {
t.Parallel()
userDirectory := &stubuserservice.StubDirectory{}
require.NoError(t, userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
service, err := New(userDirectory, &testkit.InMemorySessionStore{}, &testkit.RecordingProjectionPublisher{}, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
require.NoError(t, err)
result, err := service.Execute(context.Background(), Input{
UserID: "user-1",
ReasonCode: "policy_block",
ActorType: "admin",
})
require.NoError(t, err)
assert.Equal(t, SubjectKindUserID, result.SubjectKind)
assert.Equal(t, "blocked", result.Outcome)
})
}
@@ -0,0 +1,39 @@
package confirmemailcode
import (
"context"
"testing"
"time"
"galaxy/authsession/internal/domain/challenge"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/service/shared"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestExecuteReturnsInvalidCodeForThrottledChallengeWithoutConsumingAttempts(t *testing.T) {
t.Parallel()
deps := newConfirmDeps(t)
record := sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))
record.Status = challenge.StatusDeliveryThrottled
record.DeliveryState = challenge.DeliveryThrottled
require.NoError(t, record.Validate())
require.NoError(t, deps.challengeStore.Create(context.Background(), record))
service := mustNewConfirmService(t, deps)
_, err := service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: publicKeyString(),
})
require.Error(t, err)
assert.Equal(t, shared.ErrorCodeInvalidCode, shared.CodeOf(err))
updated, getErr := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
require.NoError(t, getErr)
assert.Equal(t, 0, updated.Attempts.Confirm)
assert.Equal(t, challenge.StatusDeliveryThrottled, updated.Status)
}
@@ -0,0 +1,106 @@
package confirmemailcode
import (
"context"
"errors"
"testing"
"time"
"galaxy/authsession/internal/domain/challenge"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/devicesession"
"galaxy/authsession/internal/service/shared"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestExecuteConfirmsChallengeAfterTransientProjectionPublishFailures(t *testing.T) {
t.Parallel()
deps := newConfirmDeps(t)
deps.publisher.Errors = []error{errors.New("publish failed"), nil}
require.NoError(t, deps.userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
require.NoError(t, deps.challengeStore.Create(
context.Background(),
sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute)),
))
service := mustNewConfirmService(t, deps)
result, err := service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: publicKeyString(),
})
require.NoError(t, err)
assert.Equal(t, "device-session-1", result.DeviceSessionID)
require.Len(t, deps.publisher.PublishedSnapshots(), 2)
}
func TestExecuteConfirmedRetryRepublishesAfterTransientProjectionPublishFailures(t *testing.T) {
t.Parallel()
deps := newConfirmDeps(t)
deps.publisher.Errors = []error{errors.New("publish failed"), nil}
key := mustClientPublicKey(t, publicKeyString())
require.NoError(t, deps.challengeStore.Create(
context.Background(),
confirmedChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", "device-session-1", key, deps.now.Add(-time.Minute), deps.now.Add(time.Minute)),
))
require.NoError(t, deps.sessionStore.Create(
context.Background(),
activeSessionFixture("device-session-1", "user-1", key, deps.now.Add(-time.Minute)),
))
service := mustNewConfirmService(t, deps)
result, err := service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: publicKeyString(),
})
require.NoError(t, err)
assert.Equal(t, "device-session-1", result.DeviceSessionID)
require.Len(t, deps.publisher.PublishedSnapshots(), 2)
}
func TestExecuteRepairsProjectionOnIdenticalRetryAfterExhaustedPublishRetries(t *testing.T) {
t.Parallel()
deps := newConfirmDeps(t)
deps.publisher.Err = errors.New("publish failed")
require.NoError(t, deps.userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
require.NoError(t, deps.challengeStore.Create(
context.Background(),
sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute)),
))
service := mustNewConfirmService(t, deps)
_, err := service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: publicKeyString(),
})
require.Error(t, err)
assert.Equal(t, shared.ErrorCodeServiceUnavailable, shared.CodeOf(err))
require.Len(t, deps.publisher.PublishedSnapshots(), shared.MaxProjectionPublishAttempts)
sessionRecord, getErr := deps.sessionStore.Get(context.Background(), common.DeviceSessionID("device-session-1"))
require.NoError(t, getErr)
assert.Equal(t, devicesession.StatusActive, sessionRecord.Status)
challengeRecord, getErr := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
require.NoError(t, getErr)
assert.Equal(t, challenge.StatusConfirmedPendingExpire, challengeRecord.Status)
require.NotNil(t, challengeRecord.Confirmation)
deps.publisher.Err = nil
result, err := service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: publicKeyString(),
})
require.NoError(t, err)
assert.Equal(t, "device-session-1", result.DeviceSessionID)
require.Len(t, deps.publisher.PublishedSnapshots(), shared.MaxProjectionPublishAttempts+1)
}
@@ -0,0 +1,588 @@
// Package confirmemailcode implements the public confirm-email-code use case.
package confirmemailcode
import (
"context"
"errors"
"fmt"
"time"
"galaxy/authsession/internal/domain/challenge"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/devicesession"
"galaxy/authsession/internal/domain/sessionlimit"
"galaxy/authsession/internal/ports"
"galaxy/authsession/internal/service/shared"
"galaxy/authsession/internal/telemetry"
"go.uber.org/zap"
)
const (
revokeReasonConfirmRace common.RevokeReasonCode = "confirm_race_repair"
revokeActorTypeService common.RevokeActorType = "service"
revokeActorIDService = "confirmemailcode"
)
// Input describes one public confirm-email-code request.
type Input struct {
// ChallengeID identifies the challenge that should be confirmed.
ChallengeID string
// Code is the cleartext confirmation code submitted by the caller.
Code string
// ClientPublicKey is the base64-encoded raw 32-byte Ed25519 public key that
// should be registered for the created device session.
ClientPublicKey string
}
// Result describes one public confirm-email-code response.
type Result struct {
// DeviceSessionID is the stable identifier of the created or idempotently
// recovered device session.
DeviceSessionID string
}
// Service executes the public confirm-email-code use case.
type Service struct {
challengeStore ports.ChallengeStore
sessionStore ports.SessionStore
userDirectory ports.UserDirectory
configProvider ports.ConfigProvider
publisher ports.GatewaySessionProjectionPublisher
idGenerator ports.IDGenerator
codeHasher ports.CodeHasher
clock ports.Clock
logger *zap.Logger
telemetry *telemetry.Runtime
}
// New returns a confirm-email-code service wired to the required ports.
func New(
challengeStore ports.ChallengeStore,
sessionStore ports.SessionStore,
userDirectory ports.UserDirectory,
configProvider ports.ConfigProvider,
publisher ports.GatewaySessionProjectionPublisher,
idGenerator ports.IDGenerator,
codeHasher ports.CodeHasher,
clock ports.Clock,
) (*Service, error) {
return NewWithTelemetry(
challengeStore,
sessionStore,
userDirectory,
configProvider,
publisher,
idGenerator,
codeHasher,
clock,
nil,
)
}
// NewWithTelemetry returns a confirm-email-code service wired to the required
// ports plus the optional Stage-17 telemetry runtime.
func NewWithTelemetry(
challengeStore ports.ChallengeStore,
sessionStore ports.SessionStore,
userDirectory ports.UserDirectory,
configProvider ports.ConfigProvider,
publisher ports.GatewaySessionProjectionPublisher,
idGenerator ports.IDGenerator,
codeHasher ports.CodeHasher,
clock ports.Clock,
telemetryRuntime *telemetry.Runtime,
) (*Service, error) {
return NewWithObservability(
challengeStore,
sessionStore,
userDirectory,
configProvider,
publisher,
idGenerator,
codeHasher,
clock,
nil,
telemetryRuntime,
)
}
// NewWithObservability returns a confirm-email-code service wired to the
// required ports plus optional structured logging and telemetry dependencies.
func NewWithObservability(
challengeStore ports.ChallengeStore,
sessionStore ports.SessionStore,
userDirectory ports.UserDirectory,
configProvider ports.ConfigProvider,
publisher ports.GatewaySessionProjectionPublisher,
idGenerator ports.IDGenerator,
codeHasher ports.CodeHasher,
clock ports.Clock,
logger *zap.Logger,
telemetryRuntime *telemetry.Runtime,
) (*Service, error) {
switch {
case challengeStore == nil:
return nil, fmt.Errorf("confirmemailcode: challenge store must not be nil")
case sessionStore == nil:
return nil, fmt.Errorf("confirmemailcode: session store must not be nil")
case userDirectory == nil:
return nil, fmt.Errorf("confirmemailcode: user directory must not be nil")
case configProvider == nil:
return nil, fmt.Errorf("confirmemailcode: config provider must not be nil")
case publisher == nil:
return nil, fmt.Errorf("confirmemailcode: projection publisher must not be nil")
case idGenerator == nil:
return nil, fmt.Errorf("confirmemailcode: id generator must not be nil")
case codeHasher == nil:
return nil, fmt.Errorf("confirmemailcode: code hasher must not be nil")
case clock == nil:
return nil, fmt.Errorf("confirmemailcode: clock must not be nil")
default:
return &Service{
challengeStore: challengeStore,
sessionStore: sessionStore,
userDirectory: userDirectory,
configProvider: configProvider,
publisher: publisher,
idGenerator: idGenerator,
codeHasher: codeHasher,
clock: clock,
logger: namedLogger(logger, "confirm_email_code"),
telemetry: telemetryRuntime,
}, nil
}
}
// Execute validates one challenge confirmation attempt, creates a device
// session when policy allows it, and handles short-window idempotent retries.
func (s *Service) Execute(ctx context.Context, input Input) (result Result, err error) {
logFields := []zap.Field{
zap.String("component", "service"),
zap.String("use_case", "confirm_email_code"),
}
defer func() {
outcome := string(telemetry.ConfirmEmailCodeOutcomeSuccess)
if err != nil {
outcome = shared.CodeOf(err)
if outcome == "" {
outcome = shared.ErrorCodeServiceUnavailable
}
}
s.telemetry.RecordConfirmEmailCode(ctx, outcome)
logFields = append(logFields, zap.String("outcome", outcome))
if result.DeviceSessionID != "" {
logFields = append(logFields, zap.String("device_session_id", result.DeviceSessionID))
}
shared.LogServiceOutcome(s.logger, ctx, "confirm email code completed", err, logFields...)
}()
challengeID, err := shared.ParseChallengeID(input.ChallengeID)
if err != nil {
return Result{}, err
}
logFields = append(logFields, zap.String("challenge_id", challengeID.String()))
code, err := shared.ParseRequiredCode(input.Code)
if err != nil {
return Result{}, err
}
clientPublicKey, err := shared.ParseClientPublicKey(input.ClientPublicKey)
if err != nil {
return Result{}, err
}
for attempt := 0; attempt < shared.MaxCompareAndSwapRetries; attempt++ {
current, err := s.challengeStore.Get(ctx, challengeID)
if err != nil {
switch {
case errors.Is(err, ports.ErrNotFound):
return Result{}, shared.ChallengeNotFound()
default:
return Result{}, shared.ServiceUnavailable(err)
}
}
now := s.clock.Now().UTC()
if expired, err := s.ensureChallengeNotExpired(ctx, current, now); err != nil {
if errors.Is(err, ports.ErrConflict) {
continue
}
return Result{}, err
} else if expired {
return Result{}, shared.ChallengeExpired()
}
switch {
case current.Status.IsConfirmedRetryState():
return s.handleConfirmedRetry(ctx, current, code, clientPublicKey)
case !current.Status.AcceptsFreshConfirm():
return Result{}, shared.InvalidCode()
}
match, err := s.codeHasher.Compare(current.CodeHash, code)
if err != nil {
return Result{}, shared.ServiceUnavailable(err)
}
if !match {
if err := s.recordInvalidConfirmAttempt(ctx, current, now); err != nil {
if errors.Is(err, ports.ErrConflict) {
continue
}
return Result{}, err
}
return Result{}, shared.InvalidCode()
}
ensureUserResult, err := s.userDirectory.EnsureUserByEmail(ctx, current.Email)
if err != nil {
return Result{}, shared.ServiceUnavailable(err)
}
if err := ensureUserResult.Validate(); err != nil {
return Result{}, shared.InternalError(err)
}
s.telemetry.RecordUserDirectoryOutcome(ctx, "ensure_user_by_email", string(ensureUserResult.Outcome))
if !ensureUserResult.UserID.IsZero() {
logFields = append(logFields, zap.String("user_id", ensureUserResult.UserID.String()))
}
if ensureUserResult.Outcome == ports.EnsureUserOutcomeBlocked {
if err := s.markChallengeFailed(ctx, current, now); err != nil {
if errors.Is(err, ports.ErrConflict) {
continue
}
return Result{}, err
}
return Result{}, shared.BlockedByPolicy()
}
limitConfig, err := s.configProvider.LoadSessionLimit(ctx)
if err != nil {
return Result{}, shared.ServiceUnavailable(err)
}
decision, err := s.evaluateSessionLimit(ctx, ensureUserResult.UserID, limitConfig)
if err != nil {
return Result{}, err
}
if decision.Kind == sessionlimit.KindExceeded {
s.telemetry.RecordSessionLimitRejection(ctx)
return Result{}, shared.SessionLimitExceeded()
}
sessionRecord, err := s.createSession(ctx, ensureUserResult.UserID, clientPublicKey, now)
if err != nil {
return Result{}, err
}
next := current
next.Status = challenge.StatusConfirmedPendingExpire
next.ExpiresAt = now.Add(challenge.ConfirmedRetention)
next.Abuse.LastAttemptAt = &now
next.Confirmation = &challenge.Confirmation{
SessionID: sessionRecord.ID,
ClientPublicKey: clientPublicKey,
ConfirmedAt: now,
}
if err := next.Validate(); err != nil {
s.bestEffortRevokeSupersededSession(ctx, sessionRecord)
return Result{}, shared.InternalError(err)
}
if err := s.challengeStore.CompareAndSwap(ctx, current, next); err != nil {
if errors.Is(err, ports.ErrConflict) {
return s.handleCreateSessionCASConflict(ctx, challengeID, code, clientPublicKey, sessionRecord)
}
s.bestEffortRevokeSupersededSession(ctx, sessionRecord)
return Result{}, shared.ServiceUnavailable(err)
}
// Publish the currently stored session view so a concurrent revoke/block
// cannot overwrite source of truth with a stale active projection.
currentSession, err := s.sessionStore.Get(ctx, sessionRecord.ID)
if err != nil {
switch {
case errors.Is(err, ports.ErrNotFound):
return Result{}, shared.InternalError(fmt.Errorf("confirmemailcode: newly created session %q was not found", sessionRecord.ID))
default:
return Result{}, shared.ServiceUnavailable(err)
}
}
if err := s.publishSession(ctx, currentSession, "confirm_email_code"); err != nil {
return Result{}, err
}
return Result{DeviceSessionID: currentSession.ID.String()}, nil
}
return Result{}, shared.ServiceUnavailable(fmt.Errorf("confirmemailcode: compare-and-swap retry limit exceeded"))
}
func (s *Service) ensureChallengeNotExpired(ctx context.Context, current challenge.Challenge, now time.Time) (bool, error) {
if current.IsExpiredAt(now) {
if current.Status != challenge.StatusExpired && current.Status.CanTransitionTo(challenge.StatusExpired) {
next := current
next.Status = challenge.StatusExpired
next.Abuse.LastAttemptAt = &now
next.Confirmation = nil
if err := next.Validate(); err != nil {
return true, shared.InternalError(err)
}
if err := s.challengeStore.CompareAndSwap(ctx, current, next); err != nil {
if !errors.Is(err, ports.ErrConflict) {
return true, shared.ServiceUnavailable(err)
}
return false, err
}
}
return true, nil
}
return false, nil
}
func (s *Service) handleConfirmedRetry(ctx context.Context, current challenge.Challenge, code string, clientPublicKey common.ClientPublicKey) (Result, error) {
match, err := s.codeHasher.Compare(current.CodeHash, code)
if err != nil {
return Result{}, shared.ServiceUnavailable(err)
}
if !match {
return Result{}, shared.InvalidCode()
}
if current.Confirmation == nil {
return Result{}, shared.InternalError(fmt.Errorf("confirmemailcode: confirmed challenge is missing confirmation metadata"))
}
if current.Confirmation.ClientPublicKey.String() != clientPublicKey.String() {
return Result{}, shared.InvalidCode()
}
record, err := s.sessionStore.Get(ctx, current.Confirmation.SessionID)
if err != nil {
switch {
case errors.Is(err, ports.ErrNotFound):
return Result{}, shared.InternalError(fmt.Errorf("confirmemailcode: confirmed session %q was not found", current.Confirmation.SessionID))
default:
return Result{}, shared.ServiceUnavailable(err)
}
}
if err := s.publishSession(ctx, record, "confirm_email_code_retry"); err != nil {
return Result{}, err
}
return Result{DeviceSessionID: record.ID.String()}, nil
}
func (s *Service) recordInvalidConfirmAttempt(ctx context.Context, current challenge.Challenge, now time.Time) error {
next := current
next.Attempts.Confirm++
next.Abuse.LastAttemptAt = &now
if next.Attempts.Confirm >= challenge.MaxInvalidConfirmAttempts {
next.Status = challenge.StatusFailed
}
if err := next.Validate(); err != nil {
return shared.InternalError(err)
}
if err := s.challengeStore.CompareAndSwap(ctx, current, next); err != nil {
switch {
case errors.Is(err, ports.ErrConflict):
return err
default:
return shared.ServiceUnavailable(err)
}
}
return nil
}
func (s *Service) markChallengeFailed(ctx context.Context, current challenge.Challenge, now time.Time) error {
next := current
next.Status = challenge.StatusFailed
next.Abuse.LastAttemptAt = &now
if err := next.Validate(); err != nil {
return shared.InternalError(err)
}
if err := s.challengeStore.CompareAndSwap(ctx, current, next); err != nil {
switch {
case errors.Is(err, ports.ErrConflict):
return err
default:
return shared.ServiceUnavailable(err)
}
}
return nil
}
func (s *Service) evaluateSessionLimit(ctx context.Context, userID common.UserID, config ports.SessionLimitConfig) (sessionlimit.Decision, error) {
activeSessionCount, err := s.sessionStore.CountActiveByUserID(ctx, userID)
if err != nil {
return sessionlimit.Decision{}, shared.ServiceUnavailable(err)
}
decision, err := shared.EvaluateSessionLimit(config, activeSessionCount)
if err != nil {
return sessionlimit.Decision{}, err
}
return decision, nil
}
func (s *Service) createSession(ctx context.Context, userID common.UserID, clientPublicKey common.ClientPublicKey, now time.Time) (devicesession.Session, error) {
for attempt := 0; attempt < shared.MaxCompareAndSwapRetries; attempt++ {
deviceSessionID, err := s.idGenerator.NewDeviceSessionID()
if err != nil {
return devicesession.Session{}, shared.ServiceUnavailable(err)
}
record := devicesession.Session{
ID: deviceSessionID,
UserID: userID,
ClientPublicKey: clientPublicKey,
Status: devicesession.StatusActive,
CreatedAt: now,
}
if err := record.Validate(); err != nil {
return devicesession.Session{}, shared.InternalError(err)
}
if err := s.sessionStore.Create(ctx, record); err != nil {
if errors.Is(err, ports.ErrConflict) {
continue
}
return devicesession.Session{}, shared.ServiceUnavailable(err)
}
s.telemetry.RecordSessionCreated(ctx)
return record, nil
}
return devicesession.Session{}, shared.ServiceUnavailable(fmt.Errorf("confirmemailcode: session id conflict retry limit exceeded"))
}
func (s *Service) handleCreateSessionCASConflict(
ctx context.Context,
challengeID common.ChallengeID,
code string,
clientPublicKey common.ClientPublicKey,
createdSession devicesession.Session,
) (Result, error) {
defer s.bestEffortRevokeSupersededSession(ctx, createdSession)
current, err := s.challengeStore.Get(ctx, challengeID)
if err != nil {
if errors.Is(err, ports.ErrNotFound) {
return Result{}, shared.ServiceUnavailable(err)
}
return Result{}, shared.ServiceUnavailable(err)
}
if current.Status != challenge.StatusConfirmedPendingExpire || current.Confirmation == nil {
return Result{}, shared.ServiceUnavailable(fmt.Errorf("confirmemailcode: challenge %q changed to unexpected status %q after create", challengeID, current.Status))
}
match, err := s.codeHasher.Compare(current.CodeHash, code)
if err != nil {
return Result{}, shared.ServiceUnavailable(err)
}
if !match || current.Confirmation.ClientPublicKey.String() != clientPublicKey.String() {
return Result{}, shared.ServiceUnavailable(fmt.Errorf("confirmemailcode: challenge %q was confirmed by a different payload", challengeID))
}
winningSession, err := s.sessionStore.Get(ctx, current.Confirmation.SessionID)
if err != nil {
switch {
case errors.Is(err, ports.ErrNotFound):
return Result{}, shared.InternalError(fmt.Errorf("confirmemailcode: winning session %q was not found", current.Confirmation.SessionID))
default:
return Result{}, shared.ServiceUnavailable(err)
}
}
if err := s.publishSession(ctx, winningSession, "confirm_email_code_race_winner"); err != nil {
return Result{}, err
}
return Result{DeviceSessionID: winningSession.ID.String()}, nil
}
func (s *Service) bestEffortRevokeSupersededSession(ctx context.Context, record devicesession.Session) {
revocation := devicesession.Revocation{
At: s.clock.Now().UTC(),
ReasonCode: revokeReasonConfirmRace,
ActorType: revokeActorTypeService,
ActorID: revokeActorIDService,
}
if err := revocation.Validate(); err != nil {
return
}
revokeResult, err := s.sessionStore.Revoke(ctx, ports.RevokeSessionInput{
DeviceSessionID: record.ID,
Revocation: revocation,
})
if err != nil {
s.logger.Warn(
"best-effort superseded session revoke failed",
zap.String("component", "service"),
zap.String("use_case", "confirm_email_code"),
zap.String("operation", "confirm_email_code_race_cleanup"),
zap.String("device_session_id", record.ID.String()),
zap.String("reason_code", revocation.ReasonCode.String()),
zap.Error(err),
)
return
}
if err := revokeResult.Validate(); err != nil {
s.logger.Warn(
"best-effort superseded session revoke produced invalid result",
zap.String("component", "service"),
zap.String("use_case", "confirm_email_code"),
zap.String("operation", "confirm_email_code_race_cleanup"),
zap.String("device_session_id", record.ID.String()),
zap.Error(err),
)
return
}
if revokeResult.Outcome == ports.RevokeSessionOutcomeRevoked {
s.telemetry.RecordSessionRevocations(ctx, "confirm_email_code_race_cleanup", revocation.ReasonCode.String(), 1)
}
snapshot, err := shared.ToGatewayProjectionSnapshot(revokeResult.Session)
if err != nil {
s.logger.Warn(
"best-effort superseded session snapshot mapping failed",
zap.String("component", "service"),
zap.String("use_case", "confirm_email_code"),
zap.String("operation", "confirm_email_code_race_cleanup"),
zap.String("device_session_id", revokeResult.Session.ID.String()),
zap.Error(err),
)
return
}
if err := shared.PublishProjectionSnapshotWithTelemetry(ctx, s.publisher, snapshot, s.telemetry, "confirm_email_code_race_cleanup"); err != nil {
s.logger.Warn(
"best-effort superseded session publish failed",
zap.String("component", "service"),
zap.String("use_case", "confirm_email_code"),
zap.String("operation", "confirm_email_code_race_cleanup"),
zap.String("device_session_id", revokeResult.Session.ID.String()),
zap.Error(err),
)
}
}
func (s *Service) publishSession(ctx context.Context, record devicesession.Session, operation string) error {
return shared.PublishSessionProjectionWithTelemetry(ctx, s.publisher, record, s.telemetry, operation)
}
func namedLogger(logger *zap.Logger, name string) *zap.Logger {
if logger == nil {
logger = zap.NewNop()
}
return logger.Named(name)
}
@@ -0,0 +1,682 @@
package confirmemailcode
import (
"context"
"errors"
"github.com/stretchr/testify/require"
"testing"
"time"
"galaxy/authsession/internal/domain/challenge"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/devicesession"
"galaxy/authsession/internal/domain/userresolution"
"galaxy/authsession/internal/service/shared"
"galaxy/authsession/internal/testkit"
)
func TestExecuteConfirmsChallengeForExistingUser(t *testing.T) {
t.Parallel()
deps := newConfirmDeps(t)
if err := deps.userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil {
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
}
if err := deps.challengeStore.Create(context.Background(), sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service := mustNewConfirmService(t, deps)
result, err := service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: publicKeyString(),
})
if err != nil {
require.Failf(t, "test failed", "Execute() returned error: %v", err)
}
if result.DeviceSessionID != "device-session-1" {
require.Failf(t, "test failed", "Execute().DeviceSessionID = %q, want %q", result.DeviceSessionID, "device-session-1")
}
record, err := deps.sessionStore.Get(context.Background(), common.DeviceSessionID("device-session-1"))
if err != nil {
require.Failf(t, "test failed", "Get() returned error: %v", err)
}
if record.Status != devicesession.StatusActive {
require.Failf(t, "test failed", "session status = %q, want %q", record.Status, devicesession.StatusActive)
}
challengeRecord, err := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
if err != nil {
require.Failf(t, "test failed", "Get() returned error: %v", err)
}
if challengeRecord.Status != challenge.StatusConfirmedPendingExpire || challengeRecord.Confirmation == nil {
require.Failf(t, "test failed", "challenge status = %q, confirmation = %+v", challengeRecord.Status, challengeRecord.Confirmation)
}
if len(deps.publisher.PublishedSnapshots()) != 1 {
require.Failf(t, "test failed", "PublishedSnapshots() length = %d, want 1", len(deps.publisher.PublishedSnapshots()))
}
}
func TestExecuteConfirmsChallengeByCreatingUser(t *testing.T) {
t.Parallel()
deps := newConfirmDeps(t)
if err := deps.userDirectory.QueueCreatedUserIDs(common.UserID("user-created")); err != nil {
require.Failf(t, "test failed", "QueueCreatedUserIDs() returned error: %v", err)
}
if err := deps.challengeStore.Create(context.Background(), sentChallengeFixture(t, deps.hasher, "challenge-1", "new@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service := mustNewConfirmService(t, deps)
result, err := service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: publicKeyString(),
})
if err != nil {
require.Failf(t, "test failed", "Execute() returned error: %v", err)
}
if result.DeviceSessionID != "device-session-1" {
require.Failf(t, "test failed", "Execute().DeviceSessionID = %q, want %q", result.DeviceSessionID, "device-session-1")
}
record, err := deps.sessionStore.Get(context.Background(), common.DeviceSessionID("device-session-1"))
if err != nil {
require.Failf(t, "test failed", "Get() returned error: %v", err)
}
if record.UserID != common.UserID("user-created") {
require.Failf(t, "test failed", "session user id = %q, want %q", record.UserID, common.UserID("user-created"))
}
}
func TestExecuteConfirmsSuppressedChallenge(t *testing.T) {
t.Parallel()
deps := newConfirmDeps(t)
if err := deps.userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil {
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
}
record := sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))
record.Status = challenge.StatusDeliverySuppressed
record.DeliveryState = challenge.DeliverySuppressed
if err := record.Validate(); err != nil {
require.Failf(t, "test failed", "Validate() returned error: %v", err)
}
if err := deps.challengeStore.Create(context.Background(), record); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service := mustNewConfirmService(t, deps)
result, err := service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: publicKeyString(),
})
if err != nil {
require.Failf(t, "test failed", "Execute() returned error: %v", err)
}
if result.DeviceSessionID != "device-session-1" {
require.Failf(t, "test failed", "Execute().DeviceSessionID = %q, want %q", result.DeviceSessionID, "device-session-1")
}
}
func TestExecuteReturnsChallengeNotFound(t *testing.T) {
t.Parallel()
service := mustNewConfirmService(t, newConfirmDeps(t))
_, err := service.Execute(context.Background(), Input{
ChallengeID: "missing",
Code: "654321",
ClientPublicKey: publicKeyString(),
})
if shared.CodeOf(err) != shared.ErrorCodeChallengeNotFound {
require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeChallengeNotFound)
}
}
func TestExecuteReturnsChallengeExpiredAndMarksExpired(t *testing.T) {
t.Parallel()
deps := newConfirmDeps(t)
if err := deps.challengeStore.Create(context.Background(), sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-2*time.Minute), deps.now.Add(-time.Second))); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service := mustNewConfirmService(t, deps)
_, err := service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: publicKeyString(),
})
if shared.CodeOf(err) != shared.ErrorCodeChallengeExpired {
require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeChallengeExpired)
}
record, err := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
if err != nil {
require.Failf(t, "test failed", "Get() returned error: %v", err)
}
if record.Status != challenge.StatusExpired {
require.Failf(t, "test failed", "challenge status = %q, want %q", record.Status, challenge.StatusExpired)
}
}
func TestExecuteReturnsChallengeExpiredForConfirmedChallengeAfterRetentionWindow(t *testing.T) {
t.Parallel()
deps := newConfirmDeps(t)
key, err := shared.ParseClientPublicKey(publicKeyString())
if err != nil {
require.Failf(t, "test failed", "ParseClientPublicKey() returned error: %v", err)
}
record := confirmedChallengeFixture(
t,
deps.hasher,
"challenge-1",
"pilot@example.com",
"654321",
"device-session-1",
key,
deps.now.Add(-2*challenge.ConfirmedRetention),
deps.now.Add(-time.Second),
)
if err := deps.challengeStore.Create(context.Background(), record); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service := mustNewConfirmService(t, deps)
_, err = service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: publicKeyString(),
})
if shared.CodeOf(err) != shared.ErrorCodeChallengeExpired {
require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeChallengeExpired)
}
updated, err := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
if err != nil {
require.Failf(t, "test failed", "Get() returned error: %v", err)
}
if updated.Status != challenge.StatusExpired {
require.Failf(t, "test failed", "challenge status = %q, want %q", updated.Status, challenge.StatusExpired)
}
if updated.Confirmation != nil {
require.Failf(t, "test failed", "Confirmation = %+v, want nil after expiration", updated.Confirmation)
}
}
func TestExecuteReturnsInvalidClientPublicKey(t *testing.T) {
t.Parallel()
service := mustNewConfirmService(t, newConfirmDeps(t))
_, err := service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: "invalid",
})
if shared.CodeOf(err) != shared.ErrorCodeInvalidClientPublicKey {
require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeInvalidClientPublicKey)
}
}
func TestExecuteInvalidCodeIncrementsAttempts(t *testing.T) {
t.Parallel()
deps := newConfirmDeps(t)
if err := deps.challengeStore.Create(context.Background(), sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service := mustNewConfirmService(t, deps)
_, err := service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "000000",
ClientPublicKey: publicKeyString(),
})
if shared.CodeOf(err) != shared.ErrorCodeInvalidCode {
require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeInvalidCode)
}
record, err := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
if err != nil {
require.Failf(t, "test failed", "Get() returned error: %v", err)
}
if record.Attempts.Confirm != 1 {
require.Failf(t, "test failed", "Attempts.Confirm = %d, want 1", record.Attempts.Confirm)
}
}
func TestExecuteFifthInvalidAttemptMarksChallengeFailed(t *testing.T) {
t.Parallel()
deps := newConfirmDeps(t)
record := sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))
record.Attempts.Confirm = 4
if err := deps.challengeStore.Create(context.Background(), record); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service := mustNewConfirmService(t, deps)
_, err := service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "000000",
ClientPublicKey: publicKeyString(),
})
if shared.CodeOf(err) != shared.ErrorCodeInvalidCode {
require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeInvalidCode)
}
updated, err := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
if err != nil {
require.Failf(t, "test failed", "Get() returned error: %v", err)
}
if updated.Status != challenge.StatusFailed {
require.Failf(t, "test failed", "challenge status = %q, want %q", updated.Status, challenge.StatusFailed)
}
}
func TestExecuteDoesNotCreateSessionAfterTooManyAttempts(t *testing.T) {
t.Parallel()
deps := newConfirmDeps(t)
if err := deps.userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil {
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
}
record := sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))
record.Attempts.Confirm = challenge.MaxInvalidConfirmAttempts
record.Status = challenge.StatusFailed
if err := record.Validate(); err != nil {
require.Failf(t, "test failed", "Validate() returned error: %v", err)
}
if err := deps.challengeStore.Create(context.Background(), record); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service := mustNewConfirmService(t, deps)
_, err := service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: publicKeyString(),
})
if shared.CodeOf(err) != shared.ErrorCodeInvalidCode {
require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeInvalidCode)
}
if got, err := deps.sessionStore.CountActiveByUserID(context.Background(), common.UserID("user-1")); err != nil {
require.Failf(t, "test failed", "CountActiveByUserID() returned error: %v", err)
} else if got != 0 {
require.Failf(t, "test failed", "CountActiveByUserID() = %d, want 0", got)
}
}
func TestExecuteReturnsSameSessionIDForIdempotentRetryAndRepublishes(t *testing.T) {
t.Parallel()
deps := newConfirmDeps(t)
key, err := shared.ParseClientPublicKey(publicKeyString())
if err != nil {
require.Failf(t, "test failed", "ParseClientPublicKey() returned error: %v", err)
}
record := confirmedChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", "device-session-1", key, deps.now.Add(-time.Minute), deps.now.Add(time.Minute))
if err := deps.challengeStore.Create(context.Background(), record); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
if err := deps.sessionStore.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", key, deps.now.Add(-time.Minute))); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service := mustNewConfirmService(t, deps)
result, err := service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: publicKeyString(),
})
if err != nil {
require.Failf(t, "test failed", "Execute() returned error: %v", err)
}
if result.DeviceSessionID != "device-session-1" {
require.Failf(t, "test failed", "Execute().DeviceSessionID = %q, want %q", result.DeviceSessionID, "device-session-1")
}
if len(deps.publisher.PublishedSnapshots()) != 1 {
require.Failf(t, "test failed", "PublishedSnapshots() length = %d, want 1", len(deps.publisher.PublishedSnapshots()))
}
}
func TestExecuteReturnsInvalidCodeForDifferentKeyDuringIdempotentRetry(t *testing.T) {
t.Parallel()
deps := newConfirmDeps(t)
key, err := shared.ParseClientPublicKey(publicKeyString())
if err != nil {
require.Failf(t, "test failed", "ParseClientPublicKey() returned error: %v", err)
}
record := confirmedChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", "device-session-1", key, deps.now.Add(-time.Minute), deps.now.Add(time.Minute))
if err := deps.challengeStore.Create(context.Background(), record); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
if err := deps.sessionStore.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", key, deps.now.Add(-time.Minute))); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service := mustNewConfirmService(t, deps)
_, err = service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: alternatePublicKeyString(),
})
if shared.CodeOf(err) != shared.ErrorCodeInvalidCode {
require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeInvalidCode)
}
updated, err := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
if err != nil {
require.Failf(t, "test failed", "Get() returned error: %v", err)
}
if updated.Attempts.Confirm != 0 {
require.Failf(t, "test failed", "Attempts.Confirm = %d, want 0", updated.Attempts.Confirm)
}
if updated.Confirmation == nil {
require.FailNow(t, "Confirmation = nil, want metadata to stay intact")
}
if updated.Confirmation.SessionID != common.DeviceSessionID("device-session-1") {
require.Failf(t, "test failed", "Confirmation.SessionID = %q, want %q", updated.Confirmation.SessionID, common.DeviceSessionID("device-session-1"))
}
}
func TestExecuteReturnsInvalidCodeForNonConfirmableStates(t *testing.T) {
t.Parallel()
tests := []struct {
name string
status challenge.Status
deliveryState challenge.DeliveryState
}{
{name: "pending send", status: challenge.StatusPendingSend, deliveryState: challenge.DeliveryPending},
{name: "failed", status: challenge.StatusFailed, deliveryState: challenge.DeliveryFailed},
{name: "cancelled", status: challenge.StatusCancelled, deliveryState: challenge.DeliverySent},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
deps := newConfirmDeps(t)
record := sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))
record.Status = tt.status
record.DeliveryState = tt.deliveryState
if err := record.Validate(); err != nil {
require.Failf(t, "test failed", "Validate() returned error: %v", err)
}
if err := deps.challengeStore.Create(context.Background(), record); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service := mustNewConfirmService(t, deps)
_, err := service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: publicKeyString(),
})
if shared.CodeOf(err) != shared.ErrorCodeInvalidCode {
require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeInvalidCode)
}
updated, err := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
if err != nil {
require.Failf(t, "test failed", "Get() returned error: %v", err)
}
if updated.Attempts.Confirm != 0 {
require.Failf(t, "test failed", "Attempts.Confirm = %d, want 0", updated.Attempts.Confirm)
}
})
}
}
func TestExecuteMarksChallengeFailedAndReturnsBlockedByPolicy(t *testing.T) {
t.Parallel()
deps := newConfirmDeps(t)
if err := deps.userDirectory.SeedBlockedEmail(common.Email("pilot@example.com"), userresolution.BlockReasonCode("policy_block")); err != nil {
require.Failf(t, "test failed", "SeedBlockedEmail() returned error: %v", err)
}
if err := deps.challengeStore.Create(context.Background(), sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service := mustNewConfirmService(t, deps)
_, err := service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: publicKeyString(),
})
if shared.CodeOf(err) != shared.ErrorCodeBlockedByPolicy {
require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeBlockedByPolicy)
}
record, err := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
if err != nil {
require.Failf(t, "test failed", "Get() returned error: %v", err)
}
if record.Status != challenge.StatusFailed {
require.Failf(t, "test failed", "challenge status = %q, want %q", record.Status, challenge.StatusFailed)
}
}
func TestExecuteReturnsSessionLimitExceededWithoutConsumingChallenge(t *testing.T) {
t.Parallel()
deps := newConfirmDeps(t)
if err := deps.userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil {
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
}
if err := deps.challengeStore.Create(context.Background(), sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
if err := deps.sessionStore.Create(context.Background(), activeSessionFixture("device-session-existing", "user-1", mustClientPublicKey(t, publicKeyString()), deps.now.Add(-2*time.Minute))); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
limit := 1
deps.configProvider.Config.ActiveSessionLimit = &limit
service := mustNewConfirmService(t, deps)
_, err := service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: publicKeyString(),
})
if shared.CodeOf(err) != shared.ErrorCodeSessionLimitExceeded {
require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeSessionLimitExceeded)
}
record, err := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
if err != nil {
require.Failf(t, "test failed", "Get() returned error: %v", err)
}
if record.Status != challenge.StatusSent {
require.Failf(t, "test failed", "challenge status = %q, want %q", record.Status, challenge.StatusSent)
}
if record.Attempts.Confirm != 0 {
require.Failf(t, "test failed", "Attempts.Confirm = %d, want 0", record.Attempts.Confirm)
}
}
func TestExecuteReturnsServiceUnavailableThenSucceedsIdempotentlyAfterPublishFailure(t *testing.T) {
t.Parallel()
deps := newConfirmDeps(t)
deps.publisher.Err = errors.New("publish failed")
if err := deps.userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil {
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
}
if err := deps.challengeStore.Create(context.Background(), sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service := mustNewConfirmService(t, deps)
_, err := service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: publicKeyString(),
})
if shared.CodeOf(err) != shared.ErrorCodeServiceUnavailable {
require.Failf(t, "test failed", "first Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeServiceUnavailable)
}
deps.publisher.Err = nil
result, err := service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: publicKeyString(),
})
if err != nil {
require.Failf(t, "test failed", "second Execute() returned error: %v", err)
}
if result.DeviceSessionID != "device-session-1" {
require.Failf(t, "test failed", "second Execute().DeviceSessionID = %q, want %q", result.DeviceSessionID, "device-session-1")
}
}
type confirmDeps struct {
challengeStore *testkit.InMemoryChallengeStore
sessionStore *testkit.InMemorySessionStore
userDirectory *testkit.InMemoryUserDirectory
configProvider testkit.StaticConfigProvider
publisher *testkit.RecordingProjectionPublisher
idGenerator *testkit.SequenceIDGenerator
hasher testkit.DeterministicCodeHasher
now time.Time
}
func newConfirmDeps(t *testing.T) confirmDeps {
t.Helper()
return confirmDeps{
challengeStore: &testkit.InMemoryChallengeStore{},
sessionStore: &testkit.InMemorySessionStore{},
userDirectory: &testkit.InMemoryUserDirectory{},
configProvider: testkit.StaticConfigProvider{},
publisher: &testkit.RecordingProjectionPublisher{},
idGenerator: &testkit.SequenceIDGenerator{
DeviceSessionIDs: []common.DeviceSessionID{"device-session-1"},
},
hasher: testkit.DeterministicCodeHasher{},
now: time.Unix(20, 0).UTC(),
}
}
func mustNewConfirmService(t *testing.T, deps confirmDeps) *Service {
t.Helper()
service, err := New(
deps.challengeStore,
deps.sessionStore,
deps.userDirectory,
deps.configProvider,
deps.publisher,
deps.idGenerator,
deps.hasher,
testkit.FixedClock{Time: deps.now},
)
if err != nil {
require.Failf(t, "test failed", "New() returned error: %v", err)
}
return service
}
func sentChallengeFixture(
t *testing.T,
hasher testkit.DeterministicCodeHasher,
challengeID string,
email string,
code string,
createdAt time.Time,
expiresAt time.Time,
) challenge.Challenge {
t.Helper()
codeHash, err := hasher.Hash(code)
if err != nil {
require.Failf(t, "test failed", "Hash() returned error: %v", err)
}
record := challenge.Challenge{
ID: common.ChallengeID(challengeID),
Email: common.Email(email),
CodeHash: codeHash,
Status: challenge.StatusSent,
DeliveryState: challenge.DeliverySent,
CreatedAt: createdAt,
ExpiresAt: expiresAt,
}
if err := record.Validate(); err != nil {
require.Failf(t, "test failed", "Validate() returned error: %v", err)
}
return record
}
func confirmedChallengeFixture(
t *testing.T,
hasher testkit.DeterministicCodeHasher,
challengeID string,
email string,
code string,
deviceSessionID string,
clientPublicKey common.ClientPublicKey,
createdAt time.Time,
expiresAt time.Time,
) challenge.Challenge {
t.Helper()
record := sentChallengeFixture(t, hasher, challengeID, email, code, createdAt, expiresAt)
record.Status = challenge.StatusConfirmedPendingExpire
record.Confirmation = &challenge.Confirmation{
SessionID: common.DeviceSessionID(deviceSessionID),
ClientPublicKey: clientPublicKey,
ConfirmedAt: createdAt.Add(time.Minute),
}
if err := record.Validate(); err != nil {
require.Failf(t, "test failed", "Validate() returned error: %v", err)
}
return record
}
func activeSessionFixture(deviceSessionID string, userID string, clientPublicKey common.ClientPublicKey, createdAt time.Time) devicesession.Session {
return devicesession.Session{
ID: common.DeviceSessionID(deviceSessionID),
UserID: common.UserID(userID),
ClientPublicKey: clientPublicKey,
Status: devicesession.StatusActive,
CreatedAt: createdAt,
}
}
func mustClientPublicKey(t *testing.T, value string) common.ClientPublicKey {
t.Helper()
key, err := shared.ParseClientPublicKey(value)
if err != nil {
require.Failf(t, "test failed", "ParseClientPublicKey() returned error: %v", err)
}
return key
}
func publicKeyString() string {
return "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8="
}
func alternatePublicKeyString() string {
return "AQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQE="
}
@@ -0,0 +1,109 @@
package confirmemailcode
import (
"context"
"testing"
"time"
stubuserservice "galaxy/authsession/internal/adapters/userservice"
"galaxy/authsession/internal/domain/challenge"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/userresolution"
"galaxy/authsession/internal/service/shared"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestExecuteWithRuntimeStubUserDirectory(t *testing.T) {
t.Parallel()
t.Run("creates user through EnsureUserByEmail", func(t *testing.T) {
t.Parallel()
deps := newConfirmDeps(t)
userDirectory := &stubuserservice.StubDirectory{}
require.NoError(t, userDirectory.QueueCreatedUserIDs(common.UserID("user-created")))
deps.userDirectory = nil
require.NoError(t, deps.challengeStore.Create(context.Background(), sentChallengeFixture(
t,
deps.hasher,
"challenge-1",
"pilot@example.com",
"654321",
deps.now.Add(-time.Minute),
deps.now.Add(time.Minute),
)))
service, err := New(
deps.challengeStore,
deps.sessionStore,
userDirectory,
deps.configProvider,
deps.publisher,
deps.idGenerator,
deps.hasher,
fixedClock(deps.now),
)
require.NoError(t, err)
result, err := service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: publicKeyString(),
})
require.NoError(t, err)
assert.Equal(t, "device-session-1", result.DeviceSessionID)
sessionRecord, err := deps.sessionStore.Get(context.Background(), common.DeviceSessionID("device-session-1"))
require.NoError(t, err)
assert.Equal(t, common.UserID("user-created"), sessionRecord.UserID)
})
t.Run("blocked email returns blocked by policy", func(t *testing.T) {
t.Parallel()
deps := newConfirmDeps(t)
userDirectory := &stubuserservice.StubDirectory{}
require.NoError(t, userDirectory.SeedBlockedEmail(common.Email("pilot@example.com"), userresolution.BlockReasonCode("policy_block")))
require.NoError(t, deps.challengeStore.Create(context.Background(), sentChallengeFixture(
t,
deps.hasher,
"challenge-1",
"pilot@example.com",
"654321",
deps.now.Add(-time.Minute),
deps.now.Add(time.Minute),
)))
service, err := New(
deps.challengeStore,
deps.sessionStore,
userDirectory,
deps.configProvider,
deps.publisher,
deps.idGenerator,
deps.hasher,
fixedClock(deps.now),
)
require.NoError(t, err)
_, err = service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: publicKeyString(),
})
require.Error(t, err)
assert.Equal(t, shared.ErrorCodeBlockedByPolicy, shared.CodeOf(err))
record, getErr := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
require.NoError(t, getErr)
assert.Equal(t, challenge.StatusFailed, record.Status)
})
}
type fixedClock time.Time
func (c fixedClock) Now() time.Time {
return time.Time(c)
}
@@ -0,0 +1,104 @@
package confirmemailcode
import (
"context"
"testing"
"time"
"galaxy/authsession/internal/domain/challenge"
authtelemetry "galaxy/authsession/internal/telemetry"
"galaxy/authsession/internal/testkit"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/attribute"
sdkmetric "go.opentelemetry.io/otel/sdk/metric"
"go.opentelemetry.io/otel/sdk/metric/metricdata"
)
func TestExecuteRecordsInvalidCodeMetricForThrottledChallenge(t *testing.T) {
t.Parallel()
runtime, reader := newObservedConfirmTelemetryRuntime(t)
deps := newConfirmDeps(t)
record := sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))
record.Status = challenge.StatusDeliveryThrottled
record.DeliveryState = challenge.DeliveryThrottled
require.NoError(t, record.Validate())
require.NoError(t, deps.challengeStore.Create(context.Background(), record))
service, err := NewWithTelemetry(
deps.challengeStore,
deps.sessionStore,
deps.userDirectory,
deps.configProvider,
deps.publisher,
deps.idGenerator,
deps.hasher,
testkit.FixedClock{Time: deps.now},
runtime,
)
require.NoError(t, err)
_, err = service.Execute(context.Background(), Input{
ChallengeID: "challenge-1",
Code: "654321",
ClientPublicKey: publicKeyString(),
})
require.Error(t, err)
assertConfirmMetricCount(t, reader, map[string]string{"outcome": "invalid_code"}, 1)
}
func newObservedConfirmTelemetryRuntime(t *testing.T) (*authtelemetry.Runtime, *sdkmetric.ManualReader) {
t.Helper()
reader := sdkmetric.NewManualReader()
provider := sdkmetric.NewMeterProvider(sdkmetric.WithReader(reader))
runtime, err := authtelemetry.New(provider)
require.NoError(t, err)
return runtime, reader
}
func assertConfirmMetricCount(t *testing.T, reader *sdkmetric.ManualReader, wantAttrs map[string]string, wantValue int64) {
t.Helper()
var resourceMetrics metricdata.ResourceMetrics
require.NoError(t, reader.Collect(context.Background(), &resourceMetrics))
for _, scopeMetrics := range resourceMetrics.ScopeMetrics {
for _, metric := range scopeMetrics.Metrics {
if metric.Name != "authsession.confirm_email_code.attempts" {
continue
}
sum, ok := metric.Data.(metricdata.Sum[int64])
require.True(t, ok)
for _, point := range sum.DataPoints {
if hasConfirmMetricAttributes(point.Attributes.ToSlice(), wantAttrs) {
assert.Equal(t, wantValue, point.Value)
return
}
}
}
}
require.Failf(t, "test failed", "confirm metric with attrs %v not found", wantAttrs)
}
func hasConfirmMetricAttributes(values []attribute.KeyValue, want map[string]string) bool {
if len(values) != len(want) {
return false
}
for _, value := range values {
if want[string(value.Key)] != value.Value.AsString() {
return false
}
}
return true
}
@@ -0,0 +1,65 @@
// Package getsession implements the trusted internal read use case for one
// device session.
package getsession
import (
"context"
"errors"
"fmt"
"galaxy/authsession/internal/ports"
"galaxy/authsession/internal/service/shared"
)
// Input describes one trusted internal get-session request.
type Input struct {
// DeviceSessionID identifies the session that should be read.
DeviceSessionID string
}
// Result describes one trusted internal get-session response.
type Result struct {
// Session stores the frozen internal read-model DTO.
Session shared.Session
}
// Service executes the trusted internal get-session use case against the
// configured ports.
type Service struct {
sessionStore ports.SessionStore
}
// New returns a get-session service wired to sessionStore.
func New(sessionStore ports.SessionStore) (*Service, error) {
if sessionStore == nil {
return nil, fmt.Errorf("getsession: session store must not be nil")
}
return &Service{sessionStore: sessionStore}, nil
}
// Execute loads one source-of-truth session and projects it into the frozen
// internal read DTO shape.
func (s *Service) Execute(ctx context.Context, input Input) (Result, error) {
deviceSessionID, err := shared.ParseDeviceSessionID(input.DeviceSessionID)
if err != nil {
return Result{}, err
}
record, err := s.sessionStore.Get(ctx, deviceSessionID)
if err != nil {
switch {
case errors.Is(err, ports.ErrNotFound):
return Result{}, shared.SessionNotFound()
default:
return Result{}, shared.ServiceUnavailable(err)
}
}
session, err := shared.ToSession(record)
if err != nil {
return Result{}, shared.InternalError(err)
}
return Result{Session: session}, nil
}
@@ -0,0 +1,68 @@
package getsession
import (
"context"
"github.com/stretchr/testify/require"
"testing"
"time"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/devicesession"
"galaxy/authsession/internal/service/shared"
"galaxy/authsession/internal/testkit"
)
func TestExecuteReturnsMappedSession(t *testing.T) {
t.Parallel()
store := &testkit.InMemorySessionStore{}
record := activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())
if err := store.Create(context.Background(), record); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service, err := New(store)
if err != nil {
require.Failf(t, "test failed", "New() returned error: %v", err)
}
result, err := service.Execute(context.Background(), Input{DeviceSessionID: " device-session-1 "})
if err != nil {
require.Failf(t, "test failed", "Execute() returned error: %v", err)
}
if result.Session.DeviceSessionID != "device-session-1" {
require.Failf(t, "test failed", "Execute().Session.DeviceSessionID = %q, want %q", result.Session.DeviceSessionID, "device-session-1")
}
if result.Session.CreatedAt != time.Unix(10, 0).UTC().Format(time.RFC3339) {
require.Failf(t, "test failed", "Execute().Session.CreatedAt = %q", result.Session.CreatedAt)
}
}
func TestExecuteReturnsSessionNotFound(t *testing.T) {
t.Parallel()
service, err := New(&testkit.InMemorySessionStore{})
if err != nil {
require.Failf(t, "test failed", "New() returned error: %v", err)
}
_, err = service.Execute(context.Background(), Input{DeviceSessionID: "missing"})
if shared.CodeOf(err) != shared.ErrorCodeSessionNotFound {
require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeSessionNotFound)
}
}
func activeSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session {
key, err := common.NewClientPublicKey(make([]byte, 32))
if err != nil {
panic(err)
}
return devicesession.Session{
ID: common.DeviceSessionID(deviceSessionID),
UserID: common.UserID(userID),
ClientPublicKey: key,
Status: devicesession.StatusActive,
CreatedAt: createdAt,
}
}
@@ -0,0 +1,58 @@
// Package listusersessions implements the trusted internal read use case for
// listing all sessions of one user.
package listusersessions
import (
"context"
"fmt"
"galaxy/authsession/internal/ports"
"galaxy/authsession/internal/service/shared"
)
// Input describes one trusted internal list-user-sessions request.
type Input struct {
// UserID identifies the owner whose sessions should be listed.
UserID string
}
// Result describes one trusted internal list-user-sessions response.
type Result struct {
// Sessions stores the frozen internal read-model DTO slice.
Sessions []shared.Session
}
// Service executes the trusted internal list-user-sessions use case.
type Service struct {
sessionStore ports.SessionStore
}
// New returns a list-user-sessions service wired to sessionStore.
func New(sessionStore ports.SessionStore) (*Service, error) {
if sessionStore == nil {
return nil, fmt.Errorf("listusersessions: session store must not be nil")
}
return &Service{sessionStore: sessionStore}, nil
}
// Execute loads all source-of-truth sessions for one user and projects them
// into the frozen internal read DTO shape.
func (s *Service) Execute(ctx context.Context, input Input) (Result, error) {
userID, err := shared.ParseUserID(input.UserID)
if err != nil {
return Result{}, err
}
records, err := s.sessionStore.ListByUserID(ctx, userID)
if err != nil {
return Result{}, shared.ServiceUnavailable(err)
}
sessions, err := shared.ToSessions(records)
if err != nil {
return Result{}, shared.InternalError(err)
}
return Result{Sessions: sessions}, nil
}
@@ -0,0 +1,73 @@
package listusersessions
import (
"context"
"github.com/stretchr/testify/require"
"testing"
"time"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/devicesession"
"galaxy/authsession/internal/testkit"
)
func TestExecutePreservesNewestFirstOrder(t *testing.T) {
t.Parallel()
store := &testkit.InMemorySessionStore{}
older := activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())
newer := activeSessionFixture("device-session-2", "user-1", time.Unix(20, 0).UTC())
for _, record := range []devicesession.Session{older, newer} {
if err := store.Create(context.Background(), record); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
}
service, err := New(store)
if err != nil {
require.Failf(t, "test failed", "New() returned error: %v", err)
}
result, err := service.Execute(context.Background(), Input{UserID: "user-1"})
if err != nil {
require.Failf(t, "test failed", "Execute() returned error: %v", err)
}
if len(result.Sessions) != 2 {
require.Failf(t, "test failed", "Execute().Sessions length = %d, want 2", len(result.Sessions))
}
if result.Sessions[0].DeviceSessionID != "device-session-2" || result.Sessions[1].DeviceSessionID != "device-session-1" {
require.Failf(t, "test failed", "Execute().Sessions order = [%q %q]", result.Sessions[0].DeviceSessionID, result.Sessions[1].DeviceSessionID)
}
}
func TestExecuteReturnsEmptyForUnknownUser(t *testing.T) {
t.Parallel()
service, err := New(&testkit.InMemorySessionStore{})
if err != nil {
require.Failf(t, "test failed", "New() returned error: %v", err)
}
result, err := service.Execute(context.Background(), Input{UserID: "missing"})
if err != nil {
require.Failf(t, "test failed", "Execute() returned error: %v", err)
}
if len(result.Sessions) != 0 {
require.Failf(t, "test failed", "Execute().Sessions length = %d, want 0", len(result.Sessions))
}
}
func activeSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session {
key, err := common.NewClientPublicKey(make([]byte, 32))
if err != nil {
panic(err)
}
return devicesession.Session{
ID: common.DeviceSessionID(deviceSessionID),
UserID: common.UserID(userID),
ClientPublicKey: key,
Status: devicesession.StatusActive,
CreatedAt: createdAt,
}
}
@@ -0,0 +1,106 @@
package revokeallusersessions
import (
"context"
"errors"
"testing"
"time"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/devicesession"
"galaxy/authsession/internal/service/shared"
"galaxy/authsession/internal/testkit"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestExecuteRetriesProjectionPublishesForBulkRevoke(t *testing.T) {
t.Parallel()
store := &testkit.InMemorySessionStore{}
userDirectory := &testkit.InMemoryUserDirectory{}
publisher := &testkit.RecordingProjectionPublisher{
Errors: []error{
errors.New("publish failed"),
nil,
errors.New("publish failed"),
nil,
},
}
require.NoError(t, userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
require.NoError(t, store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())))
require.NoError(t, store.Create(context.Background(), activeSessionFixture("device-session-2", "user-1", time.Unix(20, 0).UTC())))
service, err := New(store, userDirectory, publisher, testkit.FixedClock{Time: time.Unix(30, 0).UTC()})
require.NoError(t, err)
result, err := service.Execute(context.Background(), Input{
UserID: "user-1",
ReasonCode: "logout_all",
ActorType: "system",
})
require.NoError(t, err)
assert.Equal(t, "revoked", result.Outcome)
assert.EqualValues(t, 2, result.AffectedSessionCount)
assert.Equal(t, []string{"device-session-2", "device-session-1"}, result.AffectedDeviceSessionIDs)
require.Len(t, publisher.PublishedSnapshots(), 4)
}
func TestExecuteRepublishesCurrentRevokedSessionsOnNoActiveSessionsRetry(t *testing.T) {
t.Parallel()
store := &testkit.InMemorySessionStore{}
userDirectory := &testkit.InMemoryUserDirectory{}
publisher := &testkit.RecordingProjectionPublisher{
Errors: []error{
nil,
errors.New("publish failed"),
errors.New("publish failed"),
errors.New("publish failed"),
},
}
require.NoError(t, userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
require.NoError(t, store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())))
require.NoError(t, store.Create(context.Background(), activeSessionFixture("device-session-2", "user-1", time.Unix(20, 0).UTC())))
service, err := New(store, userDirectory, publisher, testkit.FixedClock{Time: time.Unix(30, 0).UTC()})
require.NoError(t, err)
_, err = service.Execute(context.Background(), Input{
UserID: "user-1",
ReasonCode: "logout_all",
ActorType: "system",
})
require.Error(t, err)
assert.Equal(t, shared.ErrorCodeServiceUnavailable, shared.CodeOf(err))
require.Len(t, publisher.PublishedSnapshots(), 4)
for _, deviceSessionID := range []common.DeviceSessionID{"device-session-1", "device-session-2"} {
record, getErr := store.Get(context.Background(), deviceSessionID)
require.NoError(t, getErr)
require.NotNil(t, record.Revocation)
assert.Equal(t, devicesession.StatusRevoked, record.Status)
}
publisher.Errors = nil
publisher.Err = nil
result, err := service.Execute(context.Background(), Input{
UserID: "user-1",
ReasonCode: "logout_all",
ActorType: "system",
})
require.NoError(t, err)
assert.Equal(t, "no_active_sessions", result.Outcome)
assert.EqualValues(t, 0, result.AffectedSessionCount)
require.NotNil(t, result.AffectedDeviceSessionIDs)
assert.Empty(t, result.AffectedDeviceSessionIDs)
published := publisher.PublishedSnapshots()
require.Len(t, published, 6)
assert.Equal(t, []common.DeviceSessionID{"device-session-2", "device-session-1"}, []common.DeviceSessionID{
published[4].DeviceSessionID,
published[5].DeviceSessionID,
})
}
@@ -0,0 +1,200 @@
// Package revokeallusersessions implements the trusted internal bulk revoke
// use case for all sessions of one user.
package revokeallusersessions
import (
"context"
"fmt"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/devicesession"
"galaxy/authsession/internal/ports"
"galaxy/authsession/internal/service/shared"
"galaxy/authsession/internal/telemetry"
"go.uber.org/zap"
)
// Input describes one trusted internal revoke-all-user-sessions request.
type Input struct {
// UserID identifies the owner whose sessions should be revoked.
UserID string
// ReasonCode stores the machine-readable revoke reason code.
ReasonCode string
// ActorType stores the machine-readable revoke actor type.
ActorType string
// ActorID stores the optional stable revoke actor identifier.
ActorID string
}
// Result describes the frozen internal bulk revoke acknowledgement.
type Result struct {
// Outcome reports whether active sessions were revoked during the current
// call.
Outcome string
// UserID identifies the user addressed by the operation.
UserID string
// AffectedSessionCount reports how many sessions changed state during the
// current call.
AffectedSessionCount int64
// AffectedDeviceSessionIDs lists every session identifier affected during
// the current call.
AffectedDeviceSessionIDs []string
}
// Service executes the trusted internal revoke-all-user-sessions use case.
type Service struct {
sessionStore ports.SessionStore
userDirectory ports.UserDirectory
publisher ports.GatewaySessionProjectionPublisher
clock ports.Clock
logger *zap.Logger
telemetry *telemetry.Runtime
}
// New returns a revoke-all-user-sessions service wired to the required ports.
func New(sessionStore ports.SessionStore, userDirectory ports.UserDirectory, publisher ports.GatewaySessionProjectionPublisher, clock ports.Clock) (*Service, error) {
return NewWithObservability(sessionStore, userDirectory, publisher, clock, nil, nil)
}
// NewWithObservability returns a revoke-all-user-sessions service wired to the
// required ports plus optional structured logging and telemetry dependencies.
func NewWithObservability(
sessionStore ports.SessionStore,
userDirectory ports.UserDirectory,
publisher ports.GatewaySessionProjectionPublisher,
clock ports.Clock,
logger *zap.Logger,
telemetryRuntime *telemetry.Runtime,
) (*Service, error) {
switch {
case sessionStore == nil:
return nil, fmt.Errorf("revokeallusersessions: session store must not be nil")
case userDirectory == nil:
return nil, fmt.Errorf("revokeallusersessions: user directory must not be nil")
case publisher == nil:
return nil, fmt.Errorf("revokeallusersessions: projection publisher must not be nil")
case clock == nil:
return nil, fmt.Errorf("revokeallusersessions: clock must not be nil")
default:
return &Service{
sessionStore: sessionStore,
userDirectory: userDirectory,
publisher: publisher,
clock: clock,
logger: namedLogger(logger, "revoke_all_user_sessions"),
telemetry: telemetryRuntime,
}, nil
}
}
// Execute revokes all active sessions of one user and republishes revoked
// gateway projections for every affected session.
func (s *Service) Execute(ctx context.Context, input Input) (result Result, err error) {
logFields := []zap.Field{
zap.String("component", "service"),
zap.String("use_case", "revoke_all_user_sessions"),
}
defer func() {
shared.LogServiceOutcome(s.logger, ctx, "revoke all user sessions completed", err, logFields...)
}()
userID, err := shared.ParseUserID(input.UserID)
if err != nil {
return Result{}, err
}
logFields = append(logFields, zap.String("user_id", userID.String()))
revocation, err := shared.BuildRevocation(input.ReasonCode, input.ActorType, input.ActorID, s.clock.Now())
if err != nil {
return Result{}, err
}
logFields = append(logFields, zap.String("reason_code", revocation.ReasonCode.String()))
exists, err := s.userDirectory.ExistsByUserID(ctx, userID)
if err != nil {
return Result{}, shared.ServiceUnavailable(err)
}
s.telemetry.RecordUserDirectoryOutcome(ctx, "exists_by_user_id", boolOutcome(exists))
if !exists {
return Result{}, shared.SubjectNotFound()
}
storeResult, err := s.sessionStore.RevokeAllByUserID(ctx, ports.RevokeUserSessionsInput{
UserID: userID,
Revocation: revocation,
})
if err != nil {
return Result{}, shared.ServiceUnavailable(err)
}
if err := storeResult.Validate(); err != nil {
return Result{}, shared.InternalError(err)
}
logFields = append(logFields, zap.String("outcome", string(storeResult.Outcome)))
affectedDeviceSessionIDs := make([]string, 0, len(storeResult.Sessions))
for _, record := range storeResult.Sessions {
if err := shared.PublishSessionProjectionWithTelemetry(ctx, s.publisher, record, s.telemetry, "revoke_all_user_sessions"); err != nil {
return Result{}, err
}
affectedDeviceSessionIDs = append(affectedDeviceSessionIDs, record.ID.String())
}
if storeResult.Outcome == ports.RevokeUserSessionsOutcomeNoActiveSessions {
if err := s.republishCurrentRevokedSessions(ctx, userID); err != nil {
return Result{}, err
}
}
affectedSessionCount := int64(len(storeResult.Sessions))
if affectedSessionCount > 0 {
s.telemetry.RecordSessionRevocations(ctx, "revoke_all_user_sessions", revocation.ReasonCode.String(), affectedSessionCount)
}
logFields = append(logFields, zap.Int64("affected_session_count", affectedSessionCount))
return Result{
Outcome: string(storeResult.Outcome),
UserID: storeResult.UserID.String(),
AffectedSessionCount: affectedSessionCount,
AffectedDeviceSessionIDs: affectedDeviceSessionIDs,
}, nil
}
func (s *Service) republishCurrentRevokedSessions(ctx context.Context, userID common.UserID) error {
records, err := s.sessionStore.ListByUserID(ctx, userID)
if err != nil {
return shared.ServiceUnavailable(err)
}
for _, record := range records {
if record.Status != devicesession.StatusRevoked {
continue
}
if err := shared.PublishSessionProjectionWithTelemetry(ctx, s.publisher, record, s.telemetry, "revoke_all_user_sessions_repair"); err != nil {
return err
}
}
return nil
}
func boolOutcome(value bool) string {
if value {
return "exists"
}
return "missing"
}
func namedLogger(logger *zap.Logger, name string) *zap.Logger {
if logger == nil {
logger = zap.NewNop()
}
return logger.Named(name)
}
@@ -0,0 +1,162 @@
package revokeallusersessions
import (
"context"
"errors"
"testing"
"time"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/devicesession"
"galaxy/authsession/internal/domain/gatewayprojection"
"galaxy/authsession/internal/service/shared"
"galaxy/authsession/internal/testkit"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestExecuteRevokesExistingUserSessionsAndPublishes(t *testing.T) {
t.Parallel()
store := &testkit.InMemorySessionStore{}
userDirectory := &testkit.InMemoryUserDirectory{}
publisher := &testkit.RecordingProjectionPublisher{}
if err := userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil {
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
}
for _, record := range []devicesession.Session{
activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC()),
activeSessionFixture("device-session-2", "user-1", time.Unix(20, 0).UTC()),
} {
if err := store.Create(context.Background(), record); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
}
service, err := New(store, userDirectory, publisher, testkit.FixedClock{Time: time.Unix(30, 0).UTC()})
require.NoError(t, err)
result, err := service.Execute(context.Background(), Input{
UserID: "user-1",
ReasonCode: "logout_all",
ActorType: "system",
})
require.NoError(t, err)
assert.Equal(t, "revoked", result.Outcome)
assert.EqualValues(t, 2, result.AffectedSessionCount)
assert.Equal(t, []string{"device-session-2", "device-session-1"}, result.AffectedDeviceSessionIDs)
for _, deviceSessionID := range result.AffectedDeviceSessionIDs {
stored, getErr := store.Get(context.Background(), common.DeviceSessionID(deviceSessionID))
require.NoError(t, getErr)
require.NotNil(t, stored.Revocation)
assert.Equal(t, devicesession.StatusRevoked, stored.Status)
assert.Equal(t, devicesession.RevokeReasonLogoutAll, stored.Revocation.ReasonCode)
assert.Equal(t, common.RevokeActorType("system"), stored.Revocation.ActorType)
assert.Empty(t, stored.Revocation.ActorID)
assert.Equal(t, time.Unix(30, 0).UTC(), stored.Revocation.At)
}
published := publisher.PublishedSnapshots()
require.Len(t, published, 2)
assert.Equal(t, []common.DeviceSessionID{"device-session-2", "device-session-1"}, []common.DeviceSessionID{
published[0].DeviceSessionID,
published[1].DeviceSessionID,
})
for _, snapshot := range published {
assert.Equal(t, gatewayprojection.StatusRevoked, snapshot.Status)
assert.Equal(t, devicesession.RevokeReasonLogoutAll, snapshot.RevokeReasonCode)
assert.Equal(t, common.RevokeActorType("system"), snapshot.RevokeActorType)
require.NotNil(t, snapshot.RevokedAt)
assert.Equal(t, time.Unix(30, 0).UTC(), *snapshot.RevokedAt)
}
}
func TestExecuteReturnsNoActiveSessionsForExistingUserWithoutActiveSessions(t *testing.T) {
t.Parallel()
store := &testkit.InMemorySessionStore{}
userDirectory := &testkit.InMemoryUserDirectory{}
publisher := &testkit.RecordingProjectionPublisher{}
if err := userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil {
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
}
service, err := New(store, userDirectory, publisher, testkit.FixedClock{Time: time.Unix(30, 0).UTC()})
require.NoError(t, err)
result, err := service.Execute(context.Background(), Input{
UserID: "user-1",
ReasonCode: "logout_all",
ActorType: "system",
})
require.NoError(t, err)
assert.Equal(t, "no_active_sessions", result.Outcome)
assert.EqualValues(t, 0, result.AffectedSessionCount)
require.NotNil(t, result.AffectedDeviceSessionIDs)
assert.Empty(t, result.AffectedDeviceSessionIDs)
assert.Empty(t, publisher.PublishedSnapshots())
}
func TestExecuteReturnsSubjectNotFoundForUnknownUser(t *testing.T) {
t.Parallel()
service, err := New(&testkit.InMemorySessionStore{}, &testkit.InMemoryUserDirectory{}, &testkit.RecordingProjectionPublisher{}, testkit.FixedClock{Time: time.Unix(30, 0).UTC()})
if err != nil {
require.Failf(t, "test failed", "New() returned error: %v", err)
}
_, err = service.Execute(context.Background(), Input{
UserID: "missing",
ReasonCode: "logout_all",
ActorType: "system",
})
assert.Equal(t, shared.ErrorCodeSubjectNotFound, shared.CodeOf(err))
}
func TestExecuteReturnsServiceUnavailableWhenPublishFails(t *testing.T) {
t.Parallel()
store := &testkit.InMemorySessionStore{}
userDirectory := &testkit.InMemoryUserDirectory{}
publisher := &testkit.RecordingProjectionPublisher{Err: errors.New("publish failed")}
if err := userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil {
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
}
if err := store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service, err := New(store, userDirectory, publisher, testkit.FixedClock{Time: time.Unix(30, 0).UTC()})
require.NoError(t, err)
_, err = service.Execute(context.Background(), Input{
UserID: "user-1",
ReasonCode: "logout_all",
ActorType: "system",
})
assert.Equal(t, shared.ErrorCodeServiceUnavailable, shared.CodeOf(err))
stored, getErr := store.Get(context.Background(), common.DeviceSessionID("device-session-1"))
require.NoError(t, getErr)
require.NotNil(t, stored.Revocation)
assert.Equal(t, devicesession.StatusRevoked, stored.Status)
assert.Equal(t, devicesession.RevokeReasonLogoutAll, stored.Revocation.ReasonCode)
assert.Equal(t, common.RevokeActorType("system"), stored.Revocation.ActorType)
}
func activeSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session {
key, err := common.NewClientPublicKey(make([]byte, 32))
if err != nil {
panic(err)
}
return devicesession.Session{
ID: common.DeviceSessionID(deviceSessionID),
UserID: common.UserID(userID),
ClientPublicKey: key,
Status: devicesession.StatusActive,
CreatedAt: createdAt,
}
}
@@ -0,0 +1,53 @@
package revokeallusersessions
import (
"context"
"testing"
"time"
stubuserservice "galaxy/authsession/internal/adapters/userservice"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/service/shared"
"galaxy/authsession/internal/testkit"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestExecuteWithRuntimeStubUserDirectory(t *testing.T) {
t.Parallel()
t.Run("existing user uses ExistsByUserID and returns no active sessions", func(t *testing.T) {
t.Parallel()
userDirectory := &stubuserservice.StubDirectory{}
require.NoError(t, userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
service, err := New(&testkit.InMemorySessionStore{}, userDirectory, &testkit.RecordingProjectionPublisher{}, testkit.FixedClock{Time: time.Unix(30, 0).UTC()})
require.NoError(t, err)
result, err := service.Execute(context.Background(), Input{
UserID: "user-1",
ReasonCode: "logout_all",
ActorType: "system",
})
require.NoError(t, err)
assert.Equal(t, "no_active_sessions", result.Outcome)
assert.Zero(t, result.AffectedSessionCount)
})
t.Run("unknown user returns subject not found", func(t *testing.T) {
t.Parallel()
service, err := New(&testkit.InMemorySessionStore{}, &stubuserservice.StubDirectory{}, &testkit.RecordingProjectionPublisher{}, testkit.FixedClock{Time: time.Unix(30, 0).UTC()})
require.NoError(t, err)
_, err = service.Execute(context.Background(), Input{
UserID: "missing",
ReasonCode: "logout_all",
ActorType: "system",
})
require.Error(t, err)
assert.Equal(t, shared.ErrorCodeSubjectNotFound, shared.CodeOf(err))
})
}
@@ -0,0 +1,75 @@
package revokedevicesession
import (
"context"
"errors"
"testing"
"time"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/devicesession"
"galaxy/authsession/internal/service/shared"
"galaxy/authsession/internal/testkit"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestExecuteRetriesProjectionPublishUntilSuccess(t *testing.T) {
t.Parallel()
store := &testkit.InMemorySessionStore{}
publisher := &testkit.RecordingProjectionPublisher{
Errors: []error{errors.New("publish failed"), nil},
}
require.NoError(t, store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())))
service, err := New(store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
require.NoError(t, err)
result, err := service.Execute(context.Background(), Input{
DeviceSessionID: "device-session-1",
ReasonCode: "logout_all",
ActorType: "system",
})
require.NoError(t, err)
assert.Equal(t, "revoked", result.Outcome)
require.Len(t, publisher.PublishedSnapshots(), 2)
}
func TestExecuteRepairsProjectionOnRepeatedAlreadyRevokedRequest(t *testing.T) {
t.Parallel()
store := &testkit.InMemorySessionStore{}
publisher := &testkit.RecordingProjectionPublisher{Err: errors.New("publish failed")}
require.NoError(t, store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())))
service, err := New(store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
require.NoError(t, err)
_, err = service.Execute(context.Background(), Input{
DeviceSessionID: "device-session-1",
ReasonCode: "logout_all",
ActorType: "system",
})
require.Error(t, err)
assert.Equal(t, shared.ErrorCodeServiceUnavailable, shared.CodeOf(err))
require.Len(t, publisher.PublishedSnapshots(), shared.MaxProjectionPublishAttempts)
stored, getErr := store.Get(context.Background(), common.DeviceSessionID("device-session-1"))
require.NoError(t, getErr)
require.NotNil(t, stored.Revocation)
assert.Equal(t, devicesession.StatusRevoked, stored.Status)
publisher.Err = nil
result, err := service.Execute(context.Background(), Input{
DeviceSessionID: "device-session-1",
ReasonCode: "logout_all",
ActorType: "system",
})
require.NoError(t, err)
assert.Equal(t, "already_revoked", result.Outcome)
assert.EqualValues(t, 0, result.AffectedSessionCount)
require.Len(t, publisher.PublishedSnapshots(), shared.MaxProjectionPublishAttempts+1)
}
@@ -0,0 +1,151 @@
// Package revokedevicesession implements the trusted internal single-session
// revoke use case.
package revokedevicesession
import (
"context"
"errors"
"fmt"
"galaxy/authsession/internal/ports"
"galaxy/authsession/internal/service/shared"
"galaxy/authsession/internal/telemetry"
"go.uber.org/zap"
)
// Input describes one trusted internal revoke-device-session request.
type Input struct {
// DeviceSessionID identifies the session that should be revoked.
DeviceSessionID string
// ReasonCode stores the machine-readable revoke reason code.
ReasonCode string
// ActorType stores the machine-readable revoke actor type.
ActorType string
// ActorID stores the optional stable revoke actor identifier.
ActorID string
}
// Result describes the frozen internal revoke-device-session acknowledgement.
type Result struct {
// Outcome reports whether the current call revoked the session or found it
// already revoked.
Outcome string
// DeviceSessionID identifies the session addressed by the operation.
DeviceSessionID string
// AffectedSessionCount reports how many sessions changed state during the
// current call.
AffectedSessionCount int64
}
// Service executes the trusted internal revoke-device-session use case.
type Service struct {
sessionStore ports.SessionStore
publisher ports.GatewaySessionProjectionPublisher
clock ports.Clock
logger *zap.Logger
telemetry *telemetry.Runtime
}
// New returns a revoke-device-session service wired to the required ports.
func New(sessionStore ports.SessionStore, publisher ports.GatewaySessionProjectionPublisher, clock ports.Clock) (*Service, error) {
return NewWithObservability(sessionStore, publisher, clock, nil, nil)
}
// NewWithObservability returns a revoke-device-session service wired to the
// required ports plus optional structured logging and telemetry dependencies.
func NewWithObservability(
sessionStore ports.SessionStore,
publisher ports.GatewaySessionProjectionPublisher,
clock ports.Clock,
logger *zap.Logger,
telemetryRuntime *telemetry.Runtime,
) (*Service, error) {
switch {
case sessionStore == nil:
return nil, fmt.Errorf("revokedevicesession: session store must not be nil")
case publisher == nil:
return nil, fmt.Errorf("revokedevicesession: projection publisher must not be nil")
case clock == nil:
return nil, fmt.Errorf("revokedevicesession: clock must not be nil")
default:
return &Service{
sessionStore: sessionStore,
publisher: publisher,
clock: clock,
logger: namedLogger(logger, "revoke_device_session"),
telemetry: telemetryRuntime,
}, nil
}
}
// Execute revokes one device session and republishes the current gateway
// projection for the resulting source-of-truth session state.
func (s *Service) Execute(ctx context.Context, input Input) (result Result, err error) {
logFields := []zap.Field{
zap.String("component", "service"),
zap.String("use_case", "revoke_device_session"),
}
defer func() {
shared.LogServiceOutcome(s.logger, ctx, "revoke device session completed", err, logFields...)
}()
deviceSessionID, err := shared.ParseDeviceSessionID(input.DeviceSessionID)
if err != nil {
return Result{}, err
}
logFields = append(logFields, zap.String("device_session_id", deviceSessionID.String()))
revocation, err := shared.BuildRevocation(input.ReasonCode, input.ActorType, input.ActorID, s.clock.Now())
if err != nil {
return Result{}, err
}
logFields = append(logFields, zap.String("reason_code", revocation.ReasonCode.String()))
storeResult, err := s.sessionStore.Revoke(ctx, ports.RevokeSessionInput{
DeviceSessionID: deviceSessionID,
Revocation: revocation,
})
if err != nil {
switch {
case errors.Is(err, ports.ErrNotFound):
return Result{}, shared.SessionNotFound()
default:
return Result{}, shared.ServiceUnavailable(err)
}
}
if err := storeResult.Validate(); err != nil {
return Result{}, shared.InternalError(err)
}
logFields = append(logFields, zap.String("outcome", string(storeResult.Outcome)))
if err := shared.PublishSessionProjectionWithTelemetry(ctx, s.publisher, storeResult.Session, s.telemetry, "revoke_device_session"); err != nil {
return Result{}, err
}
affectedSessionCount := int64(0)
if storeResult.Outcome == ports.RevokeSessionOutcomeRevoked {
affectedSessionCount = 1
s.telemetry.RecordSessionRevocations(ctx, "revoke_device_session", revocation.ReasonCode.String(), affectedSessionCount)
}
logFields = append(logFields, zap.Int64("affected_session_count", affectedSessionCount))
return Result{
Outcome: string(storeResult.Outcome),
DeviceSessionID: storeResult.Session.ID.String(),
AffectedSessionCount: affectedSessionCount,
}, nil
}
func namedLogger(logger *zap.Logger, name string) *zap.Logger {
if logger == nil {
logger = zap.NewNop()
}
return logger.Named(name)
}
@@ -0,0 +1,166 @@
package revokedevicesession
import (
"context"
"errors"
"testing"
"time"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/devicesession"
"galaxy/authsession/internal/domain/gatewayprojection"
"galaxy/authsession/internal/service/shared"
"galaxy/authsession/internal/testkit"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestExecuteRevokesActiveSessionAndPublishes(t *testing.T) {
t.Parallel()
store := &testkit.InMemorySessionStore{}
publisher := &testkit.RecordingProjectionPublisher{}
record := activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())
if err := store.Create(context.Background(), record); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service, err := New(store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
require.NoError(t, err)
result, err := service.Execute(context.Background(), Input{
DeviceSessionID: "device-session-1",
ReasonCode: "logout_all",
ActorType: "system",
})
require.NoError(t, err)
assert.Equal(t, "revoked", result.Outcome)
assert.EqualValues(t, 1, result.AffectedSessionCount)
assert.Equal(t, "device-session-1", result.DeviceSessionID)
stored, err := store.Get(context.Background(), common.DeviceSessionID("device-session-1"))
require.NoError(t, err)
require.NotNil(t, stored.Revocation)
assert.Equal(t, devicesession.StatusRevoked, stored.Status)
assert.Equal(t, devicesession.RevokeReasonLogoutAll, stored.Revocation.ReasonCode)
assert.Equal(t, common.RevokeActorType("system"), stored.Revocation.ActorType)
assert.Empty(t, stored.Revocation.ActorID)
assert.Equal(t, time.Unix(20, 0).UTC(), stored.Revocation.At)
published := publisher.PublishedSnapshots()
require.Len(t, published, 1)
assert.Equal(t, gatewayprojection.StatusRevoked, published[0].Status)
assert.Equal(t, common.DeviceSessionID("device-session-1"), published[0].DeviceSessionID)
assert.Equal(t, devicesession.RevokeReasonLogoutAll, published[0].RevokeReasonCode)
assert.Equal(t, common.RevokeActorType("system"), published[0].RevokeActorType)
require.NotNil(t, published[0].RevokedAt)
assert.Equal(t, time.Unix(20, 0).UTC(), published[0].RevokedAt.UTC())
}
func TestExecuteAlreadyRevokedReturnsZeroAffectedAndRepublishes(t *testing.T) {
t.Parallel()
store := &testkit.InMemorySessionStore{}
publisher := &testkit.RecordingProjectionPublisher{}
record := revokedSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())
if err := store.Create(context.Background(), record); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service, err := New(store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
require.NoError(t, err)
result, err := service.Execute(context.Background(), Input{
DeviceSessionID: "device-session-1",
ReasonCode: "logout_all",
ActorType: "system",
})
require.NoError(t, err)
assert.Equal(t, "already_revoked", result.Outcome)
assert.EqualValues(t, 0, result.AffectedSessionCount)
assert.Equal(t, "device-session-1", result.DeviceSessionID)
stored, err := store.Get(context.Background(), common.DeviceSessionID("device-session-1"))
require.NoError(t, err)
require.NotNil(t, stored.Revocation)
assert.Equal(t, *record.Revocation, *stored.Revocation)
published := publisher.PublishedSnapshots()
require.Len(t, published, 1)
assert.Equal(t, gatewayprojection.StatusRevoked, published[0].Status)
assert.Equal(t, devicesession.RevokeReasonLogoutAll, published[0].RevokeReasonCode)
assert.Equal(t, common.RevokeActorType("system"), published[0].RevokeActorType)
require.NotNil(t, published[0].RevokedAt)
assert.Equal(t, record.Revocation.At, *published[0].RevokedAt)
}
func TestExecuteReturnsSessionNotFound(t *testing.T) {
t.Parallel()
service, err := New(&testkit.InMemorySessionStore{}, &testkit.RecordingProjectionPublisher{}, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
if err != nil {
require.Failf(t, "test failed", "New() returned error: %v", err)
}
_, err = service.Execute(context.Background(), Input{
DeviceSessionID: "missing",
ReasonCode: "logout_all",
ActorType: "system",
})
assert.Equal(t, shared.ErrorCodeSessionNotFound, shared.CodeOf(err))
}
func TestExecuteReturnsServiceUnavailableWhenPublishFails(t *testing.T) {
t.Parallel()
store := &testkit.InMemorySessionStore{}
publisher := &testkit.RecordingProjectionPublisher{Err: errors.New("publish failed")}
record := activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())
if err := store.Create(context.Background(), record); err != nil {
require.Failf(t, "test failed", "Create() returned error: %v", err)
}
service, err := New(store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
require.NoError(t, err)
_, err = service.Execute(context.Background(), Input{
DeviceSessionID: "device-session-1",
ReasonCode: "logout_all",
ActorType: "system",
})
assert.Equal(t, shared.ErrorCodeServiceUnavailable, shared.CodeOf(err))
stored, getErr := store.Get(context.Background(), common.DeviceSessionID("device-session-1"))
require.NoError(t, getErr)
require.NotNil(t, stored.Revocation)
assert.Equal(t, devicesession.StatusRevoked, stored.Status)
assert.Equal(t, devicesession.RevokeReasonLogoutAll, stored.Revocation.ReasonCode)
assert.Equal(t, common.RevokeActorType("system"), stored.Revocation.ActorType)
}
func activeSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session {
key, err := common.NewClientPublicKey(make([]byte, 32))
if err != nil {
panic(err)
}
return devicesession.Session{
ID: common.DeviceSessionID(deviceSessionID),
UserID: common.UserID(userID),
ClientPublicKey: key,
Status: devicesession.StatusActive,
CreatedAt: createdAt,
}
}
func revokedSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session {
record := activeSessionFixture(deviceSessionID, userID, createdAt)
record.Status = devicesession.StatusRevoked
record.Revocation = &devicesession.Revocation{
At: createdAt.Add(time.Minute),
ReasonCode: devicesession.RevokeReasonLogoutAll,
ActorType: common.RevokeActorType("system"),
}
return record
}
@@ -0,0 +1,167 @@
package sendemailcode
import (
"context"
"testing"
"time"
"galaxy/authsession/internal/domain/challenge"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/userresolution"
"galaxy/authsession/internal/ports"
"galaxy/authsession/internal/testkit"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestExecuteCreatesThrottledChallengeWithoutUserDirectoryOrMail(t *testing.T) {
t.Parallel()
challengeStore := &testkit.InMemoryChallengeStore{}
abuseProtector := &testkit.InMemorySendEmailCodeAbuseProtector{}
now := time.Unix(10, 0).UTC()
require.NoError(t, reserveSendCooldown(abuseProtector, common.Email("pilot@example.com"), now))
userDirectory := &countingUserDirectory{}
mailSender := &testkit.RecordingMailSender{}
service, err := NewWithRuntime(
challengeStore,
userDirectory,
&testkit.SequenceIDGenerator{ChallengeIDs: []common.ChallengeID{"challenge-1"}},
testkit.FixedCodeGenerator{Code: "654321"},
testkit.DeterministicCodeHasher{},
mailSender,
abuseProtector,
testkit.FixedClock{Time: now},
nil,
)
require.NoError(t, err)
result, err := service.Execute(context.Background(), Input{Email: "pilot@example.com"})
require.NoError(t, err)
assert.Equal(t, "challenge-1", result.ChallengeID)
assert.Zero(t, userDirectory.resolveCalls)
assert.Empty(t, mailSender.RecordedInputs())
record, getErr := challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
require.NoError(t, getErr)
assert.Equal(t, challenge.StatusDeliveryThrottled, record.Status)
assert.Equal(t, challenge.DeliveryThrottled, record.DeliveryState)
assert.Equal(t, 1, record.Attempts.Send)
}
func TestExecuteBlockedEmailOutsideThrottleStillSuppressesDelivery(t *testing.T) {
t.Parallel()
challengeStore := &testkit.InMemoryChallengeStore{}
userDirectory := &testkit.InMemoryUserDirectory{}
require.NoError(t, userDirectory.SeedBlockedEmail(common.Email("pilot@example.com"), userresolution.BlockReasonCode("policy_block")))
mailSender := &testkit.RecordingMailSender{}
service, err := NewWithRuntime(
challengeStore,
userDirectory,
&testkit.SequenceIDGenerator{ChallengeIDs: []common.ChallengeID{"challenge-1"}},
testkit.FixedCodeGenerator{Code: "654321"},
testkit.DeterministicCodeHasher{},
mailSender,
&testkit.InMemorySendEmailCodeAbuseProtector{},
testkit.FixedClock{Time: time.Unix(10, 0).UTC()},
nil,
)
require.NoError(t, err)
result, err := service.Execute(context.Background(), Input{Email: "pilot@example.com"})
require.NoError(t, err)
assert.Equal(t, "challenge-1", result.ChallengeID)
assert.Empty(t, mailSender.RecordedInputs())
record, getErr := challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
require.NoError(t, getErr)
assert.Equal(t, challenge.StatusDeliverySuppressed, record.Status)
assert.Equal(t, challenge.DeliverySuppressed, record.DeliveryState)
}
func TestExecuteAllowsAgainAfterCooldown(t *testing.T) {
t.Parallel()
challengeStore := &testkit.InMemoryChallengeStore{}
userDirectory := &testkit.InMemoryUserDirectory{}
mailSender := &testkit.RecordingMailSender{}
abuseProtector := &testkit.InMemorySendEmailCodeAbuseProtector{}
clock := &mutableClock{time: time.Unix(10, 0).UTC()}
idGenerator := &testkit.SequenceIDGenerator{
ChallengeIDs: []common.ChallengeID{"challenge-1", "challenge-2"},
}
service, err := NewWithRuntime(
challengeStore,
userDirectory,
idGenerator,
testkit.FixedCodeGenerator{Code: "654321"},
testkit.DeterministicCodeHasher{},
mailSender,
abuseProtector,
clock,
nil,
)
require.NoError(t, err)
first, err := service.Execute(context.Background(), Input{Email: "pilot@example.com"})
require.NoError(t, err)
assert.Equal(t, "challenge-1", first.ChallengeID)
clock.time = clock.time.Add(challenge.ResendThrottleCooldown)
second, err := service.Execute(context.Background(), Input{Email: "pilot@example.com"})
require.NoError(t, err)
assert.Equal(t, "challenge-2", second.ChallengeID)
require.Len(t, mailSender.RecordedInputs(), 2)
secondRecord, getErr := challengeStore.Get(context.Background(), common.ChallengeID("challenge-2"))
require.NoError(t, getErr)
assert.Equal(t, challenge.StatusSent, secondRecord.Status)
assert.Equal(t, challenge.DeliverySent, secondRecord.DeliveryState)
}
func reserveSendCooldown(protector ports.SendEmailCodeAbuseProtector, email common.Email, now time.Time) error {
_, err := protector.CheckAndReserve(context.Background(), ports.SendEmailCodeAbuseInput{
Email: email,
Now: now,
})
return err
}
type mutableClock struct {
time time.Time
}
func (c *mutableClock) Now() time.Time {
return c.time
}
type countingUserDirectory struct {
resolveCalls int
}
func (d *countingUserDirectory) ResolveByEmail(_ context.Context, _ common.Email) (userresolution.Result, error) {
d.resolveCalls++
return userresolution.Result{Kind: userresolution.KindCreatable}, nil
}
func (d *countingUserDirectory) ExistsByUserID(context.Context, common.UserID) (bool, error) {
return false, nil
}
func (d *countingUserDirectory) EnsureUserByEmail(context.Context, common.Email) (ports.EnsureUserResult, error) {
return ports.EnsureUserResult{}, nil
}
func (d *countingUserDirectory) BlockByUserID(context.Context, ports.BlockUserByIDInput) (ports.BlockUserResult, error) {
return ports.BlockUserResult{}, nil
}
func (d *countingUserDirectory) BlockByEmail(context.Context, ports.BlockUserByEmailInput) (ports.BlockUserResult, error) {
return ports.BlockUserResult{}, nil
}
@@ -0,0 +1,59 @@
package sendemailcode
import (
"bytes"
"context"
"testing"
"time"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/testkit"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
func TestExecuteLogsSafeOutcomeFields(t *testing.T) {
t.Parallel()
logger, buffer := newObservedServiceLogger()
service, err := NewWithObservability(
&testkit.InMemoryChallengeStore{},
&testkit.InMemoryUserDirectory{},
&testkit.SequenceIDGenerator{ChallengeIDs: []common.ChallengeID{"challenge-1"}},
testkit.FixedCodeGenerator{Code: "654321"},
testkit.DeterministicCodeHasher{},
&testkit.RecordingMailSender{},
nil,
testkit.FixedClock{Time: time.Unix(10, 0).UTC()},
logger,
nil,
)
require.NoError(t, err)
_, err = service.Execute(context.Background(), Input{Email: "pilot@example.com"})
require.NoError(t, err)
logOutput := buffer.String()
assert.Contains(t, logOutput, "send_email_code")
assert.Contains(t, logOutput, "challenge-1")
assert.Contains(t, logOutput, "\"outcome\":\"sent\"")
assert.NotContains(t, logOutput, "pilot@example.com")
assert.NotContains(t, logOutput, "654321")
}
func newObservedServiceLogger() (*zap.Logger, *bytes.Buffer) {
buffer := &bytes.Buffer{}
encoderConfig := zap.NewProductionEncoderConfig()
encoderConfig.TimeKey = ""
core := zapcore.NewCore(
zapcore.NewJSONEncoder(encoderConfig),
zapcore.AddSync(buffer),
zap.DebugLevel,
)
return zap.New(core), buffer
}
@@ -0,0 +1,331 @@
// Package sendemailcode implements the public send-email-code use case.
package sendemailcode
import (
"context"
"fmt"
"reflect"
"galaxy/authsession/internal/domain/challenge"
"galaxy/authsession/internal/domain/userresolution"
"galaxy/authsession/internal/ports"
"galaxy/authsession/internal/service/shared"
"galaxy/authsession/internal/telemetry"
"go.uber.org/zap"
)
// Input describes one public send-email-code request.
type Input struct {
// Email is the user-supplied e-mail address that should receive the login
// code.
Email string
}
// Result describes one public send-email-code response.
type Result struct {
// ChallengeID is the stable challenge identifier returned to the caller.
ChallengeID string
}
// Service executes the public send-email-code use case.
type Service struct {
challengeStore ports.ChallengeStore
userDirectory ports.UserDirectory
idGenerator ports.IDGenerator
codeGenerator ports.CodeGenerator
codeHasher ports.CodeHasher
mailSender ports.MailSender
abuseProtector ports.SendEmailCodeAbuseProtector
clock ports.Clock
logger *zap.Logger
telemetry *telemetry.Runtime
}
// New returns a send-email-code service wired to the required ports.
func New(
challengeStore ports.ChallengeStore,
userDirectory ports.UserDirectory,
idGenerator ports.IDGenerator,
codeGenerator ports.CodeGenerator,
codeHasher ports.CodeHasher,
mailSender ports.MailSender,
clock ports.Clock,
) (*Service, error) {
return NewWithRuntime(
challengeStore,
userDirectory,
idGenerator,
codeGenerator,
codeHasher,
mailSender,
nil,
clock,
nil,
)
}
// NewWithRuntime returns a send-email-code service wired to the required
// ports plus the optional Stage-17 runtime collaborators.
func NewWithRuntime(
challengeStore ports.ChallengeStore,
userDirectory ports.UserDirectory,
idGenerator ports.IDGenerator,
codeGenerator ports.CodeGenerator,
codeHasher ports.CodeHasher,
mailSender ports.MailSender,
abuseProtector ports.SendEmailCodeAbuseProtector,
clock ports.Clock,
telemetryRuntime *telemetry.Runtime,
) (*Service, error) {
return NewWithObservability(
challengeStore,
userDirectory,
idGenerator,
codeGenerator,
codeHasher,
mailSender,
abuseProtector,
clock,
nil,
telemetryRuntime,
)
}
// NewWithObservability returns a send-email-code service wired to the required
// ports plus optional structured logging and telemetry dependencies.
func NewWithObservability(
challengeStore ports.ChallengeStore,
userDirectory ports.UserDirectory,
idGenerator ports.IDGenerator,
codeGenerator ports.CodeGenerator,
codeHasher ports.CodeHasher,
mailSender ports.MailSender,
abuseProtector ports.SendEmailCodeAbuseProtector,
clock ports.Clock,
logger *zap.Logger,
telemetryRuntime *telemetry.Runtime,
) (*Service, error) {
switch {
case challengeStore == nil:
return nil, fmt.Errorf("sendemailcode: challenge store must not be nil")
case userDirectory == nil:
return nil, fmt.Errorf("sendemailcode: user directory must not be nil")
case idGenerator == nil:
return nil, fmt.Errorf("sendemailcode: id generator must not be nil")
case codeGenerator == nil:
return nil, fmt.Errorf("sendemailcode: code generator must not be nil")
case codeHasher == nil:
return nil, fmt.Errorf("sendemailcode: code hasher must not be nil")
case mailSender == nil:
return nil, fmt.Errorf("sendemailcode: mail sender must not be nil")
case clock == nil:
return nil, fmt.Errorf("sendemailcode: clock must not be nil")
default:
return &Service{
challengeStore: challengeStore,
userDirectory: userDirectory,
idGenerator: idGenerator,
codeGenerator: codeGenerator,
codeHasher: codeHasher,
mailSender: mailSender,
abuseProtector: normalizeAbuseProtector(abuseProtector),
clock: clock,
logger: namedLogger(logger, "send_email_code"),
telemetry: telemetryRuntime,
}, nil
}
}
// Execute creates a fresh challenge for every request, stores only the hashed
// confirmation code, and records whether delivery was sent or intentionally
// suppressed.
func (s *Service) Execute(ctx context.Context, input Input) (result Result, err error) {
logFields := []zap.Field{
zap.String("component", "service"),
zap.String("use_case", "send_email_code"),
}
outcome := ""
defer func() {
if outcome != "" {
logFields = append(logFields, zap.String("outcome", outcome))
}
if result.ChallengeID != "" {
logFields = append(logFields, zap.String("challenge_id", result.ChallengeID))
}
shared.LogServiceOutcome(s.logger, ctx, "send email code completed", err, logFields...)
}()
email, err := shared.ParseEmail(input.Email)
if err != nil {
return Result{}, err
}
now := s.clock.Now().UTC()
abuseResult, err := s.abuseProtector.CheckAndReserve(ctx, ports.SendEmailCodeAbuseInput{
Email: email,
Now: now,
})
if err != nil {
return Result{}, shared.ServiceUnavailable(err)
}
if err := abuseResult.Validate(); err != nil {
return Result{}, shared.InternalError(err)
}
challengeID, err := s.idGenerator.NewChallengeID()
if err != nil {
return Result{}, shared.ServiceUnavailable(err)
}
code, err := s.codeGenerator.Generate()
if err != nil {
return Result{}, shared.ServiceUnavailable(err)
}
codeHash, err := s.codeHasher.Hash(code)
if err != nil {
return Result{}, shared.ServiceUnavailable(err)
}
pendingStatus, pendingDeliveryState, err := ports.SendEmailCodeThrottleStatusToChallengeStatus(abuseResult.Outcome)
if err != nil {
return Result{}, shared.InternalError(err)
}
pending := challenge.Challenge{
ID: challengeID,
Email: email,
CodeHash: codeHash,
Status: pendingStatus,
DeliveryState: pendingDeliveryState,
CreatedAt: now,
ExpiresAt: now.Add(challenge.InitialTTL),
}
if err := pending.Validate(); err != nil {
return Result{}, shared.InternalError(err)
}
if err := s.challengeStore.Create(ctx, pending); err != nil {
return Result{}, shared.ServiceUnavailable(err)
}
s.telemetry.RecordChallengeCreated(ctx)
final := pending
final.Attempts.Send = 1
final.Abuse.LastAttemptAt = &now
if abuseResult.Outcome == ports.SendEmailCodeAbuseOutcomeThrottled {
result, err = s.finishChallenge(ctx, pending, final)
if err == nil {
outcome = string(telemetry.SendEmailCodeOutcomeThrottled)
s.telemetry.RecordSendEmailCode(ctx, telemetry.SendEmailCodeOutcomeThrottled, telemetry.SendEmailCodeReasonThrottled)
}
return result, err
}
resolution, err := s.userDirectory.ResolveByEmail(ctx, email)
if err != nil {
return Result{}, shared.ServiceUnavailable(err)
}
if err := resolution.Validate(); err != nil {
return Result{}, shared.InternalError(err)
}
s.telemetry.RecordUserDirectoryOutcome(ctx, "resolve_by_email", string(resolution.Kind))
switch resolution.Kind {
case userresolution.KindBlocked:
final.Status = challenge.StatusDeliverySuppressed
final.DeliveryState = challenge.DeliverySuppressed
result, err = s.finishChallenge(ctx, pending, final)
if err == nil {
outcome = string(telemetry.SendEmailCodeOutcomeSuppressed)
s.telemetry.RecordSendEmailCode(ctx, telemetry.SendEmailCodeOutcomeSuppressed, telemetry.SendEmailCodeReasonBlocked)
}
return result, err
default:
deliveryResult, err := s.mailSender.SendLoginCode(ctx, ports.SendLoginCodeInput{
Email: email,
Code: code,
})
if err != nil {
final.Status = challenge.StatusFailed
final.DeliveryState = challenge.DeliveryFailed
if _, persistErr := s.finishChallenge(ctx, pending, final); persistErr != nil {
return Result{}, persistErr
}
outcome = string(telemetry.SendEmailCodeOutcomeFailed)
s.telemetry.RecordSendEmailCode(ctx, telemetry.SendEmailCodeOutcomeFailed, telemetry.SendEmailCodeReasonMailSender)
return Result{}, shared.ServiceUnavailable(err)
}
if err := deliveryResult.Validate(); err != nil {
return Result{}, shared.InternalError(err)
}
switch deliveryResult.Outcome {
case ports.SendLoginCodeOutcomeSent:
final.Status = challenge.StatusSent
final.DeliveryState = challenge.DeliverySent
result, err = s.finishChallenge(ctx, pending, final)
if err == nil {
outcome = string(telemetry.SendEmailCodeOutcomeSent)
s.telemetry.RecordSendEmailCode(ctx, telemetry.SendEmailCodeOutcomeSent, "")
}
return result, err
case ports.SendLoginCodeOutcomeSuppressed:
final.Status = challenge.StatusDeliverySuppressed
final.DeliveryState = challenge.DeliverySuppressed
result, err = s.finishChallenge(ctx, pending, final)
if err == nil {
outcome = string(telemetry.SendEmailCodeOutcomeSuppressed)
s.telemetry.RecordSendEmailCode(ctx, telemetry.SendEmailCodeOutcomeSuppressed, telemetry.SendEmailCodeReasonMailSender)
}
return result, err
default:
return Result{}, shared.InternalError(fmt.Errorf("sendemailcode: unsupported delivery outcome %q", deliveryResult.Outcome))
}
}
}
func (s *Service) finishChallenge(ctx context.Context, pending challenge.Challenge, final challenge.Challenge) (Result, error) {
if err := final.Validate(); err != nil {
return Result{}, shared.InternalError(err)
}
if err := s.challengeStore.CompareAndSwap(ctx, pending, final); err != nil {
return Result{}, shared.ServiceUnavailable(err)
}
return Result{ChallengeID: final.ID.String()}, nil
}
func normalizeAbuseProtector(protector ports.SendEmailCodeAbuseProtector) ports.SendEmailCodeAbuseProtector {
if protector == nil {
return allowAllSendEmailCodeAbuseProtector{}
}
value := reflect.ValueOf(protector)
switch value.Kind() {
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice:
if value.IsNil() {
return allowAllSendEmailCodeAbuseProtector{}
}
}
return protector
}
type allowAllSendEmailCodeAbuseProtector struct{}
func (allowAllSendEmailCodeAbuseProtector) CheckAndReserve(_ context.Context, input ports.SendEmailCodeAbuseInput) (ports.SendEmailCodeAbuseResult, error) {
if err := input.Validate(); err != nil {
return ports.SendEmailCodeAbuseResult{}, err
}
return ports.SendEmailCodeAbuseResult{
Outcome: ports.SendEmailCodeAbuseOutcomeAllowed,
}, nil
}
func namedLogger(logger *zap.Logger, name string) *zap.Logger {
if logger == nil {
logger = zap.NewNop()
}
return logger.Named(name)
}
@@ -0,0 +1,310 @@
package sendemailcode
import (
"context"
"errors"
"github.com/stretchr/testify/require"
"testing"
"time"
"galaxy/authsession/internal/domain/challenge"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/userresolution"
"galaxy/authsession/internal/ports"
"galaxy/authsession/internal/service/shared"
"galaxy/authsession/internal/testkit"
)
func TestExecuteSendsChallengeForExistingAndCreatableUsers(t *testing.T) {
t.Parallel()
tests := []struct {
name string
seed func(*testkit.InMemoryUserDirectory) error
email string
}{
{
name: "existing",
seed: func(directory *testkit.InMemoryUserDirectory) error {
return directory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1"))
},
email: " pilot@example.com ",
},
{
name: "creatable",
seed: func(*testkit.InMemoryUserDirectory) error { return nil },
email: "new@example.com",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
challengeStore := &testkit.InMemoryChallengeStore{}
userDirectory := &testkit.InMemoryUserDirectory{}
if err := tt.seed(userDirectory); err != nil {
require.Failf(t, "test failed", "seed() returned error: %v", err)
}
mailSender := &testkit.RecordingMailSender{}
service, err := New(
challengeStore,
userDirectory,
&testkit.SequenceIDGenerator{ChallengeIDs: []common.ChallengeID{"challenge-1"}},
testkit.FixedCodeGenerator{Code: "654321"},
testkit.DeterministicCodeHasher{},
mailSender,
testkit.FixedClock{Time: time.Unix(10, 0).UTC()},
)
if err != nil {
require.Failf(t, "test failed", "New() returned error: %v", err)
}
result, err := service.Execute(context.Background(), Input{Email: tt.email})
if err != nil {
require.Failf(t, "test failed", "Execute() returned error: %v", err)
}
if result.ChallengeID != "challenge-1" {
require.Failf(t, "test failed", "Execute().ChallengeID = %q, want %q", result.ChallengeID, "challenge-1")
}
if len(mailSender.RecordedInputs()) != 1 {
require.Failf(t, "test failed", "RecordedInputs() length = %d, want 1", len(mailSender.RecordedInputs()))
}
record, err := challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
if err != nil {
require.Failf(t, "test failed", "Get() returned error: %v", err)
}
if record.Status != challenge.StatusSent || record.DeliveryState != challenge.DeliverySent {
require.Failf(t, "test failed", "challenge state = %q/%q", record.Status, record.DeliveryState)
}
if record.Attempts.Send != 1 {
require.Failf(t, "test failed", "Attempts.Send = %d, want 1", record.Attempts.Send)
}
if string(record.CodeHash) == "654321" {
require.FailNow(t, "CodeHash stored cleartext code")
}
})
}
}
func TestExecuteSuppressesDeliveryForBlockedEmail(t *testing.T) {
t.Parallel()
challengeStore := &testkit.InMemoryChallengeStore{}
userDirectory := &testkit.InMemoryUserDirectory{}
if err := userDirectory.SeedBlockedEmail(common.Email("pilot@example.com"), userresolution.BlockReasonCode("policy_block")); err != nil {
require.Failf(t, "test failed", "SeedBlockedEmail() returned error: %v", err)
}
mailSender := &testkit.RecordingMailSender{}
service, err := New(
challengeStore,
userDirectory,
&testkit.SequenceIDGenerator{ChallengeIDs: []common.ChallengeID{"challenge-1"}},
testkit.FixedCodeGenerator{Code: "654321"},
testkit.DeterministicCodeHasher{},
mailSender,
testkit.FixedClock{Time: time.Unix(10, 0).UTC()},
)
if err != nil {
require.Failf(t, "test failed", "New() returned error: %v", err)
}
result, err := service.Execute(context.Background(), Input{Email: "pilot@example.com"})
if err != nil {
require.Failf(t, "test failed", "Execute() returned error: %v", err)
}
if result.ChallengeID != "challenge-1" {
require.Failf(t, "test failed", "Execute().ChallengeID = %q, want %q", result.ChallengeID, "challenge-1")
}
if len(mailSender.RecordedInputs()) != 0 {
require.Failf(t, "test failed", "RecordedInputs() length = %d, want 0", len(mailSender.RecordedInputs()))
}
record, err := challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
if err != nil {
require.Failf(t, "test failed", "Get() returned error: %v", err)
}
if record.Status != challenge.StatusDeliverySuppressed || record.DeliveryState != challenge.DeliverySuppressed {
require.Failf(t, "test failed", "challenge state = %q/%q", record.Status, record.DeliveryState)
}
}
func TestExecuteHandlesMailSenderSuppressedOutcome(t *testing.T) {
t.Parallel()
challengeStore := &testkit.InMemoryChallengeStore{}
mailSender := &testkit.RecordingMailSender{
DefaultResult: ports.SendLoginCodeResult{Outcome: ports.SendLoginCodeOutcomeSuppressed},
}
service, err := New(
challengeStore,
&testkit.InMemoryUserDirectory{},
&testkit.SequenceIDGenerator{ChallengeIDs: []common.ChallengeID{"challenge-1"}},
testkit.FixedCodeGenerator{Code: "654321"},
testkit.DeterministicCodeHasher{},
mailSender,
testkit.FixedClock{Time: time.Unix(10, 0).UTC()},
)
if err != nil {
require.Failf(t, "test failed", "New() returned error: %v", err)
}
_, err = service.Execute(context.Background(), Input{Email: "pilot@example.com"})
if err != nil {
require.Failf(t, "test failed", "Execute() returned error: %v", err)
}
record, err := challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
if err != nil {
require.Failf(t, "test failed", "Get() returned error: %v", err)
}
if record.Status != challenge.StatusDeliverySuppressed || record.DeliveryState != challenge.DeliverySuppressed {
require.Failf(t, "test failed", "challenge state = %q/%q", record.Status, record.DeliveryState)
}
}
func TestExecuteMarksChallengeFailedWhenMailSenderFails(t *testing.T) {
t.Parallel()
challengeStore := &testkit.InMemoryChallengeStore{}
mailSender := &testkit.RecordingMailSender{Err: errors.New("mail failed")}
service, err := New(
challengeStore,
&testkit.InMemoryUserDirectory{},
&testkit.SequenceIDGenerator{ChallengeIDs: []common.ChallengeID{"challenge-1"}},
testkit.FixedCodeGenerator{Code: "654321"},
testkit.DeterministicCodeHasher{},
mailSender,
testkit.FixedClock{Time: time.Unix(10, 0).UTC()},
)
if err != nil {
require.Failf(t, "test failed", "New() returned error: %v", err)
}
_, err = service.Execute(context.Background(), Input{Email: "pilot@example.com"})
if shared.CodeOf(err) != shared.ErrorCodeServiceUnavailable {
require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeServiceUnavailable)
}
record, err := challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
if err != nil {
require.Failf(t, "test failed", "Get() returned error: %v", err)
}
if record.Status != challenge.StatusFailed || record.DeliveryState != challenge.DeliveryFailed {
require.Failf(t, "test failed", "challenge state = %q/%q", record.Status, record.DeliveryState)
}
}
func TestExecuteReturnsInvalidRequestForBadEmail(t *testing.T) {
t.Parallel()
service, err := New(
&testkit.InMemoryChallengeStore{},
&testkit.InMemoryUserDirectory{},
&testkit.SequenceIDGenerator{},
testkit.FixedCodeGenerator{Code: "654321"},
testkit.DeterministicCodeHasher{},
&testkit.RecordingMailSender{},
testkit.FixedClock{Time: time.Unix(10, 0).UTC()},
)
if err != nil {
require.Failf(t, "test failed", "New() returned error: %v", err)
}
_, err = service.Execute(context.Background(), Input{Email: "pilot"})
if shared.CodeOf(err) != shared.ErrorCodeInvalidRequest {
require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeInvalidRequest)
}
}
func TestExecuteCreatesFreshChallengeForRepeatedSend(t *testing.T) {
t.Parallel()
challengeStore := &testkit.InMemoryChallengeStore{}
mailSender := &testkit.RecordingMailSender{}
clock := testkit.FixedClock{Time: time.Unix(10, 0).UTC()}
service, err := New(
challengeStore,
&testkit.InMemoryUserDirectory{},
&testkit.SequenceIDGenerator{
ChallengeIDs: []common.ChallengeID{"challenge-1", "challenge-2"},
},
testkit.FixedCodeGenerator{Code: "654321"},
testkit.DeterministicCodeHasher{},
mailSender,
clock,
)
if err != nil {
require.Failf(t, "test failed", "New() returned error: %v", err)
}
first, err := service.Execute(context.Background(), Input{Email: "pilot@example.com"})
if err != nil {
require.Failf(t, "test failed", "first Execute() returned error: %v", err)
}
second, err := service.Execute(context.Background(), Input{Email: "pilot@example.com"})
if err != nil {
require.Failf(t, "test failed", "second Execute() returned error: %v", err)
}
if first.ChallengeID == second.ChallengeID {
require.Failf(t, "test failed", "challenge ids are equal: %q", first.ChallengeID)
}
firstRecord, err := challengeStore.Get(context.Background(), common.ChallengeID(first.ChallengeID))
if err != nil {
require.Failf(t, "test failed", "Get(%q) returned error: %v", first.ChallengeID, err)
}
secondRecord, err := challengeStore.Get(context.Background(), common.ChallengeID(second.ChallengeID))
if err != nil {
require.Failf(t, "test failed", "Get(%q) returned error: %v", second.ChallengeID, err)
}
if firstRecord.Status != challenge.StatusSent {
require.Failf(t, "test failed", "first challenge status = %q, want %q", firstRecord.Status, challenge.StatusSent)
}
if secondRecord.Status != challenge.StatusSent {
require.Failf(t, "test failed", "second challenge status = %q, want %q", secondRecord.Status, challenge.StatusSent)
}
if len(mailSender.RecordedInputs()) != 2 {
require.Failf(t, "test failed", "RecordedInputs() length = %d, want 2", len(mailSender.RecordedInputs()))
}
}
func TestExecuteSetsChallengeExpirationFromInitialTTL(t *testing.T) {
t.Parallel()
now := time.Unix(10, 0).UTC()
challengeStore := &testkit.InMemoryChallengeStore{}
service, err := New(
challengeStore,
&testkit.InMemoryUserDirectory{},
&testkit.SequenceIDGenerator{ChallengeIDs: []common.ChallengeID{"challenge-1"}},
testkit.FixedCodeGenerator{Code: "654321"},
testkit.DeterministicCodeHasher{},
&testkit.RecordingMailSender{},
testkit.FixedClock{Time: now},
)
if err != nil {
require.Failf(t, "test failed", "New() returned error: %v", err)
}
if _, err := service.Execute(context.Background(), Input{Email: "pilot@example.com"}); err != nil {
require.Failf(t, "test failed", "Execute() returned error: %v", err)
}
record, err := challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
if err != nil {
require.Failf(t, "test failed", "Get() returned error: %v", err)
}
wantExpiresAt := now.Add(challenge.InitialTTL)
if !record.ExpiresAt.Equal(wantExpiresAt) {
require.Failf(t, "test failed", "ExpiresAt = %s, want %s", record.ExpiresAt, wantExpiresAt)
}
}
@@ -0,0 +1,98 @@
package sendemailcode
import (
"context"
"errors"
"testing"
"time"
stubmail "galaxy/authsession/internal/adapters/mail"
"galaxy/authsession/internal/domain/challenge"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/service/shared"
"galaxy/authsession/internal/testkit"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestExecuteWithStubSender(t *testing.T) {
t.Parallel()
tests := []struct {
name string
sender *stubmail.StubSender
wantStatus challenge.Status
wantDeliveryState challenge.DeliveryState
wantErrorCode string
wantRecordedAttempt int
}{
{
name: "sent",
sender: &stubmail.StubSender{},
wantStatus: challenge.StatusSent,
wantDeliveryState: challenge.DeliverySent,
wantRecordedAttempt: 1,
},
{
name: "suppressed",
sender: &stubmail.StubSender{
DefaultMode: stubmail.StubModeSuppressed,
},
wantStatus: challenge.StatusDeliverySuppressed,
wantDeliveryState: challenge.DeliverySuppressed,
wantRecordedAttempt: 1,
},
{
name: "failed",
sender: &stubmail.StubSender{
DefaultMode: stubmail.StubModeFailed,
DefaultError: errors.New("stub delivery failed"),
},
wantStatus: challenge.StatusFailed,
wantDeliveryState: challenge.DeliveryFailed,
wantErrorCode: shared.ErrorCodeServiceUnavailable,
wantRecordedAttempt: 1,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
challengeStore := &testkit.InMemoryChallengeStore{}
service, err := New(
challengeStore,
&testkit.InMemoryUserDirectory{},
&testkit.SequenceIDGenerator{ChallengeIDs: []common.ChallengeID{"challenge-1"}},
testkit.FixedCodeGenerator{Code: "654321"},
testkit.DeterministicCodeHasher{},
tt.sender,
testkit.FixedClock{Time: time.Unix(10, 0).UTC()},
)
require.NoError(t, err)
result, err := service.Execute(context.Background(), Input{Email: "pilot@example.com"})
if tt.wantErrorCode == "" {
require.NoError(t, err)
assert.Equal(t, "challenge-1", result.ChallengeID)
} else {
require.Error(t, err)
assert.Equal(t, tt.wantErrorCode, shared.CodeOf(err))
assert.Equal(t, Result{}, result)
}
record, getErr := challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
require.NoError(t, getErr)
assert.Equal(t, tt.wantStatus, record.Status)
assert.Equal(t, tt.wantDeliveryState, record.DeliveryState)
attempts := tt.sender.RecordedAttempts()
require.Len(t, attempts, tt.wantRecordedAttempt)
assert.Equal(t, common.Email("pilot@example.com"), attempts[0].Input.Email)
assert.Equal(t, "654321", attempts[0].Input.Code)
})
}
}
@@ -0,0 +1,93 @@
package sendemailcode
import (
"context"
"testing"
"time"
stubmail "galaxy/authsession/internal/adapters/mail"
stubuserservice "galaxy/authsession/internal/adapters/userservice"
"galaxy/authsession/internal/domain/challenge"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/userresolution"
"galaxy/authsession/internal/testkit"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestExecuteWithRuntimeStubUserDirectory(t *testing.T) {
t.Parallel()
tests := []struct {
name string
seed func(*stubuserservice.StubDirectory) error
email string
wantStatus challenge.Status
wantDeliveryState challenge.DeliveryState
wantMailCalls int
}{
{
name: "existing user",
email: "pilot@example.com",
seed: func(directory *stubuserservice.StubDirectory) error {
return directory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1"))
},
wantStatus: challenge.StatusSent,
wantDeliveryState: challenge.DeliverySent,
wantMailCalls: 1,
},
{
name: "creatable user",
email: "new@example.com",
seed: func(*stubuserservice.StubDirectory) error { return nil },
wantStatus: challenge.StatusSent,
wantDeliveryState: challenge.DeliverySent,
wantMailCalls: 1,
},
{
name: "blocked email",
email: "blocked@example.com",
seed: func(directory *stubuserservice.StubDirectory) error {
return directory.SeedBlockedEmail(common.Email("blocked@example.com"), userresolution.BlockReasonCode("policy_block"))
},
wantStatus: challenge.StatusDeliverySuppressed,
wantDeliveryState: challenge.DeliverySuppressed,
wantMailCalls: 0,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
userDirectory := &stubuserservice.StubDirectory{}
require.NoError(t, tt.seed(userDirectory))
challengeStore := &testkit.InMemoryChallengeStore{}
mailSender := &stubmail.StubSender{}
service, err := New(
challengeStore,
userDirectory,
&testkit.SequenceIDGenerator{ChallengeIDs: []common.ChallengeID{"challenge-1"}},
testkit.FixedCodeGenerator{Code: "654321"},
testkit.DeterministicCodeHasher{},
mailSender,
testkit.FixedClock{Time: time.Unix(10, 0).UTC()},
)
require.NoError(t, err)
result, err := service.Execute(context.Background(), Input{Email: tt.email})
require.NoError(t, err)
assert.Equal(t, "challenge-1", result.ChallengeID)
record, err := challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
require.NoError(t, err)
assert.Equal(t, tt.wantStatus, record.Status)
assert.Equal(t, tt.wantDeliveryState, record.DeliveryState)
assert.Len(t, mailSender.RecordedAttempts(), tt.wantMailCalls)
})
}
}
@@ -0,0 +1,171 @@
package sendemailcode
import (
"context"
"testing"
"time"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/userresolution"
authtelemetry "galaxy/authsession/internal/telemetry"
"galaxy/authsession/internal/testkit"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/attribute"
sdkmetric "go.opentelemetry.io/otel/sdk/metric"
"go.opentelemetry.io/otel/sdk/metric/metricdata"
)
func TestExecuteRecordsSentMetric(t *testing.T) {
t.Parallel()
runtime, reader := newObservedTelemetryRuntime(t)
service, _, mailSender := newObservedSendService(t, observedSendOptions{
Telemetry: runtime,
})
_, err := service.Execute(context.Background(), Input{Email: "pilot@example.com"})
require.NoError(t, err)
require.Len(t, mailSender.RecordedInputs(), 1)
assertMetricCount(t, reader, "authsession.send_email_code.attempts", map[string]string{
"outcome": "sent",
}, 1)
}
func TestExecuteRecordsBlockedSuppressedMetric(t *testing.T) {
t.Parallel()
runtime, reader := newObservedTelemetryRuntime(t)
service, _, _ := newObservedSendService(t, observedSendOptions{
Telemetry: runtime,
SeedBlockedEmail: true,
})
_, err := service.Execute(context.Background(), Input{Email: "pilot@example.com"})
require.NoError(t, err)
assertMetricCount(t, reader, "authsession.send_email_code.attempts", map[string]string{
"outcome": "suppressed",
"reason": "blocked",
}, 1)
}
func TestExecuteRecordsThrottledMetric(t *testing.T) {
t.Parallel()
runtime, reader := newObservedTelemetryRuntime(t)
abuseProtector := &testkit.InMemorySendEmailCodeAbuseProtector{}
now := time.Unix(10, 0).UTC()
require.NoError(t, reserveSendCooldown(abuseProtector, common.Email("pilot@example.com"), now))
service, _, mailSender := newObservedSendService(t, observedSendOptions{
Telemetry: runtime,
AbuseProtector: abuseProtector,
Clock: testkit.FixedClock{Time: now},
})
_, err := service.Execute(context.Background(), Input{Email: "pilot@example.com"})
require.NoError(t, err)
assert.Empty(t, mailSender.RecordedInputs())
assertMetricCount(t, reader, "authsession.send_email_code.attempts", map[string]string{
"outcome": "throttled",
"reason": "throttled",
}, 1)
}
type observedSendOptions struct {
Telemetry *authtelemetry.Runtime
AbuseProtector *testkit.InMemorySendEmailCodeAbuseProtector
SeedBlockedEmail bool
Clock portsClock
}
type portsClock interface {
Now() time.Time
}
func newObservedSendService(t *testing.T, options observedSendOptions) (*Service, *testkit.InMemoryChallengeStore, *testkit.RecordingMailSender) {
t.Helper()
challengeStore := &testkit.InMemoryChallengeStore{}
userDirectory := &testkit.InMemoryUserDirectory{}
if options.SeedBlockedEmail {
require.NoError(t, userDirectory.SeedBlockedEmail(common.Email("pilot@example.com"), userresolution.BlockReasonCode("policy_block")))
}
mailSender := &testkit.RecordingMailSender{}
clock := options.Clock
if clock == nil {
clock = testkit.FixedClock{Time: time.Unix(10, 0).UTC()}
}
service, err := NewWithRuntime(
challengeStore,
userDirectory,
&testkit.SequenceIDGenerator{ChallengeIDs: []common.ChallengeID{"challenge-1"}},
testkit.FixedCodeGenerator{Code: "654321"},
testkit.DeterministicCodeHasher{},
mailSender,
options.AbuseProtector,
clock,
options.Telemetry,
)
require.NoError(t, err)
return service, challengeStore, mailSender
}
func newObservedTelemetryRuntime(t *testing.T) (*authtelemetry.Runtime, *sdkmetric.ManualReader) {
t.Helper()
reader := sdkmetric.NewManualReader()
provider := sdkmetric.NewMeterProvider(sdkmetric.WithReader(reader))
runtime, err := authtelemetry.New(provider)
require.NoError(t, err)
return runtime, reader
}
func assertMetricCount(t *testing.T, reader *sdkmetric.ManualReader, metricName string, wantAttrs map[string]string, wantValue int64) {
t.Helper()
var resourceMetrics metricdata.ResourceMetrics
require.NoError(t, reader.Collect(context.Background(), &resourceMetrics))
for _, scopeMetrics := range resourceMetrics.ScopeMetrics {
for _, metric := range scopeMetrics.Metrics {
if metric.Name != metricName {
continue
}
sum, ok := metric.Data.(metricdata.Sum[int64])
require.True(t, ok)
for _, point := range sum.DataPoints {
if hasMetricAttributes(point.Attributes.ToSlice(), wantAttrs) {
assert.Equal(t, wantValue, point.Value)
return
}
}
}
}
require.Failf(t, "test failed", "metric %q with attrs %v not found", metricName, wantAttrs)
}
func hasMetricAttributes(values []attribute.KeyValue, want map[string]string) bool {
if len(values) != len(want) {
return false
}
for _, value := range values {
if want[string(value.Key)] != value.Value.AsString() {
return false
}
}
return true
}
@@ -0,0 +1,4 @@
// Package shared provides cross-use-case application helpers for auth/session
// services, including typed service errors, input normalization, DTO mapping,
// and application-level retry helpers.
package shared
@@ -0,0 +1,407 @@
package shared
import (
"errors"
"net/http"
"strings"
)
const (
// ErrorCodeInvalidRequest reports malformed or semantically invalid service
// input.
ErrorCodeInvalidRequest = "invalid_request"
// ErrorCodeChallengeNotFound reports that the requested challenge does not
// exist.
ErrorCodeChallengeNotFound = "challenge_not_found"
// ErrorCodeChallengeExpired reports that the requested challenge may no
// longer be confirmed.
ErrorCodeChallengeExpired = "challenge_expired"
// ErrorCodeInvalidCode reports that the submitted confirmation code does not
// match the stored challenge.
ErrorCodeInvalidCode = "invalid_code"
// ErrorCodeInvalidClientPublicKey reports that the submitted client public
// key does not satisfy the Ed25519/base64 contract.
ErrorCodeInvalidClientPublicKey = "invalid_client_public_key"
// ErrorCodeBlockedByPolicy reports that the auth flow is denied by current
// user or registration policy.
ErrorCodeBlockedByPolicy = "blocked_by_policy"
// ErrorCodeSessionLimitExceeded reports that creating another active session
// would violate the configured limit.
ErrorCodeSessionLimitExceeded = "session_limit_exceeded"
// ErrorCodeSessionNotFound reports that the requested device session does
// not exist.
ErrorCodeSessionNotFound = "session_not_found"
// ErrorCodeSubjectNotFound reports that the requested trusted internal
// subject does not exist.
ErrorCodeSubjectNotFound = "subject_not_found"
// ErrorCodeServiceUnavailable reports that a required dependency or
// propagation step is temporarily unavailable.
ErrorCodeServiceUnavailable = "service_unavailable"
// ErrorCodeInternalError reports that local state is inconsistent or an
// invariant was broken unexpectedly.
ErrorCodeInternalError = "internal_error"
)
const genericInvalidRequestMessage = "request is invalid"
var publicErrorStatusCodes = map[string]int{
ErrorCodeInvalidRequest: http.StatusBadRequest,
ErrorCodeInvalidClientPublicKey: http.StatusBadRequest,
ErrorCodeInvalidCode: http.StatusBadRequest,
ErrorCodeChallengeNotFound: http.StatusNotFound,
ErrorCodeChallengeExpired: http.StatusGone,
ErrorCodeBlockedByPolicy: http.StatusForbidden,
ErrorCodeSessionLimitExceeded: http.StatusConflict,
ErrorCodeServiceUnavailable: http.StatusServiceUnavailable,
}
var publicStableMessages = map[string]string{
ErrorCodeChallengeNotFound: "challenge not found",
ErrorCodeChallengeExpired: "challenge expired",
ErrorCodeInvalidCode: "confirmation code is invalid",
ErrorCodeInvalidClientPublicKey: "client_public_key is not a valid base64-encoded raw 32-byte Ed25519 public key",
ErrorCodeBlockedByPolicy: "authentication is blocked by policy",
ErrorCodeSessionLimitExceeded: "active session limit would be exceeded",
ErrorCodeServiceUnavailable: "service is unavailable",
}
var internalErrorStatusCodes = map[string]int{
ErrorCodeInvalidRequest: http.StatusBadRequest,
ErrorCodeSessionNotFound: http.StatusNotFound,
ErrorCodeSubjectNotFound: http.StatusNotFound,
ErrorCodeServiceUnavailable: http.StatusServiceUnavailable,
ErrorCodeInternalError: http.StatusInternalServerError,
}
var internalStableMessages = map[string]string{
ErrorCodeSessionNotFound: "session not found",
ErrorCodeSubjectNotFound: "subject not found",
ErrorCodeServiceUnavailable: "service is unavailable",
ErrorCodeInternalError: "internal server error",
}
// PublicErrorProjection describes one transport-ready public auth error after
// internal service errors have been normalized to the frozen client-safe
// surface.
type PublicErrorProjection struct {
// StatusCode is the HTTP status that should be returned to the public auth
// caller.
StatusCode int
// Code is the stable client-safe error code written into the public JSON
// envelope.
Code string
// Message is the client-safe error description exposed to the public auth
// caller.
Message string
}
// InternalErrorProjection describes one transport-ready internal API error
// after service-layer failures have been normalized to the frozen trusted
// caller surface.
type InternalErrorProjection struct {
// StatusCode is the HTTP status that should be returned to the trusted
// caller.
StatusCode int
// Code is the stable error code written into the internal JSON envelope.
Code string
// Message is the trusted-caller-safe error description exposed by the
// internal HTTP API.
Message string
}
// ServiceError projects one stable application-layer failure with a service
// error code and a caller-safe message.
type ServiceError struct {
// Code is the stable error code expected by later transport mapping.
Code string
// Message is the caller-safe error description.
Message string
// Err optionally stores the wrapped underlying cause.
Err error
}
// Error returns the caller-safe error description.
func (e *ServiceError) Error() string {
if e == nil {
return ""
}
switch {
case strings.TrimSpace(e.Message) != "":
return e.Message
case strings.TrimSpace(e.Code) != "":
return e.Code
case e.Err != nil:
return e.Err.Error()
default:
return ErrorCodeInternalError
}
}
// Unwrap returns the wrapped cause, if any.
func (e *ServiceError) Unwrap() error {
if e == nil {
return nil
}
return e.Err
}
// NewServiceError returns a new typed application-layer error.
func NewServiceError(code string, message string, err error) *ServiceError {
return &ServiceError{
Code: strings.TrimSpace(code),
Message: strings.TrimSpace(message),
Err: err,
}
}
// IsPublicErrorCode reports whether code belongs to the frozen public auth
// error surface.
func IsPublicErrorCode(code string) bool {
_, ok := publicErrorStatusCodes[strings.TrimSpace(code)]
return ok
}
// IsInternalOnlyErrorCode reports whether code is intentionally excluded from
// the public auth transport surface.
func IsInternalOnlyErrorCode(code string) bool {
switch strings.TrimSpace(code) {
case ErrorCodeSessionNotFound, ErrorCodeSubjectNotFound, ErrorCodeInternalError:
return true
default:
return false
}
}
// IsSendEmailCodePublicErrorCode reports whether code may be exposed by the
// public send-email-code route after public projection.
func IsSendEmailCodePublicErrorCode(code string) bool {
switch strings.TrimSpace(code) {
case ErrorCodeInvalidRequest, ErrorCodeServiceUnavailable:
return true
default:
return false
}
}
// IsConfirmEmailCodePublicErrorCode reports whether code may be exposed by the
// public confirm-email-code route after public projection.
func IsConfirmEmailCodePublicErrorCode(code string) bool {
switch strings.TrimSpace(code) {
case ErrorCodeInvalidRequest,
ErrorCodeChallengeNotFound,
ErrorCodeChallengeExpired,
ErrorCodeInvalidCode,
ErrorCodeInvalidClientPublicKey,
ErrorCodeBlockedByPolicy,
ErrorCodeSessionLimitExceeded,
ErrorCodeServiceUnavailable:
return true
default:
return false
}
}
// PublicHTTPStatusCode reports the frozen public HTTP status for code. Unknown
// or internal-only codes are normalized to 503 service_unavailable.
func PublicHTTPStatusCode(code string) int {
if statusCode, ok := publicErrorStatusCodes[strings.TrimSpace(code)]; ok {
return statusCode
}
return http.StatusServiceUnavailable
}
// ProjectPublicError normalizes err to the frozen public-auth error surface.
// Unknown and internal-only service failures are intentionally projected as
// 503 service_unavailable so internal invariants do not leak to public callers.
func ProjectPublicError(err error) PublicErrorProjection {
serviceErr, ok := errors.AsType[*ServiceError](err)
code := CodeOf(err)
if !IsPublicErrorCode(code) {
return PublicErrorProjection{
StatusCode: http.StatusServiceUnavailable,
Code: ErrorCodeServiceUnavailable,
Message: publicMessageForCode(ErrorCodeServiceUnavailable, ""),
}
}
message := ""
if ok && serviceErr != nil {
message = serviceErr.Message
}
return PublicErrorProjection{
StatusCode: PublicHTTPStatusCode(code),
Code: code,
Message: publicMessageForCode(code, message),
}
}
// InternalHTTPStatusCode reports the frozen internal HTTP status for code.
// Unknown codes are normalized to 500 internal_error.
func InternalHTTPStatusCode(code string) int {
if statusCode, ok := internalErrorStatusCodes[strings.TrimSpace(code)]; ok {
return statusCode
}
return http.StatusInternalServerError
}
// ProjectInternalError normalizes err to the frozen internal trusted HTTP
// error surface. Unknown failures are intentionally projected as
// 500 internal_error so transport callers do not depend on unclassified local
// failures.
func ProjectInternalError(err error) InternalErrorProjection {
serviceErr, ok := errors.AsType[*ServiceError](err)
code := CodeOf(err)
if _, known := internalErrorStatusCodes[code]; !known {
return InternalErrorProjection{
StatusCode: http.StatusInternalServerError,
Code: ErrorCodeInternalError,
Message: internalMessageForCode(ErrorCodeInternalError, ""),
}
}
message := ""
if ok && serviceErr != nil {
message = serviceErr.Message
}
return InternalErrorProjection{
StatusCode: InternalHTTPStatusCode(code),
Code: code,
Message: internalMessageForCode(code, message),
}
}
// InvalidRequest reports one malformed or semantically invalid caller input.
func InvalidRequest(message string) *ServiceError {
return NewServiceError(ErrorCodeInvalidRequest, message, nil)
}
// ChallengeNotFound reports that the requested challenge does not exist.
func ChallengeNotFound() *ServiceError {
return NewServiceError(ErrorCodeChallengeNotFound, "challenge not found", nil)
}
// ChallengeExpired reports that the requested challenge is expired.
func ChallengeExpired() *ServiceError {
return NewServiceError(ErrorCodeChallengeExpired, "challenge expired", nil)
}
// InvalidCode reports that the submitted confirmation code is invalid.
func InvalidCode() *ServiceError {
return NewServiceError(ErrorCodeInvalidCode, "confirmation code is invalid", nil)
}
// InvalidClientPublicKey reports that the submitted client public key does not
// satisfy the frozen contract.
func InvalidClientPublicKey() *ServiceError {
return NewServiceError(
ErrorCodeInvalidClientPublicKey,
"client_public_key is not a valid base64-encoded raw 32-byte Ed25519 public key",
nil,
)
}
// BlockedByPolicy reports that the current auth flow is denied by policy.
func BlockedByPolicy() *ServiceError {
return NewServiceError(ErrorCodeBlockedByPolicy, "authentication is blocked by policy", nil)
}
// SessionLimitExceeded reports that creating another active session would
// exceed the current configured limit.
func SessionLimitExceeded() *ServiceError {
return NewServiceError(ErrorCodeSessionLimitExceeded, "active session limit would be exceeded", nil)
}
// SessionNotFound reports that the requested session does not exist.
func SessionNotFound() *ServiceError {
return NewServiceError(ErrorCodeSessionNotFound, "session not found", nil)
}
// SubjectNotFound reports that the requested internal subject does not exist.
func SubjectNotFound() *ServiceError {
return NewServiceError(ErrorCodeSubjectNotFound, "subject not found", nil)
}
// ServiceUnavailable reports that a required dependency or propagation step is
// temporarily unavailable.
func ServiceUnavailable(err error) *ServiceError {
return NewServiceError(ErrorCodeServiceUnavailable, "service is unavailable", err)
}
// InternalError reports an invariant-breaking local failure.
func InternalError(err error) *ServiceError {
return NewServiceError(ErrorCodeInternalError, "internal error", err)
}
// CodeOf returns the stable service error code of err when err wraps a
// ServiceError. Otherwise it returns ErrorCodeInternalError.
func CodeOf(err error) string {
serviceErr, ok := errors.AsType[*ServiceError](err)
if !ok || serviceErr == nil || strings.TrimSpace(serviceErr.Code) == "" {
return ErrorCodeInternalError
}
return serviceErr.Code
}
func publicMessageForCode(code string, message string) string {
trimmedMessage := strings.TrimSpace(message)
switch strings.TrimSpace(code) {
case ErrorCodeInvalidRequest:
if trimmedMessage != "" {
return trimmedMessage
}
return genericInvalidRequestMessage
case ErrorCodeServiceUnavailable:
return publicStableMessages[ErrorCodeServiceUnavailable]
default:
if stableMessage, ok := publicStableMessages[strings.TrimSpace(code)]; ok {
return stableMessage
}
return publicStableMessages[ErrorCodeServiceUnavailable]
}
}
func internalMessageForCode(code string, message string) string {
trimmedMessage := strings.TrimSpace(message)
switch strings.TrimSpace(code) {
case ErrorCodeInvalidRequest:
if trimmedMessage != "" {
return trimmedMessage
}
return genericInvalidRequestMessage
case ErrorCodeSessionNotFound,
ErrorCodeSubjectNotFound,
ErrorCodeServiceUnavailable,
ErrorCodeInternalError:
if stableMessage, ok := internalStableMessages[strings.TrimSpace(code)]; ok {
return stableMessage
}
return internalStableMessages[ErrorCodeInternalError]
default:
return internalStableMessages[ErrorCodeInternalError]
}
}
@@ -0,0 +1,158 @@
package shared
import (
"crypto/ed25519"
"encoding/base64"
"fmt"
"strings"
"time"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/devicesession"
)
// NormalizeString trims surrounding Unicode whitespace from value.
func NormalizeString(value string) string {
return strings.TrimSpace(value)
}
// ParseEmail trims value and validates it against the frozen public e-mail
// contract.
func ParseEmail(value string) (common.Email, error) {
email := common.Email(NormalizeString(value))
if err := email.Validate(); err != nil {
return "", InvalidRequest(err.Error())
}
return email, nil
}
// ParseChallengeID trims value and validates it as one challenge identifier.
func ParseChallengeID(value string) (common.ChallengeID, error) {
challengeID := common.ChallengeID(NormalizeString(value))
if err := challengeID.Validate(); err != nil {
return "", InvalidRequest(err.Error())
}
return challengeID, nil
}
// ParseDeviceSessionID trims value and validates it as one device-session
// identifier.
func ParseDeviceSessionID(value string) (common.DeviceSessionID, error) {
deviceSessionID := common.DeviceSessionID(NormalizeString(value))
if err := deviceSessionID.Validate(); err != nil {
return "", InvalidRequest(err.Error())
}
return deviceSessionID, nil
}
// ParseUserID trims value and validates it as one user identifier.
func ParseUserID(value string) (common.UserID, error) {
userID := common.UserID(NormalizeString(value))
if err := userID.Validate(); err != nil {
return "", InvalidRequest(err.Error())
}
return userID, nil
}
// ParseRequiredCode trims value and validates it as a required non-empty
// confirmation code.
func ParseRequiredCode(value string) (string, error) {
code := NormalizeString(value)
if code == "" {
return "", InvalidRequest("code must not be empty")
}
return code, nil
}
// ParseClientPublicKey trims value and validates it as the standard
// base64-encoded raw 32-byte Ed25519 public key expected by the public auth
// contract.
func ParseClientPublicKey(value string) (common.ClientPublicKey, error) {
normalized := NormalizeString(value)
if normalized == "" {
return common.ClientPublicKey{}, InvalidClientPublicKey()
}
decoded, err := base64.StdEncoding.Strict().DecodeString(normalized)
if err != nil || len(decoded) != ed25519.PublicKeySize {
return common.ClientPublicKey{}, InvalidClientPublicKey()
}
key, err := common.NewClientPublicKey(ed25519.PublicKey(decoded))
if err != nil {
return common.ClientPublicKey{}, InvalidClientPublicKey()
}
return key, nil
}
// ParseRevokeReasonCode trims value and validates it as one machine-readable
// revoke reason code.
func ParseRevokeReasonCode(value string) (common.RevokeReasonCode, error) {
code := common.RevokeReasonCode(NormalizeString(value))
if err := code.Validate(); err != nil {
return "", InvalidRequest(err.Error())
}
return code, nil
}
// ParseRevokeActorType trims value and validates it as one machine-readable
// revoke actor type.
func ParseRevokeActorType(value string) (common.RevokeActorType, error) {
actorType := common.RevokeActorType(NormalizeString(value))
if err := actorType.Validate(); err != nil {
return "", InvalidRequest(err.Error())
}
return actorType, nil
}
// ParseOptionalActorID trims value and validates it as one optional stable
// actor identifier.
func ParseOptionalActorID(value string) (string, error) {
actorID := NormalizeString(value)
if actorID != value {
return "", InvalidRequest("actor_id must not contain surrounding whitespace")
}
return actorID, nil
}
// BuildRevocation validates one revoke request payload and returns the domain
// revocation metadata applied to a session mutation.
func BuildRevocation(reasonCode string, actorType string, actorID string, at time.Time) (devicesession.Revocation, error) {
if at.IsZero() {
return devicesession.Revocation{}, InternalError(fmt.Errorf("revocation time must not be zero"))
}
parsedReasonCode, err := ParseRevokeReasonCode(reasonCode)
if err != nil {
return devicesession.Revocation{}, err
}
parsedActorType, err := ParseRevokeActorType(actorType)
if err != nil {
return devicesession.Revocation{}, err
}
parsedActorID, err := ParseOptionalActorID(actorID)
if err != nil {
return devicesession.Revocation{}, err
}
revocation := devicesession.Revocation{
At: at.UTC(),
ReasonCode: parsedReasonCode,
ActorType: parsedActorType,
ActorID: parsedActorID,
}
if err := revocation.Validate(); err != nil {
return devicesession.Revocation{}, InternalError(fmt.Errorf("build revocation: %w", err))
}
return revocation, nil
}
@@ -0,0 +1,46 @@
package shared
import (
"context"
authlogging "galaxy/authsession/internal/logging"
"go.uber.org/zap"
)
// LogServiceOutcome writes one structured service-level outcome log with a
// stable severity derived from err and with trace fields attached when ctx
// carries an active span.
func LogServiceOutcome(logger *zap.Logger, ctx context.Context, message string, err error, fields ...zap.Field) {
if logger == nil {
logger = zap.NewNop()
}
fields = append(fields, authlogging.TraceFieldsFromContext(ctx)...)
switch {
case err == nil:
logger.Info(message, fields...)
case isExpectedServiceErrorCode(CodeOf(err)):
logger.Warn(message, append(fields, zap.Error(err))...)
default:
logger.Error(message, append(fields, zap.Error(err))...)
}
}
func isExpectedServiceErrorCode(code string) bool {
switch code {
case ErrorCodeInvalidRequest,
ErrorCodeChallengeNotFound,
ErrorCodeChallengeExpired,
ErrorCodeInvalidCode,
ErrorCodeInvalidClientPublicKey,
ErrorCodeBlockedByPolicy,
ErrorCodeSessionLimitExceeded,
ErrorCodeSessionNotFound,
ErrorCodeSubjectNotFound:
return true
default:
return false
}
}
@@ -0,0 +1,11 @@
package shared
const (
// MaxCompareAndSwapRetries bounds application-level retry loops around
// compare-and-swap challenge updates.
MaxCompareAndSwapRetries = 3
// MaxProjectionPublishAttempts bounds synchronous request-path retries
// around gateway session projection publication.
MaxProjectionPublishAttempts = 3
)
@@ -0,0 +1,86 @@
package shared
import (
"context"
"errors"
"fmt"
"galaxy/authsession/internal/domain/devicesession"
"galaxy/authsession/internal/domain/gatewayprojection"
"galaxy/authsession/internal/ports"
"galaxy/authsession/internal/telemetry"
)
// PublishProjectionSnapshot publishes snapshot through publisher with a small
// bounded retry loop suitable for request-path consistency repair.
func PublishProjectionSnapshot(ctx context.Context, publisher ports.GatewaySessionProjectionPublisher, snapshot gatewayprojection.Snapshot) error {
return PublishProjectionSnapshotWithTelemetry(ctx, publisher, snapshot, nil, "")
}
// PublishProjectionSnapshotWithTelemetry publishes snapshot through publisher
// with the bounded request-path retry policy and optional publish-failure
// telemetry.
func PublishProjectionSnapshotWithTelemetry(
ctx context.Context,
publisher ports.GatewaySessionProjectionPublisher,
snapshot gatewayprojection.Snapshot,
telemetryRuntime *telemetry.Runtime,
operation string,
) error {
if publisher == nil {
return InternalError(errors.New("projection publisher must not be nil"))
}
if ctx == nil {
return ServiceUnavailable(errors.New("projection publish context must not be nil"))
}
if err := snapshot.Validate(); err != nil {
return InternalError(fmt.Errorf("publish projection snapshot: %w", err))
}
var lastErr error
for attempt := 0; attempt < MaxProjectionPublishAttempts; attempt++ {
if err := ctx.Err(); err != nil {
return ServiceUnavailable(err)
}
if err := publisher.PublishSession(ctx, snapshot); err == nil {
return nil
} else {
lastErr = err
}
}
telemetryRuntime.RecordProjectionPublishFailure(ctx, operation)
return ServiceUnavailable(
fmt.Errorf(
"publish projection snapshot %q after %d attempts: %w",
snapshot.DeviceSessionID,
MaxProjectionPublishAttempts,
lastErr,
),
)
}
// PublishSessionProjection converts record into the gateway-facing snapshot and
// publishes it with the bounded request-path retry policy.
func PublishSessionProjection(ctx context.Context, publisher ports.GatewaySessionProjectionPublisher, record devicesession.Session) error {
return PublishSessionProjectionWithTelemetry(ctx, publisher, record, nil, "")
}
// PublishSessionProjectionWithTelemetry converts record into the
// gateway-facing snapshot and publishes it with the bounded request-path retry
// policy and optional publish-failure telemetry.
func PublishSessionProjectionWithTelemetry(
ctx context.Context,
publisher ports.GatewaySessionProjectionPublisher,
record devicesession.Session,
telemetryRuntime *telemetry.Runtime,
operation string,
) error {
snapshot, err := ToGatewayProjectionSnapshot(record)
if err != nil {
return InternalError(err)
}
return PublishProjectionSnapshotWithTelemetry(ctx, publisher, snapshot, telemetryRuntime, operation)
}
@@ -0,0 +1,119 @@
package shared
import (
"context"
"errors"
"testing"
"galaxy/authsession/internal/domain/devicesession"
"galaxy/authsession/internal/domain/gatewayprojection"
"galaxy/authsession/internal/testkit"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPublishSessionProjectionRetriesUntilSuccess(t *testing.T) {
t.Parallel()
tests := []struct {
name string
errors []error
wantAttempts int
}{
{
name: "success on second attempt",
errors: []error{errors.New("transient publish failure"), nil},
wantAttempts: 2,
},
{
name: "success on third attempt",
errors: []error{errors.New("transient publish failure"), errors.New("transient publish failure"), nil},
wantAttempts: 3,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
publisher := &testkit.RecordingProjectionPublisher{Errors: tt.errors}
err := PublishSessionProjection(context.Background(), publisher, revokedSessionFixture())
require.NoError(t, err)
require.Len(t, publisher.PublishedSnapshots(), tt.wantAttempts)
})
}
}
func TestPublishSessionProjectionReturnsServiceUnavailableAfterExhaustedRetries(t *testing.T) {
t.Parallel()
publisher := &testkit.RecordingProjectionPublisher{Err: errors.New("publish failed")}
err := PublishSessionProjection(context.Background(), publisher, revokedSessionFixture())
require.Error(t, err)
assert.Equal(t, ErrorCodeServiceUnavailable, CodeOf(err))
require.Len(t, publisher.PublishedSnapshots(), MaxProjectionPublishAttempts)
}
func TestPublishProjectionSnapshotStopsRetriesWhenContextIsCanceled(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
publisher := &cancelingProjectionPublisher{
cancel: cancel,
err: errors.New("publish failed"),
}
err := PublishProjectionSnapshot(ctx, publisher, mustProjectionSnapshot(t))
require.Error(t, err)
assert.Equal(t, ErrorCodeServiceUnavailable, CodeOf(err))
assert.Equal(t, 1, publisher.attempts)
}
func TestPublishSessionProjectionReturnsInternalErrorForInvalidLocalRecord(t *testing.T) {
t.Parallel()
publisher := &testkit.RecordingProjectionPublisher{}
err := PublishSessionProjection(context.Background(), publisher, invalidSessionFixture())
require.Error(t, err)
assert.Equal(t, ErrorCodeInternalError, CodeOf(err))
assert.Empty(t, publisher.PublishedSnapshots())
}
type cancelingProjectionPublisher struct {
attempts int
cancel context.CancelFunc
err error
}
func (p *cancelingProjectionPublisher) PublishSession(_ context.Context, snapshot gatewayprojection.Snapshot) error {
if err := snapshot.Validate(); err != nil {
return err
}
p.attempts++
if p.cancel != nil {
p.cancel()
p.cancel = nil
}
return p.err
}
func mustProjectionSnapshot(t *testing.T) gatewayprojection.Snapshot {
t.Helper()
snapshot, err := ToGatewayProjectionSnapshot(revokedSessionFixture())
require.NoError(t, err)
return snapshot
}
func invalidSessionFixture() devicesession.Session {
return devicesession.Session{}
}
@@ -0,0 +1,134 @@
package shared
import (
"fmt"
"time"
"galaxy/authsession/internal/domain/devicesession"
"galaxy/authsession/internal/domain/gatewayprojection"
)
// Session mirrors the frozen internal read-model DTO used by later trusted
// transport handlers.
type Session struct {
// DeviceSessionID is the stable identifier of one device session.
DeviceSessionID string
// UserID is the stable identifier of the session owner.
UserID string
// ClientPublicKey is the base64-encoded raw 32-byte Ed25519 public key of
// the device session.
ClientPublicKey string
// Status reports whether the session is active or revoked.
Status string
// CreatedAt is the RFC3339 UTC timestamp at which the session was created.
CreatedAt string
// RevokedAt is the RFC3339 UTC timestamp at which the session was revoked,
// when the session is revoked.
RevokedAt *string
// RevokeReasonCode is the machine-readable revoke reason code when the
// session is revoked.
RevokeReasonCode *string
// RevokeActorType is the machine-readable revoke actor type when the
// session is revoked.
RevokeActorType *string
// RevokeActorID is the optional stable revoke actor identifier when the
// session is revoked.
RevokeActorID *string
}
// ToSession converts source-of-truth session into the frozen internal read DTO
// shape.
func ToSession(record devicesession.Session) (Session, error) {
if err := record.Validate(); err != nil {
return Session{}, fmt.Errorf("map session: %w", err)
}
result := Session{
DeviceSessionID: record.ID.String(),
UserID: record.UserID.String(),
ClientPublicKey: record.ClientPublicKey.String(),
Status: string(record.Status),
CreatedAt: formatTime(record.CreatedAt),
}
if record.Revocation != nil {
revokedAt := formatTime(record.Revocation.At)
reasonCode := record.Revocation.ReasonCode.String()
actorType := record.Revocation.ActorType.String()
result.RevokedAt = &revokedAt
result.RevokeReasonCode = &reasonCode
result.RevokeActorType = &actorType
if record.Revocation.ActorID != "" {
actorID := record.Revocation.ActorID
result.RevokeActorID = &actorID
}
}
return result, nil
}
// ToSessions converts every source-of-truth session into the frozen internal
// read DTO shape.
func ToSessions(records []devicesession.Session) ([]Session, error) {
result := make([]Session, 0, len(records))
for index, record := range records {
mapped, err := ToSession(record)
if err != nil {
return nil, fmt.Errorf("map session %d: %w", index, err)
}
result = append(result, mapped)
}
return result, nil
}
// ToGatewayProjectionSnapshot converts source-of-truth session into the
// separate gateway-facing projection model.
func ToGatewayProjectionSnapshot(record devicesession.Session) (gatewayprojection.Snapshot, error) {
if err := record.Validate(); err != nil {
return gatewayprojection.Snapshot{}, fmt.Errorf("map gateway projection snapshot: %w", err)
}
snapshot := gatewayprojection.Snapshot{
DeviceSessionID: record.ID,
UserID: record.UserID,
ClientPublicKey: record.ClientPublicKey.String(),
Status: gatewayprojection.Status(record.Status),
}
if record.Revocation != nil {
snapshot.RevokedAt = cloneTimePointer(commonTimePointer(record.Revocation.At.UTC()))
snapshot.RevokeReasonCode = record.Revocation.ReasonCode
snapshot.RevokeActorType = record.Revocation.ActorType
snapshot.RevokeActorID = record.Revocation.ActorID
}
if err := snapshot.Validate(); err != nil {
return gatewayprojection.Snapshot{}, fmt.Errorf("map gateway projection snapshot: %w", err)
}
return snapshot, nil
}
func formatTime(value time.Time) string {
return value.UTC().Format(time.RFC3339)
}
func commonTimePointer(value time.Time) *time.Time {
return &value
}
func cloneTimePointer(value *time.Time) *time.Time {
if value == nil {
return nil
}
cloned := *value
return &cloned
}
@@ -0,0 +1,40 @@
package shared
import (
"fmt"
"galaxy/authsession/internal/domain/sessionlimit"
"galaxy/authsession/internal/ports"
)
// EvaluateSessionLimit evaluates the Stage-4 active-session creation decision
// from the loaded configuration and current active-session count.
func EvaluateSessionLimit(config ports.SessionLimitConfig, activeSessionCount int) (sessionlimit.Decision, error) {
if err := config.Validate(); err != nil {
return sessionlimit.Decision{}, InternalError(fmt.Errorf("evaluate session limit: %w", err))
}
if activeSessionCount < 0 {
return sessionlimit.Decision{}, InternalError(fmt.Errorf("evaluate session limit: active session count %d is negative", activeSessionCount))
}
decision := sessionlimit.Decision{
ActiveSessionCount: activeSessionCount,
NextSessionCount: activeSessionCount + 1,
}
if config.ActiveSessionLimit == nil {
decision.Kind = sessionlimit.KindDisabled
} else {
decision.ConfiguredLimit = config.ActiveSessionLimit
if decision.NextSessionCount <= *config.ActiveSessionLimit {
decision.Kind = sessionlimit.KindAllowed
} else {
decision.Kind = sessionlimit.KindExceeded
}
}
if err := decision.Validate(); err != nil {
return sessionlimit.Decision{}, InternalError(fmt.Errorf("evaluate session limit: %w", err))
}
return decision, nil
}
@@ -0,0 +1,380 @@
package shared
import (
"errors"
"net/http"
"testing"
"time"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/domain/devicesession"
"galaxy/authsession/internal/domain/gatewayprojection"
"galaxy/authsession/internal/domain/sessionlimit"
"galaxy/authsession/internal/ports"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNormalizeString(t *testing.T) {
t.Parallel()
assert.Equal(t, "pilot@example.com", NormalizeString(" pilot@example.com \n"))
}
func TestParseClientPublicKey(t *testing.T) {
t.Parallel()
key, err := ParseClientPublicKey(" AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8= ")
require.NoError(t, err)
assert.Equal(t, "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=", key.String())
_, err = ParseClientPublicKey("invalid")
require.Error(t, err)
assert.Equal(t, ErrorCodeInvalidClientPublicKey, CodeOf(err))
}
func TestToSession(t *testing.T) {
t.Parallel()
record := revokedSessionFixture()
dto, err := ToSession(record)
require.NoError(t, err)
assert.Equal(t, record.ID.String(), dto.DeviceSessionID)
require.NotNil(t, dto.RevokedAt)
assert.Equal(t, record.Revocation.At.UTC().Format(time.RFC3339), *dto.RevokedAt)
}
func TestToGatewayProjectionSnapshot(t *testing.T) {
t.Parallel()
record := revokedSessionFixture()
snapshot, err := ToGatewayProjectionSnapshot(record)
require.NoError(t, err)
assert.Equal(t, gatewayprojection.StatusRevoked, snapshot.Status)
}
func TestEvaluateSessionLimit(t *testing.T) {
t.Parallel()
limit := 2
tests := []struct {
name string
config ports.SessionLimitConfig
active int
want sessionlimit.Kind
}{
{name: "disabled", config: ports.SessionLimitConfig{}, active: 3, want: sessionlimit.KindDisabled},
{name: "allowed", config: ports.SessionLimitConfig{ActiveSessionLimit: &limit}, active: 1, want: sessionlimit.KindAllowed},
{name: "exceeded", config: ports.SessionLimitConfig{ActiveSessionLimit: &limit}, active: 2, want: sessionlimit.KindExceeded},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
decision, err := EvaluateSessionLimit(tt.config, tt.active)
require.NoError(t, err)
assert.Equal(t, tt.want, decision.Kind)
})
}
}
func TestServiceErrorCodePreservation(t *testing.T) {
t.Parallel()
baseErr := errors.New("base")
err := ServiceUnavailable(baseErr)
assert.Equal(t, ErrorCodeServiceUnavailable, CodeOf(err))
assert.ErrorIs(t, err, baseErr)
}
func TestErrorCodeClassification(t *testing.T) {
t.Parallel()
publicCodes := []string{
ErrorCodeInvalidRequest,
ErrorCodeChallengeNotFound,
ErrorCodeChallengeExpired,
ErrorCodeInvalidCode,
ErrorCodeInvalidClientPublicKey,
ErrorCodeBlockedByPolicy,
ErrorCodeSessionLimitExceeded,
ErrorCodeServiceUnavailable,
}
for _, code := range publicCodes {
assert.Truef(t, IsPublicErrorCode(code), "IsPublicErrorCode(%q)", code)
assert.Falsef(t, IsInternalOnlyErrorCode(code), "IsInternalOnlyErrorCode(%q)", code)
}
internalOnlyCodes := []string{
ErrorCodeSessionNotFound,
ErrorCodeSubjectNotFound,
ErrorCodeInternalError,
}
for _, code := range internalOnlyCodes {
assert.Falsef(t, IsPublicErrorCode(code), "IsPublicErrorCode(%q)", code)
assert.Truef(t, IsInternalOnlyErrorCode(code), "IsInternalOnlyErrorCode(%q)", code)
}
}
func TestPublicUseCaseErrorCodeSets(t *testing.T) {
t.Parallel()
assert.True(t, IsSendEmailCodePublicErrorCode(ErrorCodeInvalidRequest))
assert.True(t, IsSendEmailCodePublicErrorCode(ErrorCodeServiceUnavailable))
assert.False(t, IsSendEmailCodePublicErrorCode(ErrorCodeBlockedByPolicy))
assert.False(t, IsSendEmailCodePublicErrorCode(ErrorCodeChallengeNotFound))
confirmCodes := []string{
ErrorCodeInvalidRequest,
ErrorCodeChallengeNotFound,
ErrorCodeChallengeExpired,
ErrorCodeInvalidCode,
ErrorCodeInvalidClientPublicKey,
ErrorCodeBlockedByPolicy,
ErrorCodeSessionLimitExceeded,
ErrorCodeServiceUnavailable,
}
for _, code := range confirmCodes {
assert.Truef(t, IsConfirmEmailCodePublicErrorCode(code), "IsConfirmEmailCodePublicErrorCode(%q)", code)
}
assert.False(t, IsConfirmEmailCodePublicErrorCode(ErrorCodeInternalError))
assert.False(t, IsConfirmEmailCodePublicErrorCode(ErrorCodeSessionNotFound))
}
func TestPublicHTTPStatusCode(t *testing.T) {
t.Parallel()
tests := []struct {
name string
code string
want int
}{
{name: "invalid request", code: ErrorCodeInvalidRequest, want: http.StatusBadRequest},
{name: "invalid client public key", code: ErrorCodeInvalidClientPublicKey, want: http.StatusBadRequest},
{name: "invalid code", code: ErrorCodeInvalidCode, want: http.StatusBadRequest},
{name: "challenge not found", code: ErrorCodeChallengeNotFound, want: http.StatusNotFound},
{name: "challenge expired", code: ErrorCodeChallengeExpired, want: http.StatusGone},
{name: "blocked by policy", code: ErrorCodeBlockedByPolicy, want: http.StatusForbidden},
{name: "session limit exceeded", code: ErrorCodeSessionLimitExceeded, want: http.StatusConflict},
{name: "service unavailable", code: ErrorCodeServiceUnavailable, want: http.StatusServiceUnavailable},
{name: "internal error normalized", code: ErrorCodeInternalError, want: http.StatusServiceUnavailable},
{name: "unknown normalized", code: "unknown", want: http.StatusServiceUnavailable},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
assert.Equal(t, tt.want, PublicHTTPStatusCode(tt.code))
})
}
}
func TestInternalHTTPStatusCode(t *testing.T) {
t.Parallel()
tests := []struct {
name string
code string
want int
}{
{name: "invalid request", code: ErrorCodeInvalidRequest, want: http.StatusBadRequest},
{name: "session not found", code: ErrorCodeSessionNotFound, want: http.StatusNotFound},
{name: "subject not found", code: ErrorCodeSubjectNotFound, want: http.StatusNotFound},
{name: "service unavailable", code: ErrorCodeServiceUnavailable, want: http.StatusServiceUnavailable},
{name: "internal error", code: ErrorCodeInternalError, want: http.StatusInternalServerError},
{name: "unknown normalized", code: "unknown", want: http.StatusInternalServerError},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
assert.Equal(t, tt.want, InternalHTTPStatusCode(tt.code))
})
}
}
func TestProjectPublicError(t *testing.T) {
t.Parallel()
tests := []struct {
name string
err error
want PublicErrorProjection
}{
{
name: "invalid request keeps detailed message",
err: InvalidRequest("email must be a single valid email address"),
want: PublicErrorProjection{
StatusCode: http.StatusBadRequest,
Code: ErrorCodeInvalidRequest,
Message: "email must be a single valid email address",
},
},
{
name: "invalid code keeps canonical message",
err: NewServiceError(ErrorCodeInvalidCode, "custom detail should not leak", nil),
want: PublicErrorProjection{
StatusCode: http.StatusBadRequest,
Code: ErrorCodeInvalidCode,
Message: "confirmation code is invalid",
},
},
{
name: "service unavailable keeps generic message",
err: NewServiceError(ErrorCodeServiceUnavailable, "dependency timeout", errors.New("dependency timeout")),
want: PublicErrorProjection{
StatusCode: http.StatusServiceUnavailable,
Code: ErrorCodeServiceUnavailable,
Message: "service is unavailable",
},
},
{
name: "internal error is hidden",
err: InternalError(errors.New("broken invariant")),
want: PublicErrorProjection{
StatusCode: http.StatusServiceUnavailable,
Code: ErrorCodeServiceUnavailable,
Message: "service is unavailable",
},
},
{
name: "internal only session not found is hidden",
err: SessionNotFound(),
want: PublicErrorProjection{
StatusCode: http.StatusServiceUnavailable,
Code: ErrorCodeServiceUnavailable,
Message: "service is unavailable",
},
},
{
name: "non service error is hidden",
err: errors.New("boom"),
want: PublicErrorProjection{
StatusCode: http.StatusServiceUnavailable,
Code: ErrorCodeServiceUnavailable,
Message: "service is unavailable",
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
assert.Equal(t, tt.want, ProjectPublicError(tt.err))
})
}
}
func TestProjectInternalError(t *testing.T) {
t.Parallel()
tests := []struct {
name string
err error
want InternalErrorProjection
}{
{
name: "invalid request keeps detailed message",
err: InvalidRequest("reason_code must not be empty"),
want: InternalErrorProjection{
StatusCode: http.StatusBadRequest,
Code: ErrorCodeInvalidRequest,
Message: "reason_code must not be empty",
},
},
{
name: "session not found keeps canonical message",
err: NewServiceError(ErrorCodeSessionNotFound, "custom detail should not leak", nil),
want: InternalErrorProjection{
StatusCode: http.StatusNotFound,
Code: ErrorCodeSessionNotFound,
Message: "session not found",
},
},
{
name: "subject not found keeps canonical message",
err: SubjectNotFound(),
want: InternalErrorProjection{
StatusCode: http.StatusNotFound,
Code: ErrorCodeSubjectNotFound,
Message: "subject not found",
},
},
{
name: "service unavailable keeps generic message",
err: NewServiceError(ErrorCodeServiceUnavailable, "redis timeout", errors.New("redis timeout")),
want: InternalErrorProjection{
StatusCode: http.StatusServiceUnavailable,
Code: ErrorCodeServiceUnavailable,
Message: "service is unavailable",
},
},
{
name: "internal error uses internal server error message",
err: InternalError(errors.New("broken invariant")),
want: InternalErrorProjection{
StatusCode: http.StatusInternalServerError,
Code: ErrorCodeInternalError,
Message: "internal server error",
},
},
{
name: "unexpected error is hidden",
err: errors.New("boom"),
want: InternalErrorProjection{
StatusCode: http.StatusInternalServerError,
Code: ErrorCodeInternalError,
Message: "internal server error",
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
assert.Equal(t, tt.want, ProjectInternalError(tt.err))
})
}
}
func revokedSessionFixture() devicesession.Session {
key, err := common.NewClientPublicKey(make([]byte, 32))
if err != nil {
panic(err)
}
revokedAt := time.Unix(20, 0).UTC()
return devicesession.Session{
ID: common.DeviceSessionID("device-session-1"),
UserID: common.UserID("user-1"),
ClientPublicKey: key,
Status: devicesession.StatusRevoked,
CreatedAt: time.Unix(10, 0).UTC(),
Revocation: &devicesession.Revocation{
At: revokedAt,
ReasonCode: devicesession.RevokeReasonLogoutAll,
ActorType: common.RevokeActorType("system"),
ActorID: "actor-1",
},
}
}