Files
galaxy-game/backend/internal/geo/counter_test.go
T
2026-05-07 00:58:53 +03:00

321 lines
8.7 KiB
Go

package geo_test
import (
"context"
"database/sql"
"net/url"
"testing"
"time"
"galaxy/backend/internal/geo"
backendpg "galaxy/backend/internal/postgres"
pgshared "galaxy/postgres"
"github.com/google/uuid"
testcontainers "github.com/testcontainers/testcontainers-go"
tcpostgres "github.com/testcontainers/testcontainers-go/modules/postgres"
"github.com/testcontainers/testcontainers-go/wait"
"go.uber.org/zap/zaptest"
)
const (
pgImage = "postgres:16-alpine"
pgUser = "galaxy"
pgPassword = "galaxy"
pgDatabase = "galaxy_backend"
pgSchema = "backend"
pgStartup = 90 * time.Second
pgOpTO = 10 * time.Second
)
// startPostgres mirrors the auth/notification test scaffolding: spin up
// a Postgres testcontainer, apply backend migrations, return *sql.DB.
func startPostgres(t *testing.T) *sql.DB {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
t.Cleanup(cancel)
pgContainer, err := tcpostgres.Run(ctx, pgImage,
tcpostgres.WithDatabase(pgDatabase),
tcpostgres.WithUsername(pgUser),
tcpostgres.WithPassword(pgPassword),
testcontainers.WithWaitStrategy(
wait.ForLog("database system is ready to accept connections").
WithOccurrence(2).
WithStartupTimeout(pgStartup),
),
)
if err != nil {
t.Skipf("postgres testcontainer unavailable, skipping: %v", err)
}
t.Cleanup(func() {
if termErr := testcontainers.TerminateContainer(pgContainer); termErr != nil {
t.Errorf("terminate postgres container: %v", termErr)
}
})
baseDSN, err := pgContainer.ConnectionString(ctx, "sslmode=disable")
if err != nil {
t.Fatalf("connection string: %v", err)
}
scoped, err := dsnWithSearchPath(baseDSN, pgSchema)
if err != nil {
t.Fatalf("scope dsn: %v", err)
}
cfg := pgshared.DefaultConfig()
cfg.PrimaryDSN = scoped
cfg.OperationTimeout = pgOpTO
db, err := pgshared.OpenPrimary(ctx, cfg, backendpg.NoObservabilityOptions()...)
if err != nil {
t.Fatalf("open primary: %v", err)
}
t.Cleanup(func() { _ = db.Close() })
if err := pgshared.Ping(ctx, db, cfg.OperationTimeout); err != nil {
t.Fatalf("ping: %v", err)
}
if err := backendpg.ApplyMigrations(ctx, db); err != nil {
t.Fatalf("apply migrations: %v", err)
}
return db
}
func dsnWithSearchPath(baseDSN, schema string) (string, error) {
parsed, err := url.Parse(baseDSN)
if err != nil {
return "", err
}
values := parsed.Query()
values.Set("search_path", schema)
if values.Get("sslmode") == "" {
values.Set("sslmode", "disable")
}
parsed.RawQuery = values.Encode()
return parsed.String(), nil
}
// fixtureService constructs a Service that uses an injected database
// pool and skips the GeoLite2 resolver — the resolver is exercised by
// `pkg/geoip` tests, while the counter path under test is independent
// of the lookup. The caller is responsible for invoking Drain/Close.
func fixtureService(t *testing.T, db *sql.DB) *geo.Service {
t.Helper()
svc, err := geo.NewServiceForTest(db)
if err != nil {
t.Fatalf("new service: %v", err)
}
svc.SetLogger(zaptest.NewLogger(t))
return svc
}
func TestIncrementCounterAsyncCreatesRow(t *testing.T) {
db := startPostgres(t)
svc := fixtureService(t, db)
t.Cleanup(func() {
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
defer cancel()
svc.Drain(ctx)
_ = svc.Close()
})
userID := uuid.New()
svc.IncrementCounterTestSync(t, userID, "DE")
count, lastSeen := readCounter(t, db, userID, "DE")
if count != 1 {
t.Fatalf("count: want 1, got %d", count)
}
if lastSeen == nil {
t.Fatal("last_seen_at: want non-null, got null")
}
}
func TestIncrementCounterAsyncIncrementsExistingRow(t *testing.T) {
db := startPostgres(t)
svc := fixtureService(t, db)
t.Cleanup(func() {
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
defer cancel()
svc.Drain(ctx)
_ = svc.Close()
})
userID := uuid.New()
svc.IncrementCounterTestSync(t, userID, "DE")
_, firstSeen := readCounter(t, db, userID, "DE")
if firstSeen == nil {
t.Fatal("first last_seen_at: want non-null")
}
// Sleep long enough for now() to advance past Postgres timestamp
// resolution (microseconds in practice).
time.Sleep(2 * time.Millisecond)
svc.IncrementCounterTestSync(t, userID, "DE")
count, secondSeen := readCounter(t, db, userID, "DE")
if count != 2 {
t.Fatalf("count: want 2, got %d", count)
}
if secondSeen == nil || !secondSeen.After(*firstSeen) {
t.Fatalf("last_seen_at: want strictly later than %v, got %v", firstSeen, secondSeen)
}
}
func TestIncrementCounterAsyncShortCircuits(t *testing.T) {
db := startPostgres(t)
svc := fixtureService(t, db)
t.Cleanup(func() {
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
defer cancel()
svc.Drain(ctx)
_ = svc.Close()
})
// Empty country / zero user — exercise the synchronous validation
// path through the public API to confirm no goroutine is launched.
svc.IncrementCounterAsync(context.Background(), uuid.Nil, "1.2.3.4")
svc.IncrementCounterAsync(context.Background(), uuid.New(), "")
rows := totalCounterRows(t, db)
if rows != 0 {
t.Fatalf("expected zero counter rows after short-circuit calls, got %d", rows)
}
}
func TestListUserCountersOrdered(t *testing.T) {
db := startPostgres(t)
svc := fixtureService(t, db)
t.Cleanup(func() {
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
defer cancel()
svc.Drain(ctx)
_ = svc.Close()
})
userID := uuid.New()
svc.IncrementCounterTestSync(t, userID, "PL")
svc.IncrementCounterTestSync(t, userID, "DE")
svc.IncrementCounterTestSync(t, userID, "DE")
svc.IncrementCounterTestSync(t, userID, "AU")
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
defer cancel()
entries, err := svc.ListUserCounters(ctx, userID)
if err != nil {
t.Fatalf("list: %v", err)
}
if len(entries) != 3 {
t.Fatalf("entries: want 3, got %d (%+v)", len(entries), entries)
}
wantOrder := []string{"AU", "DE", "PL"}
for i, e := range entries {
if e.Country != wantOrder[i] {
t.Errorf("entries[%d].Country = %q, want %q", i, e.Country, wantOrder[i])
}
if e.LastSeenAt == nil {
t.Errorf("entries[%d].LastSeenAt: want non-nil", i)
}
}
if entries[1].Count != 2 {
t.Errorf("entries[1].Count: want 2, got %d", entries[1].Count)
}
}
func TestListUserCountersEmpty(t *testing.T) {
db := startPostgres(t)
svc := fixtureService(t, db)
t.Cleanup(func() {
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
defer cancel()
svc.Drain(ctx)
_ = svc.Close()
})
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
defer cancel()
entries, err := svc.ListUserCounters(ctx, uuid.New())
if err != nil {
t.Fatalf("list unknown user: %v", err)
}
if len(entries) != 0 {
t.Fatalf("entries: want empty, got %+v", entries)
}
}
func TestListUserCountersNilArguments(t *testing.T) {
db := startPostgres(t)
svc := fixtureService(t, db)
t.Cleanup(func() { _ = svc.Close() })
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
defer cancel()
if _, err := svc.ListUserCounters(ctx, uuid.Nil); err == nil {
t.Fatal("ListUserCounters(uuid.Nil): want error")
}
var nilSvc *geo.Service
if _, err := nilSvc.ListUserCounters(ctx, uuid.New()); err == nil {
t.Fatal("nil receiver ListUserCounters: want error")
}
}
func TestDrainAwaitsInFlightCounters(t *testing.T) {
db := startPostgres(t)
svc := fixtureService(t, db)
userID := uuid.New()
// Inject country directly through the test seam so the lookup never
// returns empty even though the resolver is unset.
svc.IncrementCounterTestSync(t, userID, "FR")
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
defer cancel()
svc.Drain(ctx)
if err := svc.Close(); err != nil {
t.Fatalf("close: %v", err)
}
count, _ := readCounter(t, db, userID, "FR")
if count != 1 {
t.Fatalf("count after drain+close: want 1, got %d", count)
}
}
func readCounter(t *testing.T, db *sql.DB, userID uuid.UUID, country string) (int64, *time.Time) {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
defer cancel()
var (
count int64
lastSeenAt sql.NullTime
)
err := db.QueryRowContext(ctx, `
SELECT count, last_seen_at FROM backend.user_country_counters
WHERE user_id = $1 AND country = $2
`, userID, country).Scan(&count, &lastSeenAt)
if err != nil {
t.Fatalf("read counter (%s/%s): %v", userID, country, err)
}
if !lastSeenAt.Valid {
return count, nil
}
ts := lastSeenAt.Time.UTC()
return count, &ts
}
func totalCounterRows(t *testing.T, db *sql.DB) int {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
defer cancel()
var n int
if err := db.QueryRowContext(ctx, `
SELECT count(*) FROM backend.user_country_counters
`).Scan(&n); err != nil {
t.Fatalf("count rows: %v", err)
}
return n
}