321 lines
8.6 KiB
Go
321 lines
8.6 KiB
Go
package geo_test
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"net/url"
|
|
"testing"
|
|
"time"
|
|
|
|
"galaxy/backend/internal/geo"
|
|
backendpg "galaxy/backend/internal/postgres"
|
|
pgshared "galaxy/postgres"
|
|
|
|
"github.com/google/uuid"
|
|
testcontainers "github.com/testcontainers/testcontainers-go"
|
|
tcpostgres "github.com/testcontainers/testcontainers-go/modules/postgres"
|
|
"github.com/testcontainers/testcontainers-go/wait"
|
|
"go.uber.org/zap/zaptest"
|
|
)
|
|
|
|
const (
|
|
pgImage = "postgres:16-alpine"
|
|
pgUser = "galaxy"
|
|
pgPassword = "galaxy"
|
|
pgDatabase = "galaxy_backend"
|
|
pgSchema = "backend"
|
|
pgStartup = 90 * time.Second
|
|
pgOpTO = 10 * time.Second
|
|
)
|
|
|
|
// startPostgres mirrors the auth/notification test scaffolding: spin up
|
|
// a Postgres testcontainer, apply backend migrations, return *sql.DB.
|
|
func startPostgres(t *testing.T) *sql.DB {
|
|
t.Helper()
|
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
|
t.Cleanup(cancel)
|
|
|
|
pgContainer, err := tcpostgres.Run(ctx, pgImage,
|
|
tcpostgres.WithDatabase(pgDatabase),
|
|
tcpostgres.WithUsername(pgUser),
|
|
tcpostgres.WithPassword(pgPassword),
|
|
testcontainers.WithWaitStrategy(
|
|
wait.ForLog("database system is ready to accept connections").
|
|
WithOccurrence(2).
|
|
WithStartupTimeout(pgStartup),
|
|
),
|
|
)
|
|
if err != nil {
|
|
t.Skipf("postgres testcontainer unavailable, skipping: %v", err)
|
|
}
|
|
t.Cleanup(func() {
|
|
if termErr := testcontainers.TerminateContainer(pgContainer); termErr != nil {
|
|
t.Errorf("terminate postgres container: %v", termErr)
|
|
}
|
|
})
|
|
|
|
baseDSN, err := pgContainer.ConnectionString(ctx, "sslmode=disable")
|
|
if err != nil {
|
|
t.Fatalf("connection string: %v", err)
|
|
}
|
|
scoped, err := dsnWithSearchPath(baseDSN, pgSchema)
|
|
if err != nil {
|
|
t.Fatalf("scope dsn: %v", err)
|
|
}
|
|
cfg := pgshared.DefaultConfig()
|
|
cfg.PrimaryDSN = scoped
|
|
cfg.OperationTimeout = pgOpTO
|
|
db, err := pgshared.OpenPrimary(ctx, cfg)
|
|
if err != nil {
|
|
t.Fatalf("open primary: %v", err)
|
|
}
|
|
t.Cleanup(func() { _ = db.Close() })
|
|
if err := pgshared.Ping(ctx, db, cfg.OperationTimeout); err != nil {
|
|
t.Fatalf("ping: %v", err)
|
|
}
|
|
if err := backendpg.ApplyMigrations(ctx, db); err != nil {
|
|
t.Fatalf("apply migrations: %v", err)
|
|
}
|
|
return db
|
|
}
|
|
|
|
func dsnWithSearchPath(baseDSN, schema string) (string, error) {
|
|
parsed, err := url.Parse(baseDSN)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
values := parsed.Query()
|
|
values.Set("search_path", schema)
|
|
if values.Get("sslmode") == "" {
|
|
values.Set("sslmode", "disable")
|
|
}
|
|
parsed.RawQuery = values.Encode()
|
|
return parsed.String(), nil
|
|
}
|
|
|
|
// fixtureService constructs a Service that uses an injected database
|
|
// pool and skips the GeoLite2 resolver — the resolver is exercised by
|
|
// `pkg/geoip` tests, while the counter path under test is independent
|
|
// of the lookup. The caller is responsible for invoking Drain/Close.
|
|
func fixtureService(t *testing.T, db *sql.DB) *geo.Service {
|
|
t.Helper()
|
|
svc, err := geo.NewServiceForTest(db)
|
|
if err != nil {
|
|
t.Fatalf("new service: %v", err)
|
|
}
|
|
svc.SetLogger(zaptest.NewLogger(t))
|
|
return svc
|
|
}
|
|
|
|
func TestIncrementCounterAsyncCreatesRow(t *testing.T) {
|
|
db := startPostgres(t)
|
|
svc := fixtureService(t, db)
|
|
t.Cleanup(func() {
|
|
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
|
|
defer cancel()
|
|
svc.Drain(ctx)
|
|
_ = svc.Close()
|
|
})
|
|
|
|
userID := uuid.New()
|
|
svc.IncrementCounterTestSync(t, userID, "DE")
|
|
|
|
count, lastSeen := readCounter(t, db, userID, "DE")
|
|
if count != 1 {
|
|
t.Fatalf("count: want 1, got %d", count)
|
|
}
|
|
if lastSeen == nil {
|
|
t.Fatal("last_seen_at: want non-null, got null")
|
|
}
|
|
}
|
|
|
|
func TestIncrementCounterAsyncIncrementsExistingRow(t *testing.T) {
|
|
db := startPostgres(t)
|
|
svc := fixtureService(t, db)
|
|
t.Cleanup(func() {
|
|
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
|
|
defer cancel()
|
|
svc.Drain(ctx)
|
|
_ = svc.Close()
|
|
})
|
|
|
|
userID := uuid.New()
|
|
svc.IncrementCounterTestSync(t, userID, "DE")
|
|
_, firstSeen := readCounter(t, db, userID, "DE")
|
|
if firstSeen == nil {
|
|
t.Fatal("first last_seen_at: want non-null")
|
|
}
|
|
|
|
// Sleep long enough for now() to advance past Postgres timestamp
|
|
// resolution (microseconds in practice).
|
|
time.Sleep(2 * time.Millisecond)
|
|
|
|
svc.IncrementCounterTestSync(t, userID, "DE")
|
|
count, secondSeen := readCounter(t, db, userID, "DE")
|
|
if count != 2 {
|
|
t.Fatalf("count: want 2, got %d", count)
|
|
}
|
|
if secondSeen == nil || !secondSeen.After(*firstSeen) {
|
|
t.Fatalf("last_seen_at: want strictly later than %v, got %v", firstSeen, secondSeen)
|
|
}
|
|
}
|
|
|
|
func TestIncrementCounterAsyncShortCircuits(t *testing.T) {
|
|
db := startPostgres(t)
|
|
svc := fixtureService(t, db)
|
|
t.Cleanup(func() {
|
|
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
|
|
defer cancel()
|
|
svc.Drain(ctx)
|
|
_ = svc.Close()
|
|
})
|
|
|
|
// Empty country / zero user — exercise the synchronous validation
|
|
// path through the public API to confirm no goroutine is launched.
|
|
svc.IncrementCounterAsync(context.Background(), uuid.Nil, "1.2.3.4")
|
|
svc.IncrementCounterAsync(context.Background(), uuid.New(), "")
|
|
|
|
rows := totalCounterRows(t, db)
|
|
if rows != 0 {
|
|
t.Fatalf("expected zero counter rows after short-circuit calls, got %d", rows)
|
|
}
|
|
}
|
|
|
|
func TestListUserCountersOrdered(t *testing.T) {
|
|
db := startPostgres(t)
|
|
svc := fixtureService(t, db)
|
|
t.Cleanup(func() {
|
|
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
|
|
defer cancel()
|
|
svc.Drain(ctx)
|
|
_ = svc.Close()
|
|
})
|
|
|
|
userID := uuid.New()
|
|
svc.IncrementCounterTestSync(t, userID, "PL")
|
|
svc.IncrementCounterTestSync(t, userID, "DE")
|
|
svc.IncrementCounterTestSync(t, userID, "DE")
|
|
svc.IncrementCounterTestSync(t, userID, "AU")
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
|
|
defer cancel()
|
|
|
|
entries, err := svc.ListUserCounters(ctx, userID)
|
|
if err != nil {
|
|
t.Fatalf("list: %v", err)
|
|
}
|
|
if len(entries) != 3 {
|
|
t.Fatalf("entries: want 3, got %d (%+v)", len(entries), entries)
|
|
}
|
|
wantOrder := []string{"AU", "DE", "PL"}
|
|
for i, e := range entries {
|
|
if e.Country != wantOrder[i] {
|
|
t.Errorf("entries[%d].Country = %q, want %q", i, e.Country, wantOrder[i])
|
|
}
|
|
if e.LastSeenAt == nil {
|
|
t.Errorf("entries[%d].LastSeenAt: want non-nil", i)
|
|
}
|
|
}
|
|
if entries[1].Count != 2 {
|
|
t.Errorf("entries[1].Count: want 2, got %d", entries[1].Count)
|
|
}
|
|
}
|
|
|
|
func TestListUserCountersEmpty(t *testing.T) {
|
|
db := startPostgres(t)
|
|
svc := fixtureService(t, db)
|
|
t.Cleanup(func() {
|
|
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
|
|
defer cancel()
|
|
svc.Drain(ctx)
|
|
_ = svc.Close()
|
|
})
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
|
|
defer cancel()
|
|
|
|
entries, err := svc.ListUserCounters(ctx, uuid.New())
|
|
if err != nil {
|
|
t.Fatalf("list unknown user: %v", err)
|
|
}
|
|
if len(entries) != 0 {
|
|
t.Fatalf("entries: want empty, got %+v", entries)
|
|
}
|
|
}
|
|
|
|
func TestListUserCountersNilArguments(t *testing.T) {
|
|
db := startPostgres(t)
|
|
svc := fixtureService(t, db)
|
|
t.Cleanup(func() { _ = svc.Close() })
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
|
|
defer cancel()
|
|
|
|
if _, err := svc.ListUserCounters(ctx, uuid.Nil); err == nil {
|
|
t.Fatal("ListUserCounters(uuid.Nil): want error")
|
|
}
|
|
|
|
var nilSvc *geo.Service
|
|
if _, err := nilSvc.ListUserCounters(ctx, uuid.New()); err == nil {
|
|
t.Fatal("nil receiver ListUserCounters: want error")
|
|
}
|
|
}
|
|
|
|
func TestDrainAwaitsInFlightCounters(t *testing.T) {
|
|
db := startPostgres(t)
|
|
svc := fixtureService(t, db)
|
|
|
|
userID := uuid.New()
|
|
// Inject country directly through the test seam so the lookup never
|
|
// returns empty even though the resolver is unset.
|
|
svc.IncrementCounterTestSync(t, userID, "FR")
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
|
|
defer cancel()
|
|
svc.Drain(ctx)
|
|
if err := svc.Close(); err != nil {
|
|
t.Fatalf("close: %v", err)
|
|
}
|
|
count, _ := readCounter(t, db, userID, "FR")
|
|
if count != 1 {
|
|
t.Fatalf("count after drain+close: want 1, got %d", count)
|
|
}
|
|
}
|
|
|
|
func readCounter(t *testing.T, db *sql.DB, userID uuid.UUID, country string) (int64, *time.Time) {
|
|
t.Helper()
|
|
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
|
|
defer cancel()
|
|
|
|
var (
|
|
count int64
|
|
lastSeenAt sql.NullTime
|
|
)
|
|
err := db.QueryRowContext(ctx, `
|
|
SELECT count, last_seen_at FROM backend.user_country_counters
|
|
WHERE user_id = $1 AND country = $2
|
|
`, userID, country).Scan(&count, &lastSeenAt)
|
|
if err != nil {
|
|
t.Fatalf("read counter (%s/%s): %v", userID, country, err)
|
|
}
|
|
if !lastSeenAt.Valid {
|
|
return count, nil
|
|
}
|
|
ts := lastSeenAt.Time.UTC()
|
|
return count, &ts
|
|
}
|
|
|
|
func totalCounterRows(t *testing.T, db *sql.DB) int {
|
|
t.Helper()
|
|
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
|
|
defer cancel()
|
|
|
|
var n int
|
|
if err := db.QueryRowContext(ctx, `
|
|
SELECT count(*) FROM backend.user_country_counters
|
|
`).Scan(&n); err != nil {
|
|
t.Fatalf("count rows: %v", err)
|
|
}
|
|
return n
|
|
}
|