408 lines
9.6 KiB
Go
408 lines
9.6 KiB
Go
// Package push provides the in-memory hub used to fan out internal
|
|
// client-facing events to active authenticated push streams.
|
|
package push
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"strings"
|
|
"sync"
|
|
)
|
|
|
|
const defaultSubscriptionQueueCapacity = 64
|
|
|
|
var (
|
|
// ErrSubscriptionOverflow reports that one push stream stopped consuming
|
|
// events quickly enough and its bounded queue overflowed.
|
|
ErrSubscriptionOverflow = errors.New("push stream overflowed")
|
|
|
|
// ErrSubscriptionRevoked reports that the authenticated device session bound
|
|
// to the push stream was revoked and the stream must terminate.
|
|
ErrSubscriptionRevoked = errors.New("device session is revoked")
|
|
|
|
// ErrHubShuttingDown reports that the gateway is shutting down and all
|
|
// active push streams must terminate promptly.
|
|
ErrHubShuttingDown = errors.New("gateway is shutting down")
|
|
)
|
|
|
|
// StreamBinding identifies one authenticated push stream tracked by Hub.
|
|
type StreamBinding struct {
|
|
// UserID is the verified authenticated user bound to the stream.
|
|
UserID string
|
|
|
|
// DeviceSessionID is the verified authenticated device session bound to the
|
|
// stream.
|
|
DeviceSessionID string
|
|
}
|
|
|
|
// Event is the internal client-facing event delivered from internal pub/sub to
|
|
// active push streams.
|
|
type Event struct {
|
|
// UserID identifies the authenticated user that should receive the event.
|
|
UserID string
|
|
|
|
// DeviceSessionID optionally narrows delivery to one device session.
|
|
DeviceSessionID string
|
|
|
|
// EventType identifies the stable client-facing event category.
|
|
EventType string
|
|
|
|
// EventID is the stable event correlation identifier.
|
|
EventID string
|
|
|
|
// PayloadBytes carries the opaque event payload bytes.
|
|
PayloadBytes []byte
|
|
|
|
// RequestID optionally correlates the event to an earlier client request.
|
|
RequestID string
|
|
|
|
// TraceID optionally carries tracing correlation.
|
|
TraceID string
|
|
}
|
|
|
|
// Subscription represents one active push stream registered in Hub.
|
|
type Subscription struct {
|
|
hub *Hub
|
|
id uint64
|
|
binding StreamBinding
|
|
events chan Event
|
|
done chan struct{}
|
|
|
|
closeOnce sync.Once
|
|
stateMu sync.RWMutex
|
|
err error
|
|
}
|
|
|
|
// Observer receives push stream lifecycle notifications suitable for metrics
|
|
// bookkeeping.
|
|
type Observer interface {
|
|
// Registered reports one active push stream binding.
|
|
Registered(binding StreamBinding)
|
|
|
|
// Unregistered reports that binding stopped with err. A nil err means the
|
|
// stream ended without a hub-enforced terminal reason.
|
|
Unregistered(binding StreamBinding, err error)
|
|
}
|
|
|
|
// Events returns the ordered event queue for the subscription.
|
|
func (s *Subscription) Events() <-chan Event {
|
|
if s == nil {
|
|
return nil
|
|
}
|
|
|
|
return s.events
|
|
}
|
|
|
|
// Done closes when the subscription has been removed from the hub.
|
|
func (s *Subscription) Done() <-chan struct{} {
|
|
if s == nil {
|
|
return nil
|
|
}
|
|
|
|
return s.done
|
|
}
|
|
|
|
// Err returns the terminal subscription error, if any.
|
|
func (s *Subscription) Err() error {
|
|
if s == nil {
|
|
return nil
|
|
}
|
|
|
|
s.stateMu.RLock()
|
|
defer s.stateMu.RUnlock()
|
|
|
|
return s.err
|
|
}
|
|
|
|
// Close unregisters the subscription from its hub.
|
|
func (s *Subscription) Close() {
|
|
if s == nil || s.hub == nil {
|
|
return
|
|
}
|
|
|
|
s.hub.unregister(s.id, nil)
|
|
}
|
|
|
|
func (s *Subscription) enqueue(event Event) bool {
|
|
if s == nil {
|
|
return true
|
|
}
|
|
|
|
cloned := cloneEvent(event)
|
|
|
|
select {
|
|
case <-s.done:
|
|
return true
|
|
default:
|
|
}
|
|
|
|
select {
|
|
case s.events <- cloned:
|
|
return true
|
|
case <-s.done:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func (s *Subscription) closeWithError(err error) {
|
|
if s == nil {
|
|
return
|
|
}
|
|
|
|
s.closeOnce.Do(func() {
|
|
s.stateMu.Lock()
|
|
s.err = err
|
|
s.stateMu.Unlock()
|
|
close(s.done)
|
|
})
|
|
}
|
|
|
|
// Hub tracks active authenticated push streams and fans out client-facing
|
|
// events to the matching subscriptions.
|
|
type Hub struct {
|
|
mu sync.RWMutex
|
|
nextID uint64
|
|
queueCapacity int
|
|
observer Observer
|
|
byID map[uint64]*Subscription
|
|
byUser map[string]map[uint64]*Subscription
|
|
bySession map[string]map[uint64]*Subscription
|
|
}
|
|
|
|
// NewHub constructs a push hub with one bounded in-memory queue per
|
|
// subscription. Non-positive queueCapacity falls back to the package default.
|
|
func NewHub(queueCapacity int) *Hub {
|
|
return NewHubWithObserver(queueCapacity, nil)
|
|
}
|
|
|
|
// NewHubWithObserver constructs a push hub that also reports stream lifecycle
|
|
// changes to observer.
|
|
func NewHubWithObserver(queueCapacity int, observer Observer) *Hub {
|
|
if queueCapacity <= 0 {
|
|
queueCapacity = defaultSubscriptionQueueCapacity
|
|
}
|
|
|
|
return &Hub{
|
|
queueCapacity: queueCapacity,
|
|
observer: observer,
|
|
byID: make(map[uint64]*Subscription),
|
|
byUser: make(map[string]map[uint64]*Subscription),
|
|
bySession: make(map[string]map[uint64]*Subscription),
|
|
}
|
|
}
|
|
|
|
// Register adds one authenticated push stream to the hub and returns its
|
|
// subscription handle.
|
|
func (h *Hub) Register(binding StreamBinding) (*Subscription, error) {
|
|
if h == nil {
|
|
return nil, errors.New("register push subscription: nil hub")
|
|
}
|
|
|
|
userID := strings.TrimSpace(binding.UserID)
|
|
if userID == "" {
|
|
return nil, errors.New("register push subscription: user id must not be empty")
|
|
}
|
|
|
|
deviceSessionID := strings.TrimSpace(binding.DeviceSessionID)
|
|
if deviceSessionID == "" {
|
|
return nil, errors.New("register push subscription: device session id must not be empty")
|
|
}
|
|
|
|
h.mu.Lock()
|
|
|
|
h.nextID++
|
|
subscription := &Subscription{
|
|
hub: h,
|
|
id: h.nextID,
|
|
binding: StreamBinding{
|
|
UserID: userID,
|
|
DeviceSessionID: deviceSessionID,
|
|
},
|
|
events: make(chan Event, h.queueCapacity),
|
|
done: make(chan struct{}),
|
|
}
|
|
h.byID[subscription.id] = subscription
|
|
addIndexedSubscription(h.byUser, userID, subscription)
|
|
addIndexedSubscription(h.bySession, deviceSessionID, subscription)
|
|
h.mu.Unlock()
|
|
|
|
if h.observer != nil {
|
|
h.observer.Registered(subscription.binding)
|
|
}
|
|
|
|
return subscription, nil
|
|
}
|
|
|
|
// Publish fans out event to the matching active subscriptions. When one
|
|
// subscription queue overflows, only that subscription is closed.
|
|
func (h *Hub) Publish(event Event) {
|
|
if h == nil {
|
|
return
|
|
}
|
|
|
|
targets := h.targets(event)
|
|
for _, target := range targets {
|
|
if target.enqueue(event) {
|
|
continue
|
|
}
|
|
|
|
h.unregister(target.id, ErrSubscriptionOverflow)
|
|
}
|
|
}
|
|
|
|
// RevokeDeviceSession closes all active subscriptions bound to the exact
|
|
// authenticated device session identifier.
|
|
func (h *Hub) RevokeDeviceSession(deviceSessionID string) {
|
|
if h == nil {
|
|
return
|
|
}
|
|
|
|
deviceSessionID = strings.TrimSpace(deviceSessionID)
|
|
if deviceSessionID == "" {
|
|
return
|
|
}
|
|
|
|
h.mu.RLock()
|
|
targets := cloneSubscriptions(h.bySession[deviceSessionID])
|
|
h.mu.RUnlock()
|
|
|
|
for _, target := range targets {
|
|
h.unregister(target.id, ErrSubscriptionRevoked)
|
|
}
|
|
}
|
|
|
|
// RevokeAllForUser closes every active subscription bound to userID,
|
|
// regardless of device-session id. Used when backend emits a
|
|
// SessionInvalidation that targets every session of a user.
|
|
func (h *Hub) RevokeAllForUser(userID string) {
|
|
if h == nil {
|
|
return
|
|
}
|
|
|
|
userID = strings.TrimSpace(userID)
|
|
if userID == "" {
|
|
return
|
|
}
|
|
|
|
h.mu.RLock()
|
|
targets := cloneSubscriptions(h.byUser[userID])
|
|
h.mu.RUnlock()
|
|
|
|
for _, target := range targets {
|
|
h.unregister(target.id, ErrSubscriptionRevoked)
|
|
}
|
|
}
|
|
|
|
// Shutdown closes every active subscription because the gateway is shutting
|
|
// down.
|
|
func (h *Hub) Shutdown() {
|
|
if h == nil {
|
|
return
|
|
}
|
|
|
|
h.mu.RLock()
|
|
targets := cloneSubscriptions(h.byID)
|
|
h.mu.RUnlock()
|
|
|
|
for _, target := range targets {
|
|
h.unregister(target.id, ErrHubShuttingDown)
|
|
}
|
|
}
|
|
|
|
func (h *Hub) targets(event Event) []*Subscription {
|
|
userID := strings.TrimSpace(event.UserID)
|
|
eventType := strings.TrimSpace(event.EventType)
|
|
eventID := strings.TrimSpace(event.EventID)
|
|
if h == nil || userID == "" || eventType == "" || eventID == "" {
|
|
return nil
|
|
}
|
|
|
|
deviceSessionID := strings.TrimSpace(event.DeviceSessionID)
|
|
|
|
h.mu.RLock()
|
|
defer h.mu.RUnlock()
|
|
|
|
if deviceSessionID == "" {
|
|
return cloneSubscriptions(h.byUser[userID])
|
|
}
|
|
|
|
sessionMatches := cloneSubscriptions(h.bySession[deviceSessionID])
|
|
filtered := sessionMatches[:0]
|
|
for _, subscription := range sessionMatches {
|
|
if subscription.binding.UserID == userID {
|
|
filtered = append(filtered, subscription)
|
|
}
|
|
}
|
|
|
|
return filtered
|
|
}
|
|
|
|
func (h *Hub) unregister(id uint64, err error) {
|
|
if h == nil || id == 0 {
|
|
return
|
|
}
|
|
|
|
h.mu.Lock()
|
|
subscription, ok := h.byID[id]
|
|
if !ok {
|
|
h.mu.Unlock()
|
|
return
|
|
}
|
|
|
|
delete(h.byID, id)
|
|
removeIndexedSubscription(h.byUser, subscription.binding.UserID, id)
|
|
removeIndexedSubscription(h.bySession, subscription.binding.DeviceSessionID, id)
|
|
h.mu.Unlock()
|
|
|
|
subscription.closeWithError(err)
|
|
if h.observer != nil {
|
|
h.observer.Unregistered(subscription.binding, err)
|
|
}
|
|
}
|
|
|
|
func addIndexedSubscription(index map[string]map[uint64]*Subscription, key string, subscription *Subscription) {
|
|
if _, ok := index[key]; !ok {
|
|
index[key] = make(map[uint64]*Subscription)
|
|
}
|
|
index[key][subscription.id] = subscription
|
|
}
|
|
|
|
func removeIndexedSubscription(index map[string]map[uint64]*Subscription, key string, id uint64) {
|
|
bucket, ok := index[key]
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
delete(bucket, id)
|
|
if len(bucket) == 0 {
|
|
delete(index, key)
|
|
}
|
|
}
|
|
|
|
func cloneSubscriptions(bucket map[uint64]*Subscription) []*Subscription {
|
|
if len(bucket) == 0 {
|
|
return nil
|
|
}
|
|
|
|
cloned := make([]*Subscription, 0, len(bucket))
|
|
for _, subscription := range bucket {
|
|
cloned = append(cloned, subscription)
|
|
}
|
|
|
|
return cloned
|
|
}
|
|
|
|
func cloneEvent(event Event) Event {
|
|
return Event{
|
|
UserID: event.UserID,
|
|
DeviceSessionID: event.DeviceSessionID,
|
|
EventType: event.EventType,
|
|
EventID: event.EventID,
|
|
PayloadBytes: bytes.Clone(event.PayloadBytes),
|
|
RequestID: event.RequestID,
|
|
TraceID: event.TraceID,
|
|
}
|
|
}
|