feat: backend service
This commit is contained in:
@@ -0,0 +1,48 @@
|
||||
package push
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// cursorWidth is the zero-padded decimal width applied to every cursor.
|
||||
// 20 digits accommodate the full uint64 range so lexicographic order
|
||||
// matches numeric order across the entire process lifetime.
|
||||
const cursorWidth = 20
|
||||
|
||||
// cursorGenerator hands out monotonically increasing uint64 sequence
|
||||
// numbers. Cursors restart from 0 on process boot; the ring buffer's
|
||||
// freshness-window TTL bounds how long a cursor remains valid, so a
|
||||
// fresh process intentionally invalidates every previously-issued
|
||||
// cursor.
|
||||
type cursorGenerator struct {
|
||||
seq atomic.Uint64
|
||||
}
|
||||
|
||||
// next returns the next sequence number. The first call returns 1.
|
||||
func (g *cursorGenerator) next() uint64 {
|
||||
return g.seq.Add(1)
|
||||
}
|
||||
|
||||
// formatCursor renders n in the canonical zero-padded form so cursor
|
||||
// strings sort identically to their numeric counterparts.
|
||||
func formatCursor(n uint64) string {
|
||||
return fmt.Sprintf("%0*d", cursorWidth, n)
|
||||
}
|
||||
|
||||
// parseCursor decodes a cursor string back to its numeric value. An
|
||||
// empty string maps to 0 ("subscribe from now"); malformed input also
|
||||
// maps to 0 with ok=false so callers can log without rejecting the
|
||||
// subscription — gateway is trusted but reconnects can race against a
|
||||
// process restart that scrambled the in-memory sequence.
|
||||
func parseCursor(s string) (uint64, bool) {
|
||||
if s == "" {
|
||||
return 0, true
|
||||
}
|
||||
n, err := strconv.ParseUint(s, 10, 64)
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
return n, true
|
||||
}
|
||||
@@ -0,0 +1,79 @@
|
||||
package push
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCursorGeneratorMonotonicAndConcurrent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var g cursorGenerator
|
||||
const goroutines = 64
|
||||
const perGoroutine = 1000
|
||||
results := make(chan uint64, goroutines*perGoroutine)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines)
|
||||
for range goroutines {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for range perGoroutine {
|
||||
results <- g.next()
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
close(results)
|
||||
|
||||
seen := make(map[uint64]struct{}, goroutines*perGoroutine)
|
||||
var max uint64
|
||||
for n := range results {
|
||||
_, dup := seen[n]
|
||||
require.Falsef(t, dup, "duplicate cursor %d", n)
|
||||
seen[n] = struct{}{}
|
||||
if n > max {
|
||||
max = n
|
||||
}
|
||||
}
|
||||
assert.EqualValues(t, goroutines*perGoroutine, max)
|
||||
}
|
||||
|
||||
func TestFormatAndParseCursor(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []struct {
|
||||
in uint64
|
||||
out string
|
||||
}{
|
||||
{0, "00000000000000000000"},
|
||||
{1, "00000000000000000001"},
|
||||
{1234567890, "00000000001234567890"},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
s := formatCursor(tc.in)
|
||||
assert.Equal(t, tc.out, s)
|
||||
assert.Len(t, s, cursorWidth)
|
||||
n, ok := parseCursor(s)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, tc.in, n)
|
||||
}
|
||||
|
||||
n, ok := parseCursor("")
|
||||
assert.True(t, ok)
|
||||
assert.Zero(t, n)
|
||||
|
||||
n, ok = parseCursor("not-a-number")
|
||||
assert.False(t, ok)
|
||||
assert.Zero(t, n)
|
||||
}
|
||||
|
||||
func TestFormatCursorLexicographicOrder(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
a := formatCursor(9)
|
||||
b := formatCursor(10)
|
||||
assert.Less(t, a, b, "lexicographic order must match numeric order")
|
||||
}
|
||||
@@ -0,0 +1,161 @@
|
||||
package push
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
pushv1 "galaxy/backend/proto/push/v1"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newTestService(t *testing.T) *Service {
|
||||
t.Helper()
|
||||
svc, err := NewService(ServiceConfig{
|
||||
FreshnessWindow: time.Minute,
|
||||
RingCapacity: 16,
|
||||
PerConnBuffer: 8,
|
||||
}, nil, nil)
|
||||
require.NoError(t, err)
|
||||
return svc
|
||||
}
|
||||
|
||||
func TestPublishClientEventStampsCursorAndPayload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := newTestService(t)
|
||||
t.Cleanup(svc.Close)
|
||||
|
||||
userID := uuid.New()
|
||||
devID := uuid.New()
|
||||
payload := map[string]any{"game_id": "g1", "n": 7.0}
|
||||
require.NoError(t, svc.PublishClientEvent(context.Background(), userID, &devID, "lobby.invite.received", payload, "route-1", "req-1", "trace-1"))
|
||||
|
||||
events, stale := svc.ring.since(0, time.Now())
|
||||
require.False(t, stale)
|
||||
require.Len(t, events, 1)
|
||||
|
||||
ev := events[0]
|
||||
assert.Equal(t, formatCursor(1), ev.Cursor)
|
||||
ce := ev.GetClientEvent()
|
||||
require.NotNil(t, ce)
|
||||
assert.Equal(t, userID.String(), ce.UserId)
|
||||
assert.Equal(t, devID.String(), ce.DeviceSessionId)
|
||||
assert.Equal(t, "lobby.invite.received", ce.Kind)
|
||||
assert.Equal(t, "route-1", ce.EventId)
|
||||
assert.Equal(t, "req-1", ce.RequestId)
|
||||
assert.Equal(t, "trace-1", ce.TraceId)
|
||||
|
||||
var got map[string]any
|
||||
require.NoError(t, json.Unmarshal(ce.Payload, &got))
|
||||
assert.Equal(t, "g1", got["game_id"])
|
||||
assert.EqualValues(t, 7.0, got["n"])
|
||||
}
|
||||
|
||||
func TestPublishClientEventOmitsDeviceSessionWhenNil(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := newTestService(t)
|
||||
t.Cleanup(svc.Close)
|
||||
|
||||
userID := uuid.New()
|
||||
require.NoError(t, svc.PublishClientEvent(context.Background(), userID, nil, "x", nil, "", "", ""))
|
||||
|
||||
events, _ := svc.ring.since(0, time.Now())
|
||||
require.Len(t, events, 1)
|
||||
assert.Empty(t, events[0].GetClientEvent().DeviceSessionId)
|
||||
}
|
||||
|
||||
func TestPublishClientEventRequiresUserAndKind(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := newTestService(t)
|
||||
t.Cleanup(svc.Close)
|
||||
|
||||
require.Error(t, svc.PublishClientEvent(context.Background(), uuid.Nil, nil, "k", nil, "", "", ""))
|
||||
require.Error(t, svc.PublishClientEvent(context.Background(), uuid.New(), nil, " ", nil, "", "", ""))
|
||||
}
|
||||
|
||||
func TestPublishSessionInvalidationStampsCursor(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := newTestService(t)
|
||||
t.Cleanup(svc.Close)
|
||||
|
||||
userID := uuid.New()
|
||||
devID := uuid.New()
|
||||
svc.PublishSessionInvalidation(context.Background(), devID, userID, "auth.revoke_session")
|
||||
|
||||
events, _ := svc.ring.since(0, time.Now())
|
||||
require.Len(t, events, 1)
|
||||
si := events[0].GetSessionInvalidation()
|
||||
require.NotNil(t, si)
|
||||
assert.Equal(t, userID.String(), si.UserId)
|
||||
assert.Equal(t, devID.String(), si.DeviceSessionId)
|
||||
assert.Equal(t, "auth.revoke_session", si.Reason)
|
||||
}
|
||||
|
||||
func TestPublishSessionInvalidationFanOutOmitsDeviceSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := newTestService(t)
|
||||
t.Cleanup(svc.Close)
|
||||
|
||||
userID := uuid.New()
|
||||
svc.PublishSessionInvalidation(context.Background(), uuid.Nil, userID, "auth.revoke_all_for_user")
|
||||
|
||||
events, _ := svc.ring.since(0, time.Now())
|
||||
require.Len(t, events, 1)
|
||||
si := events[0].GetSessionInvalidation()
|
||||
assert.Empty(t, si.DeviceSessionId)
|
||||
assert.Equal(t, userID.String(), si.UserId)
|
||||
}
|
||||
|
||||
func TestPublishCursorMonotonic(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := newTestService(t)
|
||||
t.Cleanup(svc.Close)
|
||||
|
||||
userID := uuid.New()
|
||||
for range 5 {
|
||||
require.NoError(t, svc.PublishClientEvent(context.Background(), userID, nil, "k", nil, "", "", ""))
|
||||
}
|
||||
events, _ := svc.ring.since(0, time.Now())
|
||||
require.Len(t, events, 5)
|
||||
for i, ev := range events {
|
||||
assert.Equal(t, formatCursor(uint64(i+1)), ev.Cursor)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishOnClosedServiceIsNoop(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := newTestService(t)
|
||||
svc.Close()
|
||||
require.NoError(t, svc.PublishClientEvent(context.Background(), uuid.New(), nil, "k", nil, "", "", ""))
|
||||
events, _ := svc.ring.since(0, time.Now())
|
||||
assert.Empty(t, events)
|
||||
}
|
||||
|
||||
// Compile-time interface checks: Service must satisfy the publisher
|
||||
// contracts that internal/auth and internal/notification import.
|
||||
var (
|
||||
_ pushClientEventPublisher = (*Service)(nil)
|
||||
_ pushSessionInvalidationEmitter = (*Service)(nil)
|
||||
)
|
||||
|
||||
type pushClientEventPublisher interface {
|
||||
PublishClientEvent(ctx context.Context, userID uuid.UUID, deviceSessionID *uuid.UUID, kind string, payload map[string]any, eventID, requestID, traceID string) error
|
||||
}
|
||||
|
||||
type pushSessionInvalidationEmitter interface {
|
||||
PublishSessionInvalidation(ctx context.Context, deviceSessionID, userID uuid.UUID, reason string)
|
||||
}
|
||||
|
||||
// Make sure the publisher satisfies pushv1.PushServer at the type level.
|
||||
var _ pushv1.PushServer = (*Service)(nil)
|
||||
@@ -0,0 +1,108 @@
|
||||
package push
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
pushv1 "galaxy/backend/proto/push/v1"
|
||||
)
|
||||
|
||||
// ringEntry is one event stored in the in-memory replay buffer. The
|
||||
// cursor is duplicated here for O(1) comparison without re-parsing
|
||||
// event.Cursor.
|
||||
type ringEntry struct {
|
||||
cursor uint64
|
||||
addedAt time.Time
|
||||
event *pushv1.PushEvent
|
||||
}
|
||||
|
||||
// ring is the in-memory replay buffer. Entries are evicted by either
|
||||
// freshness-window TTL or capacity, whichever triggers first. The ring
|
||||
// is not safe for concurrent use; the owning Service serialises access
|
||||
// under its mutex.
|
||||
type ring struct {
|
||||
capacity int
|
||||
ttl time.Duration
|
||||
entries []ringEntry
|
||||
lastEvicted uint64 // largest cursor evicted from the buffer
|
||||
hasLastEvicted bool
|
||||
}
|
||||
|
||||
func newRing(capacity int, ttl time.Duration) *ring {
|
||||
return &ring{
|
||||
capacity: capacity,
|
||||
ttl: ttl,
|
||||
entries: make([]ringEntry, 0, capacity),
|
||||
}
|
||||
}
|
||||
|
||||
// append records ev with its cursor and evicts entries past TTL or
|
||||
// capacity. The caller is responsible for setting ev.Cursor to
|
||||
// formatCursor(cursor) before calling.
|
||||
func (r *ring) append(cursor uint64, ev *pushv1.PushEvent, now time.Time) {
|
||||
r.evictExpired(now)
|
||||
for len(r.entries) >= r.capacity {
|
||||
r.evictHead()
|
||||
}
|
||||
r.entries = append(r.entries, ringEntry{cursor: cursor, addedAt: now, event: ev})
|
||||
}
|
||||
|
||||
// since returns the events with cursor strictly greater than fromCursor
|
||||
// in ascending cursor order. The boolean is true when the requested
|
||||
// cursor is "stale" — either older than the oldest retained event or
|
||||
// older than the last evicted cursor — meaning the caller missed at
|
||||
// least one event that the ring no longer holds. Stale callers receive
|
||||
// no replay and must resume from the live tail.
|
||||
func (r *ring) since(fromCursor uint64, now time.Time) ([]*pushv1.PushEvent, bool) {
|
||||
r.evictExpired(now)
|
||||
if len(r.entries) == 0 {
|
||||
// An empty ring is never stale: gateway is either fully caught
|
||||
// up or there has been no traffic.
|
||||
return nil, false
|
||||
}
|
||||
if r.hasLastEvicted && fromCursor < r.lastEvicted {
|
||||
return nil, true
|
||||
}
|
||||
first := r.entries[0].cursor
|
||||
if fromCursor+1 < first {
|
||||
return nil, true
|
||||
}
|
||||
out := make([]*pushv1.PushEvent, 0)
|
||||
for i := range r.entries {
|
||||
if r.entries[i].cursor > fromCursor {
|
||||
out = append(out, r.entries[i].event)
|
||||
}
|
||||
}
|
||||
return out, false
|
||||
}
|
||||
|
||||
// len reports the current number of retained entries; intended for
|
||||
// tests and metrics.
|
||||
func (r *ring) len() int {
|
||||
return len(r.entries)
|
||||
}
|
||||
|
||||
func (r *ring) evictExpired(now time.Time) {
|
||||
if r.ttl <= 0 {
|
||||
return
|
||||
}
|
||||
cutoff := now.Add(-r.ttl)
|
||||
drop := 0
|
||||
for drop < len(r.entries) && r.entries[drop].addedAt.Before(cutoff) {
|
||||
drop++
|
||||
}
|
||||
if drop == 0 {
|
||||
return
|
||||
}
|
||||
r.lastEvicted = r.entries[drop-1].cursor
|
||||
r.hasLastEvicted = true
|
||||
r.entries = append(r.entries[:0], r.entries[drop:]...)
|
||||
}
|
||||
|
||||
func (r *ring) evictHead() {
|
||||
if len(r.entries) == 0 {
|
||||
return
|
||||
}
|
||||
r.lastEvicted = r.entries[0].cursor
|
||||
r.hasLastEvicted = true
|
||||
r.entries = append(r.entries[:0], r.entries[1:]...)
|
||||
}
|
||||
@@ -0,0 +1,105 @@
|
||||
package push
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
pushv1 "galaxy/backend/proto/push/v1"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func mkEvent(cursor uint64, label string) *pushv1.PushEvent {
|
||||
return &pushv1.PushEvent{
|
||||
Cursor: formatCursor(cursor),
|
||||
Kind: &pushv1.PushEvent_ClientEvent{
|
||||
ClientEvent: &pushv1.ClientEvent{
|
||||
Kind: label,
|
||||
Payload: []byte(label),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestRingAppendAndSinceReturnsTail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Unix(1_700_000_000, 0)
|
||||
r := newRing(8, time.Minute)
|
||||
for i := uint64(1); i <= 5; i++ {
|
||||
r.append(i, mkEvent(i, "e"), now)
|
||||
}
|
||||
|
||||
got, stale := r.since(2, now)
|
||||
require.False(t, stale)
|
||||
require.Len(t, got, 3)
|
||||
assert.Equal(t, formatCursor(3), got[0].Cursor)
|
||||
assert.Equal(t, formatCursor(4), got[1].Cursor)
|
||||
assert.Equal(t, formatCursor(5), got[2].Cursor)
|
||||
}
|
||||
|
||||
func TestRingSinceReturnsEmptyWhenCaughtUp(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Unix(1_700_000_000, 0)
|
||||
r := newRing(8, time.Minute)
|
||||
for i := uint64(1); i <= 3; i++ {
|
||||
r.append(i, mkEvent(i, "e"), now)
|
||||
}
|
||||
|
||||
got, stale := r.since(3, now)
|
||||
require.False(t, stale)
|
||||
assert.Empty(t, got)
|
||||
|
||||
got, stale = r.since(99, now)
|
||||
require.False(t, stale)
|
||||
assert.Empty(t, got)
|
||||
}
|
||||
|
||||
func TestRingSinceFlagsStaleCursorBelowEvictedRange(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Unix(1_700_000_000, 0)
|
||||
r := newRing(3, time.Minute)
|
||||
for i := uint64(1); i <= 5; i++ {
|
||||
r.append(i, mkEvent(i, "e"), now)
|
||||
}
|
||||
// Capacity=3 means cursors 1 and 2 were evicted.
|
||||
require.Equal(t, 3, r.len())
|
||||
|
||||
got, stale := r.since(1, now)
|
||||
assert.True(t, stale)
|
||||
assert.Empty(t, got)
|
||||
|
||||
got, stale = r.since(2, now)
|
||||
assert.False(t, stale)
|
||||
require.Len(t, got, 3)
|
||||
assert.Equal(t, formatCursor(3), got[0].Cursor)
|
||||
}
|
||||
|
||||
func TestRingEvictsExpiredEntries(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t0 := time.Unix(1_700_000_000, 0)
|
||||
r := newRing(8, 10*time.Second)
|
||||
r.append(1, mkEvent(1, "e"), t0)
|
||||
r.append(2, mkEvent(2, "e"), t0.Add(2*time.Second))
|
||||
r.append(3, mkEvent(3, "e"), t0.Add(15*time.Second))
|
||||
|
||||
// At t0+13s the first two entries are past their 10s TTL but the
|
||||
// third (added at t0+15s) is still within the freshness window.
|
||||
got, stale := r.since(0, t0.Add(13*time.Second))
|
||||
assert.True(t, stale)
|
||||
assert.Empty(t, got)
|
||||
assert.Equal(t, 1, r.len())
|
||||
}
|
||||
|
||||
func TestRingEmptyIsNeverStale(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := newRing(4, time.Minute)
|
||||
got, stale := r.since(42, time.Now())
|
||||
assert.False(t, stale)
|
||||
assert.Empty(t, got)
|
||||
}
|
||||
@@ -0,0 +1,145 @@
|
||||
// Package push hosts the backend gRPC listener used by gateway.
|
||||
//
|
||||
// Server owns the TCP listener and gRPC machinery. Service implements
|
||||
// the PushServer interface and is registered against the gRPC server
|
||||
// before Serve begins. On shutdown the server signals the service to
|
||||
// drop its subscriptions, then performs the usual GracefulStop /
|
||||
// forced-stop sequence.
|
||||
package push
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"galaxy/backend/internal/config"
|
||||
"galaxy/backend/internal/telemetry"
|
||||
pushv1 "galaxy/backend/proto/push/v1"
|
||||
|
||||
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
// Server owns the gRPC push listener.
|
||||
type Server struct {
|
||||
cfg config.GRPCPushConfig
|
||||
svc *Service
|
||||
logger *zap.Logger
|
||||
runtime *telemetry.Runtime
|
||||
|
||||
stateMu sync.RWMutex
|
||||
server *grpc.Server
|
||||
listener net.Listener
|
||||
}
|
||||
|
||||
// NewServer constructs a gRPC push server bound to cfg. svc must not be
|
||||
// nil; it is registered as the pushv1.PushServer implementation when
|
||||
// Run starts.
|
||||
func NewServer(cfg config.GRPCPushConfig, svc *Service, logger *zap.Logger, runtime *telemetry.Runtime) *Server {
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
|
||||
return &Server{
|
||||
cfg: cfg,
|
||||
svc: svc,
|
||||
logger: logger.Named("grpc_push"),
|
||||
runtime: runtime,
|
||||
}
|
||||
}
|
||||
|
||||
// Run binds the listener and serves the gRPC surface until Shutdown closes
|
||||
// the server.
|
||||
func (s *Server) Run(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
return errors.New("run backend gRPC push server: nil context")
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
if s.svc == nil {
|
||||
return errors.New("run backend gRPC push server: nil service")
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", s.cfg.Addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("run backend gRPC push server: listen on %q: %w", s.cfg.Addr, err)
|
||||
}
|
||||
|
||||
grpcServer := grpc.NewServer(
|
||||
grpc.StatsHandler(otelgrpc.NewServerHandler()),
|
||||
)
|
||||
pushv1.RegisterPushServer(grpcServer, s.svc)
|
||||
|
||||
s.stateMu.Lock()
|
||||
s.server = grpcServer
|
||||
s.listener = listener
|
||||
s.stateMu.Unlock()
|
||||
|
||||
s.logger.Info("backend gRPC push server started", zap.String("addr", listener.Addr().String()))
|
||||
|
||||
defer func() {
|
||||
s.stateMu.Lock()
|
||||
s.server = nil
|
||||
s.listener = nil
|
||||
s.stateMu.Unlock()
|
||||
}()
|
||||
|
||||
err = grpcServer.Serve(listener)
|
||||
switch {
|
||||
case err == nil:
|
||||
return nil
|
||||
case errors.Is(err, grpc.ErrServerStopped):
|
||||
s.logger.Info("backend gRPC push server stopped")
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("run backend gRPC push server: serve on %q: %w", s.cfg.Addr, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown attempts a graceful stop within ctx, falling back to a forced stop
|
||||
// when ctx expires before GracefulStop returns. The configured per-listener
|
||||
// timeout further bounds the wait. Active SubscribePush streams are closed
|
||||
// first so GracefulStop is not blocked by long-lived server-streaming RPCs.
|
||||
func (s *Server) Shutdown(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
return errors.New("shutdown backend gRPC push server: nil context")
|
||||
}
|
||||
|
||||
s.stateMu.RLock()
|
||||
server := s.server
|
||||
s.stateMu.RUnlock()
|
||||
|
||||
if server == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if s.svc != nil {
|
||||
s.svc.Close()
|
||||
}
|
||||
|
||||
shutdownCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
if s.cfg.ShutdownTimeout > 0 {
|
||||
shutdownCtx, cancel = context.WithTimeout(ctx, s.cfg.ShutdownTimeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
stopped := make(chan struct{})
|
||||
go func() {
|
||||
server.GracefulStop()
|
||||
close(stopped)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-stopped:
|
||||
return nil
|
||||
case <-shutdownCtx.Done():
|
||||
server.Stop()
|
||||
<-stopped
|
||||
return fmt.Errorf("shutdown backend gRPC push server: %w", shutdownCtx.Err())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,327 @@
|
||||
// Package push hosts the backend gRPC SubscribePush server and the
|
||||
// publisher API consumed by other backend domains.
|
||||
//
|
||||
// Service implements pushv1.PushServer. It maintains:
|
||||
//
|
||||
// - a connection registry keyed by GatewaySubscribeRequest.gateway_client_id;
|
||||
// - an in-memory ring buffer of recent PushEvent values with TTL equal
|
||||
// to BACKEND_FRESHNESS_WINDOW;
|
||||
// - a monotonic cursor generator stamped on every published event.
|
||||
//
|
||||
// Publisher methods (PublishClientEvent, PublishSessionInvalidation)
|
||||
// satisfy the SessionInvalidator interface in internal/auth and the
|
||||
// PushPublisher interface in internal/notification — main.go injects
|
||||
// a single *Service into both wiring sites.
|
||||
//
|
||||
// See `backend/README.md` §7 and `backend/docs/flows.md` for cursor,
|
||||
// ring buffer, and backpressure semantics.
|
||||
package push
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"galaxy/backend/internal/telemetry"
|
||||
pushv1 "galaxy/backend/proto/push/v1"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// Default sizing for the ring buffer and per-connection delivery queue.
|
||||
// The values are intentionally hard-coded: ring TTL is the operational
|
||||
// dial (BACKEND_FRESHNESS_WINDOW) and the buffer sizes are chosen to
|
||||
// comfortably absorb a freshness window of traffic at MVP rates.
|
||||
const (
|
||||
defaultRingCapacity = 1024
|
||||
defaultPerConnBuffer = 256
|
||||
)
|
||||
|
||||
// ServiceConfig configures a Service. FreshnessWindow is required and
|
||||
// fixes the ring buffer's per-event TTL. RingCapacity and PerConnBuffer
|
||||
// fall back to the package defaults when zero. Now overrides time.Now
|
||||
// for deterministic tests.
|
||||
type ServiceConfig struct {
|
||||
FreshnessWindow time.Duration
|
||||
RingCapacity int
|
||||
PerConnBuffer int
|
||||
Now func() time.Time
|
||||
}
|
||||
|
||||
// Service implements pushv1.PushServer and exposes the publisher API.
|
||||
// One Service is shared by every backend domain that needs to push;
|
||||
// it is safe for concurrent use.
|
||||
type Service struct {
|
||||
pushv1.UnimplementedPushServer
|
||||
|
||||
logger *zap.Logger
|
||||
now func() time.Time
|
||||
|
||||
perConnBuffer int
|
||||
|
||||
mu sync.Mutex
|
||||
closed bool
|
||||
subs map[string]*subscription
|
||||
ring *ring
|
||||
cursorGen cursorGenerator
|
||||
|
||||
eventsTotal metric.Int64Counter
|
||||
droppedTotal metric.Int64Counter
|
||||
}
|
||||
|
||||
// NewService constructs a Service. A nil logger falls back to
|
||||
// zap.NewNop. A nil runtime disables metric emission so tests can
|
||||
// instantiate the service without the OpenTelemetry runtime.
|
||||
func NewService(cfg ServiceConfig, logger *zap.Logger, runtime *telemetry.Runtime) (*Service, error) {
|
||||
if cfg.FreshnessWindow <= 0 {
|
||||
return nil, errors.New("push.NewService: FreshnessWindow must be positive")
|
||||
}
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
if cfg.Now == nil {
|
||||
cfg.Now = time.Now
|
||||
}
|
||||
if cfg.RingCapacity <= 0 {
|
||||
cfg.RingCapacity = defaultRingCapacity
|
||||
}
|
||||
if cfg.PerConnBuffer <= 0 {
|
||||
cfg.PerConnBuffer = defaultPerConnBuffer
|
||||
}
|
||||
|
||||
s := &Service{
|
||||
logger: logger.Named("push"),
|
||||
now: cfg.Now,
|
||||
perConnBuffer: cfg.PerConnBuffer,
|
||||
subs: make(map[string]*subscription),
|
||||
ring: newRing(cfg.RingCapacity, cfg.FreshnessWindow),
|
||||
}
|
||||
|
||||
if runtime != nil {
|
||||
if err := s.registerMetrics(runtime); err != nil {
|
||||
return nil, fmt.Errorf("push.NewService: register metrics: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Close drops every active subscription and refuses new ones. It is
|
||||
// safe to call multiple times. The owning Server must call Close before
|
||||
// initiating GracefulStop so streaming handlers exit promptly.
|
||||
func (s *Service) Close() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.closed {
|
||||
return
|
||||
}
|
||||
s.closed = true
|
||||
for clientID, sub := range s.subs {
|
||||
close(sub.done)
|
||||
delete(s.subs, clientID)
|
||||
}
|
||||
}
|
||||
|
||||
// PublishClientEvent enqueues a ClientEvent for delivery. payload is
|
||||
// marshalled to JSON; deviceSessionID is optional. eventID, requestID
|
||||
// and traceID are correlation identifiers that gateway forwards
|
||||
// verbatim into the signed client envelope (typically the producing
|
||||
// route id, the originating client request id, and the trace id of the
|
||||
// span that produced the event); empty strings are forwarded
|
||||
// unchanged. The method satisfies notification.PushPublisher.
|
||||
func (s *Service) PublishClientEvent(_ context.Context, userID uuid.UUID, deviceSessionID *uuid.UUID, kind string, payload map[string]any, eventID, requestID, traceID string) error {
|
||||
if userID == uuid.Nil {
|
||||
return errors.New("push.PublishClientEvent: userID is required")
|
||||
}
|
||||
if strings.TrimSpace(kind) == "" {
|
||||
return errors.New("push.PublishClientEvent: kind is required")
|
||||
}
|
||||
encoded, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("push.PublishClientEvent: marshal payload: %w", err)
|
||||
}
|
||||
ev := &pushv1.PushEvent{
|
||||
Kind: &pushv1.PushEvent_ClientEvent{
|
||||
ClientEvent: &pushv1.ClientEvent{
|
||||
UserId: userID.String(),
|
||||
Kind: kind,
|
||||
Payload: encoded,
|
||||
EventId: eventID,
|
||||
RequestId: requestID,
|
||||
TraceId: traceID,
|
||||
},
|
||||
},
|
||||
}
|
||||
if deviceSessionID != nil {
|
||||
ev.GetClientEvent().DeviceSessionId = deviceSessionID.String()
|
||||
}
|
||||
s.publish(ev, "client_event")
|
||||
return nil
|
||||
}
|
||||
|
||||
// PublishSessionInvalidation enqueues a SessionInvalidation event. It
|
||||
// satisfies auth.SessionInvalidator. deviceSessionID may be uuid.Nil to
|
||||
// invalidate every session of userID.
|
||||
func (s *Service) PublishSessionInvalidation(_ context.Context, deviceSessionID, userID uuid.UUID, reason string) {
|
||||
if userID == uuid.Nil {
|
||||
s.logger.Warn("push session invalidation skipped: userID is required",
|
||||
zap.String("device_session_id", deviceSessionID.String()),
|
||||
zap.String("reason", reason),
|
||||
)
|
||||
return
|
||||
}
|
||||
ev := &pushv1.PushEvent{
|
||||
Kind: &pushv1.PushEvent_SessionInvalidation{
|
||||
SessionInvalidation: &pushv1.SessionInvalidation{
|
||||
UserId: userID.String(),
|
||||
Reason: reason,
|
||||
},
|
||||
},
|
||||
}
|
||||
if deviceSessionID != uuid.Nil {
|
||||
ev.GetSessionInvalidation().DeviceSessionId = deviceSessionID.String()
|
||||
}
|
||||
s.publish(ev, "session_invalidation")
|
||||
}
|
||||
|
||||
func (s *Service) publish(ev *pushv1.PushEvent, kindLabel string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.closed {
|
||||
return
|
||||
}
|
||||
cursor := s.cursorGen.next()
|
||||
ev.Cursor = formatCursor(cursor)
|
||||
s.ring.append(cursor, ev, s.now())
|
||||
if s.eventsTotal != nil {
|
||||
s.eventsTotal.Add(context.Background(), 1, metric.WithAttributes(attribute.String("kind", kindLabel)))
|
||||
}
|
||||
for clientID, sub := range s.subs {
|
||||
if dropped := sub.deliver(ev); dropped {
|
||||
if s.droppedTotal != nil {
|
||||
s.droppedTotal.Add(context.Background(), 1, metric.WithAttributes(attribute.String("gateway_client_id", clientID)))
|
||||
}
|
||||
s.logger.Warn("push subscription dropped event",
|
||||
zap.String("gateway_client_id", clientID),
|
||||
zap.String("cursor", ev.Cursor),
|
||||
zap.String("event_kind", kindLabel),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// register installs a new subscription for clientID and returns the
|
||||
// replay slice the caller must send before draining the live channel.
|
||||
// An existing subscription for the same clientID is closed first so
|
||||
// the previous reader goroutine exits.
|
||||
func (s *Service) register(clientID, cursor string) (*subscription, []*pushv1.PushEvent, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.closed {
|
||||
return nil, nil, status.Error(codes.Unavailable, "push service stopped")
|
||||
}
|
||||
if existing, ok := s.subs[clientID]; ok {
|
||||
close(existing.done)
|
||||
delete(s.subs, clientID)
|
||||
s.logger.Info("push subscription replaced",
|
||||
zap.String("gateway_client_id", clientID),
|
||||
)
|
||||
}
|
||||
sub := &subscription{
|
||||
clientID: clientID,
|
||||
ch: make(chan *pushv1.PushEvent, s.perConnBuffer),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
s.subs[clientID] = sub
|
||||
|
||||
from, ok := parseCursor(cursor)
|
||||
if !ok {
|
||||
s.logger.Warn("push subscribe with malformed cursor; resuming from live tail",
|
||||
zap.String("gateway_client_id", clientID),
|
||||
zap.String("cursor", cursor),
|
||||
)
|
||||
}
|
||||
replay, stale := s.ring.since(from, s.now())
|
||||
if stale {
|
||||
s.logger.Info("push subscribe cursor stale; replay skipped",
|
||||
zap.String("gateway_client_id", clientID),
|
||||
zap.String("cursor", cursor),
|
||||
)
|
||||
} else if len(replay) > 0 {
|
||||
s.logger.Info("push subscribe replay",
|
||||
zap.String("gateway_client_id", clientID),
|
||||
zap.String("cursor", cursor),
|
||||
zap.Int("events", len(replay)),
|
||||
)
|
||||
}
|
||||
return sub, replay, nil
|
||||
}
|
||||
|
||||
// unregister removes sub from the registry when the reader goroutine
|
||||
// exits. It is a no-op when sub has already been replaced — the
|
||||
// replacement subscription owns the entry under the same clientID.
|
||||
func (s *Service) unregister(sub *subscription) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if cur, ok := s.subs[sub.clientID]; ok && cur == sub {
|
||||
delete(s.subs, sub.clientID)
|
||||
}
|
||||
}
|
||||
|
||||
// SubscriberCount reports the number of active subscriptions; used by
|
||||
// metrics callbacks and tests.
|
||||
func (s *Service) SubscriberCount() int {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return len(s.subs)
|
||||
}
|
||||
|
||||
func (s *Service) registerMetrics(runtime *telemetry.Runtime) error {
|
||||
meter := runtime.MeterProvider().Meter("galaxy.backend/push")
|
||||
|
||||
subscribers, err := meter.Int64ObservableGauge(
|
||||
"grpc_push_subscribers",
|
||||
metric.WithDescription("Number of gateway clients currently subscribed to the backend push stream."),
|
||||
metric.WithUnit("1"),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := meter.RegisterCallback(func(_ context.Context, o metric.Observer) error {
|
||||
o.ObserveInt64(subscribers, int64(s.SubscriberCount()))
|
||||
return nil
|
||||
}, subscribers); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
eventsTotal, err := meter.Int64Counter(
|
||||
"grpc_push_events_total",
|
||||
metric.WithDescription("Number of push events published, partitioned by event kind."),
|
||||
metric.WithUnit("1"),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.eventsTotal = eventsTotal
|
||||
|
||||
droppedTotal, err := meter.Int64Counter(
|
||||
"grpc_push_dropped_total",
|
||||
metric.WithDescription("Number of push events dropped because a subscriber buffer was full, partitioned by gateway client id."),
|
||||
metric.WithUnit("1"),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.droppedTotal = droppedTotal
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,240 @@
|
||||
package push
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
pushv1 "galaxy/backend/proto/push/v1"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/grpc/test/bufconn"
|
||||
)
|
||||
|
||||
const bufconnBufferSize = 1024 * 1024
|
||||
|
||||
// startBufconnServer wires svc into an in-process gRPC server reachable
|
||||
// through a bufconn dialer. The returned cleanup function stops the
|
||||
// server and closes the listener.
|
||||
func startBufconnServer(t *testing.T, svc *Service) (pushv1.PushClient, func()) {
|
||||
t.Helper()
|
||||
|
||||
lis := bufconn.Listen(bufconnBufferSize)
|
||||
server := grpc.NewServer()
|
||||
pushv1.RegisterPushServer(server, svc)
|
||||
|
||||
go func() {
|
||||
_ = server.Serve(lis)
|
||||
}()
|
||||
|
||||
conn, err := grpc.NewClient(
|
||||
"passthrough://bufnet",
|
||||
grpc.WithContextDialer(func(_ context.Context, _ string) (net.Conn, error) {
|
||||
return lis.DialContext(context.Background())
|
||||
}),
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
cleanup := func() {
|
||||
_ = conn.Close()
|
||||
server.Stop()
|
||||
_ = lis.Close()
|
||||
}
|
||||
return pushv1.NewPushClient(conn), cleanup
|
||||
}
|
||||
|
||||
func recvOne(t *testing.T, stream pushv1.Push_SubscribePushClient, timeout time.Duration) (*pushv1.PushEvent, error) {
|
||||
t.Helper()
|
||||
type result struct {
|
||||
ev *pushv1.PushEvent
|
||||
err error
|
||||
}
|
||||
ch := make(chan result, 1)
|
||||
go func() {
|
||||
ev, err := stream.Recv()
|
||||
ch <- result{ev, err}
|
||||
}()
|
||||
select {
|
||||
case r := <-ch:
|
||||
return r.ev, r.err
|
||||
case <-time.After(timeout):
|
||||
t.Fatalf("timed out waiting for push event after %s", timeout)
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscribePushDeliversLiveEvents(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc, err := NewService(ServiceConfig{FreshnessWindow: time.Minute, RingCapacity: 16, PerConnBuffer: 8}, nil, nil)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(svc.Close)
|
||||
|
||||
client, cleanup := startBufconnServer(t, svc)
|
||||
defer cleanup()
|
||||
|
||||
stream, err := client.SubscribePush(t.Context(), &pushv1.GatewaySubscribeRequest{GatewayClientId: "gw-1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Eventually(t, func() bool { return svc.SubscriberCount() == 1 }, time.Second, 5*time.Millisecond)
|
||||
|
||||
userID := uuid.New()
|
||||
require.NoError(t, svc.PublishClientEvent(context.Background(), userID, nil, "k", nil, "", "", ""))
|
||||
|
||||
ev, err := recvOne(t, stream, time.Second)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, formatCursor(1), ev.Cursor)
|
||||
assert.Equal(t, userID.String(), ev.GetClientEvent().UserId)
|
||||
}
|
||||
|
||||
func TestSubscribePushReplaysPastEventsOnReconnect(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc, err := NewService(ServiceConfig{FreshnessWindow: time.Minute, RingCapacity: 16, PerConnBuffer: 8}, nil, nil)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(svc.Close)
|
||||
|
||||
userID := uuid.New()
|
||||
for range 3 {
|
||||
require.NoError(t, svc.PublishClientEvent(context.Background(), userID, nil, "k", nil, "", "", ""))
|
||||
}
|
||||
|
||||
client, cleanup := startBufconnServer(t, svc)
|
||||
defer cleanup()
|
||||
|
||||
stream, err := client.SubscribePush(t.Context(), &pushv1.GatewaySubscribeRequest{GatewayClientId: "gw-1", Cursor: formatCursor(1)})
|
||||
require.NoError(t, err)
|
||||
|
||||
for i := uint64(2); i <= 3; i++ {
|
||||
ev, err := recvOne(t, stream, time.Second)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, formatCursor(i), ev.Cursor)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscribePushSkipsReplayWhenCursorStale(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc, err := NewService(ServiceConfig{FreshnessWindow: time.Minute, RingCapacity: 2, PerConnBuffer: 8}, nil, nil)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(svc.Close)
|
||||
|
||||
userID := uuid.New()
|
||||
for range 4 {
|
||||
require.NoError(t, svc.PublishClientEvent(context.Background(), userID, nil, "k", nil, "", "", ""))
|
||||
}
|
||||
// Ring capacity 2 means cursors 1 and 2 are evicted.
|
||||
|
||||
client, cleanup := startBufconnServer(t, svc)
|
||||
defer cleanup()
|
||||
|
||||
stream, err := client.SubscribePush(t.Context(), &pushv1.GatewaySubscribeRequest{GatewayClientId: "gw-1", Cursor: formatCursor(1)})
|
||||
require.NoError(t, err)
|
||||
require.Eventually(t, func() bool { return svc.SubscriberCount() == 1 }, time.Second, 5*time.Millisecond)
|
||||
|
||||
// Stale cursor → no replay; live publish must arrive.
|
||||
require.NoError(t, svc.PublishClientEvent(context.Background(), userID, nil, "k", nil, "", "", ""))
|
||||
ev, err := recvOne(t, stream, time.Second)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, formatCursor(5), ev.Cursor)
|
||||
}
|
||||
|
||||
func TestSubscribePushReplacesExistingClientID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc, err := NewService(ServiceConfig{FreshnessWindow: time.Minute, RingCapacity: 8, PerConnBuffer: 8}, nil, nil)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(svc.Close)
|
||||
|
||||
client, cleanup := startBufconnServer(t, svc)
|
||||
defer cleanup()
|
||||
|
||||
stream1, err := client.SubscribePush(t.Context(), &pushv1.GatewaySubscribeRequest{GatewayClientId: "gw-1"})
|
||||
require.NoError(t, err)
|
||||
require.Eventually(t, func() bool { return svc.SubscriberCount() == 1 }, time.Second, 5*time.Millisecond)
|
||||
|
||||
stream2, err := client.SubscribePush(t.Context(), &pushv1.GatewaySubscribeRequest{GatewayClientId: "gw-1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
// First stream must terminate with Aborted.
|
||||
_, err = recvOne(t, stream1, time.Second)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Aborted, status.Code(err))
|
||||
|
||||
// Subscriber count returns to one (the replacement).
|
||||
require.Eventually(t, func() bool { return svc.SubscriberCount() == 1 }, time.Second, 5*time.Millisecond)
|
||||
|
||||
// Live publish reaches the replacement.
|
||||
require.NoError(t, svc.PublishClientEvent(context.Background(), uuid.New(), nil, "k", nil, "", "", ""))
|
||||
ev, err := recvOne(t, stream2, time.Second)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, ev.Cursor)
|
||||
}
|
||||
|
||||
func TestSubscribePushRejectsEmptyClientID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc, err := NewService(ServiceConfig{FreshnessWindow: time.Minute, RingCapacity: 4, PerConnBuffer: 4}, nil, nil)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(svc.Close)
|
||||
|
||||
client, cleanup := startBufconnServer(t, svc)
|
||||
defer cleanup()
|
||||
|
||||
stream, err := client.SubscribePush(t.Context(), &pushv1.GatewaySubscribeRequest{})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = stream.Recv()
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.InvalidArgument, status.Code(err))
|
||||
}
|
||||
|
||||
func TestSubscriptionDeliverDropsOldestOnOverflow(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sub := &subscription{
|
||||
clientID: "gw-1",
|
||||
ch: make(chan *pushv1.PushEvent, 2),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
first := mkEvent(1, "a")
|
||||
second := mkEvent(2, "b")
|
||||
third := mkEvent(3, "c")
|
||||
|
||||
assert.False(t, sub.deliver(first))
|
||||
assert.False(t, sub.deliver(second))
|
||||
assert.True(t, sub.deliver(third), "third deliver must report a drop")
|
||||
|
||||
got1 := <-sub.ch
|
||||
got2 := <-sub.ch
|
||||
assert.Equal(t, second, got1, "oldest event (first) was dropped")
|
||||
assert.Equal(t, third, got2)
|
||||
}
|
||||
|
||||
func TestServiceCloseTerminatesActiveStream(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc, err := NewService(ServiceConfig{FreshnessWindow: time.Minute, RingCapacity: 4, PerConnBuffer: 4}, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
client, cleanup := startBufconnServer(t, svc)
|
||||
defer cleanup()
|
||||
|
||||
stream, err := client.SubscribePush(t.Context(), &pushv1.GatewaySubscribeRequest{GatewayClientId: "gw-1"})
|
||||
require.NoError(t, err)
|
||||
require.Eventually(t, func() bool { return svc.SubscriberCount() == 1 }, time.Second, 5*time.Millisecond)
|
||||
|
||||
svc.Close()
|
||||
|
||||
_, err = recvOne(t, stream, time.Second)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Aborted, status.Code(err))
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
package push
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
pushv1 "galaxy/backend/proto/push/v1"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// SubscribePush is the gRPC server handler. It registers the connection
|
||||
// in the subscription registry, replays any in-buffer events newer than
|
||||
// the requested cursor, and then streams live events until the client
|
||||
// cancels, the subscription is replaced by a newer connection from the
|
||||
// same gateway client id, or the Service is shut down.
|
||||
func (s *Service) SubscribePush(req *pushv1.GatewaySubscribeRequest, stream grpc.ServerStreamingServer[pushv1.PushEvent]) error {
|
||||
if req == nil || strings.TrimSpace(req.GetGatewayClientId()) == "" {
|
||||
return status.Error(codes.InvalidArgument, "gateway_client_id is required")
|
||||
}
|
||||
|
||||
sub, replay, err := s.register(req.GetGatewayClientId(), req.GetCursor())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer s.unregister(sub)
|
||||
|
||||
for _, ev := range replay {
|
||||
if err := stream.Send(ev); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
ctx := stream.Context()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-sub.done:
|
||||
return status.Error(codes.Aborted, "push subscription replaced or service stopped")
|
||||
case ev := <-sub.ch:
|
||||
if err := stream.Send(ev); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
package push
|
||||
|
||||
import (
|
||||
pushv1 "galaxy/backend/proto/push/v1"
|
||||
)
|
||||
|
||||
// subscription is the per-gateway-instance delivery queue. Each
|
||||
// subscription owns a buffered channel; the publisher writes into it
|
||||
// without blocking by dropping the oldest queued event when the buffer
|
||||
// is full. The done channel is closed by the Service when the
|
||||
// subscription is replaced (a new connection arrived for the same
|
||||
// gateway_client_id) or when the Service is shutting down.
|
||||
type subscription struct {
|
||||
clientID string
|
||||
ch chan *pushv1.PushEvent
|
||||
done chan struct{}
|
||||
dropped uint64
|
||||
}
|
||||
|
||||
// deliver enqueues ev into the subscription's buffer. When the buffer
|
||||
// is full, the oldest queued event is dropped to make room and the
|
||||
// dropped counter increments. The bool reports whether a drop occurred,
|
||||
// so the publisher can update its drop metric.
|
||||
//
|
||||
// The Service holds its mutex while calling deliver, which means at
|
||||
// most one publisher writes to ch at a time. The reader goroutine runs
|
||||
// independently and only consumes from ch, so the second send below is
|
||||
// guaranteed not to block: after evicting the head, the channel has at
|
||||
// least one free slot which no other publisher can fill.
|
||||
func (s *subscription) deliver(ev *pushv1.PushEvent) bool {
|
||||
select {
|
||||
case s.ch <- ev:
|
||||
return false
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case <-s.ch:
|
||||
default:
|
||||
}
|
||||
s.ch <- ev
|
||||
s.dropped++
|
||||
return true
|
||||
}
|
||||
Reference in New Issue
Block a user