859b157a59
Two problems showed up while trying to log into the long-lived dev
environment with the dev-fixed code `123456`:
1. `ConfirmEmailCode` checked the per-challenge attempts ceiling
*before* the dev-fixed-code override. A developer who burned past
`ChallengeMaxAttempts` on an existing un-consumed challenge (easy
to trigger when the throttle reuses one challenge_id) hit
`ErrTooManyAttempts` and the UI rendered "code expired or already
used" even though the fixed code was correct. Reorder so the
dev-fixed-code branch runs first and bypasses both the bcrypt
verify and the attempts gate. Production stays unaffected
because production loaders refuse to set `DevFixedCode`.
2. `dev-deploy.yaml` only fires on push to `development`, so the
matching docker-compose default change for
`BACKEND_AUTH_DEV_FIXED_CODE` could not reach the running stack
before this PR merged. Add `workflow_dispatch: {}` so a developer
can deploy any branch — typically a feature branch under review —
from the Gitea Actions UI without waiting for the merge.
Covered by a new `TestConfirmEmailCodeDevFixedCodeBypassesAttemptsCeiling`
integration test that burns through the ceiling with wrong codes
then proves the dev-fixed code still produces a session.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
688 lines
19 KiB
Go
688 lines
19 KiB
Go
package auth_test
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"database/sql"
|
|
"errors"
|
|
"net/url"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"galaxy/backend/internal/auth"
|
|
"galaxy/backend/internal/config"
|
|
backendpg "galaxy/backend/internal/postgres"
|
|
"galaxy/backend/internal/user"
|
|
pgshared "galaxy/postgres"
|
|
|
|
"github.com/google/uuid"
|
|
testcontainers "github.com/testcontainers/testcontainers-go"
|
|
tcpostgres "github.com/testcontainers/testcontainers-go/modules/postgres"
|
|
"github.com/testcontainers/testcontainers-go/wait"
|
|
)
|
|
|
|
const (
|
|
pgImage = "postgres:16-alpine"
|
|
pgUser = "galaxy"
|
|
pgPassword = "galaxy"
|
|
pgDatabase = "galaxy_backend"
|
|
pgSchema = "backend"
|
|
pgStartup = 90 * time.Second
|
|
pgOpTO = 10 * time.Second
|
|
)
|
|
|
|
// startPostgres spins up a Postgres testcontainer with the backend
|
|
// migrations applied. The returned *sql.DB is closed and the container
|
|
// terminated by t.Cleanup hooks.
|
|
func startPostgres(t *testing.T) *sql.DB {
|
|
t.Helper()
|
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
|
t.Cleanup(cancel)
|
|
|
|
pgContainer, err := tcpostgres.Run(ctx, pgImage,
|
|
tcpostgres.WithDatabase(pgDatabase),
|
|
tcpostgres.WithUsername(pgUser),
|
|
tcpostgres.WithPassword(pgPassword),
|
|
testcontainers.WithWaitStrategy(
|
|
wait.ForLog("database system is ready to accept connections").
|
|
WithOccurrence(2).
|
|
WithStartupTimeout(pgStartup),
|
|
),
|
|
)
|
|
if err != nil {
|
|
t.Skipf("postgres testcontainer unavailable, skipping: %v", err)
|
|
}
|
|
t.Cleanup(func() {
|
|
if termErr := testcontainers.TerminateContainer(pgContainer); termErr != nil {
|
|
t.Errorf("terminate postgres container: %v", termErr)
|
|
}
|
|
})
|
|
|
|
baseDSN, err := pgContainer.ConnectionString(ctx, "sslmode=disable")
|
|
if err != nil {
|
|
t.Fatalf("connection string: %v", err)
|
|
}
|
|
scopedDSN, err := dsnWithSearchPath(baseDSN, pgSchema)
|
|
if err != nil {
|
|
t.Fatalf("scope dsn: %v", err)
|
|
}
|
|
|
|
cfg := pgshared.DefaultConfig()
|
|
cfg.PrimaryDSN = scopedDSN
|
|
cfg.OperationTimeout = pgOpTO
|
|
|
|
db, err := pgshared.OpenPrimary(ctx, cfg, backendpg.NoObservabilityOptions()...)
|
|
if err != nil {
|
|
t.Fatalf("open primary: %v", err)
|
|
}
|
|
t.Cleanup(func() { _ = db.Close() })
|
|
|
|
if err := pgshared.Ping(ctx, db, cfg.OperationTimeout); err != nil {
|
|
t.Fatalf("ping: %v", err)
|
|
}
|
|
if err := backendpg.ApplyMigrations(ctx, db); err != nil {
|
|
t.Fatalf("apply migrations: %v", err)
|
|
}
|
|
return db
|
|
}
|
|
|
|
func dsnWithSearchPath(baseDSN, schema string) (string, error) {
|
|
parsed, err := url.Parse(baseDSN)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
values := parsed.Query()
|
|
values.Set("search_path", schema)
|
|
if values.Get("sslmode") == "" {
|
|
values.Set("sslmode", "disable")
|
|
}
|
|
parsed.RawQuery = values.Encode()
|
|
return parsed.String(), nil
|
|
}
|
|
|
|
// recordingMailer implements auth.LoginCodeMailer and remembers the most
|
|
// recent enqueue.
|
|
type recordingMailer struct {
|
|
mu sync.Mutex
|
|
lastCode string
|
|
lastTo string
|
|
calls int
|
|
}
|
|
|
|
func newRecordingMailer() *recordingMailer { return &recordingMailer{} }
|
|
|
|
func (m *recordingMailer) EnqueueLoginCode(_ context.Context, email, code string, _ time.Duration) error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
m.lastTo = email
|
|
m.lastCode = code
|
|
m.calls++
|
|
return nil
|
|
}
|
|
|
|
func (m *recordingMailer) snapshot() (string, string, int) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
return m.lastTo, m.lastCode, m.calls
|
|
}
|
|
|
|
// recordingPush implements auth.SessionInvalidator and counts emissions.
|
|
type recordingPush struct {
|
|
mu sync.Mutex
|
|
calls []recordedPush
|
|
}
|
|
|
|
type recordedPush struct {
|
|
deviceSessionID, userID uuid.UUID
|
|
reason string
|
|
}
|
|
|
|
func newRecordingPush() *recordingPush { return &recordingPush{} }
|
|
|
|
func (p *recordingPush) PublishSessionInvalidation(_ context.Context, dsID, uid uuid.UUID, reason string) {
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
p.calls = append(p.calls, recordedPush{deviceSessionID: dsID, userID: uid, reason: reason})
|
|
}
|
|
|
|
func (p *recordingPush) snapshot() []recordedPush {
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
out := make([]recordedPush, len(p.calls))
|
|
copy(out, p.calls)
|
|
return out
|
|
}
|
|
|
|
// stubGeo implements auth.GeoService with no real lookups. The country
|
|
// it returns is configurable per call via countryByIP.
|
|
type stubGeo struct {
|
|
countryByIP map[string]string
|
|
}
|
|
|
|
func newStubGeo() *stubGeo {
|
|
return &stubGeo{countryByIP: map[string]string{}}
|
|
}
|
|
|
|
func (g *stubGeo) LookupCountry(sourceIP string) string {
|
|
return g.countryByIP[sourceIP]
|
|
}
|
|
|
|
func (g *stubGeo) SetDeclaredCountryAtRegistration(_ context.Context, _ uuid.UUID, _ string) error {
|
|
return nil
|
|
}
|
|
|
|
// authConfig builds an AuthConfig suitable for tests.
|
|
func authConfig() config.AuthConfig {
|
|
return config.AuthConfig{
|
|
ChallengeTTL: 5 * time.Minute,
|
|
ChallengeMaxAttempts: 3,
|
|
ChallengeThrottle: config.AuthChallengeThrottleConfig{
|
|
Window: time.Minute,
|
|
Max: 3,
|
|
},
|
|
UserNameMaxRetries: 10,
|
|
}
|
|
}
|
|
|
|
// buildServiceWithConfig wires every dependency around db using cfg as
|
|
// the auth configuration. Returns only the service — assertions on the
|
|
// dev-mode override path do not inspect the recording fakes.
|
|
func buildServiceWithConfig(t *testing.T, db *sql.DB, cfg config.AuthConfig) *auth.Service {
|
|
t.Helper()
|
|
store := auth.NewStore(db)
|
|
cache := auth.NewCache()
|
|
if err := cache.Warm(context.Background(), store); err != nil {
|
|
t.Fatalf("warm cache: %v", err)
|
|
}
|
|
userStore := user.NewStore(db)
|
|
userSvc := user.NewService(user.Deps{
|
|
Store: userStore,
|
|
Cache: user.NewCache(),
|
|
UserNameMaxRetries: 10,
|
|
Now: time.Now,
|
|
})
|
|
return auth.NewService(auth.Deps{
|
|
Store: store,
|
|
Cache: cache,
|
|
User: userSvc,
|
|
Geo: newStubGeo(),
|
|
Mail: newRecordingMailer(),
|
|
Push: newRecordingPush(),
|
|
Config: cfg,
|
|
Now: time.Now,
|
|
})
|
|
}
|
|
|
|
// buildService wires every dependency around db and returns the service
|
|
// plus the recording fakes for assertions.
|
|
func buildService(t *testing.T, db *sql.DB) (*auth.Service, *recordingMailer, *recordingPush, *stubGeo) {
|
|
t.Helper()
|
|
store := auth.NewStore(db)
|
|
cache := auth.NewCache()
|
|
if err := cache.Warm(context.Background(), store); err != nil {
|
|
t.Fatalf("warm cache: %v", err)
|
|
}
|
|
mailer := newRecordingMailer()
|
|
pusher := newRecordingPush()
|
|
geo := newStubGeo()
|
|
userStore := user.NewStore(db)
|
|
userSvc := user.NewService(user.Deps{
|
|
|
|
Store: userStore,
|
|
Cache: user.NewCache(),
|
|
UserNameMaxRetries: 10,
|
|
Now: time.Now,
|
|
})
|
|
svc := auth.NewService(auth.Deps{
|
|
Store: store,
|
|
Cache: cache,
|
|
User: userSvc,
|
|
Geo: geo,
|
|
Mail: mailer,
|
|
Push: pusher,
|
|
Config: authConfig(),
|
|
Now: time.Now,
|
|
})
|
|
return svc, mailer, pusher, geo
|
|
}
|
|
|
|
func randomKey(t *testing.T) []byte {
|
|
t.Helper()
|
|
key := make([]byte, 32)
|
|
if _, err := rand.Read(key); err != nil {
|
|
t.Fatalf("rand: %v", err)
|
|
}
|
|
return key
|
|
}
|
|
|
|
func TestAuthEndToEnd(t *testing.T) {
|
|
db := startPostgres(t)
|
|
svc, mailer, pusher, _ := buildService(t, db)
|
|
ctx := context.Background()
|
|
|
|
challengeID, err := svc.SendEmailCode(ctx, "Alice@Example.Test", "ru", "", "")
|
|
if err != nil {
|
|
t.Fatalf("SendEmailCode: %v", err)
|
|
}
|
|
if challengeID == uuid.Nil {
|
|
t.Fatalf("SendEmailCode returned nil challenge_id")
|
|
}
|
|
gotEmail, gotCode, calls := mailer.snapshot()
|
|
if gotEmail != "alice@example.test" {
|
|
t.Fatalf("mailer email = %q, want lower-cased", gotEmail)
|
|
}
|
|
if len(gotCode) != auth.CodeLength {
|
|
t.Fatalf("mailer code = %q (len %d), want length %d", gotCode, len(gotCode), auth.CodeLength)
|
|
}
|
|
if calls != 1 {
|
|
t.Fatalf("mailer calls = %d, want 1", calls)
|
|
}
|
|
|
|
pubKey := randomKey(t)
|
|
session, err := svc.ConfirmEmailCode(ctx, auth.ConfirmInputs{
|
|
ChallengeID: challengeID,
|
|
Code: gotCode,
|
|
ClientPublicKey: pubKey,
|
|
TimeZone: "Europe/Moscow",
|
|
SourceIP: "",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("ConfirmEmailCode: %v", err)
|
|
}
|
|
if session.UserID == uuid.Nil {
|
|
t.Fatalf("session has nil user_id")
|
|
}
|
|
if session.Status != auth.SessionStatusActive {
|
|
t.Fatalf("session.Status = %q, want %q", session.Status, auth.SessionStatusActive)
|
|
}
|
|
|
|
got, err := svc.GetSession(ctx, session.DeviceSessionID)
|
|
if err != nil {
|
|
t.Fatalf("GetSession: %v", err)
|
|
}
|
|
if got.UserID != session.UserID {
|
|
t.Fatalf("GetSession user_id = %s, want %s", got.UserID, session.UserID)
|
|
}
|
|
|
|
revoked, err := svc.RevokeSession(ctx, session.DeviceSessionID, auth.RevokeContext{
|
|
ActorKind: auth.ActorKindUserSelf,
|
|
ActorID: session.UserID.String(),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("RevokeSession: %v", err)
|
|
}
|
|
if revoked.Status != auth.SessionStatusRevoked {
|
|
t.Fatalf("revoked.Status = %q, want %q", revoked.Status, auth.SessionStatusRevoked)
|
|
}
|
|
if revoked.RevokedAt == nil {
|
|
t.Fatalf("revoked.RevokedAt nil after revoke")
|
|
}
|
|
|
|
if _, err := svc.GetSession(ctx, session.DeviceSessionID); !errors.Is(err, auth.ErrSessionNotFound) {
|
|
t.Fatalf("GetSession after revoke = %v, want ErrSessionNotFound", err)
|
|
}
|
|
|
|
again, err := svc.RevokeSession(ctx, session.DeviceSessionID, auth.RevokeContext{
|
|
ActorKind: auth.ActorKindUserSelf,
|
|
ActorID: session.UserID.String(),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("idempotent RevokeSession: %v", err)
|
|
}
|
|
if again.DeviceSessionID != session.DeviceSessionID || again.Status != auth.SessionStatusRevoked {
|
|
t.Fatalf("idempotent revoke shape mismatch: %+v", again)
|
|
}
|
|
|
|
pushes := pusher.snapshot()
|
|
if len(pushes) != 1 {
|
|
t.Fatalf("push emissions = %d, want 1", len(pushes))
|
|
}
|
|
if pushes[0].deviceSessionID != session.DeviceSessionID {
|
|
t.Fatalf("push device_session_id mismatch")
|
|
}
|
|
}
|
|
|
|
func TestSendEmailCodePermanentlyBlocked(t *testing.T) {
|
|
db := startPostgres(t)
|
|
svc, _, _, _ := buildService(t, db)
|
|
|
|
// Insert a permanent_block account directly.
|
|
if _, err := db.Exec(`
|
|
INSERT INTO backend.accounts (
|
|
user_id, email, user_name, preferred_language, time_zone, permanent_block
|
|
) VALUES ($1, $2, $3, $4, $5, true)
|
|
`, uuid.New(), "blocked@example.test", "Player-XXBLOCK1", "en", "UTC"); err != nil {
|
|
t.Fatalf("seed account: %v", err)
|
|
}
|
|
|
|
_, err := svc.SendEmailCode(context.Background(), "blocked@example.test", "", "", "")
|
|
if !errors.Is(err, auth.ErrEmailPermanentlyBlocked) {
|
|
t.Fatalf("SendEmailCode for blocked email = %v, want ErrEmailPermanentlyBlocked", err)
|
|
}
|
|
}
|
|
|
|
// TestConfirmEmailCodePermanentlyBlockedAfterSend covers the case where
|
|
// an admin applies permanent_block in the window between send and
|
|
// confirm. The send-time guard let the challenge through because the
|
|
// account was unblocked at that moment; the confirm-time guard must
|
|
// catch the late block and reject the registration.
|
|
func TestConfirmEmailCodePermanentlyBlockedAfterSend(t *testing.T) {
|
|
db := startPostgres(t)
|
|
svc, mailer, _, _ := buildService(t, db)
|
|
ctx := context.Background()
|
|
|
|
const email = "blockedlater@example.test"
|
|
|
|
if _, err := db.Exec(`
|
|
INSERT INTO backend.accounts (
|
|
user_id, email, user_name, preferred_language, time_zone
|
|
) VALUES ($1, $2, $3, $4, $5)
|
|
`, uuid.New(), email, "Player-XXBLATER", "en", "UTC"); err != nil {
|
|
t.Fatalf("seed account: %v", err)
|
|
}
|
|
|
|
id, err := svc.SendEmailCode(ctx, email, "en", "", "")
|
|
if err != nil {
|
|
t.Fatalf("SendEmailCode: %v", err)
|
|
}
|
|
_, code, _ := mailer.snapshot()
|
|
|
|
if _, err := db.Exec(`
|
|
UPDATE backend.accounts SET permanent_block = true WHERE email = $1
|
|
`, email); err != nil {
|
|
t.Fatalf("apply permanent_block: %v", err)
|
|
}
|
|
|
|
_, err = svc.ConfirmEmailCode(ctx, auth.ConfirmInputs{
|
|
ChallengeID: id,
|
|
Code: code,
|
|
ClientPublicKey: randomKey(t),
|
|
TimeZone: "UTC",
|
|
})
|
|
if !errors.Is(err, auth.ErrEmailPermanentlyBlocked) {
|
|
t.Fatalf("ConfirmEmailCode after block = %v, want ErrEmailPermanentlyBlocked", err)
|
|
}
|
|
}
|
|
|
|
func TestSendEmailCodeThrottleReusesChallenge(t *testing.T) {
|
|
db := startPostgres(t)
|
|
svc, mailer, _, _ := buildService(t, db)
|
|
ctx := context.Background()
|
|
|
|
const email = "throttle@example.test"
|
|
cfg := authConfig()
|
|
var firstID uuid.UUID
|
|
for i := range cfg.ChallengeThrottle.Max {
|
|
id, err := svc.SendEmailCode(ctx, email, "", "", "")
|
|
if err != nil {
|
|
t.Fatalf("SendEmailCode #%d: %v", i, err)
|
|
}
|
|
if i == 0 {
|
|
firstID = id
|
|
}
|
|
}
|
|
_, _, callsBefore := mailer.snapshot()
|
|
|
|
// One more call — must reuse the latest challenge_id and skip mail.
|
|
id, err := svc.SendEmailCode(ctx, email, "", "", "")
|
|
if err != nil {
|
|
t.Fatalf("SendEmailCode (throttled): %v", err)
|
|
}
|
|
_, _, callsAfter := mailer.snapshot()
|
|
if callsAfter != callsBefore {
|
|
t.Fatalf("mail enqueue should be skipped on throttle: before=%d after=%d", callsBefore, callsAfter)
|
|
}
|
|
if id == uuid.Nil {
|
|
t.Fatalf("throttled call returned nil challenge_id")
|
|
}
|
|
if id == firstID {
|
|
t.Fatalf("throttled call returned the FIRST challenge — expected the latest")
|
|
}
|
|
}
|
|
|
|
func TestConfirmEmailCodeDevFixedCodeBypass(t *testing.T) {
|
|
db := startPostgres(t)
|
|
cfg := authConfig()
|
|
cfg.DevFixedCode = "999999"
|
|
svc := buildServiceWithConfig(t, db, cfg)
|
|
ctx := context.Background()
|
|
|
|
id, err := svc.SendEmailCode(ctx, "dev-bypass@example.test", "en", "", "")
|
|
if err != nil {
|
|
t.Fatalf("send: %v", err)
|
|
}
|
|
|
|
session, err := svc.ConfirmEmailCode(ctx, auth.ConfirmInputs{
|
|
ChallengeID: id,
|
|
Code: "999999",
|
|
ClientPublicKey: randomKey(t),
|
|
TimeZone: "UTC",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("ConfirmEmailCode with dev fixed code: %v", err)
|
|
}
|
|
if session.DeviceSessionID == uuid.Nil {
|
|
t.Fatalf("dev fixed code did not produce a session")
|
|
}
|
|
}
|
|
|
|
func TestConfirmEmailCodeDevFixedCodeStillRejectsWrong(t *testing.T) {
|
|
db := startPostgres(t)
|
|
cfg := authConfig()
|
|
cfg.DevFixedCode = "999999"
|
|
svc := buildServiceWithConfig(t, db, cfg)
|
|
ctx := context.Background()
|
|
|
|
id, err := svc.SendEmailCode(ctx, "dev-bypass-wrong@example.test", "en", "", "")
|
|
if err != nil {
|
|
t.Fatalf("send: %v", err)
|
|
}
|
|
|
|
_, err = svc.ConfirmEmailCode(ctx, auth.ConfirmInputs{
|
|
ChallengeID: id,
|
|
Code: "111111",
|
|
ClientPublicKey: randomKey(t),
|
|
TimeZone: "UTC",
|
|
})
|
|
if !errors.Is(err, auth.ErrCodeMismatch) {
|
|
t.Fatalf("ConfirmEmailCode with neither real nor dev code = %v, want ErrCodeMismatch", err)
|
|
}
|
|
}
|
|
|
|
func TestConfirmEmailCodeWrongCode(t *testing.T) {
|
|
db := startPostgres(t)
|
|
svc, mailer, _, _ := buildService(t, db)
|
|
ctx := context.Background()
|
|
|
|
id, err := svc.SendEmailCode(ctx, "wrong@example.test", "en", "", "")
|
|
if err != nil {
|
|
t.Fatalf("send: %v", err)
|
|
}
|
|
_, code, _ := mailer.snapshot()
|
|
wrong := flipDigit(code)
|
|
|
|
_, err = svc.ConfirmEmailCode(ctx, auth.ConfirmInputs{
|
|
ChallengeID: id,
|
|
Code: wrong,
|
|
ClientPublicKey: randomKey(t),
|
|
TimeZone: "UTC",
|
|
})
|
|
if !errors.Is(err, auth.ErrCodeMismatch) {
|
|
t.Fatalf("ConfirmEmailCode wrong code = %v, want ErrCodeMismatch", err)
|
|
}
|
|
}
|
|
|
|
// TestConfirmEmailCodeDevFixedCodeBypassesAttemptsCeiling proves the
|
|
// dev-mode override is a true escape hatch: a developer who already
|
|
// burned past ChallengeMaxAttempts on a long-lived dev challenge
|
|
// (typically because the throttle merged repeated send-email-code
|
|
// calls onto one challenge_id) can still recover by submitting the
|
|
// fixed code without first waiting out the challenge TTL.
|
|
func TestConfirmEmailCodeDevFixedCodeBypassesAttemptsCeiling(t *testing.T) {
|
|
db := startPostgres(t)
|
|
cfg := authConfig()
|
|
cfg.DevFixedCode = "999999"
|
|
svc := buildServiceWithConfig(t, db, cfg)
|
|
ctx := context.Background()
|
|
|
|
id, err := svc.SendEmailCode(ctx, "dev-bypass-ceiling@example.test", "en", "", "")
|
|
if err != nil {
|
|
t.Fatalf("send: %v", err)
|
|
}
|
|
|
|
// Burn through the attempts ceiling with deliberately wrong codes.
|
|
for i := range cfg.ChallengeMaxAttempts + 1 {
|
|
_, err := svc.ConfirmEmailCode(ctx, auth.ConfirmInputs{
|
|
ChallengeID: id,
|
|
Code: "111111",
|
|
ClientPublicKey: randomKey(t),
|
|
TimeZone: "UTC",
|
|
})
|
|
if err == nil {
|
|
t.Fatalf("attempt %d unexpectedly succeeded", i)
|
|
}
|
|
}
|
|
|
|
// The dev-fixed code still goes through.
|
|
session, err := svc.ConfirmEmailCode(ctx, auth.ConfirmInputs{
|
|
ChallengeID: id,
|
|
Code: "999999",
|
|
ClientPublicKey: randomKey(t),
|
|
TimeZone: "UTC",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("dev-fixed-code after attempts exhausted: %v", err)
|
|
}
|
|
if session.DeviceSessionID == uuid.Nil {
|
|
t.Fatalf("dev-fixed-code did not produce a session")
|
|
}
|
|
}
|
|
|
|
func TestConfirmEmailCodeAttemptsCeiling(t *testing.T) {
|
|
db := startPostgres(t)
|
|
svc, mailer, _, _ := buildService(t, db)
|
|
ctx := context.Background()
|
|
|
|
id, err := svc.SendEmailCode(ctx, "ceiling@example.test", "en", "", "")
|
|
if err != nil {
|
|
t.Fatalf("send: %v", err)
|
|
}
|
|
_, code, _ := mailer.snapshot()
|
|
wrong := flipDigit(code)
|
|
|
|
// Burn `max` attempts with the wrong code.
|
|
for i := range authConfig().ChallengeMaxAttempts {
|
|
_, err := svc.ConfirmEmailCode(ctx, auth.ConfirmInputs{
|
|
ChallengeID: id,
|
|
Code: wrong,
|
|
ClientPublicKey: randomKey(t),
|
|
TimeZone: "UTC",
|
|
})
|
|
if !errors.Is(err, auth.ErrCodeMismatch) {
|
|
t.Fatalf("attempt %d: %v, want ErrCodeMismatch", i, err)
|
|
}
|
|
}
|
|
// One past the ceiling — even with the right code, ErrTooManyAttempts.
|
|
_, err = svc.ConfirmEmailCode(ctx, auth.ConfirmInputs{
|
|
ChallengeID: id,
|
|
Code: code,
|
|
ClientPublicKey: randomKey(t),
|
|
TimeZone: "UTC",
|
|
})
|
|
if !errors.Is(err, auth.ErrTooManyAttempts) {
|
|
t.Fatalf("post-ceiling = %v, want ErrTooManyAttempts", err)
|
|
}
|
|
}
|
|
|
|
func TestConfirmEmailCodeChallengeNotFound(t *testing.T) {
|
|
db := startPostgres(t)
|
|
svc, _, _, _ := buildService(t, db)
|
|
|
|
_, err := svc.ConfirmEmailCode(context.Background(), auth.ConfirmInputs{
|
|
ChallengeID: uuid.New(),
|
|
Code: "000000",
|
|
ClientPublicKey: randomKey(t),
|
|
TimeZone: "UTC",
|
|
})
|
|
if !errors.Is(err, auth.ErrChallengeNotFound) {
|
|
t.Fatalf("unknown challenge = %v, want ErrChallengeNotFound", err)
|
|
}
|
|
}
|
|
|
|
func TestRevokeAllForUser(t *testing.T) {
|
|
db := startPostgres(t)
|
|
svc, mailer, pusher, _ := buildService(t, db)
|
|
ctx := context.Background()
|
|
|
|
const email = "many@example.test"
|
|
const sessionsToCreate = 3
|
|
var userID uuid.UUID
|
|
deviceSessionIDs := make([]uuid.UUID, 0, sessionsToCreate)
|
|
for range sessionsToCreate {
|
|
id, err := svc.SendEmailCode(ctx, email, "en", "", "")
|
|
if err != nil {
|
|
t.Fatalf("send: %v", err)
|
|
}
|
|
_, code, _ := mailer.snapshot()
|
|
sess, err := svc.ConfirmEmailCode(ctx, auth.ConfirmInputs{
|
|
ChallengeID: id,
|
|
Code: code,
|
|
ClientPublicKey: randomKey(t),
|
|
TimeZone: "UTC",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("confirm: %v", err)
|
|
}
|
|
userID = sess.UserID
|
|
deviceSessionIDs = append(deviceSessionIDs, sess.DeviceSessionID)
|
|
}
|
|
|
|
revoked, err := svc.RevokeAllForUser(ctx, userID, auth.RevokeContext{
|
|
ActorKind: auth.ActorKindUserSelf,
|
|
ActorID: userID.String(),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("RevokeAllForUser: %v", err)
|
|
}
|
|
if len(revoked) != sessionsToCreate {
|
|
t.Fatalf("revoked count = %d, want %d", len(revoked), sessionsToCreate)
|
|
}
|
|
for _, dsID := range deviceSessionIDs {
|
|
if _, err := svc.GetSession(ctx, dsID); !errors.Is(err, auth.ErrSessionNotFound) {
|
|
t.Fatalf("session %s still in cache: %v", dsID, err)
|
|
}
|
|
}
|
|
if got := len(pusher.snapshot()); got != sessionsToCreate {
|
|
t.Fatalf("push emissions = %d, want %d", got, sessionsToCreate)
|
|
}
|
|
|
|
// Idempotent: revoking again returns an empty slice.
|
|
again, err := svc.RevokeAllForUser(ctx, userID, auth.RevokeContext{
|
|
ActorKind: auth.ActorKindUserSelf,
|
|
ActorID: userID.String(),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("idempotent RevokeAllForUser: %v", err)
|
|
}
|
|
if len(again) != 0 {
|
|
t.Fatalf("idempotent RevokeAllForUser = %d sessions, want 0", len(again))
|
|
}
|
|
}
|
|
|
|
// flipDigit returns code with its first digit replaced by ((digit+1) % 10)
|
|
// so the resulting string is still a valid CodeLength-digit code but
|
|
// guaranteed to differ.
|
|
func flipDigit(code string) string {
|
|
if code == "" {
|
|
return "0"
|
|
}
|
|
bytes := []byte(code)
|
|
if bytes[0] >= '0' && bytes[0] <= '9' {
|
|
bytes[0] = '0' + ((bytes[0]-'0')+1)%10
|
|
} else {
|
|
bytes[0] = '0'
|
|
}
|
|
return string(bytes)
|
|
}
|