Stage 1: backend foundation (Postgres, sessions, accounts, OTel)
- internal/postgres: pgx-over-database/sql pool (otelsql), embedded goose
migrations into schema 'backend', committed go-jet code + cmd/jetgen tool.
- internal/account: durable accounts + unified telegram/email identities
(UUIDv7 keys), find-or-create provisioning with unique-conflict handling.
- internal/session: opaque 256-bit tokens stored as a SHA-256 hash, revoke-only
(no TTL); write-through cache gating /readyz; store + service.
- internal/telemetry: OTel tracer/meter providers (none/stdout) + request-timing
middleware; internal/config gains Postgres + OTel env loading.
- internal/server: /api/v1 {public,user,internal,admin} skeleton + X-User-ID
middleware; /readyz checks DB ping + cache; main wires
telemetry -> db+migrate -> warm cache -> server.
- Tests: unit + integration (build tag 'integration', testcontainers
postgres:17) for migrations, accounts, sessions, readyz; new integration.yaml.
- Docs: ARCHITECTURE, TESTING, PLAN refinements, root + backend READMEs.
Session/account REST handlers deferred to Stage 6 (gateway); OTLP + dashboards
to Stage 11.
This commit is contained in:
@@ -0,0 +1,203 @@
|
||||
// Package account owns durable internal accounts and their platform/email
|
||||
// identities. First contact from a platform auto-provisions an account bound to
|
||||
// that identity; guests are session-only and never reach this package.
|
||||
package account
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/go-jet/jet/v2/postgres"
|
||||
"github.com/go-jet/jet/v2/qrm"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
|
||||
"scrabble/backend/internal/postgres/jet/backend/model"
|
||||
"scrabble/backend/internal/postgres/jet/backend/table"
|
||||
)
|
||||
|
||||
// Identity kinds recognised by the backend. Email is modelled as an identity
|
||||
// alongside platform identities; its confirmed flag is driven by the email
|
||||
// confirm-code flow in a later stage.
|
||||
const (
|
||||
KindTelegram = "telegram"
|
||||
KindEmail = "email"
|
||||
)
|
||||
|
||||
// uniqueViolation is the PostgreSQL SQLSTATE for a unique-constraint violation.
|
||||
const uniqueViolation = "23505"
|
||||
|
||||
// ErrNotFound is returned when no account matches the lookup.
|
||||
var ErrNotFound = errors.New("account: not found")
|
||||
|
||||
// Account is a durable internal account.
|
||||
type Account struct {
|
||||
ID uuid.UUID
|
||||
DisplayName string
|
||||
PreferredLanguage string
|
||||
TimeZone string
|
||||
BlockChat bool
|
||||
BlockFriendRequests bool
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// Store is the Postgres-backed query surface for accounts and identities.
|
||||
type Store struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewStore constructs a Store wrapping db.
|
||||
func NewStore(db *sql.DB) *Store {
|
||||
return &Store{db: db}
|
||||
}
|
||||
|
||||
// ProvisionByIdentity returns the account bound to (kind, externalID), creating
|
||||
// a fresh durable account and identity when none exists yet. It is safe under
|
||||
// concurrent callers: a losing race on the identity's unique constraint is
|
||||
// resolved by re-reading the winner's account. A platform identity is recorded
|
||||
// as confirmed; an email identity starts unconfirmed.
|
||||
func (s *Store) ProvisionByIdentity(ctx context.Context, kind, externalID string) (Account, error) {
|
||||
acc, err := s.findByIdentity(ctx, kind, externalID)
|
||||
if err == nil {
|
||||
return acc, nil
|
||||
}
|
||||
if !errors.Is(err, ErrNotFound) {
|
||||
return Account{}, err
|
||||
}
|
||||
|
||||
acc, err = s.create(ctx, kind, externalID)
|
||||
if err != nil {
|
||||
if isUniqueViolation(err) {
|
||||
// A concurrent caller created the identity first; return theirs.
|
||||
return s.findByIdentity(ctx, kind, externalID)
|
||||
}
|
||||
return Account{}, err
|
||||
}
|
||||
return acc, nil
|
||||
}
|
||||
|
||||
// GetByID loads the account identified by id, or ErrNotFound when it is absent.
|
||||
func (s *Store) GetByID(ctx context.Context, id uuid.UUID) (Account, error) {
|
||||
stmt := postgres.SELECT(table.Accounts.AllColumns).
|
||||
FROM(table.Accounts).
|
||||
WHERE(table.Accounts.AccountID.EQ(postgres.UUID(id))).
|
||||
LIMIT(1)
|
||||
|
||||
var row model.Accounts
|
||||
if err := stmt.QueryContext(ctx, s.db, &row); err != nil {
|
||||
if errors.Is(err, qrm.ErrNoRows) {
|
||||
return Account{}, ErrNotFound
|
||||
}
|
||||
return Account{}, fmt.Errorf("account: get by id %s: %w", id, err)
|
||||
}
|
||||
return modelToAccount(row), nil
|
||||
}
|
||||
|
||||
// findByIdentity joins identities to accounts and returns the matching account,
|
||||
// or ErrNotFound.
|
||||
func (s *Store) findByIdentity(ctx context.Context, kind, externalID string) (Account, error) {
|
||||
stmt := postgres.SELECT(table.Accounts.AllColumns).
|
||||
FROM(table.Accounts.INNER_JOIN(
|
||||
table.Identities,
|
||||
table.Identities.AccountID.EQ(table.Accounts.AccountID),
|
||||
)).
|
||||
WHERE(
|
||||
table.Identities.Kind.EQ(postgres.String(kind)).
|
||||
AND(table.Identities.ExternalID.EQ(postgres.String(externalID))),
|
||||
).
|
||||
LIMIT(1)
|
||||
|
||||
var row model.Accounts
|
||||
if err := stmt.QueryContext(ctx, s.db, &row); err != nil {
|
||||
if errors.Is(err, qrm.ErrNoRows) {
|
||||
return Account{}, ErrNotFound
|
||||
}
|
||||
return Account{}, fmt.Errorf("account: find by identity (%s, %s): %w", kind, externalID, err)
|
||||
}
|
||||
return modelToAccount(row), nil
|
||||
}
|
||||
|
||||
// create inserts a new account and its first identity inside one transaction
|
||||
// and returns the persisted account row.
|
||||
func (s *Store) create(ctx context.Context, kind, externalID string) (Account, error) {
|
||||
accountID, err := uuid.NewV7()
|
||||
if err != nil {
|
||||
return Account{}, fmt.Errorf("account: new account id: %w", err)
|
||||
}
|
||||
identityID, err := uuid.NewV7()
|
||||
if err != nil {
|
||||
return Account{}, fmt.Errorf("account: new identity id: %w", err)
|
||||
}
|
||||
|
||||
var created Account
|
||||
err = withTx(ctx, s.db, func(tx *sql.Tx) error {
|
||||
insertAccount := table.Accounts.
|
||||
INSERT(table.Accounts.AccountID).
|
||||
VALUES(accountID).
|
||||
RETURNING(table.Accounts.AllColumns)
|
||||
|
||||
var row model.Accounts
|
||||
if err := insertAccount.QueryContext(ctx, tx, &row); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
insertIdentity := table.Identities.INSERT(
|
||||
table.Identities.IdentityID,
|
||||
table.Identities.AccountID,
|
||||
table.Identities.Kind,
|
||||
table.Identities.ExternalID,
|
||||
table.Identities.Confirmed,
|
||||
).VALUES(identityID, accountID, kind, externalID, kind == KindTelegram)
|
||||
if _, err := insertIdentity.ExecContext(ctx, tx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
created = modelToAccount(row)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return Account{}, fmt.Errorf("account: create for identity (%s, %s): %w", kind, externalID, err)
|
||||
}
|
||||
return created, nil
|
||||
}
|
||||
|
||||
// modelToAccount projects a generated model row into the public Account struct.
|
||||
func modelToAccount(row model.Accounts) Account {
|
||||
return Account{
|
||||
ID: row.AccountID,
|
||||
DisplayName: row.DisplayName,
|
||||
PreferredLanguage: row.PreferredLanguage,
|
||||
TimeZone: row.TimeZone,
|
||||
BlockChat: row.BlockChat,
|
||||
BlockFriendRequests: row.BlockFriendRequests,
|
||||
CreatedAt: row.CreatedAt,
|
||||
UpdatedAt: row.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// isUniqueViolation reports whether err is a PostgreSQL unique-constraint
|
||||
// violation, used to collapse a concurrent-provision race into a re-read.
|
||||
func isUniqueViolation(err error) bool {
|
||||
var pgErr *pgconn.PgError
|
||||
return errors.As(err, &pgErr) && pgErr.Code == uniqueViolation
|
||||
}
|
||||
|
||||
// withTx wraps fn in a transaction, committing on nil and rolling back on error.
|
||||
func withTx(ctx context.Context, db *sql.DB, fn func(tx *sql.Tx) error) error {
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("begin tx: %w", err)
|
||||
}
|
||||
if err := fn(tx); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return err
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("commit tx: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -5,6 +5,11 @@ package config
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"scrabble/backend/internal/postgres"
|
||||
"scrabble/backend/internal/telemetry"
|
||||
)
|
||||
|
||||
// Config holds the backend's runtime configuration.
|
||||
@@ -13,6 +18,10 @@ type Config struct {
|
||||
HTTPAddr string
|
||||
// LogLevel is the zap log level: "debug", "info", "warn" or "error".
|
||||
LogLevel string
|
||||
// Postgres configures the primary database pool.
|
||||
Postgres postgres.Config
|
||||
// Telemetry configures the OpenTelemetry providers.
|
||||
Telemetry telemetry.Config
|
||||
}
|
||||
|
||||
// Defaults applied when the corresponding environment variable is unset.
|
||||
@@ -21,12 +30,35 @@ const (
|
||||
defaultLogLevel = "info"
|
||||
)
|
||||
|
||||
// Load reads the configuration from the environment, applies defaults for
|
||||
// unset variables, and validates the result.
|
||||
// Load reads the configuration from the environment, applies defaults for unset
|
||||
// variables, and validates the result.
|
||||
func Load() (Config, error) {
|
||||
pg := postgres.DefaultConfig()
|
||||
pg.DSN = os.Getenv("BACKEND_POSTGRES_DSN")
|
||||
var err error
|
||||
if pg.MaxOpenConns, err = envInt("BACKEND_POSTGRES_MAX_OPEN_CONNS", pg.MaxOpenConns); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
if pg.MaxIdleConns, err = envInt("BACKEND_POSTGRES_MAX_IDLE_CONNS", pg.MaxIdleConns); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
if pg.ConnMaxLifetime, err = envDuration("BACKEND_POSTGRES_CONN_MAX_LIFETIME", pg.ConnMaxLifetime); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
if pg.OperationTimeout, err = envDuration("BACKEND_POSTGRES_OPERATION_TIMEOUT", pg.OperationTimeout); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
|
||||
tel := telemetry.DefaultConfig()
|
||||
tel.ServiceName = envOr("BACKEND_SERVICE_NAME", tel.ServiceName)
|
||||
tel.TracesExporter = envOr("BACKEND_OTEL_TRACES_EXPORTER", tel.TracesExporter)
|
||||
tel.MetricsExporter = envOr("BACKEND_OTEL_METRICS_EXPORTER", tel.MetricsExporter)
|
||||
|
||||
c := Config{
|
||||
HTTPAddr: envOr("BACKEND_HTTP_ADDR", defaultHTTPAddr),
|
||||
LogLevel: envOr("BACKEND_LOG_LEVEL", defaultLogLevel),
|
||||
HTTPAddr: envOr("BACKEND_HTTP_ADDR", defaultHTTPAddr),
|
||||
LogLevel: envOr("BACKEND_LOG_LEVEL", defaultLogLevel),
|
||||
Postgres: pg,
|
||||
Telemetry: tel,
|
||||
}
|
||||
if err := c.validate(); err != nil {
|
||||
return Config{}, err
|
||||
@@ -44,6 +76,12 @@ func (c Config) validate() error {
|
||||
if c.HTTPAddr == "" {
|
||||
return fmt.Errorf("config: BACKEND_HTTP_ADDR must not be empty")
|
||||
}
|
||||
if err := c.Postgres.Validate(); err != nil {
|
||||
return fmt.Errorf("config: %w (set BACKEND_POSTGRES_DSN)", err)
|
||||
}
|
||||
if err := c.Telemetry.Validate(); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -55,3 +93,31 @@ func envOr(key, fallback string) string {
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
// envInt parses the environment variable named key as an int, returning
|
||||
// fallback when it is unset and an error when it is set but malformed.
|
||||
func envInt(key string, fallback int) (int, error) {
|
||||
v := os.Getenv(key)
|
||||
if v == "" {
|
||||
return fallback, nil
|
||||
}
|
||||
n, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("config: %s: %w", key, err)
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// envDuration parses the environment variable named key as a Go duration,
|
||||
// returning fallback when it is unset and an error when it is set but malformed.
|
||||
func envDuration(key string, fallback time.Duration) (time.Duration, error) {
|
||||
v := os.Getenv(key)
|
||||
if v == "" {
|
||||
return fallback, nil
|
||||
}
|
||||
d, err := time.ParseDuration(v)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("config: %s: %w", key, err)
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
|
||||
@@ -1,12 +1,22 @@
|
||||
package config
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
// TestLoadDefaults verifies that Load applies defaults when the environment is
|
||||
// empty.
|
||||
"scrabble/backend/internal/postgres"
|
||||
"scrabble/backend/internal/telemetry"
|
||||
)
|
||||
|
||||
// testDSN is a syntactically valid DSN used to satisfy the required-DSN check.
|
||||
const testDSN = "postgres://u:p@localhost:5432/db?search_path=backend&sslmode=disable"
|
||||
|
||||
// TestLoadDefaults verifies that Load applies defaults when only the required
|
||||
// DSN is set.
|
||||
func TestLoadDefaults(t *testing.T) {
|
||||
t.Setenv("BACKEND_HTTP_ADDR", "")
|
||||
t.Setenv("BACKEND_LOG_LEVEL", "")
|
||||
t.Setenv("BACKEND_POSTGRES_DSN", testDSN)
|
||||
|
||||
c, err := Load()
|
||||
if err != nil {
|
||||
@@ -18,12 +28,29 @@ func TestLoadDefaults(t *testing.T) {
|
||||
if c.LogLevel != defaultLogLevel {
|
||||
t.Errorf("LogLevel = %q, want %q", c.LogLevel, defaultLogLevel)
|
||||
}
|
||||
if c.Postgres.DSN != testDSN {
|
||||
t.Errorf("Postgres.DSN = %q, want %q", c.Postgres.DSN, testDSN)
|
||||
}
|
||||
if c.Postgres.MaxOpenConns != postgres.DefaultMaxOpenConns {
|
||||
t.Errorf("Postgres.MaxOpenConns = %d, want %d", c.Postgres.MaxOpenConns, postgres.DefaultMaxOpenConns)
|
||||
}
|
||||
if c.Telemetry.ServiceName != telemetry.DefaultServiceName {
|
||||
t.Errorf("Telemetry.ServiceName = %q, want %q", c.Telemetry.ServiceName, telemetry.DefaultServiceName)
|
||||
}
|
||||
if c.Telemetry.TracesExporter != telemetry.ExporterNone {
|
||||
t.Errorf("Telemetry.TracesExporter = %q, want %q", c.Telemetry.TracesExporter, telemetry.ExporterNone)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadOverrides verifies that environment variables override the defaults.
|
||||
func TestLoadOverrides(t *testing.T) {
|
||||
t.Setenv("BACKEND_POSTGRES_DSN", testDSN)
|
||||
t.Setenv("BACKEND_HTTP_ADDR", "127.0.0.1:9090")
|
||||
t.Setenv("BACKEND_LOG_LEVEL", "debug")
|
||||
t.Setenv("BACKEND_POSTGRES_MAX_OPEN_CONNS", "7")
|
||||
t.Setenv("BACKEND_POSTGRES_OPERATION_TIMEOUT", "3s")
|
||||
t.Setenv("BACKEND_SERVICE_NAME", "scrabble-test")
|
||||
t.Setenv("BACKEND_OTEL_TRACES_EXPORTER", "stdout")
|
||||
|
||||
c, err := Load()
|
||||
if err != nil {
|
||||
@@ -33,14 +60,63 @@ func TestLoadOverrides(t *testing.T) {
|
||||
t.Errorf("HTTPAddr = %q, want %q", c.HTTPAddr, "127.0.0.1:9090")
|
||||
}
|
||||
if c.LogLevel != "debug" {
|
||||
t.Errorf("LogLevel = %q, want %q", c.LogLevel, "debug")
|
||||
t.Errorf("LogLevel = %q", c.LogLevel)
|
||||
}
|
||||
if c.Postgres.MaxOpenConns != 7 {
|
||||
t.Errorf("Postgres.MaxOpenConns = %d, want 7", c.Postgres.MaxOpenConns)
|
||||
}
|
||||
if c.Postgres.OperationTimeout != 3*time.Second {
|
||||
t.Errorf("Postgres.OperationTimeout = %s, want 3s", c.Postgres.OperationTimeout)
|
||||
}
|
||||
if c.Telemetry.ServiceName != "scrabble-test" {
|
||||
t.Errorf("Telemetry.ServiceName = %q", c.Telemetry.ServiceName)
|
||||
}
|
||||
if c.Telemetry.TracesExporter != telemetry.ExporterStdout {
|
||||
t.Errorf("Telemetry.TracesExporter = %q, want %q", c.Telemetry.TracesExporter, telemetry.ExporterStdout)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadRejectsMissingDSN verifies that an empty DSN fails validation.
|
||||
func TestLoadRejectsMissingDSN(t *testing.T) {
|
||||
t.Setenv("BACKEND_POSTGRES_DSN", "")
|
||||
if _, err := Load(); err == nil {
|
||||
t.Fatal("Load: expected an error for a missing DSN, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadRejectsInvalidLevel verifies that an unknown log level is rejected.
|
||||
func TestLoadRejectsInvalidLevel(t *testing.T) {
|
||||
t.Setenv("BACKEND_POSTGRES_DSN", testDSN)
|
||||
t.Setenv("BACKEND_LOG_LEVEL", "verbose")
|
||||
if _, err := Load(); err == nil {
|
||||
t.Fatal("Load: expected an error for an invalid log level, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadRejectsMalformedInt verifies that a non-numeric pool size is rejected.
|
||||
func TestLoadRejectsMalformedInt(t *testing.T) {
|
||||
t.Setenv("BACKEND_POSTGRES_DSN", testDSN)
|
||||
t.Setenv("BACKEND_POSTGRES_MAX_OPEN_CONNS", "lots")
|
||||
if _, err := Load(); err == nil {
|
||||
t.Fatal("Load: expected an error for a malformed int, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadRejectsMalformedDuration verifies that a malformed duration is rejected.
|
||||
func TestLoadRejectsMalformedDuration(t *testing.T) {
|
||||
t.Setenv("BACKEND_POSTGRES_DSN", testDSN)
|
||||
t.Setenv("BACKEND_POSTGRES_OPERATION_TIMEOUT", "soon")
|
||||
if _, err := Load(); err == nil {
|
||||
t.Fatal("Load: expected an error for a malformed duration, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadRejectsUnsupportedExporter verifies that an exporter outside the MVP
|
||||
// set is rejected.
|
||||
func TestLoadRejectsUnsupportedExporter(t *testing.T) {
|
||||
t.Setenv("BACKEND_POSTGRES_DSN", testDSN)
|
||||
t.Setenv("BACKEND_OTEL_TRACES_EXPORTER", "otlp")
|
||||
if _, err := Load(); err == nil {
|
||||
t.Fatal("Load: expected an error for an unsupported exporter, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
//go:build integration
|
||||
|
||||
package inttest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"scrabble/backend/internal/account"
|
||||
)
|
||||
|
||||
// TestAccountProvisionByIdentity covers find-or-create semantics, distinct
|
||||
// accounts per identity, GetByID, and the identity confirmed flag per kind.
|
||||
func TestAccountProvisionByIdentity(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
store := account.NewStore(testDB)
|
||||
|
||||
tgExternal := "tg-" + uuid.NewString()
|
||||
first, err := store.ProvisionByIdentity(ctx, account.KindTelegram, tgExternal)
|
||||
if err != nil {
|
||||
t.Fatalf("provision telegram: %v", err)
|
||||
}
|
||||
if first.ID == uuid.Nil {
|
||||
t.Fatal("expected a non-nil account id")
|
||||
}
|
||||
if first.PreferredLanguage != "en" {
|
||||
t.Errorf("PreferredLanguage = %q, want default en", first.PreferredLanguage)
|
||||
}
|
||||
if first.TimeZone != "UTC" {
|
||||
t.Errorf("TimeZone = %q, want default UTC", first.TimeZone)
|
||||
}
|
||||
|
||||
// Re-provisioning the same identity returns the same account.
|
||||
again, err := store.ProvisionByIdentity(ctx, account.KindTelegram, tgExternal)
|
||||
if err != nil {
|
||||
t.Fatalf("re-provision telegram: %v", err)
|
||||
}
|
||||
if again.ID != first.ID {
|
||||
t.Errorf("re-provision id = %s, want %s", again.ID, first.ID)
|
||||
}
|
||||
|
||||
// A different identity yields a different account.
|
||||
other, err := store.ProvisionByIdentity(ctx, account.KindTelegram, "tg-"+uuid.NewString())
|
||||
if err != nil {
|
||||
t.Fatalf("provision other telegram: %v", err)
|
||||
}
|
||||
if other.ID == first.ID {
|
||||
t.Error("distinct identity must map to a distinct account")
|
||||
}
|
||||
|
||||
// GetByID round-trips, and a random id reports ErrNotFound.
|
||||
got, err := store.GetByID(ctx, first.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("get by id: %v", err)
|
||||
}
|
||||
if got.ID != first.ID {
|
||||
t.Errorf("get id = %s, want %s", got.ID, first.ID)
|
||||
}
|
||||
if _, err := store.GetByID(ctx, uuid.New()); !errors.Is(err, account.ErrNotFound) {
|
||||
t.Errorf("get missing = %v, want ErrNotFound", err)
|
||||
}
|
||||
|
||||
// A platform identity is confirmed; an email identity starts unconfirmed.
|
||||
if c := identityConfirmed(t, account.KindTelegram, tgExternal); !c {
|
||||
t.Error("telegram identity must be confirmed")
|
||||
}
|
||||
emailExternal := "e-" + uuid.NewString() + "@example.com"
|
||||
if _, err := store.ProvisionByIdentity(ctx, account.KindEmail, emailExternal); err != nil {
|
||||
t.Fatalf("provision email: %v", err)
|
||||
}
|
||||
if c := identityConfirmed(t, account.KindEmail, emailExternal); c {
|
||||
t.Error("email identity must start unconfirmed")
|
||||
}
|
||||
}
|
||||
|
||||
// identityConfirmed reads the confirmed flag for one identity directly.
|
||||
func identityConfirmed(t *testing.T, kind, externalID string) bool {
|
||||
t.Helper()
|
||||
var confirmed bool
|
||||
err := testDB.QueryRowContext(context.Background(),
|
||||
"SELECT confirmed FROM identities WHERE kind = $1 AND external_id = $2",
|
||||
kind, externalID,
|
||||
).Scan(&confirmed)
|
||||
if err != nil {
|
||||
t.Fatalf("read confirmed for (%s, %s): %v", kind, externalID, err)
|
||||
}
|
||||
return confirmed
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
// Package inttest holds the Postgres-backed integration tests for the backend.
|
||||
//
|
||||
// The tests are guarded by the `integration` build tag and run against a
|
||||
// throwaway postgres:17-alpine container started with testcontainers-go, so a
|
||||
// reachable Docker daemon is required. Run them with:
|
||||
//
|
||||
// go test -tags=integration ./backend/...
|
||||
//
|
||||
// They fail loudly when Docker is unavailable rather than skipping, per
|
||||
// docs/TESTING.md. This file carries no build tag so the package is non-empty
|
||||
// in the default (no-tag) build.
|
||||
package inttest
|
||||
@@ -0,0 +1,109 @@
|
||||
//go:build integration
|
||||
|
||||
package inttest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
testcontainers "github.com/testcontainers/testcontainers-go"
|
||||
tcpostgres "github.com/testcontainers/testcontainers-go/modules/postgres"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
|
||||
"scrabble/backend/internal/postgres"
|
||||
)
|
||||
|
||||
// testDB is the shared, migrated pool every integration test runs against. It is
|
||||
// hydrated once by TestMain.
|
||||
var testDB *sql.DB
|
||||
|
||||
const (
|
||||
pgImage = "postgres:17-alpine"
|
||||
pgDatabase = "scrabble_backend"
|
||||
pgUser = "scrabble"
|
||||
pgPassword = "scrabble"
|
||||
pgSchema = "backend"
|
||||
containerStartup = 90 * time.Second
|
||||
containerShutdown = 30 * time.Second
|
||||
)
|
||||
|
||||
// TestMain starts one Postgres container, applies the migrations, and shares the
|
||||
// resulting pool with every test. Any setup failure aborts the suite loudly
|
||||
// (exit 1) rather than skipping coverage.
|
||||
func TestMain(m *testing.M) {
|
||||
code, err := runSuite(m)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "inttest:", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func runSuite(m *testing.M) (int, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
container, err := tcpostgres.Run(ctx, pgImage,
|
||||
tcpostgres.WithDatabase(pgDatabase),
|
||||
tcpostgres.WithUsername(pgUser),
|
||||
tcpostgres.WithPassword(pgPassword),
|
||||
testcontainers.WithWaitStrategy(
|
||||
wait.ForLog("database system is ready to accept connections").
|
||||
WithOccurrence(2).
|
||||
WithStartupTimeout(containerStartup),
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("start postgres container: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), containerShutdown)
|
||||
defer cancel()
|
||||
if termErr := container.Terminate(shutdownCtx); termErr != nil {
|
||||
fmt.Fprintln(os.Stderr, "inttest: terminate container:", termErr)
|
||||
}
|
||||
}()
|
||||
|
||||
baseDSN, err := container.ConnectionString(ctx, "sslmode=disable")
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("resolve container dsn: %w", err)
|
||||
}
|
||||
dsn, err := withSearchPath(baseDSN, pgSchema)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
cfg := postgres.DefaultConfig()
|
||||
cfg.DSN = dsn
|
||||
db, err := postgres.Open(ctx, cfg)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("open pool: %w", err)
|
||||
}
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
if err := postgres.ApplyMigrations(ctx, db); err != nil {
|
||||
return 0, fmt.Errorf("apply migrations: %w", err)
|
||||
}
|
||||
|
||||
testDB = db
|
||||
return m.Run(), nil
|
||||
}
|
||||
|
||||
// withSearchPath rewrites dsn so every connection pins search_path to schema.
|
||||
func withSearchPath(dsn, schema string) (string, error) {
|
||||
u, err := url.Parse(dsn)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse dsn: %w", err)
|
||||
}
|
||||
q := u.Query()
|
||||
q.Set("search_path", schema)
|
||||
if q.Get("sslmode") == "" {
|
||||
q.Set("sslmode", "disable")
|
||||
}
|
||||
u.RawQuery = q.Encode()
|
||||
return u.String(), nil
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
//go:build integration
|
||||
|
||||
package inttest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"scrabble/backend/internal/postgres"
|
||||
)
|
||||
|
||||
// TestApplyMigrationsIdempotent re-applies the migrations against the already
|
||||
// migrated database and confirms the expected tables are queryable.
|
||||
func TestApplyMigrationsIdempotent(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
if err := postgres.ApplyMigrations(ctx, testDB); err != nil {
|
||||
t.Fatalf("re-apply migrations: %v", err)
|
||||
}
|
||||
for _, table := range []string{"accounts", "identities", "sessions"} {
|
||||
var n int
|
||||
if err := testDB.QueryRowContext(ctx, "SELECT count(*) FROM "+table).Scan(&n); err != nil {
|
||||
t.Errorf("count %s: %v", table, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
//go:build integration
|
||||
|
||||
package inttest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap/zaptest"
|
||||
|
||||
"scrabble/backend/internal/server"
|
||||
"scrabble/backend/internal/session"
|
||||
)
|
||||
|
||||
// TestReadyzWithRealDatabase exercises the assembled server against the live
|
||||
// container: /readyz answers 200 once the database pings and the session cache
|
||||
// is warmed, closing the gap the unit test (nil database) cannot cover.
|
||||
func TestReadyzWithRealDatabase(t *testing.T) {
|
||||
svc := session.NewService(session.NewStore(testDB), session.NewCache())
|
||||
if err := svc.Warm(context.Background()); err != nil {
|
||||
t.Fatalf("warm session cache: %v", err)
|
||||
}
|
||||
|
||||
srv := server.New(":0", server.Deps{
|
||||
Logger: zaptest.NewLogger(t),
|
||||
DB: testDB,
|
||||
PingTimeout: 5 * time.Second,
|
||||
SessionsReady: svc.Ready,
|
||||
})
|
||||
|
||||
for _, path := range []string{"/healthz", "/readyz"} {
|
||||
rec := httptest.NewRecorder()
|
||||
srv.Handler().ServeHTTP(rec, httptest.NewRequest(http.MethodGet, path, nil))
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("%s status = %d, want 200", path, rec.Code)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
//go:build integration
|
||||
|
||||
package inttest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"scrabble/backend/internal/account"
|
||||
"scrabble/backend/internal/session"
|
||||
)
|
||||
|
||||
// TestSessionLifecycle covers create, cache-hit resolve, DB-fallback resolve
|
||||
// after a cold cache warm, idempotent revoke, and post-revoke resolution.
|
||||
func TestSessionLifecycle(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
acc, err := account.NewStore(testDB).ProvisionByIdentity(ctx, account.KindTelegram, "tg-"+uuid.NewString())
|
||||
if err != nil {
|
||||
t.Fatalf("provision account: %v", err)
|
||||
}
|
||||
|
||||
store := session.NewStore(testDB)
|
||||
svc := session.NewService(store, session.NewCache())
|
||||
|
||||
token, sess, err := svc.Create(ctx, acc.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("create session: %v", err)
|
||||
}
|
||||
if sess.AccountID != acc.ID {
|
||||
t.Errorf("session account = %s, want %s", sess.AccountID, acc.ID)
|
||||
}
|
||||
if token == sess.TokenHash {
|
||||
t.Error("plaintext token must not equal the stored hash")
|
||||
}
|
||||
|
||||
// Resolve via the warm write-through cache.
|
||||
got, err := svc.Resolve(ctx, token)
|
||||
if err != nil {
|
||||
t.Fatalf("resolve (cache): %v", err)
|
||||
}
|
||||
if got.ID != sess.ID {
|
||||
t.Errorf("resolve id = %s, want %s", got.ID, sess.ID)
|
||||
}
|
||||
|
||||
// An unknown token is not found.
|
||||
if _, err := svc.Resolve(ctx, "not-a-real-token"); !errors.Is(err, session.ErrNotFound) {
|
||||
t.Errorf("resolve unknown = %v, want ErrNotFound", err)
|
||||
}
|
||||
|
||||
// A fresh service with a cold cache resolves through the DB after Warm.
|
||||
cold := session.NewCache()
|
||||
svc2 := session.NewService(store, cold)
|
||||
if err := svc2.Warm(ctx); err != nil {
|
||||
t.Fatalf("warm: %v", err)
|
||||
}
|
||||
if !cold.Ready() {
|
||||
t.Error("cache must be ready after Warm")
|
||||
}
|
||||
if _, ok := cold.Get(session.HashToken(token)); !ok {
|
||||
t.Error("Warm must load the active session into the cache")
|
||||
}
|
||||
if got2, err := svc2.Resolve(ctx, token); err != nil || got2.ID != sess.ID {
|
||||
t.Errorf("resolve after warm = (%s, %v), want %s", got2.ID, err, sess.ID)
|
||||
}
|
||||
|
||||
// Revoke, then the token no longer resolves; revoke again is a no-op.
|
||||
if err := svc.Revoke(ctx, token); err != nil {
|
||||
t.Fatalf("revoke: %v", err)
|
||||
}
|
||||
if _, err := svc.Resolve(ctx, token); !errors.Is(err, session.ErrNotFound) {
|
||||
t.Errorf("resolve after revoke = %v, want ErrNotFound", err)
|
||||
}
|
||||
if err := svc.Revoke(ctx, token); err != nil {
|
||||
t.Errorf("idempotent revoke: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
//
|
||||
// Code generated by go-jet DO NOT EDIT.
|
||||
//
|
||||
// WARNING: Changes to this file may cause incorrect behavior
|
||||
// and will be lost if the code is regenerated
|
||||
//
|
||||
|
||||
package model
|
||||
|
||||
import (
|
||||
"github.com/google/uuid"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Accounts struct {
|
||||
AccountID uuid.UUID `sql:"primary_key"`
|
||||
DisplayName string
|
||||
PreferredLanguage string
|
||||
TimeZone string
|
||||
BlockChat bool
|
||||
BlockFriendRequests bool
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
//
|
||||
// Code generated by go-jet DO NOT EDIT.
|
||||
//
|
||||
// WARNING: Changes to this file may cause incorrect behavior
|
||||
// and will be lost if the code is regenerated
|
||||
//
|
||||
|
||||
package model
|
||||
|
||||
import (
|
||||
"github.com/google/uuid"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Identities struct {
|
||||
IdentityID uuid.UUID `sql:"primary_key"`
|
||||
AccountID uuid.UUID
|
||||
Kind string
|
||||
ExternalID string
|
||||
Confirmed bool
|
||||
CreatedAt time.Time
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
//
|
||||
// Code generated by go-jet DO NOT EDIT.
|
||||
//
|
||||
// WARNING: Changes to this file may cause incorrect behavior
|
||||
// and will be lost if the code is regenerated
|
||||
//
|
||||
|
||||
package model
|
||||
|
||||
import (
|
||||
"github.com/google/uuid"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Sessions struct {
|
||||
SessionID uuid.UUID `sql:"primary_key"`
|
||||
AccountID uuid.UUID
|
||||
TokenHash string
|
||||
Status string
|
||||
CreatedAt time.Time
|
||||
LastSeenAt *time.Time
|
||||
RevokedAt *time.Time
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
//
|
||||
// Code generated by go-jet DO NOT EDIT.
|
||||
//
|
||||
// WARNING: Changes to this file may cause incorrect behavior
|
||||
// and will be lost if the code is regenerated
|
||||
//
|
||||
|
||||
package table
|
||||
|
||||
import (
|
||||
"github.com/go-jet/jet/v2/postgres"
|
||||
)
|
||||
|
||||
var Accounts = newAccountsTable("backend", "accounts", "")
|
||||
|
||||
type accountsTable struct {
|
||||
postgres.Table
|
||||
|
||||
// Columns
|
||||
AccountID postgres.ColumnString
|
||||
DisplayName postgres.ColumnString
|
||||
PreferredLanguage postgres.ColumnString
|
||||
TimeZone postgres.ColumnString
|
||||
BlockChat postgres.ColumnBool
|
||||
BlockFriendRequests postgres.ColumnBool
|
||||
CreatedAt postgres.ColumnTimestampz
|
||||
UpdatedAt postgres.ColumnTimestampz
|
||||
|
||||
AllColumns postgres.ColumnList
|
||||
MutableColumns postgres.ColumnList
|
||||
DefaultColumns postgres.ColumnList
|
||||
}
|
||||
|
||||
type AccountsTable struct {
|
||||
accountsTable
|
||||
|
||||
EXCLUDED accountsTable
|
||||
}
|
||||
|
||||
// AS creates new AccountsTable with assigned alias
|
||||
func (a AccountsTable) AS(alias string) *AccountsTable {
|
||||
return newAccountsTable(a.SchemaName(), a.TableName(), alias)
|
||||
}
|
||||
|
||||
// Schema creates new AccountsTable with assigned schema name
|
||||
func (a AccountsTable) FromSchema(schemaName string) *AccountsTable {
|
||||
return newAccountsTable(schemaName, a.TableName(), a.Alias())
|
||||
}
|
||||
|
||||
// WithPrefix creates new AccountsTable with assigned table prefix
|
||||
func (a AccountsTable) WithPrefix(prefix string) *AccountsTable {
|
||||
return newAccountsTable(a.SchemaName(), prefix+a.TableName(), a.TableName())
|
||||
}
|
||||
|
||||
// WithSuffix creates new AccountsTable with assigned table suffix
|
||||
func (a AccountsTable) WithSuffix(suffix string) *AccountsTable {
|
||||
return newAccountsTable(a.SchemaName(), a.TableName()+suffix, a.TableName())
|
||||
}
|
||||
|
||||
func newAccountsTable(schemaName, tableName, alias string) *AccountsTable {
|
||||
return &AccountsTable{
|
||||
accountsTable: newAccountsTableImpl(schemaName, tableName, alias),
|
||||
EXCLUDED: newAccountsTableImpl("", "excluded", ""),
|
||||
}
|
||||
}
|
||||
|
||||
func newAccountsTableImpl(schemaName, tableName, alias string) accountsTable {
|
||||
var (
|
||||
AccountIDColumn = postgres.StringColumn("account_id")
|
||||
DisplayNameColumn = postgres.StringColumn("display_name")
|
||||
PreferredLanguageColumn = postgres.StringColumn("preferred_language")
|
||||
TimeZoneColumn = postgres.StringColumn("time_zone")
|
||||
BlockChatColumn = postgres.BoolColumn("block_chat")
|
||||
BlockFriendRequestsColumn = postgres.BoolColumn("block_friend_requests")
|
||||
CreatedAtColumn = postgres.TimestampzColumn("created_at")
|
||||
UpdatedAtColumn = postgres.TimestampzColumn("updated_at")
|
||||
allColumns = postgres.ColumnList{AccountIDColumn, DisplayNameColumn, PreferredLanguageColumn, TimeZoneColumn, BlockChatColumn, BlockFriendRequestsColumn, CreatedAtColumn, UpdatedAtColumn}
|
||||
mutableColumns = postgres.ColumnList{DisplayNameColumn, PreferredLanguageColumn, TimeZoneColumn, BlockChatColumn, BlockFriendRequestsColumn, CreatedAtColumn, UpdatedAtColumn}
|
||||
defaultColumns = postgres.ColumnList{DisplayNameColumn, PreferredLanguageColumn, TimeZoneColumn, BlockChatColumn, BlockFriendRequestsColumn, CreatedAtColumn, UpdatedAtColumn}
|
||||
)
|
||||
|
||||
return accountsTable{
|
||||
Table: postgres.NewTable(schemaName, tableName, alias, allColumns...),
|
||||
|
||||
//Columns
|
||||
AccountID: AccountIDColumn,
|
||||
DisplayName: DisplayNameColumn,
|
||||
PreferredLanguage: PreferredLanguageColumn,
|
||||
TimeZone: TimeZoneColumn,
|
||||
BlockChat: BlockChatColumn,
|
||||
BlockFriendRequests: BlockFriendRequestsColumn,
|
||||
CreatedAt: CreatedAtColumn,
|
||||
UpdatedAt: UpdatedAtColumn,
|
||||
|
||||
AllColumns: allColumns,
|
||||
MutableColumns: mutableColumns,
|
||||
DefaultColumns: defaultColumns,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
//
|
||||
// Code generated by go-jet DO NOT EDIT.
|
||||
//
|
||||
// WARNING: Changes to this file may cause incorrect behavior
|
||||
// and will be lost if the code is regenerated
|
||||
//
|
||||
|
||||
package table
|
||||
|
||||
import (
|
||||
"github.com/go-jet/jet/v2/postgres"
|
||||
)
|
||||
|
||||
var Identities = newIdentitiesTable("backend", "identities", "")
|
||||
|
||||
type identitiesTable struct {
|
||||
postgres.Table
|
||||
|
||||
// Columns
|
||||
IdentityID postgres.ColumnString
|
||||
AccountID postgres.ColumnString
|
||||
Kind postgres.ColumnString
|
||||
ExternalID postgres.ColumnString
|
||||
Confirmed postgres.ColumnBool
|
||||
CreatedAt postgres.ColumnTimestampz
|
||||
|
||||
AllColumns postgres.ColumnList
|
||||
MutableColumns postgres.ColumnList
|
||||
DefaultColumns postgres.ColumnList
|
||||
}
|
||||
|
||||
type IdentitiesTable struct {
|
||||
identitiesTable
|
||||
|
||||
EXCLUDED identitiesTable
|
||||
}
|
||||
|
||||
// AS creates new IdentitiesTable with assigned alias
|
||||
func (a IdentitiesTable) AS(alias string) *IdentitiesTable {
|
||||
return newIdentitiesTable(a.SchemaName(), a.TableName(), alias)
|
||||
}
|
||||
|
||||
// Schema creates new IdentitiesTable with assigned schema name
|
||||
func (a IdentitiesTable) FromSchema(schemaName string) *IdentitiesTable {
|
||||
return newIdentitiesTable(schemaName, a.TableName(), a.Alias())
|
||||
}
|
||||
|
||||
// WithPrefix creates new IdentitiesTable with assigned table prefix
|
||||
func (a IdentitiesTable) WithPrefix(prefix string) *IdentitiesTable {
|
||||
return newIdentitiesTable(a.SchemaName(), prefix+a.TableName(), a.TableName())
|
||||
}
|
||||
|
||||
// WithSuffix creates new IdentitiesTable with assigned table suffix
|
||||
func (a IdentitiesTable) WithSuffix(suffix string) *IdentitiesTable {
|
||||
return newIdentitiesTable(a.SchemaName(), a.TableName()+suffix, a.TableName())
|
||||
}
|
||||
|
||||
func newIdentitiesTable(schemaName, tableName, alias string) *IdentitiesTable {
|
||||
return &IdentitiesTable{
|
||||
identitiesTable: newIdentitiesTableImpl(schemaName, tableName, alias),
|
||||
EXCLUDED: newIdentitiesTableImpl("", "excluded", ""),
|
||||
}
|
||||
}
|
||||
|
||||
func newIdentitiesTableImpl(schemaName, tableName, alias string) identitiesTable {
|
||||
var (
|
||||
IdentityIDColumn = postgres.StringColumn("identity_id")
|
||||
AccountIDColumn = postgres.StringColumn("account_id")
|
||||
KindColumn = postgres.StringColumn("kind")
|
||||
ExternalIDColumn = postgres.StringColumn("external_id")
|
||||
ConfirmedColumn = postgres.BoolColumn("confirmed")
|
||||
CreatedAtColumn = postgres.TimestampzColumn("created_at")
|
||||
allColumns = postgres.ColumnList{IdentityIDColumn, AccountIDColumn, KindColumn, ExternalIDColumn, ConfirmedColumn, CreatedAtColumn}
|
||||
mutableColumns = postgres.ColumnList{AccountIDColumn, KindColumn, ExternalIDColumn, ConfirmedColumn, CreatedAtColumn}
|
||||
defaultColumns = postgres.ColumnList{ConfirmedColumn, CreatedAtColumn}
|
||||
)
|
||||
|
||||
return identitiesTable{
|
||||
Table: postgres.NewTable(schemaName, tableName, alias, allColumns...),
|
||||
|
||||
//Columns
|
||||
IdentityID: IdentityIDColumn,
|
||||
AccountID: AccountIDColumn,
|
||||
Kind: KindColumn,
|
||||
ExternalID: ExternalIDColumn,
|
||||
Confirmed: ConfirmedColumn,
|
||||
CreatedAt: CreatedAtColumn,
|
||||
|
||||
AllColumns: allColumns,
|
||||
MutableColumns: mutableColumns,
|
||||
DefaultColumns: defaultColumns,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,96 @@
|
||||
//
|
||||
// Code generated by go-jet DO NOT EDIT.
|
||||
//
|
||||
// WARNING: Changes to this file may cause incorrect behavior
|
||||
// and will be lost if the code is regenerated
|
||||
//
|
||||
|
||||
package table
|
||||
|
||||
import (
|
||||
"github.com/go-jet/jet/v2/postgres"
|
||||
)
|
||||
|
||||
var Sessions = newSessionsTable("backend", "sessions", "")
|
||||
|
||||
type sessionsTable struct {
|
||||
postgres.Table
|
||||
|
||||
// Columns
|
||||
SessionID postgres.ColumnString
|
||||
AccountID postgres.ColumnString
|
||||
TokenHash postgres.ColumnString
|
||||
Status postgres.ColumnString
|
||||
CreatedAt postgres.ColumnTimestampz
|
||||
LastSeenAt postgres.ColumnTimestampz
|
||||
RevokedAt postgres.ColumnTimestampz
|
||||
|
||||
AllColumns postgres.ColumnList
|
||||
MutableColumns postgres.ColumnList
|
||||
DefaultColumns postgres.ColumnList
|
||||
}
|
||||
|
||||
type SessionsTable struct {
|
||||
sessionsTable
|
||||
|
||||
EXCLUDED sessionsTable
|
||||
}
|
||||
|
||||
// AS creates new SessionsTable with assigned alias
|
||||
func (a SessionsTable) AS(alias string) *SessionsTable {
|
||||
return newSessionsTable(a.SchemaName(), a.TableName(), alias)
|
||||
}
|
||||
|
||||
// Schema creates new SessionsTable with assigned schema name
|
||||
func (a SessionsTable) FromSchema(schemaName string) *SessionsTable {
|
||||
return newSessionsTable(schemaName, a.TableName(), a.Alias())
|
||||
}
|
||||
|
||||
// WithPrefix creates new SessionsTable with assigned table prefix
|
||||
func (a SessionsTable) WithPrefix(prefix string) *SessionsTable {
|
||||
return newSessionsTable(a.SchemaName(), prefix+a.TableName(), a.TableName())
|
||||
}
|
||||
|
||||
// WithSuffix creates new SessionsTable with assigned table suffix
|
||||
func (a SessionsTable) WithSuffix(suffix string) *SessionsTable {
|
||||
return newSessionsTable(a.SchemaName(), a.TableName()+suffix, a.TableName())
|
||||
}
|
||||
|
||||
func newSessionsTable(schemaName, tableName, alias string) *SessionsTable {
|
||||
return &SessionsTable{
|
||||
sessionsTable: newSessionsTableImpl(schemaName, tableName, alias),
|
||||
EXCLUDED: newSessionsTableImpl("", "excluded", ""),
|
||||
}
|
||||
}
|
||||
|
||||
func newSessionsTableImpl(schemaName, tableName, alias string) sessionsTable {
|
||||
var (
|
||||
SessionIDColumn = postgres.StringColumn("session_id")
|
||||
AccountIDColumn = postgres.StringColumn("account_id")
|
||||
TokenHashColumn = postgres.StringColumn("token_hash")
|
||||
StatusColumn = postgres.StringColumn("status")
|
||||
CreatedAtColumn = postgres.TimestampzColumn("created_at")
|
||||
LastSeenAtColumn = postgres.TimestampzColumn("last_seen_at")
|
||||
RevokedAtColumn = postgres.TimestampzColumn("revoked_at")
|
||||
allColumns = postgres.ColumnList{SessionIDColumn, AccountIDColumn, TokenHashColumn, StatusColumn, CreatedAtColumn, LastSeenAtColumn, RevokedAtColumn}
|
||||
mutableColumns = postgres.ColumnList{AccountIDColumn, TokenHashColumn, StatusColumn, CreatedAtColumn, LastSeenAtColumn, RevokedAtColumn}
|
||||
defaultColumns = postgres.ColumnList{StatusColumn, CreatedAtColumn}
|
||||
)
|
||||
|
||||
return sessionsTable{
|
||||
Table: postgres.NewTable(schemaName, tableName, alias, allColumns...),
|
||||
|
||||
//Columns
|
||||
SessionID: SessionIDColumn,
|
||||
AccountID: AccountIDColumn,
|
||||
TokenHash: TokenHashColumn,
|
||||
Status: StatusColumn,
|
||||
CreatedAt: CreatedAtColumn,
|
||||
LastSeenAt: LastSeenAtColumn,
|
||||
RevokedAt: RevokedAtColumn,
|
||||
|
||||
AllColumns: allColumns,
|
||||
MutableColumns: mutableColumns,
|
||||
DefaultColumns: defaultColumns,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
//
|
||||
// Code generated by go-jet DO NOT EDIT.
|
||||
//
|
||||
// WARNING: Changes to this file may cause incorrect behavior
|
||||
// and will be lost if the code is regenerated
|
||||
//
|
||||
|
||||
package table
|
||||
|
||||
// UseSchema sets a new schema name for all generated table SQL builder types. It is recommended to invoke
|
||||
// this method only once at the beginning of the program.
|
||||
func UseSchema(schema string) {
|
||||
Accounts = Accounts.FromSchema(schema)
|
||||
Identities = Identities.FromSchema(schema)
|
||||
Sessions = Sessions.FromSchema(schema)
|
||||
}
|
||||
@@ -0,0 +1,130 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pressly/goose/v3"
|
||||
|
||||
"scrabble/backend/internal/postgres/migrations"
|
||||
)
|
||||
|
||||
// schemaName is the Postgres schema owned by the backend service. Every backend
|
||||
// table lives here, and the DSN pins search_path to it.
|
||||
const schemaName = "backend"
|
||||
|
||||
// migrationRetryAttempts and migrationRetryBackoff bound the transient-error
|
||||
// retry around ApplyMigrations. A freshly started Postgres — notably a test
|
||||
// container — can reset a pooled connection moments after it reports ready,
|
||||
// which surfaces as "bad connection" mid-migration; a handful of quick retries
|
||||
// ride over that without masking real failures.
|
||||
const (
|
||||
migrationRetryAttempts = 5
|
||||
migrationRetryBackoff = 250 * time.Millisecond
|
||||
)
|
||||
|
||||
// gooseMu serialises access to goose's package-level filesystem state so a
|
||||
// second caller in the same process cannot race on goose.SetBaseFS.
|
||||
var gooseMu sync.Mutex
|
||||
|
||||
// ApplyMigrations runs every pending Up migration embedded in the backend
|
||||
// binary against db. The schema is created upfront so goose's bookkeeping table
|
||||
// (`goose_db_version`, scoped to the DSN search_path) has somewhere to land
|
||||
// before the first migration runs; migration 00001_init.sql re-asserts the
|
||||
// schema with IF NOT EXISTS, so the double-create is idempotent.
|
||||
//
|
||||
// The apply is retried on transient connection errors. Both steps are
|
||||
// idempotent, so a retry after a dropped connection resumes from the last
|
||||
// committed migration.
|
||||
func ApplyMigrations(ctx context.Context, db *sql.DB) error {
|
||||
return retryOnTransient(ctx, migrationRetryAttempts, migrationRetryBackoff, func() error {
|
||||
if _, err := db.ExecContext(ctx, "CREATE SCHEMA IF NOT EXISTS "+schemaName); err != nil {
|
||||
return fmt.Errorf("ensure backend schema: %w", err)
|
||||
}
|
||||
if err := runMigrations(ctx, db, migrations.Migrations(), "."); err != nil {
|
||||
return fmt.Errorf("apply backend migrations: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// runMigrations applies every pending Up migration found under dir inside fsys
|
||||
// against db. The PostgreSQL dialect is forced; goose's package-level base FS is
|
||||
// restored on the way out so a second caller in the same process is safe. dir
|
||||
// is "." when the migration files sit at the embed root.
|
||||
func runMigrations(ctx context.Context, db *sql.DB, fsys fs.FS, dir string) error {
|
||||
if db == nil {
|
||||
return errors.New("run migrations: nil db")
|
||||
}
|
||||
if fsys == nil {
|
||||
return errors.New("run migrations: nil fs")
|
||||
}
|
||||
|
||||
gooseMu.Lock()
|
||||
defer gooseMu.Unlock()
|
||||
|
||||
goose.SetBaseFS(fsys)
|
||||
defer goose.SetBaseFS(nil)
|
||||
|
||||
if err := goose.SetDialect("postgres"); err != nil {
|
||||
return fmt.Errorf("run migrations: set dialect: %w", err)
|
||||
}
|
||||
if err := goose.UpContext(ctx, db, dir); err != nil {
|
||||
return fmt.Errorf("run migrations: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// retryOnTransient runs op up to attempts times, retrying only when op fails
|
||||
// with a transient connection error — a dropped, reset, or refused connection,
|
||||
// as opposed to a deterministic SQL error. It waits backoff between attempts and
|
||||
// stops early if ctx is cancelled.
|
||||
func retryOnTransient(ctx context.Context, attempts int, backoff time.Duration, op func() error) error {
|
||||
var err error
|
||||
for attempt := 1; attempt <= attempts; attempt++ {
|
||||
if err = op(); err == nil {
|
||||
return nil
|
||||
}
|
||||
if attempt == attempts || !isTransientConnError(err) {
|
||||
return err
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return errors.Join(err, ctx.Err())
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// isTransientConnError reports whether err is a transient connection-level
|
||||
// failure worth retrying, leaving deterministic SQL errors (syntax, constraint
|
||||
// violations) to fail fast.
|
||||
func isTransientConnError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if errors.Is(err, driver.ErrBadConn) {
|
||||
return true
|
||||
}
|
||||
msg := strings.ToLower(err.Error())
|
||||
for _, s := range []string{
|
||||
"bad connection",
|
||||
"connection refused",
|
||||
"connection reset",
|
||||
"broken pipe",
|
||||
"server closed the connection",
|
||||
} {
|
||||
if strings.Contains(msg, s) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
-- +goose Up
|
||||
-- Initial schema for the Scrabble backend service: durable accounts, their
|
||||
-- platform/email identities, and opaque server sessions.
|
||||
--
|
||||
-- Every backend table lives in the `backend` schema. The schema is created here
|
||||
-- so a fresh database can apply this migration, and search_path is pinned for
|
||||
-- the rest of the migration so the CREATE statements land in `backend` without
|
||||
-- qualifying every object. Production also pins search_path via
|
||||
-- BACKEND_POSTGRES_DSN.
|
||||
CREATE SCHEMA IF NOT EXISTS backend;
|
||||
SET search_path = backend, pg_catalog;
|
||||
|
||||
-- Durable internal accounts. Guests are session-only and never reach this table.
|
||||
CREATE TABLE accounts (
|
||||
account_id uuid PRIMARY KEY,
|
||||
display_name text NOT NULL DEFAULT '',
|
||||
preferred_language text NOT NULL DEFAULT 'en',
|
||||
time_zone text NOT NULL DEFAULT 'UTC',
|
||||
block_chat boolean NOT NULL DEFAULT false,
|
||||
block_friend_requests boolean NOT NULL DEFAULT false,
|
||||
created_at timestamptz NOT NULL DEFAULT now(),
|
||||
updated_at timestamptz NOT NULL DEFAULT now(),
|
||||
CONSTRAINT accounts_preferred_language_chk CHECK (preferred_language IN ('en', 'ru'))
|
||||
);
|
||||
|
||||
-- Platform and email identities attached to an account. external_id is the
|
||||
-- platform user id (kind='telegram') or the email address (kind='email');
|
||||
-- confirmed flips true once an email confirm-code is verified (later stages).
|
||||
CREATE TABLE identities (
|
||||
identity_id uuid PRIMARY KEY,
|
||||
account_id uuid NOT NULL REFERENCES accounts (account_id) ON DELETE CASCADE,
|
||||
kind text NOT NULL,
|
||||
external_id text NOT NULL,
|
||||
confirmed boolean NOT NULL DEFAULT false,
|
||||
created_at timestamptz NOT NULL DEFAULT now(),
|
||||
CONSTRAINT identities_kind_chk CHECK (kind IN ('telegram', 'email')),
|
||||
CONSTRAINT identities_kind_external_id_key UNIQUE (kind, external_id)
|
||||
);
|
||||
CREATE INDEX identities_account_idx ON identities (account_id);
|
||||
|
||||
-- Opaque server sessions. token_hash is the hex-encoded SHA-256 of the bearer
|
||||
-- token; the plaintext token is never stored. Sessions are revoke-only (no
|
||||
-- TTL): status moves active -> revoked and revoked_at is stamped.
|
||||
CREATE TABLE sessions (
|
||||
session_id uuid PRIMARY KEY,
|
||||
account_id uuid NOT NULL REFERENCES accounts (account_id) ON DELETE CASCADE,
|
||||
token_hash text NOT NULL,
|
||||
status text NOT NULL DEFAULT 'active',
|
||||
created_at timestamptz NOT NULL DEFAULT now(),
|
||||
last_seen_at timestamptz,
|
||||
revoked_at timestamptz,
|
||||
CONSTRAINT sessions_status_chk CHECK (status IN ('active', 'revoked')),
|
||||
CONSTRAINT sessions_token_hash_key UNIQUE (token_hash)
|
||||
);
|
||||
CREATE INDEX sessions_account_idx ON sessions (account_id);
|
||||
|
||||
-- +goose Down
|
||||
DROP TABLE sessions;
|
||||
DROP TABLE identities;
|
||||
DROP TABLE accounts;
|
||||
@@ -0,0 +1,16 @@
|
||||
// Package migrations exposes the goose migrations applied at backend startup.
|
||||
package migrations
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"io/fs"
|
||||
)
|
||||
|
||||
//go:embed *.sql
|
||||
var migrationFiles embed.FS
|
||||
|
||||
// Migrations returns the embedded goose migration filesystem. The migration
|
||||
// files sit at the FS root, so callers pass "." as the directory argument.
|
||||
func Migrations() fs.FS {
|
||||
return migrationFiles
|
||||
}
|
||||
@@ -0,0 +1,177 @@
|
||||
// Package postgres opens the backend's Postgres pool and applies the embedded
|
||||
// goose migrations into the `backend` schema at startup.
|
||||
//
|
||||
// The pool is a standard library *sql.DB backed by the pgx driver (registered
|
||||
// through pgx/stdlib) and instrumented with otelsql, so go-jet queries run over
|
||||
// database/sql while statement spans and connection-pool metrics flow into
|
||||
// OpenTelemetry. The DSN must pin search_path to the backend schema.
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/XSAM/otelsql"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/stdlib"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
// Default pool tuning applied by DefaultConfig.
|
||||
const (
|
||||
DefaultMaxOpenConns = 25
|
||||
DefaultMaxIdleConns = 5
|
||||
DefaultConnMaxLifetime = 30 * time.Minute
|
||||
DefaultOperationTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// dbSystemAttribute identifies the wrapped backend in OpenTelemetry spans
|
||||
// without pinning the package to a specific semconv release.
|
||||
var dbSystemAttribute = attribute.String("db.system", "postgresql")
|
||||
|
||||
// Config describes how to open the backend Postgres pool.
|
||||
type Config struct {
|
||||
// DSN is the pgx/libpq connection string. It must pin search_path to the
|
||||
// backend schema, e.g. "postgres://…/db?search_path=backend&sslmode=disable".
|
||||
DSN string
|
||||
// MaxOpenConns bounds the pool's open connections (database/sql).
|
||||
MaxOpenConns int
|
||||
// MaxIdleConns bounds the pool's idle connections (database/sql).
|
||||
MaxIdleConns int
|
||||
// ConnMaxLifetime caps how long a pooled connection may be reused.
|
||||
ConnMaxLifetime time.Duration
|
||||
// OperationTimeout bounds a single connect attempt and the startup Ping.
|
||||
OperationTimeout time.Duration
|
||||
}
|
||||
|
||||
// DefaultConfig returns a Config carrying the default pool tuning and an empty
|
||||
// DSN. Callers fill DSN from the environment before opening.
|
||||
func DefaultConfig() Config {
|
||||
return Config{
|
||||
MaxOpenConns: DefaultMaxOpenConns,
|
||||
MaxIdleConns: DefaultMaxIdleConns,
|
||||
ConnMaxLifetime: DefaultConnMaxLifetime,
|
||||
OperationTimeout: DefaultOperationTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// Validate reports whether the configuration is usable.
|
||||
func (c Config) Validate() error {
|
||||
if strings.TrimSpace(c.DSN) == "" {
|
||||
return errors.New("postgres: DSN must not be empty")
|
||||
}
|
||||
if c.MaxOpenConns <= 0 {
|
||||
return fmt.Errorf("postgres: MaxOpenConns must be positive, got %d", c.MaxOpenConns)
|
||||
}
|
||||
if c.MaxIdleConns < 0 {
|
||||
return fmt.Errorf("postgres: MaxIdleConns must not be negative, got %d", c.MaxIdleConns)
|
||||
}
|
||||
if c.OperationTimeout <= 0 {
|
||||
return fmt.Errorf("postgres: OperationTimeout must be positive, got %s", c.OperationTimeout)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Option configures the OpenTelemetry providers attached to a pool by Open.
|
||||
// Unset providers fall back to the OpenTelemetry global providers.
|
||||
type Option func(*options)
|
||||
|
||||
type options struct {
|
||||
tracerProvider trace.TracerProvider
|
||||
meterProvider metric.MeterProvider
|
||||
}
|
||||
|
||||
// WithTracerProvider sets the tracer provider used for SQL statement spans.
|
||||
func WithTracerProvider(tp trace.TracerProvider) Option {
|
||||
return func(o *options) { o.tracerProvider = tp }
|
||||
}
|
||||
|
||||
// WithMeterProvider sets the meter provider used for connection-pool metrics.
|
||||
func WithMeterProvider(mp metric.MeterProvider) Option {
|
||||
return func(o *options) { o.meterProvider = mp }
|
||||
}
|
||||
|
||||
func evalOptions(opts []Option) options {
|
||||
var resolved options
|
||||
for _, opt := range opts {
|
||||
if opt != nil {
|
||||
opt(&resolved)
|
||||
}
|
||||
}
|
||||
return resolved
|
||||
}
|
||||
|
||||
func (o options) otelsqlOptions() []otelsql.Option {
|
||||
out := []otelsql.Option{otelsql.WithAttributes(dbSystemAttribute)}
|
||||
if o.tracerProvider != nil {
|
||||
out = append(out, otelsql.WithTracerProvider(o.tracerProvider))
|
||||
}
|
||||
if o.meterProvider != nil {
|
||||
out = append(out, otelsql.WithMeterProvider(o.meterProvider))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// Open opens the instrumented pool described by cfg, registers connection-pool
|
||||
// metrics, and verifies connectivity with a bounded Ping. Closing the returned
|
||||
// *sql.DB is the caller's responsibility.
|
||||
func Open(ctx context.Context, cfg Config, opts ...Option) (*sql.DB, error) {
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("open postgres: %w", err)
|
||||
}
|
||||
resolved := evalOptions(opts)
|
||||
|
||||
pgxCfg, err := pgx.ParseConfig(cfg.DSN)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open postgres: parse dsn: %w", err)
|
||||
}
|
||||
pgxCfg.ConnectTimeout = cfg.OperationTimeout
|
||||
registeredName := stdlib.RegisterConnConfig(pgxCfg)
|
||||
|
||||
db, err := otelsql.Open("pgx", registeredName, resolved.otelsqlOptions()...)
|
||||
if err != nil {
|
||||
stdlib.UnregisterConnConfig(registeredName)
|
||||
return nil, fmt.Errorf("open postgres: otelsql open: %w", err)
|
||||
}
|
||||
|
||||
db.SetMaxOpenConns(cfg.MaxOpenConns)
|
||||
db.SetMaxIdleConns(cfg.MaxIdleConns)
|
||||
db.SetConnMaxLifetime(cfg.ConnMaxLifetime)
|
||||
|
||||
if _, err := otelsql.RegisterDBStatsMetrics(db, resolved.otelsqlOptions()...); err != nil {
|
||||
_ = db.Close()
|
||||
stdlib.UnregisterConnConfig(registeredName)
|
||||
return nil, fmt.Errorf("open postgres: register db stats: %w", err)
|
||||
}
|
||||
|
||||
if err := Ping(ctx, db, cfg.OperationTimeout); err != nil {
|
||||
_ = db.Close()
|
||||
stdlib.UnregisterConnConfig(registeredName)
|
||||
return nil, err
|
||||
}
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// Ping bounds db.PingContext under timeout and wraps the error so startup
|
||||
// failures are easy to spot in service logs. The same call backs the /readyz
|
||||
// probe. timeout is typically Config.OperationTimeout.
|
||||
func Ping(ctx context.Context, db *sql.DB, timeout time.Duration) error {
|
||||
if db == nil {
|
||||
return errors.New("ping postgres: nil db")
|
||||
}
|
||||
if timeout <= 0 {
|
||||
return errors.New("ping postgres: timeout must be positive")
|
||||
}
|
||||
pingCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
if err := db.PingContext(pingCtx); err != nil {
|
||||
return fmt.Errorf("ping postgres: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// headerUserID is the identity header the gateway injects after resolving a
|
||||
// session to an internal account.
|
||||
const headerUserID = "X-User-ID"
|
||||
|
||||
// contextKey is an unexported type for request-context keys set by this package.
|
||||
type contextKey string
|
||||
|
||||
const userIDContextKey contextKey = "scrabble.user_id"
|
||||
|
||||
// RequireUserID returns middleware that requires a valid X-User-ID header and
|
||||
// stores the parsed account id in the request context. Requests without a
|
||||
// parseable UUID are rejected with 401. The backend treats X-User-ID as the
|
||||
// sole identity input and never derives identity from the request body.
|
||||
func RequireUserID() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
id, err := uuid.Parse(c.GetHeader(headerUserID))
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing or invalid X-User-ID"})
|
||||
return
|
||||
}
|
||||
c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), userIDContextKey, id))
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// UserIDFromContext returns the authenticated account id stored by
|
||||
// RequireUserID, and whether it was present.
|
||||
func UserIDFromContext(ctx context.Context) (uuid.UUID, bool) {
|
||||
id, ok := ctx.Value(userIDContextKey).(uuid.UUID)
|
||||
return id, ok
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// TestRequireUserID checks that the middleware accepts a valid X-User-ID,
|
||||
// exposes it through the request context, and rejects missing or malformed
|
||||
// headers with 401.
|
||||
func TestRequireUserID(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
var seen uuid.UUID
|
||||
var ok bool
|
||||
r := gin.New()
|
||||
r.Use(RequireUserID())
|
||||
r.GET("/x", func(c *gin.Context) {
|
||||
seen, ok = UserIDFromContext(c.Request.Context())
|
||||
c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
seen, ok = uuid.Nil, false
|
||||
id := uuid.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/x", nil)
|
||||
req.Header.Set("X-User-ID", id.String())
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want 200", rec.Code)
|
||||
}
|
||||
if !ok || seen != id {
|
||||
t.Fatalf("context id = %s (ok=%v), want %s", seen, ok, id)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing", func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/x", nil))
|
||||
if rec.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("status = %d, want 401", rec.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("malformed", func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/x", nil)
|
||||
req.Header.Set("X-User-ID", "not-a-uuid")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("status = %d, want 401", rec.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,53 +1,146 @@
|
||||
// Package server wires the backend's HTTP listener: the gin engine, its route
|
||||
// groups and the start/stop lifecycle. At this stage it serves only the
|
||||
// infrastructure probes; the domain route groups described in PLAN.md
|
||||
// (/api/v1/public, /user, /internal, /admin) are added by later stages.
|
||||
// groups, the per-request telemetry middleware and the start/stop lifecycle.
|
||||
//
|
||||
// The /api/v1 route groups (public, user, internal, admin) are created here so
|
||||
// later stages attach their endpoints to a stable structure; the /user group
|
||||
// requires the X-User-ID identity header. The probes /healthz (liveness) and
|
||||
// /readyz (database + session-cache readiness) are unauthenticated.
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"scrabble/backend/internal/telemetry"
|
||||
)
|
||||
|
||||
// shutdownTimeout bounds how long Run waits for in-flight requests to finish
|
||||
// during a graceful shutdown.
|
||||
const shutdownTimeout = 10 * time.Second
|
||||
|
||||
// Server owns the gin engine and the underlying HTTP server.
|
||||
type Server struct {
|
||||
log *zap.Logger
|
||||
http *http.Server
|
||||
// defaultPingTimeout bounds the /readyz database ping when Deps.PingTimeout is
|
||||
// not set.
|
||||
const defaultPingTimeout = 5 * time.Second
|
||||
|
||||
// Deps carries the runtime dependencies the HTTP layer needs.
|
||||
type Deps struct {
|
||||
// Logger receives lifecycle, request and readiness diagnostics.
|
||||
Logger *zap.Logger
|
||||
// DB backs the /readyz database ping. A nil DB skips the database check.
|
||||
DB *sql.DB
|
||||
// PingTimeout bounds the /readyz database ping.
|
||||
PingTimeout time.Duration
|
||||
// SessionsReady reports whether the session cache has been warmed. A nil
|
||||
// func skips the session-readiness check.
|
||||
SessionsReady func() bool
|
||||
}
|
||||
|
||||
// New returns a Server that will listen on addr. The logger receives lifecycle
|
||||
// and request diagnostics.
|
||||
func New(addr string, log *zap.Logger) *Server {
|
||||
// Server owns the gin engine, the underlying HTTP server and the readiness
|
||||
// dependencies.
|
||||
type Server struct {
|
||||
log *zap.Logger
|
||||
http *http.Server
|
||||
db *sql.DB
|
||||
pingTimeout time.Duration
|
||||
sessionsReady func() bool
|
||||
|
||||
public *gin.RouterGroup
|
||||
user *gin.RouterGroup
|
||||
internal *gin.RouterGroup
|
||||
admin *gin.RouterGroup
|
||||
}
|
||||
|
||||
// New returns a Server that will listen on addr. It installs the recovery and
|
||||
// telemetry middleware, the infrastructure probes, and the /api/v1 route groups.
|
||||
func New(addr string, deps Deps) *Server {
|
||||
log := deps.Logger
|
||||
if log == nil {
|
||||
log = zap.NewNop()
|
||||
}
|
||||
pingTimeout := deps.PingTimeout
|
||||
if pingTimeout <= 0 {
|
||||
pingTimeout = defaultPingTimeout
|
||||
}
|
||||
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
engine := gin.New()
|
||||
engine.Use(gin.Recovery())
|
||||
registerProbes(engine)
|
||||
engine.Use(telemetry.Middleware(log))
|
||||
|
||||
return &Server{
|
||||
log: log,
|
||||
http: &http.Server{Addr: addr, Handler: engine},
|
||||
s := &Server{
|
||||
log: log,
|
||||
db: deps.DB,
|
||||
pingTimeout: pingTimeout,
|
||||
sessionsReady: deps.SessionsReady,
|
||||
http: &http.Server{Addr: addr, Handler: engine},
|
||||
}
|
||||
s.registerProbes(engine)
|
||||
s.registerAPIGroups(engine)
|
||||
return s
|
||||
}
|
||||
|
||||
// registerProbes installs the unauthenticated infrastructure probes: /healthz
|
||||
// reports process liveness and /readyz reports readiness to serve traffic.
|
||||
// Until later stages add real dependencies (Postgres, warmed caches),
|
||||
// readiness mirrors liveness.
|
||||
func registerProbes(engine *gin.Engine) {
|
||||
ok := func(c *gin.Context) { c.String(http.StatusOK, "ok") }
|
||||
engine.GET("/healthz", ok)
|
||||
engine.GET("/readyz", ok)
|
||||
// reports process liveness and /readyz reports readiness to serve traffic
|
||||
// (database reachable and session cache warmed).
|
||||
func (s *Server) registerProbes(engine *gin.Engine) {
|
||||
engine.GET("/healthz", func(c *gin.Context) { c.String(http.StatusOK, "ok") })
|
||||
engine.GET("/readyz", s.readyz)
|
||||
}
|
||||
|
||||
// readyz reports 200 only when the database answers a bounded ping and the
|
||||
// session cache is warmed; otherwise 503.
|
||||
func (s *Server) readyz(c *gin.Context) {
|
||||
if s.db != nil {
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), s.pingTimeout)
|
||||
defer cancel()
|
||||
if err := s.db.PingContext(ctx); err != nil {
|
||||
s.log.Warn("readiness: database ping failed", zap.Error(err))
|
||||
c.String(http.StatusServiceUnavailable, "database unavailable")
|
||||
return
|
||||
}
|
||||
}
|
||||
if s.sessionsReady != nil && !s.sessionsReady() {
|
||||
c.String(http.StatusServiceUnavailable, "sessions not ready")
|
||||
return
|
||||
}
|
||||
c.String(http.StatusOK, "ok")
|
||||
}
|
||||
|
||||
// registerAPIGroups wires the /api/v1 route groups. They are populated by the
|
||||
// stages that add their first endpoint; the /user group requires X-User-ID,
|
||||
// which the gateway injects after resolving a session.
|
||||
func (s *Server) registerAPIGroups(engine *gin.Engine) {
|
||||
v1 := engine.Group("/api/v1")
|
||||
s.public = v1.Group("/public")
|
||||
s.user = v1.Group("/user")
|
||||
s.user.Use(RequireUserID())
|
||||
s.internal = v1.Group("/internal")
|
||||
s.admin = v1.Group("/admin")
|
||||
}
|
||||
|
||||
// PublicGroup returns the unauthenticated public route group.
|
||||
func (s *Server) PublicGroup() *gin.RouterGroup { return s.public }
|
||||
|
||||
// UserGroup returns the authenticated user route group (requires X-User-ID).
|
||||
func (s *Server) UserGroup() *gin.RouterGroup { return s.user }
|
||||
|
||||
// InternalGroup returns the gateway-facing internal route group.
|
||||
func (s *Server) InternalGroup() *gin.RouterGroup { return s.internal }
|
||||
|
||||
// AdminGroup returns the admin route group (authenticated at the gateway).
|
||||
func (s *Server) AdminGroup() *gin.RouterGroup { return s.admin }
|
||||
|
||||
// Handler returns the underlying HTTP handler. It lets tests drive the server
|
||||
// without binding a socket and lets later stages compose the backend behind
|
||||
// another listener.
|
||||
func (s *Server) Handler() http.Handler { return s.http.Handler }
|
||||
|
||||
// Run starts the listener and blocks until ctx is cancelled, then shuts the
|
||||
// server down gracefully within shutdownTimeout. It returns the first error
|
||||
// that is not the expected http.ErrServerClosed.
|
||||
|
||||
@@ -8,15 +8,51 @@ import (
|
||||
"go.uber.org/zap/zaptest"
|
||||
)
|
||||
|
||||
// TestProbes verifies that the infrastructure probes answer 200 OK.
|
||||
func TestProbes(t *testing.T) {
|
||||
srv := New(":0", zaptest.NewLogger(t))
|
||||
for _, path := range []string{"/healthz", "/readyz"} {
|
||||
req := httptest.NewRequest(http.MethodGet, path, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
srv.http.Handler.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("%s: status = %d, want %d", path, rec.Code, http.StatusOK)
|
||||
}
|
||||
// get serves a GET request against the server's handler and returns the recorder.
|
||||
func get(srv *Server, path string) *httptest.ResponseRecorder {
|
||||
req := httptest.NewRequest(http.MethodGet, path, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
srv.Handler().ServeHTTP(rec, req)
|
||||
return rec
|
||||
}
|
||||
|
||||
// TestHealthz verifies that /healthz answers 200 OK.
|
||||
func TestHealthz(t *testing.T) {
|
||||
srv := New(":0", Deps{Logger: zaptest.NewLogger(t)})
|
||||
if rec := get(srv, "/healthz"); rec.Code != http.StatusOK {
|
||||
t.Fatalf("/healthz status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
// TestReadyzReadyWithoutDeps verifies that, with no database and no session
|
||||
// readiness gate wired, /readyz answers 200 OK.
|
||||
func TestReadyzReadyWithoutDeps(t *testing.T) {
|
||||
srv := New(":0", Deps{Logger: zaptest.NewLogger(t)})
|
||||
if rec := get(srv, "/readyz"); rec.Code != http.StatusOK {
|
||||
t.Fatalf("/readyz status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
// TestReadyzNotReadyWhenSessionsCold verifies that /readyz answers 503 while the
|
||||
// session cache reports not-ready.
|
||||
func TestReadyzNotReadyWhenSessionsCold(t *testing.T) {
|
||||
srv := New(":0", Deps{
|
||||
Logger: zaptest.NewLogger(t),
|
||||
SessionsReady: func() bool { return false },
|
||||
})
|
||||
if rec := get(srv, "/readyz"); rec.Code != http.StatusServiceUnavailable {
|
||||
t.Fatalf("/readyz status = %d, want %d", rec.Code, http.StatusServiceUnavailable)
|
||||
}
|
||||
}
|
||||
|
||||
// TestReadyzReadyWhenSessionsWarm verifies that /readyz answers 200 once the
|
||||
// session cache reports ready (and no database is wired).
|
||||
func TestReadyzReadyWhenSessionsWarm(t *testing.T) {
|
||||
srv := New(":0", Deps{
|
||||
Logger: zaptest.NewLogger(t),
|
||||
SessionsReady: func() bool { return true },
|
||||
})
|
||||
if rec := get(srv, "/readyz"); rec.Code != http.StatusOK {
|
||||
t.Fatalf("/readyz status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,95 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// Cache is the in-memory write-through projection of the active rows in
|
||||
// backend.sessions, keyed by token hash so Resolve avoids a database round-trip
|
||||
// on the hot path. Reads are RLocked; writes are Locked. Callers commit the
|
||||
// corresponding database write before invoking Add or Remove so the cache stays
|
||||
// consistent with the persisted state.
|
||||
type Cache struct {
|
||||
mu sync.RWMutex
|
||||
byHash map[string]Session
|
||||
ready atomic.Bool
|
||||
}
|
||||
|
||||
// NewCache constructs an empty Cache. It reports Ready() == false until Warm
|
||||
// completes successfully.
|
||||
func NewCache() *Cache {
|
||||
return &Cache{byHash: make(map[string]Session)}
|
||||
}
|
||||
|
||||
// Warm replaces the cache contents with every active session loaded from store.
|
||||
// It is intended to run once at process boot before the listener accepts
|
||||
// traffic; success flips Ready to true. Re-warming is supported (useful in
|
||||
// tests).
|
||||
func (c *Cache) Warm(ctx context.Context, store *Store) error {
|
||||
sessions, err := store.ListActive(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.byHash = make(map[string]Session, len(sessions))
|
||||
for _, s := range sessions {
|
||||
c.byHash[s.TokenHash] = s
|
||||
}
|
||||
c.ready.Store(true)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ready reports whether Warm has completed at least once. The /readyz probe
|
||||
// wires through this so the backend only reports ready once sessions are
|
||||
// hydrated.
|
||||
func (c *Cache) Ready() bool {
|
||||
if c == nil {
|
||||
return false
|
||||
}
|
||||
return c.ready.Load()
|
||||
}
|
||||
|
||||
// Size returns the number of cached active sessions, for startup logs and tests.
|
||||
func (c *Cache) Size() int {
|
||||
if c == nil {
|
||||
return 0
|
||||
}
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return len(c.byHash)
|
||||
}
|
||||
|
||||
// Get returns the session for tokenHash and a presence flag. A miss returns the
|
||||
// zero Session and false.
|
||||
func (c *Cache) Get(tokenHash string) (Session, bool) {
|
||||
if c == nil {
|
||||
return Session{}, false
|
||||
}
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
s, ok := c.byHash[tokenHash]
|
||||
return s, ok
|
||||
}
|
||||
|
||||
// Add stores s under its token hash. It is safe to call on an existing entry.
|
||||
func (c *Cache) Add(s Session) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.byHash[s.TokenHash] = s
|
||||
}
|
||||
|
||||
// Remove evicts the entry for tokenHash. Removing a missing entry is a no-op.
|
||||
func (c *Cache) Remove(tokenHash string) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
delete(c.byHash, tokenHash)
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// TestCache exercises the write-through cache's add/get/remove/size/ready cycle.
|
||||
func TestCache(t *testing.T) {
|
||||
c := NewCache()
|
||||
if c.Ready() {
|
||||
t.Error("a fresh cache must not report ready")
|
||||
}
|
||||
if _, ok := c.Get("h1"); ok {
|
||||
t.Error("get on empty cache must miss")
|
||||
}
|
||||
|
||||
s := Session{ID: uuid.New(), AccountID: uuid.New(), TokenHash: "h1", Status: StatusActive}
|
||||
c.Add(s)
|
||||
if got, ok := c.Get("h1"); !ok || got.ID != s.ID {
|
||||
t.Fatalf("get after add: got %v ok=%v", got.ID, ok)
|
||||
}
|
||||
if c.Size() != 1 {
|
||||
t.Errorf("size = %d, want 1", c.Size())
|
||||
}
|
||||
|
||||
c.Remove("h1")
|
||||
if _, ok := c.Get("h1"); ok {
|
||||
t.Error("get after remove must miss")
|
||||
}
|
||||
if c.Size() != 0 {
|
||||
t.Errorf("size = %d, want 0", c.Size())
|
||||
}
|
||||
}
|
||||
|
||||
// TestCacheNilSafe checks that the cache methods are safe on a nil receiver,
|
||||
// which the readiness probe relies on before the cache is constructed.
|
||||
func TestCacheNilSafe(t *testing.T) {
|
||||
var c *Cache
|
||||
if c.Ready() {
|
||||
t.Error("nil cache must not be ready")
|
||||
}
|
||||
if _, ok := c.Get("x"); ok {
|
||||
t.Error("nil cache get must miss")
|
||||
}
|
||||
if c.Size() != 0 {
|
||||
t.Error("nil cache size must be 0")
|
||||
}
|
||||
c.Add(Session{TokenHash: "x"}) // must not panic
|
||||
c.Remove("x") // must not panic
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// Service mints, resolves, and revokes sessions over the store and the
|
||||
// write-through cache. The gateway is its only caller (from a later stage); the
|
||||
// HTTP surface is wired then.
|
||||
type Service struct {
|
||||
store *Store
|
||||
cache *Cache
|
||||
}
|
||||
|
||||
// NewService constructs a Service over store and cache.
|
||||
func NewService(store *Store, cache *Cache) *Service {
|
||||
return &Service{store: store, cache: cache}
|
||||
}
|
||||
|
||||
// Warm hydrates the cache from the store. Call once before serving traffic.
|
||||
func (svc *Service) Warm(ctx context.Context) error {
|
||||
return svc.cache.Warm(ctx, svc.store)
|
||||
}
|
||||
|
||||
// Ready reports whether the session cache has been warmed.
|
||||
func (svc *Service) Ready() bool {
|
||||
return svc.cache.Ready()
|
||||
}
|
||||
|
||||
// Create mints a new active session for accountID and returns the plaintext
|
||||
// token (shown to the caller once) together with the persisted session.
|
||||
func (svc *Service) Create(ctx context.Context, accountID uuid.UUID) (string, Session, error) {
|
||||
token, tokenHash, err := GenerateToken()
|
||||
if err != nil {
|
||||
return "", Session{}, err
|
||||
}
|
||||
sess, err := svc.store.Insert(ctx, accountID, tokenHash)
|
||||
if err != nil {
|
||||
return "", Session{}, err
|
||||
}
|
||||
svc.cache.Add(sess)
|
||||
return token, sess, nil
|
||||
}
|
||||
|
||||
// Resolve maps a presented token to its active session, consulting the cache
|
||||
// first and falling back to the store (repopulating the cache on a hit).
|
||||
// Returns ErrNotFound when no active session matches.
|
||||
func (svc *Service) Resolve(ctx context.Context, token string) (Session, error) {
|
||||
hash := HashToken(token)
|
||||
if sess, ok := svc.cache.Get(hash); ok {
|
||||
return sess, nil
|
||||
}
|
||||
sess, err := svc.store.FindActiveByTokenHash(ctx, hash)
|
||||
if err != nil {
|
||||
return Session{}, err
|
||||
}
|
||||
svc.cache.Add(sess)
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
// Revoke revokes the session for the presented token. It is idempotent:
|
||||
// revoking an unknown or already-revoked token returns nil.
|
||||
func (svc *Service) Revoke(ctx context.Context, token string) error {
|
||||
hash := HashToken(token)
|
||||
if _, _, err := svc.store.RevokeByTokenHash(ctx, hash, time.Now().UTC()); err != nil {
|
||||
return err
|
||||
}
|
||||
svc.cache.Remove(hash)
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,149 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/go-jet/jet/v2/postgres"
|
||||
"github.com/go-jet/jet/v2/qrm"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"scrabble/backend/internal/postgres/jet/backend/model"
|
||||
"scrabble/backend/internal/postgres/jet/backend/table"
|
||||
)
|
||||
|
||||
// Session lifecycle statuses persisted in the status column.
|
||||
const (
|
||||
StatusActive = "active"
|
||||
StatusRevoked = "revoked"
|
||||
)
|
||||
|
||||
// ErrNotFound is returned when no active session matches the lookup.
|
||||
var ErrNotFound = errors.New("session: not found")
|
||||
|
||||
// Session mirrors a row in backend.sessions. TokenHash is the hex-encoded
|
||||
// SHA-256 of the bearer token.
|
||||
type Session struct {
|
||||
ID uuid.UUID
|
||||
AccountID uuid.UUID
|
||||
TokenHash string
|
||||
Status string
|
||||
CreatedAt time.Time
|
||||
LastSeenAt *time.Time
|
||||
RevokedAt *time.Time
|
||||
}
|
||||
|
||||
// Store is the Postgres-backed query surface for backend.sessions.
|
||||
type Store struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewStore constructs a Store wrapping db.
|
||||
func NewStore(db *sql.DB) *Store {
|
||||
return &Store{db: db}
|
||||
}
|
||||
|
||||
// Insert persists a new active session for accountID carrying tokenHash and
|
||||
// returns the persisted row.
|
||||
func (s *Store) Insert(ctx context.Context, accountID uuid.UUID, tokenHash string) (Session, error) {
|
||||
id, err := uuid.NewV7()
|
||||
if err != nil {
|
||||
return Session{}, fmt.Errorf("session: new id: %w", err)
|
||||
}
|
||||
stmt := table.Sessions.INSERT(
|
||||
table.Sessions.SessionID,
|
||||
table.Sessions.AccountID,
|
||||
table.Sessions.TokenHash,
|
||||
).VALUES(id, accountID, tokenHash).RETURNING(table.Sessions.AllColumns)
|
||||
|
||||
var row model.Sessions
|
||||
if err := stmt.QueryContext(ctx, s.db, &row); err != nil {
|
||||
return Session{}, fmt.Errorf("session: insert: %w", err)
|
||||
}
|
||||
return modelToSession(row), nil
|
||||
}
|
||||
|
||||
// FindActiveByTokenHash returns the active session matching tokenHash, or
|
||||
// ErrNotFound.
|
||||
func (s *Store) FindActiveByTokenHash(ctx context.Context, tokenHash string) (Session, error) {
|
||||
stmt := postgres.SELECT(table.Sessions.AllColumns).
|
||||
FROM(table.Sessions).
|
||||
WHERE(
|
||||
table.Sessions.TokenHash.EQ(postgres.String(tokenHash)).
|
||||
AND(table.Sessions.Status.EQ(postgres.String(StatusActive))),
|
||||
).
|
||||
LIMIT(1)
|
||||
|
||||
var row model.Sessions
|
||||
if err := stmt.QueryContext(ctx, s.db, &row); err != nil {
|
||||
if errors.Is(err, qrm.ErrNoRows) {
|
||||
return Session{}, ErrNotFound
|
||||
}
|
||||
return Session{}, fmt.Errorf("session: find by token hash: %w", err)
|
||||
}
|
||||
return modelToSession(row), nil
|
||||
}
|
||||
|
||||
// RevokeByTokenHash transitions the active session for tokenHash to revoked and
|
||||
// returns the post-update row. ok is false with a nil error when no active
|
||||
// session matched, so revocation is idempotent.
|
||||
func (s *Store) RevokeByTokenHash(ctx context.Context, tokenHash string, at time.Time) (Session, bool, error) {
|
||||
stmt := table.Sessions.
|
||||
UPDATE(table.Sessions.Status, table.Sessions.RevokedAt).
|
||||
SET(postgres.String(StatusRevoked), postgres.TimestampzT(at)).
|
||||
WHERE(
|
||||
table.Sessions.TokenHash.EQ(postgres.String(tokenHash)).
|
||||
AND(table.Sessions.Status.EQ(postgres.String(StatusActive))),
|
||||
).
|
||||
RETURNING(table.Sessions.AllColumns)
|
||||
|
||||
var row model.Sessions
|
||||
if err := stmt.QueryContext(ctx, s.db, &row); err != nil {
|
||||
if errors.Is(err, qrm.ErrNoRows) {
|
||||
return Session{}, false, nil
|
||||
}
|
||||
return Session{}, false, fmt.Errorf("session: revoke by token hash: %w", err)
|
||||
}
|
||||
return modelToSession(row), true, nil
|
||||
}
|
||||
|
||||
// ListActive loads every active session. Cache.Warm calls this at boot.
|
||||
func (s *Store) ListActive(ctx context.Context) ([]Session, error) {
|
||||
stmt := postgres.SELECT(table.Sessions.AllColumns).
|
||||
FROM(table.Sessions).
|
||||
WHERE(table.Sessions.Status.EQ(postgres.String(StatusActive)))
|
||||
|
||||
var rows []model.Sessions
|
||||
if err := stmt.QueryContext(ctx, s.db, &rows); err != nil {
|
||||
return nil, fmt.Errorf("session: list active: %w", err)
|
||||
}
|
||||
out := make([]Session, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
out = append(out, modelToSession(row))
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// modelToSession projects a generated model row into the public Session struct,
|
||||
// copying pointer fields so callers cannot mutate the scan buffer.
|
||||
func modelToSession(row model.Sessions) Session {
|
||||
s := Session{
|
||||
ID: row.SessionID,
|
||||
AccountID: row.AccountID,
|
||||
TokenHash: row.TokenHash,
|
||||
Status: row.Status,
|
||||
CreatedAt: row.CreatedAt,
|
||||
}
|
||||
if row.LastSeenAt != nil {
|
||||
t := *row.LastSeenAt
|
||||
s.LastSeenAt = &t
|
||||
}
|
||||
if row.RevokedAt != nil {
|
||||
t := *row.RevokedAt
|
||||
s.RevokedAt = &t
|
||||
}
|
||||
return s
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
// Package session owns opaque server sessions: minting bearer tokens, resolving
|
||||
// them to accounts through a write-through in-memory cache, and revoking them.
|
||||
// Only the SHA-256 hash of a token is persisted; the plaintext is returned to
|
||||
// the caller once and never stored. Sessions are revoke-only (no TTL).
|
||||
package session
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// tokenBytes is the entropy of an opaque session token: 256 bits.
|
||||
const tokenBytes = 32
|
||||
|
||||
// GenerateToken returns a fresh random opaque token (URL-safe base64, 256-bit)
|
||||
// together with its hex-encoded SHA-256 hash for storage. The plaintext token
|
||||
// is handed to the caller once and is never persisted.
|
||||
func GenerateToken() (token, tokenHash string, err error) {
|
||||
buf := make([]byte, tokenBytes)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return "", "", fmt.Errorf("session: read random token: %w", err)
|
||||
}
|
||||
token = base64.RawURLEncoding.EncodeToString(buf)
|
||||
return token, HashToken(token), nil
|
||||
}
|
||||
|
||||
// HashToken returns the hex-encoded SHA-256 of token. Lookups hash the presented
|
||||
// token and compare against the stored hash.
|
||||
func HashToken(token string) string {
|
||||
sum := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package session
|
||||
|
||||
import "testing"
|
||||
|
||||
// TestGenerateTokenUniqueAndHashed checks that tokens are unique, the stored
|
||||
// value is the hash (not the plaintext), and the hash is a 64-char SHA-256 hex.
|
||||
func TestGenerateTokenUniqueAndHashed(t *testing.T) {
|
||||
tok1, hash1, err := GenerateToken()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateToken: %v", err)
|
||||
}
|
||||
tok2, hash2, err := GenerateToken()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateToken: %v", err)
|
||||
}
|
||||
|
||||
if tok1 == tok2 {
|
||||
t.Error("tokens must be unique")
|
||||
}
|
||||
if hash1 == hash2 {
|
||||
t.Error("hashes must differ for distinct tokens")
|
||||
}
|
||||
if hash1 != HashToken(tok1) {
|
||||
t.Error("stored hash must equal HashToken(token)")
|
||||
}
|
||||
if tok1 == hash1 {
|
||||
t.Error("stored hash must not equal the plaintext token")
|
||||
}
|
||||
if len(hash1) != 64 {
|
||||
t.Errorf("hash length = %d, want 64 (sha256 hex)", len(hash1))
|
||||
}
|
||||
}
|
||||
|
||||
// TestHashTokenDeterministic checks that hashing is stable for a given token.
|
||||
func TestHashTokenDeterministic(t *testing.T) {
|
||||
first := HashToken("alpha")
|
||||
second := HashToken("alpha")
|
||||
if first != second {
|
||||
t.Error("HashToken must be deterministic")
|
||||
}
|
||||
if first == HashToken("beta") {
|
||||
t.Error("distinct tokens must hash differently")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
package telemetry
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/codes"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// tracerName names the instrumentation scope for backend HTTP spans.
|
||||
const tracerName = "scrabble/backend/server"
|
||||
|
||||
// Middleware returns gin middleware that, for every request, opens a server
|
||||
// span, measures server-side latency, and emits a structured access log
|
||||
// correlated with the active trace. It uses the globally-registered tracer, so
|
||||
// spans are exported only when an exporter is configured, while the timing log
|
||||
// is always emitted. Probe paths (/healthz, /readyz) log at debug level to keep
|
||||
// the default log clean.
|
||||
func Middleware(logger *zap.Logger) gin.HandlerFunc {
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
tracer := otel.Tracer(tracerName)
|
||||
|
||||
return func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
route := c.FullPath()
|
||||
if route == "" {
|
||||
route = c.Request.URL.Path
|
||||
}
|
||||
|
||||
ctx, span := tracer.Start(c.Request.Context(), c.Request.Method+" "+route)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
|
||||
c.Next()
|
||||
|
||||
status := c.Writer.Status()
|
||||
elapsed := time.Since(start)
|
||||
|
||||
span.SetAttributes(
|
||||
attribute.String("http.request.method", c.Request.Method),
|
||||
attribute.String("http.route", route),
|
||||
attribute.Int("http.response.status_code", status),
|
||||
)
|
||||
if status >= http.StatusInternalServerError {
|
||||
span.SetStatus(codes.Error, http.StatusText(status))
|
||||
}
|
||||
span.End()
|
||||
|
||||
fields := []zap.Field{
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("path", route),
|
||||
zap.Int("status", status),
|
||||
zap.Duration("latency", elapsed),
|
||||
}
|
||||
fields = append(fields, TraceFieldsFromContext(ctx)...)
|
||||
|
||||
if isProbePath(c.Request.URL.Path) {
|
||||
logger.Debug("http request", fields...)
|
||||
return
|
||||
}
|
||||
logger.Info("http request", fields...)
|
||||
}
|
||||
}
|
||||
|
||||
// isProbePath reports whether path is one of the unauthenticated infrastructure
|
||||
// probes, whose access logs are demoted to debug level.
|
||||
func isProbePath(path string) bool {
|
||||
return path == "/healthz" || path == "/readyz"
|
||||
}
|
||||
@@ -0,0 +1,204 @@
|
||||
// Package telemetry owns the OpenTelemetry runtime for the backend process.
|
||||
//
|
||||
// New constructs the configured tracer and meter providers, registers them as
|
||||
// the OpenTelemetry globals, and exposes Shutdown for orderly exit. The MVP
|
||||
// supports the `none` and `stdout` exporters; OTLP export and dashboards arrive
|
||||
// in a later stage. The per-request timing middleware lives in middleware.go and
|
||||
// uses the registered global tracer, so requests are timed and logged even when
|
||||
// the exporter is `none`.
|
||||
package telemetry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/exporters/stdout/stdoutmetric"
|
||||
"go.opentelemetry.io/otel/exporters/stdout/stdouttrace"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
"go.opentelemetry.io/otel/propagation"
|
||||
sdkmetric "go.opentelemetry.io/otel/sdk/metric"
|
||||
"go.opentelemetry.io/otel/sdk/resource"
|
||||
sdktrace "go.opentelemetry.io/otel/sdk/trace"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Exporter selectors supported by the backend.
|
||||
const (
|
||||
ExporterNone = "none"
|
||||
ExporterStdout = "stdout"
|
||||
)
|
||||
|
||||
// DefaultServiceName labels traces and metrics when BACKEND_SERVICE_NAME is
|
||||
// unset.
|
||||
const DefaultServiceName = "scrabble-backend"
|
||||
|
||||
// Config selects the telemetry providers' service name and exporters.
|
||||
type Config struct {
|
||||
// ServiceName is reported as the OpenTelemetry service.name resource.
|
||||
ServiceName string
|
||||
// TracesExporter is one of ExporterNone or ExporterStdout.
|
||||
TracesExporter string
|
||||
// MetricsExporter is one of ExporterNone or ExporterStdout.
|
||||
MetricsExporter string
|
||||
}
|
||||
|
||||
// DefaultConfig returns the MVP telemetry configuration: named service, no
|
||||
// exporters (so no collector is required locally or in CI).
|
||||
func DefaultConfig() Config {
|
||||
return Config{
|
||||
ServiceName: DefaultServiceName,
|
||||
TracesExporter: ExporterNone,
|
||||
MetricsExporter: ExporterNone,
|
||||
}
|
||||
}
|
||||
|
||||
// Validate reports whether the configuration selects supported exporters.
|
||||
func (c Config) Validate() error {
|
||||
if c.ServiceName == "" {
|
||||
return errors.New("telemetry: ServiceName must not be empty")
|
||||
}
|
||||
if err := validateExporter("traces", c.TracesExporter); err != nil {
|
||||
return err
|
||||
}
|
||||
return validateExporter("metrics", c.MetricsExporter)
|
||||
}
|
||||
|
||||
func validateExporter(kind, value string) error {
|
||||
switch value {
|
||||
case ExporterNone, ExporterStdout:
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("telemetry: unsupported %s exporter %q", kind, value)
|
||||
}
|
||||
}
|
||||
|
||||
// Runtime owns the shared OpenTelemetry providers.
|
||||
type Runtime struct {
|
||||
tracerProvider *sdktrace.TracerProvider
|
||||
meterProvider *sdkmetric.MeterProvider
|
||||
}
|
||||
|
||||
// New constructs the telemetry runtime, registers the global providers and the
|
||||
// W3C trace-context/baggage propagators, and returns the Runtime. Callers must
|
||||
// invoke Runtime.Shutdown during process exit.
|
||||
func New(ctx context.Context, cfg Config) (*Runtime, error) {
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err := resource.New(ctx, resource.WithAttributes(
|
||||
attribute.String("service.name", cfg.ServiceName),
|
||||
))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("telemetry: build resource: %w", err)
|
||||
}
|
||||
|
||||
tracerProvider, err := newTracerProvider(cfg, res)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("telemetry: build tracer provider: %w", err)
|
||||
}
|
||||
meterProvider, err := newMeterProvider(cfg, res)
|
||||
if err != nil {
|
||||
_ = tracerProvider.Shutdown(ctx)
|
||||
return nil, fmt.Errorf("telemetry: build meter provider: %w", err)
|
||||
}
|
||||
|
||||
otel.SetTracerProvider(tracerProvider)
|
||||
otel.SetMeterProvider(meterProvider)
|
||||
otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator(
|
||||
propagation.TraceContext{},
|
||||
propagation.Baggage{},
|
||||
))
|
||||
|
||||
return &Runtime{tracerProvider: tracerProvider, meterProvider: meterProvider}, nil
|
||||
}
|
||||
|
||||
// TracerProvider returns the runtime tracer provider, or the global one when r
|
||||
// is not initialised.
|
||||
func (r *Runtime) TracerProvider() trace.TracerProvider {
|
||||
if r == nil || r.tracerProvider == nil {
|
||||
return otel.GetTracerProvider()
|
||||
}
|
||||
return r.tracerProvider
|
||||
}
|
||||
|
||||
// MeterProvider returns the runtime meter provider, or the global one when r is
|
||||
// not initialised.
|
||||
func (r *Runtime) MeterProvider() metric.MeterProvider {
|
||||
if r == nil || r.meterProvider == nil {
|
||||
return otel.GetMeterProvider()
|
||||
}
|
||||
return r.meterProvider
|
||||
}
|
||||
|
||||
// Shutdown flushes both providers within ctx.
|
||||
func (r *Runtime) Shutdown(ctx context.Context) error {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
var err error
|
||||
if r.meterProvider != nil {
|
||||
err = errors.Join(err, r.meterProvider.Shutdown(ctx))
|
||||
}
|
||||
if r.tracerProvider != nil {
|
||||
err = errors.Join(err, r.tracerProvider.Shutdown(ctx))
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// TraceFieldsFromContext returns zap fields identifying the active span, or nil
|
||||
// when ctx carries no valid span context. Collocated here so callers do not
|
||||
// import the OpenTelemetry API directly.
|
||||
func TraceFieldsFromContext(ctx context.Context) []zap.Field {
|
||||
if ctx == nil {
|
||||
return nil
|
||||
}
|
||||
sc := trace.SpanContextFromContext(ctx)
|
||||
if !sc.IsValid() {
|
||||
return nil
|
||||
}
|
||||
return []zap.Field{
|
||||
zap.String("otel_trace_id", sc.TraceID().String()),
|
||||
zap.String("otel_span_id", sc.SpanID().String()),
|
||||
}
|
||||
}
|
||||
|
||||
func newTracerProvider(cfg Config, res *resource.Resource) (*sdktrace.TracerProvider, error) {
|
||||
switch cfg.TracesExporter {
|
||||
case ExporterNone:
|
||||
return sdktrace.NewTracerProvider(sdktrace.WithResource(res)), nil
|
||||
case ExporterStdout:
|
||||
exporter, err := stdouttrace.New()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stdout trace exporter: %w", err)
|
||||
}
|
||||
return sdktrace.NewTracerProvider(
|
||||
sdktrace.WithBatcher(exporter),
|
||||
sdktrace.WithResource(res),
|
||||
), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported traces exporter %q", cfg.TracesExporter)
|
||||
}
|
||||
}
|
||||
|
||||
func newMeterProvider(cfg Config, res *resource.Resource) (*sdkmetric.MeterProvider, error) {
|
||||
switch cfg.MetricsExporter {
|
||||
case ExporterNone:
|
||||
return sdkmetric.NewMeterProvider(sdkmetric.WithResource(res)), nil
|
||||
case ExporterStdout:
|
||||
exporter, err := stdoutmetric.New()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stdout metric exporter: %w", err)
|
||||
}
|
||||
return sdkmetric.NewMeterProvider(
|
||||
sdkmetric.WithResource(res),
|
||||
sdkmetric.WithReader(sdkmetric.NewPeriodicReader(exporter)),
|
||||
), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported metrics exporter %q", cfg.MetricsExporter)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
package telemetry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestConfigValidate covers the supported and rejected exporter selections.
|
||||
func TestConfigValidate(t *testing.T) {
|
||||
if err := DefaultConfig().Validate(); err != nil {
|
||||
t.Fatalf("default config must be valid: %v", err)
|
||||
}
|
||||
|
||||
otlp := DefaultConfig()
|
||||
otlp.TracesExporter = "otlp"
|
||||
if err := otlp.Validate(); err == nil {
|
||||
t.Error("otlp exporter must be rejected in the MVP set")
|
||||
}
|
||||
|
||||
noName := DefaultConfig()
|
||||
noName.ServiceName = ""
|
||||
if err := noName.Validate(); err == nil {
|
||||
t.Error("empty service name must be rejected")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewNoneAndShutdown builds the providers with the none exporter and shuts
|
||||
// them down.
|
||||
func TestNewNoneAndShutdown(t *testing.T) {
|
||||
rt, err := New(context.Background(), DefaultConfig())
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
if rt.TracerProvider() == nil {
|
||||
t.Error("TracerProvider must not be nil")
|
||||
}
|
||||
if rt.MeterProvider() == nil {
|
||||
t.Error("MeterProvider must not be nil")
|
||||
}
|
||||
if err := rt.Shutdown(context.Background()); err != nil {
|
||||
t.Errorf("Shutdown: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewStdoutTraces builds the providers with the stdout trace exporter; no
|
||||
// spans are recorded, so shutdown flushes nothing to stdout.
|
||||
func TestNewStdoutTraces(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
cfg.TracesExporter = ExporterStdout
|
||||
rt, err := New(context.Background(), cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
if err := rt.Shutdown(context.Background()); err != nil {
|
||||
t.Errorf("Shutdown: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTraceFieldsFromContextEmpty returns no fields without an active span.
|
||||
func TestTraceFieldsFromContextEmpty(t *testing.T) {
|
||||
if TraceFieldsFromContext(context.Background()) != nil {
|
||||
t.Error("expected nil fields without an active span")
|
||||
}
|
||||
var nilCtx context.Context
|
||||
if TraceFieldsFromContext(nilCtx) != nil {
|
||||
t.Error("expected nil fields for a nil context")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNilRuntime checks the nil-receiver fallbacks used before initialisation.
|
||||
func TestNilRuntime(t *testing.T) {
|
||||
var rt *Runtime
|
||||
if rt.TracerProvider() == nil {
|
||||
t.Error("nil runtime must fall back to the global tracer provider")
|
||||
}
|
||||
if rt.MeterProvider() == nil {
|
||||
t.Error("nil runtime must fall back to the global meter provider")
|
||||
}
|
||||
if err := rt.Shutdown(context.Background()); err != nil {
|
||||
t.Errorf("nil runtime Shutdown: %v", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user