package auth_test import ( "context" "crypto/rand" "database/sql" "errors" "net/url" "sync" "testing" "time" "galaxy/backend/internal/auth" "galaxy/backend/internal/config" backendpg "galaxy/backend/internal/postgres" "galaxy/backend/internal/user" pgshared "galaxy/postgres" "github.com/google/uuid" testcontainers "github.com/testcontainers/testcontainers-go" tcpostgres "github.com/testcontainers/testcontainers-go/modules/postgres" "github.com/testcontainers/testcontainers-go/wait" ) const ( pgImage = "postgres:16-alpine" pgUser = "galaxy" pgPassword = "galaxy" pgDatabase = "galaxy_backend" pgSchema = "backend" pgStartup = 90 * time.Second pgOpTO = 10 * time.Second ) // startPostgres spins up a Postgres testcontainer with the backend // migrations applied. The returned *sql.DB is closed and the container // terminated by t.Cleanup hooks. func startPostgres(t *testing.T) *sql.DB { t.Helper() ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) t.Cleanup(cancel) pgContainer, err := tcpostgres.Run(ctx, pgImage, tcpostgres.WithDatabase(pgDatabase), tcpostgres.WithUsername(pgUser), tcpostgres.WithPassword(pgPassword), testcontainers.WithWaitStrategy( wait.ForLog("database system is ready to accept connections"). WithOccurrence(2). WithStartupTimeout(pgStartup), ), ) if err != nil { t.Skipf("postgres testcontainer unavailable, skipping: %v", err) } t.Cleanup(func() { if termErr := testcontainers.TerminateContainer(pgContainer); termErr != nil { t.Errorf("terminate postgres container: %v", termErr) } }) baseDSN, err := pgContainer.ConnectionString(ctx, "sslmode=disable") if err != nil { t.Fatalf("connection string: %v", err) } scopedDSN, err := dsnWithSearchPath(baseDSN, pgSchema) if err != nil { t.Fatalf("scope dsn: %v", err) } cfg := pgshared.DefaultConfig() cfg.PrimaryDSN = scopedDSN cfg.OperationTimeout = pgOpTO db, err := pgshared.OpenPrimary(ctx, cfg) if err != nil { t.Fatalf("open primary: %v", err) } t.Cleanup(func() { _ = db.Close() }) if err := pgshared.Ping(ctx, db, cfg.OperationTimeout); err != nil { t.Fatalf("ping: %v", err) } if err := backendpg.ApplyMigrations(ctx, db); err != nil { t.Fatalf("apply migrations: %v", err) } return db } func dsnWithSearchPath(baseDSN, schema string) (string, error) { parsed, err := url.Parse(baseDSN) if err != nil { return "", err } values := parsed.Query() values.Set("search_path", schema) if values.Get("sslmode") == "" { values.Set("sslmode", "disable") } parsed.RawQuery = values.Encode() return parsed.String(), nil } // recordingMailer implements auth.LoginCodeMailer and remembers the most // recent enqueue. type recordingMailer struct { mu sync.Mutex lastCode string lastTo string calls int } func newRecordingMailer() *recordingMailer { return &recordingMailer{} } func (m *recordingMailer) EnqueueLoginCode(_ context.Context, email, code string, _ time.Duration) error { m.mu.Lock() defer m.mu.Unlock() m.lastTo = email m.lastCode = code m.calls++ return nil } func (m *recordingMailer) snapshot() (string, string, int) { m.mu.Lock() defer m.mu.Unlock() return m.lastTo, m.lastCode, m.calls } // recordingPush implements auth.SessionInvalidator and counts emissions. type recordingPush struct { mu sync.Mutex calls []recordedPush } type recordedPush struct { deviceSessionID, userID uuid.UUID reason string } func newRecordingPush() *recordingPush { return &recordingPush{} } func (p *recordingPush) PublishSessionInvalidation(_ context.Context, dsID, uid uuid.UUID, reason string) { p.mu.Lock() defer p.mu.Unlock() p.calls = append(p.calls, recordedPush{deviceSessionID: dsID, userID: uid, reason: reason}) } func (p *recordingPush) snapshot() []recordedPush { p.mu.Lock() defer p.mu.Unlock() out := make([]recordedPush, len(p.calls)) copy(out, p.calls) return out } // stubGeo implements auth.GeoService with no real lookups. The country // it returns is configurable per call via CountryForIP; LanguageForIP // returns "" so the auth flow exercises the "en" fallback path. type stubGeo struct { countryByIP map[string]string } func newStubGeo() *stubGeo { return &stubGeo{countryByIP: map[string]string{}} } func (g *stubGeo) LookupCountry(sourceIP string) string { return g.countryByIP[sourceIP] } func (g *stubGeo) LanguageForIP(_ string) string { return "" } func (g *stubGeo) SetDeclaredCountryAtRegistration(_ context.Context, _ uuid.UUID, _ string) error { return nil } // authConfig builds an AuthConfig suitable for tests. func authConfig() config.AuthConfig { return config.AuthConfig{ ChallengeTTL: 5 * time.Minute, ChallengeMaxAttempts: 3, ChallengeThrottle: config.AuthChallengeThrottleConfig{ Window: time.Minute, Max: 3, }, UserNameMaxRetries: 10, } } // buildService wires every dependency around db and returns the service // plus the recording fakes for assertions. func buildService(t *testing.T, db *sql.DB) (*auth.Service, *recordingMailer, *recordingPush, *stubGeo) { t.Helper() store := auth.NewStore(db) cache := auth.NewCache() if err := cache.Warm(context.Background(), store); err != nil { t.Fatalf("warm cache: %v", err) } mailer := newRecordingMailer() pusher := newRecordingPush() geo := newStubGeo() userStore := user.NewStore(db) userSvc := user.NewService(user.Deps{ Store: userStore, Cache: user.NewCache(), UserNameMaxRetries: 10, Now: time.Now, }) svc := auth.NewService(auth.Deps{ Store: store, Cache: cache, User: userSvc, Geo: geo, Mail: mailer, Push: pusher, Config: authConfig(), Now: time.Now, }) return svc, mailer, pusher, geo } func randomKey(t *testing.T) []byte { t.Helper() key := make([]byte, 32) if _, err := rand.Read(key); err != nil { t.Fatalf("rand: %v", err) } return key } func TestAuthEndToEnd(t *testing.T) { db := startPostgres(t) svc, mailer, pusher, _ := buildService(t, db) ctx := context.Background() challengeID, err := svc.SendEmailCode(ctx, "Alice@Example.Test", "ru", "", "") if err != nil { t.Fatalf("SendEmailCode: %v", err) } if challengeID == uuid.Nil { t.Fatalf("SendEmailCode returned nil challenge_id") } gotEmail, gotCode, calls := mailer.snapshot() if gotEmail != "alice@example.test" { t.Fatalf("mailer email = %q, want lower-cased", gotEmail) } if len(gotCode) != auth.CodeLength { t.Fatalf("mailer code = %q (len %d), want length %d", gotCode, len(gotCode), auth.CodeLength) } if calls != 1 { t.Fatalf("mailer calls = %d, want 1", calls) } pubKey := randomKey(t) session, err := svc.ConfirmEmailCode(ctx, auth.ConfirmInputs{ ChallengeID: challengeID, Code: gotCode, ClientPublicKey: pubKey, TimeZone: "Europe/Moscow", SourceIP: "", }) if err != nil { t.Fatalf("ConfirmEmailCode: %v", err) } if session.UserID == uuid.Nil { t.Fatalf("session has nil user_id") } if session.Status != auth.SessionStatusActive { t.Fatalf("session.Status = %q, want %q", session.Status, auth.SessionStatusActive) } got, err := svc.GetSession(ctx, session.DeviceSessionID) if err != nil { t.Fatalf("GetSession: %v", err) } if got.UserID != session.UserID { t.Fatalf("GetSession user_id = %s, want %s", got.UserID, session.UserID) } revoked, err := svc.RevokeSession(ctx, session.DeviceSessionID) if err != nil { t.Fatalf("RevokeSession: %v", err) } if revoked.Status != auth.SessionStatusRevoked { t.Fatalf("revoked.Status = %q, want %q", revoked.Status, auth.SessionStatusRevoked) } if revoked.RevokedAt == nil { t.Fatalf("revoked.RevokedAt nil after revoke") } if _, err := svc.GetSession(ctx, session.DeviceSessionID); !errors.Is(err, auth.ErrSessionNotFound) { t.Fatalf("GetSession after revoke = %v, want ErrSessionNotFound", err) } again, err := svc.RevokeSession(ctx, session.DeviceSessionID) if err != nil { t.Fatalf("idempotent RevokeSession: %v", err) } if again.DeviceSessionID != session.DeviceSessionID || again.Status != auth.SessionStatusRevoked { t.Fatalf("idempotent revoke shape mismatch: %+v", again) } pushes := pusher.snapshot() if len(pushes) != 1 { t.Fatalf("push emissions = %d, want 1", len(pushes)) } if pushes[0].deviceSessionID != session.DeviceSessionID { t.Fatalf("push device_session_id mismatch") } } func TestSendEmailCodePermanentlyBlocked(t *testing.T) { db := startPostgres(t) svc, _, _, _ := buildService(t, db) // Insert a permanent_block account directly. if _, err := db.Exec(` INSERT INTO backend.accounts ( user_id, email, user_name, preferred_language, time_zone, permanent_block ) VALUES ($1, $2, $3, $4, $5, true) `, uuid.New(), "blocked@example.test", "Player-XXBLOCK1", "en", "UTC"); err != nil { t.Fatalf("seed account: %v", err) } _, err := svc.SendEmailCode(context.Background(), "blocked@example.test", "", "", "") if !errors.Is(err, auth.ErrEmailPermanentlyBlocked) { t.Fatalf("SendEmailCode for blocked email = %v, want ErrEmailPermanentlyBlocked", err) } } func TestSendEmailCodeThrottleReusesChallenge(t *testing.T) { db := startPostgres(t) svc, mailer, _, _ := buildService(t, db) ctx := context.Background() const email = "throttle@example.test" cfg := authConfig() var firstID uuid.UUID for i := range cfg.ChallengeThrottle.Max { id, err := svc.SendEmailCode(ctx, email, "", "", "") if err != nil { t.Fatalf("SendEmailCode #%d: %v", i, err) } if i == 0 { firstID = id } } _, _, callsBefore := mailer.snapshot() // One more call — must reuse the latest challenge_id and skip mail. id, err := svc.SendEmailCode(ctx, email, "", "", "") if err != nil { t.Fatalf("SendEmailCode (throttled): %v", err) } _, _, callsAfter := mailer.snapshot() if callsAfter != callsBefore { t.Fatalf("mail enqueue should be skipped on throttle: before=%d after=%d", callsBefore, callsAfter) } if id == uuid.Nil { t.Fatalf("throttled call returned nil challenge_id") } if id == firstID { t.Fatalf("throttled call returned the FIRST challenge — expected the latest") } } func TestConfirmEmailCodeWrongCode(t *testing.T) { db := startPostgres(t) svc, mailer, _, _ := buildService(t, db) ctx := context.Background() id, err := svc.SendEmailCode(ctx, "wrong@example.test", "en", "", "") if err != nil { t.Fatalf("send: %v", err) } _, code, _ := mailer.snapshot() wrong := flipDigit(code) _, err = svc.ConfirmEmailCode(ctx, auth.ConfirmInputs{ ChallengeID: id, Code: wrong, ClientPublicKey: randomKey(t), TimeZone: "UTC", }) if !errors.Is(err, auth.ErrCodeMismatch) { t.Fatalf("ConfirmEmailCode wrong code = %v, want ErrCodeMismatch", err) } } func TestConfirmEmailCodeAttemptsCeiling(t *testing.T) { db := startPostgres(t) svc, mailer, _, _ := buildService(t, db) ctx := context.Background() id, err := svc.SendEmailCode(ctx, "ceiling@example.test", "en", "", "") if err != nil { t.Fatalf("send: %v", err) } _, code, _ := mailer.snapshot() wrong := flipDigit(code) // Burn `max` attempts with the wrong code. for i := range authConfig().ChallengeMaxAttempts { _, err := svc.ConfirmEmailCode(ctx, auth.ConfirmInputs{ ChallengeID: id, Code: wrong, ClientPublicKey: randomKey(t), TimeZone: "UTC", }) if !errors.Is(err, auth.ErrCodeMismatch) { t.Fatalf("attempt %d: %v, want ErrCodeMismatch", i, err) } } // One past the ceiling — even with the right code, ErrTooManyAttempts. _, err = svc.ConfirmEmailCode(ctx, auth.ConfirmInputs{ ChallengeID: id, Code: code, ClientPublicKey: randomKey(t), TimeZone: "UTC", }) if !errors.Is(err, auth.ErrTooManyAttempts) { t.Fatalf("post-ceiling = %v, want ErrTooManyAttempts", err) } } func TestConfirmEmailCodeChallengeNotFound(t *testing.T) { db := startPostgres(t) svc, _, _, _ := buildService(t, db) _, err := svc.ConfirmEmailCode(context.Background(), auth.ConfirmInputs{ ChallengeID: uuid.New(), Code: "000000", ClientPublicKey: randomKey(t), TimeZone: "UTC", }) if !errors.Is(err, auth.ErrChallengeNotFound) { t.Fatalf("unknown challenge = %v, want ErrChallengeNotFound", err) } } func TestRevokeAllForUser(t *testing.T) { db := startPostgres(t) svc, mailer, pusher, _ := buildService(t, db) ctx := context.Background() const email = "many@example.test" const sessionsToCreate = 3 var userID uuid.UUID deviceSessionIDs := make([]uuid.UUID, 0, sessionsToCreate) for range sessionsToCreate { id, err := svc.SendEmailCode(ctx, email, "en", "", "") if err != nil { t.Fatalf("send: %v", err) } _, code, _ := mailer.snapshot() sess, err := svc.ConfirmEmailCode(ctx, auth.ConfirmInputs{ ChallengeID: id, Code: code, ClientPublicKey: randomKey(t), TimeZone: "UTC", }) if err != nil { t.Fatalf("confirm: %v", err) } userID = sess.UserID deviceSessionIDs = append(deviceSessionIDs, sess.DeviceSessionID) } revoked, err := svc.RevokeAllForUser(ctx, userID) if err != nil { t.Fatalf("RevokeAllForUser: %v", err) } if len(revoked) != sessionsToCreate { t.Fatalf("revoked count = %d, want %d", len(revoked), sessionsToCreate) } for _, dsID := range deviceSessionIDs { if _, err := svc.GetSession(ctx, dsID); !errors.Is(err, auth.ErrSessionNotFound) { t.Fatalf("session %s still in cache: %v", dsID, err) } } if got := len(pusher.snapshot()); got != sessionsToCreate { t.Fatalf("push emissions = %d, want %d", got, sessionsToCreate) } // Idempotent: revoking again returns an empty slice. again, err := svc.RevokeAllForUser(ctx, userID) if err != nil { t.Fatalf("idempotent RevokeAllForUser: %v", err) } if len(again) != 0 { t.Fatalf("idempotent RevokeAllForUser = %d sessions, want 0", len(again)) } } // flipDigit returns code with its first digit replaced by ((digit+1) % 10) // so the resulting string is still a valid CodeLength-digit code but // guaranteed to differ. func flipDigit(code string) string { if code == "" { return "0" } bytes := []byte(code) if bytes[0] >= '0' && bytes[0] <= '9' { bytes[0] = '0' + ((bytes[0]-'0')+1)%10 } else { bytes[0] = '0' } return string(bytes) }