package geo_test import ( "context" "database/sql" "net/url" "testing" "time" "galaxy/backend/internal/geo" backendpg "galaxy/backend/internal/postgres" pgshared "galaxy/postgres" "github.com/google/uuid" testcontainers "github.com/testcontainers/testcontainers-go" tcpostgres "github.com/testcontainers/testcontainers-go/modules/postgres" "github.com/testcontainers/testcontainers-go/wait" "go.uber.org/zap/zaptest" ) const ( pgImage = "postgres:16-alpine" pgUser = "galaxy" pgPassword = "galaxy" pgDatabase = "galaxy_backend" pgSchema = "backend" pgStartup = 90 * time.Second pgOpTO = 10 * time.Second ) // startPostgres mirrors the auth/notification test scaffolding: spin up // a Postgres testcontainer, apply backend migrations, return *sql.DB. func startPostgres(t *testing.T) *sql.DB { t.Helper() ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) t.Cleanup(cancel) pgContainer, err := tcpostgres.Run(ctx, pgImage, tcpostgres.WithDatabase(pgDatabase), tcpostgres.WithUsername(pgUser), tcpostgres.WithPassword(pgPassword), testcontainers.WithWaitStrategy( wait.ForLog("database system is ready to accept connections"). WithOccurrence(2). WithStartupTimeout(pgStartup), ), ) if err != nil { t.Skipf("postgres testcontainer unavailable, skipping: %v", err) } t.Cleanup(func() { if termErr := testcontainers.TerminateContainer(pgContainer); termErr != nil { t.Errorf("terminate postgres container: %v", termErr) } }) baseDSN, err := pgContainer.ConnectionString(ctx, "sslmode=disable") if err != nil { t.Fatalf("connection string: %v", err) } scoped, err := dsnWithSearchPath(baseDSN, pgSchema) if err != nil { t.Fatalf("scope dsn: %v", err) } cfg := pgshared.DefaultConfig() cfg.PrimaryDSN = scoped cfg.OperationTimeout = pgOpTO db, err := pgshared.OpenPrimary(ctx, cfg, backendpg.NoObservabilityOptions()...) if err != nil { t.Fatalf("open primary: %v", err) } t.Cleanup(func() { _ = db.Close() }) if err := pgshared.Ping(ctx, db, cfg.OperationTimeout); err != nil { t.Fatalf("ping: %v", err) } if err := backendpg.ApplyMigrations(ctx, db); err != nil { t.Fatalf("apply migrations: %v", err) } return db } func dsnWithSearchPath(baseDSN, schema string) (string, error) { parsed, err := url.Parse(baseDSN) if err != nil { return "", err } values := parsed.Query() values.Set("search_path", schema) if values.Get("sslmode") == "" { values.Set("sslmode", "disable") } parsed.RawQuery = values.Encode() return parsed.String(), nil } // fixtureService constructs a Service that uses an injected database // pool and skips the GeoLite2 resolver — the resolver is exercised by // `pkg/geoip` tests, while the counter path under test is independent // of the lookup. The caller is responsible for invoking Drain/Close. func fixtureService(t *testing.T, db *sql.DB) *geo.Service { t.Helper() svc, err := geo.NewServiceForTest(db) if err != nil { t.Fatalf("new service: %v", err) } svc.SetLogger(zaptest.NewLogger(t)) return svc } func TestIncrementCounterAsyncCreatesRow(t *testing.T) { db := startPostgres(t) svc := fixtureService(t, db) t.Cleanup(func() { ctx, cancel := context.WithTimeout(context.Background(), pgOpTO) defer cancel() svc.Drain(ctx) _ = svc.Close() }) userID := uuid.New() svc.IncrementCounterTestSync(t, userID, "DE") count, lastSeen := readCounter(t, db, userID, "DE") if count != 1 { t.Fatalf("count: want 1, got %d", count) } if lastSeen == nil { t.Fatal("last_seen_at: want non-null, got null") } } func TestIncrementCounterAsyncIncrementsExistingRow(t *testing.T) { db := startPostgres(t) svc := fixtureService(t, db) t.Cleanup(func() { ctx, cancel := context.WithTimeout(context.Background(), pgOpTO) defer cancel() svc.Drain(ctx) _ = svc.Close() }) userID := uuid.New() svc.IncrementCounterTestSync(t, userID, "DE") _, firstSeen := readCounter(t, db, userID, "DE") if firstSeen == nil { t.Fatal("first last_seen_at: want non-null") } // Sleep long enough for now() to advance past Postgres timestamp // resolution (microseconds in practice). time.Sleep(2 * time.Millisecond) svc.IncrementCounterTestSync(t, userID, "DE") count, secondSeen := readCounter(t, db, userID, "DE") if count != 2 { t.Fatalf("count: want 2, got %d", count) } if secondSeen == nil || !secondSeen.After(*firstSeen) { t.Fatalf("last_seen_at: want strictly later than %v, got %v", firstSeen, secondSeen) } } func TestIncrementCounterAsyncShortCircuits(t *testing.T) { db := startPostgres(t) svc := fixtureService(t, db) t.Cleanup(func() { ctx, cancel := context.WithTimeout(context.Background(), pgOpTO) defer cancel() svc.Drain(ctx) _ = svc.Close() }) // Empty country / zero user — exercise the synchronous validation // path through the public API to confirm no goroutine is launched. svc.IncrementCounterAsync(context.Background(), uuid.Nil, "1.2.3.4") svc.IncrementCounterAsync(context.Background(), uuid.New(), "") rows := totalCounterRows(t, db) if rows != 0 { t.Fatalf("expected zero counter rows after short-circuit calls, got %d", rows) } } func TestListUserCountersOrdered(t *testing.T) { db := startPostgres(t) svc := fixtureService(t, db) t.Cleanup(func() { ctx, cancel := context.WithTimeout(context.Background(), pgOpTO) defer cancel() svc.Drain(ctx) _ = svc.Close() }) userID := uuid.New() svc.IncrementCounterTestSync(t, userID, "PL") svc.IncrementCounterTestSync(t, userID, "DE") svc.IncrementCounterTestSync(t, userID, "DE") svc.IncrementCounterTestSync(t, userID, "AU") ctx, cancel := context.WithTimeout(context.Background(), pgOpTO) defer cancel() entries, err := svc.ListUserCounters(ctx, userID) if err != nil { t.Fatalf("list: %v", err) } if len(entries) != 3 { t.Fatalf("entries: want 3, got %d (%+v)", len(entries), entries) } wantOrder := []string{"AU", "DE", "PL"} for i, e := range entries { if e.Country != wantOrder[i] { t.Errorf("entries[%d].Country = %q, want %q", i, e.Country, wantOrder[i]) } if e.LastSeenAt == nil { t.Errorf("entries[%d].LastSeenAt: want non-nil", i) } } if entries[1].Count != 2 { t.Errorf("entries[1].Count: want 2, got %d", entries[1].Count) } } func TestListUserCountersEmpty(t *testing.T) { db := startPostgres(t) svc := fixtureService(t, db) t.Cleanup(func() { ctx, cancel := context.WithTimeout(context.Background(), pgOpTO) defer cancel() svc.Drain(ctx) _ = svc.Close() }) ctx, cancel := context.WithTimeout(context.Background(), pgOpTO) defer cancel() entries, err := svc.ListUserCounters(ctx, uuid.New()) if err != nil { t.Fatalf("list unknown user: %v", err) } if len(entries) != 0 { t.Fatalf("entries: want empty, got %+v", entries) } } func TestListUserCountersNilArguments(t *testing.T) { db := startPostgres(t) svc := fixtureService(t, db) t.Cleanup(func() { _ = svc.Close() }) ctx, cancel := context.WithTimeout(context.Background(), pgOpTO) defer cancel() if _, err := svc.ListUserCounters(ctx, uuid.Nil); err == nil { t.Fatal("ListUserCounters(uuid.Nil): want error") } var nilSvc *geo.Service if _, err := nilSvc.ListUserCounters(ctx, uuid.New()); err == nil { t.Fatal("nil receiver ListUserCounters: want error") } } func TestDrainAwaitsInFlightCounters(t *testing.T) { db := startPostgres(t) svc := fixtureService(t, db) userID := uuid.New() // Inject country directly through the test seam so the lookup never // returns empty even though the resolver is unset. svc.IncrementCounterTestSync(t, userID, "FR") ctx, cancel := context.WithTimeout(context.Background(), pgOpTO) defer cancel() svc.Drain(ctx) if err := svc.Close(); err != nil { t.Fatalf("close: %v", err) } count, _ := readCounter(t, db, userID, "FR") if count != 1 { t.Fatalf("count after drain+close: want 1, got %d", count) } } func readCounter(t *testing.T, db *sql.DB, userID uuid.UUID, country string) (int64, *time.Time) { t.Helper() ctx, cancel := context.WithTimeout(context.Background(), pgOpTO) defer cancel() var ( count int64 lastSeenAt sql.NullTime ) err := db.QueryRowContext(ctx, ` SELECT count, last_seen_at FROM backend.user_country_counters WHERE user_id = $1 AND country = $2 `, userID, country).Scan(&count, &lastSeenAt) if err != nil { t.Fatalf("read counter (%s/%s): %v", userID, country, err) } if !lastSeenAt.Valid { return count, nil } ts := lastSeenAt.Time.UTC() return count, &ts } func totalCounterRows(t *testing.T, db *sql.DB) int { t.Helper() ctx, cancel := context.WithTimeout(context.Background(), pgOpTO) defer cancel() var n int if err := db.QueryRowContext(ctx, ` SELECT count(*) FROM backend.user_country_counters `).Scan(&n); err != nil { t.Fatalf("count rows: %v", err) } return n }