Files
galaxy-game/backend/internal/auth/auth_e2e_test.go
T
2026-05-06 10:14:55 +03:00

512 lines
14 KiB
Go

package auth_test
import (
"context"
"crypto/rand"
"database/sql"
"errors"
"net/url"
"sync"
"testing"
"time"
"galaxy/backend/internal/auth"
"galaxy/backend/internal/config"
backendpg "galaxy/backend/internal/postgres"
"galaxy/backend/internal/user"
pgshared "galaxy/postgres"
"github.com/google/uuid"
testcontainers "github.com/testcontainers/testcontainers-go"
tcpostgres "github.com/testcontainers/testcontainers-go/modules/postgres"
"github.com/testcontainers/testcontainers-go/wait"
)
const (
pgImage = "postgres:16-alpine"
pgUser = "galaxy"
pgPassword = "galaxy"
pgDatabase = "galaxy_backend"
pgSchema = "backend"
pgStartup = 90 * time.Second
pgOpTO = 10 * time.Second
)
// startPostgres spins up a Postgres testcontainer with the backend
// migrations applied. The returned *sql.DB is closed and the container
// terminated by t.Cleanup hooks.
func startPostgres(t *testing.T) *sql.DB {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
t.Cleanup(cancel)
pgContainer, err := tcpostgres.Run(ctx, pgImage,
tcpostgres.WithDatabase(pgDatabase),
tcpostgres.WithUsername(pgUser),
tcpostgres.WithPassword(pgPassword),
testcontainers.WithWaitStrategy(
wait.ForLog("database system is ready to accept connections").
WithOccurrence(2).
WithStartupTimeout(pgStartup),
),
)
if err != nil {
t.Skipf("postgres testcontainer unavailable, skipping: %v", err)
}
t.Cleanup(func() {
if termErr := testcontainers.TerminateContainer(pgContainer); termErr != nil {
t.Errorf("terminate postgres container: %v", termErr)
}
})
baseDSN, err := pgContainer.ConnectionString(ctx, "sslmode=disable")
if err != nil {
t.Fatalf("connection string: %v", err)
}
scopedDSN, err := dsnWithSearchPath(baseDSN, pgSchema)
if err != nil {
t.Fatalf("scope dsn: %v", err)
}
cfg := pgshared.DefaultConfig()
cfg.PrimaryDSN = scopedDSN
cfg.OperationTimeout = pgOpTO
db, err := pgshared.OpenPrimary(ctx, cfg)
if err != nil {
t.Fatalf("open primary: %v", err)
}
t.Cleanup(func() { _ = db.Close() })
if err := pgshared.Ping(ctx, db, cfg.OperationTimeout); err != nil {
t.Fatalf("ping: %v", err)
}
if err := backendpg.ApplyMigrations(ctx, db); err != nil {
t.Fatalf("apply migrations: %v", err)
}
return db
}
func dsnWithSearchPath(baseDSN, schema string) (string, error) {
parsed, err := url.Parse(baseDSN)
if err != nil {
return "", err
}
values := parsed.Query()
values.Set("search_path", schema)
if values.Get("sslmode") == "" {
values.Set("sslmode", "disable")
}
parsed.RawQuery = values.Encode()
return parsed.String(), nil
}
// recordingMailer implements auth.LoginCodeMailer and remembers the most
// recent enqueue.
type recordingMailer struct {
mu sync.Mutex
lastCode string
lastTo string
calls int
}
func newRecordingMailer() *recordingMailer { return &recordingMailer{} }
func (m *recordingMailer) EnqueueLoginCode(_ context.Context, email, code string, _ time.Duration) error {
m.mu.Lock()
defer m.mu.Unlock()
m.lastTo = email
m.lastCode = code
m.calls++
return nil
}
func (m *recordingMailer) snapshot() (string, string, int) {
m.mu.Lock()
defer m.mu.Unlock()
return m.lastTo, m.lastCode, m.calls
}
// recordingPush implements auth.SessionInvalidator and counts emissions.
type recordingPush struct {
mu sync.Mutex
calls []recordedPush
}
type recordedPush struct {
deviceSessionID, userID uuid.UUID
reason string
}
func newRecordingPush() *recordingPush { return &recordingPush{} }
func (p *recordingPush) PublishSessionInvalidation(_ context.Context, dsID, uid uuid.UUID, reason string) {
p.mu.Lock()
defer p.mu.Unlock()
p.calls = append(p.calls, recordedPush{deviceSessionID: dsID, userID: uid, reason: reason})
}
func (p *recordingPush) snapshot() []recordedPush {
p.mu.Lock()
defer p.mu.Unlock()
out := make([]recordedPush, len(p.calls))
copy(out, p.calls)
return out
}
// stubGeo implements auth.GeoService with no real lookups. The country
// it returns is configurable per call via CountryForIP; LanguageForIP
// returns "" so the auth flow exercises the "en" fallback path.
type stubGeo struct {
countryByIP map[string]string
}
func newStubGeo() *stubGeo {
return &stubGeo{countryByIP: map[string]string{}}
}
func (g *stubGeo) LookupCountry(sourceIP string) string {
return g.countryByIP[sourceIP]
}
func (g *stubGeo) LanguageForIP(_ string) string { return "" }
func (g *stubGeo) SetDeclaredCountryAtRegistration(_ context.Context, _ uuid.UUID, _ string) error {
return nil
}
// authConfig builds an AuthConfig suitable for tests.
func authConfig() config.AuthConfig {
return config.AuthConfig{
ChallengeTTL: 5 * time.Minute,
ChallengeMaxAttempts: 3,
ChallengeThrottle: config.AuthChallengeThrottleConfig{
Window: time.Minute,
Max: 3,
},
UserNameMaxRetries: 10,
}
}
// buildService wires every dependency around db and returns the service
// plus the recording fakes for assertions.
func buildService(t *testing.T, db *sql.DB) (*auth.Service, *recordingMailer, *recordingPush, *stubGeo) {
t.Helper()
store := auth.NewStore(db)
cache := auth.NewCache()
if err := cache.Warm(context.Background(), store); err != nil {
t.Fatalf("warm cache: %v", err)
}
mailer := newRecordingMailer()
pusher := newRecordingPush()
geo := newStubGeo()
userStore := user.NewStore(db)
userSvc := user.NewService(user.Deps{
Store: userStore,
Cache: user.NewCache(),
UserNameMaxRetries: 10,
Now: time.Now,
})
svc := auth.NewService(auth.Deps{
Store: store,
Cache: cache,
User: userSvc,
Geo: geo,
Mail: mailer,
Push: pusher,
Config: authConfig(),
Now: time.Now,
})
return svc, mailer, pusher, geo
}
func randomKey(t *testing.T) []byte {
t.Helper()
key := make([]byte, 32)
if _, err := rand.Read(key); err != nil {
t.Fatalf("rand: %v", err)
}
return key
}
func TestAuthEndToEnd(t *testing.T) {
db := startPostgres(t)
svc, mailer, pusher, _ := buildService(t, db)
ctx := context.Background()
challengeID, err := svc.SendEmailCode(ctx, "Alice@Example.Test", "ru", "", "")
if err != nil {
t.Fatalf("SendEmailCode: %v", err)
}
if challengeID == uuid.Nil {
t.Fatalf("SendEmailCode returned nil challenge_id")
}
gotEmail, gotCode, calls := mailer.snapshot()
if gotEmail != "alice@example.test" {
t.Fatalf("mailer email = %q, want lower-cased", gotEmail)
}
if len(gotCode) != auth.CodeLength {
t.Fatalf("mailer code = %q (len %d), want length %d", gotCode, len(gotCode), auth.CodeLength)
}
if calls != 1 {
t.Fatalf("mailer calls = %d, want 1", calls)
}
pubKey := randomKey(t)
session, err := svc.ConfirmEmailCode(ctx, auth.ConfirmInputs{
ChallengeID: challengeID,
Code: gotCode,
ClientPublicKey: pubKey,
TimeZone: "Europe/Moscow",
SourceIP: "",
})
if err != nil {
t.Fatalf("ConfirmEmailCode: %v", err)
}
if session.UserID == uuid.Nil {
t.Fatalf("session has nil user_id")
}
if session.Status != auth.SessionStatusActive {
t.Fatalf("session.Status = %q, want %q", session.Status, auth.SessionStatusActive)
}
got, err := svc.GetSession(ctx, session.DeviceSessionID)
if err != nil {
t.Fatalf("GetSession: %v", err)
}
if got.UserID != session.UserID {
t.Fatalf("GetSession user_id = %s, want %s", got.UserID, session.UserID)
}
revoked, err := svc.RevokeSession(ctx, session.DeviceSessionID)
if err != nil {
t.Fatalf("RevokeSession: %v", err)
}
if revoked.Status != auth.SessionStatusRevoked {
t.Fatalf("revoked.Status = %q, want %q", revoked.Status, auth.SessionStatusRevoked)
}
if revoked.RevokedAt == nil {
t.Fatalf("revoked.RevokedAt nil after revoke")
}
if _, err := svc.GetSession(ctx, session.DeviceSessionID); !errors.Is(err, auth.ErrSessionNotFound) {
t.Fatalf("GetSession after revoke = %v, want ErrSessionNotFound", err)
}
again, err := svc.RevokeSession(ctx, session.DeviceSessionID)
if err != nil {
t.Fatalf("idempotent RevokeSession: %v", err)
}
if again.DeviceSessionID != session.DeviceSessionID || again.Status != auth.SessionStatusRevoked {
t.Fatalf("idempotent revoke shape mismatch: %+v", again)
}
pushes := pusher.snapshot()
if len(pushes) != 1 {
t.Fatalf("push emissions = %d, want 1", len(pushes))
}
if pushes[0].deviceSessionID != session.DeviceSessionID {
t.Fatalf("push device_session_id mismatch")
}
}
func TestSendEmailCodePermanentlyBlocked(t *testing.T) {
db := startPostgres(t)
svc, _, _, _ := buildService(t, db)
// Insert a permanent_block account directly.
if _, err := db.Exec(`
INSERT INTO backend.accounts (
user_id, email, user_name, preferred_language, time_zone, permanent_block
) VALUES ($1, $2, $3, $4, $5, true)
`, uuid.New(), "blocked@example.test", "Player-XXBLOCK1", "en", "UTC"); err != nil {
t.Fatalf("seed account: %v", err)
}
_, err := svc.SendEmailCode(context.Background(), "blocked@example.test", "", "", "")
if !errors.Is(err, auth.ErrEmailPermanentlyBlocked) {
t.Fatalf("SendEmailCode for blocked email = %v, want ErrEmailPermanentlyBlocked", err)
}
}
func TestSendEmailCodeThrottleReusesChallenge(t *testing.T) {
db := startPostgres(t)
svc, mailer, _, _ := buildService(t, db)
ctx := context.Background()
const email = "throttle@example.test"
cfg := authConfig()
var firstID uuid.UUID
for i := range cfg.ChallengeThrottle.Max {
id, err := svc.SendEmailCode(ctx, email, "", "", "")
if err != nil {
t.Fatalf("SendEmailCode #%d: %v", i, err)
}
if i == 0 {
firstID = id
}
}
_, _, callsBefore := mailer.snapshot()
// One more call — must reuse the latest challenge_id and skip mail.
id, err := svc.SendEmailCode(ctx, email, "", "", "")
if err != nil {
t.Fatalf("SendEmailCode (throttled): %v", err)
}
_, _, callsAfter := mailer.snapshot()
if callsAfter != callsBefore {
t.Fatalf("mail enqueue should be skipped on throttle: before=%d after=%d", callsBefore, callsAfter)
}
if id == uuid.Nil {
t.Fatalf("throttled call returned nil challenge_id")
}
if id == firstID {
t.Fatalf("throttled call returned the FIRST challenge — expected the latest")
}
}
func TestConfirmEmailCodeWrongCode(t *testing.T) {
db := startPostgres(t)
svc, mailer, _, _ := buildService(t, db)
ctx := context.Background()
id, err := svc.SendEmailCode(ctx, "wrong@example.test", "en", "", "")
if err != nil {
t.Fatalf("send: %v", err)
}
_, code, _ := mailer.snapshot()
wrong := flipDigit(code)
_, err = svc.ConfirmEmailCode(ctx, auth.ConfirmInputs{
ChallengeID: id,
Code: wrong,
ClientPublicKey: randomKey(t),
TimeZone: "UTC",
})
if !errors.Is(err, auth.ErrCodeMismatch) {
t.Fatalf("ConfirmEmailCode wrong code = %v, want ErrCodeMismatch", err)
}
}
func TestConfirmEmailCodeAttemptsCeiling(t *testing.T) {
db := startPostgres(t)
svc, mailer, _, _ := buildService(t, db)
ctx := context.Background()
id, err := svc.SendEmailCode(ctx, "ceiling@example.test", "en", "", "")
if err != nil {
t.Fatalf("send: %v", err)
}
_, code, _ := mailer.snapshot()
wrong := flipDigit(code)
// Burn `max` attempts with the wrong code.
for i := range authConfig().ChallengeMaxAttempts {
_, err := svc.ConfirmEmailCode(ctx, auth.ConfirmInputs{
ChallengeID: id,
Code: wrong,
ClientPublicKey: randomKey(t),
TimeZone: "UTC",
})
if !errors.Is(err, auth.ErrCodeMismatch) {
t.Fatalf("attempt %d: %v, want ErrCodeMismatch", i, err)
}
}
// One past the ceiling — even with the right code, ErrTooManyAttempts.
_, err = svc.ConfirmEmailCode(ctx, auth.ConfirmInputs{
ChallengeID: id,
Code: code,
ClientPublicKey: randomKey(t),
TimeZone: "UTC",
})
if !errors.Is(err, auth.ErrTooManyAttempts) {
t.Fatalf("post-ceiling = %v, want ErrTooManyAttempts", err)
}
}
func TestConfirmEmailCodeChallengeNotFound(t *testing.T) {
db := startPostgres(t)
svc, _, _, _ := buildService(t, db)
_, err := svc.ConfirmEmailCode(context.Background(), auth.ConfirmInputs{
ChallengeID: uuid.New(),
Code: "000000",
ClientPublicKey: randomKey(t),
TimeZone: "UTC",
})
if !errors.Is(err, auth.ErrChallengeNotFound) {
t.Fatalf("unknown challenge = %v, want ErrChallengeNotFound", err)
}
}
func TestRevokeAllForUser(t *testing.T) {
db := startPostgres(t)
svc, mailer, pusher, _ := buildService(t, db)
ctx := context.Background()
const email = "many@example.test"
const sessionsToCreate = 3
var userID uuid.UUID
deviceSessionIDs := make([]uuid.UUID, 0, sessionsToCreate)
for range sessionsToCreate {
id, err := svc.SendEmailCode(ctx, email, "en", "", "")
if err != nil {
t.Fatalf("send: %v", err)
}
_, code, _ := mailer.snapshot()
sess, err := svc.ConfirmEmailCode(ctx, auth.ConfirmInputs{
ChallengeID: id,
Code: code,
ClientPublicKey: randomKey(t),
TimeZone: "UTC",
})
if err != nil {
t.Fatalf("confirm: %v", err)
}
userID = sess.UserID
deviceSessionIDs = append(deviceSessionIDs, sess.DeviceSessionID)
}
revoked, err := svc.RevokeAllForUser(ctx, userID)
if err != nil {
t.Fatalf("RevokeAllForUser: %v", err)
}
if len(revoked) != sessionsToCreate {
t.Fatalf("revoked count = %d, want %d", len(revoked), sessionsToCreate)
}
for _, dsID := range deviceSessionIDs {
if _, err := svc.GetSession(ctx, dsID); !errors.Is(err, auth.ErrSessionNotFound) {
t.Fatalf("session %s still in cache: %v", dsID, err)
}
}
if got := len(pusher.snapshot()); got != sessionsToCreate {
t.Fatalf("push emissions = %d, want %d", got, sessionsToCreate)
}
// Idempotent: revoking again returns an empty slice.
again, err := svc.RevokeAllForUser(ctx, userID)
if err != nil {
t.Fatalf("idempotent RevokeAllForUser: %v", err)
}
if len(again) != 0 {
t.Fatalf("idempotent RevokeAllForUser = %d sessions, want 0", len(again))
}
}
// flipDigit returns code with its first digit replaced by ((digit+1) % 10)
// so the resulting string is still a valid CodeLength-digit code but
// guaranteed to differ.
func flipDigit(code string) string {
if code == "" {
return "0"
}
bytes := []byte(code)
if bytes[0] >= '0' && bytes[0] <= '9' {
bytes[0] = '0' + ((bytes[0]-'0')+1)%10
} else {
bytes[0] = '0'
}
return string(bytes)
}