feat: backend service
This commit is contained in:
@@ -0,0 +1,320 @@
|
||||
package geo_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/backend/internal/geo"
|
||||
backendpg "galaxy/backend/internal/postgres"
|
||||
pgshared "galaxy/postgres"
|
||||
|
||||
"github.com/google/uuid"
|
||||
testcontainers "github.com/testcontainers/testcontainers-go"
|
||||
tcpostgres "github.com/testcontainers/testcontainers-go/modules/postgres"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
"go.uber.org/zap/zaptest"
|
||||
)
|
||||
|
||||
const (
|
||||
pgImage = "postgres:16-alpine"
|
||||
pgUser = "galaxy"
|
||||
pgPassword = "galaxy"
|
||||
pgDatabase = "galaxy_backend"
|
||||
pgSchema = "backend"
|
||||
pgStartup = 90 * time.Second
|
||||
pgOpTO = 10 * time.Second
|
||||
)
|
||||
|
||||
// startPostgres mirrors the auth/notification test scaffolding: spin up
|
||||
// a Postgres testcontainer, apply backend migrations, return *sql.DB.
|
||||
func startPostgres(t *testing.T) *sql.DB {
|
||||
t.Helper()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
pgContainer, err := tcpostgres.Run(ctx, pgImage,
|
||||
tcpostgres.WithDatabase(pgDatabase),
|
||||
tcpostgres.WithUsername(pgUser),
|
||||
tcpostgres.WithPassword(pgPassword),
|
||||
testcontainers.WithWaitStrategy(
|
||||
wait.ForLog("database system is ready to accept connections").
|
||||
WithOccurrence(2).
|
||||
WithStartupTimeout(pgStartup),
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
t.Skipf("postgres testcontainer unavailable, skipping: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if termErr := testcontainers.TerminateContainer(pgContainer); termErr != nil {
|
||||
t.Errorf("terminate postgres container: %v", termErr)
|
||||
}
|
||||
})
|
||||
|
||||
baseDSN, err := pgContainer.ConnectionString(ctx, "sslmode=disable")
|
||||
if err != nil {
|
||||
t.Fatalf("connection string: %v", err)
|
||||
}
|
||||
scoped, err := dsnWithSearchPath(baseDSN, pgSchema)
|
||||
if err != nil {
|
||||
t.Fatalf("scope dsn: %v", err)
|
||||
}
|
||||
cfg := pgshared.DefaultConfig()
|
||||
cfg.PrimaryDSN = scoped
|
||||
cfg.OperationTimeout = pgOpTO
|
||||
db, err := pgshared.OpenPrimary(ctx, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("open primary: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
if err := pgshared.Ping(ctx, db, cfg.OperationTimeout); err != nil {
|
||||
t.Fatalf("ping: %v", err)
|
||||
}
|
||||
if err := backendpg.ApplyMigrations(ctx, db); err != nil {
|
||||
t.Fatalf("apply migrations: %v", err)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
func dsnWithSearchPath(baseDSN, schema string) (string, error) {
|
||||
parsed, err := url.Parse(baseDSN)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
values := parsed.Query()
|
||||
values.Set("search_path", schema)
|
||||
if values.Get("sslmode") == "" {
|
||||
values.Set("sslmode", "disable")
|
||||
}
|
||||
parsed.RawQuery = values.Encode()
|
||||
return parsed.String(), nil
|
||||
}
|
||||
|
||||
// fixtureService constructs a Service that uses an injected database
|
||||
// pool and skips the GeoLite2 resolver — the resolver is exercised by
|
||||
// `pkg/geoip` tests, while the counter path under test is independent
|
||||
// of the lookup. The caller is responsible for invoking Drain/Close.
|
||||
func fixtureService(t *testing.T, db *sql.DB) *geo.Service {
|
||||
t.Helper()
|
||||
svc, err := geo.NewServiceForTest(db)
|
||||
if err != nil {
|
||||
t.Fatalf("new service: %v", err)
|
||||
}
|
||||
svc.SetLogger(zaptest.NewLogger(t))
|
||||
return svc
|
||||
}
|
||||
|
||||
func TestIncrementCounterAsyncCreatesRow(t *testing.T) {
|
||||
db := startPostgres(t)
|
||||
svc := fixtureService(t, db)
|
||||
t.Cleanup(func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
|
||||
defer cancel()
|
||||
svc.Drain(ctx)
|
||||
_ = svc.Close()
|
||||
})
|
||||
|
||||
userID := uuid.New()
|
||||
svc.IncrementCounterTestSync(t, userID, "DE")
|
||||
|
||||
count, lastSeen := readCounter(t, db, userID, "DE")
|
||||
if count != 1 {
|
||||
t.Fatalf("count: want 1, got %d", count)
|
||||
}
|
||||
if lastSeen == nil {
|
||||
t.Fatal("last_seen_at: want non-null, got null")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIncrementCounterAsyncIncrementsExistingRow(t *testing.T) {
|
||||
db := startPostgres(t)
|
||||
svc := fixtureService(t, db)
|
||||
t.Cleanup(func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
|
||||
defer cancel()
|
||||
svc.Drain(ctx)
|
||||
_ = svc.Close()
|
||||
})
|
||||
|
||||
userID := uuid.New()
|
||||
svc.IncrementCounterTestSync(t, userID, "DE")
|
||||
_, firstSeen := readCounter(t, db, userID, "DE")
|
||||
if firstSeen == nil {
|
||||
t.Fatal("first last_seen_at: want non-null")
|
||||
}
|
||||
|
||||
// Sleep long enough for now() to advance past Postgres timestamp
|
||||
// resolution (microseconds in practice).
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
|
||||
svc.IncrementCounterTestSync(t, userID, "DE")
|
||||
count, secondSeen := readCounter(t, db, userID, "DE")
|
||||
if count != 2 {
|
||||
t.Fatalf("count: want 2, got %d", count)
|
||||
}
|
||||
if secondSeen == nil || !secondSeen.After(*firstSeen) {
|
||||
t.Fatalf("last_seen_at: want strictly later than %v, got %v", firstSeen, secondSeen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIncrementCounterAsyncShortCircuits(t *testing.T) {
|
||||
db := startPostgres(t)
|
||||
svc := fixtureService(t, db)
|
||||
t.Cleanup(func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
|
||||
defer cancel()
|
||||
svc.Drain(ctx)
|
||||
_ = svc.Close()
|
||||
})
|
||||
|
||||
// Empty country / zero user — exercise the synchronous validation
|
||||
// path through the public API to confirm no goroutine is launched.
|
||||
svc.IncrementCounterAsync(context.Background(), uuid.Nil, "1.2.3.4")
|
||||
svc.IncrementCounterAsync(context.Background(), uuid.New(), "")
|
||||
|
||||
rows := totalCounterRows(t, db)
|
||||
if rows != 0 {
|
||||
t.Fatalf("expected zero counter rows after short-circuit calls, got %d", rows)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListUserCountersOrdered(t *testing.T) {
|
||||
db := startPostgres(t)
|
||||
svc := fixtureService(t, db)
|
||||
t.Cleanup(func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
|
||||
defer cancel()
|
||||
svc.Drain(ctx)
|
||||
_ = svc.Close()
|
||||
})
|
||||
|
||||
userID := uuid.New()
|
||||
svc.IncrementCounterTestSync(t, userID, "PL")
|
||||
svc.IncrementCounterTestSync(t, userID, "DE")
|
||||
svc.IncrementCounterTestSync(t, userID, "DE")
|
||||
svc.IncrementCounterTestSync(t, userID, "AU")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
|
||||
defer cancel()
|
||||
|
||||
entries, err := svc.ListUserCounters(ctx, userID)
|
||||
if err != nil {
|
||||
t.Fatalf("list: %v", err)
|
||||
}
|
||||
if len(entries) != 3 {
|
||||
t.Fatalf("entries: want 3, got %d (%+v)", len(entries), entries)
|
||||
}
|
||||
wantOrder := []string{"AU", "DE", "PL"}
|
||||
for i, e := range entries {
|
||||
if e.Country != wantOrder[i] {
|
||||
t.Errorf("entries[%d].Country = %q, want %q", i, e.Country, wantOrder[i])
|
||||
}
|
||||
if e.LastSeenAt == nil {
|
||||
t.Errorf("entries[%d].LastSeenAt: want non-nil", i)
|
||||
}
|
||||
}
|
||||
if entries[1].Count != 2 {
|
||||
t.Errorf("entries[1].Count: want 2, got %d", entries[1].Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListUserCountersEmpty(t *testing.T) {
|
||||
db := startPostgres(t)
|
||||
svc := fixtureService(t, db)
|
||||
t.Cleanup(func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
|
||||
defer cancel()
|
||||
svc.Drain(ctx)
|
||||
_ = svc.Close()
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
|
||||
defer cancel()
|
||||
|
||||
entries, err := svc.ListUserCounters(ctx, uuid.New())
|
||||
if err != nil {
|
||||
t.Fatalf("list unknown user: %v", err)
|
||||
}
|
||||
if len(entries) != 0 {
|
||||
t.Fatalf("entries: want empty, got %+v", entries)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListUserCountersNilArguments(t *testing.T) {
|
||||
db := startPostgres(t)
|
||||
svc := fixtureService(t, db)
|
||||
t.Cleanup(func() { _ = svc.Close() })
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
|
||||
defer cancel()
|
||||
|
||||
if _, err := svc.ListUserCounters(ctx, uuid.Nil); err == nil {
|
||||
t.Fatal("ListUserCounters(uuid.Nil): want error")
|
||||
}
|
||||
|
||||
var nilSvc *geo.Service
|
||||
if _, err := nilSvc.ListUserCounters(ctx, uuid.New()); err == nil {
|
||||
t.Fatal("nil receiver ListUserCounters: want error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDrainAwaitsInFlightCounters(t *testing.T) {
|
||||
db := startPostgres(t)
|
||||
svc := fixtureService(t, db)
|
||||
|
||||
userID := uuid.New()
|
||||
// Inject country directly through the test seam so the lookup never
|
||||
// returns empty even though the resolver is unset.
|
||||
svc.IncrementCounterTestSync(t, userID, "FR")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
|
||||
defer cancel()
|
||||
svc.Drain(ctx)
|
||||
if err := svc.Close(); err != nil {
|
||||
t.Fatalf("close: %v", err)
|
||||
}
|
||||
count, _ := readCounter(t, db, userID, "FR")
|
||||
if count != 1 {
|
||||
t.Fatalf("count after drain+close: want 1, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func readCounter(t *testing.T, db *sql.DB, userID uuid.UUID, country string) (int64, *time.Time) {
|
||||
t.Helper()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
|
||||
defer cancel()
|
||||
|
||||
var (
|
||||
count int64
|
||||
lastSeenAt sql.NullTime
|
||||
)
|
||||
err := db.QueryRowContext(ctx, `
|
||||
SELECT count, last_seen_at FROM backend.user_country_counters
|
||||
WHERE user_id = $1 AND country = $2
|
||||
`, userID, country).Scan(&count, &lastSeenAt)
|
||||
if err != nil {
|
||||
t.Fatalf("read counter (%s/%s): %v", userID, country, err)
|
||||
}
|
||||
if !lastSeenAt.Valid {
|
||||
return count, nil
|
||||
}
|
||||
ts := lastSeenAt.Time.UTC()
|
||||
return count, &ts
|
||||
}
|
||||
|
||||
func totalCounterRows(t *testing.T, db *sql.DB) int {
|
||||
t.Helper()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pgOpTO)
|
||||
defer cancel()
|
||||
|
||||
var n int
|
||||
if err := db.QueryRowContext(ctx, `
|
||||
SELECT count(*) FROM backend.user_country_counters
|
||||
`).Scan(&n); err != nil {
|
||||
t.Fatalf("count rows: %v", err)
|
||||
}
|
||||
return n
|
||||
}
|
||||
Reference in New Issue
Block a user