feat: authsession service
This commit is contained in:
@@ -0,0 +1,88 @@
|
||||
package blockuser
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/devicesession"
|
||||
"galaxy/authsession/internal/domain/userresolution"
|
||||
"galaxy/authsession/internal/service/shared"
|
||||
"galaxy/authsession/internal/testkit"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExecuteRetriesProjectionPublishesForBlockFlow(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
userDirectory := &testkit.InMemoryUserDirectory{}
|
||||
store := &testkit.InMemorySessionStore{}
|
||||
publisher := &testkit.RecordingProjectionPublisher{
|
||||
Errors: []error{errors.New("publish failed"), nil},
|
||||
}
|
||||
require.NoError(t, userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
|
||||
require.NoError(t, store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())))
|
||||
|
||||
service, err := New(userDirectory, store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := service.Execute(context.Background(), Input{
|
||||
UserID: "user-1",
|
||||
ReasonCode: "policy_block",
|
||||
ActorType: "admin",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "blocked", result.Outcome)
|
||||
assert.EqualValues(t, 1, result.AffectedSessionCount)
|
||||
require.Len(t, publisher.PublishedSnapshots(), 2)
|
||||
}
|
||||
|
||||
func TestExecuteRepairsProjectionOnRepeatedAlreadyBlockedRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
userDirectory := &testkit.InMemoryUserDirectory{}
|
||||
store := &testkit.InMemorySessionStore{}
|
||||
publisher := &testkit.RecordingProjectionPublisher{Err: errors.New("publish failed")}
|
||||
require.NoError(t, userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
|
||||
require.NoError(t, store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())))
|
||||
|
||||
service, err := New(userDirectory, store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = service.Execute(context.Background(), Input{
|
||||
UserID: "user-1",
|
||||
ReasonCode: "policy_block",
|
||||
ActorType: "admin",
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, shared.ErrorCodeServiceUnavailable, shared.CodeOf(err))
|
||||
require.Len(t, publisher.PublishedSnapshots(), shared.MaxProjectionPublishAttempts)
|
||||
|
||||
sessionRecord, getErr := store.Get(context.Background(), common.DeviceSessionID("device-session-1"))
|
||||
require.NoError(t, getErr)
|
||||
require.NotNil(t, sessionRecord.Revocation)
|
||||
assert.Equal(t, devicesession.StatusRevoked, sessionRecord.Status)
|
||||
assert.Equal(t, devicesession.RevokeReasonUserBlocked, sessionRecord.Revocation.ReasonCode)
|
||||
|
||||
resolution, resolveErr := userDirectory.ResolveByEmail(context.Background(), common.Email("pilot@example.com"))
|
||||
require.NoError(t, resolveErr)
|
||||
assert.Equal(t, userresolution.KindBlocked, resolution.Kind)
|
||||
|
||||
publisher.Err = nil
|
||||
|
||||
result, err := service.Execute(context.Background(), Input{
|
||||
UserID: "user-1",
|
||||
ReasonCode: "policy_block",
|
||||
ActorType: "admin",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "already_blocked", result.Outcome)
|
||||
assert.EqualValues(t, 0, result.AffectedSessionCount)
|
||||
require.NotNil(t, result.AffectedDeviceSessionIDs)
|
||||
assert.Empty(t, result.AffectedDeviceSessionIDs)
|
||||
require.Len(t, publisher.PublishedSnapshots(), shared.MaxProjectionPublishAttempts+1)
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
package blockuser
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/challenge"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/service/confirmemailcode"
|
||||
"galaxy/authsession/internal/service/sendemailcode"
|
||||
"galaxy/authsession/internal/service/shared"
|
||||
"galaxy/authsession/internal/testkit"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const blockFlowPublicKey = "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8="
|
||||
|
||||
func TestBlockUserAffectsLaterSendAndConfirmFlows(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
challengeStore := &testkit.InMemoryChallengeStore{}
|
||||
sessionStore := &testkit.InMemorySessionStore{}
|
||||
userDirectory := &testkit.InMemoryUserDirectory{}
|
||||
publisher := &testkit.RecordingProjectionPublisher{}
|
||||
idGenerator := &testkit.SequenceIDGenerator{
|
||||
ChallengeIDs: []common.ChallengeID{"challenge-1"},
|
||||
DeviceSessionIDs: []common.DeviceSessionID{"device-session-1"},
|
||||
}
|
||||
hasher := testkit.DeterministicCodeHasher{}
|
||||
mailSender := &testkit.RecordingMailSender{}
|
||||
now := time.Unix(20, 0).UTC()
|
||||
clock := testkit.FixedClock{Time: now}
|
||||
|
||||
blockService, err := New(userDirectory, sessionStore, publisher, clock)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = blockService.Execute(context.Background(), Input{
|
||||
Email: "pilot@example.com",
|
||||
ReasonCode: "policy_block",
|
||||
ActorType: "admin",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
sendService, err := sendemailcode.New(
|
||||
challengeStore,
|
||||
userDirectory,
|
||||
idGenerator,
|
||||
testkit.FixedCodeGenerator{Code: "654321"},
|
||||
hasher,
|
||||
mailSender,
|
||||
clock,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
sendResult, err := sendService.Execute(context.Background(), sendemailcode.Input{Email: "pilot@example.com"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "challenge-1", sendResult.ChallengeID)
|
||||
assert.Empty(t, mailSender.RecordedInputs())
|
||||
|
||||
challengeRecord, err := challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, challenge.StatusDeliverySuppressed, challengeRecord.Status)
|
||||
assert.Equal(t, challenge.DeliverySuppressed, challengeRecord.DeliveryState)
|
||||
|
||||
confirmService, err := confirmemailcode.New(
|
||||
challengeStore,
|
||||
sessionStore,
|
||||
userDirectory,
|
||||
testkit.StaticConfigProvider{},
|
||||
publisher,
|
||||
idGenerator,
|
||||
hasher,
|
||||
clock,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = confirmService.Execute(context.Background(), confirmemailcode.Input{
|
||||
ChallengeID: "challenge-1",
|
||||
Code: "654321",
|
||||
ClientPublicKey: blockFlowPublicKey,
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, shared.ErrorCodeBlockedByPolicy, shared.CodeOf(err))
|
||||
|
||||
updatedChallenge, getErr := challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
|
||||
require.NoError(t, getErr)
|
||||
assert.Equal(t, challenge.StatusFailed, updatedChallenge.Status)
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
package blockuser
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/testkit"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
func TestExecuteLogsSafeOutcomeFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
userDirectory := &testkit.InMemoryUserDirectory{}
|
||||
require.NoError(t, userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
|
||||
|
||||
sessionStore := &testkit.InMemorySessionStore{}
|
||||
require.NoError(t, sessionStore.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())))
|
||||
|
||||
logger, buffer := newObservedServiceLogger()
|
||||
service, err := NewWithObservability(
|
||||
userDirectory,
|
||||
sessionStore,
|
||||
&testkit.RecordingProjectionPublisher{},
|
||||
testkit.FixedClock{Time: time.Unix(20, 0).UTC()},
|
||||
logger,
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = service.Execute(context.Background(), Input{
|
||||
UserID: "user-1",
|
||||
ReasonCode: "policy_block",
|
||||
ActorType: "admin",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
logOutput := buffer.String()
|
||||
assert.Contains(t, logOutput, "block_user")
|
||||
assert.Contains(t, logOutput, "\"user_id\":\"user-1\"")
|
||||
assert.Contains(t, logOutput, "\"reason_code\":\"policy_block\"")
|
||||
assert.NotContains(t, logOutput, "pilot@example.com")
|
||||
}
|
||||
|
||||
func newObservedServiceLogger() (*zap.Logger, *bytes.Buffer) {
|
||||
buffer := &bytes.Buffer{}
|
||||
encoderConfig := zap.NewProductionEncoderConfig()
|
||||
encoderConfig.TimeKey = ""
|
||||
|
||||
core := zapcore.NewCore(
|
||||
zapcore.NewJSONEncoder(encoderConfig),
|
||||
zapcore.AddSync(buffer),
|
||||
zap.DebugLevel,
|
||||
)
|
||||
|
||||
return zap.New(core), buffer
|
||||
}
|
||||
@@ -0,0 +1,294 @@
|
||||
// Package blockuser implements the trusted internal block-user use case.
|
||||
package blockuser
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/devicesession"
|
||||
"galaxy/authsession/internal/domain/userresolution"
|
||||
"galaxy/authsession/internal/ports"
|
||||
"galaxy/authsession/internal/service/shared"
|
||||
"galaxy/authsession/internal/telemetry"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
// SubjectKindUserID identifies a block request addressed by stable user id.
|
||||
SubjectKindUserID = "user_id"
|
||||
|
||||
// SubjectKindEmail identifies a block request addressed by normalized e-mail
|
||||
// address.
|
||||
SubjectKindEmail = "email"
|
||||
)
|
||||
|
||||
// Input describes one trusted internal block-user request.
|
||||
type Input struct {
|
||||
// UserID identifies the subject to block when the request is user-id based.
|
||||
UserID string
|
||||
|
||||
// Email identifies the subject to block when the request is e-mail based.
|
||||
Email string
|
||||
|
||||
// ReasonCode stores the machine-readable block reason code applied to the
|
||||
// user directory.
|
||||
ReasonCode string
|
||||
|
||||
// ActorType stores the machine-readable actor type for any derived session
|
||||
// revocation.
|
||||
ActorType string
|
||||
|
||||
// ActorID stores the optional stable actor identifier for any derived
|
||||
// session revocation.
|
||||
ActorID string
|
||||
}
|
||||
|
||||
// Result describes the frozen internal block-user acknowledgement.
|
||||
type Result struct {
|
||||
// Outcome reports whether the block state was newly applied or already
|
||||
// existed.
|
||||
Outcome string
|
||||
|
||||
// SubjectKind reports whether the request targeted `user_id` or `email`.
|
||||
SubjectKind string
|
||||
|
||||
// SubjectValue stores the normalized subject value addressed by the
|
||||
// operation.
|
||||
SubjectValue string
|
||||
|
||||
// AffectedSessionCount reports how many sessions changed state during the
|
||||
// current call.
|
||||
AffectedSessionCount int64
|
||||
|
||||
// AffectedDeviceSessionIDs lists every session identifier affected during
|
||||
// the current call.
|
||||
AffectedDeviceSessionIDs []string
|
||||
}
|
||||
|
||||
// Service executes the trusted internal block-user use case.
|
||||
type Service struct {
|
||||
userDirectory ports.UserDirectory
|
||||
sessionStore ports.SessionStore
|
||||
publisher ports.GatewaySessionProjectionPublisher
|
||||
clock ports.Clock
|
||||
logger *zap.Logger
|
||||
telemetry *telemetry.Runtime
|
||||
}
|
||||
|
||||
// New returns a block-user service wired to the required ports.
|
||||
func New(userDirectory ports.UserDirectory, sessionStore ports.SessionStore, publisher ports.GatewaySessionProjectionPublisher, clock ports.Clock) (*Service, error) {
|
||||
return NewWithObservability(userDirectory, sessionStore, publisher, clock, nil, nil)
|
||||
}
|
||||
|
||||
// NewWithObservability returns a block-user service wired to the required
|
||||
// ports plus optional structured logging and telemetry dependencies.
|
||||
func NewWithObservability(
|
||||
userDirectory ports.UserDirectory,
|
||||
sessionStore ports.SessionStore,
|
||||
publisher ports.GatewaySessionProjectionPublisher,
|
||||
clock ports.Clock,
|
||||
logger *zap.Logger,
|
||||
telemetryRuntime *telemetry.Runtime,
|
||||
) (*Service, error) {
|
||||
switch {
|
||||
case userDirectory == nil:
|
||||
return nil, fmt.Errorf("blockuser: user directory must not be nil")
|
||||
case sessionStore == nil:
|
||||
return nil, fmt.Errorf("blockuser: session store must not be nil")
|
||||
case publisher == nil:
|
||||
return nil, fmt.Errorf("blockuser: projection publisher must not be nil")
|
||||
case clock == nil:
|
||||
return nil, fmt.Errorf("blockuser: clock must not be nil")
|
||||
default:
|
||||
return &Service{
|
||||
userDirectory: userDirectory,
|
||||
sessionStore: sessionStore,
|
||||
publisher: publisher,
|
||||
clock: clock,
|
||||
logger: namedLogger(logger, "block_user"),
|
||||
telemetry: telemetryRuntime,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Execute applies the requested block state and revokes any active sessions of
|
||||
// the resolved user when one exists.
|
||||
func (s *Service) Execute(ctx context.Context, input Input) (result Result, err error) {
|
||||
logFields := []zap.Field{
|
||||
zap.String("component", "service"),
|
||||
zap.String("use_case", "block_user"),
|
||||
}
|
||||
defer func() {
|
||||
if result.Outcome != "" {
|
||||
logFields = append(logFields, zap.String("outcome", result.Outcome))
|
||||
}
|
||||
if result.SubjectKind != "" {
|
||||
logFields = append(logFields, zap.String("subject_kind", result.SubjectKind))
|
||||
}
|
||||
if result.AffectedSessionCount > 0 {
|
||||
logFields = append(logFields, zap.Int64("affected_session_count", result.AffectedSessionCount))
|
||||
}
|
||||
shared.LogServiceOutcome(s.logger, ctx, "block user completed", err, logFields...)
|
||||
}()
|
||||
|
||||
subjectKind, subjectValue, storeResult, err := s.blockSubject(ctx, input)
|
||||
if err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
logFields = append(logFields, zap.String("reason_code", shared.NormalizeString(input.ReasonCode)))
|
||||
if !storeResult.UserID.IsZero() {
|
||||
logFields = append(logFields, zap.String("user_id", storeResult.UserID.String()))
|
||||
}
|
||||
|
||||
affectedDeviceSessionIDs := []string{}
|
||||
affectedSessionCount := int64(0)
|
||||
if !storeResult.UserID.IsZero() {
|
||||
revocation, err := shared.BuildRevocation(
|
||||
devicesession.RevokeReasonUserBlocked.String(),
|
||||
input.ActorType,
|
||||
input.ActorID,
|
||||
s.clock.Now(),
|
||||
)
|
||||
if err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
|
||||
revokeResult, err := s.sessionStore.RevokeAllByUserID(ctx, ports.RevokeUserSessionsInput{
|
||||
UserID: storeResult.UserID,
|
||||
Revocation: revocation,
|
||||
})
|
||||
if err != nil {
|
||||
return Result{}, shared.ServiceUnavailable(err)
|
||||
}
|
||||
if err := revokeResult.Validate(); err != nil {
|
||||
return Result{}, shared.InternalError(err)
|
||||
}
|
||||
|
||||
for _, record := range revokeResult.Sessions {
|
||||
if err := shared.PublishSessionProjectionWithTelemetry(ctx, s.publisher, record, s.telemetry, "block_user"); err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
affectedDeviceSessionIDs = append(affectedDeviceSessionIDs, record.ID.String())
|
||||
}
|
||||
if revokeResult.Outcome == ports.RevokeUserSessionsOutcomeNoActiveSessions {
|
||||
if err := s.republishCurrentRevokedSessions(ctx, storeResult.UserID); err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
}
|
||||
affectedSessionCount = int64(len(revokeResult.Sessions))
|
||||
if affectedSessionCount > 0 {
|
||||
s.telemetry.RecordSessionRevocations(ctx, "block_user", devicesession.RevokeReasonUserBlocked.String(), affectedSessionCount)
|
||||
}
|
||||
}
|
||||
|
||||
result = Result{
|
||||
Outcome: string(storeResult.Outcome),
|
||||
SubjectKind: subjectKind,
|
||||
SubjectValue: subjectValue,
|
||||
AffectedSessionCount: affectedSessionCount,
|
||||
AffectedDeviceSessionIDs: affectedDeviceSessionIDs,
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *Service) blockSubject(ctx context.Context, input Input) (string, string, ports.BlockUserResult, error) {
|
||||
userID := shared.NormalizeString(input.UserID)
|
||||
email := shared.NormalizeString(input.Email)
|
||||
|
||||
switch {
|
||||
case userID == "" && email == "":
|
||||
return "", "", ports.BlockUserResult{}, shared.InvalidRequest("exactly one of user_id or email must be provided")
|
||||
case userID != "" && email != "":
|
||||
return "", "", ports.BlockUserResult{}, shared.InvalidRequest("exactly one of user_id or email must be provided")
|
||||
case userID != "":
|
||||
parsedUserID, err := shared.ParseUserID(userID)
|
||||
if err != nil {
|
||||
return "", "", ports.BlockUserResult{}, err
|
||||
}
|
||||
reasonCode, err := parseBlockReasonCode(input.ReasonCode)
|
||||
if err != nil {
|
||||
return "", "", ports.BlockUserResult{}, err
|
||||
}
|
||||
|
||||
result, err := s.userDirectory.BlockByUserID(ctx, ports.BlockUserByIDInput{
|
||||
UserID: parsedUserID,
|
||||
ReasonCode: reasonCode,
|
||||
})
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, ports.ErrNotFound):
|
||||
return "", "", ports.BlockUserResult{}, shared.SubjectNotFound()
|
||||
default:
|
||||
return "", "", ports.BlockUserResult{}, shared.ServiceUnavailable(err)
|
||||
}
|
||||
}
|
||||
if err := result.Validate(); err != nil {
|
||||
return "", "", ports.BlockUserResult{}, shared.InternalError(err)
|
||||
}
|
||||
s.telemetry.RecordUserDirectoryOutcome(ctx, "block_by_user_id", string(result.Outcome))
|
||||
|
||||
return SubjectKindUserID, parsedUserID.String(), result, nil
|
||||
default:
|
||||
parsedEmail, err := shared.ParseEmail(email)
|
||||
if err != nil {
|
||||
return "", "", ports.BlockUserResult{}, err
|
||||
}
|
||||
reasonCode, err := parseBlockReasonCode(input.ReasonCode)
|
||||
if err != nil {
|
||||
return "", "", ports.BlockUserResult{}, err
|
||||
}
|
||||
|
||||
result, err := s.userDirectory.BlockByEmail(ctx, ports.BlockUserByEmailInput{
|
||||
Email: parsedEmail,
|
||||
ReasonCode: reasonCode,
|
||||
})
|
||||
if err != nil {
|
||||
return "", "", ports.BlockUserResult{}, shared.ServiceUnavailable(err)
|
||||
}
|
||||
if err := result.Validate(); err != nil {
|
||||
return "", "", ports.BlockUserResult{}, shared.InternalError(err)
|
||||
}
|
||||
s.telemetry.RecordUserDirectoryOutcome(ctx, "block_by_email", string(result.Outcome))
|
||||
|
||||
return SubjectKindEmail, parsedEmail.String(), result, nil
|
||||
}
|
||||
}
|
||||
|
||||
func parseBlockReasonCode(value string) (userresolution.BlockReasonCode, error) {
|
||||
reasonCode := userresolution.BlockReasonCode(shared.NormalizeString(value))
|
||||
if err := reasonCode.Validate(); err != nil {
|
||||
return "", shared.InvalidRequest(err.Error())
|
||||
}
|
||||
|
||||
return reasonCode, nil
|
||||
}
|
||||
|
||||
func (s *Service) republishCurrentRevokedSessions(ctx context.Context, userID common.UserID) error {
|
||||
records, err := s.sessionStore.ListByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return shared.ServiceUnavailable(err)
|
||||
}
|
||||
|
||||
for _, record := range records {
|
||||
if record.Status != devicesession.StatusRevoked {
|
||||
continue
|
||||
}
|
||||
if err := shared.PublishSessionProjectionWithTelemetry(ctx, s.publisher, record, s.telemetry, "block_user_repair"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func namedLogger(logger *zap.Logger, name string) *zap.Logger {
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
|
||||
return logger.Named(name)
|
||||
}
|
||||
@@ -0,0 +1,237 @@
|
||||
package blockuser
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/devicesession"
|
||||
"galaxy/authsession/internal/domain/gatewayprojection"
|
||||
"galaxy/authsession/internal/domain/userresolution"
|
||||
"galaxy/authsession/internal/service/shared"
|
||||
"galaxy/authsession/internal/testkit"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExecuteBlocksByUserIDAndRevokesSessions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
userDirectory := &testkit.InMemoryUserDirectory{}
|
||||
store := &testkit.InMemorySessionStore{}
|
||||
publisher := &testkit.RecordingProjectionPublisher{}
|
||||
if err := userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil {
|
||||
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
|
||||
}
|
||||
if err := store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())); err != nil {
|
||||
require.Failf(t, "test failed", "Create() returned error: %v", err)
|
||||
}
|
||||
|
||||
service, err := New(userDirectory, store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := service.Execute(context.Background(), Input{
|
||||
UserID: "user-1",
|
||||
ReasonCode: "policy_block",
|
||||
ActorType: "admin",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "blocked", result.Outcome)
|
||||
assert.EqualValues(t, 1, result.AffectedSessionCount)
|
||||
assert.Equal(t, SubjectKindUserID, result.SubjectKind)
|
||||
assert.Equal(t, "user-1", result.SubjectValue)
|
||||
assert.Equal(t, []string{"device-session-1"}, result.AffectedDeviceSessionIDs)
|
||||
|
||||
stored, getErr := store.Get(context.Background(), common.DeviceSessionID("device-session-1"))
|
||||
require.NoError(t, getErr)
|
||||
require.NotNil(t, stored.Revocation)
|
||||
assert.Equal(t, devicesession.StatusRevoked, stored.Status)
|
||||
assert.Equal(t, devicesession.RevokeReasonUserBlocked, stored.Revocation.ReasonCode)
|
||||
assert.Equal(t, common.RevokeActorType("admin"), stored.Revocation.ActorType)
|
||||
|
||||
resolution, resolveErr := userDirectory.ResolveByEmail(context.Background(), common.Email("pilot@example.com"))
|
||||
require.NoError(t, resolveErr)
|
||||
assert.Equal(t, userresolution.KindBlocked, resolution.Kind)
|
||||
assert.Equal(t, userresolution.BlockReasonCode("policy_block"), resolution.BlockReasonCode)
|
||||
|
||||
published := publisher.PublishedSnapshots()
|
||||
require.Len(t, published, 1)
|
||||
assert.Equal(t, gatewayprojection.StatusRevoked, published[0].Status)
|
||||
assert.Equal(t, devicesession.RevokeReasonUserBlocked, published[0].RevokeReasonCode)
|
||||
assert.Equal(t, common.RevokeActorType("admin"), published[0].RevokeActorType)
|
||||
}
|
||||
|
||||
func TestExecuteBlocksByEmailWithoutExistingUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
userDirectory := &testkit.InMemoryUserDirectory{}
|
||||
publisher := &testkit.RecordingProjectionPublisher{}
|
||||
service, err := New(userDirectory, &testkit.InMemorySessionStore{}, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := service.Execute(context.Background(), Input{
|
||||
Email: "pilot@example.com",
|
||||
ReasonCode: "policy_block",
|
||||
ActorType: "admin",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "blocked", result.Outcome)
|
||||
assert.EqualValues(t, 0, result.AffectedSessionCount)
|
||||
assert.Equal(t, SubjectKindEmail, result.SubjectKind)
|
||||
assert.Equal(t, "pilot@example.com", result.SubjectValue)
|
||||
require.NotNil(t, result.AffectedDeviceSessionIDs)
|
||||
assert.Empty(t, result.AffectedDeviceSessionIDs)
|
||||
|
||||
resolution, resolveErr := userDirectory.ResolveByEmail(context.Background(), common.Email("pilot@example.com"))
|
||||
require.NoError(t, resolveErr)
|
||||
assert.Equal(t, userresolution.KindBlocked, resolution.Kind)
|
||||
assert.Equal(t, userresolution.BlockReasonCode("policy_block"), resolution.BlockReasonCode)
|
||||
assert.Empty(t, publisher.PublishedSnapshots())
|
||||
}
|
||||
|
||||
func TestExecuteBlocksByEmailWithExistingUserAndRevokesSessions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
userDirectory := &testkit.InMemoryUserDirectory{}
|
||||
store := &testkit.InMemorySessionStore{}
|
||||
publisher := &testkit.RecordingProjectionPublisher{}
|
||||
if err := userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil {
|
||||
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
|
||||
}
|
||||
if err := store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())); err != nil {
|
||||
require.Failf(t, "test failed", "Create() returned error: %v", err)
|
||||
}
|
||||
|
||||
service, err := New(userDirectory, store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := service.Execute(context.Background(), Input{
|
||||
Email: "pilot@example.com",
|
||||
ReasonCode: "policy_block",
|
||||
ActorType: "admin",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "blocked", result.Outcome)
|
||||
assert.EqualValues(t, 1, result.AffectedSessionCount)
|
||||
assert.Equal(t, []string{"device-session-1"}, result.AffectedDeviceSessionIDs)
|
||||
|
||||
stored, getErr := store.Get(context.Background(), common.DeviceSessionID("device-session-1"))
|
||||
require.NoError(t, getErr)
|
||||
require.NotNil(t, stored.Revocation)
|
||||
assert.Equal(t, devicesession.RevokeReasonUserBlocked, stored.Revocation.ReasonCode)
|
||||
assert.Equal(t, common.RevokeActorType("admin"), stored.Revocation.ActorType)
|
||||
|
||||
resolution, resolveErr := userDirectory.ResolveByEmail(context.Background(), common.Email("pilot@example.com"))
|
||||
require.NoError(t, resolveErr)
|
||||
assert.Equal(t, userresolution.KindBlocked, resolution.Kind)
|
||||
assert.Equal(t, userresolution.BlockReasonCode("policy_block"), resolution.BlockReasonCode)
|
||||
|
||||
published := publisher.PublishedSnapshots()
|
||||
require.Len(t, published, 1)
|
||||
assert.Equal(t, devicesession.RevokeReasonUserBlocked, published[0].RevokeReasonCode)
|
||||
}
|
||||
|
||||
func TestExecuteReturnsSubjectNotFoundForUnknownUserID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
service, err := New(&testkit.InMemoryUserDirectory{}, &testkit.InMemorySessionStore{}, &testkit.RecordingProjectionPublisher{}, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "New() returned error: %v", err)
|
||||
}
|
||||
|
||||
_, err = service.Execute(context.Background(), Input{
|
||||
UserID: "missing",
|
||||
ReasonCode: "policy_block",
|
||||
ActorType: "admin",
|
||||
})
|
||||
assert.Equal(t, shared.ErrorCodeSubjectNotFound, shared.CodeOf(err))
|
||||
}
|
||||
|
||||
func TestExecuteAlreadyBlockedStillRevokesLingeringSessions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
userDirectory := &testkit.InMemoryUserDirectory{}
|
||||
store := &testkit.InMemorySessionStore{}
|
||||
publisher := &testkit.RecordingProjectionPublisher{}
|
||||
if err := userDirectory.SeedBlockedUser(common.Email("pilot@example.com"), common.UserID("user-1"), userresolution.BlockReasonCode("policy_block")); err != nil {
|
||||
require.Failf(t, "test failed", "SeedBlockedUser() returned error: %v", err)
|
||||
}
|
||||
if err := store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())); err != nil {
|
||||
require.Failf(t, "test failed", "Create() returned error: %v", err)
|
||||
}
|
||||
|
||||
service, err := New(userDirectory, store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := service.Execute(context.Background(), Input{
|
||||
Email: "pilot@example.com",
|
||||
ReasonCode: "policy_block",
|
||||
ActorType: "admin",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "already_blocked", result.Outcome)
|
||||
assert.EqualValues(t, 1, result.AffectedSessionCount)
|
||||
assert.Equal(t, []string{"device-session-1"}, result.AffectedDeviceSessionIDs)
|
||||
|
||||
stored, getErr := store.Get(context.Background(), common.DeviceSessionID("device-session-1"))
|
||||
require.NoError(t, getErr)
|
||||
require.NotNil(t, stored.Revocation)
|
||||
assert.Equal(t, devicesession.RevokeReasonUserBlocked, stored.Revocation.ReasonCode)
|
||||
assert.Equal(t, common.RevokeActorType("admin"), stored.Revocation.ActorType)
|
||||
|
||||
published := publisher.PublishedSnapshots()
|
||||
require.Len(t, published, 1)
|
||||
assert.Equal(t, devicesession.RevokeReasonUserBlocked, published[0].RevokeReasonCode)
|
||||
}
|
||||
|
||||
func TestExecuteReturnsServiceUnavailableWhenPublishFails(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
userDirectory := &testkit.InMemoryUserDirectory{}
|
||||
store := &testkit.InMemorySessionStore{}
|
||||
publisher := &testkit.RecordingProjectionPublisher{Err: errors.New("publish failed")}
|
||||
if err := userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil {
|
||||
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
|
||||
}
|
||||
if err := store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())); err != nil {
|
||||
require.Failf(t, "test failed", "Create() returned error: %v", err)
|
||||
}
|
||||
|
||||
service, err := New(userDirectory, store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = service.Execute(context.Background(), Input{
|
||||
UserID: "user-1",
|
||||
ReasonCode: "policy_block",
|
||||
ActorType: "admin",
|
||||
})
|
||||
assert.Equal(t, shared.ErrorCodeServiceUnavailable, shared.CodeOf(err))
|
||||
|
||||
stored, getErr := store.Get(context.Background(), common.DeviceSessionID("device-session-1"))
|
||||
require.NoError(t, getErr)
|
||||
require.NotNil(t, stored.Revocation)
|
||||
assert.Equal(t, devicesession.RevokeReasonUserBlocked, stored.Revocation.ReasonCode)
|
||||
|
||||
resolution, resolveErr := userDirectory.ResolveByEmail(context.Background(), common.Email("pilot@example.com"))
|
||||
require.NoError(t, resolveErr)
|
||||
assert.Equal(t, userresolution.KindBlocked, resolution.Kind)
|
||||
assert.Equal(t, userresolution.BlockReasonCode("policy_block"), resolution.BlockReasonCode)
|
||||
}
|
||||
|
||||
func activeSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session {
|
||||
key, err := common.NewClientPublicKey(make([]byte, 32))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return devicesession.Session{
|
||||
ID: common.DeviceSessionID(deviceSessionID),
|
||||
UserID: common.UserID(userID),
|
||||
ClientPublicKey: key,
|
||||
Status: devicesession.StatusActive,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
package blockuser
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
stubuserservice "galaxy/authsession/internal/adapters/userservice"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/testkit"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExecuteWithRuntimeStubUserDirectory(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("blocks by email through runtime stub", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
userDirectory := &stubuserservice.StubDirectory{}
|
||||
require.NoError(t, userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
|
||||
|
||||
store := &testkit.InMemorySessionStore{}
|
||||
require.NoError(t, store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())))
|
||||
|
||||
service, err := New(userDirectory, store, &testkit.RecordingProjectionPublisher{}, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := service.Execute(context.Background(), Input{
|
||||
Email: "pilot@example.com",
|
||||
ReasonCode: "policy_block",
|
||||
ActorType: "admin",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, SubjectKindEmail, result.SubjectKind)
|
||||
assert.Equal(t, "blocked", result.Outcome)
|
||||
assert.EqualValues(t, 1, result.AffectedSessionCount)
|
||||
})
|
||||
|
||||
t.Run("blocks by user id through runtime stub", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
userDirectory := &stubuserservice.StubDirectory{}
|
||||
require.NoError(t, userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
|
||||
|
||||
service, err := New(userDirectory, &testkit.InMemorySessionStore{}, &testkit.RecordingProjectionPublisher{}, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := service.Execute(context.Background(), Input{
|
||||
UserID: "user-1",
|
||||
ReasonCode: "policy_block",
|
||||
ActorType: "admin",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, SubjectKindUserID, result.SubjectKind)
|
||||
assert.Equal(t, "blocked", result.Outcome)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
package confirmemailcode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/challenge"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/service/shared"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExecuteReturnsInvalidCodeForThrottledChallengeWithoutConsumingAttempts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
deps := newConfirmDeps(t)
|
||||
record := sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))
|
||||
record.Status = challenge.StatusDeliveryThrottled
|
||||
record.DeliveryState = challenge.DeliveryThrottled
|
||||
require.NoError(t, record.Validate())
|
||||
require.NoError(t, deps.challengeStore.Create(context.Background(), record))
|
||||
|
||||
service := mustNewConfirmService(t, deps)
|
||||
_, err := service.Execute(context.Background(), Input{
|
||||
ChallengeID: "challenge-1",
|
||||
Code: "654321",
|
||||
ClientPublicKey: publicKeyString(),
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, shared.ErrorCodeInvalidCode, shared.CodeOf(err))
|
||||
|
||||
updated, getErr := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
|
||||
require.NoError(t, getErr)
|
||||
assert.Equal(t, 0, updated.Attempts.Confirm)
|
||||
assert.Equal(t, challenge.StatusDeliveryThrottled, updated.Status)
|
||||
}
|
||||
@@ -0,0 +1,106 @@
|
||||
package confirmemailcode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/challenge"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/devicesession"
|
||||
"galaxy/authsession/internal/service/shared"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExecuteConfirmsChallengeAfterTransientProjectionPublishFailures(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
deps := newConfirmDeps(t)
|
||||
deps.publisher.Errors = []error{errors.New("publish failed"), nil}
|
||||
require.NoError(t, deps.userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
|
||||
require.NoError(t, deps.challengeStore.Create(
|
||||
context.Background(),
|
||||
sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute)),
|
||||
))
|
||||
|
||||
service := mustNewConfirmService(t, deps)
|
||||
result, err := service.Execute(context.Background(), Input{
|
||||
ChallengeID: "challenge-1",
|
||||
Code: "654321",
|
||||
ClientPublicKey: publicKeyString(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "device-session-1", result.DeviceSessionID)
|
||||
require.Len(t, deps.publisher.PublishedSnapshots(), 2)
|
||||
}
|
||||
|
||||
func TestExecuteConfirmedRetryRepublishesAfterTransientProjectionPublishFailures(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
deps := newConfirmDeps(t)
|
||||
deps.publisher.Errors = []error{errors.New("publish failed"), nil}
|
||||
key := mustClientPublicKey(t, publicKeyString())
|
||||
require.NoError(t, deps.challengeStore.Create(
|
||||
context.Background(),
|
||||
confirmedChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", "device-session-1", key, deps.now.Add(-time.Minute), deps.now.Add(time.Minute)),
|
||||
))
|
||||
require.NoError(t, deps.sessionStore.Create(
|
||||
context.Background(),
|
||||
activeSessionFixture("device-session-1", "user-1", key, deps.now.Add(-time.Minute)),
|
||||
))
|
||||
|
||||
service := mustNewConfirmService(t, deps)
|
||||
result, err := service.Execute(context.Background(), Input{
|
||||
ChallengeID: "challenge-1",
|
||||
Code: "654321",
|
||||
ClientPublicKey: publicKeyString(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "device-session-1", result.DeviceSessionID)
|
||||
require.Len(t, deps.publisher.PublishedSnapshots(), 2)
|
||||
}
|
||||
|
||||
func TestExecuteRepairsProjectionOnIdenticalRetryAfterExhaustedPublishRetries(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
deps := newConfirmDeps(t)
|
||||
deps.publisher.Err = errors.New("publish failed")
|
||||
require.NoError(t, deps.userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
|
||||
require.NoError(t, deps.challengeStore.Create(
|
||||
context.Background(),
|
||||
sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute)),
|
||||
))
|
||||
|
||||
service := mustNewConfirmService(t, deps)
|
||||
_, err := service.Execute(context.Background(), Input{
|
||||
ChallengeID: "challenge-1",
|
||||
Code: "654321",
|
||||
ClientPublicKey: publicKeyString(),
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, shared.ErrorCodeServiceUnavailable, shared.CodeOf(err))
|
||||
require.Len(t, deps.publisher.PublishedSnapshots(), shared.MaxProjectionPublishAttempts)
|
||||
|
||||
sessionRecord, getErr := deps.sessionStore.Get(context.Background(), common.DeviceSessionID("device-session-1"))
|
||||
require.NoError(t, getErr)
|
||||
assert.Equal(t, devicesession.StatusActive, sessionRecord.Status)
|
||||
|
||||
challengeRecord, getErr := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
|
||||
require.NoError(t, getErr)
|
||||
assert.Equal(t, challenge.StatusConfirmedPendingExpire, challengeRecord.Status)
|
||||
require.NotNil(t, challengeRecord.Confirmation)
|
||||
|
||||
deps.publisher.Err = nil
|
||||
|
||||
result, err := service.Execute(context.Background(), Input{
|
||||
ChallengeID: "challenge-1",
|
||||
Code: "654321",
|
||||
ClientPublicKey: publicKeyString(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "device-session-1", result.DeviceSessionID)
|
||||
require.Len(t, deps.publisher.PublishedSnapshots(), shared.MaxProjectionPublishAttempts+1)
|
||||
}
|
||||
@@ -0,0 +1,588 @@
|
||||
// Package confirmemailcode implements the public confirm-email-code use case.
|
||||
package confirmemailcode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/challenge"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/devicesession"
|
||||
"galaxy/authsession/internal/domain/sessionlimit"
|
||||
"galaxy/authsession/internal/ports"
|
||||
"galaxy/authsession/internal/service/shared"
|
||||
"galaxy/authsession/internal/telemetry"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
revokeReasonConfirmRace common.RevokeReasonCode = "confirm_race_repair"
|
||||
revokeActorTypeService common.RevokeActorType = "service"
|
||||
revokeActorIDService = "confirmemailcode"
|
||||
)
|
||||
|
||||
// Input describes one public confirm-email-code request.
|
||||
type Input struct {
|
||||
// ChallengeID identifies the challenge that should be confirmed.
|
||||
ChallengeID string
|
||||
|
||||
// Code is the cleartext confirmation code submitted by the caller.
|
||||
Code string
|
||||
|
||||
// ClientPublicKey is the base64-encoded raw 32-byte Ed25519 public key that
|
||||
// should be registered for the created device session.
|
||||
ClientPublicKey string
|
||||
}
|
||||
|
||||
// Result describes one public confirm-email-code response.
|
||||
type Result struct {
|
||||
// DeviceSessionID is the stable identifier of the created or idempotently
|
||||
// recovered device session.
|
||||
DeviceSessionID string
|
||||
}
|
||||
|
||||
// Service executes the public confirm-email-code use case.
|
||||
type Service struct {
|
||||
challengeStore ports.ChallengeStore
|
||||
sessionStore ports.SessionStore
|
||||
userDirectory ports.UserDirectory
|
||||
configProvider ports.ConfigProvider
|
||||
publisher ports.GatewaySessionProjectionPublisher
|
||||
idGenerator ports.IDGenerator
|
||||
codeHasher ports.CodeHasher
|
||||
clock ports.Clock
|
||||
logger *zap.Logger
|
||||
telemetry *telemetry.Runtime
|
||||
}
|
||||
|
||||
// New returns a confirm-email-code service wired to the required ports.
|
||||
func New(
|
||||
challengeStore ports.ChallengeStore,
|
||||
sessionStore ports.SessionStore,
|
||||
userDirectory ports.UserDirectory,
|
||||
configProvider ports.ConfigProvider,
|
||||
publisher ports.GatewaySessionProjectionPublisher,
|
||||
idGenerator ports.IDGenerator,
|
||||
codeHasher ports.CodeHasher,
|
||||
clock ports.Clock,
|
||||
) (*Service, error) {
|
||||
return NewWithTelemetry(
|
||||
challengeStore,
|
||||
sessionStore,
|
||||
userDirectory,
|
||||
configProvider,
|
||||
publisher,
|
||||
idGenerator,
|
||||
codeHasher,
|
||||
clock,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
// NewWithTelemetry returns a confirm-email-code service wired to the required
|
||||
// ports plus the optional Stage-17 telemetry runtime.
|
||||
func NewWithTelemetry(
|
||||
challengeStore ports.ChallengeStore,
|
||||
sessionStore ports.SessionStore,
|
||||
userDirectory ports.UserDirectory,
|
||||
configProvider ports.ConfigProvider,
|
||||
publisher ports.GatewaySessionProjectionPublisher,
|
||||
idGenerator ports.IDGenerator,
|
||||
codeHasher ports.CodeHasher,
|
||||
clock ports.Clock,
|
||||
telemetryRuntime *telemetry.Runtime,
|
||||
) (*Service, error) {
|
||||
return NewWithObservability(
|
||||
challengeStore,
|
||||
sessionStore,
|
||||
userDirectory,
|
||||
configProvider,
|
||||
publisher,
|
||||
idGenerator,
|
||||
codeHasher,
|
||||
clock,
|
||||
nil,
|
||||
telemetryRuntime,
|
||||
)
|
||||
}
|
||||
|
||||
// NewWithObservability returns a confirm-email-code service wired to the
|
||||
// required ports plus optional structured logging and telemetry dependencies.
|
||||
func NewWithObservability(
|
||||
challengeStore ports.ChallengeStore,
|
||||
sessionStore ports.SessionStore,
|
||||
userDirectory ports.UserDirectory,
|
||||
configProvider ports.ConfigProvider,
|
||||
publisher ports.GatewaySessionProjectionPublisher,
|
||||
idGenerator ports.IDGenerator,
|
||||
codeHasher ports.CodeHasher,
|
||||
clock ports.Clock,
|
||||
logger *zap.Logger,
|
||||
telemetryRuntime *telemetry.Runtime,
|
||||
) (*Service, error) {
|
||||
switch {
|
||||
case challengeStore == nil:
|
||||
return nil, fmt.Errorf("confirmemailcode: challenge store must not be nil")
|
||||
case sessionStore == nil:
|
||||
return nil, fmt.Errorf("confirmemailcode: session store must not be nil")
|
||||
case userDirectory == nil:
|
||||
return nil, fmt.Errorf("confirmemailcode: user directory must not be nil")
|
||||
case configProvider == nil:
|
||||
return nil, fmt.Errorf("confirmemailcode: config provider must not be nil")
|
||||
case publisher == nil:
|
||||
return nil, fmt.Errorf("confirmemailcode: projection publisher must not be nil")
|
||||
case idGenerator == nil:
|
||||
return nil, fmt.Errorf("confirmemailcode: id generator must not be nil")
|
||||
case codeHasher == nil:
|
||||
return nil, fmt.Errorf("confirmemailcode: code hasher must not be nil")
|
||||
case clock == nil:
|
||||
return nil, fmt.Errorf("confirmemailcode: clock must not be nil")
|
||||
default:
|
||||
return &Service{
|
||||
challengeStore: challengeStore,
|
||||
sessionStore: sessionStore,
|
||||
userDirectory: userDirectory,
|
||||
configProvider: configProvider,
|
||||
publisher: publisher,
|
||||
idGenerator: idGenerator,
|
||||
codeHasher: codeHasher,
|
||||
clock: clock,
|
||||
logger: namedLogger(logger, "confirm_email_code"),
|
||||
telemetry: telemetryRuntime,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Execute validates one challenge confirmation attempt, creates a device
|
||||
// session when policy allows it, and handles short-window idempotent retries.
|
||||
func (s *Service) Execute(ctx context.Context, input Input) (result Result, err error) {
|
||||
logFields := []zap.Field{
|
||||
zap.String("component", "service"),
|
||||
zap.String("use_case", "confirm_email_code"),
|
||||
}
|
||||
defer func() {
|
||||
outcome := string(telemetry.ConfirmEmailCodeOutcomeSuccess)
|
||||
if err != nil {
|
||||
outcome = shared.CodeOf(err)
|
||||
if outcome == "" {
|
||||
outcome = shared.ErrorCodeServiceUnavailable
|
||||
}
|
||||
}
|
||||
s.telemetry.RecordConfirmEmailCode(ctx, outcome)
|
||||
logFields = append(logFields, zap.String("outcome", outcome))
|
||||
if result.DeviceSessionID != "" {
|
||||
logFields = append(logFields, zap.String("device_session_id", result.DeviceSessionID))
|
||||
}
|
||||
shared.LogServiceOutcome(s.logger, ctx, "confirm email code completed", err, logFields...)
|
||||
}()
|
||||
|
||||
challengeID, err := shared.ParseChallengeID(input.ChallengeID)
|
||||
if err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
logFields = append(logFields, zap.String("challenge_id", challengeID.String()))
|
||||
code, err := shared.ParseRequiredCode(input.Code)
|
||||
if err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
clientPublicKey, err := shared.ParseClientPublicKey(input.ClientPublicKey)
|
||||
if err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
|
||||
for attempt := 0; attempt < shared.MaxCompareAndSwapRetries; attempt++ {
|
||||
current, err := s.challengeStore.Get(ctx, challengeID)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, ports.ErrNotFound):
|
||||
return Result{}, shared.ChallengeNotFound()
|
||||
default:
|
||||
return Result{}, shared.ServiceUnavailable(err)
|
||||
}
|
||||
}
|
||||
|
||||
now := s.clock.Now().UTC()
|
||||
if expired, err := s.ensureChallengeNotExpired(ctx, current, now); err != nil {
|
||||
if errors.Is(err, ports.ErrConflict) {
|
||||
continue
|
||||
}
|
||||
return Result{}, err
|
||||
} else if expired {
|
||||
return Result{}, shared.ChallengeExpired()
|
||||
}
|
||||
|
||||
switch {
|
||||
case current.Status.IsConfirmedRetryState():
|
||||
return s.handleConfirmedRetry(ctx, current, code, clientPublicKey)
|
||||
case !current.Status.AcceptsFreshConfirm():
|
||||
return Result{}, shared.InvalidCode()
|
||||
}
|
||||
|
||||
match, err := s.codeHasher.Compare(current.CodeHash, code)
|
||||
if err != nil {
|
||||
return Result{}, shared.ServiceUnavailable(err)
|
||||
}
|
||||
if !match {
|
||||
if err := s.recordInvalidConfirmAttempt(ctx, current, now); err != nil {
|
||||
if errors.Is(err, ports.ErrConflict) {
|
||||
continue
|
||||
}
|
||||
return Result{}, err
|
||||
}
|
||||
|
||||
return Result{}, shared.InvalidCode()
|
||||
}
|
||||
|
||||
ensureUserResult, err := s.userDirectory.EnsureUserByEmail(ctx, current.Email)
|
||||
if err != nil {
|
||||
return Result{}, shared.ServiceUnavailable(err)
|
||||
}
|
||||
if err := ensureUserResult.Validate(); err != nil {
|
||||
return Result{}, shared.InternalError(err)
|
||||
}
|
||||
s.telemetry.RecordUserDirectoryOutcome(ctx, "ensure_user_by_email", string(ensureUserResult.Outcome))
|
||||
if !ensureUserResult.UserID.IsZero() {
|
||||
logFields = append(logFields, zap.String("user_id", ensureUserResult.UserID.String()))
|
||||
}
|
||||
if ensureUserResult.Outcome == ports.EnsureUserOutcomeBlocked {
|
||||
if err := s.markChallengeFailed(ctx, current, now); err != nil {
|
||||
if errors.Is(err, ports.ErrConflict) {
|
||||
continue
|
||||
}
|
||||
return Result{}, err
|
||||
}
|
||||
|
||||
return Result{}, shared.BlockedByPolicy()
|
||||
}
|
||||
|
||||
limitConfig, err := s.configProvider.LoadSessionLimit(ctx)
|
||||
if err != nil {
|
||||
return Result{}, shared.ServiceUnavailable(err)
|
||||
}
|
||||
decision, err := s.evaluateSessionLimit(ctx, ensureUserResult.UserID, limitConfig)
|
||||
if err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
if decision.Kind == sessionlimit.KindExceeded {
|
||||
s.telemetry.RecordSessionLimitRejection(ctx)
|
||||
return Result{}, shared.SessionLimitExceeded()
|
||||
}
|
||||
|
||||
sessionRecord, err := s.createSession(ctx, ensureUserResult.UserID, clientPublicKey, now)
|
||||
if err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
|
||||
next := current
|
||||
next.Status = challenge.StatusConfirmedPendingExpire
|
||||
next.ExpiresAt = now.Add(challenge.ConfirmedRetention)
|
||||
next.Abuse.LastAttemptAt = &now
|
||||
next.Confirmation = &challenge.Confirmation{
|
||||
SessionID: sessionRecord.ID,
|
||||
ClientPublicKey: clientPublicKey,
|
||||
ConfirmedAt: now,
|
||||
}
|
||||
if err := next.Validate(); err != nil {
|
||||
s.bestEffortRevokeSupersededSession(ctx, sessionRecord)
|
||||
return Result{}, shared.InternalError(err)
|
||||
}
|
||||
|
||||
if err := s.challengeStore.CompareAndSwap(ctx, current, next); err != nil {
|
||||
if errors.Is(err, ports.ErrConflict) {
|
||||
return s.handleCreateSessionCASConflict(ctx, challengeID, code, clientPublicKey, sessionRecord)
|
||||
}
|
||||
|
||||
s.bestEffortRevokeSupersededSession(ctx, sessionRecord)
|
||||
return Result{}, shared.ServiceUnavailable(err)
|
||||
}
|
||||
|
||||
// Publish the currently stored session view so a concurrent revoke/block
|
||||
// cannot overwrite source of truth with a stale active projection.
|
||||
currentSession, err := s.sessionStore.Get(ctx, sessionRecord.ID)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, ports.ErrNotFound):
|
||||
return Result{}, shared.InternalError(fmt.Errorf("confirmemailcode: newly created session %q was not found", sessionRecord.ID))
|
||||
default:
|
||||
return Result{}, shared.ServiceUnavailable(err)
|
||||
}
|
||||
}
|
||||
if err := s.publishSession(ctx, currentSession, "confirm_email_code"); err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
|
||||
return Result{DeviceSessionID: currentSession.ID.String()}, nil
|
||||
}
|
||||
|
||||
return Result{}, shared.ServiceUnavailable(fmt.Errorf("confirmemailcode: compare-and-swap retry limit exceeded"))
|
||||
}
|
||||
|
||||
func (s *Service) ensureChallengeNotExpired(ctx context.Context, current challenge.Challenge, now time.Time) (bool, error) {
|
||||
if current.IsExpiredAt(now) {
|
||||
if current.Status != challenge.StatusExpired && current.Status.CanTransitionTo(challenge.StatusExpired) {
|
||||
next := current
|
||||
next.Status = challenge.StatusExpired
|
||||
next.Abuse.LastAttemptAt = &now
|
||||
next.Confirmation = nil
|
||||
if err := next.Validate(); err != nil {
|
||||
return true, shared.InternalError(err)
|
||||
}
|
||||
if err := s.challengeStore.CompareAndSwap(ctx, current, next); err != nil {
|
||||
if !errors.Is(err, ports.ErrConflict) {
|
||||
return true, shared.ServiceUnavailable(err)
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (s *Service) handleConfirmedRetry(ctx context.Context, current challenge.Challenge, code string, clientPublicKey common.ClientPublicKey) (Result, error) {
|
||||
match, err := s.codeHasher.Compare(current.CodeHash, code)
|
||||
if err != nil {
|
||||
return Result{}, shared.ServiceUnavailable(err)
|
||||
}
|
||||
if !match {
|
||||
return Result{}, shared.InvalidCode()
|
||||
}
|
||||
if current.Confirmation == nil {
|
||||
return Result{}, shared.InternalError(fmt.Errorf("confirmemailcode: confirmed challenge is missing confirmation metadata"))
|
||||
}
|
||||
if current.Confirmation.ClientPublicKey.String() != clientPublicKey.String() {
|
||||
return Result{}, shared.InvalidCode()
|
||||
}
|
||||
|
||||
record, err := s.sessionStore.Get(ctx, current.Confirmation.SessionID)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, ports.ErrNotFound):
|
||||
return Result{}, shared.InternalError(fmt.Errorf("confirmemailcode: confirmed session %q was not found", current.Confirmation.SessionID))
|
||||
default:
|
||||
return Result{}, shared.ServiceUnavailable(err)
|
||||
}
|
||||
}
|
||||
if err := s.publishSession(ctx, record, "confirm_email_code_retry"); err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
|
||||
return Result{DeviceSessionID: record.ID.String()}, nil
|
||||
}
|
||||
|
||||
func (s *Service) recordInvalidConfirmAttempt(ctx context.Context, current challenge.Challenge, now time.Time) error {
|
||||
next := current
|
||||
next.Attempts.Confirm++
|
||||
next.Abuse.LastAttemptAt = &now
|
||||
if next.Attempts.Confirm >= challenge.MaxInvalidConfirmAttempts {
|
||||
next.Status = challenge.StatusFailed
|
||||
}
|
||||
if err := next.Validate(); err != nil {
|
||||
return shared.InternalError(err)
|
||||
}
|
||||
|
||||
if err := s.challengeStore.CompareAndSwap(ctx, current, next); err != nil {
|
||||
switch {
|
||||
case errors.Is(err, ports.ErrConflict):
|
||||
return err
|
||||
default:
|
||||
return shared.ServiceUnavailable(err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) markChallengeFailed(ctx context.Context, current challenge.Challenge, now time.Time) error {
|
||||
next := current
|
||||
next.Status = challenge.StatusFailed
|
||||
next.Abuse.LastAttemptAt = &now
|
||||
if err := next.Validate(); err != nil {
|
||||
return shared.InternalError(err)
|
||||
}
|
||||
|
||||
if err := s.challengeStore.CompareAndSwap(ctx, current, next); err != nil {
|
||||
switch {
|
||||
case errors.Is(err, ports.ErrConflict):
|
||||
return err
|
||||
default:
|
||||
return shared.ServiceUnavailable(err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) evaluateSessionLimit(ctx context.Context, userID common.UserID, config ports.SessionLimitConfig) (sessionlimit.Decision, error) {
|
||||
activeSessionCount, err := s.sessionStore.CountActiveByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return sessionlimit.Decision{}, shared.ServiceUnavailable(err)
|
||||
}
|
||||
|
||||
decision, err := shared.EvaluateSessionLimit(config, activeSessionCount)
|
||||
if err != nil {
|
||||
return sessionlimit.Decision{}, err
|
||||
}
|
||||
|
||||
return decision, nil
|
||||
}
|
||||
|
||||
func (s *Service) createSession(ctx context.Context, userID common.UserID, clientPublicKey common.ClientPublicKey, now time.Time) (devicesession.Session, error) {
|
||||
for attempt := 0; attempt < shared.MaxCompareAndSwapRetries; attempt++ {
|
||||
deviceSessionID, err := s.idGenerator.NewDeviceSessionID()
|
||||
if err != nil {
|
||||
return devicesession.Session{}, shared.ServiceUnavailable(err)
|
||||
}
|
||||
|
||||
record := devicesession.Session{
|
||||
ID: deviceSessionID,
|
||||
UserID: userID,
|
||||
ClientPublicKey: clientPublicKey,
|
||||
Status: devicesession.StatusActive,
|
||||
CreatedAt: now,
|
||||
}
|
||||
if err := record.Validate(); err != nil {
|
||||
return devicesession.Session{}, shared.InternalError(err)
|
||||
}
|
||||
|
||||
if err := s.sessionStore.Create(ctx, record); err != nil {
|
||||
if errors.Is(err, ports.ErrConflict) {
|
||||
continue
|
||||
}
|
||||
return devicesession.Session{}, shared.ServiceUnavailable(err)
|
||||
}
|
||||
s.telemetry.RecordSessionCreated(ctx)
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
return devicesession.Session{}, shared.ServiceUnavailable(fmt.Errorf("confirmemailcode: session id conflict retry limit exceeded"))
|
||||
}
|
||||
|
||||
func (s *Service) handleCreateSessionCASConflict(
|
||||
ctx context.Context,
|
||||
challengeID common.ChallengeID,
|
||||
code string,
|
||||
clientPublicKey common.ClientPublicKey,
|
||||
createdSession devicesession.Session,
|
||||
) (Result, error) {
|
||||
defer s.bestEffortRevokeSupersededSession(ctx, createdSession)
|
||||
|
||||
current, err := s.challengeStore.Get(ctx, challengeID)
|
||||
if err != nil {
|
||||
if errors.Is(err, ports.ErrNotFound) {
|
||||
return Result{}, shared.ServiceUnavailable(err)
|
||||
}
|
||||
return Result{}, shared.ServiceUnavailable(err)
|
||||
}
|
||||
|
||||
if current.Status != challenge.StatusConfirmedPendingExpire || current.Confirmation == nil {
|
||||
return Result{}, shared.ServiceUnavailable(fmt.Errorf("confirmemailcode: challenge %q changed to unexpected status %q after create", challengeID, current.Status))
|
||||
}
|
||||
|
||||
match, err := s.codeHasher.Compare(current.CodeHash, code)
|
||||
if err != nil {
|
||||
return Result{}, shared.ServiceUnavailable(err)
|
||||
}
|
||||
if !match || current.Confirmation.ClientPublicKey.String() != clientPublicKey.String() {
|
||||
return Result{}, shared.ServiceUnavailable(fmt.Errorf("confirmemailcode: challenge %q was confirmed by a different payload", challengeID))
|
||||
}
|
||||
|
||||
winningSession, err := s.sessionStore.Get(ctx, current.Confirmation.SessionID)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, ports.ErrNotFound):
|
||||
return Result{}, shared.InternalError(fmt.Errorf("confirmemailcode: winning session %q was not found", current.Confirmation.SessionID))
|
||||
default:
|
||||
return Result{}, shared.ServiceUnavailable(err)
|
||||
}
|
||||
}
|
||||
if err := s.publishSession(ctx, winningSession, "confirm_email_code_race_winner"); err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
|
||||
return Result{DeviceSessionID: winningSession.ID.String()}, nil
|
||||
}
|
||||
|
||||
func (s *Service) bestEffortRevokeSupersededSession(ctx context.Context, record devicesession.Session) {
|
||||
revocation := devicesession.Revocation{
|
||||
At: s.clock.Now().UTC(),
|
||||
ReasonCode: revokeReasonConfirmRace,
|
||||
ActorType: revokeActorTypeService,
|
||||
ActorID: revokeActorIDService,
|
||||
}
|
||||
if err := revocation.Validate(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
revokeResult, err := s.sessionStore.Revoke(ctx, ports.RevokeSessionInput{
|
||||
DeviceSessionID: record.ID,
|
||||
Revocation: revocation,
|
||||
})
|
||||
if err != nil {
|
||||
s.logger.Warn(
|
||||
"best-effort superseded session revoke failed",
|
||||
zap.String("component", "service"),
|
||||
zap.String("use_case", "confirm_email_code"),
|
||||
zap.String("operation", "confirm_email_code_race_cleanup"),
|
||||
zap.String("device_session_id", record.ID.String()),
|
||||
zap.String("reason_code", revocation.ReasonCode.String()),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
if err := revokeResult.Validate(); err != nil {
|
||||
s.logger.Warn(
|
||||
"best-effort superseded session revoke produced invalid result",
|
||||
zap.String("component", "service"),
|
||||
zap.String("use_case", "confirm_email_code"),
|
||||
zap.String("operation", "confirm_email_code_race_cleanup"),
|
||||
zap.String("device_session_id", record.ID.String()),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
if revokeResult.Outcome == ports.RevokeSessionOutcomeRevoked {
|
||||
s.telemetry.RecordSessionRevocations(ctx, "confirm_email_code_race_cleanup", revocation.ReasonCode.String(), 1)
|
||||
}
|
||||
|
||||
snapshot, err := shared.ToGatewayProjectionSnapshot(revokeResult.Session)
|
||||
if err != nil {
|
||||
s.logger.Warn(
|
||||
"best-effort superseded session snapshot mapping failed",
|
||||
zap.String("component", "service"),
|
||||
zap.String("use_case", "confirm_email_code"),
|
||||
zap.String("operation", "confirm_email_code_race_cleanup"),
|
||||
zap.String("device_session_id", revokeResult.Session.ID.String()),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
if err := shared.PublishProjectionSnapshotWithTelemetry(ctx, s.publisher, snapshot, s.telemetry, "confirm_email_code_race_cleanup"); err != nil {
|
||||
s.logger.Warn(
|
||||
"best-effort superseded session publish failed",
|
||||
zap.String("component", "service"),
|
||||
zap.String("use_case", "confirm_email_code"),
|
||||
zap.String("operation", "confirm_email_code_race_cleanup"),
|
||||
zap.String("device_session_id", revokeResult.Session.ID.String()),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) publishSession(ctx context.Context, record devicesession.Session, operation string) error {
|
||||
return shared.PublishSessionProjectionWithTelemetry(ctx, s.publisher, record, s.telemetry, operation)
|
||||
}
|
||||
|
||||
func namedLogger(logger *zap.Logger, name string) *zap.Logger {
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
|
||||
return logger.Named(name)
|
||||
}
|
||||
@@ -0,0 +1,682 @@
|
||||
package confirmemailcode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/challenge"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/devicesession"
|
||||
"galaxy/authsession/internal/domain/userresolution"
|
||||
"galaxy/authsession/internal/service/shared"
|
||||
"galaxy/authsession/internal/testkit"
|
||||
)
|
||||
|
||||
func TestExecuteConfirmsChallengeForExistingUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
deps := newConfirmDeps(t)
|
||||
if err := deps.userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil {
|
||||
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
|
||||
}
|
||||
if err := deps.challengeStore.Create(context.Background(), sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))); err != nil {
|
||||
require.Failf(t, "test failed", "Create() returned error: %v", err)
|
||||
}
|
||||
|
||||
service := mustNewConfirmService(t, deps)
|
||||
result, err := service.Execute(context.Background(), Input{
|
||||
ChallengeID: "challenge-1",
|
||||
Code: "654321",
|
||||
ClientPublicKey: publicKeyString(),
|
||||
})
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "Execute() returned error: %v", err)
|
||||
}
|
||||
if result.DeviceSessionID != "device-session-1" {
|
||||
require.Failf(t, "test failed", "Execute().DeviceSessionID = %q, want %q", result.DeviceSessionID, "device-session-1")
|
||||
}
|
||||
|
||||
record, err := deps.sessionStore.Get(context.Background(), common.DeviceSessionID("device-session-1"))
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "Get() returned error: %v", err)
|
||||
}
|
||||
if record.Status != devicesession.StatusActive {
|
||||
require.Failf(t, "test failed", "session status = %q, want %q", record.Status, devicesession.StatusActive)
|
||||
}
|
||||
|
||||
challengeRecord, err := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "Get() returned error: %v", err)
|
||||
}
|
||||
if challengeRecord.Status != challenge.StatusConfirmedPendingExpire || challengeRecord.Confirmation == nil {
|
||||
require.Failf(t, "test failed", "challenge status = %q, confirmation = %+v", challengeRecord.Status, challengeRecord.Confirmation)
|
||||
}
|
||||
if len(deps.publisher.PublishedSnapshots()) != 1 {
|
||||
require.Failf(t, "test failed", "PublishedSnapshots() length = %d, want 1", len(deps.publisher.PublishedSnapshots()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteConfirmsChallengeByCreatingUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
deps := newConfirmDeps(t)
|
||||
if err := deps.userDirectory.QueueCreatedUserIDs(common.UserID("user-created")); err != nil {
|
||||
require.Failf(t, "test failed", "QueueCreatedUserIDs() returned error: %v", err)
|
||||
}
|
||||
if err := deps.challengeStore.Create(context.Background(), sentChallengeFixture(t, deps.hasher, "challenge-1", "new@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))); err != nil {
|
||||
require.Failf(t, "test failed", "Create() returned error: %v", err)
|
||||
}
|
||||
|
||||
service := mustNewConfirmService(t, deps)
|
||||
result, err := service.Execute(context.Background(), Input{
|
||||
ChallengeID: "challenge-1",
|
||||
Code: "654321",
|
||||
ClientPublicKey: publicKeyString(),
|
||||
})
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "Execute() returned error: %v", err)
|
||||
}
|
||||
if result.DeviceSessionID != "device-session-1" {
|
||||
require.Failf(t, "test failed", "Execute().DeviceSessionID = %q, want %q", result.DeviceSessionID, "device-session-1")
|
||||
}
|
||||
|
||||
record, err := deps.sessionStore.Get(context.Background(), common.DeviceSessionID("device-session-1"))
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "Get() returned error: %v", err)
|
||||
}
|
||||
if record.UserID != common.UserID("user-created") {
|
||||
require.Failf(t, "test failed", "session user id = %q, want %q", record.UserID, common.UserID("user-created"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteConfirmsSuppressedChallenge(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
deps := newConfirmDeps(t)
|
||||
if err := deps.userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil {
|
||||
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
|
||||
}
|
||||
record := sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))
|
||||
record.Status = challenge.StatusDeliverySuppressed
|
||||
record.DeliveryState = challenge.DeliverySuppressed
|
||||
if err := record.Validate(); err != nil {
|
||||
require.Failf(t, "test failed", "Validate() returned error: %v", err)
|
||||
}
|
||||
if err := deps.challengeStore.Create(context.Background(), record); err != nil {
|
||||
require.Failf(t, "test failed", "Create() returned error: %v", err)
|
||||
}
|
||||
|
||||
service := mustNewConfirmService(t, deps)
|
||||
result, err := service.Execute(context.Background(), Input{
|
||||
ChallengeID: "challenge-1",
|
||||
Code: "654321",
|
||||
ClientPublicKey: publicKeyString(),
|
||||
})
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "Execute() returned error: %v", err)
|
||||
}
|
||||
if result.DeviceSessionID != "device-session-1" {
|
||||
require.Failf(t, "test failed", "Execute().DeviceSessionID = %q, want %q", result.DeviceSessionID, "device-session-1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteReturnsChallengeNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
service := mustNewConfirmService(t, newConfirmDeps(t))
|
||||
|
||||
_, err := service.Execute(context.Background(), Input{
|
||||
ChallengeID: "missing",
|
||||
Code: "654321",
|
||||
ClientPublicKey: publicKeyString(),
|
||||
})
|
||||
if shared.CodeOf(err) != shared.ErrorCodeChallengeNotFound {
|
||||
require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeChallengeNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteReturnsChallengeExpiredAndMarksExpired(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
deps := newConfirmDeps(t)
|
||||
if err := deps.challengeStore.Create(context.Background(), sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-2*time.Minute), deps.now.Add(-time.Second))); err != nil {
|
||||
require.Failf(t, "test failed", "Create() returned error: %v", err)
|
||||
}
|
||||
|
||||
service := mustNewConfirmService(t, deps)
|
||||
|
||||
_, err := service.Execute(context.Background(), Input{
|
||||
ChallengeID: "challenge-1",
|
||||
Code: "654321",
|
||||
ClientPublicKey: publicKeyString(),
|
||||
})
|
||||
if shared.CodeOf(err) != shared.ErrorCodeChallengeExpired {
|
||||
require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeChallengeExpired)
|
||||
}
|
||||
|
||||
record, err := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "Get() returned error: %v", err)
|
||||
}
|
||||
if record.Status != challenge.StatusExpired {
|
||||
require.Failf(t, "test failed", "challenge status = %q, want %q", record.Status, challenge.StatusExpired)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteReturnsChallengeExpiredForConfirmedChallengeAfterRetentionWindow(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
deps := newConfirmDeps(t)
|
||||
key, err := shared.ParseClientPublicKey(publicKeyString())
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "ParseClientPublicKey() returned error: %v", err)
|
||||
}
|
||||
record := confirmedChallengeFixture(
|
||||
t,
|
||||
deps.hasher,
|
||||
"challenge-1",
|
||||
"pilot@example.com",
|
||||
"654321",
|
||||
"device-session-1",
|
||||
key,
|
||||
deps.now.Add(-2*challenge.ConfirmedRetention),
|
||||
deps.now.Add(-time.Second),
|
||||
)
|
||||
if err := deps.challengeStore.Create(context.Background(), record); err != nil {
|
||||
require.Failf(t, "test failed", "Create() returned error: %v", err)
|
||||
}
|
||||
|
||||
service := mustNewConfirmService(t, deps)
|
||||
_, err = service.Execute(context.Background(), Input{
|
||||
ChallengeID: "challenge-1",
|
||||
Code: "654321",
|
||||
ClientPublicKey: publicKeyString(),
|
||||
})
|
||||
if shared.CodeOf(err) != shared.ErrorCodeChallengeExpired {
|
||||
require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeChallengeExpired)
|
||||
}
|
||||
|
||||
updated, err := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "Get() returned error: %v", err)
|
||||
}
|
||||
if updated.Status != challenge.StatusExpired {
|
||||
require.Failf(t, "test failed", "challenge status = %q, want %q", updated.Status, challenge.StatusExpired)
|
||||
}
|
||||
if updated.Confirmation != nil {
|
||||
require.Failf(t, "test failed", "Confirmation = %+v, want nil after expiration", updated.Confirmation)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteReturnsInvalidClientPublicKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
service := mustNewConfirmService(t, newConfirmDeps(t))
|
||||
|
||||
_, err := service.Execute(context.Background(), Input{
|
||||
ChallengeID: "challenge-1",
|
||||
Code: "654321",
|
||||
ClientPublicKey: "invalid",
|
||||
})
|
||||
if shared.CodeOf(err) != shared.ErrorCodeInvalidClientPublicKey {
|
||||
require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeInvalidClientPublicKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteInvalidCodeIncrementsAttempts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
deps := newConfirmDeps(t)
|
||||
if err := deps.challengeStore.Create(context.Background(), sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))); err != nil {
|
||||
require.Failf(t, "test failed", "Create() returned error: %v", err)
|
||||
}
|
||||
|
||||
service := mustNewConfirmService(t, deps)
|
||||
_, err := service.Execute(context.Background(), Input{
|
||||
ChallengeID: "challenge-1",
|
||||
Code: "000000",
|
||||
ClientPublicKey: publicKeyString(),
|
||||
})
|
||||
if shared.CodeOf(err) != shared.ErrorCodeInvalidCode {
|
||||
require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeInvalidCode)
|
||||
}
|
||||
|
||||
record, err := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "Get() returned error: %v", err)
|
||||
}
|
||||
if record.Attempts.Confirm != 1 {
|
||||
require.Failf(t, "test failed", "Attempts.Confirm = %d, want 1", record.Attempts.Confirm)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteFifthInvalidAttemptMarksChallengeFailed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
deps := newConfirmDeps(t)
|
||||
record := sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))
|
||||
record.Attempts.Confirm = 4
|
||||
if err := deps.challengeStore.Create(context.Background(), record); err != nil {
|
||||
require.Failf(t, "test failed", "Create() returned error: %v", err)
|
||||
}
|
||||
|
||||
service := mustNewConfirmService(t, deps)
|
||||
_, err := service.Execute(context.Background(), Input{
|
||||
ChallengeID: "challenge-1",
|
||||
Code: "000000",
|
||||
ClientPublicKey: publicKeyString(),
|
||||
})
|
||||
if shared.CodeOf(err) != shared.ErrorCodeInvalidCode {
|
||||
require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeInvalidCode)
|
||||
}
|
||||
|
||||
updated, err := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "Get() returned error: %v", err)
|
||||
}
|
||||
if updated.Status != challenge.StatusFailed {
|
||||
require.Failf(t, "test failed", "challenge status = %q, want %q", updated.Status, challenge.StatusFailed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteDoesNotCreateSessionAfterTooManyAttempts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
deps := newConfirmDeps(t)
|
||||
if err := deps.userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil {
|
||||
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
|
||||
}
|
||||
record := sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))
|
||||
record.Attempts.Confirm = challenge.MaxInvalidConfirmAttempts
|
||||
record.Status = challenge.StatusFailed
|
||||
if err := record.Validate(); err != nil {
|
||||
require.Failf(t, "test failed", "Validate() returned error: %v", err)
|
||||
}
|
||||
if err := deps.challengeStore.Create(context.Background(), record); err != nil {
|
||||
require.Failf(t, "test failed", "Create() returned error: %v", err)
|
||||
}
|
||||
|
||||
service := mustNewConfirmService(t, deps)
|
||||
_, err := service.Execute(context.Background(), Input{
|
||||
ChallengeID: "challenge-1",
|
||||
Code: "654321",
|
||||
ClientPublicKey: publicKeyString(),
|
||||
})
|
||||
if shared.CodeOf(err) != shared.ErrorCodeInvalidCode {
|
||||
require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeInvalidCode)
|
||||
}
|
||||
|
||||
if got, err := deps.sessionStore.CountActiveByUserID(context.Background(), common.UserID("user-1")); err != nil {
|
||||
require.Failf(t, "test failed", "CountActiveByUserID() returned error: %v", err)
|
||||
} else if got != 0 {
|
||||
require.Failf(t, "test failed", "CountActiveByUserID() = %d, want 0", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteReturnsSameSessionIDForIdempotentRetryAndRepublishes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
deps := newConfirmDeps(t)
|
||||
key, err := shared.ParseClientPublicKey(publicKeyString())
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "ParseClientPublicKey() returned error: %v", err)
|
||||
}
|
||||
record := confirmedChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", "device-session-1", key, deps.now.Add(-time.Minute), deps.now.Add(time.Minute))
|
||||
if err := deps.challengeStore.Create(context.Background(), record); err != nil {
|
||||
require.Failf(t, "test failed", "Create() returned error: %v", err)
|
||||
}
|
||||
if err := deps.sessionStore.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", key, deps.now.Add(-time.Minute))); err != nil {
|
||||
require.Failf(t, "test failed", "Create() returned error: %v", err)
|
||||
}
|
||||
|
||||
service := mustNewConfirmService(t, deps)
|
||||
result, err := service.Execute(context.Background(), Input{
|
||||
ChallengeID: "challenge-1",
|
||||
Code: "654321",
|
||||
ClientPublicKey: publicKeyString(),
|
||||
})
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "Execute() returned error: %v", err)
|
||||
}
|
||||
if result.DeviceSessionID != "device-session-1" {
|
||||
require.Failf(t, "test failed", "Execute().DeviceSessionID = %q, want %q", result.DeviceSessionID, "device-session-1")
|
||||
}
|
||||
if len(deps.publisher.PublishedSnapshots()) != 1 {
|
||||
require.Failf(t, "test failed", "PublishedSnapshots() length = %d, want 1", len(deps.publisher.PublishedSnapshots()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteReturnsInvalidCodeForDifferentKeyDuringIdempotentRetry(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
deps := newConfirmDeps(t)
|
||||
key, err := shared.ParseClientPublicKey(publicKeyString())
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "ParseClientPublicKey() returned error: %v", err)
|
||||
}
|
||||
record := confirmedChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", "device-session-1", key, deps.now.Add(-time.Minute), deps.now.Add(time.Minute))
|
||||
if err := deps.challengeStore.Create(context.Background(), record); err != nil {
|
||||
require.Failf(t, "test failed", "Create() returned error: %v", err)
|
||||
}
|
||||
if err := deps.sessionStore.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", key, deps.now.Add(-time.Minute))); err != nil {
|
||||
require.Failf(t, "test failed", "Create() returned error: %v", err)
|
||||
}
|
||||
|
||||
service := mustNewConfirmService(t, deps)
|
||||
_, err = service.Execute(context.Background(), Input{
|
||||
ChallengeID: "challenge-1",
|
||||
Code: "654321",
|
||||
ClientPublicKey: alternatePublicKeyString(),
|
||||
})
|
||||
if shared.CodeOf(err) != shared.ErrorCodeInvalidCode {
|
||||
require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeInvalidCode)
|
||||
}
|
||||
|
||||
updated, err := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "Get() returned error: %v", err)
|
||||
}
|
||||
if updated.Attempts.Confirm != 0 {
|
||||
require.Failf(t, "test failed", "Attempts.Confirm = %d, want 0", updated.Attempts.Confirm)
|
||||
}
|
||||
if updated.Confirmation == nil {
|
||||
require.FailNow(t, "Confirmation = nil, want metadata to stay intact")
|
||||
}
|
||||
if updated.Confirmation.SessionID != common.DeviceSessionID("device-session-1") {
|
||||
require.Failf(t, "test failed", "Confirmation.SessionID = %q, want %q", updated.Confirmation.SessionID, common.DeviceSessionID("device-session-1"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteReturnsInvalidCodeForNonConfirmableStates(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
status challenge.Status
|
||||
deliveryState challenge.DeliveryState
|
||||
}{
|
||||
{name: "pending send", status: challenge.StatusPendingSend, deliveryState: challenge.DeliveryPending},
|
||||
{name: "failed", status: challenge.StatusFailed, deliveryState: challenge.DeliveryFailed},
|
||||
{name: "cancelled", status: challenge.StatusCancelled, deliveryState: challenge.DeliverySent},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
deps := newConfirmDeps(t)
|
||||
record := sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))
|
||||
record.Status = tt.status
|
||||
record.DeliveryState = tt.deliveryState
|
||||
if err := record.Validate(); err != nil {
|
||||
require.Failf(t, "test failed", "Validate() returned error: %v", err)
|
||||
}
|
||||
if err := deps.challengeStore.Create(context.Background(), record); err != nil {
|
||||
require.Failf(t, "test failed", "Create() returned error: %v", err)
|
||||
}
|
||||
|
||||
service := mustNewConfirmService(t, deps)
|
||||
_, err := service.Execute(context.Background(), Input{
|
||||
ChallengeID: "challenge-1",
|
||||
Code: "654321",
|
||||
ClientPublicKey: publicKeyString(),
|
||||
})
|
||||
if shared.CodeOf(err) != shared.ErrorCodeInvalidCode {
|
||||
require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeInvalidCode)
|
||||
}
|
||||
|
||||
updated, err := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "Get() returned error: %v", err)
|
||||
}
|
||||
if updated.Attempts.Confirm != 0 {
|
||||
require.Failf(t, "test failed", "Attempts.Confirm = %d, want 0", updated.Attempts.Confirm)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteMarksChallengeFailedAndReturnsBlockedByPolicy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
deps := newConfirmDeps(t)
|
||||
if err := deps.userDirectory.SeedBlockedEmail(common.Email("pilot@example.com"), userresolution.BlockReasonCode("policy_block")); err != nil {
|
||||
require.Failf(t, "test failed", "SeedBlockedEmail() returned error: %v", err)
|
||||
}
|
||||
if err := deps.challengeStore.Create(context.Background(), sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))); err != nil {
|
||||
require.Failf(t, "test failed", "Create() returned error: %v", err)
|
||||
}
|
||||
|
||||
service := mustNewConfirmService(t, deps)
|
||||
_, err := service.Execute(context.Background(), Input{
|
||||
ChallengeID: "challenge-1",
|
||||
Code: "654321",
|
||||
ClientPublicKey: publicKeyString(),
|
||||
})
|
||||
if shared.CodeOf(err) != shared.ErrorCodeBlockedByPolicy {
|
||||
require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeBlockedByPolicy)
|
||||
}
|
||||
|
||||
record, err := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "Get() returned error: %v", err)
|
||||
}
|
||||
if record.Status != challenge.StatusFailed {
|
||||
require.Failf(t, "test failed", "challenge status = %q, want %q", record.Status, challenge.StatusFailed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteReturnsSessionLimitExceededWithoutConsumingChallenge(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
deps := newConfirmDeps(t)
|
||||
if err := deps.userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil {
|
||||
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
|
||||
}
|
||||
if err := deps.challengeStore.Create(context.Background(), sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))); err != nil {
|
||||
require.Failf(t, "test failed", "Create() returned error: %v", err)
|
||||
}
|
||||
if err := deps.sessionStore.Create(context.Background(), activeSessionFixture("device-session-existing", "user-1", mustClientPublicKey(t, publicKeyString()), deps.now.Add(-2*time.Minute))); err != nil {
|
||||
require.Failf(t, "test failed", "Create() returned error: %v", err)
|
||||
}
|
||||
limit := 1
|
||||
deps.configProvider.Config.ActiveSessionLimit = &limit
|
||||
|
||||
service := mustNewConfirmService(t, deps)
|
||||
_, err := service.Execute(context.Background(), Input{
|
||||
ChallengeID: "challenge-1",
|
||||
Code: "654321",
|
||||
ClientPublicKey: publicKeyString(),
|
||||
})
|
||||
if shared.CodeOf(err) != shared.ErrorCodeSessionLimitExceeded {
|
||||
require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeSessionLimitExceeded)
|
||||
}
|
||||
|
||||
record, err := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "Get() returned error: %v", err)
|
||||
}
|
||||
if record.Status != challenge.StatusSent {
|
||||
require.Failf(t, "test failed", "challenge status = %q, want %q", record.Status, challenge.StatusSent)
|
||||
}
|
||||
if record.Attempts.Confirm != 0 {
|
||||
require.Failf(t, "test failed", "Attempts.Confirm = %d, want 0", record.Attempts.Confirm)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteReturnsServiceUnavailableThenSucceedsIdempotentlyAfterPublishFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
deps := newConfirmDeps(t)
|
||||
deps.publisher.Err = errors.New("publish failed")
|
||||
if err := deps.userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil {
|
||||
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
|
||||
}
|
||||
if err := deps.challengeStore.Create(context.Background(), sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))); err != nil {
|
||||
require.Failf(t, "test failed", "Create() returned error: %v", err)
|
||||
}
|
||||
|
||||
service := mustNewConfirmService(t, deps)
|
||||
_, err := service.Execute(context.Background(), Input{
|
||||
ChallengeID: "challenge-1",
|
||||
Code: "654321",
|
||||
ClientPublicKey: publicKeyString(),
|
||||
})
|
||||
if shared.CodeOf(err) != shared.ErrorCodeServiceUnavailable {
|
||||
require.Failf(t, "test failed", "first Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeServiceUnavailable)
|
||||
}
|
||||
|
||||
deps.publisher.Err = nil
|
||||
result, err := service.Execute(context.Background(), Input{
|
||||
ChallengeID: "challenge-1",
|
||||
Code: "654321",
|
||||
ClientPublicKey: publicKeyString(),
|
||||
})
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "second Execute() returned error: %v", err)
|
||||
}
|
||||
if result.DeviceSessionID != "device-session-1" {
|
||||
require.Failf(t, "test failed", "second Execute().DeviceSessionID = %q, want %q", result.DeviceSessionID, "device-session-1")
|
||||
}
|
||||
}
|
||||
|
||||
type confirmDeps struct {
|
||||
challengeStore *testkit.InMemoryChallengeStore
|
||||
sessionStore *testkit.InMemorySessionStore
|
||||
userDirectory *testkit.InMemoryUserDirectory
|
||||
configProvider testkit.StaticConfigProvider
|
||||
publisher *testkit.RecordingProjectionPublisher
|
||||
idGenerator *testkit.SequenceIDGenerator
|
||||
hasher testkit.DeterministicCodeHasher
|
||||
now time.Time
|
||||
}
|
||||
|
||||
func newConfirmDeps(t *testing.T) confirmDeps {
|
||||
t.Helper()
|
||||
|
||||
return confirmDeps{
|
||||
challengeStore: &testkit.InMemoryChallengeStore{},
|
||||
sessionStore: &testkit.InMemorySessionStore{},
|
||||
userDirectory: &testkit.InMemoryUserDirectory{},
|
||||
configProvider: testkit.StaticConfigProvider{},
|
||||
publisher: &testkit.RecordingProjectionPublisher{},
|
||||
idGenerator: &testkit.SequenceIDGenerator{
|
||||
DeviceSessionIDs: []common.DeviceSessionID{"device-session-1"},
|
||||
},
|
||||
hasher: testkit.DeterministicCodeHasher{},
|
||||
now: time.Unix(20, 0).UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
func mustNewConfirmService(t *testing.T, deps confirmDeps) *Service {
|
||||
t.Helper()
|
||||
|
||||
service, err := New(
|
||||
deps.challengeStore,
|
||||
deps.sessionStore,
|
||||
deps.userDirectory,
|
||||
deps.configProvider,
|
||||
deps.publisher,
|
||||
deps.idGenerator,
|
||||
deps.hasher,
|
||||
testkit.FixedClock{Time: deps.now},
|
||||
)
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "New() returned error: %v", err)
|
||||
}
|
||||
|
||||
return service
|
||||
}
|
||||
|
||||
func sentChallengeFixture(
|
||||
t *testing.T,
|
||||
hasher testkit.DeterministicCodeHasher,
|
||||
challengeID string,
|
||||
email string,
|
||||
code string,
|
||||
createdAt time.Time,
|
||||
expiresAt time.Time,
|
||||
) challenge.Challenge {
|
||||
t.Helper()
|
||||
|
||||
codeHash, err := hasher.Hash(code)
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "Hash() returned error: %v", err)
|
||||
}
|
||||
|
||||
record := challenge.Challenge{
|
||||
ID: common.ChallengeID(challengeID),
|
||||
Email: common.Email(email),
|
||||
CodeHash: codeHash,
|
||||
Status: challenge.StatusSent,
|
||||
DeliveryState: challenge.DeliverySent,
|
||||
CreatedAt: createdAt,
|
||||
ExpiresAt: expiresAt,
|
||||
}
|
||||
if err := record.Validate(); err != nil {
|
||||
require.Failf(t, "test failed", "Validate() returned error: %v", err)
|
||||
}
|
||||
|
||||
return record
|
||||
}
|
||||
|
||||
func confirmedChallengeFixture(
|
||||
t *testing.T,
|
||||
hasher testkit.DeterministicCodeHasher,
|
||||
challengeID string,
|
||||
email string,
|
||||
code string,
|
||||
deviceSessionID string,
|
||||
clientPublicKey common.ClientPublicKey,
|
||||
createdAt time.Time,
|
||||
expiresAt time.Time,
|
||||
) challenge.Challenge {
|
||||
t.Helper()
|
||||
|
||||
record := sentChallengeFixture(t, hasher, challengeID, email, code, createdAt, expiresAt)
|
||||
record.Status = challenge.StatusConfirmedPendingExpire
|
||||
record.Confirmation = &challenge.Confirmation{
|
||||
SessionID: common.DeviceSessionID(deviceSessionID),
|
||||
ClientPublicKey: clientPublicKey,
|
||||
ConfirmedAt: createdAt.Add(time.Minute),
|
||||
}
|
||||
if err := record.Validate(); err != nil {
|
||||
require.Failf(t, "test failed", "Validate() returned error: %v", err)
|
||||
}
|
||||
|
||||
return record
|
||||
}
|
||||
|
||||
func activeSessionFixture(deviceSessionID string, userID string, clientPublicKey common.ClientPublicKey, createdAt time.Time) devicesession.Session {
|
||||
return devicesession.Session{
|
||||
ID: common.DeviceSessionID(deviceSessionID),
|
||||
UserID: common.UserID(userID),
|
||||
ClientPublicKey: clientPublicKey,
|
||||
Status: devicesession.StatusActive,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
}
|
||||
|
||||
func mustClientPublicKey(t *testing.T, value string) common.ClientPublicKey {
|
||||
t.Helper()
|
||||
|
||||
key, err := shared.ParseClientPublicKey(value)
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "ParseClientPublicKey() returned error: %v", err)
|
||||
}
|
||||
|
||||
return key
|
||||
}
|
||||
|
||||
func publicKeyString() string {
|
||||
return "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8="
|
||||
}
|
||||
|
||||
func alternatePublicKeyString() string {
|
||||
return "AQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQE="
|
||||
}
|
||||
@@ -0,0 +1,109 @@
|
||||
package confirmemailcode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
stubuserservice "galaxy/authsession/internal/adapters/userservice"
|
||||
"galaxy/authsession/internal/domain/challenge"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/userresolution"
|
||||
"galaxy/authsession/internal/service/shared"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExecuteWithRuntimeStubUserDirectory(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("creates user through EnsureUserByEmail", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
deps := newConfirmDeps(t)
|
||||
userDirectory := &stubuserservice.StubDirectory{}
|
||||
require.NoError(t, userDirectory.QueueCreatedUserIDs(common.UserID("user-created")))
|
||||
deps.userDirectory = nil
|
||||
require.NoError(t, deps.challengeStore.Create(context.Background(), sentChallengeFixture(
|
||||
t,
|
||||
deps.hasher,
|
||||
"challenge-1",
|
||||
"pilot@example.com",
|
||||
"654321",
|
||||
deps.now.Add(-time.Minute),
|
||||
deps.now.Add(time.Minute),
|
||||
)))
|
||||
|
||||
service, err := New(
|
||||
deps.challengeStore,
|
||||
deps.sessionStore,
|
||||
userDirectory,
|
||||
deps.configProvider,
|
||||
deps.publisher,
|
||||
deps.idGenerator,
|
||||
deps.hasher,
|
||||
fixedClock(deps.now),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := service.Execute(context.Background(), Input{
|
||||
ChallengeID: "challenge-1",
|
||||
Code: "654321",
|
||||
ClientPublicKey: publicKeyString(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "device-session-1", result.DeviceSessionID)
|
||||
|
||||
sessionRecord, err := deps.sessionStore.Get(context.Background(), common.DeviceSessionID("device-session-1"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, common.UserID("user-created"), sessionRecord.UserID)
|
||||
})
|
||||
|
||||
t.Run("blocked email returns blocked by policy", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
deps := newConfirmDeps(t)
|
||||
userDirectory := &stubuserservice.StubDirectory{}
|
||||
require.NoError(t, userDirectory.SeedBlockedEmail(common.Email("pilot@example.com"), userresolution.BlockReasonCode("policy_block")))
|
||||
require.NoError(t, deps.challengeStore.Create(context.Background(), sentChallengeFixture(
|
||||
t,
|
||||
deps.hasher,
|
||||
"challenge-1",
|
||||
"pilot@example.com",
|
||||
"654321",
|
||||
deps.now.Add(-time.Minute),
|
||||
deps.now.Add(time.Minute),
|
||||
)))
|
||||
|
||||
service, err := New(
|
||||
deps.challengeStore,
|
||||
deps.sessionStore,
|
||||
userDirectory,
|
||||
deps.configProvider,
|
||||
deps.publisher,
|
||||
deps.idGenerator,
|
||||
deps.hasher,
|
||||
fixedClock(deps.now),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = service.Execute(context.Background(), Input{
|
||||
ChallengeID: "challenge-1",
|
||||
Code: "654321",
|
||||
ClientPublicKey: publicKeyString(),
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, shared.ErrorCodeBlockedByPolicy, shared.CodeOf(err))
|
||||
|
||||
record, getErr := deps.challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
|
||||
require.NoError(t, getErr)
|
||||
assert.Equal(t, challenge.StatusFailed, record.Status)
|
||||
})
|
||||
}
|
||||
|
||||
type fixedClock time.Time
|
||||
|
||||
func (c fixedClock) Now() time.Time {
|
||||
return time.Time(c)
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
package confirmemailcode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/challenge"
|
||||
authtelemetry "galaxy/authsession/internal/telemetry"
|
||||
"galaxy/authsession/internal/testkit"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
sdkmetric "go.opentelemetry.io/otel/sdk/metric"
|
||||
"go.opentelemetry.io/otel/sdk/metric/metricdata"
|
||||
)
|
||||
|
||||
func TestExecuteRecordsInvalidCodeMetricForThrottledChallenge(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
runtime, reader := newObservedConfirmTelemetryRuntime(t)
|
||||
deps := newConfirmDeps(t)
|
||||
record := sentChallengeFixture(t, deps.hasher, "challenge-1", "pilot@example.com", "654321", deps.now.Add(-time.Minute), deps.now.Add(time.Minute))
|
||||
record.Status = challenge.StatusDeliveryThrottled
|
||||
record.DeliveryState = challenge.DeliveryThrottled
|
||||
require.NoError(t, record.Validate())
|
||||
require.NoError(t, deps.challengeStore.Create(context.Background(), record))
|
||||
|
||||
service, err := NewWithTelemetry(
|
||||
deps.challengeStore,
|
||||
deps.sessionStore,
|
||||
deps.userDirectory,
|
||||
deps.configProvider,
|
||||
deps.publisher,
|
||||
deps.idGenerator,
|
||||
deps.hasher,
|
||||
testkit.FixedClock{Time: deps.now},
|
||||
runtime,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = service.Execute(context.Background(), Input{
|
||||
ChallengeID: "challenge-1",
|
||||
Code: "654321",
|
||||
ClientPublicKey: publicKeyString(),
|
||||
})
|
||||
require.Error(t, err)
|
||||
|
||||
assertConfirmMetricCount(t, reader, map[string]string{"outcome": "invalid_code"}, 1)
|
||||
}
|
||||
|
||||
func newObservedConfirmTelemetryRuntime(t *testing.T) (*authtelemetry.Runtime, *sdkmetric.ManualReader) {
|
||||
t.Helper()
|
||||
|
||||
reader := sdkmetric.NewManualReader()
|
||||
provider := sdkmetric.NewMeterProvider(sdkmetric.WithReader(reader))
|
||||
|
||||
runtime, err := authtelemetry.New(provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
return runtime, reader
|
||||
}
|
||||
|
||||
func assertConfirmMetricCount(t *testing.T, reader *sdkmetric.ManualReader, wantAttrs map[string]string, wantValue int64) {
|
||||
t.Helper()
|
||||
|
||||
var resourceMetrics metricdata.ResourceMetrics
|
||||
require.NoError(t, reader.Collect(context.Background(), &resourceMetrics))
|
||||
|
||||
for _, scopeMetrics := range resourceMetrics.ScopeMetrics {
|
||||
for _, metric := range scopeMetrics.Metrics {
|
||||
if metric.Name != "authsession.confirm_email_code.attempts" {
|
||||
continue
|
||||
}
|
||||
|
||||
sum, ok := metric.Data.(metricdata.Sum[int64])
|
||||
require.True(t, ok)
|
||||
|
||||
for _, point := range sum.DataPoints {
|
||||
if hasConfirmMetricAttributes(point.Attributes.ToSlice(), wantAttrs) {
|
||||
assert.Equal(t, wantValue, point.Value)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
require.Failf(t, "test failed", "confirm metric with attrs %v not found", wantAttrs)
|
||||
}
|
||||
|
||||
func hasConfirmMetricAttributes(values []attribute.KeyValue, want map[string]string) bool {
|
||||
if len(values) != len(want) {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, value := range values {
|
||||
if want[string(value.Key)] != value.Value.AsString() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
// Package getsession implements the trusted internal read use case for one
|
||||
// device session.
|
||||
package getsession
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"galaxy/authsession/internal/ports"
|
||||
"galaxy/authsession/internal/service/shared"
|
||||
)
|
||||
|
||||
// Input describes one trusted internal get-session request.
|
||||
type Input struct {
|
||||
// DeviceSessionID identifies the session that should be read.
|
||||
DeviceSessionID string
|
||||
}
|
||||
|
||||
// Result describes one trusted internal get-session response.
|
||||
type Result struct {
|
||||
// Session stores the frozen internal read-model DTO.
|
||||
Session shared.Session
|
||||
}
|
||||
|
||||
// Service executes the trusted internal get-session use case against the
|
||||
// configured ports.
|
||||
type Service struct {
|
||||
sessionStore ports.SessionStore
|
||||
}
|
||||
|
||||
// New returns a get-session service wired to sessionStore.
|
||||
func New(sessionStore ports.SessionStore) (*Service, error) {
|
||||
if sessionStore == nil {
|
||||
return nil, fmt.Errorf("getsession: session store must not be nil")
|
||||
}
|
||||
|
||||
return &Service{sessionStore: sessionStore}, nil
|
||||
}
|
||||
|
||||
// Execute loads one source-of-truth session and projects it into the frozen
|
||||
// internal read DTO shape.
|
||||
func (s *Service) Execute(ctx context.Context, input Input) (Result, error) {
|
||||
deviceSessionID, err := shared.ParseDeviceSessionID(input.DeviceSessionID)
|
||||
if err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
|
||||
record, err := s.sessionStore.Get(ctx, deviceSessionID)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, ports.ErrNotFound):
|
||||
return Result{}, shared.SessionNotFound()
|
||||
default:
|
||||
return Result{}, shared.ServiceUnavailable(err)
|
||||
}
|
||||
}
|
||||
|
||||
session, err := shared.ToSession(record)
|
||||
if err != nil {
|
||||
return Result{}, shared.InternalError(err)
|
||||
}
|
||||
|
||||
return Result{Session: session}, nil
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
package getsession
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/devicesession"
|
||||
"galaxy/authsession/internal/service/shared"
|
||||
"galaxy/authsession/internal/testkit"
|
||||
)
|
||||
|
||||
func TestExecuteReturnsMappedSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := &testkit.InMemorySessionStore{}
|
||||
record := activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())
|
||||
if err := store.Create(context.Background(), record); err != nil {
|
||||
require.Failf(t, "test failed", "Create() returned error: %v", err)
|
||||
}
|
||||
|
||||
service, err := New(store)
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "New() returned error: %v", err)
|
||||
}
|
||||
|
||||
result, err := service.Execute(context.Background(), Input{DeviceSessionID: " device-session-1 "})
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "Execute() returned error: %v", err)
|
||||
}
|
||||
if result.Session.DeviceSessionID != "device-session-1" {
|
||||
require.Failf(t, "test failed", "Execute().Session.DeviceSessionID = %q, want %q", result.Session.DeviceSessionID, "device-session-1")
|
||||
}
|
||||
if result.Session.CreatedAt != time.Unix(10, 0).UTC().Format(time.RFC3339) {
|
||||
require.Failf(t, "test failed", "Execute().Session.CreatedAt = %q", result.Session.CreatedAt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteReturnsSessionNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
service, err := New(&testkit.InMemorySessionStore{})
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "New() returned error: %v", err)
|
||||
}
|
||||
|
||||
_, err = service.Execute(context.Background(), Input{DeviceSessionID: "missing"})
|
||||
if shared.CodeOf(err) != shared.ErrorCodeSessionNotFound {
|
||||
require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeSessionNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
func activeSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session {
|
||||
key, err := common.NewClientPublicKey(make([]byte, 32))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return devicesession.Session{
|
||||
ID: common.DeviceSessionID(deviceSessionID),
|
||||
UserID: common.UserID(userID),
|
||||
ClientPublicKey: key,
|
||||
Status: devicesession.StatusActive,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
// Package listusersessions implements the trusted internal read use case for
|
||||
// listing all sessions of one user.
|
||||
package listusersessions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"galaxy/authsession/internal/ports"
|
||||
"galaxy/authsession/internal/service/shared"
|
||||
)
|
||||
|
||||
// Input describes one trusted internal list-user-sessions request.
|
||||
type Input struct {
|
||||
// UserID identifies the owner whose sessions should be listed.
|
||||
UserID string
|
||||
}
|
||||
|
||||
// Result describes one trusted internal list-user-sessions response.
|
||||
type Result struct {
|
||||
// Sessions stores the frozen internal read-model DTO slice.
|
||||
Sessions []shared.Session
|
||||
}
|
||||
|
||||
// Service executes the trusted internal list-user-sessions use case.
|
||||
type Service struct {
|
||||
sessionStore ports.SessionStore
|
||||
}
|
||||
|
||||
// New returns a list-user-sessions service wired to sessionStore.
|
||||
func New(sessionStore ports.SessionStore) (*Service, error) {
|
||||
if sessionStore == nil {
|
||||
return nil, fmt.Errorf("listusersessions: session store must not be nil")
|
||||
}
|
||||
|
||||
return &Service{sessionStore: sessionStore}, nil
|
||||
}
|
||||
|
||||
// Execute loads all source-of-truth sessions for one user and projects them
|
||||
// into the frozen internal read DTO shape.
|
||||
func (s *Service) Execute(ctx context.Context, input Input) (Result, error) {
|
||||
userID, err := shared.ParseUserID(input.UserID)
|
||||
if err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
|
||||
records, err := s.sessionStore.ListByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return Result{}, shared.ServiceUnavailable(err)
|
||||
}
|
||||
|
||||
sessions, err := shared.ToSessions(records)
|
||||
if err != nil {
|
||||
return Result{}, shared.InternalError(err)
|
||||
}
|
||||
|
||||
return Result{Sessions: sessions}, nil
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
package listusersessions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/devicesession"
|
||||
"galaxy/authsession/internal/testkit"
|
||||
)
|
||||
|
||||
func TestExecutePreservesNewestFirstOrder(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := &testkit.InMemorySessionStore{}
|
||||
older := activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())
|
||||
newer := activeSessionFixture("device-session-2", "user-1", time.Unix(20, 0).UTC())
|
||||
for _, record := range []devicesession.Session{older, newer} {
|
||||
if err := store.Create(context.Background(), record); err != nil {
|
||||
require.Failf(t, "test failed", "Create() returned error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
service, err := New(store)
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "New() returned error: %v", err)
|
||||
}
|
||||
|
||||
result, err := service.Execute(context.Background(), Input{UserID: "user-1"})
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "Execute() returned error: %v", err)
|
||||
}
|
||||
if len(result.Sessions) != 2 {
|
||||
require.Failf(t, "test failed", "Execute().Sessions length = %d, want 2", len(result.Sessions))
|
||||
}
|
||||
if result.Sessions[0].DeviceSessionID != "device-session-2" || result.Sessions[1].DeviceSessionID != "device-session-1" {
|
||||
require.Failf(t, "test failed", "Execute().Sessions order = [%q %q]", result.Sessions[0].DeviceSessionID, result.Sessions[1].DeviceSessionID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteReturnsEmptyForUnknownUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
service, err := New(&testkit.InMemorySessionStore{})
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "New() returned error: %v", err)
|
||||
}
|
||||
|
||||
result, err := service.Execute(context.Background(), Input{UserID: "missing"})
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "Execute() returned error: %v", err)
|
||||
}
|
||||
if len(result.Sessions) != 0 {
|
||||
require.Failf(t, "test failed", "Execute().Sessions length = %d, want 0", len(result.Sessions))
|
||||
}
|
||||
}
|
||||
|
||||
func activeSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session {
|
||||
key, err := common.NewClientPublicKey(make([]byte, 32))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return devicesession.Session{
|
||||
ID: common.DeviceSessionID(deviceSessionID),
|
||||
UserID: common.UserID(userID),
|
||||
ClientPublicKey: key,
|
||||
Status: devicesession.StatusActive,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,106 @@
|
||||
package revokeallusersessions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/devicesession"
|
||||
"galaxy/authsession/internal/service/shared"
|
||||
"galaxy/authsession/internal/testkit"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExecuteRetriesProjectionPublishesForBulkRevoke(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := &testkit.InMemorySessionStore{}
|
||||
userDirectory := &testkit.InMemoryUserDirectory{}
|
||||
publisher := &testkit.RecordingProjectionPublisher{
|
||||
Errors: []error{
|
||||
errors.New("publish failed"),
|
||||
nil,
|
||||
errors.New("publish failed"),
|
||||
nil,
|
||||
},
|
||||
}
|
||||
require.NoError(t, userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
|
||||
require.NoError(t, store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())))
|
||||
require.NoError(t, store.Create(context.Background(), activeSessionFixture("device-session-2", "user-1", time.Unix(20, 0).UTC())))
|
||||
|
||||
service, err := New(store, userDirectory, publisher, testkit.FixedClock{Time: time.Unix(30, 0).UTC()})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := service.Execute(context.Background(), Input{
|
||||
UserID: "user-1",
|
||||
ReasonCode: "logout_all",
|
||||
ActorType: "system",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "revoked", result.Outcome)
|
||||
assert.EqualValues(t, 2, result.AffectedSessionCount)
|
||||
assert.Equal(t, []string{"device-session-2", "device-session-1"}, result.AffectedDeviceSessionIDs)
|
||||
require.Len(t, publisher.PublishedSnapshots(), 4)
|
||||
}
|
||||
|
||||
func TestExecuteRepublishesCurrentRevokedSessionsOnNoActiveSessionsRetry(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := &testkit.InMemorySessionStore{}
|
||||
userDirectory := &testkit.InMemoryUserDirectory{}
|
||||
publisher := &testkit.RecordingProjectionPublisher{
|
||||
Errors: []error{
|
||||
nil,
|
||||
errors.New("publish failed"),
|
||||
errors.New("publish failed"),
|
||||
errors.New("publish failed"),
|
||||
},
|
||||
}
|
||||
require.NoError(t, userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
|
||||
require.NoError(t, store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())))
|
||||
require.NoError(t, store.Create(context.Background(), activeSessionFixture("device-session-2", "user-1", time.Unix(20, 0).UTC())))
|
||||
|
||||
service, err := New(store, userDirectory, publisher, testkit.FixedClock{Time: time.Unix(30, 0).UTC()})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = service.Execute(context.Background(), Input{
|
||||
UserID: "user-1",
|
||||
ReasonCode: "logout_all",
|
||||
ActorType: "system",
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, shared.ErrorCodeServiceUnavailable, shared.CodeOf(err))
|
||||
require.Len(t, publisher.PublishedSnapshots(), 4)
|
||||
|
||||
for _, deviceSessionID := range []common.DeviceSessionID{"device-session-1", "device-session-2"} {
|
||||
record, getErr := store.Get(context.Background(), deviceSessionID)
|
||||
require.NoError(t, getErr)
|
||||
require.NotNil(t, record.Revocation)
|
||||
assert.Equal(t, devicesession.StatusRevoked, record.Status)
|
||||
}
|
||||
|
||||
publisher.Errors = nil
|
||||
publisher.Err = nil
|
||||
|
||||
result, err := service.Execute(context.Background(), Input{
|
||||
UserID: "user-1",
|
||||
ReasonCode: "logout_all",
|
||||
ActorType: "system",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "no_active_sessions", result.Outcome)
|
||||
assert.EqualValues(t, 0, result.AffectedSessionCount)
|
||||
require.NotNil(t, result.AffectedDeviceSessionIDs)
|
||||
assert.Empty(t, result.AffectedDeviceSessionIDs)
|
||||
|
||||
published := publisher.PublishedSnapshots()
|
||||
require.Len(t, published, 6)
|
||||
assert.Equal(t, []common.DeviceSessionID{"device-session-2", "device-session-1"}, []common.DeviceSessionID{
|
||||
published[4].DeviceSessionID,
|
||||
published[5].DeviceSessionID,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,200 @@
|
||||
// Package revokeallusersessions implements the trusted internal bulk revoke
|
||||
// use case for all sessions of one user.
|
||||
package revokeallusersessions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/devicesession"
|
||||
"galaxy/authsession/internal/ports"
|
||||
"galaxy/authsession/internal/service/shared"
|
||||
"galaxy/authsession/internal/telemetry"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Input describes one trusted internal revoke-all-user-sessions request.
|
||||
type Input struct {
|
||||
// UserID identifies the owner whose sessions should be revoked.
|
||||
UserID string
|
||||
|
||||
// ReasonCode stores the machine-readable revoke reason code.
|
||||
ReasonCode string
|
||||
|
||||
// ActorType stores the machine-readable revoke actor type.
|
||||
ActorType string
|
||||
|
||||
// ActorID stores the optional stable revoke actor identifier.
|
||||
ActorID string
|
||||
}
|
||||
|
||||
// Result describes the frozen internal bulk revoke acknowledgement.
|
||||
type Result struct {
|
||||
// Outcome reports whether active sessions were revoked during the current
|
||||
// call.
|
||||
Outcome string
|
||||
|
||||
// UserID identifies the user addressed by the operation.
|
||||
UserID string
|
||||
|
||||
// AffectedSessionCount reports how many sessions changed state during the
|
||||
// current call.
|
||||
AffectedSessionCount int64
|
||||
|
||||
// AffectedDeviceSessionIDs lists every session identifier affected during
|
||||
// the current call.
|
||||
AffectedDeviceSessionIDs []string
|
||||
}
|
||||
|
||||
// Service executes the trusted internal revoke-all-user-sessions use case.
|
||||
type Service struct {
|
||||
sessionStore ports.SessionStore
|
||||
userDirectory ports.UserDirectory
|
||||
publisher ports.GatewaySessionProjectionPublisher
|
||||
clock ports.Clock
|
||||
logger *zap.Logger
|
||||
telemetry *telemetry.Runtime
|
||||
}
|
||||
|
||||
// New returns a revoke-all-user-sessions service wired to the required ports.
|
||||
func New(sessionStore ports.SessionStore, userDirectory ports.UserDirectory, publisher ports.GatewaySessionProjectionPublisher, clock ports.Clock) (*Service, error) {
|
||||
return NewWithObservability(sessionStore, userDirectory, publisher, clock, nil, nil)
|
||||
}
|
||||
|
||||
// NewWithObservability returns a revoke-all-user-sessions service wired to the
|
||||
// required ports plus optional structured logging and telemetry dependencies.
|
||||
func NewWithObservability(
|
||||
sessionStore ports.SessionStore,
|
||||
userDirectory ports.UserDirectory,
|
||||
publisher ports.GatewaySessionProjectionPublisher,
|
||||
clock ports.Clock,
|
||||
logger *zap.Logger,
|
||||
telemetryRuntime *telemetry.Runtime,
|
||||
) (*Service, error) {
|
||||
switch {
|
||||
case sessionStore == nil:
|
||||
return nil, fmt.Errorf("revokeallusersessions: session store must not be nil")
|
||||
case userDirectory == nil:
|
||||
return nil, fmt.Errorf("revokeallusersessions: user directory must not be nil")
|
||||
case publisher == nil:
|
||||
return nil, fmt.Errorf("revokeallusersessions: projection publisher must not be nil")
|
||||
case clock == nil:
|
||||
return nil, fmt.Errorf("revokeallusersessions: clock must not be nil")
|
||||
default:
|
||||
return &Service{
|
||||
sessionStore: sessionStore,
|
||||
userDirectory: userDirectory,
|
||||
publisher: publisher,
|
||||
clock: clock,
|
||||
logger: namedLogger(logger, "revoke_all_user_sessions"),
|
||||
telemetry: telemetryRuntime,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Execute revokes all active sessions of one user and republishes revoked
|
||||
// gateway projections for every affected session.
|
||||
func (s *Service) Execute(ctx context.Context, input Input) (result Result, err error) {
|
||||
logFields := []zap.Field{
|
||||
zap.String("component", "service"),
|
||||
zap.String("use_case", "revoke_all_user_sessions"),
|
||||
}
|
||||
defer func() {
|
||||
shared.LogServiceOutcome(s.logger, ctx, "revoke all user sessions completed", err, logFields...)
|
||||
}()
|
||||
|
||||
userID, err := shared.ParseUserID(input.UserID)
|
||||
if err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
logFields = append(logFields, zap.String("user_id", userID.String()))
|
||||
|
||||
revocation, err := shared.BuildRevocation(input.ReasonCode, input.ActorType, input.ActorID, s.clock.Now())
|
||||
if err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
logFields = append(logFields, zap.String("reason_code", revocation.ReasonCode.String()))
|
||||
|
||||
exists, err := s.userDirectory.ExistsByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return Result{}, shared.ServiceUnavailable(err)
|
||||
}
|
||||
s.telemetry.RecordUserDirectoryOutcome(ctx, "exists_by_user_id", boolOutcome(exists))
|
||||
if !exists {
|
||||
return Result{}, shared.SubjectNotFound()
|
||||
}
|
||||
|
||||
storeResult, err := s.sessionStore.RevokeAllByUserID(ctx, ports.RevokeUserSessionsInput{
|
||||
UserID: userID,
|
||||
Revocation: revocation,
|
||||
})
|
||||
if err != nil {
|
||||
return Result{}, shared.ServiceUnavailable(err)
|
||||
}
|
||||
if err := storeResult.Validate(); err != nil {
|
||||
return Result{}, shared.InternalError(err)
|
||||
}
|
||||
logFields = append(logFields, zap.String("outcome", string(storeResult.Outcome)))
|
||||
|
||||
affectedDeviceSessionIDs := make([]string, 0, len(storeResult.Sessions))
|
||||
for _, record := range storeResult.Sessions {
|
||||
if err := shared.PublishSessionProjectionWithTelemetry(ctx, s.publisher, record, s.telemetry, "revoke_all_user_sessions"); err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
affectedDeviceSessionIDs = append(affectedDeviceSessionIDs, record.ID.String())
|
||||
}
|
||||
if storeResult.Outcome == ports.RevokeUserSessionsOutcomeNoActiveSessions {
|
||||
if err := s.republishCurrentRevokedSessions(ctx, userID); err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
}
|
||||
|
||||
affectedSessionCount := int64(len(storeResult.Sessions))
|
||||
if affectedSessionCount > 0 {
|
||||
s.telemetry.RecordSessionRevocations(ctx, "revoke_all_user_sessions", revocation.ReasonCode.String(), affectedSessionCount)
|
||||
}
|
||||
logFields = append(logFields, zap.Int64("affected_session_count", affectedSessionCount))
|
||||
|
||||
return Result{
|
||||
Outcome: string(storeResult.Outcome),
|
||||
UserID: storeResult.UserID.String(),
|
||||
AffectedSessionCount: affectedSessionCount,
|
||||
AffectedDeviceSessionIDs: affectedDeviceSessionIDs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) republishCurrentRevokedSessions(ctx context.Context, userID common.UserID) error {
|
||||
records, err := s.sessionStore.ListByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return shared.ServiceUnavailable(err)
|
||||
}
|
||||
|
||||
for _, record := range records {
|
||||
if record.Status != devicesession.StatusRevoked {
|
||||
continue
|
||||
}
|
||||
if err := shared.PublishSessionProjectionWithTelemetry(ctx, s.publisher, record, s.telemetry, "revoke_all_user_sessions_repair"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func boolOutcome(value bool) string {
|
||||
if value {
|
||||
return "exists"
|
||||
}
|
||||
|
||||
return "missing"
|
||||
}
|
||||
|
||||
func namedLogger(logger *zap.Logger, name string) *zap.Logger {
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
|
||||
return logger.Named(name)
|
||||
}
|
||||
@@ -0,0 +1,162 @@
|
||||
package revokeallusersessions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/devicesession"
|
||||
"galaxy/authsession/internal/domain/gatewayprojection"
|
||||
"galaxy/authsession/internal/service/shared"
|
||||
"galaxy/authsession/internal/testkit"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExecuteRevokesExistingUserSessionsAndPublishes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := &testkit.InMemorySessionStore{}
|
||||
userDirectory := &testkit.InMemoryUserDirectory{}
|
||||
publisher := &testkit.RecordingProjectionPublisher{}
|
||||
if err := userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil {
|
||||
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
|
||||
}
|
||||
for _, record := range []devicesession.Session{
|
||||
activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC()),
|
||||
activeSessionFixture("device-session-2", "user-1", time.Unix(20, 0).UTC()),
|
||||
} {
|
||||
if err := store.Create(context.Background(), record); err != nil {
|
||||
require.Failf(t, "test failed", "Create() returned error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
service, err := New(store, userDirectory, publisher, testkit.FixedClock{Time: time.Unix(30, 0).UTC()})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := service.Execute(context.Background(), Input{
|
||||
UserID: "user-1",
|
||||
ReasonCode: "logout_all",
|
||||
ActorType: "system",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "revoked", result.Outcome)
|
||||
assert.EqualValues(t, 2, result.AffectedSessionCount)
|
||||
assert.Equal(t, []string{"device-session-2", "device-session-1"}, result.AffectedDeviceSessionIDs)
|
||||
|
||||
for _, deviceSessionID := range result.AffectedDeviceSessionIDs {
|
||||
stored, getErr := store.Get(context.Background(), common.DeviceSessionID(deviceSessionID))
|
||||
require.NoError(t, getErr)
|
||||
require.NotNil(t, stored.Revocation)
|
||||
assert.Equal(t, devicesession.StatusRevoked, stored.Status)
|
||||
assert.Equal(t, devicesession.RevokeReasonLogoutAll, stored.Revocation.ReasonCode)
|
||||
assert.Equal(t, common.RevokeActorType("system"), stored.Revocation.ActorType)
|
||||
assert.Empty(t, stored.Revocation.ActorID)
|
||||
assert.Equal(t, time.Unix(30, 0).UTC(), stored.Revocation.At)
|
||||
}
|
||||
|
||||
published := publisher.PublishedSnapshots()
|
||||
require.Len(t, published, 2)
|
||||
assert.Equal(t, []common.DeviceSessionID{"device-session-2", "device-session-1"}, []common.DeviceSessionID{
|
||||
published[0].DeviceSessionID,
|
||||
published[1].DeviceSessionID,
|
||||
})
|
||||
for _, snapshot := range published {
|
||||
assert.Equal(t, gatewayprojection.StatusRevoked, snapshot.Status)
|
||||
assert.Equal(t, devicesession.RevokeReasonLogoutAll, snapshot.RevokeReasonCode)
|
||||
assert.Equal(t, common.RevokeActorType("system"), snapshot.RevokeActorType)
|
||||
require.NotNil(t, snapshot.RevokedAt)
|
||||
assert.Equal(t, time.Unix(30, 0).UTC(), *snapshot.RevokedAt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteReturnsNoActiveSessionsForExistingUserWithoutActiveSessions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := &testkit.InMemorySessionStore{}
|
||||
userDirectory := &testkit.InMemoryUserDirectory{}
|
||||
publisher := &testkit.RecordingProjectionPublisher{}
|
||||
if err := userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil {
|
||||
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
|
||||
}
|
||||
|
||||
service, err := New(store, userDirectory, publisher, testkit.FixedClock{Time: time.Unix(30, 0).UTC()})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := service.Execute(context.Background(), Input{
|
||||
UserID: "user-1",
|
||||
ReasonCode: "logout_all",
|
||||
ActorType: "system",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "no_active_sessions", result.Outcome)
|
||||
assert.EqualValues(t, 0, result.AffectedSessionCount)
|
||||
require.NotNil(t, result.AffectedDeviceSessionIDs)
|
||||
assert.Empty(t, result.AffectedDeviceSessionIDs)
|
||||
assert.Empty(t, publisher.PublishedSnapshots())
|
||||
}
|
||||
|
||||
func TestExecuteReturnsSubjectNotFoundForUnknownUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
service, err := New(&testkit.InMemorySessionStore{}, &testkit.InMemoryUserDirectory{}, &testkit.RecordingProjectionPublisher{}, testkit.FixedClock{Time: time.Unix(30, 0).UTC()})
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "New() returned error: %v", err)
|
||||
}
|
||||
|
||||
_, err = service.Execute(context.Background(), Input{
|
||||
UserID: "missing",
|
||||
ReasonCode: "logout_all",
|
||||
ActorType: "system",
|
||||
})
|
||||
assert.Equal(t, shared.ErrorCodeSubjectNotFound, shared.CodeOf(err))
|
||||
}
|
||||
|
||||
func TestExecuteReturnsServiceUnavailableWhenPublishFails(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := &testkit.InMemorySessionStore{}
|
||||
userDirectory := &testkit.InMemoryUserDirectory{}
|
||||
publisher := &testkit.RecordingProjectionPublisher{Err: errors.New("publish failed")}
|
||||
if err := userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")); err != nil {
|
||||
require.Failf(t, "test failed", "SeedExisting() returned error: %v", err)
|
||||
}
|
||||
if err := store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())); err != nil {
|
||||
require.Failf(t, "test failed", "Create() returned error: %v", err)
|
||||
}
|
||||
|
||||
service, err := New(store, userDirectory, publisher, testkit.FixedClock{Time: time.Unix(30, 0).UTC()})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = service.Execute(context.Background(), Input{
|
||||
UserID: "user-1",
|
||||
ReasonCode: "logout_all",
|
||||
ActorType: "system",
|
||||
})
|
||||
assert.Equal(t, shared.ErrorCodeServiceUnavailable, shared.CodeOf(err))
|
||||
|
||||
stored, getErr := store.Get(context.Background(), common.DeviceSessionID("device-session-1"))
|
||||
require.NoError(t, getErr)
|
||||
require.NotNil(t, stored.Revocation)
|
||||
assert.Equal(t, devicesession.StatusRevoked, stored.Status)
|
||||
assert.Equal(t, devicesession.RevokeReasonLogoutAll, stored.Revocation.ReasonCode)
|
||||
assert.Equal(t, common.RevokeActorType("system"), stored.Revocation.ActorType)
|
||||
}
|
||||
|
||||
func activeSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session {
|
||||
key, err := common.NewClientPublicKey(make([]byte, 32))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return devicesession.Session{
|
||||
ID: common.DeviceSessionID(deviceSessionID),
|
||||
UserID: common.UserID(userID),
|
||||
ClientPublicKey: key,
|
||||
Status: devicesession.StatusActive,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
package revokeallusersessions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
stubuserservice "galaxy/authsession/internal/adapters/userservice"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/service/shared"
|
||||
"galaxy/authsession/internal/testkit"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExecuteWithRuntimeStubUserDirectory(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("existing user uses ExistsByUserID and returns no active sessions", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
userDirectory := &stubuserservice.StubDirectory{}
|
||||
require.NoError(t, userDirectory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1")))
|
||||
|
||||
service, err := New(&testkit.InMemorySessionStore{}, userDirectory, &testkit.RecordingProjectionPublisher{}, testkit.FixedClock{Time: time.Unix(30, 0).UTC()})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := service.Execute(context.Background(), Input{
|
||||
UserID: "user-1",
|
||||
ReasonCode: "logout_all",
|
||||
ActorType: "system",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "no_active_sessions", result.Outcome)
|
||||
assert.Zero(t, result.AffectedSessionCount)
|
||||
})
|
||||
|
||||
t.Run("unknown user returns subject not found", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
service, err := New(&testkit.InMemorySessionStore{}, &stubuserservice.StubDirectory{}, &testkit.RecordingProjectionPublisher{}, testkit.FixedClock{Time: time.Unix(30, 0).UTC()})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = service.Execute(context.Background(), Input{
|
||||
UserID: "missing",
|
||||
ReasonCode: "logout_all",
|
||||
ActorType: "system",
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, shared.ErrorCodeSubjectNotFound, shared.CodeOf(err))
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
package revokedevicesession
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/devicesession"
|
||||
"galaxy/authsession/internal/service/shared"
|
||||
"galaxy/authsession/internal/testkit"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExecuteRetriesProjectionPublishUntilSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := &testkit.InMemorySessionStore{}
|
||||
publisher := &testkit.RecordingProjectionPublisher{
|
||||
Errors: []error{errors.New("publish failed"), nil},
|
||||
}
|
||||
require.NoError(t, store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())))
|
||||
|
||||
service, err := New(store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := service.Execute(context.Background(), Input{
|
||||
DeviceSessionID: "device-session-1",
|
||||
ReasonCode: "logout_all",
|
||||
ActorType: "system",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "revoked", result.Outcome)
|
||||
require.Len(t, publisher.PublishedSnapshots(), 2)
|
||||
}
|
||||
|
||||
func TestExecuteRepairsProjectionOnRepeatedAlreadyRevokedRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := &testkit.InMemorySessionStore{}
|
||||
publisher := &testkit.RecordingProjectionPublisher{Err: errors.New("publish failed")}
|
||||
require.NoError(t, store.Create(context.Background(), activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())))
|
||||
|
||||
service, err := New(store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = service.Execute(context.Background(), Input{
|
||||
DeviceSessionID: "device-session-1",
|
||||
ReasonCode: "logout_all",
|
||||
ActorType: "system",
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, shared.ErrorCodeServiceUnavailable, shared.CodeOf(err))
|
||||
require.Len(t, publisher.PublishedSnapshots(), shared.MaxProjectionPublishAttempts)
|
||||
|
||||
stored, getErr := store.Get(context.Background(), common.DeviceSessionID("device-session-1"))
|
||||
require.NoError(t, getErr)
|
||||
require.NotNil(t, stored.Revocation)
|
||||
assert.Equal(t, devicesession.StatusRevoked, stored.Status)
|
||||
|
||||
publisher.Err = nil
|
||||
|
||||
result, err := service.Execute(context.Background(), Input{
|
||||
DeviceSessionID: "device-session-1",
|
||||
ReasonCode: "logout_all",
|
||||
ActorType: "system",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "already_revoked", result.Outcome)
|
||||
assert.EqualValues(t, 0, result.AffectedSessionCount)
|
||||
require.Len(t, publisher.PublishedSnapshots(), shared.MaxProjectionPublishAttempts+1)
|
||||
}
|
||||
@@ -0,0 +1,151 @@
|
||||
// Package revokedevicesession implements the trusted internal single-session
|
||||
// revoke use case.
|
||||
package revokedevicesession
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"galaxy/authsession/internal/ports"
|
||||
"galaxy/authsession/internal/service/shared"
|
||||
"galaxy/authsession/internal/telemetry"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Input describes one trusted internal revoke-device-session request.
|
||||
type Input struct {
|
||||
// DeviceSessionID identifies the session that should be revoked.
|
||||
DeviceSessionID string
|
||||
|
||||
// ReasonCode stores the machine-readable revoke reason code.
|
||||
ReasonCode string
|
||||
|
||||
// ActorType stores the machine-readable revoke actor type.
|
||||
ActorType string
|
||||
|
||||
// ActorID stores the optional stable revoke actor identifier.
|
||||
ActorID string
|
||||
}
|
||||
|
||||
// Result describes the frozen internal revoke-device-session acknowledgement.
|
||||
type Result struct {
|
||||
// Outcome reports whether the current call revoked the session or found it
|
||||
// already revoked.
|
||||
Outcome string
|
||||
|
||||
// DeviceSessionID identifies the session addressed by the operation.
|
||||
DeviceSessionID string
|
||||
|
||||
// AffectedSessionCount reports how many sessions changed state during the
|
||||
// current call.
|
||||
AffectedSessionCount int64
|
||||
}
|
||||
|
||||
// Service executes the trusted internal revoke-device-session use case.
|
||||
type Service struct {
|
||||
sessionStore ports.SessionStore
|
||||
publisher ports.GatewaySessionProjectionPublisher
|
||||
clock ports.Clock
|
||||
logger *zap.Logger
|
||||
telemetry *telemetry.Runtime
|
||||
}
|
||||
|
||||
// New returns a revoke-device-session service wired to the required ports.
|
||||
func New(sessionStore ports.SessionStore, publisher ports.GatewaySessionProjectionPublisher, clock ports.Clock) (*Service, error) {
|
||||
return NewWithObservability(sessionStore, publisher, clock, nil, nil)
|
||||
}
|
||||
|
||||
// NewWithObservability returns a revoke-device-session service wired to the
|
||||
// required ports plus optional structured logging and telemetry dependencies.
|
||||
func NewWithObservability(
|
||||
sessionStore ports.SessionStore,
|
||||
publisher ports.GatewaySessionProjectionPublisher,
|
||||
clock ports.Clock,
|
||||
logger *zap.Logger,
|
||||
telemetryRuntime *telemetry.Runtime,
|
||||
) (*Service, error) {
|
||||
switch {
|
||||
case sessionStore == nil:
|
||||
return nil, fmt.Errorf("revokedevicesession: session store must not be nil")
|
||||
case publisher == nil:
|
||||
return nil, fmt.Errorf("revokedevicesession: projection publisher must not be nil")
|
||||
case clock == nil:
|
||||
return nil, fmt.Errorf("revokedevicesession: clock must not be nil")
|
||||
default:
|
||||
return &Service{
|
||||
sessionStore: sessionStore,
|
||||
publisher: publisher,
|
||||
clock: clock,
|
||||
logger: namedLogger(logger, "revoke_device_session"),
|
||||
telemetry: telemetryRuntime,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Execute revokes one device session and republishes the current gateway
|
||||
// projection for the resulting source-of-truth session state.
|
||||
func (s *Service) Execute(ctx context.Context, input Input) (result Result, err error) {
|
||||
logFields := []zap.Field{
|
||||
zap.String("component", "service"),
|
||||
zap.String("use_case", "revoke_device_session"),
|
||||
}
|
||||
defer func() {
|
||||
shared.LogServiceOutcome(s.logger, ctx, "revoke device session completed", err, logFields...)
|
||||
}()
|
||||
|
||||
deviceSessionID, err := shared.ParseDeviceSessionID(input.DeviceSessionID)
|
||||
if err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
logFields = append(logFields, zap.String("device_session_id", deviceSessionID.String()))
|
||||
|
||||
revocation, err := shared.BuildRevocation(input.ReasonCode, input.ActorType, input.ActorID, s.clock.Now())
|
||||
if err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
logFields = append(logFields, zap.String("reason_code", revocation.ReasonCode.String()))
|
||||
|
||||
storeResult, err := s.sessionStore.Revoke(ctx, ports.RevokeSessionInput{
|
||||
DeviceSessionID: deviceSessionID,
|
||||
Revocation: revocation,
|
||||
})
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, ports.ErrNotFound):
|
||||
return Result{}, shared.SessionNotFound()
|
||||
default:
|
||||
return Result{}, shared.ServiceUnavailable(err)
|
||||
}
|
||||
}
|
||||
if err := storeResult.Validate(); err != nil {
|
||||
return Result{}, shared.InternalError(err)
|
||||
}
|
||||
logFields = append(logFields, zap.String("outcome", string(storeResult.Outcome)))
|
||||
|
||||
if err := shared.PublishSessionProjectionWithTelemetry(ctx, s.publisher, storeResult.Session, s.telemetry, "revoke_device_session"); err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
|
||||
affectedSessionCount := int64(0)
|
||||
if storeResult.Outcome == ports.RevokeSessionOutcomeRevoked {
|
||||
affectedSessionCount = 1
|
||||
s.telemetry.RecordSessionRevocations(ctx, "revoke_device_session", revocation.ReasonCode.String(), affectedSessionCount)
|
||||
}
|
||||
logFields = append(logFields, zap.Int64("affected_session_count", affectedSessionCount))
|
||||
|
||||
return Result{
|
||||
Outcome: string(storeResult.Outcome),
|
||||
DeviceSessionID: storeResult.Session.ID.String(),
|
||||
AffectedSessionCount: affectedSessionCount,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func namedLogger(logger *zap.Logger, name string) *zap.Logger {
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
|
||||
return logger.Named(name)
|
||||
}
|
||||
@@ -0,0 +1,166 @@
|
||||
package revokedevicesession
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/devicesession"
|
||||
"galaxy/authsession/internal/domain/gatewayprojection"
|
||||
"galaxy/authsession/internal/service/shared"
|
||||
"galaxy/authsession/internal/testkit"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExecuteRevokesActiveSessionAndPublishes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := &testkit.InMemorySessionStore{}
|
||||
publisher := &testkit.RecordingProjectionPublisher{}
|
||||
record := activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())
|
||||
if err := store.Create(context.Background(), record); err != nil {
|
||||
require.Failf(t, "test failed", "Create() returned error: %v", err)
|
||||
}
|
||||
|
||||
service, err := New(store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := service.Execute(context.Background(), Input{
|
||||
DeviceSessionID: "device-session-1",
|
||||
ReasonCode: "logout_all",
|
||||
ActorType: "system",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "revoked", result.Outcome)
|
||||
assert.EqualValues(t, 1, result.AffectedSessionCount)
|
||||
assert.Equal(t, "device-session-1", result.DeviceSessionID)
|
||||
|
||||
stored, err := store.Get(context.Background(), common.DeviceSessionID("device-session-1"))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, stored.Revocation)
|
||||
assert.Equal(t, devicesession.StatusRevoked, stored.Status)
|
||||
assert.Equal(t, devicesession.RevokeReasonLogoutAll, stored.Revocation.ReasonCode)
|
||||
assert.Equal(t, common.RevokeActorType("system"), stored.Revocation.ActorType)
|
||||
assert.Empty(t, stored.Revocation.ActorID)
|
||||
assert.Equal(t, time.Unix(20, 0).UTC(), stored.Revocation.At)
|
||||
|
||||
published := publisher.PublishedSnapshots()
|
||||
require.Len(t, published, 1)
|
||||
assert.Equal(t, gatewayprojection.StatusRevoked, published[0].Status)
|
||||
assert.Equal(t, common.DeviceSessionID("device-session-1"), published[0].DeviceSessionID)
|
||||
assert.Equal(t, devicesession.RevokeReasonLogoutAll, published[0].RevokeReasonCode)
|
||||
assert.Equal(t, common.RevokeActorType("system"), published[0].RevokeActorType)
|
||||
require.NotNil(t, published[0].RevokedAt)
|
||||
assert.Equal(t, time.Unix(20, 0).UTC(), published[0].RevokedAt.UTC())
|
||||
}
|
||||
|
||||
func TestExecuteAlreadyRevokedReturnsZeroAffectedAndRepublishes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := &testkit.InMemorySessionStore{}
|
||||
publisher := &testkit.RecordingProjectionPublisher{}
|
||||
record := revokedSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())
|
||||
if err := store.Create(context.Background(), record); err != nil {
|
||||
require.Failf(t, "test failed", "Create() returned error: %v", err)
|
||||
}
|
||||
|
||||
service, err := New(store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := service.Execute(context.Background(), Input{
|
||||
DeviceSessionID: "device-session-1",
|
||||
ReasonCode: "logout_all",
|
||||
ActorType: "system",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "already_revoked", result.Outcome)
|
||||
assert.EqualValues(t, 0, result.AffectedSessionCount)
|
||||
assert.Equal(t, "device-session-1", result.DeviceSessionID)
|
||||
|
||||
stored, err := store.Get(context.Background(), common.DeviceSessionID("device-session-1"))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, stored.Revocation)
|
||||
assert.Equal(t, *record.Revocation, *stored.Revocation)
|
||||
|
||||
published := publisher.PublishedSnapshots()
|
||||
require.Len(t, published, 1)
|
||||
assert.Equal(t, gatewayprojection.StatusRevoked, published[0].Status)
|
||||
assert.Equal(t, devicesession.RevokeReasonLogoutAll, published[0].RevokeReasonCode)
|
||||
assert.Equal(t, common.RevokeActorType("system"), published[0].RevokeActorType)
|
||||
require.NotNil(t, published[0].RevokedAt)
|
||||
assert.Equal(t, record.Revocation.At, *published[0].RevokedAt)
|
||||
}
|
||||
|
||||
func TestExecuteReturnsSessionNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
service, err := New(&testkit.InMemorySessionStore{}, &testkit.RecordingProjectionPublisher{}, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "New() returned error: %v", err)
|
||||
}
|
||||
|
||||
_, err = service.Execute(context.Background(), Input{
|
||||
DeviceSessionID: "missing",
|
||||
ReasonCode: "logout_all",
|
||||
ActorType: "system",
|
||||
})
|
||||
assert.Equal(t, shared.ErrorCodeSessionNotFound, shared.CodeOf(err))
|
||||
}
|
||||
|
||||
func TestExecuteReturnsServiceUnavailableWhenPublishFails(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store := &testkit.InMemorySessionStore{}
|
||||
publisher := &testkit.RecordingProjectionPublisher{Err: errors.New("publish failed")}
|
||||
record := activeSessionFixture("device-session-1", "user-1", time.Unix(10, 0).UTC())
|
||||
if err := store.Create(context.Background(), record); err != nil {
|
||||
require.Failf(t, "test failed", "Create() returned error: %v", err)
|
||||
}
|
||||
|
||||
service, err := New(store, publisher, testkit.FixedClock{Time: time.Unix(20, 0).UTC()})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = service.Execute(context.Background(), Input{
|
||||
DeviceSessionID: "device-session-1",
|
||||
ReasonCode: "logout_all",
|
||||
ActorType: "system",
|
||||
})
|
||||
assert.Equal(t, shared.ErrorCodeServiceUnavailable, shared.CodeOf(err))
|
||||
|
||||
stored, getErr := store.Get(context.Background(), common.DeviceSessionID("device-session-1"))
|
||||
require.NoError(t, getErr)
|
||||
require.NotNil(t, stored.Revocation)
|
||||
assert.Equal(t, devicesession.StatusRevoked, stored.Status)
|
||||
assert.Equal(t, devicesession.RevokeReasonLogoutAll, stored.Revocation.ReasonCode)
|
||||
assert.Equal(t, common.RevokeActorType("system"), stored.Revocation.ActorType)
|
||||
}
|
||||
|
||||
func activeSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session {
|
||||
key, err := common.NewClientPublicKey(make([]byte, 32))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return devicesession.Session{
|
||||
ID: common.DeviceSessionID(deviceSessionID),
|
||||
UserID: common.UserID(userID),
|
||||
ClientPublicKey: key,
|
||||
Status: devicesession.StatusActive,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
}
|
||||
|
||||
func revokedSessionFixture(deviceSessionID string, userID string, createdAt time.Time) devicesession.Session {
|
||||
record := activeSessionFixture(deviceSessionID, userID, createdAt)
|
||||
record.Status = devicesession.StatusRevoked
|
||||
record.Revocation = &devicesession.Revocation{
|
||||
At: createdAt.Add(time.Minute),
|
||||
ReasonCode: devicesession.RevokeReasonLogoutAll,
|
||||
ActorType: common.RevokeActorType("system"),
|
||||
}
|
||||
return record
|
||||
}
|
||||
@@ -0,0 +1,167 @@
|
||||
package sendemailcode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/challenge"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/userresolution"
|
||||
"galaxy/authsession/internal/ports"
|
||||
"galaxy/authsession/internal/testkit"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExecuteCreatesThrottledChallengeWithoutUserDirectoryOrMail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
challengeStore := &testkit.InMemoryChallengeStore{}
|
||||
abuseProtector := &testkit.InMemorySendEmailCodeAbuseProtector{}
|
||||
now := time.Unix(10, 0).UTC()
|
||||
require.NoError(t, reserveSendCooldown(abuseProtector, common.Email("pilot@example.com"), now))
|
||||
|
||||
userDirectory := &countingUserDirectory{}
|
||||
mailSender := &testkit.RecordingMailSender{}
|
||||
service, err := NewWithRuntime(
|
||||
challengeStore,
|
||||
userDirectory,
|
||||
&testkit.SequenceIDGenerator{ChallengeIDs: []common.ChallengeID{"challenge-1"}},
|
||||
testkit.FixedCodeGenerator{Code: "654321"},
|
||||
testkit.DeterministicCodeHasher{},
|
||||
mailSender,
|
||||
abuseProtector,
|
||||
testkit.FixedClock{Time: now},
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := service.Execute(context.Background(), Input{Email: "pilot@example.com"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "challenge-1", result.ChallengeID)
|
||||
assert.Zero(t, userDirectory.resolveCalls)
|
||||
assert.Empty(t, mailSender.RecordedInputs())
|
||||
|
||||
record, getErr := challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
|
||||
require.NoError(t, getErr)
|
||||
assert.Equal(t, challenge.StatusDeliveryThrottled, record.Status)
|
||||
assert.Equal(t, challenge.DeliveryThrottled, record.DeliveryState)
|
||||
assert.Equal(t, 1, record.Attempts.Send)
|
||||
}
|
||||
|
||||
func TestExecuteBlockedEmailOutsideThrottleStillSuppressesDelivery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
challengeStore := &testkit.InMemoryChallengeStore{}
|
||||
userDirectory := &testkit.InMemoryUserDirectory{}
|
||||
require.NoError(t, userDirectory.SeedBlockedEmail(common.Email("pilot@example.com"), userresolution.BlockReasonCode("policy_block")))
|
||||
mailSender := &testkit.RecordingMailSender{}
|
||||
|
||||
service, err := NewWithRuntime(
|
||||
challengeStore,
|
||||
userDirectory,
|
||||
&testkit.SequenceIDGenerator{ChallengeIDs: []common.ChallengeID{"challenge-1"}},
|
||||
testkit.FixedCodeGenerator{Code: "654321"},
|
||||
testkit.DeterministicCodeHasher{},
|
||||
mailSender,
|
||||
&testkit.InMemorySendEmailCodeAbuseProtector{},
|
||||
testkit.FixedClock{Time: time.Unix(10, 0).UTC()},
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := service.Execute(context.Background(), Input{Email: "pilot@example.com"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "challenge-1", result.ChallengeID)
|
||||
assert.Empty(t, mailSender.RecordedInputs())
|
||||
|
||||
record, getErr := challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
|
||||
require.NoError(t, getErr)
|
||||
assert.Equal(t, challenge.StatusDeliverySuppressed, record.Status)
|
||||
assert.Equal(t, challenge.DeliverySuppressed, record.DeliveryState)
|
||||
}
|
||||
|
||||
func TestExecuteAllowsAgainAfterCooldown(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
challengeStore := &testkit.InMemoryChallengeStore{}
|
||||
userDirectory := &testkit.InMemoryUserDirectory{}
|
||||
mailSender := &testkit.RecordingMailSender{}
|
||||
abuseProtector := &testkit.InMemorySendEmailCodeAbuseProtector{}
|
||||
clock := &mutableClock{time: time.Unix(10, 0).UTC()}
|
||||
idGenerator := &testkit.SequenceIDGenerator{
|
||||
ChallengeIDs: []common.ChallengeID{"challenge-1", "challenge-2"},
|
||||
}
|
||||
|
||||
service, err := NewWithRuntime(
|
||||
challengeStore,
|
||||
userDirectory,
|
||||
idGenerator,
|
||||
testkit.FixedCodeGenerator{Code: "654321"},
|
||||
testkit.DeterministicCodeHasher{},
|
||||
mailSender,
|
||||
abuseProtector,
|
||||
clock,
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
first, err := service.Execute(context.Background(), Input{Email: "pilot@example.com"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "challenge-1", first.ChallengeID)
|
||||
|
||||
clock.time = clock.time.Add(challenge.ResendThrottleCooldown)
|
||||
|
||||
second, err := service.Execute(context.Background(), Input{Email: "pilot@example.com"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "challenge-2", second.ChallengeID)
|
||||
require.Len(t, mailSender.RecordedInputs(), 2)
|
||||
|
||||
secondRecord, getErr := challengeStore.Get(context.Background(), common.ChallengeID("challenge-2"))
|
||||
require.NoError(t, getErr)
|
||||
assert.Equal(t, challenge.StatusSent, secondRecord.Status)
|
||||
assert.Equal(t, challenge.DeliverySent, secondRecord.DeliveryState)
|
||||
}
|
||||
|
||||
func reserveSendCooldown(protector ports.SendEmailCodeAbuseProtector, email common.Email, now time.Time) error {
|
||||
_, err := protector.CheckAndReserve(context.Background(), ports.SendEmailCodeAbuseInput{
|
||||
Email: email,
|
||||
Now: now,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
type mutableClock struct {
|
||||
time time.Time
|
||||
}
|
||||
|
||||
func (c *mutableClock) Now() time.Time {
|
||||
return c.time
|
||||
}
|
||||
|
||||
type countingUserDirectory struct {
|
||||
resolveCalls int
|
||||
}
|
||||
|
||||
func (d *countingUserDirectory) ResolveByEmail(_ context.Context, _ common.Email) (userresolution.Result, error) {
|
||||
d.resolveCalls++
|
||||
return userresolution.Result{Kind: userresolution.KindCreatable}, nil
|
||||
}
|
||||
|
||||
func (d *countingUserDirectory) ExistsByUserID(context.Context, common.UserID) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (d *countingUserDirectory) EnsureUserByEmail(context.Context, common.Email) (ports.EnsureUserResult, error) {
|
||||
return ports.EnsureUserResult{}, nil
|
||||
}
|
||||
|
||||
func (d *countingUserDirectory) BlockByUserID(context.Context, ports.BlockUserByIDInput) (ports.BlockUserResult, error) {
|
||||
return ports.BlockUserResult{}, nil
|
||||
}
|
||||
|
||||
func (d *countingUserDirectory) BlockByEmail(context.Context, ports.BlockUserByEmailInput) (ports.BlockUserResult, error) {
|
||||
return ports.BlockUserResult{}, nil
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
package sendemailcode
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/testkit"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
func TestExecuteLogsSafeOutcomeFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger, buffer := newObservedServiceLogger()
|
||||
service, err := NewWithObservability(
|
||||
&testkit.InMemoryChallengeStore{},
|
||||
&testkit.InMemoryUserDirectory{},
|
||||
&testkit.SequenceIDGenerator{ChallengeIDs: []common.ChallengeID{"challenge-1"}},
|
||||
testkit.FixedCodeGenerator{Code: "654321"},
|
||||
testkit.DeterministicCodeHasher{},
|
||||
&testkit.RecordingMailSender{},
|
||||
nil,
|
||||
testkit.FixedClock{Time: time.Unix(10, 0).UTC()},
|
||||
logger,
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = service.Execute(context.Background(), Input{Email: "pilot@example.com"})
|
||||
require.NoError(t, err)
|
||||
|
||||
logOutput := buffer.String()
|
||||
assert.Contains(t, logOutput, "send_email_code")
|
||||
assert.Contains(t, logOutput, "challenge-1")
|
||||
assert.Contains(t, logOutput, "\"outcome\":\"sent\"")
|
||||
assert.NotContains(t, logOutput, "pilot@example.com")
|
||||
assert.NotContains(t, logOutput, "654321")
|
||||
}
|
||||
|
||||
func newObservedServiceLogger() (*zap.Logger, *bytes.Buffer) {
|
||||
buffer := &bytes.Buffer{}
|
||||
encoderConfig := zap.NewProductionEncoderConfig()
|
||||
encoderConfig.TimeKey = ""
|
||||
|
||||
core := zapcore.NewCore(
|
||||
zapcore.NewJSONEncoder(encoderConfig),
|
||||
zapcore.AddSync(buffer),
|
||||
zap.DebugLevel,
|
||||
)
|
||||
|
||||
return zap.New(core), buffer
|
||||
}
|
||||
@@ -0,0 +1,331 @@
|
||||
// Package sendemailcode implements the public send-email-code use case.
|
||||
package sendemailcode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"galaxy/authsession/internal/domain/challenge"
|
||||
"galaxy/authsession/internal/domain/userresolution"
|
||||
"galaxy/authsession/internal/ports"
|
||||
"galaxy/authsession/internal/service/shared"
|
||||
"galaxy/authsession/internal/telemetry"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Input describes one public send-email-code request.
|
||||
type Input struct {
|
||||
// Email is the user-supplied e-mail address that should receive the login
|
||||
// code.
|
||||
Email string
|
||||
}
|
||||
|
||||
// Result describes one public send-email-code response.
|
||||
type Result struct {
|
||||
// ChallengeID is the stable challenge identifier returned to the caller.
|
||||
ChallengeID string
|
||||
}
|
||||
|
||||
// Service executes the public send-email-code use case.
|
||||
type Service struct {
|
||||
challengeStore ports.ChallengeStore
|
||||
userDirectory ports.UserDirectory
|
||||
idGenerator ports.IDGenerator
|
||||
codeGenerator ports.CodeGenerator
|
||||
codeHasher ports.CodeHasher
|
||||
mailSender ports.MailSender
|
||||
abuseProtector ports.SendEmailCodeAbuseProtector
|
||||
clock ports.Clock
|
||||
logger *zap.Logger
|
||||
telemetry *telemetry.Runtime
|
||||
}
|
||||
|
||||
// New returns a send-email-code service wired to the required ports.
|
||||
func New(
|
||||
challengeStore ports.ChallengeStore,
|
||||
userDirectory ports.UserDirectory,
|
||||
idGenerator ports.IDGenerator,
|
||||
codeGenerator ports.CodeGenerator,
|
||||
codeHasher ports.CodeHasher,
|
||||
mailSender ports.MailSender,
|
||||
clock ports.Clock,
|
||||
) (*Service, error) {
|
||||
return NewWithRuntime(
|
||||
challengeStore,
|
||||
userDirectory,
|
||||
idGenerator,
|
||||
codeGenerator,
|
||||
codeHasher,
|
||||
mailSender,
|
||||
nil,
|
||||
clock,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
// NewWithRuntime returns a send-email-code service wired to the required
|
||||
// ports plus the optional Stage-17 runtime collaborators.
|
||||
func NewWithRuntime(
|
||||
challengeStore ports.ChallengeStore,
|
||||
userDirectory ports.UserDirectory,
|
||||
idGenerator ports.IDGenerator,
|
||||
codeGenerator ports.CodeGenerator,
|
||||
codeHasher ports.CodeHasher,
|
||||
mailSender ports.MailSender,
|
||||
abuseProtector ports.SendEmailCodeAbuseProtector,
|
||||
clock ports.Clock,
|
||||
telemetryRuntime *telemetry.Runtime,
|
||||
) (*Service, error) {
|
||||
return NewWithObservability(
|
||||
challengeStore,
|
||||
userDirectory,
|
||||
idGenerator,
|
||||
codeGenerator,
|
||||
codeHasher,
|
||||
mailSender,
|
||||
abuseProtector,
|
||||
clock,
|
||||
nil,
|
||||
telemetryRuntime,
|
||||
)
|
||||
}
|
||||
|
||||
// NewWithObservability returns a send-email-code service wired to the required
|
||||
// ports plus optional structured logging and telemetry dependencies.
|
||||
func NewWithObservability(
|
||||
challengeStore ports.ChallengeStore,
|
||||
userDirectory ports.UserDirectory,
|
||||
idGenerator ports.IDGenerator,
|
||||
codeGenerator ports.CodeGenerator,
|
||||
codeHasher ports.CodeHasher,
|
||||
mailSender ports.MailSender,
|
||||
abuseProtector ports.SendEmailCodeAbuseProtector,
|
||||
clock ports.Clock,
|
||||
logger *zap.Logger,
|
||||
telemetryRuntime *telemetry.Runtime,
|
||||
) (*Service, error) {
|
||||
switch {
|
||||
case challengeStore == nil:
|
||||
return nil, fmt.Errorf("sendemailcode: challenge store must not be nil")
|
||||
case userDirectory == nil:
|
||||
return nil, fmt.Errorf("sendemailcode: user directory must not be nil")
|
||||
case idGenerator == nil:
|
||||
return nil, fmt.Errorf("sendemailcode: id generator must not be nil")
|
||||
case codeGenerator == nil:
|
||||
return nil, fmt.Errorf("sendemailcode: code generator must not be nil")
|
||||
case codeHasher == nil:
|
||||
return nil, fmt.Errorf("sendemailcode: code hasher must not be nil")
|
||||
case mailSender == nil:
|
||||
return nil, fmt.Errorf("sendemailcode: mail sender must not be nil")
|
||||
case clock == nil:
|
||||
return nil, fmt.Errorf("sendemailcode: clock must not be nil")
|
||||
default:
|
||||
return &Service{
|
||||
challengeStore: challengeStore,
|
||||
userDirectory: userDirectory,
|
||||
idGenerator: idGenerator,
|
||||
codeGenerator: codeGenerator,
|
||||
codeHasher: codeHasher,
|
||||
mailSender: mailSender,
|
||||
abuseProtector: normalizeAbuseProtector(abuseProtector),
|
||||
clock: clock,
|
||||
logger: namedLogger(logger, "send_email_code"),
|
||||
telemetry: telemetryRuntime,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Execute creates a fresh challenge for every request, stores only the hashed
|
||||
// confirmation code, and records whether delivery was sent or intentionally
|
||||
// suppressed.
|
||||
func (s *Service) Execute(ctx context.Context, input Input) (result Result, err error) {
|
||||
logFields := []zap.Field{
|
||||
zap.String("component", "service"),
|
||||
zap.String("use_case", "send_email_code"),
|
||||
}
|
||||
outcome := ""
|
||||
defer func() {
|
||||
if outcome != "" {
|
||||
logFields = append(logFields, zap.String("outcome", outcome))
|
||||
}
|
||||
if result.ChallengeID != "" {
|
||||
logFields = append(logFields, zap.String("challenge_id", result.ChallengeID))
|
||||
}
|
||||
shared.LogServiceOutcome(s.logger, ctx, "send email code completed", err, logFields...)
|
||||
}()
|
||||
|
||||
email, err := shared.ParseEmail(input.Email)
|
||||
if err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
|
||||
now := s.clock.Now().UTC()
|
||||
abuseResult, err := s.abuseProtector.CheckAndReserve(ctx, ports.SendEmailCodeAbuseInput{
|
||||
Email: email,
|
||||
Now: now,
|
||||
})
|
||||
if err != nil {
|
||||
return Result{}, shared.ServiceUnavailable(err)
|
||||
}
|
||||
if err := abuseResult.Validate(); err != nil {
|
||||
return Result{}, shared.InternalError(err)
|
||||
}
|
||||
|
||||
challengeID, err := s.idGenerator.NewChallengeID()
|
||||
if err != nil {
|
||||
return Result{}, shared.ServiceUnavailable(err)
|
||||
}
|
||||
code, err := s.codeGenerator.Generate()
|
||||
if err != nil {
|
||||
return Result{}, shared.ServiceUnavailable(err)
|
||||
}
|
||||
codeHash, err := s.codeHasher.Hash(code)
|
||||
if err != nil {
|
||||
return Result{}, shared.ServiceUnavailable(err)
|
||||
}
|
||||
|
||||
pendingStatus, pendingDeliveryState, err := ports.SendEmailCodeThrottleStatusToChallengeStatus(abuseResult.Outcome)
|
||||
if err != nil {
|
||||
return Result{}, shared.InternalError(err)
|
||||
}
|
||||
pending := challenge.Challenge{
|
||||
ID: challengeID,
|
||||
Email: email,
|
||||
CodeHash: codeHash,
|
||||
Status: pendingStatus,
|
||||
DeliveryState: pendingDeliveryState,
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(challenge.InitialTTL),
|
||||
}
|
||||
if err := pending.Validate(); err != nil {
|
||||
return Result{}, shared.InternalError(err)
|
||||
}
|
||||
if err := s.challengeStore.Create(ctx, pending); err != nil {
|
||||
return Result{}, shared.ServiceUnavailable(err)
|
||||
}
|
||||
s.telemetry.RecordChallengeCreated(ctx)
|
||||
|
||||
final := pending
|
||||
final.Attempts.Send = 1
|
||||
final.Abuse.LastAttemptAt = &now
|
||||
if abuseResult.Outcome == ports.SendEmailCodeAbuseOutcomeThrottled {
|
||||
result, err = s.finishChallenge(ctx, pending, final)
|
||||
if err == nil {
|
||||
outcome = string(telemetry.SendEmailCodeOutcomeThrottled)
|
||||
s.telemetry.RecordSendEmailCode(ctx, telemetry.SendEmailCodeOutcomeThrottled, telemetry.SendEmailCodeReasonThrottled)
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
resolution, err := s.userDirectory.ResolveByEmail(ctx, email)
|
||||
if err != nil {
|
||||
return Result{}, shared.ServiceUnavailable(err)
|
||||
}
|
||||
if err := resolution.Validate(); err != nil {
|
||||
return Result{}, shared.InternalError(err)
|
||||
}
|
||||
s.telemetry.RecordUserDirectoryOutcome(ctx, "resolve_by_email", string(resolution.Kind))
|
||||
|
||||
switch resolution.Kind {
|
||||
case userresolution.KindBlocked:
|
||||
final.Status = challenge.StatusDeliverySuppressed
|
||||
final.DeliveryState = challenge.DeliverySuppressed
|
||||
result, err = s.finishChallenge(ctx, pending, final)
|
||||
if err == nil {
|
||||
outcome = string(telemetry.SendEmailCodeOutcomeSuppressed)
|
||||
s.telemetry.RecordSendEmailCode(ctx, telemetry.SendEmailCodeOutcomeSuppressed, telemetry.SendEmailCodeReasonBlocked)
|
||||
}
|
||||
return result, err
|
||||
default:
|
||||
deliveryResult, err := s.mailSender.SendLoginCode(ctx, ports.SendLoginCodeInput{
|
||||
Email: email,
|
||||
Code: code,
|
||||
})
|
||||
if err != nil {
|
||||
final.Status = challenge.StatusFailed
|
||||
final.DeliveryState = challenge.DeliveryFailed
|
||||
if _, persistErr := s.finishChallenge(ctx, pending, final); persistErr != nil {
|
||||
return Result{}, persistErr
|
||||
}
|
||||
outcome = string(telemetry.SendEmailCodeOutcomeFailed)
|
||||
s.telemetry.RecordSendEmailCode(ctx, telemetry.SendEmailCodeOutcomeFailed, telemetry.SendEmailCodeReasonMailSender)
|
||||
|
||||
return Result{}, shared.ServiceUnavailable(err)
|
||||
}
|
||||
if err := deliveryResult.Validate(); err != nil {
|
||||
return Result{}, shared.InternalError(err)
|
||||
}
|
||||
|
||||
switch deliveryResult.Outcome {
|
||||
case ports.SendLoginCodeOutcomeSent:
|
||||
final.Status = challenge.StatusSent
|
||||
final.DeliveryState = challenge.DeliverySent
|
||||
result, err = s.finishChallenge(ctx, pending, final)
|
||||
if err == nil {
|
||||
outcome = string(telemetry.SendEmailCodeOutcomeSent)
|
||||
s.telemetry.RecordSendEmailCode(ctx, telemetry.SendEmailCodeOutcomeSent, "")
|
||||
}
|
||||
return result, err
|
||||
case ports.SendLoginCodeOutcomeSuppressed:
|
||||
final.Status = challenge.StatusDeliverySuppressed
|
||||
final.DeliveryState = challenge.DeliverySuppressed
|
||||
result, err = s.finishChallenge(ctx, pending, final)
|
||||
if err == nil {
|
||||
outcome = string(telemetry.SendEmailCodeOutcomeSuppressed)
|
||||
s.telemetry.RecordSendEmailCode(ctx, telemetry.SendEmailCodeOutcomeSuppressed, telemetry.SendEmailCodeReasonMailSender)
|
||||
}
|
||||
return result, err
|
||||
default:
|
||||
return Result{}, shared.InternalError(fmt.Errorf("sendemailcode: unsupported delivery outcome %q", deliveryResult.Outcome))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) finishChallenge(ctx context.Context, pending challenge.Challenge, final challenge.Challenge) (Result, error) {
|
||||
if err := final.Validate(); err != nil {
|
||||
return Result{}, shared.InternalError(err)
|
||||
}
|
||||
if err := s.challengeStore.CompareAndSwap(ctx, pending, final); err != nil {
|
||||
return Result{}, shared.ServiceUnavailable(err)
|
||||
}
|
||||
|
||||
return Result{ChallengeID: final.ID.String()}, nil
|
||||
}
|
||||
|
||||
func normalizeAbuseProtector(protector ports.SendEmailCodeAbuseProtector) ports.SendEmailCodeAbuseProtector {
|
||||
if protector == nil {
|
||||
return allowAllSendEmailCodeAbuseProtector{}
|
||||
}
|
||||
|
||||
value := reflect.ValueOf(protector)
|
||||
switch value.Kind() {
|
||||
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice:
|
||||
if value.IsNil() {
|
||||
return allowAllSendEmailCodeAbuseProtector{}
|
||||
}
|
||||
}
|
||||
|
||||
return protector
|
||||
}
|
||||
|
||||
type allowAllSendEmailCodeAbuseProtector struct{}
|
||||
|
||||
func (allowAllSendEmailCodeAbuseProtector) CheckAndReserve(_ context.Context, input ports.SendEmailCodeAbuseInput) (ports.SendEmailCodeAbuseResult, error) {
|
||||
if err := input.Validate(); err != nil {
|
||||
return ports.SendEmailCodeAbuseResult{}, err
|
||||
}
|
||||
|
||||
return ports.SendEmailCodeAbuseResult{
|
||||
Outcome: ports.SendEmailCodeAbuseOutcomeAllowed,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func namedLogger(logger *zap.Logger, name string) *zap.Logger {
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
|
||||
return logger.Named(name)
|
||||
}
|
||||
@@ -0,0 +1,310 @@
|
||||
package sendemailcode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/challenge"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/userresolution"
|
||||
"galaxy/authsession/internal/ports"
|
||||
"galaxy/authsession/internal/service/shared"
|
||||
"galaxy/authsession/internal/testkit"
|
||||
)
|
||||
|
||||
func TestExecuteSendsChallengeForExistingAndCreatableUsers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
seed func(*testkit.InMemoryUserDirectory) error
|
||||
email string
|
||||
}{
|
||||
{
|
||||
name: "existing",
|
||||
seed: func(directory *testkit.InMemoryUserDirectory) error {
|
||||
return directory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1"))
|
||||
},
|
||||
email: " pilot@example.com ",
|
||||
},
|
||||
{
|
||||
name: "creatable",
|
||||
seed: func(*testkit.InMemoryUserDirectory) error { return nil },
|
||||
email: "new@example.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
challengeStore := &testkit.InMemoryChallengeStore{}
|
||||
userDirectory := &testkit.InMemoryUserDirectory{}
|
||||
if err := tt.seed(userDirectory); err != nil {
|
||||
require.Failf(t, "test failed", "seed() returned error: %v", err)
|
||||
}
|
||||
mailSender := &testkit.RecordingMailSender{}
|
||||
service, err := New(
|
||||
challengeStore,
|
||||
userDirectory,
|
||||
&testkit.SequenceIDGenerator{ChallengeIDs: []common.ChallengeID{"challenge-1"}},
|
||||
testkit.FixedCodeGenerator{Code: "654321"},
|
||||
testkit.DeterministicCodeHasher{},
|
||||
mailSender,
|
||||
testkit.FixedClock{Time: time.Unix(10, 0).UTC()},
|
||||
)
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "New() returned error: %v", err)
|
||||
}
|
||||
|
||||
result, err := service.Execute(context.Background(), Input{Email: tt.email})
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "Execute() returned error: %v", err)
|
||||
}
|
||||
if result.ChallengeID != "challenge-1" {
|
||||
require.Failf(t, "test failed", "Execute().ChallengeID = %q, want %q", result.ChallengeID, "challenge-1")
|
||||
}
|
||||
if len(mailSender.RecordedInputs()) != 1 {
|
||||
require.Failf(t, "test failed", "RecordedInputs() length = %d, want 1", len(mailSender.RecordedInputs()))
|
||||
}
|
||||
|
||||
record, err := challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "Get() returned error: %v", err)
|
||||
}
|
||||
if record.Status != challenge.StatusSent || record.DeliveryState != challenge.DeliverySent {
|
||||
require.Failf(t, "test failed", "challenge state = %q/%q", record.Status, record.DeliveryState)
|
||||
}
|
||||
if record.Attempts.Send != 1 {
|
||||
require.Failf(t, "test failed", "Attempts.Send = %d, want 1", record.Attempts.Send)
|
||||
}
|
||||
if string(record.CodeHash) == "654321" {
|
||||
require.FailNow(t, "CodeHash stored cleartext code")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteSuppressesDeliveryForBlockedEmail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
challengeStore := &testkit.InMemoryChallengeStore{}
|
||||
userDirectory := &testkit.InMemoryUserDirectory{}
|
||||
if err := userDirectory.SeedBlockedEmail(common.Email("pilot@example.com"), userresolution.BlockReasonCode("policy_block")); err != nil {
|
||||
require.Failf(t, "test failed", "SeedBlockedEmail() returned error: %v", err)
|
||||
}
|
||||
mailSender := &testkit.RecordingMailSender{}
|
||||
|
||||
service, err := New(
|
||||
challengeStore,
|
||||
userDirectory,
|
||||
&testkit.SequenceIDGenerator{ChallengeIDs: []common.ChallengeID{"challenge-1"}},
|
||||
testkit.FixedCodeGenerator{Code: "654321"},
|
||||
testkit.DeterministicCodeHasher{},
|
||||
mailSender,
|
||||
testkit.FixedClock{Time: time.Unix(10, 0).UTC()},
|
||||
)
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "New() returned error: %v", err)
|
||||
}
|
||||
|
||||
result, err := service.Execute(context.Background(), Input{Email: "pilot@example.com"})
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "Execute() returned error: %v", err)
|
||||
}
|
||||
if result.ChallengeID != "challenge-1" {
|
||||
require.Failf(t, "test failed", "Execute().ChallengeID = %q, want %q", result.ChallengeID, "challenge-1")
|
||||
}
|
||||
if len(mailSender.RecordedInputs()) != 0 {
|
||||
require.Failf(t, "test failed", "RecordedInputs() length = %d, want 0", len(mailSender.RecordedInputs()))
|
||||
}
|
||||
|
||||
record, err := challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "Get() returned error: %v", err)
|
||||
}
|
||||
if record.Status != challenge.StatusDeliverySuppressed || record.DeliveryState != challenge.DeliverySuppressed {
|
||||
require.Failf(t, "test failed", "challenge state = %q/%q", record.Status, record.DeliveryState)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteHandlesMailSenderSuppressedOutcome(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
challengeStore := &testkit.InMemoryChallengeStore{}
|
||||
mailSender := &testkit.RecordingMailSender{
|
||||
DefaultResult: ports.SendLoginCodeResult{Outcome: ports.SendLoginCodeOutcomeSuppressed},
|
||||
}
|
||||
|
||||
service, err := New(
|
||||
challengeStore,
|
||||
&testkit.InMemoryUserDirectory{},
|
||||
&testkit.SequenceIDGenerator{ChallengeIDs: []common.ChallengeID{"challenge-1"}},
|
||||
testkit.FixedCodeGenerator{Code: "654321"},
|
||||
testkit.DeterministicCodeHasher{},
|
||||
mailSender,
|
||||
testkit.FixedClock{Time: time.Unix(10, 0).UTC()},
|
||||
)
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "New() returned error: %v", err)
|
||||
}
|
||||
|
||||
_, err = service.Execute(context.Background(), Input{Email: "pilot@example.com"})
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "Execute() returned error: %v", err)
|
||||
}
|
||||
|
||||
record, err := challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "Get() returned error: %v", err)
|
||||
}
|
||||
if record.Status != challenge.StatusDeliverySuppressed || record.DeliveryState != challenge.DeliverySuppressed {
|
||||
require.Failf(t, "test failed", "challenge state = %q/%q", record.Status, record.DeliveryState)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteMarksChallengeFailedWhenMailSenderFails(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
challengeStore := &testkit.InMemoryChallengeStore{}
|
||||
mailSender := &testkit.RecordingMailSender{Err: errors.New("mail failed")}
|
||||
|
||||
service, err := New(
|
||||
challengeStore,
|
||||
&testkit.InMemoryUserDirectory{},
|
||||
&testkit.SequenceIDGenerator{ChallengeIDs: []common.ChallengeID{"challenge-1"}},
|
||||
testkit.FixedCodeGenerator{Code: "654321"},
|
||||
testkit.DeterministicCodeHasher{},
|
||||
mailSender,
|
||||
testkit.FixedClock{Time: time.Unix(10, 0).UTC()},
|
||||
)
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "New() returned error: %v", err)
|
||||
}
|
||||
|
||||
_, err = service.Execute(context.Background(), Input{Email: "pilot@example.com"})
|
||||
if shared.CodeOf(err) != shared.ErrorCodeServiceUnavailable {
|
||||
require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeServiceUnavailable)
|
||||
}
|
||||
|
||||
record, err := challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "Get() returned error: %v", err)
|
||||
}
|
||||
if record.Status != challenge.StatusFailed || record.DeliveryState != challenge.DeliveryFailed {
|
||||
require.Failf(t, "test failed", "challenge state = %q/%q", record.Status, record.DeliveryState)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteReturnsInvalidRequestForBadEmail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
service, err := New(
|
||||
&testkit.InMemoryChallengeStore{},
|
||||
&testkit.InMemoryUserDirectory{},
|
||||
&testkit.SequenceIDGenerator{},
|
||||
testkit.FixedCodeGenerator{Code: "654321"},
|
||||
testkit.DeterministicCodeHasher{},
|
||||
&testkit.RecordingMailSender{},
|
||||
testkit.FixedClock{Time: time.Unix(10, 0).UTC()},
|
||||
)
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "New() returned error: %v", err)
|
||||
}
|
||||
|
||||
_, err = service.Execute(context.Background(), Input{Email: "pilot"})
|
||||
if shared.CodeOf(err) != shared.ErrorCodeInvalidRequest {
|
||||
require.Failf(t, "test failed", "Execute() error code = %q, want %q", shared.CodeOf(err), shared.ErrorCodeInvalidRequest)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteCreatesFreshChallengeForRepeatedSend(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
challengeStore := &testkit.InMemoryChallengeStore{}
|
||||
mailSender := &testkit.RecordingMailSender{}
|
||||
clock := testkit.FixedClock{Time: time.Unix(10, 0).UTC()}
|
||||
|
||||
service, err := New(
|
||||
challengeStore,
|
||||
&testkit.InMemoryUserDirectory{},
|
||||
&testkit.SequenceIDGenerator{
|
||||
ChallengeIDs: []common.ChallengeID{"challenge-1", "challenge-2"},
|
||||
},
|
||||
testkit.FixedCodeGenerator{Code: "654321"},
|
||||
testkit.DeterministicCodeHasher{},
|
||||
mailSender,
|
||||
clock,
|
||||
)
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "New() returned error: %v", err)
|
||||
}
|
||||
|
||||
first, err := service.Execute(context.Background(), Input{Email: "pilot@example.com"})
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "first Execute() returned error: %v", err)
|
||||
}
|
||||
second, err := service.Execute(context.Background(), Input{Email: "pilot@example.com"})
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "second Execute() returned error: %v", err)
|
||||
}
|
||||
if first.ChallengeID == second.ChallengeID {
|
||||
require.Failf(t, "test failed", "challenge ids are equal: %q", first.ChallengeID)
|
||||
}
|
||||
|
||||
firstRecord, err := challengeStore.Get(context.Background(), common.ChallengeID(first.ChallengeID))
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "Get(%q) returned error: %v", first.ChallengeID, err)
|
||||
}
|
||||
secondRecord, err := challengeStore.Get(context.Background(), common.ChallengeID(second.ChallengeID))
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "Get(%q) returned error: %v", second.ChallengeID, err)
|
||||
}
|
||||
if firstRecord.Status != challenge.StatusSent {
|
||||
require.Failf(t, "test failed", "first challenge status = %q, want %q", firstRecord.Status, challenge.StatusSent)
|
||||
}
|
||||
if secondRecord.Status != challenge.StatusSent {
|
||||
require.Failf(t, "test failed", "second challenge status = %q, want %q", secondRecord.Status, challenge.StatusSent)
|
||||
}
|
||||
if len(mailSender.RecordedInputs()) != 2 {
|
||||
require.Failf(t, "test failed", "RecordedInputs() length = %d, want 2", len(mailSender.RecordedInputs()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteSetsChallengeExpirationFromInitialTTL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Unix(10, 0).UTC()
|
||||
challengeStore := &testkit.InMemoryChallengeStore{}
|
||||
|
||||
service, err := New(
|
||||
challengeStore,
|
||||
&testkit.InMemoryUserDirectory{},
|
||||
&testkit.SequenceIDGenerator{ChallengeIDs: []common.ChallengeID{"challenge-1"}},
|
||||
testkit.FixedCodeGenerator{Code: "654321"},
|
||||
testkit.DeterministicCodeHasher{},
|
||||
&testkit.RecordingMailSender{},
|
||||
testkit.FixedClock{Time: now},
|
||||
)
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "New() returned error: %v", err)
|
||||
}
|
||||
|
||||
if _, err := service.Execute(context.Background(), Input{Email: "pilot@example.com"}); err != nil {
|
||||
require.Failf(t, "test failed", "Execute() returned error: %v", err)
|
||||
}
|
||||
|
||||
record, err := challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
|
||||
if err != nil {
|
||||
require.Failf(t, "test failed", "Get() returned error: %v", err)
|
||||
}
|
||||
wantExpiresAt := now.Add(challenge.InitialTTL)
|
||||
if !record.ExpiresAt.Equal(wantExpiresAt) {
|
||||
require.Failf(t, "test failed", "ExpiresAt = %s, want %s", record.ExpiresAt, wantExpiresAt)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
package sendemailcode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
stubmail "galaxy/authsession/internal/adapters/mail"
|
||||
"galaxy/authsession/internal/domain/challenge"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/service/shared"
|
||||
"galaxy/authsession/internal/testkit"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExecuteWithStubSender(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sender *stubmail.StubSender
|
||||
wantStatus challenge.Status
|
||||
wantDeliveryState challenge.DeliveryState
|
||||
wantErrorCode string
|
||||
wantRecordedAttempt int
|
||||
}{
|
||||
{
|
||||
name: "sent",
|
||||
sender: &stubmail.StubSender{},
|
||||
wantStatus: challenge.StatusSent,
|
||||
wantDeliveryState: challenge.DeliverySent,
|
||||
wantRecordedAttempt: 1,
|
||||
},
|
||||
{
|
||||
name: "suppressed",
|
||||
sender: &stubmail.StubSender{
|
||||
DefaultMode: stubmail.StubModeSuppressed,
|
||||
},
|
||||
wantStatus: challenge.StatusDeliverySuppressed,
|
||||
wantDeliveryState: challenge.DeliverySuppressed,
|
||||
wantRecordedAttempt: 1,
|
||||
},
|
||||
{
|
||||
name: "failed",
|
||||
sender: &stubmail.StubSender{
|
||||
DefaultMode: stubmail.StubModeFailed,
|
||||
DefaultError: errors.New("stub delivery failed"),
|
||||
},
|
||||
wantStatus: challenge.StatusFailed,
|
||||
wantDeliveryState: challenge.DeliveryFailed,
|
||||
wantErrorCode: shared.ErrorCodeServiceUnavailable,
|
||||
wantRecordedAttempt: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
challengeStore := &testkit.InMemoryChallengeStore{}
|
||||
service, err := New(
|
||||
challengeStore,
|
||||
&testkit.InMemoryUserDirectory{},
|
||||
&testkit.SequenceIDGenerator{ChallengeIDs: []common.ChallengeID{"challenge-1"}},
|
||||
testkit.FixedCodeGenerator{Code: "654321"},
|
||||
testkit.DeterministicCodeHasher{},
|
||||
tt.sender,
|
||||
testkit.FixedClock{Time: time.Unix(10, 0).UTC()},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := service.Execute(context.Background(), Input{Email: "pilot@example.com"})
|
||||
if tt.wantErrorCode == "" {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "challenge-1", result.ChallengeID)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tt.wantErrorCode, shared.CodeOf(err))
|
||||
assert.Equal(t, Result{}, result)
|
||||
}
|
||||
|
||||
record, getErr := challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
|
||||
require.NoError(t, getErr)
|
||||
assert.Equal(t, tt.wantStatus, record.Status)
|
||||
assert.Equal(t, tt.wantDeliveryState, record.DeliveryState)
|
||||
|
||||
attempts := tt.sender.RecordedAttempts()
|
||||
require.Len(t, attempts, tt.wantRecordedAttempt)
|
||||
assert.Equal(t, common.Email("pilot@example.com"), attempts[0].Input.Email)
|
||||
assert.Equal(t, "654321", attempts[0].Input.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
package sendemailcode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
stubmail "galaxy/authsession/internal/adapters/mail"
|
||||
stubuserservice "galaxy/authsession/internal/adapters/userservice"
|
||||
"galaxy/authsession/internal/domain/challenge"
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/userresolution"
|
||||
"galaxy/authsession/internal/testkit"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExecuteWithRuntimeStubUserDirectory(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
seed func(*stubuserservice.StubDirectory) error
|
||||
email string
|
||||
wantStatus challenge.Status
|
||||
wantDeliveryState challenge.DeliveryState
|
||||
wantMailCalls int
|
||||
}{
|
||||
{
|
||||
name: "existing user",
|
||||
email: "pilot@example.com",
|
||||
seed: func(directory *stubuserservice.StubDirectory) error {
|
||||
return directory.SeedExisting(common.Email("pilot@example.com"), common.UserID("user-1"))
|
||||
},
|
||||
wantStatus: challenge.StatusSent,
|
||||
wantDeliveryState: challenge.DeliverySent,
|
||||
wantMailCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "creatable user",
|
||||
email: "new@example.com",
|
||||
seed: func(*stubuserservice.StubDirectory) error { return nil },
|
||||
wantStatus: challenge.StatusSent,
|
||||
wantDeliveryState: challenge.DeliverySent,
|
||||
wantMailCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "blocked email",
|
||||
email: "blocked@example.com",
|
||||
seed: func(directory *stubuserservice.StubDirectory) error {
|
||||
return directory.SeedBlockedEmail(common.Email("blocked@example.com"), userresolution.BlockReasonCode("policy_block"))
|
||||
},
|
||||
wantStatus: challenge.StatusDeliverySuppressed,
|
||||
wantDeliveryState: challenge.DeliverySuppressed,
|
||||
wantMailCalls: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
userDirectory := &stubuserservice.StubDirectory{}
|
||||
require.NoError(t, tt.seed(userDirectory))
|
||||
|
||||
challengeStore := &testkit.InMemoryChallengeStore{}
|
||||
mailSender := &stubmail.StubSender{}
|
||||
service, err := New(
|
||||
challengeStore,
|
||||
userDirectory,
|
||||
&testkit.SequenceIDGenerator{ChallengeIDs: []common.ChallengeID{"challenge-1"}},
|
||||
testkit.FixedCodeGenerator{Code: "654321"},
|
||||
testkit.DeterministicCodeHasher{},
|
||||
mailSender,
|
||||
testkit.FixedClock{Time: time.Unix(10, 0).UTC()},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := service.Execute(context.Background(), Input{Email: tt.email})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "challenge-1", result.ChallengeID)
|
||||
|
||||
record, err := challengeStore.Get(context.Background(), common.ChallengeID("challenge-1"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantStatus, record.Status)
|
||||
assert.Equal(t, tt.wantDeliveryState, record.DeliveryState)
|
||||
assert.Len(t, mailSender.RecordedAttempts(), tt.wantMailCalls)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,171 @@
|
||||
package sendemailcode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/userresolution"
|
||||
authtelemetry "galaxy/authsession/internal/telemetry"
|
||||
"galaxy/authsession/internal/testkit"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
sdkmetric "go.opentelemetry.io/otel/sdk/metric"
|
||||
"go.opentelemetry.io/otel/sdk/metric/metricdata"
|
||||
)
|
||||
|
||||
func TestExecuteRecordsSentMetric(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
runtime, reader := newObservedTelemetryRuntime(t)
|
||||
service, _, mailSender := newObservedSendService(t, observedSendOptions{
|
||||
Telemetry: runtime,
|
||||
})
|
||||
|
||||
_, err := service.Execute(context.Background(), Input{Email: "pilot@example.com"})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, mailSender.RecordedInputs(), 1)
|
||||
|
||||
assertMetricCount(t, reader, "authsession.send_email_code.attempts", map[string]string{
|
||||
"outcome": "sent",
|
||||
}, 1)
|
||||
}
|
||||
|
||||
func TestExecuteRecordsBlockedSuppressedMetric(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
runtime, reader := newObservedTelemetryRuntime(t)
|
||||
service, _, _ := newObservedSendService(t, observedSendOptions{
|
||||
Telemetry: runtime,
|
||||
SeedBlockedEmail: true,
|
||||
})
|
||||
|
||||
_, err := service.Execute(context.Background(), Input{Email: "pilot@example.com"})
|
||||
require.NoError(t, err)
|
||||
|
||||
assertMetricCount(t, reader, "authsession.send_email_code.attempts", map[string]string{
|
||||
"outcome": "suppressed",
|
||||
"reason": "blocked",
|
||||
}, 1)
|
||||
}
|
||||
|
||||
func TestExecuteRecordsThrottledMetric(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
runtime, reader := newObservedTelemetryRuntime(t)
|
||||
abuseProtector := &testkit.InMemorySendEmailCodeAbuseProtector{}
|
||||
now := time.Unix(10, 0).UTC()
|
||||
require.NoError(t, reserveSendCooldown(abuseProtector, common.Email("pilot@example.com"), now))
|
||||
|
||||
service, _, mailSender := newObservedSendService(t, observedSendOptions{
|
||||
Telemetry: runtime,
|
||||
AbuseProtector: abuseProtector,
|
||||
Clock: testkit.FixedClock{Time: now},
|
||||
})
|
||||
|
||||
_, err := service.Execute(context.Background(), Input{Email: "pilot@example.com"})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, mailSender.RecordedInputs())
|
||||
|
||||
assertMetricCount(t, reader, "authsession.send_email_code.attempts", map[string]string{
|
||||
"outcome": "throttled",
|
||||
"reason": "throttled",
|
||||
}, 1)
|
||||
}
|
||||
|
||||
type observedSendOptions struct {
|
||||
Telemetry *authtelemetry.Runtime
|
||||
AbuseProtector *testkit.InMemorySendEmailCodeAbuseProtector
|
||||
SeedBlockedEmail bool
|
||||
Clock portsClock
|
||||
}
|
||||
|
||||
type portsClock interface {
|
||||
Now() time.Time
|
||||
}
|
||||
|
||||
func newObservedSendService(t *testing.T, options observedSendOptions) (*Service, *testkit.InMemoryChallengeStore, *testkit.RecordingMailSender) {
|
||||
t.Helper()
|
||||
|
||||
challengeStore := &testkit.InMemoryChallengeStore{}
|
||||
userDirectory := &testkit.InMemoryUserDirectory{}
|
||||
if options.SeedBlockedEmail {
|
||||
require.NoError(t, userDirectory.SeedBlockedEmail(common.Email("pilot@example.com"), userresolution.BlockReasonCode("policy_block")))
|
||||
}
|
||||
mailSender := &testkit.RecordingMailSender{}
|
||||
clock := options.Clock
|
||||
if clock == nil {
|
||||
clock = testkit.FixedClock{Time: time.Unix(10, 0).UTC()}
|
||||
}
|
||||
|
||||
service, err := NewWithRuntime(
|
||||
challengeStore,
|
||||
userDirectory,
|
||||
&testkit.SequenceIDGenerator{ChallengeIDs: []common.ChallengeID{"challenge-1"}},
|
||||
testkit.FixedCodeGenerator{Code: "654321"},
|
||||
testkit.DeterministicCodeHasher{},
|
||||
mailSender,
|
||||
options.AbuseProtector,
|
||||
clock,
|
||||
options.Telemetry,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
return service, challengeStore, mailSender
|
||||
}
|
||||
|
||||
func newObservedTelemetryRuntime(t *testing.T) (*authtelemetry.Runtime, *sdkmetric.ManualReader) {
|
||||
t.Helper()
|
||||
|
||||
reader := sdkmetric.NewManualReader()
|
||||
provider := sdkmetric.NewMeterProvider(sdkmetric.WithReader(reader))
|
||||
|
||||
runtime, err := authtelemetry.New(provider)
|
||||
require.NoError(t, err)
|
||||
|
||||
return runtime, reader
|
||||
}
|
||||
|
||||
func assertMetricCount(t *testing.T, reader *sdkmetric.ManualReader, metricName string, wantAttrs map[string]string, wantValue int64) {
|
||||
t.Helper()
|
||||
|
||||
var resourceMetrics metricdata.ResourceMetrics
|
||||
require.NoError(t, reader.Collect(context.Background(), &resourceMetrics))
|
||||
|
||||
for _, scopeMetrics := range resourceMetrics.ScopeMetrics {
|
||||
for _, metric := range scopeMetrics.Metrics {
|
||||
if metric.Name != metricName {
|
||||
continue
|
||||
}
|
||||
|
||||
sum, ok := metric.Data.(metricdata.Sum[int64])
|
||||
require.True(t, ok)
|
||||
|
||||
for _, point := range sum.DataPoints {
|
||||
if hasMetricAttributes(point.Attributes.ToSlice(), wantAttrs) {
|
||||
assert.Equal(t, wantValue, point.Value)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
require.Failf(t, "test failed", "metric %q with attrs %v not found", metricName, wantAttrs)
|
||||
}
|
||||
|
||||
func hasMetricAttributes(values []attribute.KeyValue, want map[string]string) bool {
|
||||
if len(values) != len(want) {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, value := range values {
|
||||
if want[string(value.Key)] != value.Value.AsString() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
// Package shared provides cross-use-case application helpers for auth/session
|
||||
// services, including typed service errors, input normalization, DTO mapping,
|
||||
// and application-level retry helpers.
|
||||
package shared
|
||||
@@ -0,0 +1,407 @@
|
||||
package shared
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
// ErrorCodeInvalidRequest reports malformed or semantically invalid service
|
||||
// input.
|
||||
ErrorCodeInvalidRequest = "invalid_request"
|
||||
|
||||
// ErrorCodeChallengeNotFound reports that the requested challenge does not
|
||||
// exist.
|
||||
ErrorCodeChallengeNotFound = "challenge_not_found"
|
||||
|
||||
// ErrorCodeChallengeExpired reports that the requested challenge may no
|
||||
// longer be confirmed.
|
||||
ErrorCodeChallengeExpired = "challenge_expired"
|
||||
|
||||
// ErrorCodeInvalidCode reports that the submitted confirmation code does not
|
||||
// match the stored challenge.
|
||||
ErrorCodeInvalidCode = "invalid_code"
|
||||
|
||||
// ErrorCodeInvalidClientPublicKey reports that the submitted client public
|
||||
// key does not satisfy the Ed25519/base64 contract.
|
||||
ErrorCodeInvalidClientPublicKey = "invalid_client_public_key"
|
||||
|
||||
// ErrorCodeBlockedByPolicy reports that the auth flow is denied by current
|
||||
// user or registration policy.
|
||||
ErrorCodeBlockedByPolicy = "blocked_by_policy"
|
||||
|
||||
// ErrorCodeSessionLimitExceeded reports that creating another active session
|
||||
// would violate the configured limit.
|
||||
ErrorCodeSessionLimitExceeded = "session_limit_exceeded"
|
||||
|
||||
// ErrorCodeSessionNotFound reports that the requested device session does
|
||||
// not exist.
|
||||
ErrorCodeSessionNotFound = "session_not_found"
|
||||
|
||||
// ErrorCodeSubjectNotFound reports that the requested trusted internal
|
||||
// subject does not exist.
|
||||
ErrorCodeSubjectNotFound = "subject_not_found"
|
||||
|
||||
// ErrorCodeServiceUnavailable reports that a required dependency or
|
||||
// propagation step is temporarily unavailable.
|
||||
ErrorCodeServiceUnavailable = "service_unavailable"
|
||||
|
||||
// ErrorCodeInternalError reports that local state is inconsistent or an
|
||||
// invariant was broken unexpectedly.
|
||||
ErrorCodeInternalError = "internal_error"
|
||||
)
|
||||
|
||||
const genericInvalidRequestMessage = "request is invalid"
|
||||
|
||||
var publicErrorStatusCodes = map[string]int{
|
||||
ErrorCodeInvalidRequest: http.StatusBadRequest,
|
||||
ErrorCodeInvalidClientPublicKey: http.StatusBadRequest,
|
||||
ErrorCodeInvalidCode: http.StatusBadRequest,
|
||||
ErrorCodeChallengeNotFound: http.StatusNotFound,
|
||||
ErrorCodeChallengeExpired: http.StatusGone,
|
||||
ErrorCodeBlockedByPolicy: http.StatusForbidden,
|
||||
ErrorCodeSessionLimitExceeded: http.StatusConflict,
|
||||
ErrorCodeServiceUnavailable: http.StatusServiceUnavailable,
|
||||
}
|
||||
|
||||
var publicStableMessages = map[string]string{
|
||||
ErrorCodeChallengeNotFound: "challenge not found",
|
||||
ErrorCodeChallengeExpired: "challenge expired",
|
||||
ErrorCodeInvalidCode: "confirmation code is invalid",
|
||||
ErrorCodeInvalidClientPublicKey: "client_public_key is not a valid base64-encoded raw 32-byte Ed25519 public key",
|
||||
ErrorCodeBlockedByPolicy: "authentication is blocked by policy",
|
||||
ErrorCodeSessionLimitExceeded: "active session limit would be exceeded",
|
||||
ErrorCodeServiceUnavailable: "service is unavailable",
|
||||
}
|
||||
|
||||
var internalErrorStatusCodes = map[string]int{
|
||||
ErrorCodeInvalidRequest: http.StatusBadRequest,
|
||||
ErrorCodeSessionNotFound: http.StatusNotFound,
|
||||
ErrorCodeSubjectNotFound: http.StatusNotFound,
|
||||
ErrorCodeServiceUnavailable: http.StatusServiceUnavailable,
|
||||
ErrorCodeInternalError: http.StatusInternalServerError,
|
||||
}
|
||||
|
||||
var internalStableMessages = map[string]string{
|
||||
ErrorCodeSessionNotFound: "session not found",
|
||||
ErrorCodeSubjectNotFound: "subject not found",
|
||||
ErrorCodeServiceUnavailable: "service is unavailable",
|
||||
ErrorCodeInternalError: "internal server error",
|
||||
}
|
||||
|
||||
// PublicErrorProjection describes one transport-ready public auth error after
|
||||
// internal service errors have been normalized to the frozen client-safe
|
||||
// surface.
|
||||
type PublicErrorProjection struct {
|
||||
// StatusCode is the HTTP status that should be returned to the public auth
|
||||
// caller.
|
||||
StatusCode int
|
||||
|
||||
// Code is the stable client-safe error code written into the public JSON
|
||||
// envelope.
|
||||
Code string
|
||||
|
||||
// Message is the client-safe error description exposed to the public auth
|
||||
// caller.
|
||||
Message string
|
||||
}
|
||||
|
||||
// InternalErrorProjection describes one transport-ready internal API error
|
||||
// after service-layer failures have been normalized to the frozen trusted
|
||||
// caller surface.
|
||||
type InternalErrorProjection struct {
|
||||
// StatusCode is the HTTP status that should be returned to the trusted
|
||||
// caller.
|
||||
StatusCode int
|
||||
|
||||
// Code is the stable error code written into the internal JSON envelope.
|
||||
Code string
|
||||
|
||||
// Message is the trusted-caller-safe error description exposed by the
|
||||
// internal HTTP API.
|
||||
Message string
|
||||
}
|
||||
|
||||
// ServiceError projects one stable application-layer failure with a service
|
||||
// error code and a caller-safe message.
|
||||
type ServiceError struct {
|
||||
// Code is the stable error code expected by later transport mapping.
|
||||
Code string
|
||||
|
||||
// Message is the caller-safe error description.
|
||||
Message string
|
||||
|
||||
// Err optionally stores the wrapped underlying cause.
|
||||
Err error
|
||||
}
|
||||
|
||||
// Error returns the caller-safe error description.
|
||||
func (e *ServiceError) Error() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch {
|
||||
case strings.TrimSpace(e.Message) != "":
|
||||
return e.Message
|
||||
case strings.TrimSpace(e.Code) != "":
|
||||
return e.Code
|
||||
case e.Err != nil:
|
||||
return e.Err.Error()
|
||||
default:
|
||||
return ErrorCodeInternalError
|
||||
}
|
||||
}
|
||||
|
||||
// Unwrap returns the wrapped cause, if any.
|
||||
func (e *ServiceError) Unwrap() error {
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return e.Err
|
||||
}
|
||||
|
||||
// NewServiceError returns a new typed application-layer error.
|
||||
func NewServiceError(code string, message string, err error) *ServiceError {
|
||||
return &ServiceError{
|
||||
Code: strings.TrimSpace(code),
|
||||
Message: strings.TrimSpace(message),
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// IsPublicErrorCode reports whether code belongs to the frozen public auth
|
||||
// error surface.
|
||||
func IsPublicErrorCode(code string) bool {
|
||||
_, ok := publicErrorStatusCodes[strings.TrimSpace(code)]
|
||||
return ok
|
||||
}
|
||||
|
||||
// IsInternalOnlyErrorCode reports whether code is intentionally excluded from
|
||||
// the public auth transport surface.
|
||||
func IsInternalOnlyErrorCode(code string) bool {
|
||||
switch strings.TrimSpace(code) {
|
||||
case ErrorCodeSessionNotFound, ErrorCodeSubjectNotFound, ErrorCodeInternalError:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// IsSendEmailCodePublicErrorCode reports whether code may be exposed by the
|
||||
// public send-email-code route after public projection.
|
||||
func IsSendEmailCodePublicErrorCode(code string) bool {
|
||||
switch strings.TrimSpace(code) {
|
||||
case ErrorCodeInvalidRequest, ErrorCodeServiceUnavailable:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// IsConfirmEmailCodePublicErrorCode reports whether code may be exposed by the
|
||||
// public confirm-email-code route after public projection.
|
||||
func IsConfirmEmailCodePublicErrorCode(code string) bool {
|
||||
switch strings.TrimSpace(code) {
|
||||
case ErrorCodeInvalidRequest,
|
||||
ErrorCodeChallengeNotFound,
|
||||
ErrorCodeChallengeExpired,
|
||||
ErrorCodeInvalidCode,
|
||||
ErrorCodeInvalidClientPublicKey,
|
||||
ErrorCodeBlockedByPolicy,
|
||||
ErrorCodeSessionLimitExceeded,
|
||||
ErrorCodeServiceUnavailable:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// PublicHTTPStatusCode reports the frozen public HTTP status for code. Unknown
|
||||
// or internal-only codes are normalized to 503 service_unavailable.
|
||||
func PublicHTTPStatusCode(code string) int {
|
||||
if statusCode, ok := publicErrorStatusCodes[strings.TrimSpace(code)]; ok {
|
||||
return statusCode
|
||||
}
|
||||
|
||||
return http.StatusServiceUnavailable
|
||||
}
|
||||
|
||||
// ProjectPublicError normalizes err to the frozen public-auth error surface.
|
||||
// Unknown and internal-only service failures are intentionally projected as
|
||||
// 503 service_unavailable so internal invariants do not leak to public callers.
|
||||
func ProjectPublicError(err error) PublicErrorProjection {
|
||||
serviceErr, ok := errors.AsType[*ServiceError](err)
|
||||
code := CodeOf(err)
|
||||
if !IsPublicErrorCode(code) {
|
||||
return PublicErrorProjection{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Code: ErrorCodeServiceUnavailable,
|
||||
Message: publicMessageForCode(ErrorCodeServiceUnavailable, ""),
|
||||
}
|
||||
}
|
||||
|
||||
message := ""
|
||||
if ok && serviceErr != nil {
|
||||
message = serviceErr.Message
|
||||
}
|
||||
|
||||
return PublicErrorProjection{
|
||||
StatusCode: PublicHTTPStatusCode(code),
|
||||
Code: code,
|
||||
Message: publicMessageForCode(code, message),
|
||||
}
|
||||
}
|
||||
|
||||
// InternalHTTPStatusCode reports the frozen internal HTTP status for code.
|
||||
// Unknown codes are normalized to 500 internal_error.
|
||||
func InternalHTTPStatusCode(code string) int {
|
||||
if statusCode, ok := internalErrorStatusCodes[strings.TrimSpace(code)]; ok {
|
||||
return statusCode
|
||||
}
|
||||
|
||||
return http.StatusInternalServerError
|
||||
}
|
||||
|
||||
// ProjectInternalError normalizes err to the frozen internal trusted HTTP
|
||||
// error surface. Unknown failures are intentionally projected as
|
||||
// 500 internal_error so transport callers do not depend on unclassified local
|
||||
// failures.
|
||||
func ProjectInternalError(err error) InternalErrorProjection {
|
||||
serviceErr, ok := errors.AsType[*ServiceError](err)
|
||||
code := CodeOf(err)
|
||||
if _, known := internalErrorStatusCodes[code]; !known {
|
||||
return InternalErrorProjection{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Code: ErrorCodeInternalError,
|
||||
Message: internalMessageForCode(ErrorCodeInternalError, ""),
|
||||
}
|
||||
}
|
||||
|
||||
message := ""
|
||||
if ok && serviceErr != nil {
|
||||
message = serviceErr.Message
|
||||
}
|
||||
|
||||
return InternalErrorProjection{
|
||||
StatusCode: InternalHTTPStatusCode(code),
|
||||
Code: code,
|
||||
Message: internalMessageForCode(code, message),
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidRequest reports one malformed or semantically invalid caller input.
|
||||
func InvalidRequest(message string) *ServiceError {
|
||||
return NewServiceError(ErrorCodeInvalidRequest, message, nil)
|
||||
}
|
||||
|
||||
// ChallengeNotFound reports that the requested challenge does not exist.
|
||||
func ChallengeNotFound() *ServiceError {
|
||||
return NewServiceError(ErrorCodeChallengeNotFound, "challenge not found", nil)
|
||||
}
|
||||
|
||||
// ChallengeExpired reports that the requested challenge is expired.
|
||||
func ChallengeExpired() *ServiceError {
|
||||
return NewServiceError(ErrorCodeChallengeExpired, "challenge expired", nil)
|
||||
}
|
||||
|
||||
// InvalidCode reports that the submitted confirmation code is invalid.
|
||||
func InvalidCode() *ServiceError {
|
||||
return NewServiceError(ErrorCodeInvalidCode, "confirmation code is invalid", nil)
|
||||
}
|
||||
|
||||
// InvalidClientPublicKey reports that the submitted client public key does not
|
||||
// satisfy the frozen contract.
|
||||
func InvalidClientPublicKey() *ServiceError {
|
||||
return NewServiceError(
|
||||
ErrorCodeInvalidClientPublicKey,
|
||||
"client_public_key is not a valid base64-encoded raw 32-byte Ed25519 public key",
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
// BlockedByPolicy reports that the current auth flow is denied by policy.
|
||||
func BlockedByPolicy() *ServiceError {
|
||||
return NewServiceError(ErrorCodeBlockedByPolicy, "authentication is blocked by policy", nil)
|
||||
}
|
||||
|
||||
// SessionLimitExceeded reports that creating another active session would
|
||||
// exceed the current configured limit.
|
||||
func SessionLimitExceeded() *ServiceError {
|
||||
return NewServiceError(ErrorCodeSessionLimitExceeded, "active session limit would be exceeded", nil)
|
||||
}
|
||||
|
||||
// SessionNotFound reports that the requested session does not exist.
|
||||
func SessionNotFound() *ServiceError {
|
||||
return NewServiceError(ErrorCodeSessionNotFound, "session not found", nil)
|
||||
}
|
||||
|
||||
// SubjectNotFound reports that the requested internal subject does not exist.
|
||||
func SubjectNotFound() *ServiceError {
|
||||
return NewServiceError(ErrorCodeSubjectNotFound, "subject not found", nil)
|
||||
}
|
||||
|
||||
// ServiceUnavailable reports that a required dependency or propagation step is
|
||||
// temporarily unavailable.
|
||||
func ServiceUnavailable(err error) *ServiceError {
|
||||
return NewServiceError(ErrorCodeServiceUnavailable, "service is unavailable", err)
|
||||
}
|
||||
|
||||
// InternalError reports an invariant-breaking local failure.
|
||||
func InternalError(err error) *ServiceError {
|
||||
return NewServiceError(ErrorCodeInternalError, "internal error", err)
|
||||
}
|
||||
|
||||
// CodeOf returns the stable service error code of err when err wraps a
|
||||
// ServiceError. Otherwise it returns ErrorCodeInternalError.
|
||||
func CodeOf(err error) string {
|
||||
serviceErr, ok := errors.AsType[*ServiceError](err)
|
||||
if !ok || serviceErr == nil || strings.TrimSpace(serviceErr.Code) == "" {
|
||||
return ErrorCodeInternalError
|
||||
}
|
||||
|
||||
return serviceErr.Code
|
||||
}
|
||||
|
||||
func publicMessageForCode(code string, message string) string {
|
||||
trimmedMessage := strings.TrimSpace(message)
|
||||
|
||||
switch strings.TrimSpace(code) {
|
||||
case ErrorCodeInvalidRequest:
|
||||
if trimmedMessage != "" {
|
||||
return trimmedMessage
|
||||
}
|
||||
return genericInvalidRequestMessage
|
||||
case ErrorCodeServiceUnavailable:
|
||||
return publicStableMessages[ErrorCodeServiceUnavailable]
|
||||
default:
|
||||
if stableMessage, ok := publicStableMessages[strings.TrimSpace(code)]; ok {
|
||||
return stableMessage
|
||||
}
|
||||
return publicStableMessages[ErrorCodeServiceUnavailable]
|
||||
}
|
||||
}
|
||||
|
||||
func internalMessageForCode(code string, message string) string {
|
||||
trimmedMessage := strings.TrimSpace(message)
|
||||
|
||||
switch strings.TrimSpace(code) {
|
||||
case ErrorCodeInvalidRequest:
|
||||
if trimmedMessage != "" {
|
||||
return trimmedMessage
|
||||
}
|
||||
return genericInvalidRequestMessage
|
||||
case ErrorCodeSessionNotFound,
|
||||
ErrorCodeSubjectNotFound,
|
||||
ErrorCodeServiceUnavailable,
|
||||
ErrorCodeInternalError:
|
||||
if stableMessage, ok := internalStableMessages[strings.TrimSpace(code)]; ok {
|
||||
return stableMessage
|
||||
}
|
||||
return internalStableMessages[ErrorCodeInternalError]
|
||||
default:
|
||||
return internalStableMessages[ErrorCodeInternalError]
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,158 @@
|
||||
package shared
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/devicesession"
|
||||
)
|
||||
|
||||
// NormalizeString trims surrounding Unicode whitespace from value.
|
||||
func NormalizeString(value string) string {
|
||||
return strings.TrimSpace(value)
|
||||
}
|
||||
|
||||
// ParseEmail trims value and validates it against the frozen public e-mail
|
||||
// contract.
|
||||
func ParseEmail(value string) (common.Email, error) {
|
||||
email := common.Email(NormalizeString(value))
|
||||
if err := email.Validate(); err != nil {
|
||||
return "", InvalidRequest(err.Error())
|
||||
}
|
||||
|
||||
return email, nil
|
||||
}
|
||||
|
||||
// ParseChallengeID trims value and validates it as one challenge identifier.
|
||||
func ParseChallengeID(value string) (common.ChallengeID, error) {
|
||||
challengeID := common.ChallengeID(NormalizeString(value))
|
||||
if err := challengeID.Validate(); err != nil {
|
||||
return "", InvalidRequest(err.Error())
|
||||
}
|
||||
|
||||
return challengeID, nil
|
||||
}
|
||||
|
||||
// ParseDeviceSessionID trims value and validates it as one device-session
|
||||
// identifier.
|
||||
func ParseDeviceSessionID(value string) (common.DeviceSessionID, error) {
|
||||
deviceSessionID := common.DeviceSessionID(NormalizeString(value))
|
||||
if err := deviceSessionID.Validate(); err != nil {
|
||||
return "", InvalidRequest(err.Error())
|
||||
}
|
||||
|
||||
return deviceSessionID, nil
|
||||
}
|
||||
|
||||
// ParseUserID trims value and validates it as one user identifier.
|
||||
func ParseUserID(value string) (common.UserID, error) {
|
||||
userID := common.UserID(NormalizeString(value))
|
||||
if err := userID.Validate(); err != nil {
|
||||
return "", InvalidRequest(err.Error())
|
||||
}
|
||||
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
// ParseRequiredCode trims value and validates it as a required non-empty
|
||||
// confirmation code.
|
||||
func ParseRequiredCode(value string) (string, error) {
|
||||
code := NormalizeString(value)
|
||||
if code == "" {
|
||||
return "", InvalidRequest("code must not be empty")
|
||||
}
|
||||
|
||||
return code, nil
|
||||
}
|
||||
|
||||
// ParseClientPublicKey trims value and validates it as the standard
|
||||
// base64-encoded raw 32-byte Ed25519 public key expected by the public auth
|
||||
// contract.
|
||||
func ParseClientPublicKey(value string) (common.ClientPublicKey, error) {
|
||||
normalized := NormalizeString(value)
|
||||
if normalized == "" {
|
||||
return common.ClientPublicKey{}, InvalidClientPublicKey()
|
||||
}
|
||||
|
||||
decoded, err := base64.StdEncoding.Strict().DecodeString(normalized)
|
||||
if err != nil || len(decoded) != ed25519.PublicKeySize {
|
||||
return common.ClientPublicKey{}, InvalidClientPublicKey()
|
||||
}
|
||||
|
||||
key, err := common.NewClientPublicKey(ed25519.PublicKey(decoded))
|
||||
if err != nil {
|
||||
return common.ClientPublicKey{}, InvalidClientPublicKey()
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// ParseRevokeReasonCode trims value and validates it as one machine-readable
|
||||
// revoke reason code.
|
||||
func ParseRevokeReasonCode(value string) (common.RevokeReasonCode, error) {
|
||||
code := common.RevokeReasonCode(NormalizeString(value))
|
||||
if err := code.Validate(); err != nil {
|
||||
return "", InvalidRequest(err.Error())
|
||||
}
|
||||
|
||||
return code, nil
|
||||
}
|
||||
|
||||
// ParseRevokeActorType trims value and validates it as one machine-readable
|
||||
// revoke actor type.
|
||||
func ParseRevokeActorType(value string) (common.RevokeActorType, error) {
|
||||
actorType := common.RevokeActorType(NormalizeString(value))
|
||||
if err := actorType.Validate(); err != nil {
|
||||
return "", InvalidRequest(err.Error())
|
||||
}
|
||||
|
||||
return actorType, nil
|
||||
}
|
||||
|
||||
// ParseOptionalActorID trims value and validates it as one optional stable
|
||||
// actor identifier.
|
||||
func ParseOptionalActorID(value string) (string, error) {
|
||||
actorID := NormalizeString(value)
|
||||
if actorID != value {
|
||||
return "", InvalidRequest("actor_id must not contain surrounding whitespace")
|
||||
}
|
||||
|
||||
return actorID, nil
|
||||
}
|
||||
|
||||
// BuildRevocation validates one revoke request payload and returns the domain
|
||||
// revocation metadata applied to a session mutation.
|
||||
func BuildRevocation(reasonCode string, actorType string, actorID string, at time.Time) (devicesession.Revocation, error) {
|
||||
if at.IsZero() {
|
||||
return devicesession.Revocation{}, InternalError(fmt.Errorf("revocation time must not be zero"))
|
||||
}
|
||||
|
||||
parsedReasonCode, err := ParseRevokeReasonCode(reasonCode)
|
||||
if err != nil {
|
||||
return devicesession.Revocation{}, err
|
||||
}
|
||||
parsedActorType, err := ParseRevokeActorType(actorType)
|
||||
if err != nil {
|
||||
return devicesession.Revocation{}, err
|
||||
}
|
||||
parsedActorID, err := ParseOptionalActorID(actorID)
|
||||
if err != nil {
|
||||
return devicesession.Revocation{}, err
|
||||
}
|
||||
|
||||
revocation := devicesession.Revocation{
|
||||
At: at.UTC(),
|
||||
ReasonCode: parsedReasonCode,
|
||||
ActorType: parsedActorType,
|
||||
ActorID: parsedActorID,
|
||||
}
|
||||
if err := revocation.Validate(); err != nil {
|
||||
return devicesession.Revocation{}, InternalError(fmt.Errorf("build revocation: %w", err))
|
||||
}
|
||||
|
||||
return revocation, nil
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
package shared
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
authlogging "galaxy/authsession/internal/logging"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// LogServiceOutcome writes one structured service-level outcome log with a
|
||||
// stable severity derived from err and with trace fields attached when ctx
|
||||
// carries an active span.
|
||||
func LogServiceOutcome(logger *zap.Logger, ctx context.Context, message string, err error, fields ...zap.Field) {
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
|
||||
fields = append(fields, authlogging.TraceFieldsFromContext(ctx)...)
|
||||
|
||||
switch {
|
||||
case err == nil:
|
||||
logger.Info(message, fields...)
|
||||
case isExpectedServiceErrorCode(CodeOf(err)):
|
||||
logger.Warn(message, append(fields, zap.Error(err))...)
|
||||
default:
|
||||
logger.Error(message, append(fields, zap.Error(err))...)
|
||||
}
|
||||
}
|
||||
|
||||
func isExpectedServiceErrorCode(code string) bool {
|
||||
switch code {
|
||||
case ErrorCodeInvalidRequest,
|
||||
ErrorCodeChallengeNotFound,
|
||||
ErrorCodeChallengeExpired,
|
||||
ErrorCodeInvalidCode,
|
||||
ErrorCodeInvalidClientPublicKey,
|
||||
ErrorCodeBlockedByPolicy,
|
||||
ErrorCodeSessionLimitExceeded,
|
||||
ErrorCodeSessionNotFound,
|
||||
ErrorCodeSubjectNotFound:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
package shared
|
||||
|
||||
const (
|
||||
// MaxCompareAndSwapRetries bounds application-level retry loops around
|
||||
// compare-and-swap challenge updates.
|
||||
MaxCompareAndSwapRetries = 3
|
||||
|
||||
// MaxProjectionPublishAttempts bounds synchronous request-path retries
|
||||
// around gateway session projection publication.
|
||||
MaxProjectionPublishAttempts = 3
|
||||
)
|
||||
@@ -0,0 +1,86 @@
|
||||
package shared
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"galaxy/authsession/internal/domain/devicesession"
|
||||
"galaxy/authsession/internal/domain/gatewayprojection"
|
||||
"galaxy/authsession/internal/ports"
|
||||
"galaxy/authsession/internal/telemetry"
|
||||
)
|
||||
|
||||
// PublishProjectionSnapshot publishes snapshot through publisher with a small
|
||||
// bounded retry loop suitable for request-path consistency repair.
|
||||
func PublishProjectionSnapshot(ctx context.Context, publisher ports.GatewaySessionProjectionPublisher, snapshot gatewayprojection.Snapshot) error {
|
||||
return PublishProjectionSnapshotWithTelemetry(ctx, publisher, snapshot, nil, "")
|
||||
}
|
||||
|
||||
// PublishProjectionSnapshotWithTelemetry publishes snapshot through publisher
|
||||
// with the bounded request-path retry policy and optional publish-failure
|
||||
// telemetry.
|
||||
func PublishProjectionSnapshotWithTelemetry(
|
||||
ctx context.Context,
|
||||
publisher ports.GatewaySessionProjectionPublisher,
|
||||
snapshot gatewayprojection.Snapshot,
|
||||
telemetryRuntime *telemetry.Runtime,
|
||||
operation string,
|
||||
) error {
|
||||
if publisher == nil {
|
||||
return InternalError(errors.New("projection publisher must not be nil"))
|
||||
}
|
||||
if ctx == nil {
|
||||
return ServiceUnavailable(errors.New("projection publish context must not be nil"))
|
||||
}
|
||||
if err := snapshot.Validate(); err != nil {
|
||||
return InternalError(fmt.Errorf("publish projection snapshot: %w", err))
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < MaxProjectionPublishAttempts; attempt++ {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return ServiceUnavailable(err)
|
||||
}
|
||||
|
||||
if err := publisher.PublishSession(ctx, snapshot); err == nil {
|
||||
return nil
|
||||
} else {
|
||||
lastErr = err
|
||||
}
|
||||
}
|
||||
|
||||
telemetryRuntime.RecordProjectionPublishFailure(ctx, operation)
|
||||
return ServiceUnavailable(
|
||||
fmt.Errorf(
|
||||
"publish projection snapshot %q after %d attempts: %w",
|
||||
snapshot.DeviceSessionID,
|
||||
MaxProjectionPublishAttempts,
|
||||
lastErr,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
// PublishSessionProjection converts record into the gateway-facing snapshot and
|
||||
// publishes it with the bounded request-path retry policy.
|
||||
func PublishSessionProjection(ctx context.Context, publisher ports.GatewaySessionProjectionPublisher, record devicesession.Session) error {
|
||||
return PublishSessionProjectionWithTelemetry(ctx, publisher, record, nil, "")
|
||||
}
|
||||
|
||||
// PublishSessionProjectionWithTelemetry converts record into the
|
||||
// gateway-facing snapshot and publishes it with the bounded request-path retry
|
||||
// policy and optional publish-failure telemetry.
|
||||
func PublishSessionProjectionWithTelemetry(
|
||||
ctx context.Context,
|
||||
publisher ports.GatewaySessionProjectionPublisher,
|
||||
record devicesession.Session,
|
||||
telemetryRuntime *telemetry.Runtime,
|
||||
operation string,
|
||||
) error {
|
||||
snapshot, err := ToGatewayProjectionSnapshot(record)
|
||||
if err != nil {
|
||||
return InternalError(err)
|
||||
}
|
||||
|
||||
return PublishProjectionSnapshotWithTelemetry(ctx, publisher, snapshot, telemetryRuntime, operation)
|
||||
}
|
||||
@@ -0,0 +1,119 @@
|
||||
package shared
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"galaxy/authsession/internal/domain/devicesession"
|
||||
"galaxy/authsession/internal/domain/gatewayprojection"
|
||||
"galaxy/authsession/internal/testkit"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPublishSessionProjectionRetriesUntilSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
errors []error
|
||||
wantAttempts int
|
||||
}{
|
||||
{
|
||||
name: "success on second attempt",
|
||||
errors: []error{errors.New("transient publish failure"), nil},
|
||||
wantAttempts: 2,
|
||||
},
|
||||
{
|
||||
name: "success on third attempt",
|
||||
errors: []error{errors.New("transient publish failure"), errors.New("transient publish failure"), nil},
|
||||
wantAttempts: 3,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
publisher := &testkit.RecordingProjectionPublisher{Errors: tt.errors}
|
||||
|
||||
err := PublishSessionProjection(context.Background(), publisher, revokedSessionFixture())
|
||||
require.NoError(t, err)
|
||||
require.Len(t, publisher.PublishedSnapshots(), tt.wantAttempts)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishSessionProjectionReturnsServiceUnavailableAfterExhaustedRetries(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
publisher := &testkit.RecordingProjectionPublisher{Err: errors.New("publish failed")}
|
||||
|
||||
err := PublishSessionProjection(context.Background(), publisher, revokedSessionFixture())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, ErrorCodeServiceUnavailable, CodeOf(err))
|
||||
require.Len(t, publisher.PublishedSnapshots(), MaxProjectionPublishAttempts)
|
||||
}
|
||||
|
||||
func TestPublishProjectionSnapshotStopsRetriesWhenContextIsCanceled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
publisher := &cancelingProjectionPublisher{
|
||||
cancel: cancel,
|
||||
err: errors.New("publish failed"),
|
||||
}
|
||||
|
||||
err := PublishProjectionSnapshot(ctx, publisher, mustProjectionSnapshot(t))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, ErrorCodeServiceUnavailable, CodeOf(err))
|
||||
assert.Equal(t, 1, publisher.attempts)
|
||||
}
|
||||
|
||||
func TestPublishSessionProjectionReturnsInternalErrorForInvalidLocalRecord(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
publisher := &testkit.RecordingProjectionPublisher{}
|
||||
|
||||
err := PublishSessionProjection(context.Background(), publisher, invalidSessionFixture())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, ErrorCodeInternalError, CodeOf(err))
|
||||
assert.Empty(t, publisher.PublishedSnapshots())
|
||||
}
|
||||
|
||||
type cancelingProjectionPublisher struct {
|
||||
attempts int
|
||||
cancel context.CancelFunc
|
||||
err error
|
||||
}
|
||||
|
||||
func (p *cancelingProjectionPublisher) PublishSession(_ context.Context, snapshot gatewayprojection.Snapshot) error {
|
||||
if err := snapshot.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.attempts++
|
||||
if p.cancel != nil {
|
||||
p.cancel()
|
||||
p.cancel = nil
|
||||
}
|
||||
|
||||
return p.err
|
||||
}
|
||||
|
||||
func mustProjectionSnapshot(t *testing.T) gatewayprojection.Snapshot {
|
||||
t.Helper()
|
||||
|
||||
snapshot, err := ToGatewayProjectionSnapshot(revokedSessionFixture())
|
||||
require.NoError(t, err)
|
||||
|
||||
return snapshot
|
||||
}
|
||||
|
||||
func invalidSessionFixture() devicesession.Session {
|
||||
return devicesession.Session{}
|
||||
}
|
||||
@@ -0,0 +1,134 @@
|
||||
package shared
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/devicesession"
|
||||
"galaxy/authsession/internal/domain/gatewayprojection"
|
||||
)
|
||||
|
||||
// Session mirrors the frozen internal read-model DTO used by later trusted
|
||||
// transport handlers.
|
||||
type Session struct {
|
||||
// DeviceSessionID is the stable identifier of one device session.
|
||||
DeviceSessionID string
|
||||
|
||||
// UserID is the stable identifier of the session owner.
|
||||
UserID string
|
||||
|
||||
// ClientPublicKey is the base64-encoded raw 32-byte Ed25519 public key of
|
||||
// the device session.
|
||||
ClientPublicKey string
|
||||
|
||||
// Status reports whether the session is active or revoked.
|
||||
Status string
|
||||
|
||||
// CreatedAt is the RFC3339 UTC timestamp at which the session was created.
|
||||
CreatedAt string
|
||||
|
||||
// RevokedAt is the RFC3339 UTC timestamp at which the session was revoked,
|
||||
// when the session is revoked.
|
||||
RevokedAt *string
|
||||
|
||||
// RevokeReasonCode is the machine-readable revoke reason code when the
|
||||
// session is revoked.
|
||||
RevokeReasonCode *string
|
||||
|
||||
// RevokeActorType is the machine-readable revoke actor type when the
|
||||
// session is revoked.
|
||||
RevokeActorType *string
|
||||
|
||||
// RevokeActorID is the optional stable revoke actor identifier when the
|
||||
// session is revoked.
|
||||
RevokeActorID *string
|
||||
}
|
||||
|
||||
// ToSession converts source-of-truth session into the frozen internal read DTO
|
||||
// shape.
|
||||
func ToSession(record devicesession.Session) (Session, error) {
|
||||
if err := record.Validate(); err != nil {
|
||||
return Session{}, fmt.Errorf("map session: %w", err)
|
||||
}
|
||||
|
||||
result := Session{
|
||||
DeviceSessionID: record.ID.String(),
|
||||
UserID: record.UserID.String(),
|
||||
ClientPublicKey: record.ClientPublicKey.String(),
|
||||
Status: string(record.Status),
|
||||
CreatedAt: formatTime(record.CreatedAt),
|
||||
}
|
||||
|
||||
if record.Revocation != nil {
|
||||
revokedAt := formatTime(record.Revocation.At)
|
||||
reasonCode := record.Revocation.ReasonCode.String()
|
||||
actorType := record.Revocation.ActorType.String()
|
||||
result.RevokedAt = &revokedAt
|
||||
result.RevokeReasonCode = &reasonCode
|
||||
result.RevokeActorType = &actorType
|
||||
if record.Revocation.ActorID != "" {
|
||||
actorID := record.Revocation.ActorID
|
||||
result.RevokeActorID = &actorID
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ToSessions converts every source-of-truth session into the frozen internal
|
||||
// read DTO shape.
|
||||
func ToSessions(records []devicesession.Session) ([]Session, error) {
|
||||
result := make([]Session, 0, len(records))
|
||||
for index, record := range records {
|
||||
mapped, err := ToSession(record)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("map session %d: %w", index, err)
|
||||
}
|
||||
result = append(result, mapped)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ToGatewayProjectionSnapshot converts source-of-truth session into the
|
||||
// separate gateway-facing projection model.
|
||||
func ToGatewayProjectionSnapshot(record devicesession.Session) (gatewayprojection.Snapshot, error) {
|
||||
if err := record.Validate(); err != nil {
|
||||
return gatewayprojection.Snapshot{}, fmt.Errorf("map gateway projection snapshot: %w", err)
|
||||
}
|
||||
|
||||
snapshot := gatewayprojection.Snapshot{
|
||||
DeviceSessionID: record.ID,
|
||||
UserID: record.UserID,
|
||||
ClientPublicKey: record.ClientPublicKey.String(),
|
||||
Status: gatewayprojection.Status(record.Status),
|
||||
}
|
||||
if record.Revocation != nil {
|
||||
snapshot.RevokedAt = cloneTimePointer(commonTimePointer(record.Revocation.At.UTC()))
|
||||
snapshot.RevokeReasonCode = record.Revocation.ReasonCode
|
||||
snapshot.RevokeActorType = record.Revocation.ActorType
|
||||
snapshot.RevokeActorID = record.Revocation.ActorID
|
||||
}
|
||||
if err := snapshot.Validate(); err != nil {
|
||||
return gatewayprojection.Snapshot{}, fmt.Errorf("map gateway projection snapshot: %w", err)
|
||||
}
|
||||
|
||||
return snapshot, nil
|
||||
}
|
||||
|
||||
func formatTime(value time.Time) string {
|
||||
return value.UTC().Format(time.RFC3339)
|
||||
}
|
||||
|
||||
func commonTimePointer(value time.Time) *time.Time {
|
||||
return &value
|
||||
}
|
||||
|
||||
func cloneTimePointer(value *time.Time) *time.Time {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cloned := *value
|
||||
return &cloned
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
package shared
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"galaxy/authsession/internal/domain/sessionlimit"
|
||||
"galaxy/authsession/internal/ports"
|
||||
)
|
||||
|
||||
// EvaluateSessionLimit evaluates the Stage-4 active-session creation decision
|
||||
// from the loaded configuration and current active-session count.
|
||||
func EvaluateSessionLimit(config ports.SessionLimitConfig, activeSessionCount int) (sessionlimit.Decision, error) {
|
||||
if err := config.Validate(); err != nil {
|
||||
return sessionlimit.Decision{}, InternalError(fmt.Errorf("evaluate session limit: %w", err))
|
||||
}
|
||||
if activeSessionCount < 0 {
|
||||
return sessionlimit.Decision{}, InternalError(fmt.Errorf("evaluate session limit: active session count %d is negative", activeSessionCount))
|
||||
}
|
||||
|
||||
decision := sessionlimit.Decision{
|
||||
ActiveSessionCount: activeSessionCount,
|
||||
NextSessionCount: activeSessionCount + 1,
|
||||
}
|
||||
|
||||
if config.ActiveSessionLimit == nil {
|
||||
decision.Kind = sessionlimit.KindDisabled
|
||||
} else {
|
||||
decision.ConfiguredLimit = config.ActiveSessionLimit
|
||||
if decision.NextSessionCount <= *config.ActiveSessionLimit {
|
||||
decision.Kind = sessionlimit.KindAllowed
|
||||
} else {
|
||||
decision.Kind = sessionlimit.KindExceeded
|
||||
}
|
||||
}
|
||||
if err := decision.Validate(); err != nil {
|
||||
return sessionlimit.Decision{}, InternalError(fmt.Errorf("evaluate session limit: %w", err))
|
||||
}
|
||||
|
||||
return decision, nil
|
||||
}
|
||||
@@ -0,0 +1,380 @@
|
||||
package shared
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/authsession/internal/domain/common"
|
||||
"galaxy/authsession/internal/domain/devicesession"
|
||||
"galaxy/authsession/internal/domain/gatewayprojection"
|
||||
"galaxy/authsession/internal/domain/sessionlimit"
|
||||
"galaxy/authsession/internal/ports"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNormalizeString(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert.Equal(t, "pilot@example.com", NormalizeString(" pilot@example.com \n"))
|
||||
}
|
||||
|
||||
func TestParseClientPublicKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
key, err := ParseClientPublicKey(" AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8= ")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=", key.String())
|
||||
|
||||
_, err = ParseClientPublicKey("invalid")
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, ErrorCodeInvalidClientPublicKey, CodeOf(err))
|
||||
}
|
||||
|
||||
func TestToSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
record := revokedSessionFixture()
|
||||
|
||||
dto, err := ToSession(record)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, record.ID.String(), dto.DeviceSessionID)
|
||||
require.NotNil(t, dto.RevokedAt)
|
||||
assert.Equal(t, record.Revocation.At.UTC().Format(time.RFC3339), *dto.RevokedAt)
|
||||
}
|
||||
|
||||
func TestToGatewayProjectionSnapshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
record := revokedSessionFixture()
|
||||
|
||||
snapshot, err := ToGatewayProjectionSnapshot(record)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, gatewayprojection.StatusRevoked, snapshot.Status)
|
||||
}
|
||||
|
||||
func TestEvaluateSessionLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
limit := 2
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config ports.SessionLimitConfig
|
||||
active int
|
||||
want sessionlimit.Kind
|
||||
}{
|
||||
{name: "disabled", config: ports.SessionLimitConfig{}, active: 3, want: sessionlimit.KindDisabled},
|
||||
{name: "allowed", config: ports.SessionLimitConfig{ActiveSessionLimit: &limit}, active: 1, want: sessionlimit.KindAllowed},
|
||||
{name: "exceeded", config: ports.SessionLimitConfig{ActiveSessionLimit: &limit}, active: 2, want: sessionlimit.KindExceeded},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
decision, err := EvaluateSessionLimit(tt.config, tt.active)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.want, decision.Kind)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServiceErrorCodePreservation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
baseErr := errors.New("base")
|
||||
err := ServiceUnavailable(baseErr)
|
||||
|
||||
assert.Equal(t, ErrorCodeServiceUnavailable, CodeOf(err))
|
||||
assert.ErrorIs(t, err, baseErr)
|
||||
}
|
||||
|
||||
func TestErrorCodeClassification(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
publicCodes := []string{
|
||||
ErrorCodeInvalidRequest,
|
||||
ErrorCodeChallengeNotFound,
|
||||
ErrorCodeChallengeExpired,
|
||||
ErrorCodeInvalidCode,
|
||||
ErrorCodeInvalidClientPublicKey,
|
||||
ErrorCodeBlockedByPolicy,
|
||||
ErrorCodeSessionLimitExceeded,
|
||||
ErrorCodeServiceUnavailable,
|
||||
}
|
||||
for _, code := range publicCodes {
|
||||
assert.Truef(t, IsPublicErrorCode(code), "IsPublicErrorCode(%q)", code)
|
||||
assert.Falsef(t, IsInternalOnlyErrorCode(code), "IsInternalOnlyErrorCode(%q)", code)
|
||||
}
|
||||
|
||||
internalOnlyCodes := []string{
|
||||
ErrorCodeSessionNotFound,
|
||||
ErrorCodeSubjectNotFound,
|
||||
ErrorCodeInternalError,
|
||||
}
|
||||
for _, code := range internalOnlyCodes {
|
||||
assert.Falsef(t, IsPublicErrorCode(code), "IsPublicErrorCode(%q)", code)
|
||||
assert.Truef(t, IsInternalOnlyErrorCode(code), "IsInternalOnlyErrorCode(%q)", code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublicUseCaseErrorCodeSets(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert.True(t, IsSendEmailCodePublicErrorCode(ErrorCodeInvalidRequest))
|
||||
assert.True(t, IsSendEmailCodePublicErrorCode(ErrorCodeServiceUnavailable))
|
||||
assert.False(t, IsSendEmailCodePublicErrorCode(ErrorCodeBlockedByPolicy))
|
||||
assert.False(t, IsSendEmailCodePublicErrorCode(ErrorCodeChallengeNotFound))
|
||||
|
||||
confirmCodes := []string{
|
||||
ErrorCodeInvalidRequest,
|
||||
ErrorCodeChallengeNotFound,
|
||||
ErrorCodeChallengeExpired,
|
||||
ErrorCodeInvalidCode,
|
||||
ErrorCodeInvalidClientPublicKey,
|
||||
ErrorCodeBlockedByPolicy,
|
||||
ErrorCodeSessionLimitExceeded,
|
||||
ErrorCodeServiceUnavailable,
|
||||
}
|
||||
for _, code := range confirmCodes {
|
||||
assert.Truef(t, IsConfirmEmailCodePublicErrorCode(code), "IsConfirmEmailCodePublicErrorCode(%q)", code)
|
||||
}
|
||||
assert.False(t, IsConfirmEmailCodePublicErrorCode(ErrorCodeInternalError))
|
||||
assert.False(t, IsConfirmEmailCodePublicErrorCode(ErrorCodeSessionNotFound))
|
||||
}
|
||||
|
||||
func TestPublicHTTPStatusCode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
code string
|
||||
want int
|
||||
}{
|
||||
{name: "invalid request", code: ErrorCodeInvalidRequest, want: http.StatusBadRequest},
|
||||
{name: "invalid client public key", code: ErrorCodeInvalidClientPublicKey, want: http.StatusBadRequest},
|
||||
{name: "invalid code", code: ErrorCodeInvalidCode, want: http.StatusBadRequest},
|
||||
{name: "challenge not found", code: ErrorCodeChallengeNotFound, want: http.StatusNotFound},
|
||||
{name: "challenge expired", code: ErrorCodeChallengeExpired, want: http.StatusGone},
|
||||
{name: "blocked by policy", code: ErrorCodeBlockedByPolicy, want: http.StatusForbidden},
|
||||
{name: "session limit exceeded", code: ErrorCodeSessionLimitExceeded, want: http.StatusConflict},
|
||||
{name: "service unavailable", code: ErrorCodeServiceUnavailable, want: http.StatusServiceUnavailable},
|
||||
{name: "internal error normalized", code: ErrorCodeInternalError, want: http.StatusServiceUnavailable},
|
||||
{name: "unknown normalized", code: "unknown", want: http.StatusServiceUnavailable},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert.Equal(t, tt.want, PublicHTTPStatusCode(tt.code))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInternalHTTPStatusCode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
code string
|
||||
want int
|
||||
}{
|
||||
{name: "invalid request", code: ErrorCodeInvalidRequest, want: http.StatusBadRequest},
|
||||
{name: "session not found", code: ErrorCodeSessionNotFound, want: http.StatusNotFound},
|
||||
{name: "subject not found", code: ErrorCodeSubjectNotFound, want: http.StatusNotFound},
|
||||
{name: "service unavailable", code: ErrorCodeServiceUnavailable, want: http.StatusServiceUnavailable},
|
||||
{name: "internal error", code: ErrorCodeInternalError, want: http.StatusInternalServerError},
|
||||
{name: "unknown normalized", code: "unknown", want: http.StatusInternalServerError},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert.Equal(t, tt.want, InternalHTTPStatusCode(tt.code))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProjectPublicError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
want PublicErrorProjection
|
||||
}{
|
||||
{
|
||||
name: "invalid request keeps detailed message",
|
||||
err: InvalidRequest("email must be a single valid email address"),
|
||||
want: PublicErrorProjection{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Code: ErrorCodeInvalidRequest,
|
||||
Message: "email must be a single valid email address",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid code keeps canonical message",
|
||||
err: NewServiceError(ErrorCodeInvalidCode, "custom detail should not leak", nil),
|
||||
want: PublicErrorProjection{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Code: ErrorCodeInvalidCode,
|
||||
Message: "confirmation code is invalid",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "service unavailable keeps generic message",
|
||||
err: NewServiceError(ErrorCodeServiceUnavailable, "dependency timeout", errors.New("dependency timeout")),
|
||||
want: PublicErrorProjection{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Code: ErrorCodeServiceUnavailable,
|
||||
Message: "service is unavailable",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "internal error is hidden",
|
||||
err: InternalError(errors.New("broken invariant")),
|
||||
want: PublicErrorProjection{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Code: ErrorCodeServiceUnavailable,
|
||||
Message: "service is unavailable",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "internal only session not found is hidden",
|
||||
err: SessionNotFound(),
|
||||
want: PublicErrorProjection{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Code: ErrorCodeServiceUnavailable,
|
||||
Message: "service is unavailable",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non service error is hidden",
|
||||
err: errors.New("boom"),
|
||||
want: PublicErrorProjection{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Code: ErrorCodeServiceUnavailable,
|
||||
Message: "service is unavailable",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert.Equal(t, tt.want, ProjectPublicError(tt.err))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProjectInternalError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
want InternalErrorProjection
|
||||
}{
|
||||
{
|
||||
name: "invalid request keeps detailed message",
|
||||
err: InvalidRequest("reason_code must not be empty"),
|
||||
want: InternalErrorProjection{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Code: ErrorCodeInvalidRequest,
|
||||
Message: "reason_code must not be empty",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "session not found keeps canonical message",
|
||||
err: NewServiceError(ErrorCodeSessionNotFound, "custom detail should not leak", nil),
|
||||
want: InternalErrorProjection{
|
||||
StatusCode: http.StatusNotFound,
|
||||
Code: ErrorCodeSessionNotFound,
|
||||
Message: "session not found",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "subject not found keeps canonical message",
|
||||
err: SubjectNotFound(),
|
||||
want: InternalErrorProjection{
|
||||
StatusCode: http.StatusNotFound,
|
||||
Code: ErrorCodeSubjectNotFound,
|
||||
Message: "subject not found",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "service unavailable keeps generic message",
|
||||
err: NewServiceError(ErrorCodeServiceUnavailable, "redis timeout", errors.New("redis timeout")),
|
||||
want: InternalErrorProjection{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Code: ErrorCodeServiceUnavailable,
|
||||
Message: "service is unavailable",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "internal error uses internal server error message",
|
||||
err: InternalError(errors.New("broken invariant")),
|
||||
want: InternalErrorProjection{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Code: ErrorCodeInternalError,
|
||||
Message: "internal server error",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unexpected error is hidden",
|
||||
err: errors.New("boom"),
|
||||
want: InternalErrorProjection{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Code: ErrorCodeInternalError,
|
||||
Message: "internal server error",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert.Equal(t, tt.want, ProjectInternalError(tt.err))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func revokedSessionFixture() devicesession.Session {
|
||||
key, err := common.NewClientPublicKey(make([]byte, 32))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
revokedAt := time.Unix(20, 0).UTC()
|
||||
return devicesession.Session{
|
||||
ID: common.DeviceSessionID("device-session-1"),
|
||||
UserID: common.UserID("user-1"),
|
||||
ClientPublicKey: key,
|
||||
Status: devicesession.StatusRevoked,
|
||||
CreatedAt: time.Unix(10, 0).UTC(),
|
||||
Revocation: &devicesession.Revocation{
|
||||
At: revokedAt,
|
||||
ReasonCode: devicesession.RevokeReasonLogoutAll,
|
||||
ActorType: common.RevokeActorType("system"),
|
||||
ActorID: "actor-1",
|
||||
},
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user