package notification_test import ( "context" "database/sql" "errors" "net/url" "sync" "testing" "time" "galaxy/backend/internal/config" "galaxy/backend/internal/notification" backendpg "galaxy/backend/internal/postgres" "galaxy/backend/internal/user" pgshared "galaxy/postgres" "github.com/google/uuid" testcontainers "github.com/testcontainers/testcontainers-go" tcpostgres "github.com/testcontainers/testcontainers-go/modules/postgres" "github.com/testcontainers/testcontainers-go/wait" "go.uber.org/zap/zaptest" ) const ( pgImage = "postgres:16-alpine" pgUser = "galaxy" pgPassword = "galaxy" pgDatabase = "galaxy_backend" pgSchema = "backend" pgStartup = 90 * time.Second pgOpTO = 10 * time.Second ) // startPostgres mirrors the mail/auth scaffolding: spin up Postgres, // apply migrations, return *sql.DB. func startPostgres(t *testing.T) *sql.DB { t.Helper() ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) t.Cleanup(cancel) pgContainer, err := tcpostgres.Run(ctx, pgImage, tcpostgres.WithDatabase(pgDatabase), tcpostgres.WithUsername(pgUser), tcpostgres.WithPassword(pgPassword), testcontainers.WithWaitStrategy( wait.ForLog("database system is ready to accept connections"). WithOccurrence(2). WithStartupTimeout(pgStartup), ), ) if err != nil { t.Skipf("postgres testcontainer unavailable, skipping: %v", err) } t.Cleanup(func() { if termErr := testcontainers.TerminateContainer(pgContainer); termErr != nil { t.Errorf("terminate postgres container: %v", termErr) } }) baseDSN, err := pgContainer.ConnectionString(ctx, "sslmode=disable") if err != nil { t.Fatalf("connection string: %v", err) } scoped, err := dsnWithSearchPath(baseDSN, pgSchema) if err != nil { t.Fatalf("scope dsn: %v", err) } cfg := pgshared.DefaultConfig() cfg.PrimaryDSN = scoped cfg.OperationTimeout = pgOpTO db, err := pgshared.OpenPrimary(ctx, cfg) if err != nil { t.Fatalf("open primary: %v", err) } t.Cleanup(func() { _ = db.Close() }) if err := pgshared.Ping(ctx, db, cfg.OperationTimeout); err != nil { t.Fatalf("ping: %v", err) } if err := backendpg.ApplyMigrations(ctx, db); err != nil { t.Fatalf("apply migrations: %v", err) } return db } func dsnWithSearchPath(baseDSN, schema string) (string, error) { parsed, err := url.Parse(baseDSN) if err != nil { return "", err } values := parsed.Query() values.Set("search_path", schema) if values.Get("sslmode") == "" { values.Set("sslmode", "disable") } parsed.RawQuery = values.Encode() return parsed.String(), nil } // recordingMailer captures every EnqueueTemplate call. type recordingMailer struct { mu sync.Mutex calls []recordedEnqueue err error } type recordedEnqueue struct { TemplateID string Recipient string Payload map[string]any IdempotencyKey string } func (r *recordingMailer) EnqueueTemplate(_ context.Context, templateID, recipient string, payload map[string]any, idempotencyKey string) error { r.mu.Lock() defer r.mu.Unlock() if r.err != nil { return r.err } r.calls = append(r.calls, recordedEnqueue{ TemplateID: templateID, Recipient: recipient, Payload: payload, IdempotencyKey: idempotencyKey, }) return nil } func (r *recordingMailer) Calls() []recordedEnqueue { r.mu.Lock() defer r.mu.Unlock() out := make([]recordedEnqueue, len(r.calls)) copy(out, r.calls) return out } // recordingPush captures every PublishClientEvent call. type recordingPush struct { mu sync.Mutex calls []recordedPushEvent } type recordedPushEvent struct { UserID uuid.UUID Kind string Payload map[string]any EventID string RequestID string TraceID string } func (r *recordingPush) PublishClientEvent(_ context.Context, userID uuid.UUID, _ *uuid.UUID, kind string, payload map[string]any, eventID, requestID, traceID string) error { r.mu.Lock() defer r.mu.Unlock() r.calls = append(r.calls, recordedPushEvent{ UserID: userID, Kind: kind, Payload: payload, EventID: eventID, RequestID: requestID, TraceID: traceID, }) return nil } func (r *recordingPush) Calls() []recordedPushEvent { r.mu.Lock() defer r.mu.Unlock() out := make([]recordedPushEvent, len(r.calls)) copy(out, r.calls) return out } // stubAccounts hands back a fixed account record for any user_id, so // tests don't need to seed the accounts table. type stubAccounts struct { account user.Account err error } func (s *stubAccounts) GetAccount(_ context.Context, userID uuid.UUID) (user.Account, error) { if s.err != nil { return user.Account{}, s.err } out := s.account out.UserID = userID return out, nil } func newService(t *testing.T, db *sql.DB, mailer notification.Mailer, push notification.PushPublisher, accounts notification.AccountResolver, adminEmail string) *notification.Service { t.Helper() cfg := config.NotificationConfig{ AdminEmail: adminEmail, WorkerInterval: 10 * time.Millisecond, MaxAttempts: 3, } return notification.NewService(notification.Deps{ Store: notification.NewStore(db), Mail: mailer, Push: push, Accounts: accounts, Config: cfg, Logger: zaptest.NewLogger(t), }) } func TestSubmitFansOutLobbyInviteToPushAndEmail(t *testing.T) { t.Parallel() db := startPostgres(t) mailer := &recordingMailer{} push := &recordingPush{} accounts := &stubAccounts{account: user.Account{ Email: "alice@example.test", PreferredLanguage: "en", }} svc := newService(t, db, mailer, push, accounts, "") recipient := uuid.New() id, err := svc.Submit(context.Background(), notification.Intent{ Kind: notification.KindLobbyInviteReceived, IdempotencyKey: "invite:" + uuid.NewString(), Recipients: []uuid.UUID{recipient}, Payload: map[string]any{ "game_id": uuid.NewString(), "inviter_user_id": uuid.NewString(), }, }) if err != nil { t.Fatalf("submit: %v", err) } if id == uuid.Nil { t.Fatal("submit returned nil id") } // Best-effort dispatch ran synchronously; both channels should // have observed exactly one call. if got := len(push.Calls()); got != 1 { t.Errorf("push calls=%d, want 1", got) } if got := len(mailer.Calls()); got != 1 { t.Errorf("mail calls=%d, want 1", got) } else { call := mailer.Calls()[0] if call.Recipient != "alice@example.test" { t.Errorf("mail recipient=%q", call.Recipient) } if call.TemplateID != notification.KindLobbyInviteReceived { t.Errorf("mail template=%q", call.TemplateID) } } } func TestSubmitIsIdempotent(t *testing.T) { t.Parallel() db := startPostgres(t) svc := newService(t, db, &recordingMailer{}, &recordingPush{}, &stubAccounts{account: user.Account{Email: "x@example.test"}}, "") intent := notification.Intent{ Kind: notification.KindLobbyApplicationSubmitted, IdempotencyKey: "dedupe-key", Recipients: []uuid.UUID{uuid.New()}, Payload: map[string]any{"game_id": uuid.NewString(), "application_id": uuid.NewString()}, } first, err := svc.Submit(context.Background(), intent) if err != nil { t.Fatalf("first submit: %v", err) } second, err := svc.Submit(context.Background(), intent) if err != nil { t.Fatalf("second submit: %v", err) } if first != second { t.Fatalf("idempotent submit must return same id: %s vs %s", first, second) } } func TestSubmitMalformedPersists(t *testing.T) { t.Parallel() db := startPostgres(t) svc := newService(t, db, &recordingMailer{}, &recordingPush{}, &stubAccounts{}, "") id, err := svc.Submit(context.Background(), notification.Intent{ Kind: "nonsense.kind", IdempotencyKey: "anything", Recipients: []uuid.UUID{uuid.New()}, }) if err != nil { t.Fatalf("submit: %v", err) } if id != uuid.Nil { t.Fatalf("malformed submit must return nil id, got %s", id) } page, err := svc.AdminListMalformed(context.Background(), 1, 10) if err != nil { t.Fatalf("list malformed: %v", err) } if page.Total < 1 { t.Fatalf("malformed total=%d, want >= 1", page.Total) } } func TestSubmitAdminEmailSkipsWhenNotConfigured(t *testing.T) { t.Parallel() db := startPostgres(t) mailer := &recordingMailer{} svc := newService(t, db, mailer, &recordingPush{}, &stubAccounts{}, "") id, err := svc.Submit(context.Background(), notification.Intent{ Kind: notification.KindRuntimeImagePullFailed, IdempotencyKey: "ipf-1", Payload: map[string]any{"game_id": uuid.NewString(), "image_ref": "registry/img:tag"}, }) if err != nil { t.Fatalf("submit: %v", err) } if id == uuid.Nil { t.Fatal("admin submit returned nil id") } if got := len(mailer.Calls()); got != 0 { t.Errorf("mail calls=%d, want 0 (admin email unset)", got) } } func TestSubmitAdminEmailDispatchesWhenConfigured(t *testing.T) { t.Parallel() db := startPostgres(t) mailer := &recordingMailer{} svc := newService(t, db, mailer, &recordingPush{}, &stubAccounts{}, "ops@example.test") if _, err := svc.Submit(context.Background(), notification.Intent{ Kind: notification.KindRuntimeContainerStartFailed, IdempotencyKey: "csf-1", Payload: map[string]any{"game_id": uuid.NewString()}, }); err != nil { t.Fatalf("submit: %v", err) } calls := mailer.Calls() if len(calls) != 1 { t.Fatalf("mail calls=%d, want 1", len(calls)) } if calls[0].Recipient != "ops@example.test" { t.Errorf("admin recipient=%q", calls[0].Recipient) } } func TestSubmitMissingAccountSkipsEmail(t *testing.T) { t.Parallel() db := startPostgres(t) mailer := &recordingMailer{} push := &recordingPush{} accounts := &stubAccounts{err: user.ErrAccountNotFound} svc := newService(t, db, mailer, push, accounts, "") if _, err := svc.Submit(context.Background(), notification.Intent{ Kind: notification.KindLobbyApplicationApproved, IdempotencyKey: "missing-1", Recipients: []uuid.UUID{uuid.New()}, Payload: map[string]any{"game_id": uuid.NewString()}, }); err != nil { t.Fatalf("submit: %v", err) } if got := len(mailer.Calls()); got != 0 { t.Errorf("mail calls=%d want 0 when account missing", got) } if got := len(push.Calls()); got != 0 { t.Errorf("push calls=%d want 0 when account missing", got) } } func TestWorkerRetryAndDeadLetter(t *testing.T) { t.Parallel() db := startPostgres(t) failingMailer := &recordingMailer{err: errors.New("smtp down")} push := &recordingPush{} accounts := &stubAccounts{account: user.Account{Email: "alice@example.test", PreferredLanguage: "en"}} svc := newService(t, db, failingMailer, push, accounts, "") // MaxAttempts=3 from newService config. Submit fires one // best-effort attempt; subsequent Tick calls drive attempts 2 and // 3, the last one dead-letters. if _, err := svc.Submit(context.Background(), notification.Intent{ Kind: notification.KindLobbyInviteReceived, IdempotencyKey: "fail-1", Recipients: []uuid.UUID{uuid.New()}, Payload: map[string]any{"game_id": uuid.NewString(), "inviter_user_id": uuid.NewString()}, }); err != nil { t.Fatalf("submit: %v", err) } // Force every retry to be due immediately. if _, err := db.Exec(`UPDATE backend.notification_routes SET next_attempt_at = now() WHERE channel = 'email'`); err != nil { t.Fatalf("force due: %v", err) } worker := notification.NewWorker(svc) for range 5 { if err := worker.Tick(context.Background()); err != nil { t.Fatalf("tick: %v", err) } if _, err := db.Exec(`UPDATE backend.notification_routes SET next_attempt_at = now() WHERE channel = 'email' AND status = 'retrying'`); err != nil { t.Fatalf("force due: %v", err) } } dead, err := svc.AdminListDeadLetters(context.Background(), 1, 10) if err != nil { t.Fatalf("list dead-letters: %v", err) } if dead.Total < 1 { t.Fatalf("expected dead-letter row, got total=%d (mail attempts=%d)", dead.Total, len(failingMailer.Calls())) } } func TestOnUserDeletedSkipsPendingRoutes(t *testing.T) { t.Parallel() db := startPostgres(t) failingMailer := &recordingMailer{err: errors.New("smtp down")} push := &recordingPush{} userID := uuid.New() accounts := &stubAccounts{account: user.Account{Email: "alice@example.test", PreferredLanguage: "en"}} svc := newService(t, db, failingMailer, push, accounts, "") // Submit something that owns user_id so the cascade picks it up. if _, err := svc.Submit(context.Background(), notification.Intent{ Kind: notification.KindLobbyApplicationApproved, IdempotencyKey: "cascade-1", Recipients: []uuid.UUID{userID}, Payload: map[string]any{"game_id": uuid.NewString()}, }); err != nil { t.Fatalf("submit: %v", err) } if err := svc.OnUserDeleted(context.Background(), userID); err != nil { t.Fatalf("OnUserDeleted: %v", err) } var skipped int if err := db.QueryRow(` SELECT COUNT(*) FROM backend.notification_routes r JOIN backend.notifications n ON n.notification_id = r.notification_id WHERE n.user_id = $1 AND r.status = 'skipped' `, userID).Scan(&skipped); err != nil { t.Fatalf("count skipped: %v", err) } if skipped == 0 { t.Fatal("expected at least one skipped route after cascade") } } func TestAdminGetMissing(t *testing.T) { t.Parallel() db := startPostgres(t) svc := newService(t, db, &recordingMailer{}, &recordingPush{}, &stubAccounts{}, "") if _, err := svc.AdminGetNotification(context.Background(), uuid.New()); !errors.Is(err, notification.ErrNotificationNotFound) { t.Fatalf("got %v, want ErrNotificationNotFound", err) } }