feat: backend service
This commit is contained in:
@@ -0,0 +1,36 @@
|
||||
package geo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"galaxy/backend/internal/postgres/jet/backend/table"
|
||||
|
||||
"github.com/go-jet/jet/v2/postgres"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// OnUserDeleted removes every `backend.user_country_counters` row for
|
||||
// userID. It is the geo-side leg of the soft-delete cascade documented
|
||||
// in `backend/PLAN.md` §5.2 / §5.8 and is invoked from
|
||||
// `backend/internal/user.Service.SoftDelete` after the
|
||||
// `accounts.deleted_at` write commits.
|
||||
//
|
||||
// The DELETE is idempotent: re-running on a user with no counters is a
|
||||
// successful no-op. Errors from the database are wrapped with the geo
|
||||
// prefix so caller logs identify the source.
|
||||
func (s *Service) OnUserDeleted(ctx context.Context, userID uuid.UUID) error {
|
||||
if s == nil {
|
||||
return errors.New("geo: nil service")
|
||||
}
|
||||
if userID == uuid.Nil {
|
||||
return errors.New("geo: nil user id")
|
||||
}
|
||||
stmt := table.UserCountryCounters.DELETE().
|
||||
WHERE(table.UserCountryCounters.UserID.EQ(postgres.UUID(userID)))
|
||||
if _, err := stmt.ExecContext(ctx, s.db); err != nil {
|
||||
return fmt.Errorf("geo: delete counters for %s: %w", userID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,136 @@
|
||||
package geo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"galaxy/backend/internal/postgres/jet/backend/model"
|
||||
"galaxy/backend/internal/postgres/jet/backend/table"
|
||||
|
||||
"github.com/go-jet/jet/v2/postgres"
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// counterUpsertTimeout bounds the database call performed by a single
|
||||
// fire-and-forget counter goroutine. The upsert is a single statement on
|
||||
// a tiny table and should complete in well under a second; the timeout
|
||||
// exists to keep one slow Postgres node from accumulating leaked
|
||||
// goroutines under load.
|
||||
const counterUpsertTimeout = 5 * time.Second
|
||||
|
||||
// CountryCounter is one row from `backend.user_country_counters` exposed
|
||||
// to the admin surface (`GET /api/v1/admin/geo/users/{user_id}/countries`).
|
||||
//
|
||||
// Country is the uppercase ISO 3166-1 alpha-2 code stored alongside the
|
||||
// running count. LastSeenAt is nullable on the table and therefore
|
||||
// optional; the admin response surfaces null when it is unset.
|
||||
type CountryCounter struct {
|
||||
Country string
|
||||
Count int64
|
||||
LastSeenAt *time.Time
|
||||
}
|
||||
|
||||
// IncrementCounterAsync upserts the per-country counter for userID as a
|
||||
// fire-and-forget goroutine: the country lookup is performed
|
||||
// synchronously (it is pure CPU plus an mmap read), then a goroutine
|
||||
// runs the database upsert against the Service-internal background
|
||||
// context. The caller never blocks on the database round-trip and never
|
||||
// observes errors directly — failures are logged via the Service logger
|
||||
// configured through SetLogger.
|
||||
//
|
||||
// Inputs that yield no useful data short-circuit without launching the
|
||||
// goroutine: a nil receiver, a zero userID, an empty sourceIP, or a
|
||||
// failed country lookup all return immediately. A Service whose
|
||||
// background context has already been cancelled (typically because Drain
|
||||
// or Close ran) also short-circuits — counters are not started during
|
||||
// shutdown, but live ones are awaited by Drain.
|
||||
//
|
||||
// The ctx parameter is intentionally unused for the database call: the
|
||||
// request-scoped context is cancelled the moment the response is
|
||||
// flushed to the gateway, which would race with the upsert. The
|
||||
// goroutine derives its context from the Service-internal one
|
||||
// instead.
|
||||
func (s *Service) IncrementCounterAsync(_ context.Context, userID uuid.UUID, sourceIP string) {
|
||||
if s == nil || userID == uuid.Nil || sourceIP == "" {
|
||||
return
|
||||
}
|
||||
if s.bgCtx == nil || s.bgCtx.Err() != nil {
|
||||
return
|
||||
}
|
||||
country := s.LookupCountry(sourceIP)
|
||||
if country == "" {
|
||||
return
|
||||
}
|
||||
|
||||
s.wg.Go(func() {
|
||||
ctx, cancel := context.WithTimeout(s.bgCtx, counterUpsertTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := s.upsertCounter(ctx, userID, country); err != nil {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return
|
||||
}
|
||||
s.logger.Warn("counter upsert failed",
|
||||
zap.String("user_id", userID.String()),
|
||||
zap.String("country", country),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// upsertCounter executes the atomic INSERT...ON CONFLICT against
|
||||
// `backend.user_country_counters`. The compound primary key
|
||||
// `(user_id, country)` makes the upsert race-safe across concurrent
|
||||
// goroutines.
|
||||
func (s *Service) upsertCounter(ctx context.Context, userID uuid.UUID, country string) error {
|
||||
ucc := table.UserCountryCounters
|
||||
stmt := ucc.INSERT(ucc.UserID, ucc.Country, ucc.Count, ucc.LastSeenAt).
|
||||
VALUES(userID, country, postgres.Int(1), postgres.NOW()).
|
||||
ON_CONFLICT(ucc.UserID, ucc.Country).
|
||||
DO_UPDATE(postgres.SET(
|
||||
ucc.Count.SET(ucc.Count.ADD(postgres.Int(1))),
|
||||
ucc.LastSeenAt.SET(postgres.TimestampzExp(postgres.NOW())),
|
||||
))
|
||||
if _, err := stmt.ExecContext(ctx, s.db); err != nil {
|
||||
return fmt.Errorf("geo: upsert counter for %s/%s: %w", userID, country, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListUserCounters returns every per-country counter recorded for
|
||||
// userID, ordered by country ASC. The list is empty (and the error is
|
||||
// nil) when the user has no rows; ListUserCounters does not check that
|
||||
// the user exists in `backend.accounts` because the admin surface gates
|
||||
// existence through a separate listing endpoint.
|
||||
func (s *Service) ListUserCounters(ctx context.Context, userID uuid.UUID) ([]CountryCounter, error) {
|
||||
if s == nil {
|
||||
return nil, errors.New("geo: nil service")
|
||||
}
|
||||
if userID == uuid.Nil {
|
||||
return nil, errors.New("geo: nil user id")
|
||||
}
|
||||
ucc := table.UserCountryCounters
|
||||
stmt := postgres.SELECT(ucc.Country, ucc.Count, ucc.LastSeenAt).
|
||||
FROM(ucc).
|
||||
WHERE(ucc.UserID.EQ(postgres.UUID(userID))).
|
||||
ORDER_BY(ucc.Country.ASC())
|
||||
|
||||
var dest []model.UserCountryCounters
|
||||
if err := stmt.QueryContext(ctx, s.db, &dest); err != nil {
|
||||
return nil, fmt.Errorf("geo: list counters for %s: %w", userID, err)
|
||||
}
|
||||
out := make([]CountryCounter, 0, len(dest))
|
||||
for _, row := range dest {
|
||||
entry := CountryCounter{Country: row.Country, Count: row.Count}
|
||||
if row.LastSeenAt != nil {
|
||||
ts := row.LastSeenAt.UTC()
|
||||
entry.LastSeenAt = &ts
|
||||
}
|
||||
out = append(out, entry)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
@@ -0,0 +1,320 @@
|
||||
package geo_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/backend/internal/geo"
|
||||
backendpg "galaxy/backend/internal/postgres"
|
||||
pgshared "galaxy/postgres"
|
||||
|
||||
"github.com/google/uuid"
|
||||
testcontainers "github.com/testcontainers/testcontainers-go"
|
||||
tcpostgres "github.com/testcontainers/testcontainers-go/modules/postgres"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
"go.uber.org/zap/zaptest"
|
||||
)
|
||||
|
||||
const (
|
||||
pgImage = "postgres:16-alpine"
|
||||
pgUser = "galaxy"
|
||||
pgPassword = "galaxy"
|
||||
pgDatabase = "galaxy_backend"
|
||||
pgSchema = "backend"
|
||||
pgStartup = 90 * time.Second
|
||||
pgOpTO = 10 * time.Second
|
||||
)
|
||||
|
||||
// startPostgres mirrors the auth/notification test scaffolding: spin up
|
||||
// a Postgres testcontainer, apply backend migrations, return *sql.DB.
|
||||
func startPostgres(t *testing.T) *sql.DB {
|
||||
t.Helper()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
pgContainer, err := tcpostgres.Run(ctx, pgImage,
|
||||
tcpostgres.WithDatabase(pgDatabase),
|
||||
tcpostgres.WithUsername(pgUser),
|
||||
tcpostgres.WithPassword(pgPassword),
|
||||
testcontainers.WithWaitStrategy(
|
||||
wait.ForLog("database system is ready to accept connections").
|
||||
WithOccurrence(2).
|
||||
WithStartupTimeout(pgStartup),
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
t.Skipf("postgres testcontainer unavailable, skipping: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if termErr := testcontainers.TerminateContainer(pgContainer); termErr != nil {
|
||||
t.Errorf("terminate postgres container: %v", termErr)
|
||||
}
|
||||
})
|
||||
|
||||
baseDSN, err := pgContainer.ConnectionString(ctx, "sslmode=disable")
|
||||
if err != nil {
|
||||
t.Fatalf("connection string: %v", err)
|
||||
}
|
||||
scoped, err := dsnWithSearchPath(baseDSN, pgSchema)
|
||||
if err != nil {
|
||||
t.Fatalf("scope dsn: %v", err)
|
||||
}
|
||||
cfg := pgshared.DefaultConfig()
|
||||
cfg.PrimaryDSN = scoped
|
||||
cfg.OperationTimeout = pgOpTO
|
||||
db, err := pgshared.OpenPrimary(ctx, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("open primary: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
if err := pgshared.Ping(ctx, db, cfg.OperationTimeout); err != nil {
|
||||
t.Fatalf("ping: %v", err)
|
||||
}
|
||||
if err := backendpg.ApplyMigrations(ctx, db); err != nil {
|
||||
t.Fatalf("apply migrations: %v", err)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
func dsnWithSearchPath(baseDSN, schema string) (string, error) {
|
||||
parsed, err := url.Parse(baseDSN)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
values := parsed.Query()
|
||||
values.Set("search_path", schema)
|
||||
if values.Get("sslmode") == "" {
|
||||
values.Set("sslmode", "disable")
|
||||
}
|
||||
parsed.RawQuery = values.Encode()
|
||||
return parsed.String(), nil
|
||||
}
|
||||
|
||||
// fixtureService constructs a Service that uses an injected database
|
||||
// pool and skips the GeoLite2 resolver — the resolver is exercised by
|
||||
// `pkg/geoip` tests, while the counter path under test is independent
|
||||
// of the lookup. The caller is responsible for invoking Drain/Close.
|
||||
func fixtureService(t *testing.T, db *sql.DB) *geo.Service {
|
||||
t.Helper()
|
||||
svc, err := geo.NewServiceForTest(db)
|
||||
if err != nil {
|
||||
t.Fatalf("new service: %v", err)
|
||||
}
|
||||
svc.SetLogger(zaptest.NewLogger(t))
|
||||
return svc
|
||||
}
|
||||
|
||||
func TestIncrementCounterAsyncCreatesRow(t *testing.T) {
|
||||
db := startPostgres(t)
|
||||
svc := fixtureService(t, db)
|
||||
t.Cleanup(func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
|
||||
defer cancel()
|
||||
svc.Drain(ctx)
|
||||
_ = svc.Close()
|
||||
})
|
||||
|
||||
userID := uuid.New()
|
||||
svc.IncrementCounterTestSync(t, userID, "DE")
|
||||
|
||||
count, lastSeen := readCounter(t, db, userID, "DE")
|
||||
if count != 1 {
|
||||
t.Fatalf("count: want 1, got %d", count)
|
||||
}
|
||||
if lastSeen == nil {
|
||||
t.Fatal("last_seen_at: want non-null, got null")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIncrementCounterAsyncIncrementsExistingRow(t *testing.T) {
|
||||
db := startPostgres(t)
|
||||
svc := fixtureService(t, db)
|
||||
t.Cleanup(func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
|
||||
defer cancel()
|
||||
svc.Drain(ctx)
|
||||
_ = svc.Close()
|
||||
})
|
||||
|
||||
userID := uuid.New()
|
||||
svc.IncrementCounterTestSync(t, userID, "DE")
|
||||
_, firstSeen := readCounter(t, db, userID, "DE")
|
||||
if firstSeen == nil {
|
||||
t.Fatal("first last_seen_at: want non-null")
|
||||
}
|
||||
|
||||
// Sleep long enough for now() to advance past Postgres timestamp
|
||||
// resolution (microseconds in practice).
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
|
||||
svc.IncrementCounterTestSync(t, userID, "DE")
|
||||
count, secondSeen := readCounter(t, db, userID, "DE")
|
||||
if count != 2 {
|
||||
t.Fatalf("count: want 2, got %d", count)
|
||||
}
|
||||
if secondSeen == nil || !secondSeen.After(*firstSeen) {
|
||||
t.Fatalf("last_seen_at: want strictly later than %v, got %v", firstSeen, secondSeen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIncrementCounterAsyncShortCircuits(t *testing.T) {
|
||||
db := startPostgres(t)
|
||||
svc := fixtureService(t, db)
|
||||
t.Cleanup(func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
|
||||
defer cancel()
|
||||
svc.Drain(ctx)
|
||||
_ = svc.Close()
|
||||
})
|
||||
|
||||
// Empty country / zero user — exercise the synchronous validation
|
||||
// path through the public API to confirm no goroutine is launched.
|
||||
svc.IncrementCounterAsync(context.Background(), uuid.Nil, "1.2.3.4")
|
||||
svc.IncrementCounterAsync(context.Background(), uuid.New(), "")
|
||||
|
||||
rows := totalCounterRows(t, db)
|
||||
if rows != 0 {
|
||||
t.Fatalf("expected zero counter rows after short-circuit calls, got %d", rows)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListUserCountersOrdered(t *testing.T) {
|
||||
db := startPostgres(t)
|
||||
svc := fixtureService(t, db)
|
||||
t.Cleanup(func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
|
||||
defer cancel()
|
||||
svc.Drain(ctx)
|
||||
_ = svc.Close()
|
||||
})
|
||||
|
||||
userID := uuid.New()
|
||||
svc.IncrementCounterTestSync(t, userID, "PL")
|
||||
svc.IncrementCounterTestSync(t, userID, "DE")
|
||||
svc.IncrementCounterTestSync(t, userID, "DE")
|
||||
svc.IncrementCounterTestSync(t, userID, "AU")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
|
||||
defer cancel()
|
||||
|
||||
entries, err := svc.ListUserCounters(ctx, userID)
|
||||
if err != nil {
|
||||
t.Fatalf("list: %v", err)
|
||||
}
|
||||
if len(entries) != 3 {
|
||||
t.Fatalf("entries: want 3, got %d (%+v)", len(entries), entries)
|
||||
}
|
||||
wantOrder := []string{"AU", "DE", "PL"}
|
||||
for i, e := range entries {
|
||||
if e.Country != wantOrder[i] {
|
||||
t.Errorf("entries[%d].Country = %q, want %q", i, e.Country, wantOrder[i])
|
||||
}
|
||||
if e.LastSeenAt == nil {
|
||||
t.Errorf("entries[%d].LastSeenAt: want non-nil", i)
|
||||
}
|
||||
}
|
||||
if entries[1].Count != 2 {
|
||||
t.Errorf("entries[1].Count: want 2, got %d", entries[1].Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListUserCountersEmpty(t *testing.T) {
|
||||
db := startPostgres(t)
|
||||
svc := fixtureService(t, db)
|
||||
t.Cleanup(func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
|
||||
defer cancel()
|
||||
svc.Drain(ctx)
|
||||
_ = svc.Close()
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
|
||||
defer cancel()
|
||||
|
||||
entries, err := svc.ListUserCounters(ctx, uuid.New())
|
||||
if err != nil {
|
||||
t.Fatalf("list unknown user: %v", err)
|
||||
}
|
||||
if len(entries) != 0 {
|
||||
t.Fatalf("entries: want empty, got %+v", entries)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListUserCountersNilArguments(t *testing.T) {
|
||||
db := startPostgres(t)
|
||||
svc := fixtureService(t, db)
|
||||
t.Cleanup(func() { _ = svc.Close() })
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
|
||||
defer cancel()
|
||||
|
||||
if _, err := svc.ListUserCounters(ctx, uuid.Nil); err == nil {
|
||||
t.Fatal("ListUserCounters(uuid.Nil): want error")
|
||||
}
|
||||
|
||||
var nilSvc *geo.Service
|
||||
if _, err := nilSvc.ListUserCounters(ctx, uuid.New()); err == nil {
|
||||
t.Fatal("nil receiver ListUserCounters: want error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDrainAwaitsInFlightCounters(t *testing.T) {
|
||||
db := startPostgres(t)
|
||||
svc := fixtureService(t, db)
|
||||
|
||||
userID := uuid.New()
|
||||
// Inject country directly through the test seam so the lookup never
|
||||
// returns empty even though the resolver is unset.
|
||||
svc.IncrementCounterTestSync(t, userID, "FR")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
|
||||
defer cancel()
|
||||
svc.Drain(ctx)
|
||||
if err := svc.Close(); err != nil {
|
||||
t.Fatalf("close: %v", err)
|
||||
}
|
||||
count, _ := readCounter(t, db, userID, "FR")
|
||||
if count != 1 {
|
||||
t.Fatalf("count after drain+close: want 1, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func readCounter(t *testing.T, db *sql.DB, userID uuid.UUID, country string) (int64, *time.Time) {
|
||||
t.Helper()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
|
||||
defer cancel()
|
||||
|
||||
var (
|
||||
count int64
|
||||
lastSeenAt sql.NullTime
|
||||
)
|
||||
err := db.QueryRowContext(ctx, `
|
||||
SELECT count, last_seen_at FROM backend.user_country_counters
|
||||
WHERE user_id = $1 AND country = $2
|
||||
`, userID, country).Scan(&count, &lastSeenAt)
|
||||
if err != nil {
|
||||
t.Fatalf("read counter (%s/%s): %v", userID, country, err)
|
||||
}
|
||||
if !lastSeenAt.Valid {
|
||||
return count, nil
|
||||
}
|
||||
ts := lastSeenAt.Time.UTC()
|
||||
return count, &ts
|
||||
}
|
||||
|
||||
func totalCounterRows(t *testing.T, db *sql.DB) int {
|
||||
t.Helper()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
|
||||
defer cancel()
|
||||
|
||||
var n int
|
||||
if err := db.QueryRowContext(ctx, `
|
||||
SELECT count(*) FROM backend.user_country_counters
|
||||
`).Scan(&n); err != nil {
|
||||
t.Fatalf("count rows: %v", err)
|
||||
}
|
||||
return n
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
package geo
|
||||
|
||||
import "strings"
|
||||
|
||||
// countryToLanguage maps an uppercase ISO 3166-1 alpha-2 country code to
|
||||
// an ISO 639-1 lowercase language code. The set is intentionally minimal
|
||||
// — covering the top-traffic Galaxy locales — and is consulted as a
|
||||
// fallback when neither the request body nor the Accept-Language header
|
||||
// supplied a locale at send-email-code. Unknown countries map to the
|
||||
// empty string so the auth flow can default to "en".
|
||||
//
|
||||
// The mapping is intentionally hard-coded rather than derived from the
|
||||
// GeoLite2 database: countries with multiple official languages collapse
|
||||
// to the single most common UI locale to keep the registration path
|
||||
// deterministic. The implementation may revise this table without changing the
|
||||
// surface auth depends on.
|
||||
var countryToLanguage = map[string]string{
|
||||
// English-default territories and the platform fallback.
|
||||
"US": "en", "GB": "en", "AU": "en", "NZ": "en", "IE": "en", "CA": "en",
|
||||
// Western Europe.
|
||||
"DE": "de", "AT": "de", "CH": "de",
|
||||
"FR": "fr", "BE": "fr", "LU": "fr",
|
||||
"ES": "es", "MX": "es", "AR": "es", "CL": "es", "CO": "es",
|
||||
"IT": "it",
|
||||
"PT": "pt", "BR": "pt",
|
||||
"NL": "nl",
|
||||
// Central / Eastern Europe.
|
||||
"PL": "pl",
|
||||
"RU": "ru", "BY": "ru", "KZ": "ru",
|
||||
"UA": "uk",
|
||||
"CZ": "cs",
|
||||
"SK": "sk",
|
||||
"HU": "hu",
|
||||
"RO": "ro",
|
||||
"BG": "bg",
|
||||
// Northern Europe.
|
||||
"SE": "sv",
|
||||
"NO": "no",
|
||||
"DK": "da",
|
||||
"FI": "fi",
|
||||
// Asia.
|
||||
"JP": "ja",
|
||||
"KR": "ko",
|
||||
"CN": "zh", "TW": "zh", "HK": "zh", "SG": "zh",
|
||||
"VN": "vi",
|
||||
"TH": "th",
|
||||
"ID": "id",
|
||||
"IN": "en",
|
||||
"IL": "he",
|
||||
"TR": "tr",
|
||||
// Middle East and North Africa.
|
||||
"SA": "ar", "AE": "ar", "EG": "ar",
|
||||
}
|
||||
|
||||
// languageForCountry returns the ISO 639-1 language code mapped to
|
||||
// country, or "" when no mapping is known. country is normalised to
|
||||
// uppercase before lookup.
|
||||
func languageForCountry(country string) string {
|
||||
if country == "" {
|
||||
return ""
|
||||
}
|
||||
return countryToLanguage[strings.ToUpper(strings.TrimSpace(country))]
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
package geo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"galaxy/backend/internal/postgres/jet/backend/table"
|
||||
|
||||
"github.com/go-jet/jet/v2/postgres"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// SetDeclaredCountryAtRegistration writes the geoip-derived country to
|
||||
// `accounts.declared_country` for userID, and only when the column is
|
||||
// currently NULL. The semantics match PLAN.md §5.8: declared_country is
|
||||
// captured at first registration and never updated thereafter, so
|
||||
// repeated calls on the same account are no-ops.
|
||||
//
|
||||
// The geoip lookup itself is best-effort: a missing or invalid country
|
||||
// returns nil (no UPDATE executed) and never blocks the auth flow. Errors
|
||||
// from the database UPDATE itself surface to the caller so the auth
|
||||
// service can decide whether to log or escalate.
|
||||
func (s *Service) SetDeclaredCountryAtRegistration(ctx context.Context, userID uuid.UUID, sourceIP string) error {
|
||||
if s == nil {
|
||||
return errors.New("geo: nil service")
|
||||
}
|
||||
country := s.LookupCountry(sourceIP)
|
||||
if country == "" {
|
||||
return nil
|
||||
}
|
||||
stmt := table.Accounts.UPDATE(table.Accounts.DeclaredCountry, table.Accounts.UpdatedAt).
|
||||
SET(postgres.String(country), postgres.NOW()).
|
||||
WHERE(
|
||||
table.Accounts.UserID.EQ(postgres.UUID(userID)).
|
||||
AND(table.Accounts.DeclaredCountry.IS_NULL()).
|
||||
AND(table.Accounts.DeletedAt.IS_NULL()),
|
||||
)
|
||||
if _, err := stmt.ExecContext(ctx, s.db); err != nil {
|
||||
return fmt.Errorf("geo: set declared_country for %s: %w", userID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
package geo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// NewServiceForTest builds a Service with no GeoLite2 resolver. It is
|
||||
// the entry point external tests use when they want to exercise the
|
||||
// counter / admin paths without spinning up a real mmdb file. The
|
||||
// returned Service still owns its background context and logger so
|
||||
// IncrementCounterAsync and ListUserCounters behave exactly as they do
|
||||
// in production.
|
||||
func NewServiceForTest(db *sql.DB) (*Service, error) {
|
||||
if db == nil {
|
||||
return nil, errors.New("geo: db must not be nil")
|
||||
}
|
||||
bgCtx, bgCancel := context.WithCancel(context.Background())
|
||||
return &Service{
|
||||
db: db,
|
||||
logger: zap.NewNop(),
|
||||
bgCtx: bgCtx,
|
||||
bgCancel: bgCancel,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// IncrementCounterTestSync runs the package-private upsert path
|
||||
// synchronously so external tests can assert on counter rows without
|
||||
// having to deal with goroutine scheduling. Failure to upsert fails the
|
||||
// test rather than being silently logged.
|
||||
func (s *Service) IncrementCounterTestSync(t *testing.T, userID uuid.UUID, country string) {
|
||||
t.Helper()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), counterUpsertTimeout)
|
||||
defer cancel()
|
||||
if err := s.upsertCounter(ctx, userID, country); err != nil {
|
||||
t.Fatalf("upsert counter (%s/%s): %v", userID, country, err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,159 @@
|
||||
// Package geo wraps the GeoLite2 country resolver and exposes the
|
||||
// platform-level geo helpers consumed by `backend/internal/auth` at user
|
||||
// registration time and by the user-surface middleware on every
|
||||
// authenticated request.
|
||||
//
|
||||
// The implementation shipped `LookupCountry`, `LanguageForIP` and
|
||||
// `SetDeclaredCountryAtRegistration`. The implementation added the
|
||||
// `OnUserDeleted` cascade leg. The implementation layers `IncrementCounterAsync`
|
||||
// and `ListUserCounters` on top of the same Service plus the
|
||||
// background-goroutine machinery (cancellable context and WaitGroup)
|
||||
// needed to drain pending counter upserts on shutdown.
|
||||
package geo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"galaxy/geoip"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Service is the geo-domain entry point. It is safe for concurrent use.
|
||||
type Service struct {
|
||||
db *sql.DB
|
||||
resolver *geoip.Resolver
|
||||
|
||||
logger *zap.Logger
|
||||
|
||||
// bgCtx is the lifetime context passed to fire-and-forget goroutines
|
||||
// launched by IncrementCounterAsync. It is cancelled by Close so that
|
||||
// in-flight counter upserts observe shutdown promptly. The matching
|
||||
// WaitGroup tracks live goroutines so Drain (and Close) can wait for
|
||||
// them.
|
||||
bgCtx context.Context
|
||||
bgCancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
closed atomic.Bool
|
||||
}
|
||||
|
||||
// NewService constructs a Service backed by the GeoLite2 country database
|
||||
// at databasePath and the supplied Postgres pool. Closing the returned
|
||||
// Service releases the memory-mapped database file; the database pool is
|
||||
// owned by the caller.
|
||||
//
|
||||
// A trimmed-empty databasePath is rejected with a non-nil error so that
|
||||
// boot fails fast rather than silently hiding lookups behind a permanent
|
||||
// failure path. Callers that explicitly want a no-op Service should
|
||||
// inject their own implementation via the auth-level interfaces.
|
||||
//
|
||||
// The returned Service uses a no-op zap logger by default; callers that
|
||||
// want diagnostic output from the asynchronous counter path inject one
|
||||
// via SetLogger.
|
||||
func NewService(databasePath string, db *sql.DB) (*Service, error) {
|
||||
if db == nil {
|
||||
return nil, errors.New("geo: db must not be nil")
|
||||
}
|
||||
resolver, err := geoip.Open(databasePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("geo: open resolver: %w", err)
|
||||
}
|
||||
bgCtx, bgCancel := context.WithCancel(context.Background())
|
||||
return &Service{
|
||||
db: db,
|
||||
resolver: resolver,
|
||||
logger: zap.NewNop(),
|
||||
bgCtx: bgCtx,
|
||||
bgCancel: bgCancel,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SetLogger replaces the diagnostic logger used by the asynchronous
|
||||
// counter path. A nil argument resets the logger to a no-op so that
|
||||
// production wiring can supply a real logger after construction without
|
||||
// the test paths having to thread one through. SetLogger is nil-safe on
|
||||
// the Service receiver.
|
||||
func (s *Service) SetLogger(logger *zap.Logger) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
s.logger = logger.Named("geo")
|
||||
}
|
||||
|
||||
// Drain blocks until every fire-and-forget goroutine launched through
|
||||
// IncrementCounterAsync has finished, or until ctx is done. It cancels
|
||||
// the Service-internal background context so live goroutines observe
|
||||
// shutdown and stop waiting on the database. Drain is nil-safe and
|
||||
// idempotent: subsequent calls return immediately.
|
||||
//
|
||||
// Drain does not close the GeoLite2 resolver — Close does. The split
|
||||
// lets the boot orchestrator wait for in-flight writes within the
|
||||
// shutdown deadline before the resolver and database pool are torn
|
||||
// down.
|
||||
func (s *Service) Drain(ctx context.Context) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
if s.bgCancel != nil {
|
||||
s.bgCancel()
|
||||
}
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
s.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
|
||||
// Close releases the underlying GeoLite2 database resources. Pending
|
||||
// counter goroutines launched through IncrementCounterAsync are
|
||||
// signalled to stop via the internal background context but are NOT
|
||||
// awaited; callers that need to wait must invoke Drain first. Close is
|
||||
// idempotent and nil-safe; subsequent lookups return the empty country
|
||||
// / language ("" treated as no data).
|
||||
func (s *Service) Close() error {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
if !s.closed.CompareAndSwap(false, true) {
|
||||
return nil
|
||||
}
|
||||
if s.bgCancel != nil {
|
||||
s.bgCancel()
|
||||
}
|
||||
if s.resolver == nil {
|
||||
return nil
|
||||
}
|
||||
if err := s.resolver.Close(); err != nil {
|
||||
return fmt.Errorf("geo: close resolver: %w", err)
|
||||
}
|
||||
s.resolver = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// LookupCountry resolves an uppercase ISO 3166-1 alpha-2 country code
|
||||
// from sourceIP. The lookup is best-effort: the empty string is returned
|
||||
// for any invalid address, missing record, or closed resolver. The
|
||||
// returned error is always nil; callers that need diagnostic detail
|
||||
// should query the geoip resolver directly.
|
||||
func (s *Service) LookupCountry(sourceIP string) string {
|
||||
if s == nil || s.resolver == nil || sourceIP == "" {
|
||||
return ""
|
||||
}
|
||||
code, err := s.resolver.CountryString(sourceIP)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return code
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
package geo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestLanguageForCountry(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"DE": "de",
|
||||
"de": "de", // case-insensitive input
|
||||
"RU": "ru",
|
||||
"BR": "pt",
|
||||
"": "",
|
||||
"ZZ": "",
|
||||
}
|
||||
for input, want := range cases {
|
||||
if got := languageForCountry(input); got != want {
|
||||
t.Errorf("languageForCountry(%q) = %q, want %q", input, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLookupCountryNilSafety(t *testing.T) {
|
||||
var s *Service
|
||||
if got := s.LookupCountry("8.8.8.8"); got != "" {
|
||||
t.Errorf("nil Service LookupCountry = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLanguageForIPNilSafety(t *testing.T) {
|
||||
var s *Service
|
||||
if got := s.LanguageForIP("8.8.8.8"); got != "" {
|
||||
t.Errorf("nil Service LanguageForIP = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetLoggerNilSafety(t *testing.T) {
|
||||
var s *Service
|
||||
s.SetLogger(zap.NewNop())
|
||||
s.SetLogger(nil)
|
||||
|
||||
live := &Service{}
|
||||
live.SetLogger(nil) // does not panic; falls back to nop logger.
|
||||
}
|
||||
|
||||
func TestDrainNilSafety(t *testing.T) {
|
||||
var s *Service
|
||||
s.Drain(context.Background())
|
||||
}
|
||||
|
||||
func TestDrainReturnsWhenContextDone(t *testing.T) {
|
||||
live := &Service{}
|
||||
live.bgCtx, live.bgCancel = context.WithCancel(context.Background())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
start := time.Now()
|
||||
live.Drain(ctx)
|
||||
if elapsed := time.Since(start); elapsed > 5*time.Second {
|
||||
t.Fatalf("Drain blocked too long: %s", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseIdempotent(t *testing.T) {
|
||||
live := &Service{}
|
||||
live.bgCtx, live.bgCancel = context.WithCancel(context.Background())
|
||||
if err := live.Close(); err != nil {
|
||||
t.Fatalf("first Close: %v", err)
|
||||
}
|
||||
if err := live.Close(); err != nil {
|
||||
t.Fatalf("second Close: %v", err)
|
||||
}
|
||||
var nilSvc *Service
|
||||
if err := nilSvc.Close(); err != nil {
|
||||
t.Fatalf("nil Service Close: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
package geo
|
||||
|
||||
// LanguageForIP returns an ISO 639-1 language code derived from
|
||||
// sourceIP. The function looks up the country via LookupCountry and then
|
||||
// consults the static country->language table. Returns "" when the
|
||||
// country lookup fails or no language mapping exists for the country.
|
||||
//
|
||||
// Auth uses LanguageForIP as a fallback after the client-supplied locale
|
||||
// (request body or Accept-Language header). The empty string signals
|
||||
// "fall through to the platform default 'en'".
|
||||
func (s *Service) LanguageForIP(sourceIP string) string {
|
||||
country := s.LookupCountry(sourceIP)
|
||||
return languageForCountry(country)
|
||||
}
|
||||
Reference in New Issue
Block a user