feat: backend service

This commit is contained in:
Ilia Denisov
2026-05-06 10:14:55 +03:00
committed by GitHub
parent 3e2622757e
commit f446c6a2ac
1486 changed files with 49720 additions and 266401 deletions
+36
View File
@@ -0,0 +1,36 @@
package geo
import (
"context"
"errors"
"fmt"
"galaxy/backend/internal/postgres/jet/backend/table"
"github.com/go-jet/jet/v2/postgres"
"github.com/google/uuid"
)
// OnUserDeleted removes every `backend.user_country_counters` row for
// userID. It is the geo-side leg of the soft-delete cascade documented
// in `backend/PLAN.md` §5.2 / §5.8 and is invoked from
// `backend/internal/user.Service.SoftDelete` after the
// `accounts.deleted_at` write commits.
//
// The DELETE is idempotent: re-running on a user with no counters is a
// successful no-op. Errors from the database are wrapped with the geo
// prefix so caller logs identify the source.
func (s *Service) OnUserDeleted(ctx context.Context, userID uuid.UUID) error {
if s == nil {
return errors.New("geo: nil service")
}
if userID == uuid.Nil {
return errors.New("geo: nil user id")
}
stmt := table.UserCountryCounters.DELETE().
WHERE(table.UserCountryCounters.UserID.EQ(postgres.UUID(userID)))
if _, err := stmt.ExecContext(ctx, s.db); err != nil {
return fmt.Errorf("geo: delete counters for %s: %w", userID, err)
}
return nil
}
+136
View File
@@ -0,0 +1,136 @@
package geo
import (
"context"
"errors"
"fmt"
"time"
"galaxy/backend/internal/postgres/jet/backend/model"
"galaxy/backend/internal/postgres/jet/backend/table"
"github.com/go-jet/jet/v2/postgres"
"github.com/google/uuid"
"go.uber.org/zap"
)
// counterUpsertTimeout bounds the database call performed by a single
// fire-and-forget counter goroutine. The upsert is a single statement on
// a tiny table and should complete in well under a second; the timeout
// exists to keep one slow Postgres node from accumulating leaked
// goroutines under load.
const counterUpsertTimeout = 5 * time.Second
// CountryCounter is one row from `backend.user_country_counters` exposed
// to the admin surface (`GET /api/v1/admin/geo/users/{user_id}/countries`).
//
// Country is the uppercase ISO 3166-1 alpha-2 code stored alongside the
// running count. LastSeenAt is nullable on the table and therefore
// optional; the admin response surfaces null when it is unset.
type CountryCounter struct {
Country string
Count int64
LastSeenAt *time.Time
}
// IncrementCounterAsync upserts the per-country counter for userID as a
// fire-and-forget goroutine: the country lookup is performed
// synchronously (it is pure CPU plus an mmap read), then a goroutine
// runs the database upsert against the Service-internal background
// context. The caller never blocks on the database round-trip and never
// observes errors directly — failures are logged via the Service logger
// configured through SetLogger.
//
// Inputs that yield no useful data short-circuit without launching the
// goroutine: a nil receiver, a zero userID, an empty sourceIP, or a
// failed country lookup all return immediately. A Service whose
// background context has already been cancelled (typically because Drain
// or Close ran) also short-circuits — counters are not started during
// shutdown, but live ones are awaited by Drain.
//
// The ctx parameter is intentionally unused for the database call: the
// request-scoped context is cancelled the moment the response is
// flushed to the gateway, which would race with the upsert. The
// goroutine derives its context from the Service-internal one
// instead.
func (s *Service) IncrementCounterAsync(_ context.Context, userID uuid.UUID, sourceIP string) {
if s == nil || userID == uuid.Nil || sourceIP == "" {
return
}
if s.bgCtx == nil || s.bgCtx.Err() != nil {
return
}
country := s.LookupCountry(sourceIP)
if country == "" {
return
}
s.wg.Go(func() {
ctx, cancel := context.WithTimeout(s.bgCtx, counterUpsertTimeout)
defer cancel()
if err := s.upsertCounter(ctx, userID, country); err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return
}
s.logger.Warn("counter upsert failed",
zap.String("user_id", userID.String()),
zap.String("country", country),
zap.Error(err),
)
}
})
}
// upsertCounter executes the atomic INSERT...ON CONFLICT against
// `backend.user_country_counters`. The compound primary key
// `(user_id, country)` makes the upsert race-safe across concurrent
// goroutines.
func (s *Service) upsertCounter(ctx context.Context, userID uuid.UUID, country string) error {
ucc := table.UserCountryCounters
stmt := ucc.INSERT(ucc.UserID, ucc.Country, ucc.Count, ucc.LastSeenAt).
VALUES(userID, country, postgres.Int(1), postgres.NOW()).
ON_CONFLICT(ucc.UserID, ucc.Country).
DO_UPDATE(postgres.SET(
ucc.Count.SET(ucc.Count.ADD(postgres.Int(1))),
ucc.LastSeenAt.SET(postgres.TimestampzExp(postgres.NOW())),
))
if _, err := stmt.ExecContext(ctx, s.db); err != nil {
return fmt.Errorf("geo: upsert counter for %s/%s: %w", userID, country, err)
}
return nil
}
// ListUserCounters returns every per-country counter recorded for
// userID, ordered by country ASC. The list is empty (and the error is
// nil) when the user has no rows; ListUserCounters does not check that
// the user exists in `backend.accounts` because the admin surface gates
// existence through a separate listing endpoint.
func (s *Service) ListUserCounters(ctx context.Context, userID uuid.UUID) ([]CountryCounter, error) {
if s == nil {
return nil, errors.New("geo: nil service")
}
if userID == uuid.Nil {
return nil, errors.New("geo: nil user id")
}
ucc := table.UserCountryCounters
stmt := postgres.SELECT(ucc.Country, ucc.Count, ucc.LastSeenAt).
FROM(ucc).
WHERE(ucc.UserID.EQ(postgres.UUID(userID))).
ORDER_BY(ucc.Country.ASC())
var dest []model.UserCountryCounters
if err := stmt.QueryContext(ctx, s.db, &dest); err != nil {
return nil, fmt.Errorf("geo: list counters for %s: %w", userID, err)
}
out := make([]CountryCounter, 0, len(dest))
for _, row := range dest {
entry := CountryCounter{Country: row.Country, Count: row.Count}
if row.LastSeenAt != nil {
ts := row.LastSeenAt.UTC()
entry.LastSeenAt = &ts
}
out = append(out, entry)
}
return out, nil
}
+320
View File
@@ -0,0 +1,320 @@
package geo_test
import (
"context"
"database/sql"
"net/url"
"testing"
"time"
"galaxy/backend/internal/geo"
backendpg "galaxy/backend/internal/postgres"
pgshared "galaxy/postgres"
"github.com/google/uuid"
testcontainers "github.com/testcontainers/testcontainers-go"
tcpostgres "github.com/testcontainers/testcontainers-go/modules/postgres"
"github.com/testcontainers/testcontainers-go/wait"
"go.uber.org/zap/zaptest"
)
const (
pgImage = "postgres:16-alpine"
pgUser = "galaxy"
pgPassword = "galaxy"
pgDatabase = "galaxy_backend"
pgSchema = "backend"
pgStartup = 90 * time.Second
pgOpTO = 10 * time.Second
)
// startPostgres mirrors the auth/notification test scaffolding: spin up
// a Postgres testcontainer, apply backend migrations, return *sql.DB.
func startPostgres(t *testing.T) *sql.DB {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
t.Cleanup(cancel)
pgContainer, err := tcpostgres.Run(ctx, pgImage,
tcpostgres.WithDatabase(pgDatabase),
tcpostgres.WithUsername(pgUser),
tcpostgres.WithPassword(pgPassword),
testcontainers.WithWaitStrategy(
wait.ForLog("database system is ready to accept connections").
WithOccurrence(2).
WithStartupTimeout(pgStartup),
),
)
if err != nil {
t.Skipf("postgres testcontainer unavailable, skipping: %v", err)
}
t.Cleanup(func() {
if termErr := testcontainers.TerminateContainer(pgContainer); termErr != nil {
t.Errorf("terminate postgres container: %v", termErr)
}
})
baseDSN, err := pgContainer.ConnectionString(ctx, "sslmode=disable")
if err != nil {
t.Fatalf("connection string: %v", err)
}
scoped, err := dsnWithSearchPath(baseDSN, pgSchema)
if err != nil {
t.Fatalf("scope dsn: %v", err)
}
cfg := pgshared.DefaultConfig()
cfg.PrimaryDSN = scoped
cfg.OperationTimeout = pgOpTO
db, err := pgshared.OpenPrimary(ctx, cfg)
if err != nil {
t.Fatalf("open primary: %v", err)
}
t.Cleanup(func() { _ = db.Close() })
if err := pgshared.Ping(ctx, db, cfg.OperationTimeout); err != nil {
t.Fatalf("ping: %v", err)
}
if err := backendpg.ApplyMigrations(ctx, db); err != nil {
t.Fatalf("apply migrations: %v", err)
}
return db
}
func dsnWithSearchPath(baseDSN, schema string) (string, error) {
parsed, err := url.Parse(baseDSN)
if err != nil {
return "", err
}
values := parsed.Query()
values.Set("search_path", schema)
if values.Get("sslmode") == "" {
values.Set("sslmode", "disable")
}
parsed.RawQuery = values.Encode()
return parsed.String(), nil
}
// fixtureService constructs a Service that uses an injected database
// pool and skips the GeoLite2 resolver — the resolver is exercised by
// `pkg/geoip` tests, while the counter path under test is independent
// of the lookup. The caller is responsible for invoking Drain/Close.
func fixtureService(t *testing.T, db *sql.DB) *geo.Service {
t.Helper()
svc, err := geo.NewServiceForTest(db)
if err != nil {
t.Fatalf("new service: %v", err)
}
svc.SetLogger(zaptest.NewLogger(t))
return svc
}
func TestIncrementCounterAsyncCreatesRow(t *testing.T) {
db := startPostgres(t)
svc := fixtureService(t, db)
t.Cleanup(func() {
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
defer cancel()
svc.Drain(ctx)
_ = svc.Close()
})
userID := uuid.New()
svc.IncrementCounterTestSync(t, userID, "DE")
count, lastSeen := readCounter(t, db, userID, "DE")
if count != 1 {
t.Fatalf("count: want 1, got %d", count)
}
if lastSeen == nil {
t.Fatal("last_seen_at: want non-null, got null")
}
}
func TestIncrementCounterAsyncIncrementsExistingRow(t *testing.T) {
db := startPostgres(t)
svc := fixtureService(t, db)
t.Cleanup(func() {
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
defer cancel()
svc.Drain(ctx)
_ = svc.Close()
})
userID := uuid.New()
svc.IncrementCounterTestSync(t, userID, "DE")
_, firstSeen := readCounter(t, db, userID, "DE")
if firstSeen == nil {
t.Fatal("first last_seen_at: want non-null")
}
// Sleep long enough for now() to advance past Postgres timestamp
// resolution (microseconds in practice).
time.Sleep(2 * time.Millisecond)
svc.IncrementCounterTestSync(t, userID, "DE")
count, secondSeen := readCounter(t, db, userID, "DE")
if count != 2 {
t.Fatalf("count: want 2, got %d", count)
}
if secondSeen == nil || !secondSeen.After(*firstSeen) {
t.Fatalf("last_seen_at: want strictly later than %v, got %v", firstSeen, secondSeen)
}
}
func TestIncrementCounterAsyncShortCircuits(t *testing.T) {
db := startPostgres(t)
svc := fixtureService(t, db)
t.Cleanup(func() {
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
defer cancel()
svc.Drain(ctx)
_ = svc.Close()
})
// Empty country / zero user — exercise the synchronous validation
// path through the public API to confirm no goroutine is launched.
svc.IncrementCounterAsync(context.Background(), uuid.Nil, "1.2.3.4")
svc.IncrementCounterAsync(context.Background(), uuid.New(), "")
rows := totalCounterRows(t, db)
if rows != 0 {
t.Fatalf("expected zero counter rows after short-circuit calls, got %d", rows)
}
}
func TestListUserCountersOrdered(t *testing.T) {
db := startPostgres(t)
svc := fixtureService(t, db)
t.Cleanup(func() {
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
defer cancel()
svc.Drain(ctx)
_ = svc.Close()
})
userID := uuid.New()
svc.IncrementCounterTestSync(t, userID, "PL")
svc.IncrementCounterTestSync(t, userID, "DE")
svc.IncrementCounterTestSync(t, userID, "DE")
svc.IncrementCounterTestSync(t, userID, "AU")
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
defer cancel()
entries, err := svc.ListUserCounters(ctx, userID)
if err != nil {
t.Fatalf("list: %v", err)
}
if len(entries) != 3 {
t.Fatalf("entries: want 3, got %d (%+v)", len(entries), entries)
}
wantOrder := []string{"AU", "DE", "PL"}
for i, e := range entries {
if e.Country != wantOrder[i] {
t.Errorf("entries[%d].Country = %q, want %q", i, e.Country, wantOrder[i])
}
if e.LastSeenAt == nil {
t.Errorf("entries[%d].LastSeenAt: want non-nil", i)
}
}
if entries[1].Count != 2 {
t.Errorf("entries[1].Count: want 2, got %d", entries[1].Count)
}
}
func TestListUserCountersEmpty(t *testing.T) {
db := startPostgres(t)
svc := fixtureService(t, db)
t.Cleanup(func() {
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
defer cancel()
svc.Drain(ctx)
_ = svc.Close()
})
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
defer cancel()
entries, err := svc.ListUserCounters(ctx, uuid.New())
if err != nil {
t.Fatalf("list unknown user: %v", err)
}
if len(entries) != 0 {
t.Fatalf("entries: want empty, got %+v", entries)
}
}
func TestListUserCountersNilArguments(t *testing.T) {
db := startPostgres(t)
svc := fixtureService(t, db)
t.Cleanup(func() { _ = svc.Close() })
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
defer cancel()
if _, err := svc.ListUserCounters(ctx, uuid.Nil); err == nil {
t.Fatal("ListUserCounters(uuid.Nil): want error")
}
var nilSvc *geo.Service
if _, err := nilSvc.ListUserCounters(ctx, uuid.New()); err == nil {
t.Fatal("nil receiver ListUserCounters: want error")
}
}
func TestDrainAwaitsInFlightCounters(t *testing.T) {
db := startPostgres(t)
svc := fixtureService(t, db)
userID := uuid.New()
// Inject country directly through the test seam so the lookup never
// returns empty even though the resolver is unset.
svc.IncrementCounterTestSync(t, userID, "FR")
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
defer cancel()
svc.Drain(ctx)
if err := svc.Close(); err != nil {
t.Fatalf("close: %v", err)
}
count, _ := readCounter(t, db, userID, "FR")
if count != 1 {
t.Fatalf("count after drain+close: want 1, got %d", count)
}
}
func readCounter(t *testing.T, db *sql.DB, userID uuid.UUID, country string) (int64, *time.Time) {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
defer cancel()
var (
count int64
lastSeenAt sql.NullTime
)
err := db.QueryRowContext(ctx, `
SELECT count, last_seen_at FROM backend.user_country_counters
WHERE user_id = $1 AND country = $2
`, userID, country).Scan(&count, &lastSeenAt)
if err != nil {
t.Fatalf("read counter (%s/%s): %v", userID, country, err)
}
if !lastSeenAt.Valid {
return count, nil
}
ts := lastSeenAt.Time.UTC()
return count, &ts
}
func totalCounterRows(t *testing.T, db *sql.DB) int {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
defer cancel()
var n int
if err := db.QueryRowContext(ctx, `
SELECT count(*) FROM backend.user_country_counters
`).Scan(&n); err != nil {
t.Fatalf("count rows: %v", err)
}
return n
}
+63
View File
@@ -0,0 +1,63 @@
package geo
import "strings"
// countryToLanguage maps an uppercase ISO 3166-1 alpha-2 country code to
// an ISO 639-1 lowercase language code. The set is intentionally minimal
// — covering the top-traffic Galaxy locales — and is consulted as a
// fallback when neither the request body nor the Accept-Language header
// supplied a locale at send-email-code. Unknown countries map to the
// empty string so the auth flow can default to "en".
//
// The mapping is intentionally hard-coded rather than derived from the
// GeoLite2 database: countries with multiple official languages collapse
// to the single most common UI locale to keep the registration path
// deterministic. The implementation may revise this table without changing the
// surface auth depends on.
var countryToLanguage = map[string]string{
// English-default territories and the platform fallback.
"US": "en", "GB": "en", "AU": "en", "NZ": "en", "IE": "en", "CA": "en",
// Western Europe.
"DE": "de", "AT": "de", "CH": "de",
"FR": "fr", "BE": "fr", "LU": "fr",
"ES": "es", "MX": "es", "AR": "es", "CL": "es", "CO": "es",
"IT": "it",
"PT": "pt", "BR": "pt",
"NL": "nl",
// Central / Eastern Europe.
"PL": "pl",
"RU": "ru", "BY": "ru", "KZ": "ru",
"UA": "uk",
"CZ": "cs",
"SK": "sk",
"HU": "hu",
"RO": "ro",
"BG": "bg",
// Northern Europe.
"SE": "sv",
"NO": "no",
"DK": "da",
"FI": "fi",
// Asia.
"JP": "ja",
"KR": "ko",
"CN": "zh", "TW": "zh", "HK": "zh", "SG": "zh",
"VN": "vi",
"TH": "th",
"ID": "id",
"IN": "en",
"IL": "he",
"TR": "tr",
// Middle East and North Africa.
"SA": "ar", "AE": "ar", "EG": "ar",
}
// languageForCountry returns the ISO 639-1 language code mapped to
// country, or "" when no mapping is known. country is normalised to
// uppercase before lookup.
func languageForCountry(country string) string {
if country == "" {
return ""
}
return countryToLanguage[strings.ToUpper(strings.TrimSpace(country))]
}
+43
View File
@@ -0,0 +1,43 @@
package geo
import (
"context"
"errors"
"fmt"
"galaxy/backend/internal/postgres/jet/backend/table"
"github.com/go-jet/jet/v2/postgres"
"github.com/google/uuid"
)
// SetDeclaredCountryAtRegistration writes the geoip-derived country to
// `accounts.declared_country` for userID, and only when the column is
// currently NULL. The semantics match PLAN.md §5.8: declared_country is
// captured at first registration and never updated thereafter, so
// repeated calls on the same account are no-ops.
//
// The geoip lookup itself is best-effort: a missing or invalid country
// returns nil (no UPDATE executed) and never blocks the auth flow. Errors
// from the database UPDATE itself surface to the caller so the auth
// service can decide whether to log or escalate.
func (s *Service) SetDeclaredCountryAtRegistration(ctx context.Context, userID uuid.UUID, sourceIP string) error {
if s == nil {
return errors.New("geo: nil service")
}
country := s.LookupCountry(sourceIP)
if country == "" {
return nil
}
stmt := table.Accounts.UPDATE(table.Accounts.DeclaredCountry, table.Accounts.UpdatedAt).
SET(postgres.String(country), postgres.NOW()).
WHERE(
table.Accounts.UserID.EQ(postgres.UUID(userID)).
AND(table.Accounts.DeclaredCountry.IS_NULL()).
AND(table.Accounts.DeletedAt.IS_NULL()),
)
if _, err := stmt.ExecContext(ctx, s.db); err != nil {
return fmt.Errorf("geo: set declared_country for %s: %w", userID, err)
}
return nil
}
+43
View File
@@ -0,0 +1,43 @@
package geo
import (
"context"
"database/sql"
"errors"
"testing"
"github.com/google/uuid"
"go.uber.org/zap"
)
// NewServiceForTest builds a Service with no GeoLite2 resolver. It is
// the entry point external tests use when they want to exercise the
// counter / admin paths without spinning up a real mmdb file. The
// returned Service still owns its background context and logger so
// IncrementCounterAsync and ListUserCounters behave exactly as they do
// in production.
func NewServiceForTest(db *sql.DB) (*Service, error) {
if db == nil {
return nil, errors.New("geo: db must not be nil")
}
bgCtx, bgCancel := context.WithCancel(context.Background())
return &Service{
db: db,
logger: zap.NewNop(),
bgCtx: bgCtx,
bgCancel: bgCancel,
}, nil
}
// IncrementCounterTestSync runs the package-private upsert path
// synchronously so external tests can assert on counter rows without
// having to deal with goroutine scheduling. Failure to upsert fails the
// test rather than being silently logged.
func (s *Service) IncrementCounterTestSync(t *testing.T, userID uuid.UUID, country string) {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), counterUpsertTimeout)
defer cancel()
if err := s.upsertCounter(ctx, userID, country); err != nil {
t.Fatalf("upsert counter (%s/%s): %v", userID, country, err)
}
}
+159
View File
@@ -0,0 +1,159 @@
// Package geo wraps the GeoLite2 country resolver and exposes the
// platform-level geo helpers consumed by `backend/internal/auth` at user
// registration time and by the user-surface middleware on every
// authenticated request.
//
// The implementation shipped `LookupCountry`, `LanguageForIP` and
// `SetDeclaredCountryAtRegistration`. The implementation added the
// `OnUserDeleted` cascade leg. The implementation layers `IncrementCounterAsync`
// and `ListUserCounters` on top of the same Service plus the
// background-goroutine machinery (cancellable context and WaitGroup)
// needed to drain pending counter upserts on shutdown.
package geo
import (
"context"
"database/sql"
"errors"
"fmt"
"sync"
"sync/atomic"
"galaxy/geoip"
"go.uber.org/zap"
)
// Service is the geo-domain entry point. It is safe for concurrent use.
type Service struct {
db *sql.DB
resolver *geoip.Resolver
logger *zap.Logger
// bgCtx is the lifetime context passed to fire-and-forget goroutines
// launched by IncrementCounterAsync. It is cancelled by Close so that
// in-flight counter upserts observe shutdown promptly. The matching
// WaitGroup tracks live goroutines so Drain (and Close) can wait for
// them.
bgCtx context.Context
bgCancel context.CancelFunc
wg sync.WaitGroup
closed atomic.Bool
}
// NewService constructs a Service backed by the GeoLite2 country database
// at databasePath and the supplied Postgres pool. Closing the returned
// Service releases the memory-mapped database file; the database pool is
// owned by the caller.
//
// A trimmed-empty databasePath is rejected with a non-nil error so that
// boot fails fast rather than silently hiding lookups behind a permanent
// failure path. Callers that explicitly want a no-op Service should
// inject their own implementation via the auth-level interfaces.
//
// The returned Service uses a no-op zap logger by default; callers that
// want diagnostic output from the asynchronous counter path inject one
// via SetLogger.
func NewService(databasePath string, db *sql.DB) (*Service, error) {
if db == nil {
return nil, errors.New("geo: db must not be nil")
}
resolver, err := geoip.Open(databasePath)
if err != nil {
return nil, fmt.Errorf("geo: open resolver: %w", err)
}
bgCtx, bgCancel := context.WithCancel(context.Background())
return &Service{
db: db,
resolver: resolver,
logger: zap.NewNop(),
bgCtx: bgCtx,
bgCancel: bgCancel,
}, nil
}
// SetLogger replaces the diagnostic logger used by the asynchronous
// counter path. A nil argument resets the logger to a no-op so that
// production wiring can supply a real logger after construction without
// the test paths having to thread one through. SetLogger is nil-safe on
// the Service receiver.
func (s *Service) SetLogger(logger *zap.Logger) {
if s == nil {
return
}
if logger == nil {
logger = zap.NewNop()
}
s.logger = logger.Named("geo")
}
// Drain blocks until every fire-and-forget goroutine launched through
// IncrementCounterAsync has finished, or until ctx is done. It cancels
// the Service-internal background context so live goroutines observe
// shutdown and stop waiting on the database. Drain is nil-safe and
// idempotent: subsequent calls return immediately.
//
// Drain does not close the GeoLite2 resolver — Close does. The split
// lets the boot orchestrator wait for in-flight writes within the
// shutdown deadline before the resolver and database pool are torn
// down.
func (s *Service) Drain(ctx context.Context) {
if s == nil {
return
}
if s.bgCancel != nil {
s.bgCancel()
}
done := make(chan struct{})
go func() {
s.wg.Wait()
close(done)
}()
select {
case <-done:
case <-ctx.Done():
}
}
// Close releases the underlying GeoLite2 database resources. Pending
// counter goroutines launched through IncrementCounterAsync are
// signalled to stop via the internal background context but are NOT
// awaited; callers that need to wait must invoke Drain first. Close is
// idempotent and nil-safe; subsequent lookups return the empty country
// / language ("" treated as no data).
func (s *Service) Close() error {
if s == nil {
return nil
}
if !s.closed.CompareAndSwap(false, true) {
return nil
}
if s.bgCancel != nil {
s.bgCancel()
}
if s.resolver == nil {
return nil
}
if err := s.resolver.Close(); err != nil {
return fmt.Errorf("geo: close resolver: %w", err)
}
s.resolver = nil
return nil
}
// LookupCountry resolves an uppercase ISO 3166-1 alpha-2 country code
// from sourceIP. The lookup is best-effort: the empty string is returned
// for any invalid address, missing record, or closed resolver. The
// returned error is always nil; callers that need diagnostic detail
// should query the geoip resolver directly.
func (s *Service) LookupCountry(sourceIP string) string {
if s == nil || s.resolver == nil || sourceIP == "" {
return ""
}
code, err := s.resolver.CountryString(sourceIP)
if err != nil {
return ""
}
return code
}
+82
View File
@@ -0,0 +1,82 @@
package geo
import (
"context"
"testing"
"time"
"go.uber.org/zap"
)
func TestLanguageForCountry(t *testing.T) {
cases := map[string]string{
"DE": "de",
"de": "de", // case-insensitive input
"RU": "ru",
"BR": "pt",
"": "",
"ZZ": "",
}
for input, want := range cases {
if got := languageForCountry(input); got != want {
t.Errorf("languageForCountry(%q) = %q, want %q", input, got, want)
}
}
}
func TestLookupCountryNilSafety(t *testing.T) {
var s *Service
if got := s.LookupCountry("8.8.8.8"); got != "" {
t.Errorf("nil Service LookupCountry = %q, want empty", got)
}
}
func TestLanguageForIPNilSafety(t *testing.T) {
var s *Service
if got := s.LanguageForIP("8.8.8.8"); got != "" {
t.Errorf("nil Service LanguageForIP = %q, want empty", got)
}
}
func TestSetLoggerNilSafety(t *testing.T) {
var s *Service
s.SetLogger(zap.NewNop())
s.SetLogger(nil)
live := &Service{}
live.SetLogger(nil) // does not panic; falls back to nop logger.
}
func TestDrainNilSafety(t *testing.T) {
var s *Service
s.Drain(context.Background())
}
func TestDrainReturnsWhenContextDone(t *testing.T) {
live := &Service{}
live.bgCtx, live.bgCancel = context.WithCancel(context.Background())
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
start := time.Now()
live.Drain(ctx)
if elapsed := time.Since(start); elapsed > 5*time.Second {
t.Fatalf("Drain blocked too long: %s", elapsed)
}
}
func TestCloseIdempotent(t *testing.T) {
live := &Service{}
live.bgCtx, live.bgCancel = context.WithCancel(context.Background())
if err := live.Close(); err != nil {
t.Fatalf("first Close: %v", err)
}
if err := live.Close(); err != nil {
t.Fatalf("second Close: %v", err)
}
var nilSvc *Service
if err := nilSvc.Close(); err != nil {
t.Fatalf("nil Service Close: %v", err)
}
}
+14
View File
@@ -0,0 +1,14 @@
package geo
// LanguageForIP returns an ISO 639-1 language code derived from
// sourceIP. The function looks up the country via LookupCountry and then
// consults the static country->language table. Returns "" when the
// country lookup fails or no language mapping exists for the country.
//
// Auth uses LanguageForIP as a fallback after the client-supplied locale
// (request body or Accept-Language header). The empty string signals
// "fall through to the platform default 'en'".
func (s *Service) LanguageForIP(sourceIP string) string {
country := s.LookupCountry(sourceIP)
return languageForCountry(country)
}