feat: backend service
This commit is contained in:
@@ -0,0 +1,93 @@
|
||||
// Package auth implements the email-code authentication flow and the
|
||||
// active-session bookkeeping consumed by gateway. The package is
|
||||
// described end-to-end in `backend/PLAN.md` §5.1.
|
||||
//
|
||||
// External dependencies that have not landed yet (mail in 5.6, push
|
||||
// session_invalidation in 6) are injected through the LoginCodeMailer
|
||||
// and SessionInvalidator interfaces; auth ships no-op implementations
|
||||
// that satisfy the contract until the real services arrive.
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"time"
|
||||
|
||||
"galaxy/backend/internal/config"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Deps aggregates every collaborator the Service depends on.
|
||||
// Constructing the Service through Deps (rather than positional args)
|
||||
// keeps wiring patches small when new dependencies are added.
|
||||
//
|
||||
// Cache and Store must be non-nil: GetSession reads through Cache,
|
||||
// SendEmailCode and ConfirmEmailCode mutate Store. User, Geo, Mail and
|
||||
// Push are tested-in-isolation interfaces; production wires the real
|
||||
// `*user.Service`, `*geo.Service`, mail, and push implementations.
|
||||
type Deps struct {
|
||||
Store *Store
|
||||
Cache *Cache
|
||||
User UserEnsurer
|
||||
Geo GeoService
|
||||
Mail LoginCodeMailer
|
||||
Push SessionInvalidator
|
||||
Config config.AuthConfig
|
||||
// Now overrides time.Now for deterministic tests. A nil Now defaults
|
||||
// to time.Now in NewService.
|
||||
Now func() time.Time
|
||||
// Logger is named under "auth" by NewService. Nil falls back to
|
||||
// zap.NewNop.
|
||||
Logger *zap.Logger
|
||||
}
|
||||
|
||||
// Service is the auth-domain entry point.
|
||||
type Service struct {
|
||||
deps Deps
|
||||
|
||||
// emailHashKey keys the HMAC used to derive `email_hash` log fields.
|
||||
// A per-boot random key keeps email PII out of structured logs while
|
||||
// still letting operators correlate log entries within a single
|
||||
// process lifetime.
|
||||
emailHashKey []byte
|
||||
}
|
||||
|
||||
// NewService constructs a Service from deps. A nil Now defaults to
|
||||
// time.Now; a nil Logger defaults to zap.NewNop. The other dependencies
|
||||
// must be supplied — calling Service methods with nil Cache/Store/User/
|
||||
// Geo/Mail/Push will panic at first use, matching how main.go signals
|
||||
// missing wiring.
|
||||
func NewService(deps Deps) *Service {
|
||||
if deps.Now == nil {
|
||||
deps.Now = time.Now
|
||||
}
|
||||
if deps.Logger == nil {
|
||||
deps.Logger = zap.NewNop()
|
||||
}
|
||||
deps.Logger = deps.Logger.Named("auth")
|
||||
|
||||
key := make([]byte, 32)
|
||||
if _, err := rand.Read(key); err != nil {
|
||||
// rand.Read should not fail in practice; if it does, fall back
|
||||
// to a deterministic key. Email hashing is a log-scoping aid,
|
||||
// not a security primitive, so a constant key is acceptable.
|
||||
copy(key, []byte("galaxy-backend-auth-fallback-key"))
|
||||
}
|
||||
return &Service{deps: deps, emailHashKey: key}
|
||||
}
|
||||
|
||||
// hashEmail returns a stable, hex-encoded HMAC-SHA256 prefix of email
|
||||
// suitable for use in structured logs. The key is per-process so the
|
||||
// same email maps to the same hash across log lines emitted by this
|
||||
// process, but never across process restarts. The truncation gives
|
||||
// operators enough collision-resistance for ad-hoc grep without keeping
|
||||
// an offline key store.
|
||||
func (s *Service) hashEmail(email string) string {
|
||||
mac := hmac.New(sha256.New, s.emailHashKey)
|
||||
_, _ = mac.Write([]byte(email))
|
||||
full := mac.Sum(nil)
|
||||
return hex.EncodeToString(full[:8])
|
||||
}
|
||||
@@ -0,0 +1,511 @@
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"net/url"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/backend/internal/auth"
|
||||
"galaxy/backend/internal/config"
|
||||
backendpg "galaxy/backend/internal/postgres"
|
||||
"galaxy/backend/internal/user"
|
||||
pgshared "galaxy/postgres"
|
||||
|
||||
"github.com/google/uuid"
|
||||
testcontainers "github.com/testcontainers/testcontainers-go"
|
||||
tcpostgres "github.com/testcontainers/testcontainers-go/modules/postgres"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
)
|
||||
|
||||
const (
|
||||
pgImage = "postgres:16-alpine"
|
||||
pgUser = "galaxy"
|
||||
pgPassword = "galaxy"
|
||||
pgDatabase = "galaxy_backend"
|
||||
pgSchema = "backend"
|
||||
pgStartup = 90 * time.Second
|
||||
pgOpTO = 10 * time.Second
|
||||
)
|
||||
|
||||
// startPostgres spins up a Postgres testcontainer with the backend
|
||||
// migrations applied. The returned *sql.DB is closed and the container
|
||||
// terminated by t.Cleanup hooks.
|
||||
func startPostgres(t *testing.T) *sql.DB {
|
||||
t.Helper()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
pgContainer, err := tcpostgres.Run(ctx, pgImage,
|
||||
tcpostgres.WithDatabase(pgDatabase),
|
||||
tcpostgres.WithUsername(pgUser),
|
||||
tcpostgres.WithPassword(pgPassword),
|
||||
testcontainers.WithWaitStrategy(
|
||||
wait.ForLog("database system is ready to accept connections").
|
||||
WithOccurrence(2).
|
||||
WithStartupTimeout(pgStartup),
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
t.Skipf("postgres testcontainer unavailable, skipping: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if termErr := testcontainers.TerminateContainer(pgContainer); termErr != nil {
|
||||
t.Errorf("terminate postgres container: %v", termErr)
|
||||
}
|
||||
})
|
||||
|
||||
baseDSN, err := pgContainer.ConnectionString(ctx, "sslmode=disable")
|
||||
if err != nil {
|
||||
t.Fatalf("connection string: %v", err)
|
||||
}
|
||||
scopedDSN, err := dsnWithSearchPath(baseDSN, pgSchema)
|
||||
if err != nil {
|
||||
t.Fatalf("scope dsn: %v", err)
|
||||
}
|
||||
|
||||
cfg := pgshared.DefaultConfig()
|
||||
cfg.PrimaryDSN = scopedDSN
|
||||
cfg.OperationTimeout = pgOpTO
|
||||
|
||||
db, err := pgshared.OpenPrimary(ctx, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("open primary: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
|
||||
if err := pgshared.Ping(ctx, db, cfg.OperationTimeout); err != nil {
|
||||
t.Fatalf("ping: %v", err)
|
||||
}
|
||||
if err := backendpg.ApplyMigrations(ctx, db); err != nil {
|
||||
t.Fatalf("apply migrations: %v", err)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
func dsnWithSearchPath(baseDSN, schema string) (string, error) {
|
||||
parsed, err := url.Parse(baseDSN)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
values := parsed.Query()
|
||||
values.Set("search_path", schema)
|
||||
if values.Get("sslmode") == "" {
|
||||
values.Set("sslmode", "disable")
|
||||
}
|
||||
parsed.RawQuery = values.Encode()
|
||||
return parsed.String(), nil
|
||||
}
|
||||
|
||||
// recordingMailer implements auth.LoginCodeMailer and remembers the most
|
||||
// recent enqueue.
|
||||
type recordingMailer struct {
|
||||
mu sync.Mutex
|
||||
lastCode string
|
||||
lastTo string
|
||||
calls int
|
||||
}
|
||||
|
||||
func newRecordingMailer() *recordingMailer { return &recordingMailer{} }
|
||||
|
||||
func (m *recordingMailer) EnqueueLoginCode(_ context.Context, email, code string, _ time.Duration) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.lastTo = email
|
||||
m.lastCode = code
|
||||
m.calls++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *recordingMailer) snapshot() (string, string, int) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.lastTo, m.lastCode, m.calls
|
||||
}
|
||||
|
||||
// recordingPush implements auth.SessionInvalidator and counts emissions.
|
||||
type recordingPush struct {
|
||||
mu sync.Mutex
|
||||
calls []recordedPush
|
||||
}
|
||||
|
||||
type recordedPush struct {
|
||||
deviceSessionID, userID uuid.UUID
|
||||
reason string
|
||||
}
|
||||
|
||||
func newRecordingPush() *recordingPush { return &recordingPush{} }
|
||||
|
||||
func (p *recordingPush) PublishSessionInvalidation(_ context.Context, dsID, uid uuid.UUID, reason string) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.calls = append(p.calls, recordedPush{deviceSessionID: dsID, userID: uid, reason: reason})
|
||||
}
|
||||
|
||||
func (p *recordingPush) snapshot() []recordedPush {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
out := make([]recordedPush, len(p.calls))
|
||||
copy(out, p.calls)
|
||||
return out
|
||||
}
|
||||
|
||||
// stubGeo implements auth.GeoService with no real lookups. The country
|
||||
// it returns is configurable per call via CountryForIP; LanguageForIP
|
||||
// returns "" so the auth flow exercises the "en" fallback path.
|
||||
type stubGeo struct {
|
||||
countryByIP map[string]string
|
||||
}
|
||||
|
||||
func newStubGeo() *stubGeo {
|
||||
return &stubGeo{countryByIP: map[string]string{}}
|
||||
}
|
||||
|
||||
func (g *stubGeo) LookupCountry(sourceIP string) string {
|
||||
return g.countryByIP[sourceIP]
|
||||
}
|
||||
|
||||
func (g *stubGeo) LanguageForIP(_ string) string { return "" }
|
||||
|
||||
func (g *stubGeo) SetDeclaredCountryAtRegistration(_ context.Context, _ uuid.UUID, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// authConfig builds an AuthConfig suitable for tests.
|
||||
func authConfig() config.AuthConfig {
|
||||
return config.AuthConfig{
|
||||
ChallengeTTL: 5 * time.Minute,
|
||||
ChallengeMaxAttempts: 3,
|
||||
ChallengeThrottle: config.AuthChallengeThrottleConfig{
|
||||
Window: time.Minute,
|
||||
Max: 3,
|
||||
},
|
||||
UserNameMaxRetries: 10,
|
||||
}
|
||||
}
|
||||
|
||||
// buildService wires every dependency around db and returns the service
|
||||
// plus the recording fakes for assertions.
|
||||
func buildService(t *testing.T, db *sql.DB) (*auth.Service, *recordingMailer, *recordingPush, *stubGeo) {
|
||||
t.Helper()
|
||||
store := auth.NewStore(db)
|
||||
cache := auth.NewCache()
|
||||
if err := cache.Warm(context.Background(), store); err != nil {
|
||||
t.Fatalf("warm cache: %v", err)
|
||||
}
|
||||
mailer := newRecordingMailer()
|
||||
pusher := newRecordingPush()
|
||||
geo := newStubGeo()
|
||||
userStore := user.NewStore(db)
|
||||
userSvc := user.NewService(user.Deps{
|
||||
|
||||
Store: userStore,
|
||||
Cache: user.NewCache(),
|
||||
UserNameMaxRetries: 10,
|
||||
Now: time.Now,
|
||||
})
|
||||
svc := auth.NewService(auth.Deps{
|
||||
Store: store,
|
||||
Cache: cache,
|
||||
User: userSvc,
|
||||
Geo: geo,
|
||||
Mail: mailer,
|
||||
Push: pusher,
|
||||
Config: authConfig(),
|
||||
Now: time.Now,
|
||||
})
|
||||
return svc, mailer, pusher, geo
|
||||
}
|
||||
|
||||
func randomKey(t *testing.T) []byte {
|
||||
t.Helper()
|
||||
key := make([]byte, 32)
|
||||
if _, err := rand.Read(key); err != nil {
|
||||
t.Fatalf("rand: %v", err)
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
func TestAuthEndToEnd(t *testing.T) {
|
||||
db := startPostgres(t)
|
||||
svc, mailer, pusher, _ := buildService(t, db)
|
||||
ctx := context.Background()
|
||||
|
||||
challengeID, err := svc.SendEmailCode(ctx, "Alice@Example.Test", "ru", "", "")
|
||||
if err != nil {
|
||||
t.Fatalf("SendEmailCode: %v", err)
|
||||
}
|
||||
if challengeID == uuid.Nil {
|
||||
t.Fatalf("SendEmailCode returned nil challenge_id")
|
||||
}
|
||||
gotEmail, gotCode, calls := mailer.snapshot()
|
||||
if gotEmail != "alice@example.test" {
|
||||
t.Fatalf("mailer email = %q, want lower-cased", gotEmail)
|
||||
}
|
||||
if len(gotCode) != auth.CodeLength {
|
||||
t.Fatalf("mailer code = %q (len %d), want length %d", gotCode, len(gotCode), auth.CodeLength)
|
||||
}
|
||||
if calls != 1 {
|
||||
t.Fatalf("mailer calls = %d, want 1", calls)
|
||||
}
|
||||
|
||||
pubKey := randomKey(t)
|
||||
session, err := svc.ConfirmEmailCode(ctx, auth.ConfirmInputs{
|
||||
ChallengeID: challengeID,
|
||||
Code: gotCode,
|
||||
ClientPublicKey: pubKey,
|
||||
TimeZone: "Europe/Moscow",
|
||||
SourceIP: "",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ConfirmEmailCode: %v", err)
|
||||
}
|
||||
if session.UserID == uuid.Nil {
|
||||
t.Fatalf("session has nil user_id")
|
||||
}
|
||||
if session.Status != auth.SessionStatusActive {
|
||||
t.Fatalf("session.Status = %q, want %q", session.Status, auth.SessionStatusActive)
|
||||
}
|
||||
|
||||
got, err := svc.GetSession(ctx, session.DeviceSessionID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetSession: %v", err)
|
||||
}
|
||||
if got.UserID != session.UserID {
|
||||
t.Fatalf("GetSession user_id = %s, want %s", got.UserID, session.UserID)
|
||||
}
|
||||
|
||||
revoked, err := svc.RevokeSession(ctx, session.DeviceSessionID)
|
||||
if err != nil {
|
||||
t.Fatalf("RevokeSession: %v", err)
|
||||
}
|
||||
if revoked.Status != auth.SessionStatusRevoked {
|
||||
t.Fatalf("revoked.Status = %q, want %q", revoked.Status, auth.SessionStatusRevoked)
|
||||
}
|
||||
if revoked.RevokedAt == nil {
|
||||
t.Fatalf("revoked.RevokedAt nil after revoke")
|
||||
}
|
||||
|
||||
if _, err := svc.GetSession(ctx, session.DeviceSessionID); !errors.Is(err, auth.ErrSessionNotFound) {
|
||||
t.Fatalf("GetSession after revoke = %v, want ErrSessionNotFound", err)
|
||||
}
|
||||
|
||||
again, err := svc.RevokeSession(ctx, session.DeviceSessionID)
|
||||
if err != nil {
|
||||
t.Fatalf("idempotent RevokeSession: %v", err)
|
||||
}
|
||||
if again.DeviceSessionID != session.DeviceSessionID || again.Status != auth.SessionStatusRevoked {
|
||||
t.Fatalf("idempotent revoke shape mismatch: %+v", again)
|
||||
}
|
||||
|
||||
pushes := pusher.snapshot()
|
||||
if len(pushes) != 1 {
|
||||
t.Fatalf("push emissions = %d, want 1", len(pushes))
|
||||
}
|
||||
if pushes[0].deviceSessionID != session.DeviceSessionID {
|
||||
t.Fatalf("push device_session_id mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendEmailCodePermanentlyBlocked(t *testing.T) {
|
||||
db := startPostgres(t)
|
||||
svc, _, _, _ := buildService(t, db)
|
||||
|
||||
// Insert a permanent_block account directly.
|
||||
if _, err := db.Exec(`
|
||||
INSERT INTO backend.accounts (
|
||||
user_id, email, user_name, preferred_language, time_zone, permanent_block
|
||||
) VALUES ($1, $2, $3, $4, $5, true)
|
||||
`, uuid.New(), "blocked@example.test", "Player-XXBLOCK1", "en", "UTC"); err != nil {
|
||||
t.Fatalf("seed account: %v", err)
|
||||
}
|
||||
|
||||
_, err := svc.SendEmailCode(context.Background(), "blocked@example.test", "", "", "")
|
||||
if !errors.Is(err, auth.ErrEmailPermanentlyBlocked) {
|
||||
t.Fatalf("SendEmailCode for blocked email = %v, want ErrEmailPermanentlyBlocked", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendEmailCodeThrottleReusesChallenge(t *testing.T) {
|
||||
db := startPostgres(t)
|
||||
svc, mailer, _, _ := buildService(t, db)
|
||||
ctx := context.Background()
|
||||
|
||||
const email = "throttle@example.test"
|
||||
cfg := authConfig()
|
||||
var firstID uuid.UUID
|
||||
for i := range cfg.ChallengeThrottle.Max {
|
||||
id, err := svc.SendEmailCode(ctx, email, "", "", "")
|
||||
if err != nil {
|
||||
t.Fatalf("SendEmailCode #%d: %v", i, err)
|
||||
}
|
||||
if i == 0 {
|
||||
firstID = id
|
||||
}
|
||||
}
|
||||
_, _, callsBefore := mailer.snapshot()
|
||||
|
||||
// One more call — must reuse the latest challenge_id and skip mail.
|
||||
id, err := svc.SendEmailCode(ctx, email, "", "", "")
|
||||
if err != nil {
|
||||
t.Fatalf("SendEmailCode (throttled): %v", err)
|
||||
}
|
||||
_, _, callsAfter := mailer.snapshot()
|
||||
if callsAfter != callsBefore {
|
||||
t.Fatalf("mail enqueue should be skipped on throttle: before=%d after=%d", callsBefore, callsAfter)
|
||||
}
|
||||
if id == uuid.Nil {
|
||||
t.Fatalf("throttled call returned nil challenge_id")
|
||||
}
|
||||
if id == firstID {
|
||||
t.Fatalf("throttled call returned the FIRST challenge — expected the latest")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmEmailCodeWrongCode(t *testing.T) {
|
||||
db := startPostgres(t)
|
||||
svc, mailer, _, _ := buildService(t, db)
|
||||
ctx := context.Background()
|
||||
|
||||
id, err := svc.SendEmailCode(ctx, "wrong@example.test", "en", "", "")
|
||||
if err != nil {
|
||||
t.Fatalf("send: %v", err)
|
||||
}
|
||||
_, code, _ := mailer.snapshot()
|
||||
wrong := flipDigit(code)
|
||||
|
||||
_, err = svc.ConfirmEmailCode(ctx, auth.ConfirmInputs{
|
||||
ChallengeID: id,
|
||||
Code: wrong,
|
||||
ClientPublicKey: randomKey(t),
|
||||
TimeZone: "UTC",
|
||||
})
|
||||
if !errors.Is(err, auth.ErrCodeMismatch) {
|
||||
t.Fatalf("ConfirmEmailCode wrong code = %v, want ErrCodeMismatch", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmEmailCodeAttemptsCeiling(t *testing.T) {
|
||||
db := startPostgres(t)
|
||||
svc, mailer, _, _ := buildService(t, db)
|
||||
ctx := context.Background()
|
||||
|
||||
id, err := svc.SendEmailCode(ctx, "ceiling@example.test", "en", "", "")
|
||||
if err != nil {
|
||||
t.Fatalf("send: %v", err)
|
||||
}
|
||||
_, code, _ := mailer.snapshot()
|
||||
wrong := flipDigit(code)
|
||||
|
||||
// Burn `max` attempts with the wrong code.
|
||||
for i := range authConfig().ChallengeMaxAttempts {
|
||||
_, err := svc.ConfirmEmailCode(ctx, auth.ConfirmInputs{
|
||||
ChallengeID: id,
|
||||
Code: wrong,
|
||||
ClientPublicKey: randomKey(t),
|
||||
TimeZone: "UTC",
|
||||
})
|
||||
if !errors.Is(err, auth.ErrCodeMismatch) {
|
||||
t.Fatalf("attempt %d: %v, want ErrCodeMismatch", i, err)
|
||||
}
|
||||
}
|
||||
// One past the ceiling — even with the right code, ErrTooManyAttempts.
|
||||
_, err = svc.ConfirmEmailCode(ctx, auth.ConfirmInputs{
|
||||
ChallengeID: id,
|
||||
Code: code,
|
||||
ClientPublicKey: randomKey(t),
|
||||
TimeZone: "UTC",
|
||||
})
|
||||
if !errors.Is(err, auth.ErrTooManyAttempts) {
|
||||
t.Fatalf("post-ceiling = %v, want ErrTooManyAttempts", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfirmEmailCodeChallengeNotFound(t *testing.T) {
|
||||
db := startPostgres(t)
|
||||
svc, _, _, _ := buildService(t, db)
|
||||
|
||||
_, err := svc.ConfirmEmailCode(context.Background(), auth.ConfirmInputs{
|
||||
ChallengeID: uuid.New(),
|
||||
Code: "000000",
|
||||
ClientPublicKey: randomKey(t),
|
||||
TimeZone: "UTC",
|
||||
})
|
||||
if !errors.Is(err, auth.ErrChallengeNotFound) {
|
||||
t.Fatalf("unknown challenge = %v, want ErrChallengeNotFound", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRevokeAllForUser(t *testing.T) {
|
||||
db := startPostgres(t)
|
||||
svc, mailer, pusher, _ := buildService(t, db)
|
||||
ctx := context.Background()
|
||||
|
||||
const email = "many@example.test"
|
||||
const sessionsToCreate = 3
|
||||
var userID uuid.UUID
|
||||
deviceSessionIDs := make([]uuid.UUID, 0, sessionsToCreate)
|
||||
for range sessionsToCreate {
|
||||
id, err := svc.SendEmailCode(ctx, email, "en", "", "")
|
||||
if err != nil {
|
||||
t.Fatalf("send: %v", err)
|
||||
}
|
||||
_, code, _ := mailer.snapshot()
|
||||
sess, err := svc.ConfirmEmailCode(ctx, auth.ConfirmInputs{
|
||||
ChallengeID: id,
|
||||
Code: code,
|
||||
ClientPublicKey: randomKey(t),
|
||||
TimeZone: "UTC",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("confirm: %v", err)
|
||||
}
|
||||
userID = sess.UserID
|
||||
deviceSessionIDs = append(deviceSessionIDs, sess.DeviceSessionID)
|
||||
}
|
||||
|
||||
revoked, err := svc.RevokeAllForUser(ctx, userID)
|
||||
if err != nil {
|
||||
t.Fatalf("RevokeAllForUser: %v", err)
|
||||
}
|
||||
if len(revoked) != sessionsToCreate {
|
||||
t.Fatalf("revoked count = %d, want %d", len(revoked), sessionsToCreate)
|
||||
}
|
||||
for _, dsID := range deviceSessionIDs {
|
||||
if _, err := svc.GetSession(ctx, dsID); !errors.Is(err, auth.ErrSessionNotFound) {
|
||||
t.Fatalf("session %s still in cache: %v", dsID, err)
|
||||
}
|
||||
}
|
||||
if got := len(pusher.snapshot()); got != sessionsToCreate {
|
||||
t.Fatalf("push emissions = %d, want %d", got, sessionsToCreate)
|
||||
}
|
||||
|
||||
// Idempotent: revoking again returns an empty slice.
|
||||
again, err := svc.RevokeAllForUser(ctx, userID)
|
||||
if err != nil {
|
||||
t.Fatalf("idempotent RevokeAllForUser: %v", err)
|
||||
}
|
||||
if len(again) != 0 {
|
||||
t.Fatalf("idempotent RevokeAllForUser = %d sessions, want 0", len(again))
|
||||
}
|
||||
}
|
||||
|
||||
// flipDigit returns code with its first digit replaced by ((digit+1) % 10)
|
||||
// so the resulting string is still a valid CodeLength-digit code but
|
||||
// guaranteed to differ.
|
||||
func flipDigit(code string) string {
|
||||
if code == "" {
|
||||
return "0"
|
||||
}
|
||||
bytes := []byte(code)
|
||||
if bytes[0] >= '0' && bytes[0] <= '9' {
|
||||
bytes[0] = '0' + ((bytes[0]-'0')+1)%10
|
||||
} else {
|
||||
bytes[0] = '0'
|
||||
}
|
||||
return string(bytes)
|
||||
}
|
||||
@@ -0,0 +1,159 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// Cache is the in-memory write-through projection of the active rows in
|
||||
// `backend.device_sessions`. Reads (Get) are RLocked; writes (Add,
|
||||
// Remove, RemoveByUser) are Locked. The cache holds two maps:
|
||||
//
|
||||
// - byID maps device_session_id → Session.
|
||||
// - byUser maps user_id → set of device_session_ids belonging to that
|
||||
// user, used to satisfy bulk revoke without scanning byID.
|
||||
//
|
||||
// Both maps are updated atomically inside one Lock per mutation. The
|
||||
// caller is expected to commit the corresponding database write *before*
|
||||
// invoking Add or Remove so that the cache stays consistent under crash:
|
||||
// a Postgres commit failure leaves the cache untouched, matching the
|
||||
// previous DB state.
|
||||
type Cache struct {
|
||||
mu sync.RWMutex
|
||||
byID map[uuid.UUID]Session
|
||||
byUser map[uuid.UUID]map[uuid.UUID]struct{}
|
||||
ready atomic.Bool
|
||||
}
|
||||
|
||||
// NewCache constructs an empty Cache. The cache reports Ready() == false
|
||||
// until Warm completes successfully.
|
||||
func NewCache() *Cache {
|
||||
return &Cache{
|
||||
byID: make(map[uuid.UUID]Session),
|
||||
byUser: make(map[uuid.UUID]map[uuid.UUID]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Warm replaces the cache contents with every active session loaded from
|
||||
// store. It is intended to be called exactly once at process boot before
|
||||
// the HTTP listener accepts traffic; successful completion flips Ready
|
||||
// to true. Subsequent calls re-warm the cache (useful in tests).
|
||||
func (c *Cache) Warm(ctx context.Context, store *Store) error {
|
||||
sessions, err := store.ListActiveSessions(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.byID = make(map[uuid.UUID]Session, len(sessions))
|
||||
c.byUser = make(map[uuid.UUID]map[uuid.UUID]struct{})
|
||||
for _, s := range sessions {
|
||||
c.byID[s.DeviceSessionID] = s
|
||||
set, ok := c.byUser[s.UserID]
|
||||
if !ok {
|
||||
set = make(map[uuid.UUID]struct{})
|
||||
c.byUser[s.UserID] = set
|
||||
}
|
||||
set[s.DeviceSessionID] = struct{}{}
|
||||
}
|
||||
c.ready.Store(true)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ready reports whether Warm has completed at least once. The HTTP
|
||||
// readiness probe wires through this method so `/readyz` only flips to
|
||||
// 200 after the cache is hydrated.
|
||||
func (c *Cache) Ready() bool {
|
||||
if c == nil {
|
||||
return false
|
||||
}
|
||||
return c.ready.Load()
|
||||
}
|
||||
|
||||
// Size returns the number of cached active sessions. Useful in startup
|
||||
// logs ("auth cache warmed: N sessions") and in tests.
|
||||
func (c *Cache) Size() int {
|
||||
if c == nil {
|
||||
return 0
|
||||
}
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return len(c.byID)
|
||||
}
|
||||
|
||||
// Get returns the session with deviceSessionID and a presence flag.
|
||||
// Misses always return the zero Session and false; callers should not
|
||||
// inspect the returned value when ok is false.
|
||||
func (c *Cache) Get(deviceSessionID uuid.UUID) (Session, bool) {
|
||||
if c == nil {
|
||||
return Session{}, false
|
||||
}
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
s, ok := c.byID[deviceSessionID]
|
||||
return s, ok
|
||||
}
|
||||
|
||||
// Add stores s in the cache. It is safe to call on an existing entry
|
||||
// — both the primary map and the user index are updated to the latest
|
||||
// snapshot.
|
||||
func (c *Cache) Add(s Session) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.byID[s.DeviceSessionID] = s
|
||||
set, ok := c.byUser[s.UserID]
|
||||
if !ok {
|
||||
set = make(map[uuid.UUID]struct{})
|
||||
c.byUser[s.UserID] = set
|
||||
}
|
||||
set[s.DeviceSessionID] = struct{}{}
|
||||
}
|
||||
|
||||
// Remove evicts the entry for deviceSessionID from both maps. Calling
|
||||
// Remove on a missing entry is a no-op.
|
||||
func (c *Cache) Remove(deviceSessionID uuid.UUID) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
s, ok := c.byID[deviceSessionID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
delete(c.byID, deviceSessionID)
|
||||
if set := c.byUser[s.UserID]; set != nil {
|
||||
delete(set, deviceSessionID)
|
||||
if len(set) == 0 {
|
||||
delete(c.byUser, s.UserID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveByUser evicts every cached entry belonging to userID and returns
|
||||
// the device_session_ids it removed. The returned slice is safe for the
|
||||
// caller to hold past the call — it is freshly allocated.
|
||||
func (c *Cache) RemoveByUser(userID uuid.UUID) []uuid.UUID {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
set, ok := c.byUser[userID]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
removed := make([]uuid.UUID, 0, len(set))
|
||||
for id := range set {
|
||||
removed = append(removed, id)
|
||||
delete(c.byID, id)
|
||||
}
|
||||
delete(c.byUser, userID)
|
||||
return removed
|
||||
}
|
||||
@@ -0,0 +1,141 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func TestCacheGetAddRemove(t *testing.T) {
|
||||
c := NewCache()
|
||||
if c.Ready() {
|
||||
t.Fatalf("fresh cache should not be Ready before Warm")
|
||||
}
|
||||
if c.Size() != 0 {
|
||||
t.Fatalf("fresh cache size = %d, want 0", c.Size())
|
||||
}
|
||||
|
||||
id := uuid.New()
|
||||
uid := uuid.New()
|
||||
s := Session{DeviceSessionID: id, UserID: uid, Status: SessionStatusActive}
|
||||
c.Add(s)
|
||||
if c.Size() != 1 {
|
||||
t.Fatalf("size after Add = %d, want 1", c.Size())
|
||||
}
|
||||
got, ok := c.Get(id)
|
||||
if !ok || got.DeviceSessionID != id {
|
||||
t.Fatalf("Get after Add: ok=%v session=%+v", ok, got)
|
||||
}
|
||||
|
||||
c.Remove(id)
|
||||
if c.Size() != 0 {
|
||||
t.Fatalf("size after Remove = %d, want 0", c.Size())
|
||||
}
|
||||
if _, ok := c.Get(id); ok {
|
||||
t.Fatalf("Get after Remove returned a hit")
|
||||
}
|
||||
|
||||
// Remove on already-evicted entry is a no-op.
|
||||
c.Remove(id)
|
||||
}
|
||||
|
||||
func TestCacheRemoveByUser(t *testing.T) {
|
||||
c := NewCache()
|
||||
uid := uuid.New()
|
||||
other := uuid.New()
|
||||
c.Add(Session{DeviceSessionID: uuid.New(), UserID: uid, Status: SessionStatusActive})
|
||||
c.Add(Session{DeviceSessionID: uuid.New(), UserID: uid, Status: SessionStatusActive})
|
||||
c.Add(Session{DeviceSessionID: uuid.New(), UserID: other, Status: SessionStatusActive})
|
||||
|
||||
removed := c.RemoveByUser(uid)
|
||||
if len(removed) != 2 {
|
||||
t.Fatalf("RemoveByUser removed %d, want 2", len(removed))
|
||||
}
|
||||
if c.Size() != 1 {
|
||||
t.Fatalf("size after RemoveByUser = %d, want 1", c.Size())
|
||||
}
|
||||
if got := c.RemoveByUser(uid); got != nil {
|
||||
t.Fatalf("RemoveByUser on empty user returned %v, want nil", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheWarmFlipsReady(t *testing.T) {
|
||||
// Constructing a Cache and calling Warm against a Store without a real
|
||||
// database is awkward — the e2e test exercises Warm against Postgres.
|
||||
// Here we manually populate to confirm Ready toggles.
|
||||
c := NewCache()
|
||||
if c.Ready() {
|
||||
t.Fatalf("Ready before Warm")
|
||||
}
|
||||
// Simulate a successful Warm by setting ready and inserting via Add.
|
||||
c.ready.Store(true)
|
||||
if !c.Ready() {
|
||||
t.Fatalf("Ready did not flip after store")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheConcurrentGetAddRemove(t *testing.T) {
|
||||
c := NewCache()
|
||||
|
||||
const writers = 4
|
||||
const readers = 4
|
||||
const opsPerWorker = 1000
|
||||
|
||||
uid := uuid.New()
|
||||
ids := make([]uuid.UUID, opsPerWorker)
|
||||
for i := range ids {
|
||||
ids[i] = uuid.New()
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
var stop atomic.Bool
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for range writers {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := range opsPerWorker {
|
||||
if stop.Load() {
|
||||
return
|
||||
}
|
||||
c.Add(Session{DeviceSessionID: ids[i], UserID: uid, Status: SessionStatusActive})
|
||||
c.Remove(ids[i])
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
for range readers {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := range opsPerWorker {
|
||||
if stop.Load() {
|
||||
return
|
||||
}
|
||||
_, _ = c.Get(ids[i%len(ids)])
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() { wg.Wait(); close(done) }()
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctx.Done():
|
||||
stop.Store(true)
|
||||
<-done
|
||||
t.Fatalf("cache concurrency test timed out")
|
||||
}
|
||||
|
||||
// After all goroutines finish, the cache must be empty (every Add
|
||||
// is paired with a Remove).
|
||||
if c.Size() != 0 {
|
||||
t.Fatalf("cache size after concurrent run = %d, want 0", c.Size())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,262 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// SendEmailCode issues an email login challenge for email and returns
|
||||
// its challenge_id. The wire shape is intentionally identical for new
|
||||
// users, existing users, and throttled requesters; the only path that
|
||||
// returns ErrEmailPermanentlyBlocked is when email maps to an account
|
||||
// whose `permanent_block` column is true (handler maps that sentinel to
|
||||
// 400 invalid_request).
|
||||
//
|
||||
// Throttle behaviour: when the count of un-consumed, non-expired
|
||||
// challenges for email created within ChallengeThrottle.Window already
|
||||
// equals or exceeds ChallengeThrottle.Max, SendEmailCode reuses the
|
||||
// most recent existing challenge_id and skips the mail enqueue. This
|
||||
// avoids a leak where an attacker who controls their own SMTP server
|
||||
// could otherwise correlate "row created without mail" with
|
||||
// throttle-state on the platform.
|
||||
//
|
||||
// locale (request body, BCP 47) takes precedence over acceptLanguage
|
||||
// (the standard HTTP header forwarded by gateway) when both are
|
||||
// supplied. The captured value is persisted on the challenge row as
|
||||
// `preferred_language`, replayed at confirm-email-code, and used only
|
||||
// for newly-registered accounts; existing accounts keep their stored
|
||||
// language.
|
||||
func (s *Service) SendEmailCode(
|
||||
ctx context.Context,
|
||||
email, locale, acceptLanguage, sourceIP string,
|
||||
) (uuid.UUID, error) {
|
||||
normalised := normaliseEmail(email)
|
||||
if normalised == "" {
|
||||
return uuid.Nil, fmt.Errorf("auth: email is empty")
|
||||
}
|
||||
|
||||
permanent, err := s.deps.Store.IsEmailPermanentlyBlocked(ctx, normalised)
|
||||
if err != nil {
|
||||
return uuid.Nil, err
|
||||
}
|
||||
if permanent {
|
||||
return uuid.Nil, ErrEmailPermanentlyBlocked
|
||||
}
|
||||
|
||||
captured := pickCapturedLocale(locale, acceptLanguage)
|
||||
|
||||
now := s.deps.Now()
|
||||
windowStart := now.Add(-s.deps.Config.ChallengeThrottle.Window)
|
||||
count, err := s.deps.Store.CountRecentChallenges(ctx, normalised, windowStart)
|
||||
if err != nil {
|
||||
return uuid.Nil, err
|
||||
}
|
||||
if count >= s.deps.Config.ChallengeThrottle.Max {
|
||||
existing, lerr := s.deps.Store.LatestUnconsumedChallenge(ctx, normalised, windowStart)
|
||||
if lerr == nil {
|
||||
s.deps.Logger.Info("auth challenge reused (throttled)",
|
||||
zap.String("email_hash", s.hashEmail(normalised)),
|
||||
zap.String("challenge_id", existing.ChallengeID.String()),
|
||||
zap.Int("recent_count", count),
|
||||
)
|
||||
return existing.ChallengeID, nil
|
||||
}
|
||||
if !errors.Is(lerr, sql.ErrNoRows) {
|
||||
return uuid.Nil, lerr
|
||||
}
|
||||
// sql.ErrNoRows here is a race (a concurrent confirm consumed
|
||||
// the row between count and select); fall through and issue a
|
||||
// fresh challenge.
|
||||
}
|
||||
|
||||
code, err := generateCode()
|
||||
if err != nil {
|
||||
return uuid.Nil, err
|
||||
}
|
||||
hash, err := hashCode(code)
|
||||
if err != nil {
|
||||
return uuid.Nil, fmt.Errorf("auth: hash code: %w", err)
|
||||
}
|
||||
|
||||
challenge := Challenge{
|
||||
ChallengeID: uuid.New(),
|
||||
Email: normalised,
|
||||
CodeHash: hash,
|
||||
ExpiresAt: now.Add(s.deps.Config.ChallengeTTL),
|
||||
PreferredLanguage: captured,
|
||||
}
|
||||
if err := s.deps.Store.InsertChallenge(ctx, challenge); err != nil {
|
||||
return uuid.Nil, err
|
||||
}
|
||||
|
||||
if err := s.deps.Mail.EnqueueLoginCode(ctx, normalised, code, s.deps.Config.ChallengeTTL); err != nil {
|
||||
// A mail-enqueue failure is logged but not surfaced — the user
|
||||
// can issue another challenge. The implementation will surface a
|
||||
// transient error path; for The implementation the no-op publisher never
|
||||
// returns an error.
|
||||
s.deps.Logger.Warn("auth: enqueue login code failed",
|
||||
zap.String("email_hash", s.hashEmail(normalised)),
|
||||
zap.String("challenge_id", challenge.ChallengeID.String()),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
|
||||
s.deps.Logger.Info("auth challenge issued",
|
||||
zap.String("email_hash", s.hashEmail(normalised)),
|
||||
zap.String("challenge_id", challenge.ChallengeID.String()),
|
||||
)
|
||||
|
||||
return challenge.ChallengeID, nil
|
||||
}
|
||||
|
||||
// ConfirmInputs is the parsed-and-validated input to ConfirmEmailCode.
|
||||
// Wire-format validation (base64 decode, 32-byte length, IANA time-zone
|
||||
// parse, source-IP extraction) happens at the handler boundary so the
|
||||
// service operates on already-typed values.
|
||||
type ConfirmInputs struct {
|
||||
ChallengeID uuid.UUID
|
||||
Code string
|
||||
ClientPublicKey []byte
|
||||
TimeZone string
|
||||
SourceIP string
|
||||
}
|
||||
|
||||
// ConfirmEmailCode redeems a challenge_id, ensures the corresponding
|
||||
// `accounts` row exists, and creates an active `device_sessions` row.
|
||||
// The returned Session is identical to the row stored in the database
|
||||
// (including server-assigned timestamps).
|
||||
//
|
||||
// The flow runs in two transactions:
|
||||
//
|
||||
// 1. LoadAndIncrementChallenge increments the attempts counter under
|
||||
// SELECT FOR UPDATE so concurrent attempts cannot bypass the ceiling.
|
||||
// 2. Out-of-band: ceiling check, bcrypt verify, EnsureByEmail.
|
||||
// 3. MarkConsumedAndInsertSession atomically marks the challenge
|
||||
// consumed and inserts the device_session row, satisfying the
|
||||
// "single challenge → at most one session" invariant.
|
||||
//
|
||||
// Post-commit work (cache write-through, declared_country backfill) is
|
||||
// best-effort: a failure does not roll the registration back.
|
||||
func (s *Service) ConfirmEmailCode(ctx context.Context, in ConfirmInputs) (Session, error) {
|
||||
if in.ChallengeID == uuid.Nil {
|
||||
return Session{}, ErrChallengeNotFound
|
||||
}
|
||||
if len(in.ClientPublicKey) != 32 {
|
||||
return Session{}, fmt.Errorf("auth: client public key must be 32 bytes, got %d", len(in.ClientPublicKey))
|
||||
}
|
||||
if strings.TrimSpace(in.TimeZone) == "" {
|
||||
return Session{}, fmt.Errorf("auth: time_zone must not be empty")
|
||||
}
|
||||
|
||||
loaded, err := s.deps.Store.LoadAndIncrementChallenge(ctx, in.ChallengeID)
|
||||
if err != nil {
|
||||
return Session{}, err
|
||||
}
|
||||
|
||||
if int(loaded.Attempts) > s.deps.Config.ChallengeMaxAttempts {
|
||||
s.deps.Logger.Info("auth challenge attempts exhausted",
|
||||
zap.String("challenge_id", in.ChallengeID.String()),
|
||||
zap.Int32("attempts", loaded.Attempts),
|
||||
)
|
||||
return Session{}, ErrTooManyAttempts
|
||||
}
|
||||
|
||||
if err := verifyCode(loaded.CodeHash, in.Code); err != nil {
|
||||
if errors.Is(err, ErrCodeMismatch) {
|
||||
s.deps.Logger.Info("auth challenge code mismatch",
|
||||
zap.String("challenge_id", in.ChallengeID.String()),
|
||||
zap.Int32("attempts", loaded.Attempts),
|
||||
)
|
||||
return Session{}, ErrCodeMismatch
|
||||
}
|
||||
return Session{}, err
|
||||
}
|
||||
|
||||
preferredLang := loaded.PreferredLanguage
|
||||
if preferredLang == "" {
|
||||
preferredLang = s.deps.Geo.LanguageForIP(in.SourceIP)
|
||||
}
|
||||
if preferredLang == "" {
|
||||
preferredLang = defaultLanguage
|
||||
}
|
||||
|
||||
declaredCountry := s.deps.Geo.LookupCountry(in.SourceIP)
|
||||
|
||||
userID, err := s.deps.User.EnsureByEmail(ctx, loaded.Email, preferredLang, in.TimeZone, declaredCountry)
|
||||
if err != nil {
|
||||
return Session{}, fmt.Errorf("auth: ensure account by email: %w", err)
|
||||
}
|
||||
|
||||
deviceSessionID := uuid.New()
|
||||
pending := Session{
|
||||
DeviceSessionID: deviceSessionID,
|
||||
UserID: userID,
|
||||
Status: SessionStatusActive,
|
||||
ClientPublicKey: cloneBytes(in.ClientPublicKey),
|
||||
}
|
||||
if err := s.deps.Store.MarkConsumedAndInsertSession(ctx, in.ChallengeID, pending); err != nil {
|
||||
return Session{}, err
|
||||
}
|
||||
|
||||
persisted, err := s.deps.Store.LoadSession(ctx, deviceSessionID)
|
||||
if err != nil {
|
||||
return Session{}, fmt.Errorf("auth: reload created session: %w", err)
|
||||
}
|
||||
s.deps.Cache.Add(persisted)
|
||||
|
||||
if err := s.deps.Geo.SetDeclaredCountryAtRegistration(ctx, userID, in.SourceIP); err != nil {
|
||||
s.deps.Logger.Warn("auth: declared country backfill failed",
|
||||
zap.String("user_id", userID.String()),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
|
||||
s.deps.Logger.Info("auth session created",
|
||||
zap.String("user_id", userID.String()),
|
||||
zap.String("device_session_id", deviceSessionID.String()),
|
||||
)
|
||||
|
||||
return persisted, nil
|
||||
}
|
||||
|
||||
// defaultLanguage is the fallback locale written when neither the body
|
||||
// nor the Accept-Language header nor the geoip-derived language produce
|
||||
// a value.
|
||||
const defaultLanguage = "en"
|
||||
|
||||
func normaliseEmail(email string) string {
|
||||
return strings.ToLower(strings.TrimSpace(email))
|
||||
}
|
||||
|
||||
// pickCapturedLocale picks the locale to persist on the challenge row.
|
||||
// The body field wins over the header. The header parsing is
|
||||
// intentionally minimal — auth only stores the value, so a richer parse
|
||||
// would be wasted; user.Service treats the captured string as opaque.
|
||||
func pickCapturedLocale(locale, acceptLanguage string) string {
|
||||
if v := strings.TrimSpace(locale); v != "" {
|
||||
return v
|
||||
}
|
||||
if acceptLanguage == "" {
|
||||
return ""
|
||||
}
|
||||
first := acceptLanguage
|
||||
if idx := strings.IndexAny(first, ",;"); idx >= 0 {
|
||||
first = first[:idx]
|
||||
}
|
||||
return strings.TrimSpace(first)
|
||||
}
|
||||
|
||||
func cloneBytes(b []byte) []byte {
|
||||
if b == nil {
|
||||
return nil
|
||||
}
|
||||
out := make([]byte, len(b))
|
||||
copy(out, b)
|
||||
return out
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// CodeLength is the fixed length of the decimal code delivered by
|
||||
// SendEmailCode. The OpenAPI description ("six-digit") locks the value
|
||||
// at six; tests cannot lower it without breaking the contract test
|
||||
// against the schema.
|
||||
const CodeLength = 6
|
||||
|
||||
// codeBcryptCost is the bcrypt cost used to store the hashed code in
|
||||
// auth_challenges.code_hash. Cost 10 matches the convention documented
|
||||
// for admin password storage in `backend/README.md` §12. Six-digit codes
|
||||
// have only ~1M entropy, so the bcrypt slowdown is what bounds online
|
||||
// attacks together with the per-challenge attempt ceiling.
|
||||
const codeBcryptCost = bcrypt.DefaultCost
|
||||
|
||||
// generateCode returns a random CodeLength-character decimal string. The
|
||||
// modulo bias when mapping uniform bytes to ten digits is acceptable for
|
||||
// short-lived registration codes — the per-challenge attempt ceiling and
|
||||
// the TTL bound abuse far more tightly than the negligible bias.
|
||||
func generateCode() (string, error) {
|
||||
digits := make([]byte, CodeLength)
|
||||
if _, err := rand.Read(digits); err != nil {
|
||||
return "", fmt.Errorf("auth: generate code: %w", err)
|
||||
}
|
||||
var sb strings.Builder
|
||||
sb.Grow(CodeLength)
|
||||
for _, b := range digits {
|
||||
sb.WriteByte('0' + b%10)
|
||||
}
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
// hashCode returns the bcrypt hash of code using the package-level cost.
|
||||
func hashCode(code string) ([]byte, error) {
|
||||
return bcrypt.GenerateFromPassword([]byte(code), codeBcryptCost)
|
||||
}
|
||||
|
||||
// verifyCode reports whether code matches hash. The function is a thin
|
||||
// wrapper around bcrypt.CompareHashAndPassword so the comparison is
|
||||
// constant-time on the matching path. Returns nil on match,
|
||||
// ErrCodeMismatch when the bcrypt mismatch error fires, and a wrapped
|
||||
// error for any other failure (e.g. malformed hash).
|
||||
func verifyCode(hash []byte, code string) error {
|
||||
err := bcrypt.CompareHashAndPassword(hash, []byte(code))
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) {
|
||||
return ErrCodeMismatch
|
||||
}
|
||||
return fmt.Errorf("auth: verify code: %w", err)
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"errors"
|
||||
)
|
||||
|
||||
func TestGenerateCodeShape(t *testing.T) {
|
||||
for range 100 {
|
||||
code, err := generateCode()
|
||||
if err != nil {
|
||||
t.Fatalf("generateCode: %v", err)
|
||||
}
|
||||
if len(code) != CodeLength {
|
||||
t.Fatalf("len(code) = %d, want %d (got %q)", len(code), CodeLength, code)
|
||||
}
|
||||
for _, r := range code {
|
||||
if r < '0' || r > '9' {
|
||||
t.Fatalf("non-digit rune %q in code %q", r, code)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateCodeRandomness(t *testing.T) {
|
||||
seen := make(map[string]struct{})
|
||||
const trials = 50
|
||||
for range trials {
|
||||
code, err := generateCode()
|
||||
if err != nil {
|
||||
t.Fatalf("generateCode: %v", err)
|
||||
}
|
||||
seen[code] = struct{}{}
|
||||
}
|
||||
// 50 trials over a 10^6 space — duplicate is astronomically unlikely.
|
||||
if len(seen) < trials-1 {
|
||||
t.Fatalf("generateCode produced too many duplicates: %d/%d unique", len(seen), trials)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashAndVerifyCodeRoundTrip(t *testing.T) {
|
||||
const code = "654321"
|
||||
hash, err := hashCode(code)
|
||||
if err != nil {
|
||||
t.Fatalf("hashCode: %v", err)
|
||||
}
|
||||
if !strings.HasPrefix(string(hash), "$2") {
|
||||
t.Fatalf("hash does not look like bcrypt: %q", string(hash))
|
||||
}
|
||||
if err := verifyCode(hash, code); err != nil {
|
||||
t.Fatalf("verifyCode on matching code: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyCodeMismatch(t *testing.T) {
|
||||
hash, err := hashCode("111111")
|
||||
if err != nil {
|
||||
t.Fatalf("hashCode: %v", err)
|
||||
}
|
||||
err = verifyCode(hash, "222222")
|
||||
if !errors.Is(err, ErrCodeMismatch) {
|
||||
t.Fatalf("verifyCode mismatch returned %v, want ErrCodeMismatch", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyCodeMalformedHash(t *testing.T) {
|
||||
err := verifyCode([]byte("not-a-hash"), "111111")
|
||||
if err == nil {
|
||||
t.Fatalf("verifyCode with garbage hash returned nil")
|
||||
}
|
||||
if errors.Is(err, ErrCodeMismatch) {
|
||||
t.Fatalf("malformed hash classified as mismatch: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,90 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// LoginCodeMailer is the publisher contract auth uses to deliver a
|
||||
// one-time login code to a user's mailbox. The canonical
|
||||
// implementation lives in `backend/internal/mail`; tests can use
|
||||
// `NewNoopLoginCodeMailer` to record the outbound code without wiring
|
||||
// SMTP.
|
||||
type LoginCodeMailer interface {
|
||||
EnqueueLoginCode(ctx context.Context, email, code string, ttl time.Duration) error
|
||||
}
|
||||
|
||||
// SessionInvalidator emits the gRPC push session_invalidation event
|
||||
// when auth revokes one or more device sessions. The canonical
|
||||
// implementation lives in `backend/internal/push`; tests can use
|
||||
// `NewNoopSessionInvalidator` for an in-memory log-only fallback.
|
||||
type SessionInvalidator interface {
|
||||
PublishSessionInvalidation(ctx context.Context, deviceSessionID, userID uuid.UUID, reason string)
|
||||
}
|
||||
|
||||
// UserEnsurer binds a confirmed email to an `accounts.user_id`. The
|
||||
// canonical implementation is `*user.Service`; tests can swap in a
|
||||
// recording fake.
|
||||
type UserEnsurer interface {
|
||||
EnsureByEmail(ctx context.Context, email, preferredLanguage, timeZone, declaredCountry string) (uuid.UUID, error)
|
||||
}
|
||||
|
||||
// GeoService provides the geo helpers auth needs at confirm-email-code:
|
||||
// a country lookup for the `preferred_language` fallback and a
|
||||
// post-commit write of `accounts.declared_country`. Both methods are
|
||||
// best-effort — auth never blocks the registration flow on geo failures.
|
||||
type GeoService interface {
|
||||
LookupCountry(sourceIP string) string
|
||||
LanguageForIP(sourceIP string) string
|
||||
SetDeclaredCountryAtRegistration(ctx context.Context, userID uuid.UUID, sourceIP string) error
|
||||
}
|
||||
|
||||
// NewNoopLoginCodeMailer returns a LoginCodeMailer that logs the
|
||||
// outbound code at info level and returns nil. The wiring code uses
|
||||
// the real `mail.Service`; this constructor exists for tests and for
|
||||
// local smoke runs that do not want to bring up an SMTP relay.
|
||||
func NewNoopLoginCodeMailer(logger *zap.Logger) LoginCodeMailer {
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
return &noopLoginCodeMailer{logger: logger.Named("auth.mail.noop")}
|
||||
}
|
||||
|
||||
type noopLoginCodeMailer struct {
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
func (m *noopLoginCodeMailer) EnqueueLoginCode(_ context.Context, email, code string, ttl time.Duration) error {
|
||||
m.logger.Info("auth login code (noop publisher)",
|
||||
zap.String("email", email),
|
||||
zap.String("code", code),
|
||||
zap.Duration("ttl", ttl),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewNoopSessionInvalidator returns a SessionInvalidator that logs
|
||||
// every invalidation at info level and never blocks. The wiring code
|
||||
// uses the real `push.Service`; this constructor exists for tests
|
||||
// that need a callable surface without bringing up gRPC.
|
||||
func NewNoopSessionInvalidator(logger *zap.Logger) SessionInvalidator {
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
return &noopSessionInvalidator{logger: logger.Named("auth.push.noop")}
|
||||
}
|
||||
|
||||
type noopSessionInvalidator struct {
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
func (p *noopSessionInvalidator) PublishSessionInvalidation(_ context.Context, deviceSessionID, userID uuid.UUID, reason string) {
|
||||
p.logger.Info("session invalidation (noop publisher)",
|
||||
zap.String("device_session_id", deviceSessionID.String()),
|
||||
zap.String("user_id", userID.String()),
|
||||
zap.String("reason", reason),
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
package auth
|
||||
|
||||
import "errors"
|
||||
|
||||
// Sentinel errors emitted by Service methods. Handlers translate them
|
||||
// into HTTP responses; callers in tests can match on them with
|
||||
// errors.Is.
|
||||
var (
|
||||
// ErrChallengeNotFound is returned when a confirm-email-code request
|
||||
// references a challenge_id that does not exist, has already been
|
||||
// consumed, or has expired. Returned as a single sentinel because the
|
||||
// API surface deliberately does not differentiate between these cases
|
||||
// — distinguishing them would leak whether a challenge_id was ever
|
||||
// valid, which is signal an attacker should not have.
|
||||
ErrChallengeNotFound = errors.New("auth: challenge is not redeemable")
|
||||
|
||||
// ErrTooManyAttempts is returned when confirm-email-code increments
|
||||
// the attempts counter past the configured ceiling. The challenge row
|
||||
// remains in the database with its incremented counter so further
|
||||
// attempts on the same challenge_id continue to fail with the same
|
||||
// error until the row expires.
|
||||
ErrTooManyAttempts = errors.New("auth: too many attempts")
|
||||
|
||||
// ErrCodeMismatch is returned when the supplied code does not match
|
||||
// the stored bcrypt hash. The challenge stays un-consumed so the user
|
||||
// can try again — bounded by ErrTooManyAttempts.
|
||||
ErrCodeMismatch = errors.New("auth: code is incorrect")
|
||||
|
||||
// ErrEmailPermanentlyBlocked is returned by SendEmailCode when the
|
||||
// supplied email maps to an existing account whose `permanent_block`
|
||||
// column is true. This is the only path that does not return an
|
||||
// opaque success shape.
|
||||
ErrEmailPermanentlyBlocked = errors.New("auth: email is permanently blocked")
|
||||
|
||||
// ErrSessionNotFound is returned by GetSession (and the revoke
|
||||
// helpers in their look-it-up-after-zero-rows fallback) when the
|
||||
// device_session_id does not name a row in `device_sessions`.
|
||||
ErrSessionNotFound = errors.New("auth: session not found")
|
||||
)
|
||||
@@ -0,0 +1,90 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// GetSession returns the active session keyed by deviceSessionID. The
|
||||
// lookup is cache-only: the cache is the write-through projection of
|
||||
// `device_sessions WHERE status='active'`, so a miss means the session
|
||||
// is either revoked or absent. Either way the gateway sees
|
||||
// ErrSessionNotFound and treats the calling client as unauthenticated.
|
||||
func (s *Service) GetSession(_ context.Context, deviceSessionID uuid.UUID) (Session, error) {
|
||||
if deviceSessionID == uuid.Nil {
|
||||
return Session{}, ErrSessionNotFound
|
||||
}
|
||||
sess, ok := s.deps.Cache.Get(deviceSessionID)
|
||||
if !ok {
|
||||
return Session{}, ErrSessionNotFound
|
||||
}
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
// RevokeSession marks deviceSessionID revoked, evicts it from the cache,
|
||||
// and emits a session_invalidation push event. The call is idempotent:
|
||||
// a second revoke on an already-revoked session returns the existing
|
||||
// row with status='revoked' (HTTP 200), not ErrSessionNotFound. An
|
||||
// unknown device_session_id yields ErrSessionNotFound.
|
||||
//
|
||||
// Cache eviction and the push emission run after the database UPDATE
|
||||
// commits so a failed UPDATE leaves both cache and gateway view intact.
|
||||
func (s *Service) RevokeSession(ctx context.Context, deviceSessionID uuid.UUID) (Session, error) {
|
||||
if deviceSessionID == uuid.Nil {
|
||||
return Session{}, ErrSessionNotFound
|
||||
}
|
||||
revoked, ok, err := s.deps.Store.RevokeSession(ctx, deviceSessionID)
|
||||
if err != nil {
|
||||
return Session{}, err
|
||||
}
|
||||
if ok {
|
||||
s.deps.Cache.Remove(deviceSessionID)
|
||||
s.deps.Push.PublishSessionInvalidation(ctx, deviceSessionID, revoked.UserID, "auth.revoke_session")
|
||||
s.deps.Logger.Info("auth session revoked",
|
||||
zap.String("device_session_id", deviceSessionID.String()),
|
||||
zap.String("user_id", revoked.UserID.String()),
|
||||
)
|
||||
return revoked, nil
|
||||
}
|
||||
// UPDATE matched no rows: the session is either already revoked or
|
||||
// never existed. Distinguish by reading the row directly so we can
|
||||
// return the idempotent revoked-shape rather than a 404 when the
|
||||
// session simply was revoked earlier.
|
||||
existing, err := s.deps.Store.LoadSession(ctx, deviceSessionID)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSessionNotFound) {
|
||||
return Session{}, ErrSessionNotFound
|
||||
}
|
||||
return Session{}, err
|
||||
}
|
||||
return existing, nil
|
||||
}
|
||||
|
||||
// RevokeAllForUser marks every active session for userID revoked,
|
||||
// evicts each from the cache, and emits one session_invalidation push
|
||||
// event per revoked row. Returns the list of revoked sessions in the
|
||||
// order Postgres returned them. An empty result is a successful
|
||||
// idempotent call (handler reports revoked_count=0).
|
||||
func (s *Service) RevokeAllForUser(ctx context.Context, userID uuid.UUID) ([]Session, error) {
|
||||
if userID == uuid.Nil {
|
||||
return nil, nil
|
||||
}
|
||||
revoked, err := s.deps.Store.RevokeAllForUser(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, sess := range revoked {
|
||||
s.deps.Cache.Remove(sess.DeviceSessionID)
|
||||
s.deps.Push.PublishSessionInvalidation(ctx, sess.DeviceSessionID, sess.UserID, "auth.revoke_all_for_user")
|
||||
}
|
||||
if len(revoked) > 0 {
|
||||
s.deps.Logger.Info("auth sessions revoked (bulk)",
|
||||
zap.String("user_id", userID.String()),
|
||||
zap.Int("count", len(revoked)),
|
||||
)
|
||||
}
|
||||
return revoked, nil
|
||||
}
|
||||
@@ -0,0 +1,444 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"galaxy/backend/internal/postgres/jet/backend/model"
|
||||
"galaxy/backend/internal/postgres/jet/backend/table"
|
||||
|
||||
"github.com/go-jet/jet/v2/postgres"
|
||||
"github.com/go-jet/jet/v2/qrm"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// Challenge mirrors a row in `backend.auth_challenges` enriched with the
|
||||
// PreferredLanguage column added by migration 00002. The CodeHash slice
|
||||
// is the raw bcrypt hash; verifyCode wraps the comparison.
|
||||
type Challenge struct {
|
||||
ChallengeID uuid.UUID
|
||||
Email string
|
||||
CodeHash []byte
|
||||
Attempts int32
|
||||
CreatedAt time.Time
|
||||
ExpiresAt time.Time
|
||||
ConsumedAt *time.Time
|
||||
PreferredLanguage string
|
||||
}
|
||||
|
||||
// Session mirrors a row in `backend.device_sessions`. The
|
||||
// ClientPublicKey slice is the raw 32-byte Ed25519 key; the handler
|
||||
// layer is responsible for base64 encoding/decoding on the wire.
|
||||
type Session struct {
|
||||
DeviceSessionID uuid.UUID
|
||||
UserID uuid.UUID
|
||||
Status string
|
||||
ClientPublicKey []byte
|
||||
CreatedAt time.Time
|
||||
RevokedAt *time.Time
|
||||
LastSeenAt *time.Time
|
||||
}
|
||||
|
||||
// SessionStatusActive and SessionStatusRevoked enumerate the values
|
||||
// auth writes. The CHECK constraint on `device_sessions.status` also
|
||||
// allows 'blocked', which the user package emits when applying a
|
||||
// `permanent_block` sanction.
|
||||
const (
|
||||
SessionStatusActive = "active"
|
||||
SessionStatusRevoked = "revoked"
|
||||
)
|
||||
|
||||
// Store is the Postgres-backed query surface for `backend.auth_challenges`,
|
||||
// `backend.device_sessions` and the read-side `backend.accounts` lookup
|
||||
// auth needs to detect permanently-blocked emails.
|
||||
type Store struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewStore constructs a Store wrapping db.
|
||||
func NewStore(db *sql.DB) *Store {
|
||||
return &Store{db: db}
|
||||
}
|
||||
|
||||
// challengeColumns lists the projection used by every read of
|
||||
// `auth_challenges`. The order matches model.AuthChallenges field order
|
||||
// inside QueryContext destination scans.
|
||||
func challengeColumns() postgres.ColumnList {
|
||||
return postgres.ColumnList{
|
||||
table.AuthChallenges.ChallengeID,
|
||||
table.AuthChallenges.Email,
|
||||
table.AuthChallenges.CodeHash,
|
||||
table.AuthChallenges.Attempts,
|
||||
table.AuthChallenges.CreatedAt,
|
||||
table.AuthChallenges.ExpiresAt,
|
||||
table.AuthChallenges.ConsumedAt,
|
||||
table.AuthChallenges.PreferredLanguage,
|
||||
}
|
||||
}
|
||||
|
||||
// sessionColumns lists the projection used by every read of
|
||||
// `device_sessions`.
|
||||
func sessionColumns() postgres.ColumnList {
|
||||
return postgres.ColumnList{
|
||||
table.DeviceSessions.DeviceSessionID,
|
||||
table.DeviceSessions.UserID,
|
||||
table.DeviceSessions.ClientPublicKey,
|
||||
table.DeviceSessions.Status,
|
||||
table.DeviceSessions.CreatedAt,
|
||||
table.DeviceSessions.RevokedAt,
|
||||
table.DeviceSessions.LastSeenAt,
|
||||
}
|
||||
}
|
||||
|
||||
// IsEmailPermanentlyBlocked reports whether email maps to a live
|
||||
// `accounts` row whose permanent_block column is true. The lookup is
|
||||
// case-sensitive: callers are expected to pass an already-normalised
|
||||
// (lowercase, trimmed) email.
|
||||
//
|
||||
// A non-existent account returns (false, nil) — the auth flow treats
|
||||
// such emails as eligible for fresh registration.
|
||||
func (s *Store) IsEmailPermanentlyBlocked(ctx context.Context, email string) (bool, error) {
|
||||
stmt := postgres.SELECT(table.Accounts.PermanentBlock).
|
||||
FROM(table.Accounts).
|
||||
WHERE(
|
||||
table.Accounts.Email.EQ(postgres.String(email)).
|
||||
AND(table.Accounts.DeletedAt.IS_NULL()),
|
||||
).
|
||||
LIMIT(1)
|
||||
|
||||
var row model.Accounts
|
||||
if err := stmt.QueryContext(ctx, s.db, &row); err != nil {
|
||||
if errors.Is(err, qrm.ErrNoRows) {
|
||||
return false, nil
|
||||
}
|
||||
return false, fmt.Errorf("auth store: query permanent_block for %q: %w", email, err)
|
||||
}
|
||||
return row.PermanentBlock, nil
|
||||
}
|
||||
|
||||
// LatestUnconsumedChallenge returns the most recently issued
|
||||
// un-consumed, non-expired challenge for email created at or after
|
||||
// since. Returns sql.ErrNoRows when no such challenge exists. The
|
||||
// throttle path uses this method to reuse the existing challenge_id
|
||||
// rather than emit a fresh row.
|
||||
func (s *Store) LatestUnconsumedChallenge(ctx context.Context, email string, since time.Time) (Challenge, error) {
|
||||
stmt := postgres.SELECT(challengeColumns()).
|
||||
FROM(table.AuthChallenges).
|
||||
WHERE(
|
||||
table.AuthChallenges.Email.EQ(postgres.String(email)).
|
||||
AND(table.AuthChallenges.ConsumedAt.IS_NULL()).
|
||||
AND(table.AuthChallenges.ExpiresAt.GT(postgres.NOW())).
|
||||
AND(table.AuthChallenges.CreatedAt.GT_EQ(postgres.TimestampzT(since))),
|
||||
).
|
||||
ORDER_BY(table.AuthChallenges.CreatedAt.DESC()).
|
||||
LIMIT(1)
|
||||
|
||||
var row model.AuthChallenges
|
||||
if err := stmt.QueryContext(ctx, s.db, &row); err != nil {
|
||||
if errors.Is(err, qrm.ErrNoRows) {
|
||||
return Challenge{}, sql.ErrNoRows
|
||||
}
|
||||
return Challenge{}, err
|
||||
}
|
||||
return modelToChallenge(row), nil
|
||||
}
|
||||
|
||||
// CountRecentChallenges returns the number of un-consumed, non-expired
|
||||
// challenges issued for email at or after since. Used by the throttle
|
||||
// gate in SendEmailCode.
|
||||
func (s *Store) CountRecentChallenges(ctx context.Context, email string, since time.Time) (int, error) {
|
||||
stmt := postgres.SELECT(postgres.COUNT(postgres.STAR).AS("count")).
|
||||
FROM(table.AuthChallenges).
|
||||
WHERE(
|
||||
table.AuthChallenges.Email.EQ(postgres.String(email)).
|
||||
AND(table.AuthChallenges.ConsumedAt.IS_NULL()).
|
||||
AND(table.AuthChallenges.ExpiresAt.GT(postgres.NOW())).
|
||||
AND(table.AuthChallenges.CreatedAt.GT_EQ(postgres.TimestampzT(since))),
|
||||
)
|
||||
|
||||
var dest struct {
|
||||
Count int64 `alias:"count"`
|
||||
}
|
||||
if err := stmt.QueryContext(ctx, s.db, &dest); err != nil {
|
||||
return 0, fmt.Errorf("auth store: count recent challenges: %w", err)
|
||||
}
|
||||
return int(dest.Count), nil
|
||||
}
|
||||
|
||||
// InsertChallenge persists a fresh `auth_challenges` row. The caller
|
||||
// owns the primary-key, the bcrypt hash, the expires_at timestamp and
|
||||
// the captured locale. created_at and attempts default at the schema
|
||||
// level.
|
||||
func (s *Store) InsertChallenge(ctx context.Context, c Challenge) error {
|
||||
stmt := table.AuthChallenges.INSERT(
|
||||
table.AuthChallenges.ChallengeID,
|
||||
table.AuthChallenges.Email,
|
||||
table.AuthChallenges.CodeHash,
|
||||
table.AuthChallenges.ExpiresAt,
|
||||
table.AuthChallenges.PreferredLanguage,
|
||||
).VALUES(c.ChallengeID, c.Email, c.CodeHash, c.ExpiresAt, c.PreferredLanguage)
|
||||
|
||||
if _, err := stmt.ExecContext(ctx, s.db); err != nil {
|
||||
return fmt.Errorf("auth store: insert challenge: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadAndIncrementChallenge atomically locks the challenge row,
|
||||
// validates that it is still un-consumed and non-expired, and increments
|
||||
// its `attempts` counter. The returned Challenge carries the
|
||||
// post-increment counter so the caller can compare it against the
|
||||
// configured ceiling without a second query.
|
||||
//
|
||||
// Returns ErrChallengeNotFound when the row does not exist, has been
|
||||
// consumed, or has expired. Any other error is wrapped with the auth
|
||||
// store prefix.
|
||||
func (s *Store) LoadAndIncrementChallenge(ctx context.Context, challengeID uuid.UUID) (Challenge, error) {
|
||||
var loaded Challenge
|
||||
err := withTx(ctx, s.db, func(tx *sql.Tx) error {
|
||||
selectStmt := postgres.SELECT(challengeColumns()).
|
||||
FROM(table.AuthChallenges).
|
||||
WHERE(table.AuthChallenges.ChallengeID.EQ(postgres.UUID(challengeID))).
|
||||
FOR(postgres.UPDATE())
|
||||
|
||||
var row model.AuthChallenges
|
||||
if err := selectStmt.QueryContext(ctx, tx, &row); err != nil {
|
||||
if errors.Is(err, qrm.ErrNoRows) {
|
||||
return ErrChallengeNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
loaded = modelToChallenge(row)
|
||||
if loaded.ConsumedAt != nil {
|
||||
return ErrChallengeNotFound
|
||||
}
|
||||
if !loaded.ExpiresAt.After(time.Now()) {
|
||||
return ErrChallengeNotFound
|
||||
}
|
||||
updateStmt := table.AuthChallenges.
|
||||
UPDATE(table.AuthChallenges.Attempts).
|
||||
SET(table.AuthChallenges.Attempts.ADD(postgres.Int(1))).
|
||||
WHERE(table.AuthChallenges.ChallengeID.EQ(postgres.UUID(challengeID)))
|
||||
if _, err := updateStmt.ExecContext(ctx, tx); err != nil {
|
||||
return err
|
||||
}
|
||||
loaded.Attempts++
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrChallengeNotFound) {
|
||||
return Challenge{}, err
|
||||
}
|
||||
return Challenge{}, fmt.Errorf("auth store: load and increment challenge: %w", err)
|
||||
}
|
||||
return loaded, nil
|
||||
}
|
||||
|
||||
// MarkConsumedAndInsertSession atomically:
|
||||
//
|
||||
// 1. Locks the challenge row.
|
||||
// 2. Validates that it is still un-consumed and non-expired.
|
||||
// 3. Sets consumed_at = now().
|
||||
// 4. Inserts the supplied Session into device_sessions with status =
|
||||
// 'active'.
|
||||
//
|
||||
// The two writes are committed together so a single challenge yields at
|
||||
// most one device session even under concurrent confirm-email-code
|
||||
// callers.
|
||||
//
|
||||
// Returns ErrChallengeNotFound when the challenge has been consumed (by
|
||||
// a concurrent caller) or has expired in the gap between the
|
||||
// LoadAndIncrementChallenge call and this one.
|
||||
func (s *Store) MarkConsumedAndInsertSession(ctx context.Context, challengeID uuid.UUID, session Session) error {
|
||||
err := withTx(ctx, s.db, func(tx *sql.Tx) error {
|
||||
lockStmt := postgres.SELECT(table.AuthChallenges.ConsumedAt, table.AuthChallenges.ExpiresAt).
|
||||
FROM(table.AuthChallenges).
|
||||
WHERE(table.AuthChallenges.ChallengeID.EQ(postgres.UUID(challengeID))).
|
||||
FOR(postgres.UPDATE())
|
||||
|
||||
var locked model.AuthChallenges
|
||||
if err := lockStmt.QueryContext(ctx, tx, &locked); err != nil {
|
||||
if errors.Is(err, qrm.ErrNoRows) {
|
||||
return ErrChallengeNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
if locked.ConsumedAt != nil || !locked.ExpiresAt.After(time.Now()) {
|
||||
return ErrChallengeNotFound
|
||||
}
|
||||
consumeStmt := table.AuthChallenges.
|
||||
UPDATE(table.AuthChallenges.ConsumedAt).
|
||||
SET(postgres.NOW()).
|
||||
WHERE(table.AuthChallenges.ChallengeID.EQ(postgres.UUID(challengeID)))
|
||||
if _, err := consumeStmt.ExecContext(ctx, tx); err != nil {
|
||||
return err
|
||||
}
|
||||
insertStmt := table.DeviceSessions.INSERT(
|
||||
table.DeviceSessions.DeviceSessionID,
|
||||
table.DeviceSessions.UserID,
|
||||
table.DeviceSessions.ClientPublicKey,
|
||||
table.DeviceSessions.Status,
|
||||
).VALUES(session.DeviceSessionID, session.UserID, session.ClientPublicKey, SessionStatusActive)
|
||||
if _, err := insertStmt.ExecContext(ctx, tx); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrChallengeNotFound) {
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("auth store: mark consumed and insert session: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListActiveSessions loads every row from device_sessions whose status
|
||||
// is 'active'. Cache.Warm calls this at process boot.
|
||||
func (s *Store) ListActiveSessions(ctx context.Context) ([]Session, error) {
|
||||
stmt := postgres.SELECT(sessionColumns()).
|
||||
FROM(table.DeviceSessions).
|
||||
WHERE(table.DeviceSessions.Status.EQ(postgres.String(SessionStatusActive)))
|
||||
|
||||
var rows []model.DeviceSessions
|
||||
if err := stmt.QueryContext(ctx, s.db, &rows); err != nil {
|
||||
return nil, fmt.Errorf("auth store: list active sessions: %w", err)
|
||||
}
|
||||
out := make([]Session, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
out = append(out, modelToSession(row))
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// LoadSession returns the row for deviceSessionID regardless of status.
|
||||
// Returns ErrSessionNotFound on missing row.
|
||||
func (s *Store) LoadSession(ctx context.Context, deviceSessionID uuid.UUID) (Session, error) {
|
||||
stmt := postgres.SELECT(sessionColumns()).
|
||||
FROM(table.DeviceSessions).
|
||||
WHERE(table.DeviceSessions.DeviceSessionID.EQ(postgres.UUID(deviceSessionID))).
|
||||
LIMIT(1)
|
||||
|
||||
var row model.DeviceSessions
|
||||
if err := stmt.QueryContext(ctx, s.db, &row); err != nil {
|
||||
if errors.Is(err, qrm.ErrNoRows) {
|
||||
return Session{}, ErrSessionNotFound
|
||||
}
|
||||
return Session{}, fmt.Errorf("auth store: load session %s: %w", deviceSessionID, err)
|
||||
}
|
||||
return modelToSession(row), nil
|
||||
}
|
||||
|
||||
// RevokeSession transitions an active row to status='revoked' and
|
||||
// returns the row as it stands after the update. The boolean reports
|
||||
// whether the UPDATE actually changed a row — false means the row was
|
||||
// already revoked or did not exist; the auth Service then falls back to
|
||||
// LoadSession for idempotent-revoke responses.
|
||||
func (s *Store) RevokeSession(ctx context.Context, deviceSessionID uuid.UUID) (Session, bool, error) {
|
||||
stmt := table.DeviceSessions.
|
||||
UPDATE(table.DeviceSessions.Status, table.DeviceSessions.RevokedAt).
|
||||
SET(postgres.String(SessionStatusRevoked), postgres.NOW()).
|
||||
WHERE(
|
||||
table.DeviceSessions.DeviceSessionID.EQ(postgres.UUID(deviceSessionID)).
|
||||
AND(table.DeviceSessions.Status.EQ(postgres.String(SessionStatusActive))),
|
||||
).
|
||||
RETURNING(sessionColumns())
|
||||
|
||||
var row model.DeviceSessions
|
||||
if err := stmt.QueryContext(ctx, s.db, &row); err != nil {
|
||||
if errors.Is(err, qrm.ErrNoRows) {
|
||||
return Session{}, false, nil
|
||||
}
|
||||
return Session{}, false, fmt.Errorf("auth store: revoke session %s: %w", deviceSessionID, err)
|
||||
}
|
||||
return modelToSession(row), true, nil
|
||||
}
|
||||
|
||||
// RevokeAllForUser transitions every active row for userID to
|
||||
// status='revoked' and returns the rows as they stand after the update.
|
||||
// An empty slice with a nil error is returned when the user owned no
|
||||
// active sessions; the caller must treat that as a successful idempotent
|
||||
// revoke (the API surface returns revoked_count=0 in that case).
|
||||
func (s *Store) RevokeAllForUser(ctx context.Context, userID uuid.UUID) ([]Session, error) {
|
||||
stmt := table.DeviceSessions.
|
||||
UPDATE(table.DeviceSessions.Status, table.DeviceSessions.RevokedAt).
|
||||
SET(postgres.String(SessionStatusRevoked), postgres.NOW()).
|
||||
WHERE(
|
||||
table.DeviceSessions.UserID.EQ(postgres.UUID(userID)).
|
||||
AND(table.DeviceSessions.Status.EQ(postgres.String(SessionStatusActive))),
|
||||
).
|
||||
RETURNING(sessionColumns())
|
||||
|
||||
var rows []model.DeviceSessions
|
||||
if err := stmt.QueryContext(ctx, s.db, &rows); err != nil {
|
||||
return nil, fmt.Errorf("auth store: revoke all for user %s: %w", userID, err)
|
||||
}
|
||||
out := make([]Session, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
out = append(out, modelToSession(row))
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// modelToChallenge projects a generated model row into the public
|
||||
// Challenge struct. Pointer fields are copied so callers cannot mutate
|
||||
// the underlying scan buffer.
|
||||
func modelToChallenge(row model.AuthChallenges) Challenge {
|
||||
c := Challenge{
|
||||
ChallengeID: row.ChallengeID,
|
||||
Email: row.Email,
|
||||
CodeHash: row.CodeHash,
|
||||
Attempts: row.Attempts,
|
||||
CreatedAt: row.CreatedAt,
|
||||
ExpiresAt: row.ExpiresAt,
|
||||
PreferredLanguage: row.PreferredLanguage,
|
||||
}
|
||||
if row.ConsumedAt != nil {
|
||||
t := *row.ConsumedAt
|
||||
c.ConsumedAt = &t
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// modelToSession projects a generated model row into the public Session
|
||||
// struct.
|
||||
func modelToSession(row model.DeviceSessions) Session {
|
||||
s := Session{
|
||||
DeviceSessionID: row.DeviceSessionID,
|
||||
UserID: row.UserID,
|
||||
Status: row.Status,
|
||||
ClientPublicKey: row.ClientPublicKey,
|
||||
CreatedAt: row.CreatedAt,
|
||||
}
|
||||
if row.RevokedAt != nil {
|
||||
t := *row.RevokedAt
|
||||
s.RevokedAt = &t
|
||||
}
|
||||
if row.LastSeenAt != nil {
|
||||
t := *row.LastSeenAt
|
||||
s.LastSeenAt = &t
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// withTx wraps fn in a Postgres transaction. fn's return value
|
||||
// determines commit (nil) vs rollback (non-nil). Rollback errors are
|
||||
// swallowed when fn already returned an error, since the latter is more
|
||||
// actionable.
|
||||
func withTx(ctx context.Context, db *sql.DB, fn func(tx *sql.Tx) error) error {
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("auth store: begin tx: %w", err)
|
||||
}
|
||||
if err := fn(tx); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return err
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("auth store: commit tx: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user