feat: backend service
This commit is contained in:
@@ -0,0 +1,511 @@
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"net/url"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/backend/internal/auth"
|
||||
"galaxy/backend/internal/config"
|
||||
backendpg "galaxy/backend/internal/postgres"
|
||||
"galaxy/backend/internal/user"
|
||||
pgshared "galaxy/postgres"
|
||||
|
||||
"github.com/google/uuid"
|
||||
testcontainers "github.com/testcontainers/testcontainers-go"
|
||||
tcpostgres "github.com/testcontainers/testcontainers-go/modules/postgres"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
)
|
||||
|
||||
const (
|
||||
pgImage = "postgres:16-alpine"
|
||||
pgUser = "galaxy"
|
||||
pgPassword = "galaxy"
|
||||
pgDatabase = "galaxy_backend"
|
||||
pgSchema = "backend"
|
||||
pgStartup = 90 * time.Second
|
||||
pgOpTO = 10 * time.Second
|
||||
)
|
||||
|
||||
// startPostgres spins up a Postgres testcontainer with the backend
|
||||
// migrations applied. The returned *sql.DB is closed and the container
|
||||
// terminated by t.Cleanup hooks.
|
||||
func startPostgres(t *testing.T) *sql.DB {
|
||||
t.Helper()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
pgContainer, err := tcpostgres.Run(ctx, pgImage,
|
||||
tcpostgres.WithDatabase(pgDatabase),
|
||||
tcpostgres.WithUsername(pgUser),
|
||||
tcpostgres.WithPassword(pgPassword),
|
||||
testcontainers.WithWaitStrategy(
|
||||
wait.ForLog("database system is ready to accept connections").
|
||||
WithOccurrence(2).
|
||||
WithStartupTimeout(pgStartup),
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
t.Skipf("postgres testcontainer unavailable, skipping: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if termErr := testcontainers.TerminateContainer(pgContainer); termErr != nil {
|
||||
t.Errorf("terminate postgres container: %v", termErr)
|
||||
}
|
||||
})
|
||||
|
||||
baseDSN, err := pgContainer.ConnectionString(ctx, "sslmode=disable")
|
||||
if err != nil {
|
||||
t.Fatalf("connection string: %v", err)
|
||||
}
|
||||
scopedDSN, err := dsnWithSearchPath(baseDSN, pgSchema)
|
||||
if err != nil {
|
||||
t.Fatalf("scope dsn: %v", err)
|
||||
}
|
||||
|
||||
cfg := pgshared.DefaultConfig()
|
||||
cfg.PrimaryDSN = scopedDSN
|
||||
cfg.OperationTimeout = pgOpTO
|
||||
|
||||
db, err := pgshared.OpenPrimary(ctx, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("open primary: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
|
||||
if err := pgshared.Ping(ctx, db, cfg.OperationTimeout); err != nil {
|
||||
t.Fatalf("ping: %v", err)
|
||||
}
|
||||
if err := backendpg.ApplyMigrations(ctx, db); err != nil {
|
||||
t.Fatalf("apply migrations: %v", err)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
func dsnWithSearchPath(baseDSN, schema string) (string, error) {
|
||||
parsed, err := url.Parse(baseDSN)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
values := parsed.Query()
|
||||
values.Set("search_path", schema)
|
||||
if values.Get("sslmode") == "" {
|
||||
values.Set("sslmode", "disable")
|
||||
}
|
||||
parsed.RawQuery = values.Encode()
|
||||
return parsed.String(), nil
|
||||
}
|
||||
|
||||
// recordingMailer implements auth.LoginCodeMailer and remembers the most
|
||||
// recent enqueue.
|
||||
type recordingMailer struct {
|
||||
mu sync.Mutex
|
||||
lastCode string
|
||||
lastTo string
|
||||
calls int
|
||||
}
|
||||
|
||||
func newRecordingMailer() *recordingMailer { return &recordingMailer{} }
|
||||
|
||||
func (m *recordingMailer) EnqueueLoginCode(_ context.Context, email, code string, _ time.Duration) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.lastTo = email
|
||||
m.lastCode = code
|
||||
m.calls++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *recordingMailer) snapshot() (string, string, int) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.lastTo, m.lastCode, m.calls
|
||||
}
|
||||
|
||||
// recordingPush implements auth.SessionInvalidator and counts emissions.
|
||||
type recordingPush struct {
|
||||
mu sync.Mutex
|
||||
calls []recordedPush
|
||||
}
|
||||
|
||||
type recordedPush struct {
|
||||
deviceSessionID, userID uuid.UUID
|
||||
reason string
|
||||
}
|
||||
|
||||
func newRecordingPush() *recordingPush { return &recordingPush{} }
|
||||
|
||||
func (p *recordingPush) PublishSessionInvalidation(_ context.Context, dsID, uid uuid.UUID, reason string) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.calls = append(p.calls, recordedPush{deviceSessionID: dsID, userID: uid, reason: reason})
|
||||
}
|
||||
|
||||
func (p *recordingPush) snapshot() []recordedPush {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
out := make([]recordedPush, len(p.calls))
|
||||
copy(out, p.calls)
|
||||
return out
|
||||
}
|
||||
|
||||
// stubGeo implements auth.GeoService with no real lookups. The country
|
||||
// it returns is configurable per call via CountryForIP; LanguageForIP
|
||||
// returns "" so the auth flow exercises the "en" fallback path.
|
||||
type stubGeo struct {
|
||||
countryByIP map[string]string
|
||||
}
|
||||
|
||||
func newStubGeo() *stubGeo {
|
||||
return &stubGeo{countryByIP: map[string]string{}}
|
||||
}
|
||||
|
||||
func (g *stubGeo) LookupCountry(sourceIP string) string {
|
||||
return g.countryByIP[sourceIP]
|
||||
}
|
||||
|
||||
func (g *stubGeo) LanguageForIP(_ string) string { return "" }
|
||||
|
||||
func (g *stubGeo) SetDeclaredCountryAtRegistration(_ context.Context, _ uuid.UUID, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// authConfig builds an AuthConfig suitable for tests.
|
||||
func authConfig() config.AuthConfig {
|
||||
return config.AuthConfig{
|
||||
ChallengeTTL: 5 * time.Minute,
|
||||
ChallengeMaxAttempts: 3,
|
||||
ChallengeThrottle: config.AuthChallengeThrottleConfig{
|
||||
Window: time.Minute,
|
||||
Max: 3,
|
||||
},
|
||||
UserNameMaxRetries: 10,
|
||||
}
|
||||
}
|
||||
|
||||
// buildService wires every dependency around db and returns the service
|
||||
// plus the recording fakes for assertions.
|
||||
func buildService(t *testing.T, db *sql.DB) (*auth.Service, *recordingMailer, *recordingPush, *stubGeo) {
|
||||
t.Helper()
|
||||
store := auth.NewStore(db)
|
||||
cache := auth.NewCache()
|
||||
if err := cache.Warm(context.Background(), store); err != nil {
|
||||
t.Fatalf("warm cache: %v", err)
|
||||
}
|
||||
mailer := newRecordingMailer()
|
||||
pusher := newRecordingPush()
|
||||
geo := newStubGeo()
|
||||
userStore := user.NewStore(db)
|
||||
userSvc := user.NewService(user.Deps{
|
||||
|
||||
Store: userStore,
|
||||
Cache: user.NewCache(),
|
||||
UserNameMaxRetries: 10,
|
||||
Now: time.Now,
|
||||
})
|
||||
svc := auth.NewService(auth.Deps{
|
||||
Store: store,
|
||||
Cache: cache,
|
||||
User: userSvc,
|
||||
Geo: geo,
|
||||
Mail: mailer,
|
||||
Push: pusher,
|
||||
Config: authConfig(),
|
||||
Now: time.Now,
|
||||
})
|
||||
return svc, mailer, pusher, geo
|
||||
}
|
||||
|
||||
func randomKey(t *testing.T) []byte {
|
||||
t.Helper()
|
||||
key := make([]byte, 32)
|
||||
if _, err := rand.Read(key); err != nil {
|
||||
t.Fatalf("rand: %v", err)
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
func TestAuthEndToEnd(t *testing.T) {
|
||||
db := startPostgres(t)
|
||||
svc, mailer, pusher, _ := buildService(t, db)
|
||||
ctx := context.Background()
|
||||
|
||||
challengeID, err := svc.SendEmailCode(ctx, "Alice@Example.Test", "ru", "", "")
|
||||
if err != nil {
|
||||
t.Fatalf("SendEmailCode: %v", err)
|
||||
}
|
||||
if challengeID == uuid.Nil {
|
||||
t.Fatalf("SendEmailCode returned nil challenge_id")
|
||||
}
|
||||
gotEmail, gotCode, calls := mailer.snapshot()
|
||||
if gotEmail != "alice@example.test" {
|
||||
t.Fatalf("mailer email = %q, want lower-cased", gotEmail)
|
||||
}
|
||||
if len(gotCode) != auth.CodeLength {
|
||||
t.Fatalf("mailer code = %q (len %d), want length %d", gotCode, len(gotCode), auth.CodeLength)
|
||||
}
|
||||
if calls != 1 {
|
||||
t.Fatalf("mailer calls = %d, want 1", calls)
|
||||
}
|
||||
|
||||
pubKey := randomKey(t)
|
||||
session, err := svc.ConfirmEmailCode(ctx, auth.ConfirmInputs{
|
||||
ChallengeID: challengeID,
|
||||
Code: gotCode,
|
||||
ClientPublicKey: pubKey,
|
||||
TimeZone: "Europe/Moscow",
|
||||
SourceIP: "",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ConfirmEmailCode: %v", err)
|
||||
}
|
||||
if session.UserID == uuid.Nil {
|
||||
t.Fatalf("session has nil user_id")
|
||||
}
|
||||
if session.Status != auth.SessionStatusActive {
|
||||
t.Fatalf("session.Status = %q, want %q", session.Status, auth.SessionStatusActive)
|
||||
}
|
||||
|
||||
got, err := svc.GetSession(ctx, session.DeviceSessionID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetSession: %v", err)
|
||||
}
|
||||
if got.UserID != session.UserID {
|
||||
t.Fatalf("GetSession user_id = %s, want %s", got.UserID, session.UserID)
|
||||
}
|
||||
|
||||
revoked, err := svc.RevokeSession(ctx, session.DeviceSessionID)
|
||||
if err != nil {
|
||||
t.Fatalf("RevokeSession: %v", err)
|
||||
}
|
||||
if revoked.Status != auth.SessionStatusRevoked {
|
||||
t.Fatalf("revoked.Status = %q, want %q", revoked.Status, auth.SessionStatusRevoked)
|
||||
}
|
||||
if revoked.RevokedAt == nil {
|
||||
t.Fatalf("revoked.RevokedAt nil after revoke")
|
||||
}
|
||||
|
||||
if _, err := svc.GetSession(ctx, session.DeviceSessionID); !errors.Is(err, auth.ErrSessionNotFound) {
|
||||
t.Fatalf("GetSession after revoke = %v, want ErrSessionNotFound", err)
|
||||
}
|
||||
|
||||
again, err := svc.RevokeSession(ctx, session.DeviceSessionID)
|
||||
if err != nil {
|
||||
t.Fatalf("idempotent RevokeSession: %v", err)
|
||||
}
|
||||
if again.DeviceSessionID != session.DeviceSessionID || again.Status != auth.SessionStatusRevoked {
|
||||
t.Fatalf("idempotent revoke shape mismatch: %+v", again)
|
||||
}
|
||||
|
||||
pushes := pusher.snapshot()
|
||||
if len(pushes) != 1 {
|
||||
t.Fatalf("push emissions = %d, want 1", len(pushes))
|
||||
}
|
||||
if pushes[0].deviceSessionID != session.DeviceSessionID {
|
||||
t.Fatalf("push device_session_id mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendEmailCodePermanentlyBlocked(t *testing.T) {
|
||||
db := startPostgres(t)
|
||||
svc, _, _, _ := buildService(t, db)
|
||||
|
||||
// Insert a permanent_block account directly.
|
||||
if _, err := db.Exec(`
|
||||
INSERT INTO backend.accounts (
|
||||
user_id, email, user_name, preferred_language, time_zone, permanent_block
|
||||
) VALUES ($1, $2, $3, $4, $5, true)
|
||||
`, uuid.New(), "blocked@example.test", "Player-XXBLOCK1", "en", "UTC"); err != nil {
|
||||
t.Fatalf("seed account: %v", err)
|
||||
}
|
||||
|
||||
_, err := svc.SendEmailCode(context.Background(), "blocked@example.test", "", "", "")
|
||||
if !errors.Is(err, auth.ErrEmailPermanentlyBlocked) {
|
||||
t.Fatalf("SendEmailCode for blocked email = %v, want ErrEmailPermanentlyBlocked", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendEmailCodeThrottleReusesChallenge(t *testing.T) {
|
||||
db := startPostgres(t)
|
||||
svc, mailer, _, _ := buildService(t, db)
|
||||
ctx := context.Background()
|
||||
|
||||
const email = "throttle@example.test"
|
||||
cfg := authConfig()
|
||||
var firstID uuid.UUID
|
||||
for i := range cfg.ChallengeThrottle.Max {
|
||||
id, err := svc.SendEmailCode(ctx, email, "", "", "")
|
||||
if err != nil {
|
||||
t.Fatalf("SendEmailCode #%d: %v", i, err)
|
||||
}
|
||||
if i == 0 {
|
||||
firstID = id
|
||||
}
|
||||
}
|
||||
_, _, callsBefore := mailer.snapshot()
|
||||
|
||||
// One more call — must reuse the latest challenge_id and skip mail.
|
||||
id, err := svc.SendEmailCode(ctx, email, "", "", "")
|
||||
if err != nil {
|
||||
t.Fatalf("SendEmailCode (throttled): %v", err)
|
||||
}
|
||||
_, _, callsAfter := mailer.snapshot()
|
||||
if callsAfter != callsBefore {
|
||||
t.Fatalf("mail enqueue should be skipped on throttle: before=%d after=%d", callsBefore, callsAfter)
|
||||
}
|
||||
if id == uuid.Nil {
|
||||
t.Fatalf("throttled call returned nil challenge_id")
|
||||
}
|
||||
if id == firstID {
|
||||
t.Fatalf("throttled call returned the FIRST challenge — expected the latest")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmEmailCodeWrongCode(t *testing.T) {
|
||||
db := startPostgres(t)
|
||||
svc, mailer, _, _ := buildService(t, db)
|
||||
ctx := context.Background()
|
||||
|
||||
id, err := svc.SendEmailCode(ctx, "wrong@example.test", "en", "", "")
|
||||
if err != nil {
|
||||
t.Fatalf("send: %v", err)
|
||||
}
|
||||
_, code, _ := mailer.snapshot()
|
||||
wrong := flipDigit(code)
|
||||
|
||||
_, err = svc.ConfirmEmailCode(ctx, auth.ConfirmInputs{
|
||||
ChallengeID: id,
|
||||
Code: wrong,
|
||||
ClientPublicKey: randomKey(t),
|
||||
TimeZone: "UTC",
|
||||
})
|
||||
if !errors.Is(err, auth.ErrCodeMismatch) {
|
||||
t.Fatalf("ConfirmEmailCode wrong code = %v, want ErrCodeMismatch", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmEmailCodeAttemptsCeiling(t *testing.T) {
|
||||
db := startPostgres(t)
|
||||
svc, mailer, _, _ := buildService(t, db)
|
||||
ctx := context.Background()
|
||||
|
||||
id, err := svc.SendEmailCode(ctx, "ceiling@example.test", "en", "", "")
|
||||
if err != nil {
|
||||
t.Fatalf("send: %v", err)
|
||||
}
|
||||
_, code, _ := mailer.snapshot()
|
||||
wrong := flipDigit(code)
|
||||
|
||||
// Burn `max` attempts with the wrong code.
|
||||
for i := range authConfig().ChallengeMaxAttempts {
|
||||
_, err := svc.ConfirmEmailCode(ctx, auth.ConfirmInputs{
|
||||
ChallengeID: id,
|
||||
Code: wrong,
|
||||
ClientPublicKey: randomKey(t),
|
||||
TimeZone: "UTC",
|
||||
})
|
||||
if !errors.Is(err, auth.ErrCodeMismatch) {
|
||||
t.Fatalf("attempt %d: %v, want ErrCodeMismatch", i, err)
|
||||
}
|
||||
}
|
||||
// One past the ceiling — even with the right code, ErrTooManyAttempts.
|
||||
_, err = svc.ConfirmEmailCode(ctx, auth.ConfirmInputs{
|
||||
ChallengeID: id,
|
||||
Code: code,
|
||||
ClientPublicKey: randomKey(t),
|
||||
TimeZone: "UTC",
|
||||
})
|
||||
if !errors.Is(err, auth.ErrTooManyAttempts) {
|
||||
t.Fatalf("post-ceiling = %v, want ErrTooManyAttempts", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmEmailCodeChallengeNotFound(t *testing.T) {
|
||||
db := startPostgres(t)
|
||||
svc, _, _, _ := buildService(t, db)
|
||||
|
||||
_, err := svc.ConfirmEmailCode(context.Background(), auth.ConfirmInputs{
|
||||
ChallengeID: uuid.New(),
|
||||
Code: "000000",
|
||||
ClientPublicKey: randomKey(t),
|
||||
TimeZone: "UTC",
|
||||
})
|
||||
if !errors.Is(err, auth.ErrChallengeNotFound) {
|
||||
t.Fatalf("unknown challenge = %v, want ErrChallengeNotFound", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRevokeAllForUser(t *testing.T) {
|
||||
db := startPostgres(t)
|
||||
svc, mailer, pusher, _ := buildService(t, db)
|
||||
ctx := context.Background()
|
||||
|
||||
const email = "many@example.test"
|
||||
const sessionsToCreate = 3
|
||||
var userID uuid.UUID
|
||||
deviceSessionIDs := make([]uuid.UUID, 0, sessionsToCreate)
|
||||
for range sessionsToCreate {
|
||||
id, err := svc.SendEmailCode(ctx, email, "en", "", "")
|
||||
if err != nil {
|
||||
t.Fatalf("send: %v", err)
|
||||
}
|
||||
_, code, _ := mailer.snapshot()
|
||||
sess, err := svc.ConfirmEmailCode(ctx, auth.ConfirmInputs{
|
||||
ChallengeID: id,
|
||||
Code: code,
|
||||
ClientPublicKey: randomKey(t),
|
||||
TimeZone: "UTC",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("confirm: %v", err)
|
||||
}
|
||||
userID = sess.UserID
|
||||
deviceSessionIDs = append(deviceSessionIDs, sess.DeviceSessionID)
|
||||
}
|
||||
|
||||
revoked, err := svc.RevokeAllForUser(ctx, userID)
|
||||
if err != nil {
|
||||
t.Fatalf("RevokeAllForUser: %v", err)
|
||||
}
|
||||
if len(revoked) != sessionsToCreate {
|
||||
t.Fatalf("revoked count = %d, want %d", len(revoked), sessionsToCreate)
|
||||
}
|
||||
for _, dsID := range deviceSessionIDs {
|
||||
if _, err := svc.GetSession(ctx, dsID); !errors.Is(err, auth.ErrSessionNotFound) {
|
||||
t.Fatalf("session %s still in cache: %v", dsID, err)
|
||||
}
|
||||
}
|
||||
if got := len(pusher.snapshot()); got != sessionsToCreate {
|
||||
t.Fatalf("push emissions = %d, want %d", got, sessionsToCreate)
|
||||
}
|
||||
|
||||
// Idempotent: revoking again returns an empty slice.
|
||||
again, err := svc.RevokeAllForUser(ctx, userID)
|
||||
if err != nil {
|
||||
t.Fatalf("idempotent RevokeAllForUser: %v", err)
|
||||
}
|
||||
if len(again) != 0 {
|
||||
t.Fatalf("idempotent RevokeAllForUser = %d sessions, want 0", len(again))
|
||||
}
|
||||
}
|
||||
|
||||
// flipDigit returns code with its first digit replaced by ((digit+1) % 10)
|
||||
// so the resulting string is still a valid CodeLength-digit code but
|
||||
// guaranteed to differ.
|
||||
func flipDigit(code string) string {
|
||||
if code == "" {
|
||||
return "0"
|
||||
}
|
||||
bytes := []byte(code)
|
||||
if bytes[0] >= '0' && bytes[0] <= '9' {
|
||||
bytes[0] = '0' + ((bytes[0]-'0')+1)%10
|
||||
} else {
|
||||
bytes[0] = '0'
|
||||
}
|
||||
return string(bytes)
|
||||
}
|
||||
Reference in New Issue
Block a user