Files
galaxy-game/backend/internal/notification/dispatcher.go
T
2026-05-06 10:14:55 +03:00

176 lines
6.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package notification
import (
"context"
"database/sql"
"errors"
"fmt"
"math/rand/v2"
"time"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
)
// traceIDFromContext returns the W3C trace id of the active span as a
// hex string, or an empty string when ctx carries no recording span.
// The id is forwarded to gateway as ClientEvent.trace_id so push
// envelopes can be correlated to the producing trace.
func traceIDFromContext(ctx context.Context) string {
if ctx == nil {
return ""
}
spanCtx := trace.SpanContextFromContext(ctx)
if !spanCtx.HasTraceID() {
return ""
}
return spanCtx.TraceID().String()
}
// finaliseDispatch records the outcome of a single delivery attempt
// inside tx. The status transition table mirrors README §10 and the
// `notification_routes`'s CHECK constraint:
//
// - success → published (next_attempt_at NULL)
// - failure with attempt < max → retrying (next_attempt_at armed)
// - failure with attempt >= max → dead_lettered (+ insert
// notification_dead_letters row)
//
// The function does not commit tx: the caller (worker / Submit best-
// effort) owns the transaction so it can compose the dispatch with the
// preceding ClaimDueRoutes lock.
func (s *Service) finaliseDispatch(ctx context.Context, tx *sql.Tx, claim ClaimedRoute, dispatchErr error, at time.Time) error {
if dispatchErr == nil {
return s.deps.Store.MarkRoutePublished(ctx, tx, claim.Route.RouteID, at)
}
attempt := claim.Route.Attempts + 1
reason := dispatchErr.Error()
maxAttempts := claim.Route.MaxAttempts
if maxAttempts <= 0 {
maxAttempts = int32(s.deps.Config.MaxAttempts)
}
if attempt >= maxAttempts {
s.deps.Logger.Warn("notification route dead-lettered",
zap.String("kind", claim.Notification.Kind),
zap.String("channel", claim.Route.Channel),
zap.String("route_id", claim.Route.RouteID.String()),
zap.Int32("attempt", attempt),
zap.Error(dispatchErr),
)
return s.deps.Store.MarkRouteDeadLettered(ctx, tx, claim.Notification.NotificationID, claim.Route.RouteID, at, reason)
}
nextAt := at.Add(routeBackoff(attempt))
s.deps.Logger.Info("notification route retry scheduled",
zap.String("kind", claim.Notification.Kind),
zap.String("channel", claim.Route.Channel),
zap.String("route_id", claim.Route.RouteID.String()),
zap.Int32("attempt", attempt),
zap.Time("next_attempt_at", nextAt),
zap.Error(dispatchErr),
)
return s.deps.Store.ScheduleRouteRetry(ctx, tx, claim.Route.RouteID, at, nextAt, reason)
}
// bestEffortDispatch is invoked from Submit immediately after a route
// is durably persisted. It opens its own short transaction, runs the
// channel call, and writes the outcome with the same Mark* helpers
// the worker uses. Failures here are logged at debug level — the
// worker will retry on the next tick, so the producer never sees the
// synchronous failure.
func (s *Service) bestEffortDispatch(ctx context.Context, n Notification, route Route) {
if route.Status != RouteStatusPending {
return
}
claim := ClaimedRoute{Route: route, Notification: n}
tx, err := s.deps.Store.BeginTx(ctx)
if err != nil {
s.deps.Logger.Debug("best-effort dispatch: begin tx failed",
zap.String("route_id", route.RouteID.String()),
zap.Error(err))
return
}
defer func() { _ = tx.Rollback() }()
dispatchErr := s.performDispatch(ctx, claim)
at := s.nowUTC()
if err := s.finaliseDispatch(ctx, tx, claim, dispatchErr, at); err != nil {
s.deps.Logger.Debug("best-effort dispatch finalise failed",
zap.String("route_id", route.RouteID.String()),
zap.Error(err))
return
}
if err := tx.Commit(); err != nil {
s.deps.Logger.Debug("best-effort dispatch commit failed",
zap.String("route_id", route.RouteID.String()),
zap.Error(err))
}
}
// performDispatch runs the channel-specific delivery. Returns nil on
// success and any error otherwise. The caller decides between retry
// and dead-letter based on the attempt counter and persisted state.
func (s *Service) performDispatch(ctx context.Context, claim ClaimedRoute) error {
if ctx.Err() != nil {
return ctx.Err()
}
switch claim.Route.Channel {
case ChannelPush:
if claim.Route.UserID == nil {
return errors.New("push route missing user_id")
}
eventID := claim.Route.RouteID.String()
requestID := claim.Notification.IdempotencyKey
traceID := traceIDFromContext(ctx)
return s.deps.Push.PublishClientEvent(ctx, *claim.Route.UserID, claim.Route.DeviceSessionID, claim.Notification.Kind, claim.Notification.Payload, eventID, requestID, traceID)
case ChannelEmail:
entry, ok := LookupCatalog(claim.Notification.Kind)
if !ok {
return fmt.Errorf("unknown kind %q", claim.Notification.Kind)
}
recipient := claim.Route.ResolvedEmail
if trimSpace(recipient) == "" {
return errors.New("email route missing resolved recipient")
}
// Use the route id as idempotency_key so the mail outbox
// UNIQUE(template_id, idempotency_key) catches a duplicate
// enqueue if the worker re-claims after a crash before
// commit. Producers should never need to know the route id.
return s.deps.Mail.EnqueueTemplate(ctx, entry.MailTemplateID, recipient, claim.Notification.Payload, claim.Route.RouteID.String())
default:
return fmt.Errorf("unknown channel %q", claim.Route.Channel)
}
}
// routeBackoff computes the per-attempt delay using the package
// constants and ±backoffJitter randomisation. attempt is 1-indexed
// (the value the row will carry after Mark*); attempt==1 maps to
// `backoffBase × backoffFactor⁰`.
func routeBackoff(attempt int32) time.Duration {
if attempt <= 1 {
return jitter(backoffBase)
}
d := float64(backoffBase)
for i := int32(1); i < attempt; i++ {
d *= backoffFactor
if time.Duration(d) >= backoffMax {
return jitter(backoffMax)
}
}
return jitter(time.Duration(d))
}
// jitter applies the package-standard ±backoffJitter swing using the
// new global v2 rand source.
func jitter(d time.Duration) time.Duration {
if backoffJitter <= 0 {
return d
}
span := float64(d) * backoffJitter
delta := (rand.Float64()*2 - 1) * span
out := time.Duration(float64(d) + delta)
if out < 0 {
return d
}
return out
}