feat: backend service

This commit is contained in:
Ilia Denisov
2026-05-06 10:14:55 +03:00
committed by GitHub
parent 3e2622757e
commit f446c6a2ac
1486 changed files with 49720 additions and 266401 deletions
+93
View File
@@ -0,0 +1,93 @@
// Package auth implements the email-code authentication flow and the
// active-session bookkeeping consumed by gateway. The package is
// described end-to-end in `backend/PLAN.md` §5.1.
//
// External dependencies that have not landed yet (mail in 5.6, push
// session_invalidation in 6) are injected through the LoginCodeMailer
// and SessionInvalidator interfaces; auth ships no-op implementations
// that satisfy the contract until the real services arrive.
package auth
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"time"
"galaxy/backend/internal/config"
"go.uber.org/zap"
)
// Deps aggregates every collaborator the Service depends on.
// Constructing the Service through Deps (rather than positional args)
// keeps wiring patches small when new dependencies are added.
//
// Cache and Store must be non-nil: GetSession reads through Cache,
// SendEmailCode and ConfirmEmailCode mutate Store. User, Geo, Mail and
// Push are tested-in-isolation interfaces; production wires the real
// `*user.Service`, `*geo.Service`, mail, and push implementations.
type Deps struct {
Store *Store
Cache *Cache
User UserEnsurer
Geo GeoService
Mail LoginCodeMailer
Push SessionInvalidator
Config config.AuthConfig
// Now overrides time.Now for deterministic tests. A nil Now defaults
// to time.Now in NewService.
Now func() time.Time
// Logger is named under "auth" by NewService. Nil falls back to
// zap.NewNop.
Logger *zap.Logger
}
// Service is the auth-domain entry point.
type Service struct {
deps Deps
// emailHashKey keys the HMAC used to derive `email_hash` log fields.
// A per-boot random key keeps email PII out of structured logs while
// still letting operators correlate log entries within a single
// process lifetime.
emailHashKey []byte
}
// NewService constructs a Service from deps. A nil Now defaults to
// time.Now; a nil Logger defaults to zap.NewNop. The other dependencies
// must be supplied — calling Service methods with nil Cache/Store/User/
// Geo/Mail/Push will panic at first use, matching how main.go signals
// missing wiring.
func NewService(deps Deps) *Service {
if deps.Now == nil {
deps.Now = time.Now
}
if deps.Logger == nil {
deps.Logger = zap.NewNop()
}
deps.Logger = deps.Logger.Named("auth")
key := make([]byte, 32)
if _, err := rand.Read(key); err != nil {
// rand.Read should not fail in practice; if it does, fall back
// to a deterministic key. Email hashing is a log-scoping aid,
// not a security primitive, so a constant key is acceptable.
copy(key, []byte("galaxy-backend-auth-fallback-key"))
}
return &Service{deps: deps, emailHashKey: key}
}
// hashEmail returns a stable, hex-encoded HMAC-SHA256 prefix of email
// suitable for use in structured logs. The key is per-process so the
// same email maps to the same hash across log lines emitted by this
// process, but never across process restarts. The truncation gives
// operators enough collision-resistance for ad-hoc grep without keeping
// an offline key store.
func (s *Service) hashEmail(email string) string {
mac := hmac.New(sha256.New, s.emailHashKey)
_, _ = mac.Write([]byte(email))
full := mac.Sum(nil)
return hex.EncodeToString(full[:8])
}
+511
View File
@@ -0,0 +1,511 @@
package auth_test
import (
"context"
"crypto/rand"
"database/sql"
"errors"
"net/url"
"sync"
"testing"
"time"
"galaxy/backend/internal/auth"
"galaxy/backend/internal/config"
backendpg "galaxy/backend/internal/postgres"
"galaxy/backend/internal/user"
pgshared "galaxy/postgres"
"github.com/google/uuid"
testcontainers "github.com/testcontainers/testcontainers-go"
tcpostgres "github.com/testcontainers/testcontainers-go/modules/postgres"
"github.com/testcontainers/testcontainers-go/wait"
)
const (
pgImage = "postgres:16-alpine"
pgUser = "galaxy"
pgPassword = "galaxy"
pgDatabase = "galaxy_backend"
pgSchema = "backend"
pgStartup = 90 * time.Second
pgOpTO = 10 * time.Second
)
// startPostgres spins up a Postgres testcontainer with the backend
// migrations applied. The returned *sql.DB is closed and the container
// terminated by t.Cleanup hooks.
func startPostgres(t *testing.T) *sql.DB {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
t.Cleanup(cancel)
pgContainer, err := tcpostgres.Run(ctx, pgImage,
tcpostgres.WithDatabase(pgDatabase),
tcpostgres.WithUsername(pgUser),
tcpostgres.WithPassword(pgPassword),
testcontainers.WithWaitStrategy(
wait.ForLog("database system is ready to accept connections").
WithOccurrence(2).
WithStartupTimeout(pgStartup),
),
)
if err != nil {
t.Skipf("postgres testcontainer unavailable, skipping: %v", err)
}
t.Cleanup(func() {
if termErr := testcontainers.TerminateContainer(pgContainer); termErr != nil {
t.Errorf("terminate postgres container: %v", termErr)
}
})
baseDSN, err := pgContainer.ConnectionString(ctx, "sslmode=disable")
if err != nil {
t.Fatalf("connection string: %v", err)
}
scopedDSN, err := dsnWithSearchPath(baseDSN, pgSchema)
if err != nil {
t.Fatalf("scope dsn: %v", err)
}
cfg := pgshared.DefaultConfig()
cfg.PrimaryDSN = scopedDSN
cfg.OperationTimeout = pgOpTO
db, err := pgshared.OpenPrimary(ctx, cfg)
if err != nil {
t.Fatalf("open primary: %v", err)
}
t.Cleanup(func() { _ = db.Close() })
if err := pgshared.Ping(ctx, db, cfg.OperationTimeout); err != nil {
t.Fatalf("ping: %v", err)
}
if err := backendpg.ApplyMigrations(ctx, db); err != nil {
t.Fatalf("apply migrations: %v", err)
}
return db
}
func dsnWithSearchPath(baseDSN, schema string) (string, error) {
parsed, err := url.Parse(baseDSN)
if err != nil {
return "", err
}
values := parsed.Query()
values.Set("search_path", schema)
if values.Get("sslmode") == "" {
values.Set("sslmode", "disable")
}
parsed.RawQuery = values.Encode()
return parsed.String(), nil
}
// recordingMailer implements auth.LoginCodeMailer and remembers the most
// recent enqueue.
type recordingMailer struct {
mu sync.Mutex
lastCode string
lastTo string
calls int
}
func newRecordingMailer() *recordingMailer { return &recordingMailer{} }
func (m *recordingMailer) EnqueueLoginCode(_ context.Context, email, code string, _ time.Duration) error {
m.mu.Lock()
defer m.mu.Unlock()
m.lastTo = email
m.lastCode = code
m.calls++
return nil
}
func (m *recordingMailer) snapshot() (string, string, int) {
m.mu.Lock()
defer m.mu.Unlock()
return m.lastTo, m.lastCode, m.calls
}
// recordingPush implements auth.SessionInvalidator and counts emissions.
type recordingPush struct {
mu sync.Mutex
calls []recordedPush
}
type recordedPush struct {
deviceSessionID, userID uuid.UUID
reason string
}
func newRecordingPush() *recordingPush { return &recordingPush{} }
func (p *recordingPush) PublishSessionInvalidation(_ context.Context, dsID, uid uuid.UUID, reason string) {
p.mu.Lock()
defer p.mu.Unlock()
p.calls = append(p.calls, recordedPush{deviceSessionID: dsID, userID: uid, reason: reason})
}
func (p *recordingPush) snapshot() []recordedPush {
p.mu.Lock()
defer p.mu.Unlock()
out := make([]recordedPush, len(p.calls))
copy(out, p.calls)
return out
}
// stubGeo implements auth.GeoService with no real lookups. The country
// it returns is configurable per call via CountryForIP; LanguageForIP
// returns "" so the auth flow exercises the "en" fallback path.
type stubGeo struct {
countryByIP map[string]string
}
func newStubGeo() *stubGeo {
return &stubGeo{countryByIP: map[string]string{}}
}
func (g *stubGeo) LookupCountry(sourceIP string) string {
return g.countryByIP[sourceIP]
}
func (g *stubGeo) LanguageForIP(_ string) string { return "" }
func (g *stubGeo) SetDeclaredCountryAtRegistration(_ context.Context, _ uuid.UUID, _ string) error {
return nil
}
// authConfig builds an AuthConfig suitable for tests.
func authConfig() config.AuthConfig {
return config.AuthConfig{
ChallengeTTL: 5 * time.Minute,
ChallengeMaxAttempts: 3,
ChallengeThrottle: config.AuthChallengeThrottleConfig{
Window: time.Minute,
Max: 3,
},
UserNameMaxRetries: 10,
}
}
// buildService wires every dependency around db and returns the service
// plus the recording fakes for assertions.
func buildService(t *testing.T, db *sql.DB) (*auth.Service, *recordingMailer, *recordingPush, *stubGeo) {
t.Helper()
store := auth.NewStore(db)
cache := auth.NewCache()
if err := cache.Warm(context.Background(), store); err != nil {
t.Fatalf("warm cache: %v", err)
}
mailer := newRecordingMailer()
pusher := newRecordingPush()
geo := newStubGeo()
userStore := user.NewStore(db)
userSvc := user.NewService(user.Deps{
Store: userStore,
Cache: user.NewCache(),
UserNameMaxRetries: 10,
Now: time.Now,
})
svc := auth.NewService(auth.Deps{
Store: store,
Cache: cache,
User: userSvc,
Geo: geo,
Mail: mailer,
Push: pusher,
Config: authConfig(),
Now: time.Now,
})
return svc, mailer, pusher, geo
}
func randomKey(t *testing.T) []byte {
t.Helper()
key := make([]byte, 32)
if _, err := rand.Read(key); err != nil {
t.Fatalf("rand: %v", err)
}
return key
}
func TestAuthEndToEnd(t *testing.T) {
db := startPostgres(t)
svc, mailer, pusher, _ := buildService(t, db)
ctx := context.Background()
challengeID, err := svc.SendEmailCode(ctx, "Alice@Example.Test", "ru", "", "")
if err != nil {
t.Fatalf("SendEmailCode: %v", err)
}
if challengeID == uuid.Nil {
t.Fatalf("SendEmailCode returned nil challenge_id")
}
gotEmail, gotCode, calls := mailer.snapshot()
if gotEmail != "alice@example.test" {
t.Fatalf("mailer email = %q, want lower-cased", gotEmail)
}
if len(gotCode) != auth.CodeLength {
t.Fatalf("mailer code = %q (len %d), want length %d", gotCode, len(gotCode), auth.CodeLength)
}
if calls != 1 {
t.Fatalf("mailer calls = %d, want 1", calls)
}
pubKey := randomKey(t)
session, err := svc.ConfirmEmailCode(ctx, auth.ConfirmInputs{
ChallengeID: challengeID,
Code: gotCode,
ClientPublicKey: pubKey,
TimeZone: "Europe/Moscow",
SourceIP: "",
})
if err != nil {
t.Fatalf("ConfirmEmailCode: %v", err)
}
if session.UserID == uuid.Nil {
t.Fatalf("session has nil user_id")
}
if session.Status != auth.SessionStatusActive {
t.Fatalf("session.Status = %q, want %q", session.Status, auth.SessionStatusActive)
}
got, err := svc.GetSession(ctx, session.DeviceSessionID)
if err != nil {
t.Fatalf("GetSession: %v", err)
}
if got.UserID != session.UserID {
t.Fatalf("GetSession user_id = %s, want %s", got.UserID, session.UserID)
}
revoked, err := svc.RevokeSession(ctx, session.DeviceSessionID)
if err != nil {
t.Fatalf("RevokeSession: %v", err)
}
if revoked.Status != auth.SessionStatusRevoked {
t.Fatalf("revoked.Status = %q, want %q", revoked.Status, auth.SessionStatusRevoked)
}
if revoked.RevokedAt == nil {
t.Fatalf("revoked.RevokedAt nil after revoke")
}
if _, err := svc.GetSession(ctx, session.DeviceSessionID); !errors.Is(err, auth.ErrSessionNotFound) {
t.Fatalf("GetSession after revoke = %v, want ErrSessionNotFound", err)
}
again, err := svc.RevokeSession(ctx, session.DeviceSessionID)
if err != nil {
t.Fatalf("idempotent RevokeSession: %v", err)
}
if again.DeviceSessionID != session.DeviceSessionID || again.Status != auth.SessionStatusRevoked {
t.Fatalf("idempotent revoke shape mismatch: %+v", again)
}
pushes := pusher.snapshot()
if len(pushes) != 1 {
t.Fatalf("push emissions = %d, want 1", len(pushes))
}
if pushes[0].deviceSessionID != session.DeviceSessionID {
t.Fatalf("push device_session_id mismatch")
}
}
func TestSendEmailCodePermanentlyBlocked(t *testing.T) {
db := startPostgres(t)
svc, _, _, _ := buildService(t, db)
// Insert a permanent_block account directly.
if _, err := db.Exec(`
INSERT INTO backend.accounts (
user_id, email, user_name, preferred_language, time_zone, permanent_block
) VALUES ($1, $2, $3, $4, $5, true)
`, uuid.New(), "blocked@example.test", "Player-XXBLOCK1", "en", "UTC"); err != nil {
t.Fatalf("seed account: %v", err)
}
_, err := svc.SendEmailCode(context.Background(), "blocked@example.test", "", "", "")
if !errors.Is(err, auth.ErrEmailPermanentlyBlocked) {
t.Fatalf("SendEmailCode for blocked email = %v, want ErrEmailPermanentlyBlocked", err)
}
}
func TestSendEmailCodeThrottleReusesChallenge(t *testing.T) {
db := startPostgres(t)
svc, mailer, _, _ := buildService(t, db)
ctx := context.Background()
const email = "throttle@example.test"
cfg := authConfig()
var firstID uuid.UUID
for i := range cfg.ChallengeThrottle.Max {
id, err := svc.SendEmailCode(ctx, email, "", "", "")
if err != nil {
t.Fatalf("SendEmailCode #%d: %v", i, err)
}
if i == 0 {
firstID = id
}
}
_, _, callsBefore := mailer.snapshot()
// One more call — must reuse the latest challenge_id and skip mail.
id, err := svc.SendEmailCode(ctx, email, "", "", "")
if err != nil {
t.Fatalf("SendEmailCode (throttled): %v", err)
}
_, _, callsAfter := mailer.snapshot()
if callsAfter != callsBefore {
t.Fatalf("mail enqueue should be skipped on throttle: before=%d after=%d", callsBefore, callsAfter)
}
if id == uuid.Nil {
t.Fatalf("throttled call returned nil challenge_id")
}
if id == firstID {
t.Fatalf("throttled call returned the FIRST challenge — expected the latest")
}
}
func TestConfirmEmailCodeWrongCode(t *testing.T) {
db := startPostgres(t)
svc, mailer, _, _ := buildService(t, db)
ctx := context.Background()
id, err := svc.SendEmailCode(ctx, "wrong@example.test", "en", "", "")
if err != nil {
t.Fatalf("send: %v", err)
}
_, code, _ := mailer.snapshot()
wrong := flipDigit(code)
_, err = svc.ConfirmEmailCode(ctx, auth.ConfirmInputs{
ChallengeID: id,
Code: wrong,
ClientPublicKey: randomKey(t),
TimeZone: "UTC",
})
if !errors.Is(err, auth.ErrCodeMismatch) {
t.Fatalf("ConfirmEmailCode wrong code = %v, want ErrCodeMismatch", err)
}
}
func TestConfirmEmailCodeAttemptsCeiling(t *testing.T) {
db := startPostgres(t)
svc, mailer, _, _ := buildService(t, db)
ctx := context.Background()
id, err := svc.SendEmailCode(ctx, "ceiling@example.test", "en", "", "")
if err != nil {
t.Fatalf("send: %v", err)
}
_, code, _ := mailer.snapshot()
wrong := flipDigit(code)
// Burn `max` attempts with the wrong code.
for i := range authConfig().ChallengeMaxAttempts {
_, err := svc.ConfirmEmailCode(ctx, auth.ConfirmInputs{
ChallengeID: id,
Code: wrong,
ClientPublicKey: randomKey(t),
TimeZone: "UTC",
})
if !errors.Is(err, auth.ErrCodeMismatch) {
t.Fatalf("attempt %d: %v, want ErrCodeMismatch", i, err)
}
}
// One past the ceiling — even with the right code, ErrTooManyAttempts.
_, err = svc.ConfirmEmailCode(ctx, auth.ConfirmInputs{
ChallengeID: id,
Code: code,
ClientPublicKey: randomKey(t),
TimeZone: "UTC",
})
if !errors.Is(err, auth.ErrTooManyAttempts) {
t.Fatalf("post-ceiling = %v, want ErrTooManyAttempts", err)
}
}
func TestConfirmEmailCodeChallengeNotFound(t *testing.T) {
db := startPostgres(t)
svc, _, _, _ := buildService(t, db)
_, err := svc.ConfirmEmailCode(context.Background(), auth.ConfirmInputs{
ChallengeID: uuid.New(),
Code: "000000",
ClientPublicKey: randomKey(t),
TimeZone: "UTC",
})
if !errors.Is(err, auth.ErrChallengeNotFound) {
t.Fatalf("unknown challenge = %v, want ErrChallengeNotFound", err)
}
}
func TestRevokeAllForUser(t *testing.T) {
db := startPostgres(t)
svc, mailer, pusher, _ := buildService(t, db)
ctx := context.Background()
const email = "many@example.test"
const sessionsToCreate = 3
var userID uuid.UUID
deviceSessionIDs := make([]uuid.UUID, 0, sessionsToCreate)
for range sessionsToCreate {
id, err := svc.SendEmailCode(ctx, email, "en", "", "")
if err != nil {
t.Fatalf("send: %v", err)
}
_, code, _ := mailer.snapshot()
sess, err := svc.ConfirmEmailCode(ctx, auth.ConfirmInputs{
ChallengeID: id,
Code: code,
ClientPublicKey: randomKey(t),
TimeZone: "UTC",
})
if err != nil {
t.Fatalf("confirm: %v", err)
}
userID = sess.UserID
deviceSessionIDs = append(deviceSessionIDs, sess.DeviceSessionID)
}
revoked, err := svc.RevokeAllForUser(ctx, userID)
if err != nil {
t.Fatalf("RevokeAllForUser: %v", err)
}
if len(revoked) != sessionsToCreate {
t.Fatalf("revoked count = %d, want %d", len(revoked), sessionsToCreate)
}
for _, dsID := range deviceSessionIDs {
if _, err := svc.GetSession(ctx, dsID); !errors.Is(err, auth.ErrSessionNotFound) {
t.Fatalf("session %s still in cache: %v", dsID, err)
}
}
if got := len(pusher.snapshot()); got != sessionsToCreate {
t.Fatalf("push emissions = %d, want %d", got, sessionsToCreate)
}
// Idempotent: revoking again returns an empty slice.
again, err := svc.RevokeAllForUser(ctx, userID)
if err != nil {
t.Fatalf("idempotent RevokeAllForUser: %v", err)
}
if len(again) != 0 {
t.Fatalf("idempotent RevokeAllForUser = %d sessions, want 0", len(again))
}
}
// flipDigit returns code with its first digit replaced by ((digit+1) % 10)
// so the resulting string is still a valid CodeLength-digit code but
// guaranteed to differ.
func flipDigit(code string) string {
if code == "" {
return "0"
}
bytes := []byte(code)
if bytes[0] >= '0' && bytes[0] <= '9' {
bytes[0] = '0' + ((bytes[0]-'0')+1)%10
} else {
bytes[0] = '0'
}
return string(bytes)
}
+159
View File
@@ -0,0 +1,159 @@
package auth
import (
"context"
"sync"
"sync/atomic"
"github.com/google/uuid"
)
// Cache is the in-memory write-through projection of the active rows in
// `backend.device_sessions`. Reads (Get) are RLocked; writes (Add,
// Remove, RemoveByUser) are Locked. The cache holds two maps:
//
// - byID maps device_session_id → Session.
// - byUser maps user_id → set of device_session_ids belonging to that
// user, used to satisfy bulk revoke without scanning byID.
//
// Both maps are updated atomically inside one Lock per mutation. The
// caller is expected to commit the corresponding database write *before*
// invoking Add or Remove so that the cache stays consistent under crash:
// a Postgres commit failure leaves the cache untouched, matching the
// previous DB state.
type Cache struct {
mu sync.RWMutex
byID map[uuid.UUID]Session
byUser map[uuid.UUID]map[uuid.UUID]struct{}
ready atomic.Bool
}
// NewCache constructs an empty Cache. The cache reports Ready() == false
// until Warm completes successfully.
func NewCache() *Cache {
return &Cache{
byID: make(map[uuid.UUID]Session),
byUser: make(map[uuid.UUID]map[uuid.UUID]struct{}),
}
}
// Warm replaces the cache contents with every active session loaded from
// store. It is intended to be called exactly once at process boot before
// the HTTP listener accepts traffic; successful completion flips Ready
// to true. Subsequent calls re-warm the cache (useful in tests).
func (c *Cache) Warm(ctx context.Context, store *Store) error {
sessions, err := store.ListActiveSessions(ctx)
if err != nil {
return err
}
c.mu.Lock()
defer c.mu.Unlock()
c.byID = make(map[uuid.UUID]Session, len(sessions))
c.byUser = make(map[uuid.UUID]map[uuid.UUID]struct{})
for _, s := range sessions {
c.byID[s.DeviceSessionID] = s
set, ok := c.byUser[s.UserID]
if !ok {
set = make(map[uuid.UUID]struct{})
c.byUser[s.UserID] = set
}
set[s.DeviceSessionID] = struct{}{}
}
c.ready.Store(true)
return nil
}
// Ready reports whether Warm has completed at least once. The HTTP
// readiness probe wires through this method so `/readyz` only flips to
// 200 after the cache is hydrated.
func (c *Cache) Ready() bool {
if c == nil {
return false
}
return c.ready.Load()
}
// Size returns the number of cached active sessions. Useful in startup
// logs ("auth cache warmed: N sessions") and in tests.
func (c *Cache) Size() int {
if c == nil {
return 0
}
c.mu.RLock()
defer c.mu.RUnlock()
return len(c.byID)
}
// Get returns the session with deviceSessionID and a presence flag.
// Misses always return the zero Session and false; callers should not
// inspect the returned value when ok is false.
func (c *Cache) Get(deviceSessionID uuid.UUID) (Session, bool) {
if c == nil {
return Session{}, false
}
c.mu.RLock()
defer c.mu.RUnlock()
s, ok := c.byID[deviceSessionID]
return s, ok
}
// Add stores s in the cache. It is safe to call on an existing entry
// — both the primary map and the user index are updated to the latest
// snapshot.
func (c *Cache) Add(s Session) {
if c == nil {
return
}
c.mu.Lock()
defer c.mu.Unlock()
c.byID[s.DeviceSessionID] = s
set, ok := c.byUser[s.UserID]
if !ok {
set = make(map[uuid.UUID]struct{})
c.byUser[s.UserID] = set
}
set[s.DeviceSessionID] = struct{}{}
}
// Remove evicts the entry for deviceSessionID from both maps. Calling
// Remove on a missing entry is a no-op.
func (c *Cache) Remove(deviceSessionID uuid.UUID) {
if c == nil {
return
}
c.mu.Lock()
defer c.mu.Unlock()
s, ok := c.byID[deviceSessionID]
if !ok {
return
}
delete(c.byID, deviceSessionID)
if set := c.byUser[s.UserID]; set != nil {
delete(set, deviceSessionID)
if len(set) == 0 {
delete(c.byUser, s.UserID)
}
}
}
// RemoveByUser evicts every cached entry belonging to userID and returns
// the device_session_ids it removed. The returned slice is safe for the
// caller to hold past the call — it is freshly allocated.
func (c *Cache) RemoveByUser(userID uuid.UUID) []uuid.UUID {
if c == nil {
return nil
}
c.mu.Lock()
defer c.mu.Unlock()
set, ok := c.byUser[userID]
if !ok {
return nil
}
removed := make([]uuid.UUID, 0, len(set))
for id := range set {
removed = append(removed, id)
delete(c.byID, id)
}
delete(c.byUser, userID)
return removed
}
+141
View File
@@ -0,0 +1,141 @@
package auth
import (
"context"
"sync"
"sync/atomic"
"testing"
"github.com/google/uuid"
)
func TestCacheGetAddRemove(t *testing.T) {
c := NewCache()
if c.Ready() {
t.Fatalf("fresh cache should not be Ready before Warm")
}
if c.Size() != 0 {
t.Fatalf("fresh cache size = %d, want 0", c.Size())
}
id := uuid.New()
uid := uuid.New()
s := Session{DeviceSessionID: id, UserID: uid, Status: SessionStatusActive}
c.Add(s)
if c.Size() != 1 {
t.Fatalf("size after Add = %d, want 1", c.Size())
}
got, ok := c.Get(id)
if !ok || got.DeviceSessionID != id {
t.Fatalf("Get after Add: ok=%v session=%+v", ok, got)
}
c.Remove(id)
if c.Size() != 0 {
t.Fatalf("size after Remove = %d, want 0", c.Size())
}
if _, ok := c.Get(id); ok {
t.Fatalf("Get after Remove returned a hit")
}
// Remove on already-evicted entry is a no-op.
c.Remove(id)
}
func TestCacheRemoveByUser(t *testing.T) {
c := NewCache()
uid := uuid.New()
other := uuid.New()
c.Add(Session{DeviceSessionID: uuid.New(), UserID: uid, Status: SessionStatusActive})
c.Add(Session{DeviceSessionID: uuid.New(), UserID: uid, Status: SessionStatusActive})
c.Add(Session{DeviceSessionID: uuid.New(), UserID: other, Status: SessionStatusActive})
removed := c.RemoveByUser(uid)
if len(removed) != 2 {
t.Fatalf("RemoveByUser removed %d, want 2", len(removed))
}
if c.Size() != 1 {
t.Fatalf("size after RemoveByUser = %d, want 1", c.Size())
}
if got := c.RemoveByUser(uid); got != nil {
t.Fatalf("RemoveByUser on empty user returned %v, want nil", got)
}
}
func TestCacheWarmFlipsReady(t *testing.T) {
// Constructing a Cache and calling Warm against a Store without a real
// database is awkward — the e2e test exercises Warm against Postgres.
// Here we manually populate to confirm Ready toggles.
c := NewCache()
if c.Ready() {
t.Fatalf("Ready before Warm")
}
// Simulate a successful Warm by setting ready and inserting via Add.
c.ready.Store(true)
if !c.Ready() {
t.Fatalf("Ready did not flip after store")
}
}
func TestCacheConcurrentGetAddRemove(t *testing.T) {
c := NewCache()
const writers = 4
const readers = 4
const opsPerWorker = 1000
uid := uuid.New()
ids := make([]uuid.UUID, opsPerWorker)
for i := range ids {
ids[i] = uuid.New()
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var stop atomic.Bool
var wg sync.WaitGroup
for range writers {
wg.Add(1)
go func() {
defer wg.Done()
for i := range opsPerWorker {
if stop.Load() {
return
}
c.Add(Session{DeviceSessionID: ids[i], UserID: uid, Status: SessionStatusActive})
c.Remove(ids[i])
}
}()
}
for range readers {
wg.Add(1)
go func() {
defer wg.Done()
for i := range opsPerWorker {
if stop.Load() {
return
}
_, _ = c.Get(ids[i%len(ids)])
}
}()
}
done := make(chan struct{})
go func() { wg.Wait(); close(done) }()
select {
case <-done:
case <-ctx.Done():
stop.Store(true)
<-done
t.Fatalf("cache concurrency test timed out")
}
// After all goroutines finish, the cache must be empty (every Add
// is paired with a Remove).
if c.Size() != 0 {
t.Fatalf("cache size after concurrent run = %d, want 0", c.Size())
}
}
+262
View File
@@ -0,0 +1,262 @@
package auth
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"github.com/google/uuid"
"go.uber.org/zap"
)
// SendEmailCode issues an email login challenge for email and returns
// its challenge_id. The wire shape is intentionally identical for new
// users, existing users, and throttled requesters; the only path that
// returns ErrEmailPermanentlyBlocked is when email maps to an account
// whose `permanent_block` column is true (handler maps that sentinel to
// 400 invalid_request).
//
// Throttle behaviour: when the count of un-consumed, non-expired
// challenges for email created within ChallengeThrottle.Window already
// equals or exceeds ChallengeThrottle.Max, SendEmailCode reuses the
// most recent existing challenge_id and skips the mail enqueue. This
// avoids a leak where an attacker who controls their own SMTP server
// could otherwise correlate "row created without mail" with
// throttle-state on the platform.
//
// locale (request body, BCP 47) takes precedence over acceptLanguage
// (the standard HTTP header forwarded by gateway) when both are
// supplied. The captured value is persisted on the challenge row as
// `preferred_language`, replayed at confirm-email-code, and used only
// for newly-registered accounts; existing accounts keep their stored
// language.
func (s *Service) SendEmailCode(
ctx context.Context,
email, locale, acceptLanguage, sourceIP string,
) (uuid.UUID, error) {
normalised := normaliseEmail(email)
if normalised == "" {
return uuid.Nil, fmt.Errorf("auth: email is empty")
}
permanent, err := s.deps.Store.IsEmailPermanentlyBlocked(ctx, normalised)
if err != nil {
return uuid.Nil, err
}
if permanent {
return uuid.Nil, ErrEmailPermanentlyBlocked
}
captured := pickCapturedLocale(locale, acceptLanguage)
now := s.deps.Now()
windowStart := now.Add(-s.deps.Config.ChallengeThrottle.Window)
count, err := s.deps.Store.CountRecentChallenges(ctx, normalised, windowStart)
if err != nil {
return uuid.Nil, err
}
if count >= s.deps.Config.ChallengeThrottle.Max {
existing, lerr := s.deps.Store.LatestUnconsumedChallenge(ctx, normalised, windowStart)
if lerr == nil {
s.deps.Logger.Info("auth challenge reused (throttled)",
zap.String("email_hash", s.hashEmail(normalised)),
zap.String("challenge_id", existing.ChallengeID.String()),
zap.Int("recent_count", count),
)
return existing.ChallengeID, nil
}
if !errors.Is(lerr, sql.ErrNoRows) {
return uuid.Nil, lerr
}
// sql.ErrNoRows here is a race (a concurrent confirm consumed
// the row between count and select); fall through and issue a
// fresh challenge.
}
code, err := generateCode()
if err != nil {
return uuid.Nil, err
}
hash, err := hashCode(code)
if err != nil {
return uuid.Nil, fmt.Errorf("auth: hash code: %w", err)
}
challenge := Challenge{
ChallengeID: uuid.New(),
Email: normalised,
CodeHash: hash,
ExpiresAt: now.Add(s.deps.Config.ChallengeTTL),
PreferredLanguage: captured,
}
if err := s.deps.Store.InsertChallenge(ctx, challenge); err != nil {
return uuid.Nil, err
}
if err := s.deps.Mail.EnqueueLoginCode(ctx, normalised, code, s.deps.Config.ChallengeTTL); err != nil {
// A mail-enqueue failure is logged but not surfaced — the user
// can issue another challenge. The implementation will surface a
// transient error path; for The implementation the no-op publisher never
// returns an error.
s.deps.Logger.Warn("auth: enqueue login code failed",
zap.String("email_hash", s.hashEmail(normalised)),
zap.String("challenge_id", challenge.ChallengeID.String()),
zap.Error(err),
)
}
s.deps.Logger.Info("auth challenge issued",
zap.String("email_hash", s.hashEmail(normalised)),
zap.String("challenge_id", challenge.ChallengeID.String()),
)
return challenge.ChallengeID, nil
}
// ConfirmInputs is the parsed-and-validated input to ConfirmEmailCode.
// Wire-format validation (base64 decode, 32-byte length, IANA time-zone
// parse, source-IP extraction) happens at the handler boundary so the
// service operates on already-typed values.
type ConfirmInputs struct {
ChallengeID uuid.UUID
Code string
ClientPublicKey []byte
TimeZone string
SourceIP string
}
// ConfirmEmailCode redeems a challenge_id, ensures the corresponding
// `accounts` row exists, and creates an active `device_sessions` row.
// The returned Session is identical to the row stored in the database
// (including server-assigned timestamps).
//
// The flow runs in two transactions:
//
// 1. LoadAndIncrementChallenge increments the attempts counter under
// SELECT FOR UPDATE so concurrent attempts cannot bypass the ceiling.
// 2. Out-of-band: ceiling check, bcrypt verify, EnsureByEmail.
// 3. MarkConsumedAndInsertSession atomically marks the challenge
// consumed and inserts the device_session row, satisfying the
// "single challenge → at most one session" invariant.
//
// Post-commit work (cache write-through, declared_country backfill) is
// best-effort: a failure does not roll the registration back.
func (s *Service) ConfirmEmailCode(ctx context.Context, in ConfirmInputs) (Session, error) {
if in.ChallengeID == uuid.Nil {
return Session{}, ErrChallengeNotFound
}
if len(in.ClientPublicKey) != 32 {
return Session{}, fmt.Errorf("auth: client public key must be 32 bytes, got %d", len(in.ClientPublicKey))
}
if strings.TrimSpace(in.TimeZone) == "" {
return Session{}, fmt.Errorf("auth: time_zone must not be empty")
}
loaded, err := s.deps.Store.LoadAndIncrementChallenge(ctx, in.ChallengeID)
if err != nil {
return Session{}, err
}
if int(loaded.Attempts) > s.deps.Config.ChallengeMaxAttempts {
s.deps.Logger.Info("auth challenge attempts exhausted",
zap.String("challenge_id", in.ChallengeID.String()),
zap.Int32("attempts", loaded.Attempts),
)
return Session{}, ErrTooManyAttempts
}
if err := verifyCode(loaded.CodeHash, in.Code); err != nil {
if errors.Is(err, ErrCodeMismatch) {
s.deps.Logger.Info("auth challenge code mismatch",
zap.String("challenge_id", in.ChallengeID.String()),
zap.Int32("attempts", loaded.Attempts),
)
return Session{}, ErrCodeMismatch
}
return Session{}, err
}
preferredLang := loaded.PreferredLanguage
if preferredLang == "" {
preferredLang = s.deps.Geo.LanguageForIP(in.SourceIP)
}
if preferredLang == "" {
preferredLang = defaultLanguage
}
declaredCountry := s.deps.Geo.LookupCountry(in.SourceIP)
userID, err := s.deps.User.EnsureByEmail(ctx, loaded.Email, preferredLang, in.TimeZone, declaredCountry)
if err != nil {
return Session{}, fmt.Errorf("auth: ensure account by email: %w", err)
}
deviceSessionID := uuid.New()
pending := Session{
DeviceSessionID: deviceSessionID,
UserID: userID,
Status: SessionStatusActive,
ClientPublicKey: cloneBytes(in.ClientPublicKey),
}
if err := s.deps.Store.MarkConsumedAndInsertSession(ctx, in.ChallengeID, pending); err != nil {
return Session{}, err
}
persisted, err := s.deps.Store.LoadSession(ctx, deviceSessionID)
if err != nil {
return Session{}, fmt.Errorf("auth: reload created session: %w", err)
}
s.deps.Cache.Add(persisted)
if err := s.deps.Geo.SetDeclaredCountryAtRegistration(ctx, userID, in.SourceIP); err != nil {
s.deps.Logger.Warn("auth: declared country backfill failed",
zap.String("user_id", userID.String()),
zap.Error(err),
)
}
s.deps.Logger.Info("auth session created",
zap.String("user_id", userID.String()),
zap.String("device_session_id", deviceSessionID.String()),
)
return persisted, nil
}
// defaultLanguage is the fallback locale written when neither the body
// nor the Accept-Language header nor the geoip-derived language produce
// a value.
const defaultLanguage = "en"
func normaliseEmail(email string) string {
return strings.ToLower(strings.TrimSpace(email))
}
// pickCapturedLocale picks the locale to persist on the challenge row.
// The body field wins over the header. The header parsing is
// intentionally minimal — auth only stores the value, so a richer parse
// would be wasted; user.Service treats the captured string as opaque.
func pickCapturedLocale(locale, acceptLanguage string) string {
if v := strings.TrimSpace(locale); v != "" {
return v
}
if acceptLanguage == "" {
return ""
}
first := acceptLanguage
if idx := strings.IndexAny(first, ",;"); idx >= 0 {
first = first[:idx]
}
return strings.TrimSpace(first)
}
func cloneBytes(b []byte) []byte {
if b == nil {
return nil
}
out := make([]byte, len(b))
copy(out, b)
return out
}
+61
View File
@@ -0,0 +1,61 @@
package auth
import (
"crypto/rand"
"errors"
"fmt"
"strings"
"golang.org/x/crypto/bcrypt"
)
// CodeLength is the fixed length of the decimal code delivered by
// SendEmailCode. The OpenAPI description ("six-digit") locks the value
// at six; tests cannot lower it without breaking the contract test
// against the schema.
const CodeLength = 6
// codeBcryptCost is the bcrypt cost used to store the hashed code in
// auth_challenges.code_hash. Cost 10 matches the convention documented
// for admin password storage in `backend/README.md` §12. Six-digit codes
// have only ~1M entropy, so the bcrypt slowdown is what bounds online
// attacks together with the per-challenge attempt ceiling.
const codeBcryptCost = bcrypt.DefaultCost
// generateCode returns a random CodeLength-character decimal string. The
// modulo bias when mapping uniform bytes to ten digits is acceptable for
// short-lived registration codes — the per-challenge attempt ceiling and
// the TTL bound abuse far more tightly than the negligible bias.
func generateCode() (string, error) {
digits := make([]byte, CodeLength)
if _, err := rand.Read(digits); err != nil {
return "", fmt.Errorf("auth: generate code: %w", err)
}
var sb strings.Builder
sb.Grow(CodeLength)
for _, b := range digits {
sb.WriteByte('0' + b%10)
}
return sb.String(), nil
}
// hashCode returns the bcrypt hash of code using the package-level cost.
func hashCode(code string) ([]byte, error) {
return bcrypt.GenerateFromPassword([]byte(code), codeBcryptCost)
}
// verifyCode reports whether code matches hash. The function is a thin
// wrapper around bcrypt.CompareHashAndPassword so the comparison is
// constant-time on the matching path. Returns nil on match,
// ErrCodeMismatch when the bcrypt mismatch error fires, and a wrapped
// error for any other failure (e.g. malformed hash).
func verifyCode(hash []byte, code string) error {
err := bcrypt.CompareHashAndPassword(hash, []byte(code))
if err == nil {
return nil
}
if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) {
return ErrCodeMismatch
}
return fmt.Errorf("auth: verify code: %w", err)
}
+76
View File
@@ -0,0 +1,76 @@
package auth
import (
"strings"
"testing"
"errors"
)
func TestGenerateCodeShape(t *testing.T) {
for range 100 {
code, err := generateCode()
if err != nil {
t.Fatalf("generateCode: %v", err)
}
if len(code) != CodeLength {
t.Fatalf("len(code) = %d, want %d (got %q)", len(code), CodeLength, code)
}
for _, r := range code {
if r < '0' || r > '9' {
t.Fatalf("non-digit rune %q in code %q", r, code)
}
}
}
}
func TestGenerateCodeRandomness(t *testing.T) {
seen := make(map[string]struct{})
const trials = 50
for range trials {
code, err := generateCode()
if err != nil {
t.Fatalf("generateCode: %v", err)
}
seen[code] = struct{}{}
}
// 50 trials over a 10^6 space — duplicate is astronomically unlikely.
if len(seen) < trials-1 {
t.Fatalf("generateCode produced too many duplicates: %d/%d unique", len(seen), trials)
}
}
func TestHashAndVerifyCodeRoundTrip(t *testing.T) {
const code = "654321"
hash, err := hashCode(code)
if err != nil {
t.Fatalf("hashCode: %v", err)
}
if !strings.HasPrefix(string(hash), "$2") {
t.Fatalf("hash does not look like bcrypt: %q", string(hash))
}
if err := verifyCode(hash, code); err != nil {
t.Fatalf("verifyCode on matching code: %v", err)
}
}
func TestVerifyCodeMismatch(t *testing.T) {
hash, err := hashCode("111111")
if err != nil {
t.Fatalf("hashCode: %v", err)
}
err = verifyCode(hash, "222222")
if !errors.Is(err, ErrCodeMismatch) {
t.Fatalf("verifyCode mismatch returned %v, want ErrCodeMismatch", err)
}
}
func TestVerifyCodeMalformedHash(t *testing.T) {
err := verifyCode([]byte("not-a-hash"), "111111")
if err == nil {
t.Fatalf("verifyCode with garbage hash returned nil")
}
if errors.Is(err, ErrCodeMismatch) {
t.Fatalf("malformed hash classified as mismatch: %v", err)
}
}
+90
View File
@@ -0,0 +1,90 @@
package auth
import (
"context"
"time"
"github.com/google/uuid"
"go.uber.org/zap"
)
// LoginCodeMailer is the publisher contract auth uses to deliver a
// one-time login code to a user's mailbox. The canonical
// implementation lives in `backend/internal/mail`; tests can use
// `NewNoopLoginCodeMailer` to record the outbound code without wiring
// SMTP.
type LoginCodeMailer interface {
EnqueueLoginCode(ctx context.Context, email, code string, ttl time.Duration) error
}
// SessionInvalidator emits the gRPC push session_invalidation event
// when auth revokes one or more device sessions. The canonical
// implementation lives in `backend/internal/push`; tests can use
// `NewNoopSessionInvalidator` for an in-memory log-only fallback.
type SessionInvalidator interface {
PublishSessionInvalidation(ctx context.Context, deviceSessionID, userID uuid.UUID, reason string)
}
// UserEnsurer binds a confirmed email to an `accounts.user_id`. The
// canonical implementation is `*user.Service`; tests can swap in a
// recording fake.
type UserEnsurer interface {
EnsureByEmail(ctx context.Context, email, preferredLanguage, timeZone, declaredCountry string) (uuid.UUID, error)
}
// GeoService provides the geo helpers auth needs at confirm-email-code:
// a country lookup for the `preferred_language` fallback and a
// post-commit write of `accounts.declared_country`. Both methods are
// best-effort — auth never blocks the registration flow on geo failures.
type GeoService interface {
LookupCountry(sourceIP string) string
LanguageForIP(sourceIP string) string
SetDeclaredCountryAtRegistration(ctx context.Context, userID uuid.UUID, sourceIP string) error
}
// NewNoopLoginCodeMailer returns a LoginCodeMailer that logs the
// outbound code at info level and returns nil. The wiring code uses
// the real `mail.Service`; this constructor exists for tests and for
// local smoke runs that do not want to bring up an SMTP relay.
func NewNoopLoginCodeMailer(logger *zap.Logger) LoginCodeMailer {
if logger == nil {
logger = zap.NewNop()
}
return &noopLoginCodeMailer{logger: logger.Named("auth.mail.noop")}
}
type noopLoginCodeMailer struct {
logger *zap.Logger
}
func (m *noopLoginCodeMailer) EnqueueLoginCode(_ context.Context, email, code string, ttl time.Duration) error {
m.logger.Info("auth login code (noop publisher)",
zap.String("email", email),
zap.String("code", code),
zap.Duration("ttl", ttl),
)
return nil
}
// NewNoopSessionInvalidator returns a SessionInvalidator that logs
// every invalidation at info level and never blocks. The wiring code
// uses the real `push.Service`; this constructor exists for tests
// that need a callable surface without bringing up gRPC.
func NewNoopSessionInvalidator(logger *zap.Logger) SessionInvalidator {
if logger == nil {
logger = zap.NewNop()
}
return &noopSessionInvalidator{logger: logger.Named("auth.push.noop")}
}
type noopSessionInvalidator struct {
logger *zap.Logger
}
func (p *noopSessionInvalidator) PublishSessionInvalidation(_ context.Context, deviceSessionID, userID uuid.UUID, reason string) {
p.logger.Info("session invalidation (noop publisher)",
zap.String("device_session_id", deviceSessionID.String()),
zap.String("user_id", userID.String()),
zap.String("reason", reason),
)
}
+39
View File
@@ -0,0 +1,39 @@
package auth
import "errors"
// Sentinel errors emitted by Service methods. Handlers translate them
// into HTTP responses; callers in tests can match on them with
// errors.Is.
var (
// ErrChallengeNotFound is returned when a confirm-email-code request
// references a challenge_id that does not exist, has already been
// consumed, or has expired. Returned as a single sentinel because the
// API surface deliberately does not differentiate between these cases
// — distinguishing them would leak whether a challenge_id was ever
// valid, which is signal an attacker should not have.
ErrChallengeNotFound = errors.New("auth: challenge is not redeemable")
// ErrTooManyAttempts is returned when confirm-email-code increments
// the attempts counter past the configured ceiling. The challenge row
// remains in the database with its incremented counter so further
// attempts on the same challenge_id continue to fail with the same
// error until the row expires.
ErrTooManyAttempts = errors.New("auth: too many attempts")
// ErrCodeMismatch is returned when the supplied code does not match
// the stored bcrypt hash. The challenge stays un-consumed so the user
// can try again — bounded by ErrTooManyAttempts.
ErrCodeMismatch = errors.New("auth: code is incorrect")
// ErrEmailPermanentlyBlocked is returned by SendEmailCode when the
// supplied email maps to an existing account whose `permanent_block`
// column is true. This is the only path that does not return an
// opaque success shape.
ErrEmailPermanentlyBlocked = errors.New("auth: email is permanently blocked")
// ErrSessionNotFound is returned by GetSession (and the revoke
// helpers in their look-it-up-after-zero-rows fallback) when the
// device_session_id does not name a row in `device_sessions`.
ErrSessionNotFound = errors.New("auth: session not found")
)
+90
View File
@@ -0,0 +1,90 @@
package auth
import (
"context"
"errors"
"github.com/google/uuid"
"go.uber.org/zap"
)
// GetSession returns the active session keyed by deviceSessionID. The
// lookup is cache-only: the cache is the write-through projection of
// `device_sessions WHERE status='active'`, so a miss means the session
// is either revoked or absent. Either way the gateway sees
// ErrSessionNotFound and treats the calling client as unauthenticated.
func (s *Service) GetSession(_ context.Context, deviceSessionID uuid.UUID) (Session, error) {
if deviceSessionID == uuid.Nil {
return Session{}, ErrSessionNotFound
}
sess, ok := s.deps.Cache.Get(deviceSessionID)
if !ok {
return Session{}, ErrSessionNotFound
}
return sess, nil
}
// RevokeSession marks deviceSessionID revoked, evicts it from the cache,
// and emits a session_invalidation push event. The call is idempotent:
// a second revoke on an already-revoked session returns the existing
// row with status='revoked' (HTTP 200), not ErrSessionNotFound. An
// unknown device_session_id yields ErrSessionNotFound.
//
// Cache eviction and the push emission run after the database UPDATE
// commits so a failed UPDATE leaves both cache and gateway view intact.
func (s *Service) RevokeSession(ctx context.Context, deviceSessionID uuid.UUID) (Session, error) {
if deviceSessionID == uuid.Nil {
return Session{}, ErrSessionNotFound
}
revoked, ok, err := s.deps.Store.RevokeSession(ctx, deviceSessionID)
if err != nil {
return Session{}, err
}
if ok {
s.deps.Cache.Remove(deviceSessionID)
s.deps.Push.PublishSessionInvalidation(ctx, deviceSessionID, revoked.UserID, "auth.revoke_session")
s.deps.Logger.Info("auth session revoked",
zap.String("device_session_id", deviceSessionID.String()),
zap.String("user_id", revoked.UserID.String()),
)
return revoked, nil
}
// UPDATE matched no rows: the session is either already revoked or
// never existed. Distinguish by reading the row directly so we can
// return the idempotent revoked-shape rather than a 404 when the
// session simply was revoked earlier.
existing, err := s.deps.Store.LoadSession(ctx, deviceSessionID)
if err != nil {
if errors.Is(err, ErrSessionNotFound) {
return Session{}, ErrSessionNotFound
}
return Session{}, err
}
return existing, nil
}
// RevokeAllForUser marks every active session for userID revoked,
// evicts each from the cache, and emits one session_invalidation push
// event per revoked row. Returns the list of revoked sessions in the
// order Postgres returned them. An empty result is a successful
// idempotent call (handler reports revoked_count=0).
func (s *Service) RevokeAllForUser(ctx context.Context, userID uuid.UUID) ([]Session, error) {
if userID == uuid.Nil {
return nil, nil
}
revoked, err := s.deps.Store.RevokeAllForUser(ctx, userID)
if err != nil {
return nil, err
}
for _, sess := range revoked {
s.deps.Cache.Remove(sess.DeviceSessionID)
s.deps.Push.PublishSessionInvalidation(ctx, sess.DeviceSessionID, sess.UserID, "auth.revoke_all_for_user")
}
if len(revoked) > 0 {
s.deps.Logger.Info("auth sessions revoked (bulk)",
zap.String("user_id", userID.String()),
zap.Int("count", len(revoked)),
)
}
return revoked, nil
}
+444
View File
@@ -0,0 +1,444 @@
package auth
import (
"context"
"database/sql"
"errors"
"fmt"
"time"
"galaxy/backend/internal/postgres/jet/backend/model"
"galaxy/backend/internal/postgres/jet/backend/table"
"github.com/go-jet/jet/v2/postgres"
"github.com/go-jet/jet/v2/qrm"
"github.com/google/uuid"
)
// Challenge mirrors a row in `backend.auth_challenges` enriched with the
// PreferredLanguage column added by migration 00002. The CodeHash slice
// is the raw bcrypt hash; verifyCode wraps the comparison.
type Challenge struct {
ChallengeID uuid.UUID
Email string
CodeHash []byte
Attempts int32
CreatedAt time.Time
ExpiresAt time.Time
ConsumedAt *time.Time
PreferredLanguage string
}
// Session mirrors a row in `backend.device_sessions`. The
// ClientPublicKey slice is the raw 32-byte Ed25519 key; the handler
// layer is responsible for base64 encoding/decoding on the wire.
type Session struct {
DeviceSessionID uuid.UUID
UserID uuid.UUID
Status string
ClientPublicKey []byte
CreatedAt time.Time
RevokedAt *time.Time
LastSeenAt *time.Time
}
// SessionStatusActive and SessionStatusRevoked enumerate the values
// auth writes. The CHECK constraint on `device_sessions.status` also
// allows 'blocked', which the user package emits when applying a
// `permanent_block` sanction.
const (
SessionStatusActive = "active"
SessionStatusRevoked = "revoked"
)
// Store is the Postgres-backed query surface for `backend.auth_challenges`,
// `backend.device_sessions` and the read-side `backend.accounts` lookup
// auth needs to detect permanently-blocked emails.
type Store struct {
db *sql.DB
}
// NewStore constructs a Store wrapping db.
func NewStore(db *sql.DB) *Store {
return &Store{db: db}
}
// challengeColumns lists the projection used by every read of
// `auth_challenges`. The order matches model.AuthChallenges field order
// inside QueryContext destination scans.
func challengeColumns() postgres.ColumnList {
return postgres.ColumnList{
table.AuthChallenges.ChallengeID,
table.AuthChallenges.Email,
table.AuthChallenges.CodeHash,
table.AuthChallenges.Attempts,
table.AuthChallenges.CreatedAt,
table.AuthChallenges.ExpiresAt,
table.AuthChallenges.ConsumedAt,
table.AuthChallenges.PreferredLanguage,
}
}
// sessionColumns lists the projection used by every read of
// `device_sessions`.
func sessionColumns() postgres.ColumnList {
return postgres.ColumnList{
table.DeviceSessions.DeviceSessionID,
table.DeviceSessions.UserID,
table.DeviceSessions.ClientPublicKey,
table.DeviceSessions.Status,
table.DeviceSessions.CreatedAt,
table.DeviceSessions.RevokedAt,
table.DeviceSessions.LastSeenAt,
}
}
// IsEmailPermanentlyBlocked reports whether email maps to a live
// `accounts` row whose permanent_block column is true. The lookup is
// case-sensitive: callers are expected to pass an already-normalised
// (lowercase, trimmed) email.
//
// A non-existent account returns (false, nil) — the auth flow treats
// such emails as eligible for fresh registration.
func (s *Store) IsEmailPermanentlyBlocked(ctx context.Context, email string) (bool, error) {
stmt := postgres.SELECT(table.Accounts.PermanentBlock).
FROM(table.Accounts).
WHERE(
table.Accounts.Email.EQ(postgres.String(email)).
AND(table.Accounts.DeletedAt.IS_NULL()),
).
LIMIT(1)
var row model.Accounts
if err := stmt.QueryContext(ctx, s.db, &row); err != nil {
if errors.Is(err, qrm.ErrNoRows) {
return false, nil
}
return false, fmt.Errorf("auth store: query permanent_block for %q: %w", email, err)
}
return row.PermanentBlock, nil
}
// LatestUnconsumedChallenge returns the most recently issued
// un-consumed, non-expired challenge for email created at or after
// since. Returns sql.ErrNoRows when no such challenge exists. The
// throttle path uses this method to reuse the existing challenge_id
// rather than emit a fresh row.
func (s *Store) LatestUnconsumedChallenge(ctx context.Context, email string, since time.Time) (Challenge, error) {
stmt := postgres.SELECT(challengeColumns()).
FROM(table.AuthChallenges).
WHERE(
table.AuthChallenges.Email.EQ(postgres.String(email)).
AND(table.AuthChallenges.ConsumedAt.IS_NULL()).
AND(table.AuthChallenges.ExpiresAt.GT(postgres.NOW())).
AND(table.AuthChallenges.CreatedAt.GT_EQ(postgres.TimestampzT(since))),
).
ORDER_BY(table.AuthChallenges.CreatedAt.DESC()).
LIMIT(1)
var row model.AuthChallenges
if err := stmt.QueryContext(ctx, s.db, &row); err != nil {
if errors.Is(err, qrm.ErrNoRows) {
return Challenge{}, sql.ErrNoRows
}
return Challenge{}, err
}
return modelToChallenge(row), nil
}
// CountRecentChallenges returns the number of un-consumed, non-expired
// challenges issued for email at or after since. Used by the throttle
// gate in SendEmailCode.
func (s *Store) CountRecentChallenges(ctx context.Context, email string, since time.Time) (int, error) {
stmt := postgres.SELECT(postgres.COUNT(postgres.STAR).AS("count")).
FROM(table.AuthChallenges).
WHERE(
table.AuthChallenges.Email.EQ(postgres.String(email)).
AND(table.AuthChallenges.ConsumedAt.IS_NULL()).
AND(table.AuthChallenges.ExpiresAt.GT(postgres.NOW())).
AND(table.AuthChallenges.CreatedAt.GT_EQ(postgres.TimestampzT(since))),
)
var dest struct {
Count int64 `alias:"count"`
}
if err := stmt.QueryContext(ctx, s.db, &dest); err != nil {
return 0, fmt.Errorf("auth store: count recent challenges: %w", err)
}
return int(dest.Count), nil
}
// InsertChallenge persists a fresh `auth_challenges` row. The caller
// owns the primary-key, the bcrypt hash, the expires_at timestamp and
// the captured locale. created_at and attempts default at the schema
// level.
func (s *Store) InsertChallenge(ctx context.Context, c Challenge) error {
stmt := table.AuthChallenges.INSERT(
table.AuthChallenges.ChallengeID,
table.AuthChallenges.Email,
table.AuthChallenges.CodeHash,
table.AuthChallenges.ExpiresAt,
table.AuthChallenges.PreferredLanguage,
).VALUES(c.ChallengeID, c.Email, c.CodeHash, c.ExpiresAt, c.PreferredLanguage)
if _, err := stmt.ExecContext(ctx, s.db); err != nil {
return fmt.Errorf("auth store: insert challenge: %w", err)
}
return nil
}
// LoadAndIncrementChallenge atomically locks the challenge row,
// validates that it is still un-consumed and non-expired, and increments
// its `attempts` counter. The returned Challenge carries the
// post-increment counter so the caller can compare it against the
// configured ceiling without a second query.
//
// Returns ErrChallengeNotFound when the row does not exist, has been
// consumed, or has expired. Any other error is wrapped with the auth
// store prefix.
func (s *Store) LoadAndIncrementChallenge(ctx context.Context, challengeID uuid.UUID) (Challenge, error) {
var loaded Challenge
err := withTx(ctx, s.db, func(tx *sql.Tx) error {
selectStmt := postgres.SELECT(challengeColumns()).
FROM(table.AuthChallenges).
WHERE(table.AuthChallenges.ChallengeID.EQ(postgres.UUID(challengeID))).
FOR(postgres.UPDATE())
var row model.AuthChallenges
if err := selectStmt.QueryContext(ctx, tx, &row); err != nil {
if errors.Is(err, qrm.ErrNoRows) {
return ErrChallengeNotFound
}
return err
}
loaded = modelToChallenge(row)
if loaded.ConsumedAt != nil {
return ErrChallengeNotFound
}
if !loaded.ExpiresAt.After(time.Now()) {
return ErrChallengeNotFound
}
updateStmt := table.AuthChallenges.
UPDATE(table.AuthChallenges.Attempts).
SET(table.AuthChallenges.Attempts.ADD(postgres.Int(1))).
WHERE(table.AuthChallenges.ChallengeID.EQ(postgres.UUID(challengeID)))
if _, err := updateStmt.ExecContext(ctx, tx); err != nil {
return err
}
loaded.Attempts++
return nil
})
if err != nil {
if errors.Is(err, ErrChallengeNotFound) {
return Challenge{}, err
}
return Challenge{}, fmt.Errorf("auth store: load and increment challenge: %w", err)
}
return loaded, nil
}
// MarkConsumedAndInsertSession atomically:
//
// 1. Locks the challenge row.
// 2. Validates that it is still un-consumed and non-expired.
// 3. Sets consumed_at = now().
// 4. Inserts the supplied Session into device_sessions with status =
// 'active'.
//
// The two writes are committed together so a single challenge yields at
// most one device session even under concurrent confirm-email-code
// callers.
//
// Returns ErrChallengeNotFound when the challenge has been consumed (by
// a concurrent caller) or has expired in the gap between the
// LoadAndIncrementChallenge call and this one.
func (s *Store) MarkConsumedAndInsertSession(ctx context.Context, challengeID uuid.UUID, session Session) error {
err := withTx(ctx, s.db, func(tx *sql.Tx) error {
lockStmt := postgres.SELECT(table.AuthChallenges.ConsumedAt, table.AuthChallenges.ExpiresAt).
FROM(table.AuthChallenges).
WHERE(table.AuthChallenges.ChallengeID.EQ(postgres.UUID(challengeID))).
FOR(postgres.UPDATE())
var locked model.AuthChallenges
if err := lockStmt.QueryContext(ctx, tx, &locked); err != nil {
if errors.Is(err, qrm.ErrNoRows) {
return ErrChallengeNotFound
}
return err
}
if locked.ConsumedAt != nil || !locked.ExpiresAt.After(time.Now()) {
return ErrChallengeNotFound
}
consumeStmt := table.AuthChallenges.
UPDATE(table.AuthChallenges.ConsumedAt).
SET(postgres.NOW()).
WHERE(table.AuthChallenges.ChallengeID.EQ(postgres.UUID(challengeID)))
if _, err := consumeStmt.ExecContext(ctx, tx); err != nil {
return err
}
insertStmt := table.DeviceSessions.INSERT(
table.DeviceSessions.DeviceSessionID,
table.DeviceSessions.UserID,
table.DeviceSessions.ClientPublicKey,
table.DeviceSessions.Status,
).VALUES(session.DeviceSessionID, session.UserID, session.ClientPublicKey, SessionStatusActive)
if _, err := insertStmt.ExecContext(ctx, tx); err != nil {
return err
}
return nil
})
if err != nil {
if errors.Is(err, ErrChallengeNotFound) {
return err
}
return fmt.Errorf("auth store: mark consumed and insert session: %w", err)
}
return nil
}
// ListActiveSessions loads every row from device_sessions whose status
// is 'active'. Cache.Warm calls this at process boot.
func (s *Store) ListActiveSessions(ctx context.Context) ([]Session, error) {
stmt := postgres.SELECT(sessionColumns()).
FROM(table.DeviceSessions).
WHERE(table.DeviceSessions.Status.EQ(postgres.String(SessionStatusActive)))
var rows []model.DeviceSessions
if err := stmt.QueryContext(ctx, s.db, &rows); err != nil {
return nil, fmt.Errorf("auth store: list active sessions: %w", err)
}
out := make([]Session, 0, len(rows))
for _, row := range rows {
out = append(out, modelToSession(row))
}
return out, nil
}
// LoadSession returns the row for deviceSessionID regardless of status.
// Returns ErrSessionNotFound on missing row.
func (s *Store) LoadSession(ctx context.Context, deviceSessionID uuid.UUID) (Session, error) {
stmt := postgres.SELECT(sessionColumns()).
FROM(table.DeviceSessions).
WHERE(table.DeviceSessions.DeviceSessionID.EQ(postgres.UUID(deviceSessionID))).
LIMIT(1)
var row model.DeviceSessions
if err := stmt.QueryContext(ctx, s.db, &row); err != nil {
if errors.Is(err, qrm.ErrNoRows) {
return Session{}, ErrSessionNotFound
}
return Session{}, fmt.Errorf("auth store: load session %s: %w", deviceSessionID, err)
}
return modelToSession(row), nil
}
// RevokeSession transitions an active row to status='revoked' and
// returns the row as it stands after the update. The boolean reports
// whether the UPDATE actually changed a row — false means the row was
// already revoked or did not exist; the auth Service then falls back to
// LoadSession for idempotent-revoke responses.
func (s *Store) RevokeSession(ctx context.Context, deviceSessionID uuid.UUID) (Session, bool, error) {
stmt := table.DeviceSessions.
UPDATE(table.DeviceSessions.Status, table.DeviceSessions.RevokedAt).
SET(postgres.String(SessionStatusRevoked), postgres.NOW()).
WHERE(
table.DeviceSessions.DeviceSessionID.EQ(postgres.UUID(deviceSessionID)).
AND(table.DeviceSessions.Status.EQ(postgres.String(SessionStatusActive))),
).
RETURNING(sessionColumns())
var row model.DeviceSessions
if err := stmt.QueryContext(ctx, s.db, &row); err != nil {
if errors.Is(err, qrm.ErrNoRows) {
return Session{}, false, nil
}
return Session{}, false, fmt.Errorf("auth store: revoke session %s: %w", deviceSessionID, err)
}
return modelToSession(row), true, nil
}
// RevokeAllForUser transitions every active row for userID to
// status='revoked' and returns the rows as they stand after the update.
// An empty slice with a nil error is returned when the user owned no
// active sessions; the caller must treat that as a successful idempotent
// revoke (the API surface returns revoked_count=0 in that case).
func (s *Store) RevokeAllForUser(ctx context.Context, userID uuid.UUID) ([]Session, error) {
stmt := table.DeviceSessions.
UPDATE(table.DeviceSessions.Status, table.DeviceSessions.RevokedAt).
SET(postgres.String(SessionStatusRevoked), postgres.NOW()).
WHERE(
table.DeviceSessions.UserID.EQ(postgres.UUID(userID)).
AND(table.DeviceSessions.Status.EQ(postgres.String(SessionStatusActive))),
).
RETURNING(sessionColumns())
var rows []model.DeviceSessions
if err := stmt.QueryContext(ctx, s.db, &rows); err != nil {
return nil, fmt.Errorf("auth store: revoke all for user %s: %w", userID, err)
}
out := make([]Session, 0, len(rows))
for _, row := range rows {
out = append(out, modelToSession(row))
}
return out, nil
}
// modelToChallenge projects a generated model row into the public
// Challenge struct. Pointer fields are copied so callers cannot mutate
// the underlying scan buffer.
func modelToChallenge(row model.AuthChallenges) Challenge {
c := Challenge{
ChallengeID: row.ChallengeID,
Email: row.Email,
CodeHash: row.CodeHash,
Attempts: row.Attempts,
CreatedAt: row.CreatedAt,
ExpiresAt: row.ExpiresAt,
PreferredLanguage: row.PreferredLanguage,
}
if row.ConsumedAt != nil {
t := *row.ConsumedAt
c.ConsumedAt = &t
}
return c
}
// modelToSession projects a generated model row into the public Session
// struct.
func modelToSession(row model.DeviceSessions) Session {
s := Session{
DeviceSessionID: row.DeviceSessionID,
UserID: row.UserID,
Status: row.Status,
ClientPublicKey: row.ClientPublicKey,
CreatedAt: row.CreatedAt,
}
if row.RevokedAt != nil {
t := *row.RevokedAt
s.RevokedAt = &t
}
if row.LastSeenAt != nil {
t := *row.LastSeenAt
s.LastSeenAt = &t
}
return s
}
// withTx wraps fn in a Postgres transaction. fn's return value
// determines commit (nil) vs rollback (non-nil). Rollback errors are
// swallowed when fn already returned an error, since the latter is more
// actionable.
func withTx(ctx context.Context, db *sql.DB, fn func(tx *sql.Tx) error) error {
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("auth store: begin tx: %w", err)
}
if err := fn(tx); err != nil {
_ = tx.Rollback()
return err
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("auth store: commit tx: %w", err)
}
return nil
}