176 lines
6.1 KiB
Go
176 lines
6.1 KiB
Go
package notification
|
||
|
||
import (
|
||
"context"
|
||
"database/sql"
|
||
"errors"
|
||
"fmt"
|
||
"math/rand/v2"
|
||
"time"
|
||
|
||
"go.opentelemetry.io/otel/trace"
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
// traceIDFromContext returns the W3C trace id of the active span as a
|
||
// hex string, or an empty string when ctx carries no recording span.
|
||
// The id is forwarded to gateway as ClientEvent.trace_id so push
|
||
// envelopes can be correlated to the producing trace.
|
||
func traceIDFromContext(ctx context.Context) string {
|
||
if ctx == nil {
|
||
return ""
|
||
}
|
||
spanCtx := trace.SpanContextFromContext(ctx)
|
||
if !spanCtx.HasTraceID() {
|
||
return ""
|
||
}
|
||
return spanCtx.TraceID().String()
|
||
}
|
||
|
||
// finaliseDispatch records the outcome of a single delivery attempt
|
||
// inside tx. The status transition table mirrors README §10 and the
|
||
// `notification_routes`'s CHECK constraint:
|
||
//
|
||
// - success → published (next_attempt_at NULL)
|
||
// - failure with attempt < max → retrying (next_attempt_at armed)
|
||
// - failure with attempt >= max → dead_lettered (+ insert
|
||
// notification_dead_letters row)
|
||
//
|
||
// The function does not commit tx: the caller (worker / Submit best-
|
||
// effort) owns the transaction so it can compose the dispatch with the
|
||
// preceding ClaimDueRoutes lock.
|
||
func (s *Service) finaliseDispatch(ctx context.Context, tx *sql.Tx, claim ClaimedRoute, dispatchErr error, at time.Time) error {
|
||
if dispatchErr == nil {
|
||
return s.deps.Store.MarkRoutePublished(ctx, tx, claim.Route.RouteID, at)
|
||
}
|
||
attempt := claim.Route.Attempts + 1
|
||
reason := dispatchErr.Error()
|
||
maxAttempts := claim.Route.MaxAttempts
|
||
if maxAttempts <= 0 {
|
||
maxAttempts = int32(s.deps.Config.MaxAttempts)
|
||
}
|
||
if attempt >= maxAttempts {
|
||
s.deps.Logger.Warn("notification route dead-lettered",
|
||
zap.String("kind", claim.Notification.Kind),
|
||
zap.String("channel", claim.Route.Channel),
|
||
zap.String("route_id", claim.Route.RouteID.String()),
|
||
zap.Int32("attempt", attempt),
|
||
zap.Error(dispatchErr),
|
||
)
|
||
return s.deps.Store.MarkRouteDeadLettered(ctx, tx, claim.Notification.NotificationID, claim.Route.RouteID, at, reason)
|
||
}
|
||
nextAt := at.Add(routeBackoff(attempt))
|
||
s.deps.Logger.Info("notification route retry scheduled",
|
||
zap.String("kind", claim.Notification.Kind),
|
||
zap.String("channel", claim.Route.Channel),
|
||
zap.String("route_id", claim.Route.RouteID.String()),
|
||
zap.Int32("attempt", attempt),
|
||
zap.Time("next_attempt_at", nextAt),
|
||
zap.Error(dispatchErr),
|
||
)
|
||
return s.deps.Store.ScheduleRouteRetry(ctx, tx, claim.Route.RouteID, at, nextAt, reason)
|
||
}
|
||
|
||
// bestEffortDispatch is invoked from Submit immediately after a route
|
||
// is durably persisted. It opens its own short transaction, runs the
|
||
// channel call, and writes the outcome with the same Mark* helpers
|
||
// the worker uses. Failures here are logged at debug level — the
|
||
// worker will retry on the next tick, so the producer never sees the
|
||
// synchronous failure.
|
||
func (s *Service) bestEffortDispatch(ctx context.Context, n Notification, route Route) {
|
||
if route.Status != RouteStatusPending {
|
||
return
|
||
}
|
||
claim := ClaimedRoute{Route: route, Notification: n}
|
||
tx, err := s.deps.Store.BeginTx(ctx)
|
||
if err != nil {
|
||
s.deps.Logger.Debug("best-effort dispatch: begin tx failed",
|
||
zap.String("route_id", route.RouteID.String()),
|
||
zap.Error(err))
|
||
return
|
||
}
|
||
defer func() { _ = tx.Rollback() }()
|
||
|
||
dispatchErr := s.performDispatch(ctx, claim)
|
||
at := s.nowUTC()
|
||
if err := s.finaliseDispatch(ctx, tx, claim, dispatchErr, at); err != nil {
|
||
s.deps.Logger.Debug("best-effort dispatch finalise failed",
|
||
zap.String("route_id", route.RouteID.String()),
|
||
zap.Error(err))
|
||
return
|
||
}
|
||
if err := tx.Commit(); err != nil {
|
||
s.deps.Logger.Debug("best-effort dispatch commit failed",
|
||
zap.String("route_id", route.RouteID.String()),
|
||
zap.Error(err))
|
||
}
|
||
}
|
||
|
||
// performDispatch runs the channel-specific delivery. Returns nil on
|
||
// success and any error otherwise. The caller decides between retry
|
||
// and dead-letter based on the attempt counter and persisted state.
|
||
func (s *Service) performDispatch(ctx context.Context, claim ClaimedRoute) error {
|
||
if ctx.Err() != nil {
|
||
return ctx.Err()
|
||
}
|
||
switch claim.Route.Channel {
|
||
case ChannelPush:
|
||
if claim.Route.UserID == nil {
|
||
return errors.New("push route missing user_id")
|
||
}
|
||
eventID := claim.Route.RouteID.String()
|
||
requestID := claim.Notification.IdempotencyKey
|
||
traceID := traceIDFromContext(ctx)
|
||
return s.deps.Push.PublishClientEvent(ctx, *claim.Route.UserID, claim.Route.DeviceSessionID, claim.Notification.Kind, claim.Notification.Payload, eventID, requestID, traceID)
|
||
case ChannelEmail:
|
||
entry, ok := LookupCatalog(claim.Notification.Kind)
|
||
if !ok {
|
||
return fmt.Errorf("unknown kind %q", claim.Notification.Kind)
|
||
}
|
||
recipient := claim.Route.ResolvedEmail
|
||
if trimSpace(recipient) == "" {
|
||
return errors.New("email route missing resolved recipient")
|
||
}
|
||
// Use the route id as idempotency_key so the mail outbox
|
||
// UNIQUE(template_id, idempotency_key) catches a duplicate
|
||
// enqueue if the worker re-claims after a crash before
|
||
// commit. Producers should never need to know the route id.
|
||
return s.deps.Mail.EnqueueTemplate(ctx, entry.MailTemplateID, recipient, claim.Notification.Payload, claim.Route.RouteID.String())
|
||
default:
|
||
return fmt.Errorf("unknown channel %q", claim.Route.Channel)
|
||
}
|
||
}
|
||
|
||
// routeBackoff computes the per-attempt delay using the package
|
||
// constants and ±backoffJitter randomisation. attempt is 1-indexed
|
||
// (the value the row will carry after Mark*); attempt==1 maps to
|
||
// `backoffBase × backoffFactor⁰`.
|
||
func routeBackoff(attempt int32) time.Duration {
|
||
if attempt <= 1 {
|
||
return jitter(backoffBase)
|
||
}
|
||
d := float64(backoffBase)
|
||
for i := int32(1); i < attempt; i++ {
|
||
d *= backoffFactor
|
||
if time.Duration(d) >= backoffMax {
|
||
return jitter(backoffMax)
|
||
}
|
||
}
|
||
return jitter(time.Duration(d))
|
||
}
|
||
|
||
// jitter applies the package-standard ±backoffJitter swing using the
|
||
// new global v2 rand source.
|
||
func jitter(d time.Duration) time.Duration {
|
||
if backoffJitter <= 0 {
|
||
return d
|
||
}
|
||
span := float64(d) * backoffJitter
|
||
delta := (rand.Float64()*2 - 1) * span
|
||
out := time.Duration(float64(d) + delta)
|
||
if out < 0 {
|
||
return d
|
||
}
|
||
return out
|
||
}
|