feat: edge gateway service
This commit is contained in:
@@ -0,0 +1,133 @@
|
||||
// Package adminapi exposes the optional private admin HTTP listener used for
|
||||
// operational endpoints such as Prometheus metrics.
|
||||
package adminapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"galaxy/gateway/internal/config"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Server owns the optional admin HTTP listener exposed by the gateway.
|
||||
type Server struct {
|
||||
cfg config.AdminHTTPConfig
|
||||
handler http.Handler
|
||||
logger *zap.Logger
|
||||
|
||||
stateMu sync.RWMutex
|
||||
server *http.Server
|
||||
listener net.Listener
|
||||
}
|
||||
|
||||
// NewServer constructs an admin HTTP server for cfg and handler.
|
||||
func NewServer(cfg config.AdminHTTPConfig, handler http.Handler, logger *zap.Logger) *Server {
|
||||
if handler == nil {
|
||||
handler = http.NotFoundHandler()
|
||||
}
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
|
||||
return &Server{
|
||||
cfg: cfg,
|
||||
handler: handler,
|
||||
logger: logger.Named("admin_http"),
|
||||
}
|
||||
}
|
||||
|
||||
// Enabled reports whether the admin listener should run.
|
||||
func (s *Server) Enabled() bool {
|
||||
return s != nil && s.cfg.Addr != ""
|
||||
}
|
||||
|
||||
// Run binds the configured listener and serves the admin HTTP surface until
|
||||
// Shutdown closes the server. A disabled admin server returns when ctx is
|
||||
// canceled.
|
||||
func (s *Server) Run(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
return errors.New("run admin HTTP server: nil context")
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
if !s.Enabled() {
|
||||
<-ctx.Done()
|
||||
return nil
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", s.cfg.Addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("run admin HTTP server: listen on %q: %w", s.cfg.Addr, err)
|
||||
}
|
||||
|
||||
server := &http.Server{
|
||||
Handler: s.handler,
|
||||
ReadHeaderTimeout: s.cfg.ReadHeaderTimeout,
|
||||
ReadTimeout: s.cfg.ReadTimeout,
|
||||
IdleTimeout: s.cfg.IdleTimeout,
|
||||
}
|
||||
|
||||
s.stateMu.Lock()
|
||||
s.server = server
|
||||
s.listener = listener
|
||||
s.stateMu.Unlock()
|
||||
|
||||
s.logger.Info("admin HTTP server started", zap.String("addr", listener.Addr().String()))
|
||||
|
||||
defer func() {
|
||||
s.stateMu.Lock()
|
||||
s.server = nil
|
||||
s.listener = nil
|
||||
s.stateMu.Unlock()
|
||||
}()
|
||||
|
||||
err = server.Serve(listener)
|
||||
switch {
|
||||
case err == nil:
|
||||
return nil
|
||||
case errors.Is(err, http.ErrServerClosed):
|
||||
s.logger.Info("admin HTTP server stopped")
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("run admin HTTP server: serve on %q: %w", s.cfg.Addr, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown gracefully stops the admin HTTP server within ctx.
|
||||
func (s *Server) Shutdown(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
return errors.New("shutdown admin HTTP server: nil context")
|
||||
}
|
||||
|
||||
s.stateMu.RLock()
|
||||
server := s.server
|
||||
s.stateMu.RUnlock()
|
||||
|
||||
if server == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := server.Shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
return fmt.Errorf("shutdown admin HTTP server: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) listenAddr() string {
|
||||
s.stateMu.RLock()
|
||||
defer s.stateMu.RUnlock()
|
||||
|
||||
if s.listener == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return s.listener.Addr().String()
|
||||
}
|
||||
@@ -0,0 +1,102 @@
|
||||
package adminapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/app"
|
||||
"galaxy/gateway/internal/config"
|
||||
"galaxy/gateway/internal/restapi"
|
||||
"galaxy/gateway/internal/testutil"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMetricsAreReachableOnlyOnAdminListener(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger, _ := testutil.NewObservedLogger(t)
|
||||
telemetryRuntime := testutil.NewTelemetryRuntime(t, logger)
|
||||
|
||||
publicAddr := unusedTCPAddr(t)
|
||||
adminAddr := unusedTCPAddr(t)
|
||||
|
||||
publicCfg := config.DefaultPublicHTTPConfig()
|
||||
publicCfg.Addr = publicAddr
|
||||
adminCfg := config.DefaultAdminHTTPConfig()
|
||||
adminCfg.Addr = adminAddr
|
||||
|
||||
restServer := restapi.NewServer(publicCfg, restapi.ServerDependencies{
|
||||
Logger: logger,
|
||||
Telemetry: telemetryRuntime,
|
||||
})
|
||||
adminServer := NewServer(adminCfg, telemetryRuntime.Handler(), logger)
|
||||
|
||||
application := app.New(
|
||||
config.Config{
|
||||
ShutdownTimeout: time.Second,
|
||||
PublicHTTP: publicCfg,
|
||||
AdminHTTP: adminCfg,
|
||||
AuthenticatedGRPC: config.DefaultAuthenticatedGRPCConfig(),
|
||||
},
|
||||
restServer,
|
||||
adminServer,
|
||||
)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
resultCh := make(chan error, 1)
|
||||
go func() {
|
||||
resultCh <- application.Run(ctx)
|
||||
}()
|
||||
defer func() {
|
||||
cancel()
|
||||
select {
|
||||
case err := <-resultCh:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "application did not stop")
|
||||
}
|
||||
}()
|
||||
|
||||
waitForHTTPStatus(t, "http://"+publicAddr+"/healthz", http.StatusOK)
|
||||
waitForHTTPStatus(t, "http://"+adminAddr+"/metrics", http.StatusOK)
|
||||
|
||||
publicMetricsResp, err := http.Get("http://" + publicAddr + "/metrics")
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, publicMetricsResp.Body.Close())
|
||||
}()
|
||||
assert.Equal(t, http.StatusNotFound, publicMetricsResp.StatusCode)
|
||||
}
|
||||
|
||||
func waitForHTTPStatus(t *testing.T, rawURL string, wantStatus int) {
|
||||
t.Helper()
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
resp, err := http.Get(rawURL)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
return resp.StatusCode == wantStatus
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func unusedTCPAddr(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
addr := listener.Addr().String()
|
||||
require.NoError(t, listener.Close())
|
||||
|
||||
return addr
|
||||
}
|
||||
@@ -0,0 +1,178 @@
|
||||
// Package app wires the gateway process lifecycle and coordinates component
|
||||
// startup and graceful shutdown.
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"galaxy/gateway/internal/config"
|
||||
)
|
||||
|
||||
// Component is a long-lived gateway subsystem that participates in coordinated
|
||||
// startup and graceful shutdown.
|
||||
type Component interface {
|
||||
// Run starts the component and blocks until it stops.
|
||||
Run(context.Context) error
|
||||
|
||||
// Shutdown stops the component within the provided timeout-bounded context.
|
||||
Shutdown(context.Context) error
|
||||
}
|
||||
|
||||
// App owns the process-level lifecycle of the gateway and its registered
|
||||
// components.
|
||||
type App struct {
|
||||
cfg config.Config
|
||||
components []Component
|
||||
}
|
||||
|
||||
// New constructs an App with a defensive copy of the supplied components.
|
||||
func New(cfg config.Config, components ...Component) *App {
|
||||
clonedComponents := append([]Component(nil), components...)
|
||||
|
||||
return &App{
|
||||
cfg: cfg,
|
||||
components: clonedComponents,
|
||||
}
|
||||
}
|
||||
|
||||
// Run starts all configured components, waits for cancellation or the first
|
||||
// component failure, and then executes best-effort graceful shutdown for every
|
||||
// component.
|
||||
func (a *App) Run(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
return errors.New("run gateway app: nil context")
|
||||
}
|
||||
if err := a.validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(a.components) == 0 {
|
||||
<-ctx.Done()
|
||||
return nil
|
||||
}
|
||||
|
||||
runCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
results := make(chan componentResult, len(a.components))
|
||||
var runWG sync.WaitGroup
|
||||
|
||||
for idx, component := range a.components {
|
||||
runWG.Add(1)
|
||||
|
||||
go func(index int, component Component) {
|
||||
defer runWG.Done()
|
||||
results <- componentResult{
|
||||
index: index,
|
||||
err: component.Run(runCtx),
|
||||
}
|
||||
}(idx, component)
|
||||
}
|
||||
|
||||
var runErr error
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case result := <-results:
|
||||
runErr = classifyComponentResult(ctx, result)
|
||||
}
|
||||
|
||||
cancel()
|
||||
|
||||
shutdownErr := a.shutdownComponents()
|
||||
waitErr := a.waitForComponents(&runWG)
|
||||
|
||||
return errors.Join(runErr, shutdownErr, waitErr)
|
||||
}
|
||||
|
||||
// componentResult captures the first observed exit from a running component.
|
||||
type componentResult struct {
|
||||
index int
|
||||
err error
|
||||
}
|
||||
|
||||
// validate confirms that the App has a safe shutdown budget and no nil
|
||||
// components before goroutines are started.
|
||||
func (a *App) validate() error {
|
||||
if a.cfg.ShutdownTimeout <= 0 {
|
||||
return fmt.Errorf("run gateway app: shutdown timeout must be positive, got %s", a.cfg.ShutdownTimeout)
|
||||
}
|
||||
|
||||
for idx, component := range a.components {
|
||||
if component == nil {
|
||||
return fmt.Errorf("run gateway app: component %d is nil", idx)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// classifyComponentResult maps the first component exit into the error that
|
||||
// should control the application result.
|
||||
func classifyComponentResult(parentCtx context.Context, result componentResult) error {
|
||||
switch {
|
||||
case result.err == nil:
|
||||
if parentCtx.Err() != nil {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("run gateway app: component %d exited without error before shutdown", result.index)
|
||||
case errors.Is(result.err, context.Canceled) && parentCtx.Err() != nil:
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("run gateway app: component %d: %w", result.index, result.err)
|
||||
}
|
||||
}
|
||||
|
||||
// shutdownComponents calls Shutdown on every registered component using a fresh
|
||||
// timeout-bounded context per component and joins any shutdown failures.
|
||||
func (a *App) shutdownComponents() error {
|
||||
var shutdownWG sync.WaitGroup
|
||||
errs := make(chan error, len(a.components))
|
||||
|
||||
for idx, component := range a.components {
|
||||
shutdownWG.Add(1)
|
||||
|
||||
go func(index int, component Component) {
|
||||
defer shutdownWG.Done()
|
||||
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), a.cfg.ShutdownTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := component.Shutdown(shutdownCtx); err != nil {
|
||||
errs <- fmt.Errorf("shutdown gateway component %d: %w", index, err)
|
||||
}
|
||||
}(idx, component)
|
||||
}
|
||||
|
||||
shutdownWG.Wait()
|
||||
close(errs)
|
||||
|
||||
var joined error
|
||||
for err := range errs {
|
||||
joined = errors.Join(joined, err)
|
||||
}
|
||||
|
||||
return joined
|
||||
}
|
||||
|
||||
// waitForComponents waits for running components to return after shutdown and
|
||||
// reports when they outlive the configured shutdown budget.
|
||||
func (a *App) waitForComponents(runWG *sync.WaitGroup) error {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
runWG.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(context.Background(), a.cfg.ShutdownTimeout)
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return nil
|
||||
case <-waitCtx.Done():
|
||||
return fmt.Errorf("wait for gateway components: %w", waitCtx.Err())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,268 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/config"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAppRunWaitsForCancellationWithoutComponents(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
application := New(config.Config{ShutdownTimeout: 50 * time.Millisecond})
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
resultCh := make(chan error, 1)
|
||||
go func() {
|
||||
resultCh <- application.Run(ctx)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-resultCh:
|
||||
require.FailNowf(t, "Run() returned early", "error=%v", err)
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case err := <-resultCh:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "Run() did not return after cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppRunCancelsComponentsAndCallsShutdownOnce(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
first := newLifecycleComponent()
|
||||
second := newLifecycleComponent()
|
||||
|
||||
application := New(
|
||||
config.Config{ShutdownTimeout: time.Second},
|
||||
first,
|
||||
second,
|
||||
)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
resultCh := make(chan error, 1)
|
||||
go func() {
|
||||
resultCh <- application.Run(ctx)
|
||||
}()
|
||||
|
||||
first.waitStarted(t)
|
||||
second.waitStarted(t)
|
||||
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case err := <-resultCh:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "Run() did not return after cancellation")
|
||||
}
|
||||
|
||||
first.waitRunExited(t)
|
||||
second.waitRunExited(t)
|
||||
|
||||
assert.Equal(t, 1, first.shutdownCalls())
|
||||
assert.Equal(t, 1, second.shutdownCalls())
|
||||
}
|
||||
|
||||
func TestAppRunReturnsComponentErrorAndStillShutsDown(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
runErr := errors.New("boom")
|
||||
failing := newFailingComponent(runErr)
|
||||
blocking := newLifecycleComponent()
|
||||
|
||||
application := New(
|
||||
config.Config{ShutdownTimeout: time.Second},
|
||||
failing,
|
||||
blocking,
|
||||
)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
resultCh := make(chan error, 1)
|
||||
go func() {
|
||||
resultCh <- application.Run(ctx)
|
||||
}()
|
||||
|
||||
failing.waitStarted(t)
|
||||
blocking.waitStarted(t)
|
||||
failing.releaseRun()
|
||||
|
||||
select {
|
||||
case err := <-resultCh:
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, runErr)
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "Run() did not return after component failure")
|
||||
}
|
||||
|
||||
failing.waitRunExited(t)
|
||||
blocking.waitRunExited(t)
|
||||
|
||||
assert.Equal(t, 1, failing.shutdownCalls())
|
||||
assert.Equal(t, 1, blocking.shutdownCalls())
|
||||
}
|
||||
|
||||
// lifecycleComponent blocks in Run until the application calls Shutdown.
|
||||
type lifecycleComponent struct {
|
||||
startedCh chan struct{}
|
||||
runDoneCh chan struct{}
|
||||
stopCh chan struct{}
|
||||
shutdownMu sync.Mutex
|
||||
shutdownCnt int
|
||||
}
|
||||
|
||||
// newLifecycleComponent builds a component that exits Run only after Shutdown
|
||||
// signals its stop channel.
|
||||
func newLifecycleComponent() *lifecycleComponent {
|
||||
return &lifecycleComponent{
|
||||
startedCh: make(chan struct{}),
|
||||
runDoneCh: make(chan struct{}),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Run marks the component as started, waits for cancellation, and then blocks
|
||||
// until Shutdown releases the stop channel.
|
||||
func (c *lifecycleComponent) Run(ctx context.Context) error {
|
||||
close(c.startedCh)
|
||||
defer close(c.runDoneCh)
|
||||
|
||||
<-ctx.Done()
|
||||
<-c.stopCh
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown records the call and releases the run loop.
|
||||
func (c *lifecycleComponent) Shutdown(context.Context) error {
|
||||
c.shutdownMu.Lock()
|
||||
defer c.shutdownMu.Unlock()
|
||||
|
||||
c.shutdownCnt++
|
||||
if c.shutdownCnt == 1 {
|
||||
close(c.stopCh)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// waitStarted blocks until Run has started or fails the test on timeout.
|
||||
func (c *lifecycleComponent) waitStarted(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
case <-c.startedCh:
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "component did not start")
|
||||
}
|
||||
}
|
||||
|
||||
// waitRunExited blocks until Run exits or fails the test on timeout.
|
||||
func (c *lifecycleComponent) waitRunExited(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
case <-c.runDoneCh:
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "component run did not exit")
|
||||
}
|
||||
}
|
||||
|
||||
// shutdownCalls returns the number of observed Shutdown invocations.
|
||||
func (c *lifecycleComponent) shutdownCalls() int {
|
||||
c.shutdownMu.Lock()
|
||||
defer c.shutdownMu.Unlock()
|
||||
|
||||
return c.shutdownCnt
|
||||
}
|
||||
|
||||
// failingComponent returns a predefined error once released by the test and
|
||||
// still tracks shutdown calls.
|
||||
type failingComponent struct {
|
||||
startedCh chan struct{}
|
||||
releaseCh chan struct{}
|
||||
runDoneCh chan struct{}
|
||||
shutdownMu sync.Mutex
|
||||
shutdownCnt int
|
||||
err error
|
||||
}
|
||||
|
||||
// newFailingComponent builds a component whose Run returns err after release.
|
||||
func newFailingComponent(err error) *failingComponent {
|
||||
return &failingComponent{
|
||||
startedCh: make(chan struct{}),
|
||||
releaseCh: make(chan struct{}),
|
||||
runDoneCh: make(chan struct{}),
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// Run waits until the test releases it and then returns the configured error.
|
||||
func (c *failingComponent) Run(context.Context) error {
|
||||
close(c.startedCh)
|
||||
defer close(c.runDoneCh)
|
||||
|
||||
<-c.releaseCh
|
||||
return c.err
|
||||
}
|
||||
|
||||
// Shutdown records that the application attempted graceful shutdown.
|
||||
func (c *failingComponent) Shutdown(context.Context) error {
|
||||
c.shutdownMu.Lock()
|
||||
defer c.shutdownMu.Unlock()
|
||||
|
||||
c.shutdownCnt++
|
||||
return nil
|
||||
}
|
||||
|
||||
// waitStarted blocks until Run has started or fails the test on timeout.
|
||||
func (c *failingComponent) waitStarted(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
case <-c.startedCh:
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "failing component did not start")
|
||||
}
|
||||
}
|
||||
|
||||
// releaseRun allows Run to return its configured error.
|
||||
func (c *failingComponent) releaseRun() {
|
||||
close(c.releaseCh)
|
||||
}
|
||||
|
||||
// waitRunExited blocks until Run exits or fails the test on timeout.
|
||||
func (c *failingComponent) waitRunExited(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
case <-c.runDoneCh:
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "failing component run did not exit")
|
||||
}
|
||||
}
|
||||
|
||||
// shutdownCalls returns the number of observed Shutdown invocations.
|
||||
func (c *failingComponent) shutdownCalls() int {
|
||||
c.shutdownMu.Lock()
|
||||
defer c.shutdownMu.Unlock()
|
||||
|
||||
return c.shutdownCnt
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
package authn
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
)
|
||||
|
||||
const (
|
||||
// EventDomainMarkerV1 binds the v1 server event signature to the Galaxy
|
||||
// gateway transport contract.
|
||||
EventDomainMarkerV1 = "galaxy-event-v1"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrInvalidEventSignature reports that a gateway stream event signature is
|
||||
// not a raw Ed25519 signature for the canonical event signing input.
|
||||
ErrInvalidEventSignature = errors.New("invalid event signature")
|
||||
)
|
||||
|
||||
// EventSigningFields contains the canonical v1 stream-event fields that are
|
||||
// bound into the server signing input.
|
||||
type EventSigningFields struct {
|
||||
// EventType identifies the stable client-facing event category.
|
||||
EventType string
|
||||
|
||||
// EventID is the stable event correlation identifier.
|
||||
EventID string
|
||||
|
||||
// TimestampMS carries the server event timestamp in milliseconds.
|
||||
TimestampMS int64
|
||||
|
||||
// RequestID optionally correlates the event to the opening client request.
|
||||
RequestID string
|
||||
|
||||
// TraceID optionally carries the client-supplied tracing correlation value.
|
||||
TraceID string
|
||||
|
||||
// PayloadHash is the raw SHA-256 digest of event payload bytes.
|
||||
PayloadHash []byte
|
||||
}
|
||||
|
||||
// BuildEventSigningInput returns the canonical byte sequence the v1 gateway
|
||||
// stream-event signature covers. String and byte fields are length-prefixed
|
||||
// with uvarint(len(field)) followed by raw bytes, while TimestampMS is
|
||||
// appended as an 8-byte big-endian uint64.
|
||||
func BuildEventSigningInput(fields EventSigningFields) []byte {
|
||||
size := len(EventDomainMarkerV1) +
|
||||
len(fields.EventType) +
|
||||
len(fields.EventID) +
|
||||
len(fields.RequestID) +
|
||||
len(fields.TraceID) +
|
||||
len(fields.PayloadHash) +
|
||||
(6 * binary.MaxVarintLen64) +
|
||||
8
|
||||
|
||||
buf := make([]byte, 0, size)
|
||||
buf = appendLengthPrefixedString(buf, EventDomainMarkerV1)
|
||||
buf = appendLengthPrefixedString(buf, fields.EventType)
|
||||
buf = appendLengthPrefixedString(buf, fields.EventID)
|
||||
buf = binary.BigEndian.AppendUint64(buf, uint64(fields.TimestampMS))
|
||||
buf = appendLengthPrefixedString(buf, fields.RequestID)
|
||||
buf = appendLengthPrefixedString(buf, fields.TraceID)
|
||||
buf = appendLengthPrefixedBytes(buf, fields.PayloadHash)
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
// VerifyEventSignature verifies that signature authenticates fields under
|
||||
// publicKey using the canonical v1 event signing input.
|
||||
func VerifyEventSignature(publicKey ed25519.PublicKey, signature []byte, fields EventSigningFields) error {
|
||||
if len(publicKey) != ed25519.PublicKeySize || len(signature) != ed25519.SignatureSize {
|
||||
return ErrInvalidEventSignature
|
||||
}
|
||||
if !ed25519.Verify(publicKey, BuildEventSigningInput(fields), signature) {
|
||||
return ErrInvalidEventSignature
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,111 @@
|
||||
package authn
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBuildEventSigningInputChangesWhenSignedFieldChanges(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
base := EventSigningFields{
|
||||
EventType: "gateway.server_time",
|
||||
EventID: "request-123",
|
||||
TimestampMS: 123456789,
|
||||
RequestID: "request-123",
|
||||
TraceID: "trace-123",
|
||||
PayloadHash: mustSHA256([]byte("payload")),
|
||||
}
|
||||
|
||||
baseInput := BuildEventSigningInput(base)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mutate func(EventSigningFields) EventSigningFields
|
||||
}{
|
||||
{
|
||||
name: "event type",
|
||||
mutate: func(fields EventSigningFields) EventSigningFields {
|
||||
fields.EventType = "gateway.other"
|
||||
return fields
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "event id",
|
||||
mutate: func(fields EventSigningFields) EventSigningFields {
|
||||
fields.EventID = "request-456"
|
||||
return fields
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "timestamp",
|
||||
mutate: func(fields EventSigningFields) EventSigningFields {
|
||||
fields.TimestampMS++
|
||||
return fields
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "request id",
|
||||
mutate: func(fields EventSigningFields) EventSigningFields {
|
||||
fields.RequestID = "request-456"
|
||||
return fields
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "trace id",
|
||||
mutate: func(fields EventSigningFields) EventSigningFields {
|
||||
fields.TraceID = "trace-456"
|
||||
return fields
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "payload hash",
|
||||
mutate: func(fields EventSigningFields) EventSigningFields {
|
||||
fields.PayloadHash = mustSHA256([]byte("other"))
|
||||
return fields
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mutated := BuildEventSigningInput(tt.mutate(base))
|
||||
assert.False(t, bytes.Equal(baseInput, mutated))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignAndVerifyEventSignature(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, privateKey, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
signer, err := NewEd25519ResponseSigner(privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
fields := EventSigningFields{
|
||||
EventType: "gateway.server_time",
|
||||
EventID: "request-123",
|
||||
TimestampMS: 123456789,
|
||||
RequestID: "request-123",
|
||||
TraceID: "trace-123",
|
||||
PayloadHash: mustSHA256([]byte("payload")),
|
||||
}
|
||||
|
||||
signature, err := signer.SignEvent(fields)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, VerifyEventSignature(signer.PublicKey(), signature, fields))
|
||||
|
||||
fields.TraceID = "changed"
|
||||
require.ErrorIs(t, VerifyEventSignature(signer.PublicKey(), signature, fields), ErrInvalidEventSignature)
|
||||
}
|
||||
@@ -0,0 +1,101 @@
|
||||
// Package authn defines authenticated transport helpers shared by the gateway
|
||||
// edge verification pipeline.
|
||||
package authn
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
)
|
||||
|
||||
const (
|
||||
// RequestDomainMarkerV1 binds the v1 client request signature to the Galaxy
|
||||
// gateway transport contract.
|
||||
RequestDomainMarkerV1 = "galaxy-request-v1"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrInvalidPayloadHash reports that payloadHash is not a raw SHA-256 digest.
|
||||
ErrInvalidPayloadHash = errors.New("payload_hash must be a 32-byte SHA-256 digest")
|
||||
|
||||
// ErrPayloadHashMismatch reports that payloadHash does not match payloadBytes.
|
||||
ErrPayloadHashMismatch = errors.New("payload_hash does not match payload_bytes")
|
||||
)
|
||||
|
||||
// RequestSigningFields contains the canonical v1 request fields that are bound
|
||||
// into the client signing input after the gateway validates and normalizes the
|
||||
// request envelope.
|
||||
type RequestSigningFields struct {
|
||||
// ProtocolVersion identifies the transport envelope version.
|
||||
ProtocolVersion string
|
||||
|
||||
// DeviceSessionID identifies the authenticated device session bound to the
|
||||
// request.
|
||||
DeviceSessionID string
|
||||
|
||||
// MessageType is the stable downstream routing key.
|
||||
MessageType string
|
||||
|
||||
// TimestampMS carries the client request timestamp in milliseconds.
|
||||
TimestampMS int64
|
||||
|
||||
// RequestID is the transport correlation and anti-replay identifier.
|
||||
RequestID string
|
||||
|
||||
// PayloadHash is the raw SHA-256 digest of payload bytes.
|
||||
PayloadHash []byte
|
||||
}
|
||||
|
||||
// BuildRequestSigningInput returns the canonical byte sequence the v1 client
|
||||
// request signature covers. String and byte fields are length-prefixed with
|
||||
// uvarint(len(field)) followed by raw bytes, while TimestampMS is appended as
|
||||
// an 8-byte big-endian uint64. The caller is expected to pass fields that have
|
||||
// already passed earlier envelope validation.
|
||||
func BuildRequestSigningInput(fields RequestSigningFields) []byte {
|
||||
size := len(RequestDomainMarkerV1) +
|
||||
len(fields.ProtocolVersion) +
|
||||
len(fields.DeviceSessionID) +
|
||||
len(fields.MessageType) +
|
||||
len(fields.RequestID) +
|
||||
len(fields.PayloadHash) +
|
||||
(6 * binary.MaxVarintLen64) +
|
||||
8
|
||||
|
||||
buf := make([]byte, 0, size)
|
||||
buf = appendLengthPrefixedString(buf, RequestDomainMarkerV1)
|
||||
buf = appendLengthPrefixedString(buf, fields.ProtocolVersion)
|
||||
buf = appendLengthPrefixedString(buf, fields.DeviceSessionID)
|
||||
buf = appendLengthPrefixedString(buf, fields.MessageType)
|
||||
buf = binary.BigEndian.AppendUint64(buf, uint64(fields.TimestampMS))
|
||||
buf = appendLengthPrefixedString(buf, fields.RequestID)
|
||||
buf = appendLengthPrefixedBytes(buf, fields.PayloadHash)
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
// VerifyPayloadHash checks that payloadHash is the raw SHA-256 digest of
|
||||
// payloadBytes. Empty payloadBytes are valid and must use sha256.Sum256(nil).
|
||||
func VerifyPayloadHash(payloadBytes, payloadHash []byte) error {
|
||||
if len(payloadHash) != sha256.Size {
|
||||
return ErrInvalidPayloadHash
|
||||
}
|
||||
|
||||
sum := sha256.Sum256(payloadBytes)
|
||||
if !bytes.Equal(sum[:], payloadHash) {
|
||||
return ErrPayloadHashMismatch
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func appendLengthPrefixedString(dst []byte, value string) []byte {
|
||||
return appendLengthPrefixedBytes(dst, []byte(value))
|
||||
}
|
||||
|
||||
func appendLengthPrefixedBytes(dst []byte, value []byte) []byte {
|
||||
dst = binary.AppendUvarint(dst, uint64(len(value)))
|
||||
dst = append(dst, value...)
|
||||
|
||||
return dst
|
||||
}
|
||||
@@ -0,0 +1,163 @@
|
||||
package authn
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestVerifyPayloadHash(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
payloadSum := sha256.Sum256([]byte("payload"))
|
||||
emptySum := sha256.Sum256(nil)
|
||||
otherSum := sha256.Sum256([]byte("other"))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
payload []byte
|
||||
payloadHash []byte
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "matches non-empty payload",
|
||||
payload: []byte("payload"),
|
||||
payloadHash: payloadSum[:],
|
||||
},
|
||||
{
|
||||
name: "matches empty payload",
|
||||
payload: nil,
|
||||
payloadHash: emptySum[:],
|
||||
},
|
||||
{
|
||||
name: "rejects digest with invalid length",
|
||||
payload: []byte("payload"),
|
||||
payloadHash: []byte("short"),
|
||||
wantErr: ErrInvalidPayloadHash,
|
||||
},
|
||||
{
|
||||
name: "rejects digest mismatch",
|
||||
payload: []byte("payload"),
|
||||
payloadHash: otherSum[:],
|
||||
wantErr: ErrPayloadHashMismatch,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := VerifyPayloadHash(tt.payload, tt.payloadHash)
|
||||
if tt.wantErr == nil {
|
||||
require.NoError(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRequestSigningInput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fields := RequestSigningFields{
|
||||
ProtocolVersion: "v1",
|
||||
DeviceSessionID: "device-session-123",
|
||||
MessageType: "fleet.move",
|
||||
TimestampMS: 123456789,
|
||||
RequestID: "request-123",
|
||||
PayloadHash: mustSHA256([]byte("payload")),
|
||||
}
|
||||
|
||||
got := BuildRequestSigningInput(fields)
|
||||
|
||||
want, err := hex.DecodeString("1167616c6178792d726571756573742d7631027631126465766963652d73657373696f6e2d3132330a666c6565742e6d6f766500000000075bcd150b726571756573742d31323320239f59ed55e737c77147cf55ad0c1b030b6d7ee748a7426952f9b852d5a935e5")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestBuildRequestSigningInputChangesWhenSignedFieldChanges(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
base := RequestSigningFields{
|
||||
ProtocolVersion: "v1",
|
||||
DeviceSessionID: "device-session-123",
|
||||
MessageType: "fleet.move",
|
||||
TimestampMS: 123456789,
|
||||
RequestID: "request-123",
|
||||
PayloadHash: mustSHA256([]byte("payload")),
|
||||
}
|
||||
|
||||
baseInput := BuildRequestSigningInput(base)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mutate func(RequestSigningFields) RequestSigningFields
|
||||
}{
|
||||
{
|
||||
name: "protocol version",
|
||||
mutate: func(fields RequestSigningFields) RequestSigningFields {
|
||||
fields.ProtocolVersion = "v2"
|
||||
return fields
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "device session id",
|
||||
mutate: func(fields RequestSigningFields) RequestSigningFields {
|
||||
fields.DeviceSessionID = "device-session-456"
|
||||
return fields
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "message type",
|
||||
mutate: func(fields RequestSigningFields) RequestSigningFields {
|
||||
fields.MessageType = "fleet.attack"
|
||||
return fields
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "timestamp",
|
||||
mutate: func(fields RequestSigningFields) RequestSigningFields {
|
||||
fields.TimestampMS++
|
||||
return fields
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "request id",
|
||||
mutate: func(fields RequestSigningFields) RequestSigningFields {
|
||||
fields.RequestID = "request-456"
|
||||
return fields
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "payload hash",
|
||||
mutate: func(fields RequestSigningFields) RequestSigningFields {
|
||||
fields.PayloadHash = mustSHA256([]byte("other"))
|
||||
return fields
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mutated := BuildRequestSigningInput(tt.mutate(base))
|
||||
assert.False(t, bytes.Equal(baseInput, mutated))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func mustSHA256(payload []byte) []byte {
|
||||
sum := sha256.Sum256(payload)
|
||||
return sum[:]
|
||||
}
|
||||
@@ -0,0 +1,189 @@
|
||||
package authn
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ed25519"
|
||||
"crypto/x509"
|
||||
"encoding/binary"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
const (
|
||||
// ResponseDomainMarkerV1 binds the v1 server response signature to the
|
||||
// Galaxy gateway transport contract.
|
||||
ResponseDomainMarkerV1 = "galaxy-response-v1"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrInvalidResponsePrivateKeyPEM reports that the configured response
|
||||
// signer private key is not a strict PKCS#8 PEM-encoded private key.
|
||||
ErrInvalidResponsePrivateKeyPEM = errors.New("response signer private key is not a valid PKCS#8 PEM block")
|
||||
|
||||
// ErrInvalidResponsePrivateKey reports that the configured response signer
|
||||
// private key is not an Ed25519 private key.
|
||||
ErrInvalidResponsePrivateKey = errors.New("response signer private key must be an Ed25519 PKCS#8 private key")
|
||||
|
||||
// ErrInvalidResponseSignature reports that a server response signature is
|
||||
// not a raw Ed25519 signature for the canonical response signing input.
|
||||
ErrInvalidResponseSignature = errors.New("invalid response signature")
|
||||
)
|
||||
|
||||
// ResponseSigningFields contains the canonical v1 response fields that are
|
||||
// bound into the server signing input.
|
||||
type ResponseSigningFields struct {
|
||||
// ProtocolVersion identifies the transport envelope version.
|
||||
ProtocolVersion string
|
||||
|
||||
// RequestID is the transport correlation identifier copied from the
|
||||
// authenticated request.
|
||||
RequestID string
|
||||
|
||||
// TimestampMS carries the server response timestamp in milliseconds.
|
||||
TimestampMS int64
|
||||
|
||||
// ResultCode is the opaque downstream result code returned to the client.
|
||||
ResultCode string
|
||||
|
||||
// PayloadHash is the raw SHA-256 digest of response payload bytes.
|
||||
PayloadHash []byte
|
||||
}
|
||||
|
||||
// ResponseSigner signs authenticated unary responses and client-facing stream
|
||||
// events with one server-side key.
|
||||
type ResponseSigner interface {
|
||||
// SignResponse returns the raw Ed25519 signature for the canonical response
|
||||
// signing input built from fields.
|
||||
SignResponse(fields ResponseSigningFields) ([]byte, error)
|
||||
|
||||
// SignEvent returns the raw Ed25519 signature for the canonical event
|
||||
// signing input built from fields.
|
||||
SignEvent(fields EventSigningFields) ([]byte, error)
|
||||
}
|
||||
|
||||
// Ed25519ResponseSigner signs authenticated responses with one Ed25519 private
|
||||
// key loaded during process startup.
|
||||
type Ed25519ResponseSigner struct {
|
||||
privateKey ed25519.PrivateKey
|
||||
}
|
||||
|
||||
// NewEd25519ResponseSigner validates privateKey and constructs a signer using
|
||||
// a defensive key copy.
|
||||
func NewEd25519ResponseSigner(privateKey ed25519.PrivateKey) (*Ed25519ResponseSigner, error) {
|
||||
if len(privateKey) != ed25519.PrivateKeySize {
|
||||
return nil, ErrInvalidResponsePrivateKey
|
||||
}
|
||||
|
||||
return &Ed25519ResponseSigner{
|
||||
privateKey: bytes.Clone(privateKey),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// LoadEd25519ResponseSignerFromPEMFile loads a strict PKCS#8 PEM-encoded
|
||||
// Ed25519 private key from path and constructs a signer.
|
||||
func LoadEd25519ResponseSignerFromPEMFile(path string) (*Ed25519ResponseSigner, error) {
|
||||
pemBytes, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response signer private key PEM: %w", err)
|
||||
}
|
||||
|
||||
signer, err := ParseEd25519ResponseSignerPEM(pemBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return signer, nil
|
||||
}
|
||||
|
||||
// ParseEd25519ResponseSignerPEM parses one strict PKCS#8 PEM-encoded Ed25519
|
||||
// private key and constructs a signer from it.
|
||||
func ParseEd25519ResponseSignerPEM(pemBytes []byte) (*Ed25519ResponseSigner, error) {
|
||||
block, rest := pem.Decode(pemBytes)
|
||||
if block == nil || block.Type != "PRIVATE KEY" || len(bytes.TrimSpace(rest)) > 0 {
|
||||
return nil, ErrInvalidResponsePrivateKeyPEM
|
||||
}
|
||||
|
||||
parsedKey, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidResponsePrivateKeyPEM
|
||||
}
|
||||
|
||||
privateKey, ok := parsedKey.(ed25519.PrivateKey)
|
||||
if !ok {
|
||||
return nil, ErrInvalidResponsePrivateKey
|
||||
}
|
||||
|
||||
return NewEd25519ResponseSigner(privateKey)
|
||||
}
|
||||
|
||||
// PublicKey returns the Ed25519 public key that corresponds to the configured
|
||||
// response signer private key.
|
||||
func (s *Ed25519ResponseSigner) PublicKey() ed25519.PublicKey {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
publicKey, _ := s.privateKey.Public().(ed25519.PublicKey)
|
||||
return bytes.Clone(publicKey)
|
||||
}
|
||||
|
||||
// SignResponse signs the canonical v1 response signing input built from
|
||||
// fields.
|
||||
func (s *Ed25519ResponseSigner) SignResponse(fields ResponseSigningFields) ([]byte, error) {
|
||||
if s == nil || len(s.privateKey) != ed25519.PrivateKeySize {
|
||||
return nil, ErrInvalidResponsePrivateKey
|
||||
}
|
||||
|
||||
signature := ed25519.Sign(s.privateKey, BuildResponseSigningInput(fields))
|
||||
return bytes.Clone(signature), nil
|
||||
}
|
||||
|
||||
// SignEvent signs the canonical v1 stream-event signing input built from
|
||||
// fields.
|
||||
func (s *Ed25519ResponseSigner) SignEvent(fields EventSigningFields) ([]byte, error) {
|
||||
if s == nil || len(s.privateKey) != ed25519.PrivateKeySize {
|
||||
return nil, ErrInvalidResponsePrivateKey
|
||||
}
|
||||
|
||||
signature := ed25519.Sign(s.privateKey, BuildEventSigningInput(fields))
|
||||
return bytes.Clone(signature), nil
|
||||
}
|
||||
|
||||
// BuildResponseSigningInput returns the canonical byte sequence the v1 server
|
||||
// response signature covers. String and byte fields are length-prefixed with
|
||||
// uvarint(len(field)) followed by raw bytes, while TimestampMS is appended as
|
||||
// an 8-byte big-endian uint64.
|
||||
func BuildResponseSigningInput(fields ResponseSigningFields) []byte {
|
||||
size := len(ResponseDomainMarkerV1) +
|
||||
len(fields.ProtocolVersion) +
|
||||
len(fields.RequestID) +
|
||||
len(fields.ResultCode) +
|
||||
len(fields.PayloadHash) +
|
||||
(5 * binary.MaxVarintLen64) +
|
||||
8
|
||||
|
||||
buf := make([]byte, 0, size)
|
||||
buf = appendLengthPrefixedString(buf, ResponseDomainMarkerV1)
|
||||
buf = appendLengthPrefixedString(buf, fields.ProtocolVersion)
|
||||
buf = appendLengthPrefixedString(buf, fields.RequestID)
|
||||
buf = binary.BigEndian.AppendUint64(buf, uint64(fields.TimestampMS))
|
||||
buf = appendLengthPrefixedString(buf, fields.ResultCode)
|
||||
buf = appendLengthPrefixedBytes(buf, fields.PayloadHash)
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
// VerifyResponseSignature verifies that signature authenticates fields under
|
||||
// publicKey using the canonical v1 response signing input.
|
||||
func VerifyResponseSignature(publicKey ed25519.PublicKey, signature []byte, fields ResponseSigningFields) error {
|
||||
if len(publicKey) != ed25519.PublicKeySize || len(signature) != ed25519.SignatureSize {
|
||||
return ErrInvalidResponseSignature
|
||||
}
|
||||
if !ed25519.Verify(publicKey, BuildResponseSigningInput(fields), signature) {
|
||||
return ErrInvalidResponseSignature
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
package authn
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBuildResponseSigningInputChangesWhenSignedFieldChanges(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
base := ResponseSigningFields{
|
||||
ProtocolVersion: "v1",
|
||||
RequestID: "request-123",
|
||||
TimestampMS: 123456789,
|
||||
ResultCode: "ok",
|
||||
PayloadHash: mustSHA256([]byte("payload")),
|
||||
}
|
||||
|
||||
baseInput := BuildResponseSigningInput(base)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mutate func(ResponseSigningFields) ResponseSigningFields
|
||||
}{
|
||||
{
|
||||
name: "protocol version",
|
||||
mutate: func(fields ResponseSigningFields) ResponseSigningFields {
|
||||
fields.ProtocolVersion = "v2"
|
||||
return fields
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "request id",
|
||||
mutate: func(fields ResponseSigningFields) ResponseSigningFields {
|
||||
fields.RequestID = "request-456"
|
||||
return fields
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "timestamp",
|
||||
mutate: func(fields ResponseSigningFields) ResponseSigningFields {
|
||||
fields.TimestampMS++
|
||||
return fields
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "result code",
|
||||
mutate: func(fields ResponseSigningFields) ResponseSigningFields {
|
||||
fields.ResultCode = "denied"
|
||||
return fields
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "payload hash",
|
||||
mutate: func(fields ResponseSigningFields) ResponseSigningFields {
|
||||
fields.PayloadHash = mustSHA256([]byte("other"))
|
||||
return fields
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mutated := BuildResponseSigningInput(tt.mutate(base))
|
||||
assert.False(t, bytes.Equal(baseInput, mutated))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseEd25519ResponseSignerPEMRejectsMalformedPEM(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := ParseEd25519ResponseSignerPEM([]byte("not-pem"))
|
||||
require.ErrorIs(t, err, ErrInvalidResponsePrivateKeyPEM)
|
||||
}
|
||||
|
||||
func TestParseEd25519ResponseSignerPEMRejectsNonPKCS8PEM(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, privateKey, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
pemBytes, err := x509.MarshalPKCS8PrivateKey(privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
block := pem.Block{
|
||||
Type: "ED25519 PRIVATE KEY",
|
||||
Bytes: pemBytes,
|
||||
}
|
||||
|
||||
_, err = ParseEd25519ResponseSignerPEM(pem.EncodeToMemory(&block))
|
||||
require.ErrorIs(t, err, ErrInvalidResponsePrivateKeyPEM)
|
||||
}
|
||||
|
||||
func TestParseEd25519ResponseSignerPEMRejectsNonEd25519Key(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
require.NoError(t, err)
|
||||
|
||||
pemBytes, err := x509.MarshalPKCS8PrivateKey(privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ParseEd25519ResponseSignerPEM(pem.EncodeToMemory(&pem.Block{
|
||||
Type: "PRIVATE KEY",
|
||||
Bytes: pemBytes,
|
||||
}))
|
||||
require.ErrorIs(t, err, ErrInvalidResponsePrivateKey)
|
||||
}
|
||||
|
||||
func TestSignAndVerifyResponseSignature(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, privateKey, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
signer, err := NewEd25519ResponseSigner(privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
fields := ResponseSigningFields{
|
||||
ProtocolVersion: "v1",
|
||||
RequestID: "request-123",
|
||||
TimestampMS: 123456789,
|
||||
ResultCode: "ok",
|
||||
PayloadHash: mustSHA256([]byte("payload")),
|
||||
}
|
||||
|
||||
signature, err := signer.SignResponse(fields)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, VerifyResponseSignature(signer.PublicKey(), signature, fields))
|
||||
|
||||
fields.ResultCode = "changed"
|
||||
require.ErrorIs(t, VerifyResponseSignature(signer.PublicKey(), signature, fields), ErrInvalidResponseSignature)
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
package authn
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrInvalidClientPublicKey reports that cached client public key material
|
||||
// is not a base64-encoded raw Ed25519 public key.
|
||||
ErrInvalidClientPublicKey = errors.New("client_public_key is not a valid base64-encoded Ed25519 public key")
|
||||
|
||||
// ErrInvalidRequestSignature reports that a request signature is not a raw
|
||||
// Ed25519 signature for the canonical request signing input.
|
||||
ErrInvalidRequestSignature = errors.New("invalid request signature")
|
||||
)
|
||||
|
||||
// VerifyRequestSignature validates the base64-encoded raw Ed25519 public key
|
||||
// from session cache, builds the canonical v1 signing input from fields, and
|
||||
// verifies that signature authenticates the request.
|
||||
func VerifyRequestSignature(clientPublicKey string, signature []byte, fields RequestSigningFields) error {
|
||||
publicKey, err := decodeClientPublicKey(clientPublicKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(signature) != ed25519.SignatureSize {
|
||||
return ErrInvalidRequestSignature
|
||||
}
|
||||
if !ed25519.Verify(publicKey, BuildRequestSigningInput(fields), signature) {
|
||||
return ErrInvalidRequestSignature
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeClientPublicKey(value string) (ed25519.PublicKey, error) {
|
||||
decoded, err := base64.StdEncoding.Strict().DecodeString(value)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidClientPublicKey
|
||||
}
|
||||
if len(decoded) != ed25519.PublicKeySize {
|
||||
return nil, ErrInvalidClientPublicKey
|
||||
}
|
||||
|
||||
return ed25519.PublicKey(decoded), nil
|
||||
}
|
||||
@@ -0,0 +1,137 @@
|
||||
package authn
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestVerifyRequestSignature(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
clientPrivateKey := newTestPrivateKey("primary")
|
||||
clientPublicKey := clientPrivateKey.Public().(ed25519.PublicKey)
|
||||
otherPrivateKey := newTestPrivateKey("other")
|
||||
|
||||
fields := RequestSigningFields{
|
||||
ProtocolVersion: "v1",
|
||||
DeviceSessionID: "device-session-123",
|
||||
MessageType: "fleet.move",
|
||||
TimestampMS: 123456789,
|
||||
RequestID: "request-123",
|
||||
PayloadHash: mustSHA256([]byte("payload")),
|
||||
}
|
||||
|
||||
signature := ed25519.Sign(clientPrivateKey, BuildRequestSigningInput(fields))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
clientPublicKey string
|
||||
signature []byte
|
||||
fields RequestSigningFields
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "valid signature",
|
||||
clientPublicKey: base64.StdEncoding.EncodeToString(clientPublicKey),
|
||||
signature: signature,
|
||||
fields: fields,
|
||||
},
|
||||
{
|
||||
name: "message type change rejects signature",
|
||||
clientPublicKey: base64.StdEncoding.EncodeToString(clientPublicKey),
|
||||
signature: signature,
|
||||
fields: func() RequestSigningFields {
|
||||
mutated := fields
|
||||
mutated.MessageType = "fleet.attack"
|
||||
return mutated
|
||||
}(),
|
||||
wantErr: ErrInvalidRequestSignature,
|
||||
},
|
||||
{
|
||||
name: "request id change rejects signature",
|
||||
clientPublicKey: base64.StdEncoding.EncodeToString(clientPublicKey),
|
||||
signature: signature,
|
||||
fields: func() RequestSigningFields {
|
||||
mutated := fields
|
||||
mutated.RequestID = "request-456"
|
||||
return mutated
|
||||
}(),
|
||||
wantErr: ErrInvalidRequestSignature,
|
||||
},
|
||||
{
|
||||
name: "payload hash change rejects signature",
|
||||
clientPublicKey: base64.StdEncoding.EncodeToString(clientPublicKey),
|
||||
signature: signature,
|
||||
fields: func() RequestSigningFields {
|
||||
mutated := fields
|
||||
mutated.PayloadHash = mustSHA256([]byte("other"))
|
||||
return mutated
|
||||
}(),
|
||||
wantErr: ErrInvalidRequestSignature,
|
||||
},
|
||||
{
|
||||
name: "wrong key rejects signature",
|
||||
clientPublicKey: base64.StdEncoding.EncodeToString(otherPrivateKey.Public().(ed25519.PublicKey)),
|
||||
signature: signature,
|
||||
fields: fields,
|
||||
wantErr: ErrInvalidRequestSignature,
|
||||
},
|
||||
{
|
||||
name: "bit flipped signature rejects",
|
||||
clientPublicKey: base64.StdEncoding.EncodeToString(clientPublicKey),
|
||||
signature: func() []byte {
|
||||
corrupted := append([]byte(nil), signature...)
|
||||
corrupted[0] ^= 0xff
|
||||
return corrupted
|
||||
}(),
|
||||
fields: fields,
|
||||
wantErr: ErrInvalidRequestSignature,
|
||||
},
|
||||
{
|
||||
name: "invalid signature length rejects",
|
||||
clientPublicKey: base64.StdEncoding.EncodeToString(clientPublicKey),
|
||||
signature: signature[:len(signature)-1],
|
||||
fields: fields,
|
||||
wantErr: ErrInvalidRequestSignature,
|
||||
},
|
||||
{
|
||||
name: "invalid base64 public key rejects",
|
||||
clientPublicKey: "%%%not-base64%%%",
|
||||
signature: signature,
|
||||
fields: fields,
|
||||
wantErr: ErrInvalidClientPublicKey,
|
||||
},
|
||||
{
|
||||
name: "invalid public key length rejects",
|
||||
clientPublicKey: base64.StdEncoding.EncodeToString([]byte("short")),
|
||||
signature: signature,
|
||||
fields: fields,
|
||||
wantErr: ErrInvalidClientPublicKey,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := VerifyRequestSignature(tt.clientPublicKey, tt.signature, tt.fields)
|
||||
if tt.wantErr == nil {
|
||||
require.NoError(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newTestPrivateKey(label string) ed25519.PrivateKey {
|
||||
seed := sha256.Sum256([]byte("gateway-authn-signature-test-" + label))
|
||||
return ed25519.NewKeyFromSeed(seed[:])
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
// Package clock provides the gateway time source abstraction used by
|
||||
// authenticated transport checks.
|
||||
package clock
|
||||
|
||||
import "time"
|
||||
|
||||
// Clock returns current server time for freshness checks and time-dependent
|
||||
// transport behavior.
|
||||
type Clock interface {
|
||||
// Now returns the current server time.
|
||||
Now() time.Time
|
||||
}
|
||||
|
||||
// System returns the current process time using the local system clock.
|
||||
type System struct{}
|
||||
|
||||
// Now returns the current UTC time from the system clock.
|
||||
func (System) Now() time.Time {
|
||||
return time.Now().UTC()
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,108 @@
|
||||
// Package downstream defines the verified internal command contract used by the
|
||||
// gateway after the authenticated edge pipeline succeeds.
|
||||
package downstream
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrRouteNotFound reports that Router does not have an exact-match handler
|
||||
// for the supplied authenticated message type.
|
||||
ErrRouteNotFound = errors.New("downstream route not found")
|
||||
|
||||
// ErrDownstreamUnavailable reports that the resolved downstream dependency is
|
||||
// temporarily unavailable.
|
||||
ErrDownstreamUnavailable = errors.New("downstream service is unavailable")
|
||||
)
|
||||
|
||||
// AuthenticatedCommand is the minimum verified unary command context the
|
||||
// gateway may forward to downstream business services.
|
||||
type AuthenticatedCommand struct {
|
||||
// ProtocolVersion is the authenticated transport protocol version accepted
|
||||
// by the gateway.
|
||||
ProtocolVersion string
|
||||
|
||||
// UserID is the authenticated user identity resolved from SessionCache.
|
||||
UserID string
|
||||
|
||||
// DeviceSessionID is the authenticated device session that originated the
|
||||
// command.
|
||||
DeviceSessionID string
|
||||
|
||||
// MessageType is the stable exact-match downstream routing key.
|
||||
MessageType string
|
||||
|
||||
// TimestampMS is the client-supplied request timestamp that already passed
|
||||
// freshness verification.
|
||||
TimestampMS int64
|
||||
|
||||
// RequestID is the transport correlation and anti-replay identifier.
|
||||
RequestID string
|
||||
|
||||
// TraceID is the optional client-supplied correlation identifier.
|
||||
TraceID string
|
||||
|
||||
// PayloadBytes carries the verified opaque business payload bytes.
|
||||
PayloadBytes []byte
|
||||
}
|
||||
|
||||
// UnaryResult is the minimum downstream unary result the gateway needs in
|
||||
// order to build a signed authenticated client response.
|
||||
type UnaryResult struct {
|
||||
// ResultCode is the stable opaque downstream result code returned to the
|
||||
// client without business reinterpretation by the gateway.
|
||||
ResultCode string
|
||||
|
||||
// PayloadBytes carries the opaque downstream response payload bytes.
|
||||
PayloadBytes []byte
|
||||
}
|
||||
|
||||
// Client executes a verified authenticated unary command against one concrete
|
||||
// downstream service or adapter.
|
||||
type Client interface {
|
||||
// ExecuteCommand executes command and returns the downstream unary result.
|
||||
ExecuteCommand(ctx context.Context, command AuthenticatedCommand) (UnaryResult, error)
|
||||
}
|
||||
|
||||
// Router resolves the downstream unary client for one exact authenticated
|
||||
// message_type value.
|
||||
type Router interface {
|
||||
// Route returns the downstream client for messageType. Implementations must
|
||||
// wrap ErrRouteNotFound when the route table does not contain messageType.
|
||||
Route(messageType string) (Client, error)
|
||||
}
|
||||
|
||||
// StaticRouter resolves exact message_type literals from an immutable route
|
||||
// map supplied at construction time.
|
||||
type StaticRouter struct {
|
||||
routes map[string]Client
|
||||
}
|
||||
|
||||
// NewStaticRouter constructs a StaticRouter with a defensive copy of routes.
|
||||
func NewStaticRouter(routes map[string]Client) *StaticRouter {
|
||||
clonedRoutes := make(map[string]Client, len(routes))
|
||||
for messageType, client := range routes {
|
||||
if client == nil {
|
||||
continue
|
||||
}
|
||||
clonedRoutes[messageType] = client
|
||||
}
|
||||
|
||||
return &StaticRouter{routes: clonedRoutes}
|
||||
}
|
||||
|
||||
// Route returns the exact-match client for messageType.
|
||||
func (r *StaticRouter) Route(messageType string) (Client, error) {
|
||||
if r == nil {
|
||||
return nil, ErrRouteNotFound
|
||||
}
|
||||
|
||||
client, ok := r.routes[messageType]
|
||||
if !ok || client == nil {
|
||||
return nil, ErrRouteNotFound
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
package downstream
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStaticRouterRoutesExactMessageType(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
want := &stubClient{}
|
||||
router := NewStaticRouter(map[string]Client{
|
||||
"fleet.move": want,
|
||||
})
|
||||
|
||||
got, err := router.Route("fleet.move")
|
||||
require.NoError(t, err)
|
||||
assert.Same(t, want, got)
|
||||
}
|
||||
|
||||
func TestStaticRouterRejectsUnknownMessageType(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
router := NewStaticRouter(map[string]Client{
|
||||
"fleet.move": &stubClient{},
|
||||
})
|
||||
|
||||
_, err := router.Route("fleet.rename")
|
||||
require.ErrorIs(t, err, ErrRouteNotFound)
|
||||
}
|
||||
|
||||
type stubClient struct{}
|
||||
|
||||
func (*stubClient) ExecuteCommand(context.Context, AuthenticatedCommand) (UnaryResult, error) {
|
||||
return UnaryResult{}, nil
|
||||
}
|
||||
@@ -0,0 +1,341 @@
|
||||
package events
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/config"
|
||||
"galaxy/gateway/internal/push"
|
||||
"galaxy/gateway/internal/telemetry"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const clientEventReadCount int64 = 128
|
||||
|
||||
// ClientEventPublisher accepts decoded client-facing events from the internal
|
||||
// event subscriber.
|
||||
type ClientEventPublisher interface {
|
||||
// Publish fans out event to the currently active push streams.
|
||||
Publish(event push.Event)
|
||||
}
|
||||
|
||||
// RedisClientEventSubscriber consumes client-facing events from one Redis
|
||||
// Stream and forwards them to the configured publisher.
|
||||
type RedisClientEventSubscriber struct {
|
||||
client *redis.Client
|
||||
stream string
|
||||
pingTimeout time.Duration
|
||||
readBlockTimeout time.Duration
|
||||
publisher ClientEventPublisher
|
||||
logger *zap.Logger
|
||||
metrics *telemetry.Runtime
|
||||
|
||||
closeOnce sync.Once
|
||||
startedOnce sync.Once
|
||||
started chan struct{}
|
||||
}
|
||||
|
||||
// NewRedisClientEventSubscriber constructs a Redis Stream subscriber that
|
||||
// reuses the SessionCache Redis connection settings and forwards decoded
|
||||
// client-facing events to publisher.
|
||||
func NewRedisClientEventSubscriber(sessionCfg config.SessionCacheRedisConfig, eventsCfg config.ClientEventsRedisConfig, publisher ClientEventPublisher) (*RedisClientEventSubscriber, error) {
|
||||
return NewRedisClientEventSubscriberWithObservability(sessionCfg, eventsCfg, publisher, nil, nil)
|
||||
}
|
||||
|
||||
// NewRedisClientEventSubscriberWithObservability constructs a Redis Stream
|
||||
// subscriber that also records malformed or dropped internal events.
|
||||
func NewRedisClientEventSubscriberWithObservability(sessionCfg config.SessionCacheRedisConfig, eventsCfg config.ClientEventsRedisConfig, publisher ClientEventPublisher, logger *zap.Logger, metrics *telemetry.Runtime) (*RedisClientEventSubscriber, error) {
|
||||
if strings.TrimSpace(sessionCfg.Addr) == "" {
|
||||
return nil, errors.New("new redis client event subscriber: redis addr must not be empty")
|
||||
}
|
||||
if sessionCfg.DB < 0 {
|
||||
return nil, errors.New("new redis client event subscriber: redis db must not be negative")
|
||||
}
|
||||
if sessionCfg.LookupTimeout <= 0 {
|
||||
return nil, errors.New("new redis client event subscriber: lookup timeout must be positive")
|
||||
}
|
||||
if strings.TrimSpace(eventsCfg.Stream) == "" {
|
||||
return nil, errors.New("new redis client event subscriber: stream must not be empty")
|
||||
}
|
||||
if eventsCfg.ReadBlockTimeout <= 0 {
|
||||
return nil, errors.New("new redis client event subscriber: read block timeout must be positive")
|
||||
}
|
||||
if publisher == nil {
|
||||
return nil, errors.New("new redis client event subscriber: nil publisher")
|
||||
}
|
||||
|
||||
options := &redis.Options{
|
||||
Addr: sessionCfg.Addr,
|
||||
Username: sessionCfg.Username,
|
||||
Password: sessionCfg.Password,
|
||||
DB: sessionCfg.DB,
|
||||
Protocol: 2,
|
||||
DisableIdentity: true,
|
||||
}
|
||||
if sessionCfg.TLSEnabled {
|
||||
options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
|
||||
}
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
|
||||
return &RedisClientEventSubscriber{
|
||||
client: redis.NewClient(options),
|
||||
stream: eventsCfg.Stream,
|
||||
pingTimeout: sessionCfg.LookupTimeout,
|
||||
readBlockTimeout: eventsCfg.ReadBlockTimeout,
|
||||
publisher: publisher,
|
||||
logger: logger.Named("client_event_subscriber"),
|
||||
metrics: metrics,
|
||||
started: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Ping verifies that the Redis backend used for client-facing event fan-out is
|
||||
// reachable within the configured timeout budget.
|
||||
func (s *RedisClientEventSubscriber) Ping(ctx context.Context) error {
|
||||
if s == nil || s.client == nil {
|
||||
return errors.New("ping redis client event subscriber: nil subscriber")
|
||||
}
|
||||
if ctx == nil {
|
||||
return errors.New("ping redis client event subscriber: nil context")
|
||||
}
|
||||
|
||||
pingCtx, cancel := context.WithTimeout(ctx, s.pingTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := s.client.Ping(pingCtx).Err(); err != nil {
|
||||
return fmt.Errorf("ping redis client event subscriber: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Run consumes client-facing events until ctx is canceled or Redis returns an
|
||||
// unexpected error.
|
||||
func (s *RedisClientEventSubscriber) Run(ctx context.Context) error {
|
||||
if s == nil || s.client == nil {
|
||||
return errors.New("run redis client event subscriber: nil subscriber")
|
||||
}
|
||||
if ctx == nil {
|
||||
return errors.New("run redis client event subscriber: nil context")
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
lastID, err := s.resolveStartID(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.signalStarted()
|
||||
|
||||
for {
|
||||
streams, err := s.client.XRead(ctx, &redis.XReadArgs{
|
||||
Streams: []string{s.stream, lastID},
|
||||
Count: clientEventReadCount,
|
||||
Block: s.readBlockTimeout,
|
||||
}).Result()
|
||||
switch {
|
||||
case err == nil:
|
||||
for _, stream := range streams {
|
||||
for _, message := range stream.Messages {
|
||||
s.publishMessage(message)
|
||||
lastID = message.ID
|
||||
}
|
||||
}
|
||||
continue
|
||||
case errors.Is(err, redis.Nil):
|
||||
continue
|
||||
case ctx.Err() != nil && (errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) || errors.Is(err, redis.ErrClosed)):
|
||||
return ctx.Err()
|
||||
case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded), errors.Is(err, redis.ErrClosed):
|
||||
return fmt.Errorf("run redis client event subscriber: %w", err)
|
||||
default:
|
||||
return fmt.Errorf("run redis client event subscriber: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RedisClientEventSubscriber) resolveStartID(ctx context.Context) (string, error) {
|
||||
messages, err := s.client.XRevRangeN(ctx, s.stream, "+", "-", 1).Result()
|
||||
switch {
|
||||
case err == nil:
|
||||
case errors.Is(err, redis.Nil):
|
||||
return "0-0", nil
|
||||
default:
|
||||
return "", fmt.Errorf("run redis client event subscriber: resolve stream tail: %w", err)
|
||||
}
|
||||
|
||||
if len(messages) == 0 {
|
||||
return "0-0", nil
|
||||
}
|
||||
|
||||
return messages[0].ID, nil
|
||||
}
|
||||
|
||||
// Shutdown closes the Redis client so a blocking stream read can terminate
|
||||
// promptly during gateway shutdown.
|
||||
func (s *RedisClientEventSubscriber) Shutdown(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
return errors.New("shutdown redis client event subscriber: nil context")
|
||||
}
|
||||
|
||||
return s.Close()
|
||||
}
|
||||
|
||||
// Close releases the underlying Redis client resources.
|
||||
func (s *RedisClientEventSubscriber) Close() error {
|
||||
if s == nil || s.client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
s.closeOnce.Do(func() {
|
||||
err = s.client.Close()
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *RedisClientEventSubscriber) signalStarted() {
|
||||
s.startedOnce.Do(func() {
|
||||
close(s.started)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *RedisClientEventSubscriber) publishMessage(message redis.XMessage) {
|
||||
event, err := decodeClientEvent(message.Values)
|
||||
if err != nil {
|
||||
s.logger.Warn("dropped malformed client event",
|
||||
zap.String("stream", s.stream),
|
||||
zap.String("message_id", message.ID),
|
||||
zap.Error(err),
|
||||
)
|
||||
s.metrics.RecordInternalEventDrop(context.Background(),
|
||||
attribute.String("component", "client_event_subscriber"),
|
||||
attribute.String("reason", "malformed_event"),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
s.publisher.Publish(event)
|
||||
}
|
||||
|
||||
func decodeClientEvent(values map[string]any) (push.Event, error) {
|
||||
requiredKeys := map[string]struct{}{
|
||||
"user_id": {},
|
||||
"event_type": {},
|
||||
"event_id": {},
|
||||
"payload_bytes": {},
|
||||
}
|
||||
optionalKeys := map[string]struct{}{
|
||||
"device_session_id": {},
|
||||
"request_id": {},
|
||||
"trace_id": {},
|
||||
}
|
||||
|
||||
for key := range values {
|
||||
if _, ok := requiredKeys[key]; ok {
|
||||
continue
|
||||
}
|
||||
if _, ok := optionalKeys[key]; ok {
|
||||
continue
|
||||
}
|
||||
|
||||
return push.Event{}, fmt.Errorf("decode client event: unsupported field %q", key)
|
||||
}
|
||||
|
||||
userID, err := requiredStringField(values, "user_id")
|
||||
if err != nil {
|
||||
return push.Event{}, err
|
||||
}
|
||||
eventType, err := requiredStringField(values, "event_type")
|
||||
if err != nil {
|
||||
return push.Event{}, err
|
||||
}
|
||||
eventID, err := requiredStringField(values, "event_id")
|
||||
if err != nil {
|
||||
return push.Event{}, err
|
||||
}
|
||||
payloadBytes, err := requiredBytesField(values, "payload_bytes")
|
||||
if err != nil {
|
||||
return push.Event{}, err
|
||||
}
|
||||
|
||||
event := push.Event{
|
||||
UserID: userID,
|
||||
EventType: eventType,
|
||||
EventID: eventID,
|
||||
PayloadBytes: payloadBytes,
|
||||
}
|
||||
|
||||
if deviceSessionID, ok, err := optionalStringField(values, "device_session_id"); err != nil {
|
||||
return push.Event{}, err
|
||||
} else if ok {
|
||||
event.DeviceSessionID = strings.TrimSpace(deviceSessionID)
|
||||
}
|
||||
|
||||
if requestID, ok, err := optionalStringField(values, "request_id"); err != nil {
|
||||
return push.Event{}, err
|
||||
} else if ok {
|
||||
event.RequestID = requestID
|
||||
}
|
||||
|
||||
if traceID, ok, err := optionalStringField(values, "trace_id"); err != nil {
|
||||
return push.Event{}, err
|
||||
} else if ok {
|
||||
event.TraceID = traceID
|
||||
}
|
||||
|
||||
return event, nil
|
||||
}
|
||||
|
||||
func requiredBytesField(values map[string]any, field string) ([]byte, error) {
|
||||
value, ok := values[field]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("decode client event: missing %s", field)
|
||||
}
|
||||
|
||||
byteValue, err := coerceBytes(value)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode client event: %s: %w", field, err)
|
||||
}
|
||||
|
||||
return byteValue, nil
|
||||
}
|
||||
|
||||
func optionalStringField(values map[string]any, field string) (string, bool, error) {
|
||||
value, ok := values[field]
|
||||
if !ok {
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
stringValue, err := coerceString(value)
|
||||
if err != nil {
|
||||
return "", false, fmt.Errorf("decode client event: %s: %w", field, err)
|
||||
}
|
||||
|
||||
return stringValue, true, nil
|
||||
}
|
||||
|
||||
func coerceBytes(value any) ([]byte, error) {
|
||||
switch typed := value.(type) {
|
||||
case string:
|
||||
return []byte(typed), nil
|
||||
case []byte:
|
||||
return bytes.Clone(typed), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported type %T", value)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,294 @@
|
||||
package events
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/config"
|
||||
"galaxy/gateway/internal/push"
|
||||
"galaxy/gateway/internal/testutil"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRedisClientEventSubscriberPublishesValidEvent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := &recordingClientEventPublisher{}
|
||||
subscriber := newTestRedisClientEventSubscriber(t, server, publisher)
|
||||
running := runTestClientEventSubscriber(t, subscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
addClientEvent(t, server, "gateway:client_events", map[string]any{
|
||||
"user_id": "user-123",
|
||||
"device_session_id": "device-session-123",
|
||||
"event_type": "fleet.updated",
|
||||
"event_id": "event-123",
|
||||
"payload_bytes": []byte("payload-123"),
|
||||
"request_id": "request-123",
|
||||
"trace_id": "trace-123",
|
||||
})
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return len(publisher.events()) == 1
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
|
||||
assert.Equal(t, []push.Event{{
|
||||
UserID: "user-123",
|
||||
DeviceSessionID: "device-session-123",
|
||||
EventType: "fleet.updated",
|
||||
EventID: "event-123",
|
||||
PayloadBytes: []byte("payload-123"),
|
||||
RequestID: "request-123",
|
||||
TraceID: "trace-123",
|
||||
}}, publisher.events())
|
||||
}
|
||||
|
||||
func TestRedisClientEventSubscriberSkipsMalformedEventAndContinues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := &recordingClientEventPublisher{}
|
||||
subscriber := newTestRedisClientEventSubscriber(t, server, publisher)
|
||||
running := runTestClientEventSubscriber(t, subscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
addClientEvent(t, server, "gateway:client_events", map[string]any{
|
||||
"user_id": "user-123",
|
||||
"event_type": "fleet.updated",
|
||||
"event_id": "event-bad",
|
||||
"payload_bytes": []byte("payload-bad"),
|
||||
"unexpected": "boom",
|
||||
})
|
||||
addClientEvent(t, server, "gateway:client_events", map[string]any{
|
||||
"user_id": "user-123",
|
||||
"event_type": "fleet.updated",
|
||||
"event_id": "event-good",
|
||||
"payload_bytes": []byte("payload-good"),
|
||||
})
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
events := publisher.events()
|
||||
return len(events) == 1 && events[0].EventID == "event-good"
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestRedisClientEventSubscriberStartsFromCurrentTail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := &recordingClientEventPublisher{}
|
||||
|
||||
addClientEvent(t, server, "gateway:client_events", map[string]any{
|
||||
"user_id": "user-123",
|
||||
"event_type": "fleet.updated",
|
||||
"event_id": "event-old",
|
||||
"payload_bytes": []byte("payload-old"),
|
||||
})
|
||||
|
||||
subscriber := newTestRedisClientEventSubscriber(t, server, publisher)
|
||||
running := runTestClientEventSubscriber(t, subscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
assert.Never(t, func() bool {
|
||||
return len(publisher.events()) > 0
|
||||
}, 100*time.Millisecond, 10*time.Millisecond)
|
||||
|
||||
addClientEvent(t, server, "gateway:client_events", map[string]any{
|
||||
"user_id": "user-123",
|
||||
"event_type": "fleet.updated",
|
||||
"event_id": "event-new",
|
||||
"payload_bytes": []byte("payload-new"),
|
||||
})
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
events := publisher.events()
|
||||
return len(events) == 1 && events[0].EventID == "event-new"
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestRedisClientEventSubscriberShutdownInterruptsBlockingRead(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := &recordingClientEventPublisher{}
|
||||
subscriber := newTestRedisClientEventSubscriber(t, server, publisher)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
resultCh := make(chan error, 1)
|
||||
go func() {
|
||||
resultCh <- subscriber.Run(ctx)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-subscriber.started:
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "subscriber did not start")
|
||||
}
|
||||
|
||||
cancel()
|
||||
require.NoError(t, subscriber.Shutdown(context.Background()))
|
||||
|
||||
select {
|
||||
case err := <-resultCh:
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "subscriber did not stop after shutdown")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisClientEventSubscriberLogsAndCountsMalformedEvents(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := &recordingClientEventPublisher{}
|
||||
logger, logBuffer := testutil.NewObservedLogger(t)
|
||||
telemetryRuntime := testutil.NewTelemetryRuntime(t, logger)
|
||||
|
||||
subscriber, err := NewRedisClientEventSubscriberWithObservability(
|
||||
config.SessionCacheRedisConfig{
|
||||
Addr: server.Addr(),
|
||||
LookupTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
config.ClientEventsRedisConfig{
|
||||
Stream: "gateway:client_events",
|
||||
ReadBlockTimeout: 25 * time.Millisecond,
|
||||
},
|
||||
publisher,
|
||||
logger,
|
||||
telemetryRuntime,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, subscriber.Close())
|
||||
})
|
||||
|
||||
running := runTestClientEventSubscriber(t, subscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
addClientEvent(t, server, "gateway:client_events", map[string]any{
|
||||
"user_id": "user-123",
|
||||
"event_type": "fleet.updated",
|
||||
"event_id": "event-bad",
|
||||
"payload_bytes": []byte("payload-bad"),
|
||||
"unexpected": "boom",
|
||||
})
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return strings.Contains(logBuffer.String(), "dropped malformed client event")
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
|
||||
metricsText := testutil.ScrapeMetrics(t, telemetryRuntime.Handler())
|
||||
assert.Contains(t, metricsText, `gateway_internal_event_drops_total`)
|
||||
assert.Contains(t, metricsText, `component="client_event_subscriber"`)
|
||||
assert.Contains(t, metricsText, `reason="malformed_event"`)
|
||||
}
|
||||
|
||||
func newTestRedisClientEventSubscriber(t *testing.T, server *miniredis.Miniredis, publisher ClientEventPublisher) *RedisClientEventSubscriber {
|
||||
t.Helper()
|
||||
|
||||
subscriber, err := NewRedisClientEventSubscriber(
|
||||
config.SessionCacheRedisConfig{
|
||||
Addr: server.Addr(),
|
||||
LookupTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
config.ClientEventsRedisConfig{
|
||||
Stream: "gateway:client_events",
|
||||
ReadBlockTimeout: 25 * time.Millisecond,
|
||||
},
|
||||
publisher,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, subscriber.Close())
|
||||
})
|
||||
|
||||
return subscriber
|
||||
}
|
||||
|
||||
func addClientEvent(t *testing.T, server *miniredis.Miniredis, stream string, values map[string]any) {
|
||||
t.Helper()
|
||||
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: server.Addr(),
|
||||
Protocol: 2,
|
||||
DisableIdentity: true,
|
||||
})
|
||||
defer func() {
|
||||
assert.NoError(t, client.Close())
|
||||
}()
|
||||
|
||||
err := client.XAdd(context.Background(), &redis.XAddArgs{
|
||||
Stream: stream,
|
||||
Values: values,
|
||||
}).Err()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
type runningClientEventSubscriber struct {
|
||||
cancel context.CancelFunc
|
||||
resultCh chan error
|
||||
}
|
||||
|
||||
func runTestClientEventSubscriber(t *testing.T, subscriber *RedisClientEventSubscriber) runningClientEventSubscriber {
|
||||
t.Helper()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
resultCh := make(chan error, 1)
|
||||
go func() {
|
||||
resultCh <- subscriber.Run(ctx)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-subscriber.started:
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "subscriber did not start")
|
||||
}
|
||||
|
||||
return runningClientEventSubscriber{
|
||||
cancel: cancel,
|
||||
resultCh: resultCh,
|
||||
}
|
||||
}
|
||||
|
||||
func (r runningClientEventSubscriber) stop(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
r.cancel()
|
||||
|
||||
select {
|
||||
case err := <-r.resultCh:
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "subscriber did not stop")
|
||||
}
|
||||
}
|
||||
|
||||
type recordingClientEventPublisher struct {
|
||||
mu sync.Mutex
|
||||
records []push.Event
|
||||
}
|
||||
|
||||
func (p *recordingClientEventPublisher) Publish(event push.Event) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
p.records = append(p.records, event)
|
||||
}
|
||||
|
||||
func (p *recordingClientEventPublisher) events() []push.Event {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
cloned := make([]push.Event, len(p.records))
|
||||
copy(cloned, p.records)
|
||||
return cloned
|
||||
}
|
||||
@@ -0,0 +1,385 @@
|
||||
package events
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/app"
|
||||
"galaxy/gateway/internal/authn"
|
||||
"galaxy/gateway/internal/clock"
|
||||
"galaxy/gateway/internal/config"
|
||||
"galaxy/gateway/internal/downstream"
|
||||
"galaxy/gateway/internal/grpcapi"
|
||||
"galaxy/gateway/internal/replay"
|
||||
"galaxy/gateway/internal/session"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
var testNow = time.Date(2026, time.April, 1, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
func TestAuthenticatedGatewayWarmsLocalSessionCache(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
local := session.NewMemoryCache()
|
||||
fallback := &countingSessionCache{
|
||||
records: map[string]session.Record{
|
||||
"device-session-123": newActiveSessionRecord("user-123"),
|
||||
},
|
||||
}
|
||||
readThrough, err := session.NewReadThroughCache(local, fallback)
|
||||
require.NoError(t, err)
|
||||
|
||||
subscriber := newTestRedisSessionSubscriber(t, server, local)
|
||||
downstreamClient := &recordingDownstreamClient{}
|
||||
addr, running := runAuthenticatedGateway(t, readThrough, subscriber, downstreamClient)
|
||||
defer running.stop(t)
|
||||
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-1"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, fallback.lookupCalls())
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-2"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, fallback.lookupCalls())
|
||||
assert.Len(t, downstreamClient.commands(), 2)
|
||||
}
|
||||
|
||||
func TestAuthenticatedGatewayUsesSessionUpdateEventWithoutFallbackLookup(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
local := session.NewMemoryCache()
|
||||
fallback := &countingSessionCache{
|
||||
records: map[string]session.Record{
|
||||
"device-session-123": newActiveSessionRecord("user-123"),
|
||||
},
|
||||
}
|
||||
readThrough, err := session.NewReadThroughCache(local, fallback)
|
||||
require.NoError(t, err)
|
||||
|
||||
subscriber := newTestRedisSessionSubscriber(t, server, local)
|
||||
downstreamClient := &recordingDownstreamClient{}
|
||||
addr, running := runAuthenticatedGateway(t, readThrough, subscriber, downstreamClient)
|
||||
defer running.stop(t)
|
||||
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-1"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, fallback.lookupCalls())
|
||||
|
||||
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
||||
"device_session_id": "device-session-123",
|
||||
"user_id": "user-456",
|
||||
"client_public_key": testClientPublicKeyBase64(),
|
||||
"status": string(session.StatusActive),
|
||||
})
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
record, lookupErr := local.Lookup(context.Background(), "device-session-123")
|
||||
return lookupErr == nil && record.UserID == "user-456"
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-2"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, fallback.lookupCalls())
|
||||
|
||||
commands := downstreamClient.commands()
|
||||
require.Len(t, commands, 2)
|
||||
assert.Equal(t, "user-456", commands[1].UserID)
|
||||
}
|
||||
|
||||
func TestAuthenticatedGatewayRejectsRevokedSessionAfterEventWithoutFallbackLookup(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
local := session.NewMemoryCache()
|
||||
fallback := &countingSessionCache{
|
||||
records: map[string]session.Record{
|
||||
"device-session-123": newActiveSessionRecord("user-123"),
|
||||
},
|
||||
}
|
||||
readThrough, err := session.NewReadThroughCache(local, fallback)
|
||||
require.NoError(t, err)
|
||||
|
||||
subscriber := newTestRedisSessionSubscriber(t, server, local)
|
||||
downstreamClient := &recordingDownstreamClient{}
|
||||
addr, running := runAuthenticatedGateway(t, readThrough, subscriber, downstreamClient)
|
||||
defer running.stop(t)
|
||||
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-1"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, fallback.lookupCalls())
|
||||
|
||||
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
||||
"device_session_id": "device-session-123",
|
||||
"user_id": "user-123",
|
||||
"client_public_key": testClientPublicKeyBase64(),
|
||||
"status": string(session.StatusRevoked),
|
||||
"revoked_at_ms": "123456789",
|
||||
})
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
record, lookupErr := local.Lookup(context.Background(), "device-session-123")
|
||||
return lookupErr == nil && record.Status == session.StatusRevoked
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-2"))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
|
||||
assert.Equal(t, "device session is revoked", status.Convert(err).Message())
|
||||
assert.Equal(t, 1, fallback.lookupCalls())
|
||||
}
|
||||
|
||||
type runningAuthenticatedGateway struct {
|
||||
cancel context.CancelFunc
|
||||
resultCh chan error
|
||||
}
|
||||
|
||||
func runAuthenticatedGateway(t *testing.T, sessionCache session.Cache, subscriber *RedisSessionSubscriber, downstreamClient downstream.Client) (string, runningAuthenticatedGateway) {
|
||||
t.Helper()
|
||||
|
||||
addr := unusedTCPAddr(t)
|
||||
grpcCfg := config.DefaultAuthenticatedGRPCConfig()
|
||||
grpcCfg.Addr = addr
|
||||
grpcCfg.FreshnessWindow = 5 * time.Minute
|
||||
|
||||
router := downstream.NewStaticRouter(map[string]downstream.Client{
|
||||
"fleet.move": downstreamClient,
|
||||
})
|
||||
|
||||
gateway := grpcapi.NewServer(grpcCfg, grpcapi.ServerDependencies{
|
||||
Router: router,
|
||||
ResponseSigner: newTestResponseSigner(t),
|
||||
SessionCache: sessionCache,
|
||||
ReplayStore: staticReplayStore{},
|
||||
Clock: fixedClock{now: testNow},
|
||||
})
|
||||
|
||||
application := app.New(
|
||||
config.Config{
|
||||
ShutdownTimeout: time.Second,
|
||||
AuthenticatedGRPC: grpcCfg,
|
||||
},
|
||||
gateway,
|
||||
subscriber,
|
||||
)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
resultCh := make(chan error, 1)
|
||||
go func() {
|
||||
resultCh <- application.Run(ctx)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-subscriber.started:
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "session subscriber did not start")
|
||||
}
|
||||
|
||||
return addr, runningAuthenticatedGateway{
|
||||
cancel: cancel,
|
||||
resultCh: resultCh,
|
||||
}
|
||||
}
|
||||
|
||||
func (g runningAuthenticatedGateway) stop(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
g.cancel()
|
||||
|
||||
select {
|
||||
case err := <-g.resultCh:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(2 * time.Second):
|
||||
require.FailNow(t, "gateway did not stop after cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
func dialGatewayClient(t *testing.T, addr string) *grpc.ClientConn {
|
||||
t.Helper()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn, err := grpc.DialContext(
|
||||
ctx,
|
||||
addr,
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithBlock(),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
return conn
|
||||
}
|
||||
|
||||
func unusedTCPAddr(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
addr := listener.Addr().String()
|
||||
require.NoError(t, listener.Close())
|
||||
|
||||
return addr
|
||||
}
|
||||
|
||||
func newExecuteCommandRequest(requestID string) *gatewayv1.ExecuteCommandRequest {
|
||||
payloadBytes := []byte("payload")
|
||||
payloadHash := sha256.Sum256(payloadBytes)
|
||||
|
||||
req := &gatewayv1.ExecuteCommandRequest{
|
||||
ProtocolVersion: "v1",
|
||||
DeviceSessionId: "device-session-123",
|
||||
MessageType: "fleet.move",
|
||||
TimestampMs: testNow.UnixMilli(),
|
||||
RequestId: requestID,
|
||||
PayloadBytes: payloadBytes,
|
||||
PayloadHash: payloadHash[:],
|
||||
TraceId: "trace-123",
|
||||
}
|
||||
req.Signature = ed25519.Sign(testClientPrivateKey(), authn.BuildRequestSigningInput(authn.RequestSigningFields{
|
||||
ProtocolVersion: req.GetProtocolVersion(),
|
||||
DeviceSessionID: req.GetDeviceSessionId(),
|
||||
MessageType: req.GetMessageType(),
|
||||
TimestampMS: req.GetTimestampMs(),
|
||||
RequestID: req.GetRequestId(),
|
||||
PayloadHash: req.GetPayloadHash(),
|
||||
}))
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
func newActiveSessionRecord(userID string) session.Record {
|
||||
return session.Record{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: userID,
|
||||
ClientPublicKey: testClientPublicKeyBase64(),
|
||||
Status: session.StatusActive,
|
||||
}
|
||||
}
|
||||
|
||||
func testClientPrivateKey() ed25519.PrivateKey {
|
||||
seed := sha256.Sum256([]byte("gateway-events-grpc-test-client"))
|
||||
return ed25519.NewKeyFromSeed(seed[:])
|
||||
}
|
||||
|
||||
func testClientPublicKeyBase64() string {
|
||||
return base64.StdEncoding.EncodeToString(testClientPrivateKey().Public().(ed25519.PublicKey))
|
||||
}
|
||||
|
||||
func newTestResponseSigner(t *testing.T) authn.ResponseSigner {
|
||||
t.Helper()
|
||||
|
||||
seed := sha256.Sum256([]byte("gateway-events-grpc-test-response"))
|
||||
signer, err := authn.NewEd25519ResponseSigner(ed25519.NewKeyFromSeed(seed[:]))
|
||||
require.NoError(t, err)
|
||||
|
||||
return signer
|
||||
}
|
||||
|
||||
type fixedClock struct {
|
||||
now time.Time
|
||||
}
|
||||
|
||||
func (c fixedClock) Now() time.Time {
|
||||
return c.now
|
||||
}
|
||||
|
||||
var _ clock.Clock = fixedClock{}
|
||||
|
||||
type staticReplayStore struct{}
|
||||
|
||||
func (staticReplayStore) Reserve(context.Context, string, string, time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ replay.Store = staticReplayStore{}
|
||||
|
||||
type countingSessionCache struct {
|
||||
mu sync.Mutex
|
||||
records map[string]session.Record
|
||||
lookupCount int
|
||||
}
|
||||
|
||||
func (c *countingSessionCache) Lookup(context.Context, string) (session.Record, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.lookupCount++
|
||||
|
||||
record, ok := c.records["device-session-123"]
|
||||
if !ok {
|
||||
return session.Record{}, errors.New("lookup session from counting cache: session cache record not found")
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func (c *countingSessionCache) lookupCalls() int {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
return c.lookupCount
|
||||
}
|
||||
|
||||
type recordingDownstreamClient struct {
|
||||
mu sync.Mutex
|
||||
captured []downstream.AuthenticatedCommand
|
||||
}
|
||||
|
||||
func (c *recordingDownstreamClient) ExecuteCommand(_ context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
|
||||
c.mu.Lock()
|
||||
c.captured = append(c.captured, command)
|
||||
c.mu.Unlock()
|
||||
|
||||
return downstream.UnaryResult{
|
||||
ResultCode: "ok",
|
||||
PayloadBytes: []byte("response"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *recordingDownstreamClient) commands() []downstream.AuthenticatedCommand {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
cloned := make([]downstream.AuthenticatedCommand, len(c.captured))
|
||||
copy(cloned, c.captured)
|
||||
return cloned
|
||||
}
|
||||
@@ -0,0 +1,416 @@
|
||||
package events
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/app"
|
||||
"galaxy/gateway/internal/authn"
|
||||
"galaxy/gateway/internal/config"
|
||||
"galaxy/gateway/internal/grpcapi"
|
||||
"galaxy/gateway/internal/push"
|
||||
"galaxy/gateway/internal/session"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestSubscribeEventsFanOutsUserTargetedEventToAllUserSessions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
sessionCache := session.NewMemoryCache()
|
||||
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-1", "user-123")))
|
||||
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-2", "user-123")))
|
||||
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-3", "user-999")))
|
||||
|
||||
pushHub := push.NewHub(4)
|
||||
clientSubscriber := newTestRedisClientEventSubscriber(t, server, pushHub)
|
||||
addr, running := runPushGateway(t, sessionCache, pushHub, clientSubscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
targetOneCtx, cancelTargetOne := context.WithCancel(context.Background())
|
||||
defer cancelTargetOne()
|
||||
targetOne, err := client.SubscribeEvents(targetOneCtx, newPushSubscribeEventsRequest("device-session-1", "request-1"))
|
||||
require.NoError(t, err)
|
||||
assertPushBootstrapEvent(t, recvPushEvent(t, targetOne), "request-1", "trace-device-session-1")
|
||||
|
||||
targetTwoCtx, cancelTargetTwo := context.WithCancel(context.Background())
|
||||
defer cancelTargetTwo()
|
||||
targetTwo, err := client.SubscribeEvents(targetTwoCtx, newPushSubscribeEventsRequest("device-session-2", "request-2"))
|
||||
require.NoError(t, err)
|
||||
assertPushBootstrapEvent(t, recvPushEvent(t, targetTwo), "request-2", "trace-device-session-2")
|
||||
|
||||
unrelatedCtx, cancelUnrelated := context.WithCancel(context.Background())
|
||||
defer cancelUnrelated()
|
||||
unrelated, err := client.SubscribeEvents(unrelatedCtx, newPushSubscribeEventsRequest("device-session-3", "request-3"))
|
||||
require.NoError(t, err)
|
||||
assertPushBootstrapEvent(t, recvPushEvent(t, unrelated), "request-3", "trace-device-session-3")
|
||||
|
||||
addClientEvent(t, server, "gateway:client_events", map[string]any{
|
||||
"user_id": "user-123",
|
||||
"event_type": "fleet.updated",
|
||||
"event_id": "event-123",
|
||||
"payload_bytes": []byte("payload-123"),
|
||||
"request_id": "request-123",
|
||||
"trace_id": "trace-123",
|
||||
})
|
||||
|
||||
assertSignedPushEvent(t, recvPushEvent(t, targetOne), push.Event{
|
||||
UserID: "user-123",
|
||||
EventType: "fleet.updated",
|
||||
EventID: "event-123",
|
||||
PayloadBytes: []byte("payload-123"),
|
||||
RequestID: "request-123",
|
||||
TraceID: "trace-123",
|
||||
})
|
||||
assertSignedPushEvent(t, recvPushEvent(t, targetTwo), push.Event{
|
||||
UserID: "user-123",
|
||||
EventType: "fleet.updated",
|
||||
EventID: "event-123",
|
||||
PayloadBytes: []byte("payload-123"),
|
||||
RequestID: "request-123",
|
||||
TraceID: "trace-123",
|
||||
})
|
||||
assertNoPushEvent(t, unrelated, cancelUnrelated)
|
||||
}
|
||||
|
||||
func TestSubscribeEventsFanOutsSessionTargetedEventOnlyToMatchingSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
sessionCache := session.NewMemoryCache()
|
||||
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-1", "user-123")))
|
||||
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-2", "user-123")))
|
||||
|
||||
pushHub := push.NewHub(4)
|
||||
clientSubscriber := newTestRedisClientEventSubscriber(t, server, pushHub)
|
||||
addr, running := runPushGateway(t, sessionCache, pushHub, clientSubscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
otherCtx, cancelOther := context.WithCancel(context.Background())
|
||||
defer cancelOther()
|
||||
otherStream, err := client.SubscribeEvents(otherCtx, newPushSubscribeEventsRequest("device-session-1", "request-1"))
|
||||
require.NoError(t, err)
|
||||
assertPushBootstrapEvent(t, recvPushEvent(t, otherStream), "request-1", "trace-device-session-1")
|
||||
|
||||
targetCtx, cancelTarget := context.WithCancel(context.Background())
|
||||
defer cancelTarget()
|
||||
targetStream, err := client.SubscribeEvents(targetCtx, newPushSubscribeEventsRequest("device-session-2", "request-2"))
|
||||
require.NoError(t, err)
|
||||
assertPushBootstrapEvent(t, recvPushEvent(t, targetStream), "request-2", "trace-device-session-2")
|
||||
|
||||
addClientEvent(t, server, "gateway:client_events", map[string]any{
|
||||
"user_id": "user-123",
|
||||
"device_session_id": "device-session-2",
|
||||
"event_type": "fleet.updated",
|
||||
"event_id": "event-456",
|
||||
"payload_bytes": []byte("payload-456"),
|
||||
})
|
||||
|
||||
assertSignedPushEvent(t, recvPushEvent(t, targetStream), push.Event{
|
||||
UserID: "user-123",
|
||||
DeviceSessionID: "device-session-2",
|
||||
EventType: "fleet.updated",
|
||||
EventID: "event-456",
|
||||
PayloadBytes: []byte("payload-456"),
|
||||
})
|
||||
assertNoPushEvent(t, otherStream, cancelOther)
|
||||
}
|
||||
|
||||
func TestSubscribeEventsClosesRevokedSessionStreamAndRejectsReopen(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
sessionCache := session.NewMemoryCache()
|
||||
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-1", "user-123")))
|
||||
|
||||
pushHub := push.NewHub(4)
|
||||
clientSubscriber := newTestRedisClientEventSubscriber(t, server, pushHub)
|
||||
sessionSubscriber := newTestRedisSessionSubscriberWithRevocationHandler(t, server, sessionCache, pushHub)
|
||||
addr, running := runPushGateway(t, sessionCache, pushHub, clientSubscriber, sessionSubscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
select {
|
||||
case <-sessionSubscriber.started:
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "session subscriber did not start")
|
||||
}
|
||||
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
streamCtx, cancelStream := context.WithCancel(context.Background())
|
||||
defer cancelStream()
|
||||
|
||||
stream, err := client.SubscribeEvents(streamCtx, newPushSubscribeEventsRequest("device-session-1", "request-1"))
|
||||
require.NoError(t, err)
|
||||
assertPushBootstrapEvent(t, recvPushEvent(t, stream), "request-1", "trace-device-session-1")
|
||||
|
||||
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
||||
"device_session_id": "device-session-1",
|
||||
"user_id": "user-123",
|
||||
"client_public_key": pushClientPublicKeyBase64(),
|
||||
"status": string(session.StatusRevoked),
|
||||
"revoked_at_ms": "123456789",
|
||||
})
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
record, lookupErr := sessionCache.Lookup(context.Background(), "device-session-1")
|
||||
return lookupErr == nil && record.Status == session.StatusRevoked
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
|
||||
recvErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
_, recvErr := stream.Recv()
|
||||
recvErrCh <- recvErr
|
||||
}()
|
||||
|
||||
select {
|
||||
case recvErr := <-recvErrCh:
|
||||
require.Error(t, recvErr)
|
||||
assert.Equal(t, codes.FailedPrecondition, status.Code(recvErr))
|
||||
assert.Equal(t, "device session is revoked", status.Convert(recvErr).Message())
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "stream did not close after revoke")
|
||||
}
|
||||
|
||||
reopened, err := client.SubscribeEvents(context.Background(), newPushSubscribeEventsRequest("device-session-1", "request-2"))
|
||||
if err == nil {
|
||||
_, err = reopened.Recv()
|
||||
}
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
|
||||
assert.Equal(t, "device session is revoked", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
func TestSubscribeEventsClosesActiveStreamWhenGatewayShutsDown(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
sessionCache := session.NewMemoryCache()
|
||||
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-1", "user-123")))
|
||||
|
||||
pushHub := push.NewHub(4)
|
||||
clientSubscriber := newTestRedisClientEventSubscriber(t, server, pushHub)
|
||||
addr, running := runPushGateway(t, sessionCache, pushHub, clientSubscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
stream, err := client.SubscribeEvents(context.Background(), newPushSubscribeEventsRequest("device-session-1", "request-1"))
|
||||
require.NoError(t, err)
|
||||
assertPushBootstrapEvent(t, recvPushEvent(t, stream), "request-1", "trace-device-session-1")
|
||||
|
||||
recvErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
_, recvErr := stream.Recv()
|
||||
recvErrCh <- recvErr
|
||||
}()
|
||||
|
||||
running.cancel()
|
||||
|
||||
select {
|
||||
case recvErr := <-recvErrCh:
|
||||
require.Error(t, recvErr)
|
||||
assert.Equal(t, codes.Unavailable, status.Code(recvErr))
|
||||
assert.Equal(t, "gateway is shutting down", status.Convert(recvErr).Message())
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "stream did not close after gateway shutdown")
|
||||
}
|
||||
}
|
||||
|
||||
func runPushGateway(t *testing.T, sessionCache session.Cache, pushHub *push.Hub, clientSubscriber *RedisClientEventSubscriber, extraComponents ...app.Component) (string, runningAuthenticatedGateway) {
|
||||
t.Helper()
|
||||
|
||||
addr := unusedTCPAddr(t)
|
||||
grpcCfg := config.DefaultAuthenticatedGRPCConfig()
|
||||
grpcCfg.Addr = addr
|
||||
grpcCfg.FreshnessWindow = 5 * time.Minute
|
||||
|
||||
responseSigner := newTestResponseSigner(t)
|
||||
gateway := grpcapi.NewServer(grpcCfg, grpcapi.ServerDependencies{
|
||||
Service: grpcapi.NewFanOutPushStreamService(pushHub, responseSigner, fixedClock{now: testNow}, zap.NewNop()),
|
||||
ResponseSigner: responseSigner,
|
||||
SessionCache: sessionCache,
|
||||
ReplayStore: staticReplayStore{},
|
||||
Clock: fixedClock{now: testNow},
|
||||
PushHub: pushHub,
|
||||
})
|
||||
|
||||
components := []app.Component{gateway, clientSubscriber}
|
||||
components = append(components, extraComponents...)
|
||||
application := app.New(
|
||||
config.Config{
|
||||
ShutdownTimeout: time.Second,
|
||||
AuthenticatedGRPC: grpcCfg,
|
||||
},
|
||||
components...,
|
||||
)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
resultCh := make(chan error, 1)
|
||||
go func() {
|
||||
resultCh <- application.Run(ctx)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-clientSubscriber.started:
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "client event subscriber did not start")
|
||||
}
|
||||
|
||||
return addr, runningAuthenticatedGateway{
|
||||
cancel: cancel,
|
||||
resultCh: resultCh,
|
||||
}
|
||||
}
|
||||
|
||||
func newPushActiveSessionRecord(deviceSessionID string, userID string) session.Record {
|
||||
return session.Record{
|
||||
DeviceSessionID: deviceSessionID,
|
||||
UserID: userID,
|
||||
ClientPublicKey: pushClientPublicKeyBase64(),
|
||||
Status: session.StatusActive,
|
||||
}
|
||||
}
|
||||
|
||||
func newPushSubscribeEventsRequest(deviceSessionID string, requestID string) *gatewayv1.SubscribeEventsRequest {
|
||||
payloadHash := sha256.Sum256(nil)
|
||||
traceID := "trace-" + deviceSessionID
|
||||
|
||||
req := &gatewayv1.SubscribeEventsRequest{
|
||||
ProtocolVersion: "v1",
|
||||
DeviceSessionId: deviceSessionID,
|
||||
MessageType: "gateway.subscribe",
|
||||
TimestampMs: testNow.UnixMilli(),
|
||||
RequestId: requestID,
|
||||
PayloadHash: payloadHash[:],
|
||||
TraceId: traceID,
|
||||
}
|
||||
req.Signature = ed25519.Sign(pushClientPrivateKey(), authn.BuildRequestSigningInput(authn.RequestSigningFields{
|
||||
ProtocolVersion: req.GetProtocolVersion(),
|
||||
DeviceSessionID: req.GetDeviceSessionId(),
|
||||
MessageType: req.GetMessageType(),
|
||||
TimestampMS: req.GetTimestampMs(),
|
||||
RequestID: req.GetRequestId(),
|
||||
PayloadHash: req.GetPayloadHash(),
|
||||
}))
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
func recvPushEvent(t *testing.T, stream grpc.ServerStreamingClient[gatewayv1.GatewayEvent]) *gatewayv1.GatewayEvent {
|
||||
t.Helper()
|
||||
|
||||
event, err := stream.Recv()
|
||||
require.NoError(t, err)
|
||||
return event
|
||||
}
|
||||
|
||||
func assertPushBootstrapEvent(t *testing.T, event *gatewayv1.GatewayEvent, wantRequestID string, wantTraceID string) {
|
||||
t.Helper()
|
||||
|
||||
require.NotNil(t, event)
|
||||
assert.Equal(t, "gateway.server_time", event.GetEventType())
|
||||
assert.Equal(t, wantRequestID, event.GetEventId())
|
||||
assert.Equal(t, wantRequestID, event.GetRequestId())
|
||||
assert.Equal(t, wantTraceID, event.GetTraceId())
|
||||
require.NoError(t, authn.VerifyPayloadHash(event.GetPayloadBytes(), event.GetPayloadHash()))
|
||||
require.NoError(t, authn.VerifyEventSignature(pushResponseSignerPublicKey(), event.GetSignature(), authn.EventSigningFields{
|
||||
EventType: event.GetEventType(),
|
||||
EventID: event.GetEventId(),
|
||||
TimestampMS: event.GetTimestampMs(),
|
||||
RequestID: event.GetRequestId(),
|
||||
TraceID: event.GetTraceId(),
|
||||
PayloadHash: event.GetPayloadHash(),
|
||||
}))
|
||||
}
|
||||
|
||||
func assertSignedPushEvent(t *testing.T, event *gatewayv1.GatewayEvent, want push.Event) {
|
||||
t.Helper()
|
||||
|
||||
require.NotNil(t, event)
|
||||
assert.Equal(t, want.EventType, event.GetEventType())
|
||||
assert.Equal(t, want.EventID, event.GetEventId())
|
||||
assert.Equal(t, want.RequestID, event.GetRequestId())
|
||||
assert.Equal(t, want.TraceID, event.GetTraceId())
|
||||
assert.Equal(t, want.PayloadBytes, event.GetPayloadBytes())
|
||||
require.NoError(t, authn.VerifyPayloadHash(event.GetPayloadBytes(), event.GetPayloadHash()))
|
||||
require.NoError(t, authn.VerifyEventSignature(pushResponseSignerPublicKey(), event.GetSignature(), authn.EventSigningFields{
|
||||
EventType: event.GetEventType(),
|
||||
EventID: event.GetEventId(),
|
||||
TimestampMS: event.GetTimestampMs(),
|
||||
RequestID: event.GetRequestId(),
|
||||
TraceID: event.GetTraceId(),
|
||||
PayloadHash: event.GetPayloadHash(),
|
||||
}))
|
||||
}
|
||||
|
||||
func assertNoPushEvent(t *testing.T, stream grpc.ServerStreamingClient[gatewayv1.GatewayEvent], cancel context.CancelFunc) {
|
||||
t.Helper()
|
||||
|
||||
recvCh := make(chan *gatewayv1.GatewayEvent, 1)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
event, err := stream.Recv()
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
recvCh <- event
|
||||
}()
|
||||
|
||||
select {
|
||||
case event := <-recvCh:
|
||||
require.FailNowf(t, "unexpected push event delivered", "%+v", event)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
cancel()
|
||||
case err := <-errCh:
|
||||
require.FailNowf(t, "stream closed unexpectedly", "%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func pushClientPrivateKey() ed25519.PrivateKey {
|
||||
seed := sha256.Sum256([]byte("gateway-push-grpc-test-client"))
|
||||
return ed25519.NewKeyFromSeed(seed[:])
|
||||
}
|
||||
|
||||
func pushClientPublicKeyBase64() string {
|
||||
return base64.StdEncoding.EncodeToString(pushClientPrivateKey().Public().(ed25519.PublicKey))
|
||||
}
|
||||
|
||||
func pushResponseSignerPublicKey() ed25519.PublicKey {
|
||||
seed := sha256.Sum256([]byte("gateway-events-grpc-test-response"))
|
||||
return ed25519.NewKeyFromSeed(seed[:]).Public().(ed25519.PublicKey)
|
||||
}
|
||||
@@ -0,0 +1,389 @@
|
||||
// Package events subscribes to internal session lifecycle streams used to keep
|
||||
// the gateway hot-path session cache synchronized without per-request upstream
|
||||
// lookups.
|
||||
package events
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/config"
|
||||
"galaxy/gateway/internal/session"
|
||||
"galaxy/gateway/internal/telemetry"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const sessionEventReadCount int64 = 128
|
||||
|
||||
// SessionRevocationHandler reacts to a successfully applied revoked session
|
||||
// snapshot and may tear down active resources bound to that session.
|
||||
type SessionRevocationHandler interface {
|
||||
// RevokeDeviceSession tears down active resources bound to deviceSessionID.
|
||||
RevokeDeviceSession(deviceSessionID string)
|
||||
}
|
||||
|
||||
// RedisSessionSubscriber consumes full session snapshots from one Redis Stream
|
||||
// and applies them to a process-local session snapshot store.
|
||||
type RedisSessionSubscriber struct {
|
||||
client *redis.Client
|
||||
stream string
|
||||
pingTimeout time.Duration
|
||||
readBlockTimeout time.Duration
|
||||
store session.SnapshotStore
|
||||
revocationHandler SessionRevocationHandler
|
||||
logger *zap.Logger
|
||||
metrics *telemetry.Runtime
|
||||
|
||||
closeOnce sync.Once
|
||||
startedOnce sync.Once
|
||||
started chan struct{}
|
||||
}
|
||||
|
||||
// NewRedisSessionSubscriber constructs a Redis Stream subscriber that reuses
|
||||
// the SessionCache Redis connection settings and applies updates to store.
|
||||
func NewRedisSessionSubscriber(sessionCfg config.SessionCacheRedisConfig, eventsCfg config.SessionEventsRedisConfig, store session.SnapshotStore) (*RedisSessionSubscriber, error) {
|
||||
return NewRedisSessionSubscriberWithObservability(sessionCfg, eventsCfg, store, nil, nil, nil)
|
||||
}
|
||||
|
||||
// NewRedisSessionSubscriberWithRevocationHandler constructs a Redis Stream
|
||||
// subscriber that reuses the SessionCache Redis connection settings, applies
|
||||
// updates to store, and optionally tears down active resources for revoked
|
||||
// sessions.
|
||||
func NewRedisSessionSubscriberWithRevocationHandler(sessionCfg config.SessionCacheRedisConfig, eventsCfg config.SessionEventsRedisConfig, store session.SnapshotStore, revocationHandler SessionRevocationHandler) (*RedisSessionSubscriber, error) {
|
||||
return NewRedisSessionSubscriberWithObservability(sessionCfg, eventsCfg, store, revocationHandler, nil, nil)
|
||||
}
|
||||
|
||||
// NewRedisSessionSubscriberWithObservability constructs a Redis Stream
|
||||
// subscriber that also logs and counts malformed internal session events.
|
||||
func NewRedisSessionSubscriberWithObservability(sessionCfg config.SessionCacheRedisConfig, eventsCfg config.SessionEventsRedisConfig, store session.SnapshotStore, revocationHandler SessionRevocationHandler, logger *zap.Logger, metrics *telemetry.Runtime) (*RedisSessionSubscriber, error) {
|
||||
if strings.TrimSpace(sessionCfg.Addr) == "" {
|
||||
return nil, errors.New("new redis session subscriber: redis addr must not be empty")
|
||||
}
|
||||
if sessionCfg.DB < 0 {
|
||||
return nil, errors.New("new redis session subscriber: redis db must not be negative")
|
||||
}
|
||||
if sessionCfg.LookupTimeout <= 0 {
|
||||
return nil, errors.New("new redis session subscriber: lookup timeout must be positive")
|
||||
}
|
||||
if strings.TrimSpace(eventsCfg.Stream) == "" {
|
||||
return nil, errors.New("new redis session subscriber: stream must not be empty")
|
||||
}
|
||||
if eventsCfg.ReadBlockTimeout <= 0 {
|
||||
return nil, errors.New("new redis session subscriber: read block timeout must be positive")
|
||||
}
|
||||
if store == nil {
|
||||
return nil, errors.New("new redis session subscriber: nil session snapshot store")
|
||||
}
|
||||
|
||||
options := &redis.Options{
|
||||
Addr: sessionCfg.Addr,
|
||||
Username: sessionCfg.Username,
|
||||
Password: sessionCfg.Password,
|
||||
DB: sessionCfg.DB,
|
||||
Protocol: 2,
|
||||
DisableIdentity: true,
|
||||
}
|
||||
if sessionCfg.TLSEnabled {
|
||||
options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
|
||||
}
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
|
||||
return &RedisSessionSubscriber{
|
||||
client: redis.NewClient(options),
|
||||
stream: eventsCfg.Stream,
|
||||
pingTimeout: sessionCfg.LookupTimeout,
|
||||
readBlockTimeout: eventsCfg.ReadBlockTimeout,
|
||||
store: store,
|
||||
revocationHandler: revocationHandler,
|
||||
logger: logger.Named("session_subscriber"),
|
||||
metrics: metrics,
|
||||
started: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Ping verifies that the Redis backend used for session lifecycle events is
|
||||
// reachable within the configured timeout budget.
|
||||
func (s *RedisSessionSubscriber) Ping(ctx context.Context) error {
|
||||
if s == nil || s.client == nil {
|
||||
return errors.New("ping redis session subscriber: nil subscriber")
|
||||
}
|
||||
if ctx == nil {
|
||||
return errors.New("ping redis session subscriber: nil context")
|
||||
}
|
||||
|
||||
pingCtx, cancel := context.WithTimeout(ctx, s.pingTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := s.client.Ping(pingCtx).Err(); err != nil {
|
||||
return fmt.Errorf("ping redis session subscriber: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Run consumes session lifecycle events until ctx is canceled or Redis returns
|
||||
// an unexpected error.
|
||||
func (s *RedisSessionSubscriber) Run(ctx context.Context) error {
|
||||
if s == nil || s.client == nil {
|
||||
return errors.New("run redis session subscriber: nil subscriber")
|
||||
}
|
||||
if ctx == nil {
|
||||
return errors.New("run redis session subscriber: nil context")
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
lastID, err := s.resolveStartID(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.signalStarted()
|
||||
|
||||
for {
|
||||
streams, err := s.client.XRead(ctx, &redis.XReadArgs{
|
||||
Streams: []string{s.stream, lastID},
|
||||
Count: sessionEventReadCount,
|
||||
Block: s.readBlockTimeout,
|
||||
}).Result()
|
||||
switch {
|
||||
case err == nil:
|
||||
for _, stream := range streams {
|
||||
for _, message := range stream.Messages {
|
||||
s.applyMessage(message)
|
||||
lastID = message.ID
|
||||
}
|
||||
}
|
||||
continue
|
||||
case errors.Is(err, redis.Nil):
|
||||
continue
|
||||
case ctx.Err() != nil && (errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) || errors.Is(err, redis.ErrClosed)):
|
||||
return ctx.Err()
|
||||
case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded), errors.Is(err, redis.ErrClosed):
|
||||
return fmt.Errorf("run redis session subscriber: %w", err)
|
||||
default:
|
||||
return fmt.Errorf("run redis session subscriber: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RedisSessionSubscriber) resolveStartID(ctx context.Context) (string, error) {
|
||||
messages, err := s.client.XRevRangeN(ctx, s.stream, "+", "-", 1).Result()
|
||||
switch {
|
||||
case err == nil:
|
||||
case errors.Is(err, redis.Nil):
|
||||
return "0-0", nil
|
||||
default:
|
||||
return "", fmt.Errorf("run redis session subscriber: resolve stream tail: %w", err)
|
||||
}
|
||||
|
||||
if len(messages) == 0 {
|
||||
return "0-0", nil
|
||||
}
|
||||
|
||||
return messages[0].ID, nil
|
||||
}
|
||||
|
||||
// Shutdown closes the Redis client so a blocking stream read can terminate
|
||||
// promptly during gateway shutdown.
|
||||
func (s *RedisSessionSubscriber) Shutdown(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
return errors.New("shutdown redis session subscriber: nil context")
|
||||
}
|
||||
|
||||
return s.Close()
|
||||
}
|
||||
|
||||
// Close releases the underlying Redis client resources.
|
||||
func (s *RedisSessionSubscriber) Close() error {
|
||||
if s == nil || s.client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
s.closeOnce.Do(func() {
|
||||
err = s.client.Close()
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *RedisSessionSubscriber) signalStarted() {
|
||||
s.startedOnce.Do(func() {
|
||||
close(s.started)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *RedisSessionSubscriber) applyMessage(message redis.XMessage) {
|
||||
record, err := decodeSessionRecordSnapshot(message.Values)
|
||||
if err != nil {
|
||||
s.logger.Warn("dropped malformed session event",
|
||||
zap.String("stream", s.stream),
|
||||
zap.String("message_id", message.ID),
|
||||
zap.Error(err),
|
||||
)
|
||||
s.metrics.RecordInternalEventDrop(context.Background(),
|
||||
attribute.String("component", "session_subscriber"),
|
||||
attribute.String("reason", "malformed_event"),
|
||||
)
|
||||
if deviceSessionID, ok := extractDeviceSessionID(message.Values); ok {
|
||||
s.store.Delete(deviceSessionID)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.store.Upsert(record); err != nil {
|
||||
s.logger.Warn("dropped session snapshot after store failure",
|
||||
zap.String("stream", s.stream),
|
||||
zap.String("message_id", message.ID),
|
||||
zap.String("device_session_id", record.DeviceSessionID),
|
||||
zap.Error(err),
|
||||
)
|
||||
s.metrics.RecordInternalEventDrop(context.Background(),
|
||||
attribute.String("component", "session_subscriber"),
|
||||
attribute.String("reason", "store_failure"),
|
||||
)
|
||||
s.store.Delete(record.DeviceSessionID)
|
||||
return
|
||||
}
|
||||
|
||||
if record.Status == session.StatusRevoked && s.revocationHandler != nil {
|
||||
s.revocationHandler.RevokeDeviceSession(record.DeviceSessionID)
|
||||
}
|
||||
}
|
||||
|
||||
func decodeSessionRecordSnapshot(values map[string]any) (session.Record, error) {
|
||||
requiredKeys := map[string]struct{}{
|
||||
"device_session_id": {},
|
||||
"user_id": {},
|
||||
"client_public_key": {},
|
||||
"status": {},
|
||||
}
|
||||
optionalKeys := map[string]struct{}{
|
||||
"revoked_at_ms": {},
|
||||
}
|
||||
|
||||
for key := range values {
|
||||
if _, ok := requiredKeys[key]; ok {
|
||||
continue
|
||||
}
|
||||
if _, ok := optionalKeys[key]; ok {
|
||||
continue
|
||||
}
|
||||
|
||||
return session.Record{}, fmt.Errorf("decode session event: unsupported field %q", key)
|
||||
}
|
||||
|
||||
deviceSessionID, err := requiredStringField(values, "device_session_id")
|
||||
if err != nil {
|
||||
return session.Record{}, err
|
||||
}
|
||||
userID, err := requiredStringField(values, "user_id")
|
||||
if err != nil {
|
||||
return session.Record{}, err
|
||||
}
|
||||
clientPublicKey, err := requiredStringField(values, "client_public_key")
|
||||
if err != nil {
|
||||
return session.Record{}, err
|
||||
}
|
||||
statusValue, err := requiredStringField(values, "status")
|
||||
if err != nil {
|
||||
return session.Record{}, err
|
||||
}
|
||||
|
||||
record := session.Record{
|
||||
DeviceSessionID: deviceSessionID,
|
||||
UserID: userID,
|
||||
ClientPublicKey: clientPublicKey,
|
||||
Status: session.Status(statusValue),
|
||||
}
|
||||
|
||||
if rawRevokedAtMS, ok := values["revoked_at_ms"]; ok {
|
||||
revokedAtMS, err := parseInt64Field(rawRevokedAtMS, "revoked_at_ms")
|
||||
if err != nil {
|
||||
return session.Record{}, err
|
||||
}
|
||||
record.RevokedAtMS = &revokedAtMS
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func extractDeviceSessionID(values map[string]any) (string, bool) {
|
||||
value, ok := values["device_session_id"]
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
|
||||
deviceSessionID, err := coerceString(value)
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
if strings.TrimSpace(deviceSessionID) == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return deviceSessionID, true
|
||||
}
|
||||
|
||||
func requiredStringField(values map[string]any, field string) (string, error) {
|
||||
value, ok := values[field]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("decode session event: missing %s", field)
|
||||
}
|
||||
|
||||
stringValue, err := coerceString(value)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decode session event: %s: %w", field, err)
|
||||
}
|
||||
if strings.TrimSpace(stringValue) == "" {
|
||||
return "", fmt.Errorf("decode session event: %s must not be empty", field)
|
||||
}
|
||||
|
||||
return stringValue, nil
|
||||
}
|
||||
|
||||
func parseInt64Field(value any, field string) (int64, error) {
|
||||
stringValue, err := coerceString(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("decode session event: %s: %w", field, err)
|
||||
}
|
||||
|
||||
parsed, err := strconv.ParseInt(strings.TrimSpace(stringValue), 10, 64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("decode session event: %s: %w", field, err)
|
||||
}
|
||||
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
func coerceString(value any) (string, error) {
|
||||
switch typed := value.(type) {
|
||||
case string:
|
||||
return typed, nil
|
||||
case []byte:
|
||||
return string(typed), nil
|
||||
case fmt.Stringer:
|
||||
return typed.String(), nil
|
||||
case int:
|
||||
return strconv.Itoa(typed), nil
|
||||
case int64:
|
||||
return strconv.FormatInt(typed, 10), nil
|
||||
case uint64:
|
||||
return strconv.FormatUint(typed, 10), nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported value type %T", value)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,366 @@
|
||||
package events
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/config"
|
||||
"galaxy/gateway/internal/session"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRedisSessionSubscriberAppliesActiveSnapshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := session.NewMemoryCache()
|
||||
subscriber := newTestRedisSessionSubscriber(t, server, store)
|
||||
running := runTestSubscriber(t, subscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
||||
"device_session_id": "device-session-123",
|
||||
"user_id": "user-123",
|
||||
"client_public_key": "public-key-123",
|
||||
"status": string(session.StatusActive),
|
||||
})
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
record, err := store.Lookup(context.Background(), "device-session-123")
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return record.UserID == "user-123" && record.Status == session.StatusActive
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestRedisSessionSubscriberAppliesRevokedSnapshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := session.NewMemoryCache()
|
||||
require.NoError(t, store.Upsert(session.Record{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: "public-key-123",
|
||||
Status: session.StatusActive,
|
||||
}))
|
||||
|
||||
subscriber := newTestRedisSessionSubscriber(t, server, store)
|
||||
running := runTestSubscriber(t, subscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
||||
"device_session_id": "device-session-123",
|
||||
"user_id": "user-123",
|
||||
"client_public_key": "public-key-123",
|
||||
"status": string(session.StatusRevoked),
|
||||
"revoked_at_ms": "123456789",
|
||||
})
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
record, err := store.Lookup(context.Background(), "device-session-123")
|
||||
if err != nil || record.RevokedAtMS == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return record.Status == session.StatusRevoked && *record.RevokedAtMS == 123456789
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestRedisSessionSubscriberRevokedSnapshotTriggersRevocationHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := session.NewMemoryCache()
|
||||
handler := &recordingSessionRevocationHandler{}
|
||||
subscriber := newTestRedisSessionSubscriberWithRevocationHandler(t, server, store, handler)
|
||||
running := runTestSubscriber(t, subscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
||||
"device_session_id": "device-session-123",
|
||||
"user_id": "user-123",
|
||||
"client_public_key": "public-key-123",
|
||||
"status": string(session.StatusRevoked),
|
||||
"revoked_at_ms": "123456789",
|
||||
})
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
record, err := store.Lookup(context.Background(), "device-session-123")
|
||||
if err != nil || record.Status != session.StatusRevoked {
|
||||
return false
|
||||
}
|
||||
|
||||
return assert.ObjectsAreEqual([]string{"device-session-123"}, handler.revocations())
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestRedisSessionSubscriberActiveSnapshotDoesNotTriggerRevocationHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := session.NewMemoryCache()
|
||||
handler := &recordingSessionRevocationHandler{}
|
||||
subscriber := newTestRedisSessionSubscriberWithRevocationHandler(t, server, store, handler)
|
||||
running := runTestSubscriber(t, subscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
||||
"device_session_id": "device-session-123",
|
||||
"user_id": "user-123",
|
||||
"client_public_key": "public-key-123",
|
||||
"status": string(session.StatusActive),
|
||||
})
|
||||
|
||||
assert.Never(t, func() bool {
|
||||
return len(handler.revocations()) != 0
|
||||
}, 100*time.Millisecond, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestRedisSessionSubscriberStoreFailureDoesNotTriggerRevocationHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
handler := &recordingSessionRevocationHandler{}
|
||||
subscriber := newTestRedisSessionSubscriberWithRevocationHandler(t, server, failingSnapshotStore{}, handler)
|
||||
running := runTestSubscriber(t, subscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
||||
"device_session_id": "device-session-123",
|
||||
"user_id": "user-123",
|
||||
"client_public_key": "public-key-123",
|
||||
"status": string(session.StatusRevoked),
|
||||
"revoked_at_ms": "123456789",
|
||||
})
|
||||
|
||||
assert.Never(t, func() bool {
|
||||
return len(handler.revocations()) != 0
|
||||
}, 100*time.Millisecond, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestRedisSessionSubscriberLaterEventWins(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := session.NewMemoryCache()
|
||||
subscriber := newTestRedisSessionSubscriber(t, server, store)
|
||||
running := runTestSubscriber(t, subscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
||||
"device_session_id": "device-session-123",
|
||||
"user_id": "user-123",
|
||||
"client_public_key": "public-key-123",
|
||||
"status": string(session.StatusActive),
|
||||
})
|
||||
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
||||
"device_session_id": "device-session-123",
|
||||
"user_id": "user-456",
|
||||
"client_public_key": "public-key-456",
|
||||
"status": string(session.StatusActive),
|
||||
})
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
record, err := store.Lookup(context.Background(), "device-session-123")
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return record.UserID == "user-456" && record.ClientPublicKey == "public-key-456"
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestRedisSessionSubscriberMalformedEventEvictsAndContinues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := session.NewMemoryCache()
|
||||
require.NoError(t, store.Upsert(session.Record{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: "public-key-123",
|
||||
Status: session.StatusActive,
|
||||
}))
|
||||
|
||||
subscriber := newTestRedisSessionSubscriber(t, server, store)
|
||||
running := runTestSubscriber(t, subscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
||||
"device_session_id": "device-session-123",
|
||||
"user_id": "user-123",
|
||||
"client_public_key": "public-key-123",
|
||||
"status": "paused",
|
||||
})
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
_, err := store.Lookup(context.Background(), "device-session-123")
|
||||
return err != nil
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
|
||||
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
||||
"device_session_id": "device-session-123",
|
||||
"user_id": "user-456",
|
||||
"client_public_key": "public-key-456",
|
||||
"status": string(session.StatusActive),
|
||||
})
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
record, err := store.Lookup(context.Background(), "device-session-123")
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return record.UserID == "user-456" && record.Status == session.StatusActive
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestRedisSessionSubscriberShutdownInterruptsBlockingRead(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := session.NewMemoryCache()
|
||||
subscriber := newTestRedisSessionSubscriber(t, server, store)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
resultCh := make(chan error, 1)
|
||||
go func() {
|
||||
resultCh <- subscriber.Run(ctx)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-subscriber.started:
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "subscriber did not start")
|
||||
}
|
||||
|
||||
cancel()
|
||||
require.NoError(t, subscriber.Shutdown(context.Background()))
|
||||
|
||||
select {
|
||||
case err := <-resultCh:
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "subscriber did not stop after shutdown")
|
||||
}
|
||||
}
|
||||
|
||||
func newTestRedisSessionSubscriber(t *testing.T, server *miniredis.Miniredis, store session.SnapshotStore) *RedisSessionSubscriber {
|
||||
t.Helper()
|
||||
|
||||
return newTestRedisSessionSubscriberWithRevocationHandler(t, server, store, nil)
|
||||
}
|
||||
|
||||
func newTestRedisSessionSubscriberWithRevocationHandler(t *testing.T, server *miniredis.Miniredis, store session.SnapshotStore, revocationHandler SessionRevocationHandler) *RedisSessionSubscriber {
|
||||
t.Helper()
|
||||
|
||||
subscriber, err := NewRedisSessionSubscriberWithRevocationHandler(
|
||||
config.SessionCacheRedisConfig{
|
||||
Addr: server.Addr(),
|
||||
LookupTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
config.SessionEventsRedisConfig{
|
||||
Stream: "gateway:session_events",
|
||||
ReadBlockTimeout: 25 * time.Millisecond,
|
||||
},
|
||||
store,
|
||||
revocationHandler,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, subscriber.Close())
|
||||
})
|
||||
|
||||
return subscriber
|
||||
}
|
||||
|
||||
type recordingSessionRevocationHandler struct {
|
||||
mu sync.Mutex
|
||||
revokedIDs []string
|
||||
}
|
||||
|
||||
func (h *recordingSessionRevocationHandler) RevokeDeviceSession(deviceSessionID string) {
|
||||
h.mu.Lock()
|
||||
h.revokedIDs = append(h.revokedIDs, deviceSessionID)
|
||||
h.mu.Unlock()
|
||||
}
|
||||
|
||||
func (h *recordingSessionRevocationHandler) revocations() []string {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
return append([]string(nil), h.revokedIDs...)
|
||||
}
|
||||
|
||||
type failingSnapshotStore struct{}
|
||||
|
||||
func (failingSnapshotStore) Lookup(context.Context, string) (session.Record, error) {
|
||||
return session.Record{}, session.ErrNotFound
|
||||
}
|
||||
|
||||
func (failingSnapshotStore) Upsert(session.Record) error {
|
||||
return context.DeadlineExceeded
|
||||
}
|
||||
|
||||
func (failingSnapshotStore) Delete(string) {}
|
||||
|
||||
func addSessionEvent(t *testing.T, server *miniredis.Miniredis, stream string, fields map[string]string) {
|
||||
t.Helper()
|
||||
|
||||
values := make([]string, 0, len(fields)*2)
|
||||
for key, value := range fields {
|
||||
values = append(values, key, value)
|
||||
}
|
||||
|
||||
_, err := server.XAdd(stream, "*", values)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
type runningSubscriber struct {
|
||||
cancel context.CancelFunc
|
||||
resultCh chan error
|
||||
stopOnce bool
|
||||
}
|
||||
|
||||
func runTestSubscriber(t *testing.T, subscriber *RedisSessionSubscriber) runningSubscriber {
|
||||
t.Helper()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
resultCh := make(chan error, 1)
|
||||
go func() {
|
||||
resultCh <- subscriber.Run(ctx)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-subscriber.started:
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "subscriber did not start")
|
||||
}
|
||||
|
||||
return runningSubscriber{
|
||||
cancel: cancel,
|
||||
resultCh: resultCh,
|
||||
}
|
||||
}
|
||||
|
||||
func (r runningSubscriber) stop(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
r.cancel()
|
||||
|
||||
select {
|
||||
case err := <-r.resultCh:
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "subscriber did not stop")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,145 @@
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/authn"
|
||||
"galaxy/gateway/internal/clock"
|
||||
"galaxy/gateway/internal/downstream"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// commandRoutingService translates the verified authenticated request context
|
||||
// into an internal downstream command and signs successful unary responses.
|
||||
type commandRoutingService struct {
|
||||
gatewayv1.UnimplementedEdgeGatewayServer
|
||||
|
||||
subscribeDelegate gatewayv1.EdgeGatewayServer
|
||||
router downstream.Router
|
||||
responseSigner authn.ResponseSigner
|
||||
clock clock.Clock
|
||||
downstreamTimeout time.Duration
|
||||
}
|
||||
|
||||
// ExecuteCommand builds a verified downstream command, routes it by exact
|
||||
// message_type, executes it, and signs the resulting unary response.
|
||||
func (s commandRoutingService) ExecuteCommand(ctx context.Context, _ *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
|
||||
command, err := authenticatedCommandFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client, err := s.router.Route(command.MessageType)
|
||||
switch {
|
||||
case err == nil:
|
||||
case errors.Is(err, downstream.ErrRouteNotFound):
|
||||
return nil, status.Error(codes.Unimplemented, "message_type is not routed")
|
||||
case errors.Is(err, downstream.ErrDownstreamUnavailable):
|
||||
return nil, status.Error(codes.Unavailable, "downstream service is unavailable")
|
||||
default:
|
||||
return nil, status.Error(codes.Internal, "downstream route resolution failed")
|
||||
}
|
||||
|
||||
downstreamCtx, cancel := context.WithTimeout(ctx, s.downstreamTimeout)
|
||||
defer cancel()
|
||||
|
||||
result, err := client.ExecuteCommand(downstreamCtx, command)
|
||||
switch {
|
||||
case err == nil:
|
||||
case errors.Is(err, downstream.ErrDownstreamUnavailable),
|
||||
errors.Is(err, context.DeadlineExceeded),
|
||||
errors.Is(err, context.Canceled):
|
||||
return nil, status.Error(codes.Unavailable, "downstream service is unavailable")
|
||||
default:
|
||||
return nil, status.Error(codes.Internal, "downstream execution failed")
|
||||
}
|
||||
|
||||
if strings.TrimSpace(result.ResultCode) == "" {
|
||||
return nil, status.Error(codes.Internal, "downstream response is invalid")
|
||||
}
|
||||
|
||||
responseTimestampMS := s.clock.Now().UTC().UnixMilli()
|
||||
payloadHash := sha256.Sum256(result.PayloadBytes)
|
||||
signature, err := s.responseSigner.SignResponse(authn.ResponseSigningFields{
|
||||
ProtocolVersion: command.ProtocolVersion,
|
||||
RequestID: command.RequestID,
|
||||
TimestampMS: responseTimestampMS,
|
||||
ResultCode: result.ResultCode,
|
||||
PayloadHash: payloadHash[:],
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Unavailable, "response signer is unavailable")
|
||||
}
|
||||
|
||||
return &gatewayv1.ExecuteCommandResponse{
|
||||
ProtocolVersion: command.ProtocolVersion,
|
||||
RequestId: command.RequestID,
|
||||
TimestampMs: responseTimestampMS,
|
||||
ResultCode: result.ResultCode,
|
||||
PayloadBytes: bytes.Clone(result.PayloadBytes),
|
||||
PayloadHash: bytes.Clone(payloadHash[:]),
|
||||
Signature: signature,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SubscribeEvents delegates to the authenticated streaming service
|
||||
// implementation selected during server construction.
|
||||
func (s commandRoutingService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
|
||||
return s.subscribeDelegate.SubscribeEvents(req, stream)
|
||||
}
|
||||
|
||||
// newCommandRoutingService constructs the final authenticated service that
|
||||
// owns verified unary routing while preserving the delegated streaming path.
|
||||
func newCommandRoutingService(subscribeDelegate gatewayv1.EdgeGatewayServer, router downstream.Router, responseSigner authn.ResponseSigner, clk clock.Clock, downstreamTimeout time.Duration) gatewayv1.EdgeGatewayServer {
|
||||
return commandRoutingService{
|
||||
subscribeDelegate: subscribeDelegate,
|
||||
router: router,
|
||||
responseSigner: responseSigner,
|
||||
clock: clk,
|
||||
downstreamTimeout: downstreamTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
func authenticatedCommandFromContext(ctx context.Context) (downstream.AuthenticatedCommand, error) {
|
||||
envelope, ok := parsedEnvelopeFromContext(ctx)
|
||||
if !ok {
|
||||
return downstream.AuthenticatedCommand{}, status.Error(codes.Internal, "authenticated request context is incomplete")
|
||||
}
|
||||
|
||||
record, ok := resolvedSessionFromContext(ctx)
|
||||
if !ok {
|
||||
return downstream.AuthenticatedCommand{}, status.Error(codes.Internal, "authenticated request context is incomplete")
|
||||
}
|
||||
|
||||
return downstream.AuthenticatedCommand{
|
||||
ProtocolVersion: envelope.ProtocolVersion,
|
||||
UserID: record.UserID,
|
||||
DeviceSessionID: record.DeviceSessionID,
|
||||
MessageType: envelope.MessageType,
|
||||
TimestampMS: envelope.TimestampMS,
|
||||
RequestID: envelope.RequestID,
|
||||
TraceID: envelope.TraceID,
|
||||
PayloadBytes: bytes.Clone(envelope.PayloadBytes),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type unavailableResponseSigner struct{}
|
||||
|
||||
func (unavailableResponseSigner) SignResponse(authn.ResponseSigningFields) ([]byte, error) {
|
||||
return nil, errors.New("response signer is unavailable")
|
||||
}
|
||||
|
||||
func (unavailableResponseSigner) SignEvent(authn.EventSigningFields) ([]byte, error) {
|
||||
return nil, errors.New("response signer is unavailable")
|
||||
}
|
||||
|
||||
var _ gatewayv1.EdgeGatewayServer = commandRoutingService{}
|
||||
@@ -0,0 +1,296 @@
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/authn"
|
||||
"galaxy/gateway/internal/downstream"
|
||||
"galaxy/gateway/internal/testutil"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestExecuteCommandRoutesVerifiedCommandAndSignsResponse(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
signer := newTestEd25519ResponseSigner()
|
||||
moveClient := &recordingDownstreamClient{
|
||||
executeFunc: func(_ context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
|
||||
assert.Equal(t, downstream.AuthenticatedCommand{
|
||||
ProtocolVersion: "v1",
|
||||
UserID: "user-123",
|
||||
DeviceSessionID: "device-session-123",
|
||||
MessageType: "fleet.move",
|
||||
TimestampMS: testCurrentTime.UnixMilli(),
|
||||
RequestID: "request-123",
|
||||
TraceID: "trace-123",
|
||||
PayloadBytes: []byte("payload"),
|
||||
}, command)
|
||||
|
||||
return downstream.UnaryResult{
|
||||
ResultCode: "accepted",
|
||||
PayloadBytes: []byte("downstream-response"),
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
renameClient := &recordingDownstreamClient{}
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Router: downstream.NewStaticRouter(map[string]downstream.Client{
|
||||
"fleet.move": moveClient,
|
||||
"fleet.rename": renameClient,
|
||||
}),
|
||||
ResponseSigner: signer,
|
||||
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
|
||||
ReplayStore: staticReplayStore{},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
response, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "v1", response.GetProtocolVersion())
|
||||
assert.Equal(t, "request-123", response.GetRequestId())
|
||||
assert.Equal(t, testCurrentTime.UnixMilli(), response.GetTimestampMs())
|
||||
assert.Equal(t, "accepted", response.GetResultCode())
|
||||
assert.Equal(t, []byte("downstream-response"), response.GetPayloadBytes())
|
||||
assert.Equal(t, 1, moveClient.executeCalls)
|
||||
assert.Zero(t, renameClient.executeCalls)
|
||||
|
||||
wantHash := sha256.Sum256([]byte("downstream-response"))
|
||||
assert.Equal(t, wantHash[:], response.GetPayloadHash())
|
||||
require.NoError(t, authn.VerifyPayloadHash(response.GetPayloadBytes(), response.GetPayloadHash()))
|
||||
require.NoError(t, authn.VerifyResponseSignature(signer.PublicKey(), response.GetSignature(), authn.ResponseSigningFields{
|
||||
ProtocolVersion: response.GetProtocolVersion(),
|
||||
RequestID: response.GetRequestId(),
|
||||
TimestampMS: response.GetTimestampMs(),
|
||||
ResultCode: response.GetResultCode(),
|
||||
PayloadHash: response.GetPayloadHash(),
|
||||
}))
|
||||
}
|
||||
|
||||
func TestExecuteCommandRouteMissReturnsUnimplemented(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Router: downstream.NewStaticRouter(nil),
|
||||
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
|
||||
ReplayStore: staticReplayStore{},
|
||||
ResponseSigner: newTestResponseSigner(),
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unimplemented, status.Code(err))
|
||||
assert.Equal(t, "message_type is not routed", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
func TestExecuteCommandMapsDownstreamUnavailableToUnavailable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
failingClient := &recordingDownstreamClient{
|
||||
executeFunc: func(context.Context, downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
|
||||
return downstream.UnaryResult{}, fmt.Errorf("rpc transport failed: %w", downstream.ErrDownstreamUnavailable)
|
||||
},
|
||||
}
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Router: downstream.NewStaticRouter(map[string]downstream.Client{
|
||||
"fleet.move": failingClient,
|
||||
}),
|
||||
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
|
||||
ReplayStore: staticReplayStore{},
|
||||
ResponseSigner: newTestResponseSigner(),
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unavailable, status.Code(err))
|
||||
assert.Equal(t, "downstream service is unavailable", status.Convert(err).Message())
|
||||
assert.Equal(t, 1, failingClient.executeCalls)
|
||||
}
|
||||
|
||||
func TestExecuteCommandPropagatesOTelSpanContextToDownstream(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := zap.NewNop()
|
||||
telemetryRuntime := testutil.NewTelemetryRuntime(t, logger)
|
||||
|
||||
var (
|
||||
seenSpanContext trace.SpanContext
|
||||
seenCommand downstream.AuthenticatedCommand
|
||||
)
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Router: downstream.NewStaticRouter(map[string]downstream.Client{
|
||||
"fleet.move": &recordingDownstreamClient{
|
||||
executeFunc: func(ctx context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
|
||||
seenSpanContext = trace.SpanContextFromContext(ctx)
|
||||
seenCommand = command
|
||||
|
||||
return downstream.UnaryResult{
|
||||
ResultCode: "accepted",
|
||||
PayloadBytes: []byte("downstream-response"),
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
}),
|
||||
ResponseSigner: newTestResponseSigner(),
|
||||
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
|
||||
ReplayStore: staticReplayStore{},
|
||||
Logger: logger,
|
||||
Telemetry: telemetryRuntime,
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, seenSpanContext.IsValid())
|
||||
assert.Equal(t, "trace-123", seenCommand.TraceID)
|
||||
}
|
||||
|
||||
func TestExecuteCommandDrainsInFlightUnaryDuringShutdown(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
started := make(chan struct{})
|
||||
release := make(chan struct{})
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Router: downstream.NewStaticRouter(map[string]downstream.Client{
|
||||
"fleet.move": &recordingDownstreamClient{
|
||||
executeFunc: func(_ context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
|
||||
close(started)
|
||||
<-release
|
||||
|
||||
return downstream.UnaryResult{
|
||||
ResultCode: "accepted",
|
||||
PayloadBytes: []byte("downstream-response"),
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
}),
|
||||
ResponseSigner: newTestResponseSigner(),
|
||||
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
|
||||
ReplayStore: staticReplayStore{},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
resultCh := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
resultCh <- err
|
||||
}()
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
select {
|
||||
case <-started:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, time.Second, 10*time.Millisecond, "downstream execution did not start")
|
||||
|
||||
runGateway.cancel()
|
||||
|
||||
require.Never(t, func() bool {
|
||||
select {
|
||||
case <-resultCh:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, 100*time.Millisecond, 10*time.Millisecond, "unary request returned before downstream release")
|
||||
|
||||
close(release)
|
||||
|
||||
var err error
|
||||
require.Eventually(t, func() bool {
|
||||
select {
|
||||
case err = <-resultCh:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, time.Second, 10*time.Millisecond, "unary request did not drain before shutdown timeout")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestExecuteCommandLogsDoNotContainSensitiveTransportMaterial(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger, logBuffer := testutil.NewObservedLogger(t)
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Router: downstream.NewStaticRouter(map[string]downstream.Client{
|
||||
"fleet.move": &recordingDownstreamClient{},
|
||||
}),
|
||||
ResponseSigner: newTestResponseSigner(),
|
||||
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
|
||||
ReplayStore: staticReplayStore{},
|
||||
Logger: logger,
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
require.NoError(t, err)
|
||||
|
||||
logOutput := logBuffer.String()
|
||||
assert.NotContains(t, logOutput, "payload_hash")
|
||||
assert.NotContains(t, logOutput, "signature")
|
||||
assert.NotContains(t, logOutput, `"payload"`)
|
||||
}
|
||||
@@ -0,0 +1,214 @@
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"buf.build/go/protovalidate"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
const supportedProtocolVersion = "v1"
|
||||
|
||||
// parsedEnvelope captures the authenticated transport fields extracted from a
|
||||
// request envelope after validation succeeds. Later wrappers may enrich this
|
||||
// structure without changing the raw gRPC request types.
|
||||
type parsedEnvelope struct {
|
||||
ProtocolVersion string
|
||||
DeviceSessionID string
|
||||
MessageType string
|
||||
TimestampMS int64
|
||||
RequestID string
|
||||
TraceID string
|
||||
PayloadBytes []byte
|
||||
PayloadHash []byte
|
||||
Signature []byte
|
||||
}
|
||||
|
||||
// parsedEnvelopeFromContext returns the parsed envelope previously attached to
|
||||
// ctx by the envelope-validating gRPC service wrapper.
|
||||
func parsedEnvelopeFromContext(ctx context.Context) (parsedEnvelope, bool) {
|
||||
if ctx == nil {
|
||||
return parsedEnvelope{}, false
|
||||
}
|
||||
|
||||
envelope, ok := ctx.Value(parsedEnvelopeContextKey{}).(parsedEnvelope)
|
||||
if !ok {
|
||||
return parsedEnvelope{}, false
|
||||
}
|
||||
|
||||
return envelope, true
|
||||
}
|
||||
|
||||
// envelopeValidatingService applies envelope parsing and the protocol gate
|
||||
// before delegating to the configured service implementation.
|
||||
type envelopeValidatingService struct {
|
||||
gatewayv1.UnimplementedEdgeGatewayServer
|
||||
|
||||
delegate gatewayv1.EdgeGatewayServer
|
||||
}
|
||||
|
||||
// ExecuteCommand validates req and only then forwards it to the configured
|
||||
// delegate with the parsed envelope attached to ctx.
|
||||
func (s envelopeValidatingService) ExecuteCommand(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
|
||||
envelope, err := parseExecuteCommandRequest(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.delegate.ExecuteCommand(context.WithValue(ctx, parsedEnvelopeContextKey{}, envelope), req)
|
||||
}
|
||||
|
||||
// SubscribeEvents validates req and only then forwards it to the configured
|
||||
// delegate with the parsed envelope attached to the stream context.
|
||||
func (s envelopeValidatingService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
|
||||
envelope, err := parseSubscribeEventsRequest(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.delegate.SubscribeEvents(req, envelopeContextStream{
|
||||
ServerStreamingServer: stream,
|
||||
ctx: context.WithValue(stream.Context(), parsedEnvelopeContextKey{}, envelope),
|
||||
})
|
||||
}
|
||||
|
||||
// parseExecuteCommandRequest validates req according to the request-envelope
|
||||
// rules and returns a cloned parsed envelope suitable for later auth steps.
|
||||
func parseExecuteCommandRequest(req *gatewayv1.ExecuteCommandRequest) (parsedEnvelope, error) {
|
||||
if req == nil {
|
||||
return parsedEnvelope{}, newMalformedEnvelopeError("request envelope must not be nil")
|
||||
}
|
||||
if err := protovalidate.Validate(req); err != nil {
|
||||
return parsedEnvelope{}, canonicalExecuteCommandValidationError(req)
|
||||
}
|
||||
if req.GetProtocolVersion() != supportedProtocolVersion {
|
||||
return parsedEnvelope{}, newUnsupportedProtocolVersionError(req.GetProtocolVersion())
|
||||
}
|
||||
|
||||
return parsedEnvelope{
|
||||
ProtocolVersion: req.GetProtocolVersion(),
|
||||
DeviceSessionID: req.GetDeviceSessionId(),
|
||||
MessageType: req.GetMessageType(),
|
||||
TimestampMS: req.GetTimestampMs(),
|
||||
RequestID: req.GetRequestId(),
|
||||
TraceID: req.GetTraceId(),
|
||||
PayloadBytes: bytes.Clone(req.GetPayloadBytes()),
|
||||
PayloadHash: bytes.Clone(req.GetPayloadHash()),
|
||||
Signature: bytes.Clone(req.GetSignature()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// parseSubscribeEventsRequest validates req according to the request-envelope
|
||||
// rules and returns a cloned parsed envelope suitable for later auth steps.
|
||||
func parseSubscribeEventsRequest(req *gatewayv1.SubscribeEventsRequest) (parsedEnvelope, error) {
|
||||
if req == nil {
|
||||
return parsedEnvelope{}, newMalformedEnvelopeError("request envelope must not be nil")
|
||||
}
|
||||
if err := protovalidate.Validate(req); err != nil {
|
||||
return parsedEnvelope{}, canonicalSubscribeEventsValidationError(req)
|
||||
}
|
||||
if req.GetProtocolVersion() != supportedProtocolVersion {
|
||||
return parsedEnvelope{}, newUnsupportedProtocolVersionError(req.GetProtocolVersion())
|
||||
}
|
||||
|
||||
return parsedEnvelope{
|
||||
ProtocolVersion: req.GetProtocolVersion(),
|
||||
DeviceSessionID: req.GetDeviceSessionId(),
|
||||
MessageType: req.GetMessageType(),
|
||||
TimestampMS: req.GetTimestampMs(),
|
||||
RequestID: req.GetRequestId(),
|
||||
TraceID: req.GetTraceId(),
|
||||
PayloadBytes: bytes.Clone(req.GetPayloadBytes()),
|
||||
PayloadHash: bytes.Clone(req.GetPayloadHash()),
|
||||
Signature: bytes.Clone(req.GetSignature()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// newEnvelopeValidatingService wraps delegate with the envelope-validation
|
||||
// gate.
|
||||
func newEnvelopeValidatingService(delegate gatewayv1.EdgeGatewayServer) gatewayv1.EdgeGatewayServer {
|
||||
return envelopeValidatingService{delegate: delegate}
|
||||
}
|
||||
|
||||
// canonicalExecuteCommandValidationError maps any ExecuteCommand validation
|
||||
// failure into the stable canonical error chosen by field order.
|
||||
func canonicalExecuteCommandValidationError(req *gatewayv1.ExecuteCommandRequest) error {
|
||||
switch {
|
||||
case req.GetProtocolVersion() == "":
|
||||
return newMalformedEnvelopeError("protocol_version must not be empty")
|
||||
case req.GetDeviceSessionId() == "":
|
||||
return newMalformedEnvelopeError("device_session_id must not be empty")
|
||||
case req.GetMessageType() == "":
|
||||
return newMalformedEnvelopeError("message_type must not be empty")
|
||||
case req.GetTimestampMs() <= 0:
|
||||
return newMalformedEnvelopeError("timestamp_ms must be greater than zero")
|
||||
case req.GetRequestId() == "":
|
||||
return newMalformedEnvelopeError("request_id must not be empty")
|
||||
case len(req.GetPayloadBytes()) == 0:
|
||||
return newMalformedEnvelopeError("payload_bytes must not be empty")
|
||||
case len(req.GetPayloadHash()) == 0:
|
||||
return newMalformedEnvelopeError("payload_hash must not be empty")
|
||||
case len(req.GetSignature()) == 0:
|
||||
return newMalformedEnvelopeError("signature must not be empty")
|
||||
default:
|
||||
return newMalformedEnvelopeError("request envelope is invalid")
|
||||
}
|
||||
}
|
||||
|
||||
// canonicalSubscribeEventsValidationError maps any SubscribeEvents validation
|
||||
// failure into the stable canonical error chosen by field order.
|
||||
func canonicalSubscribeEventsValidationError(req *gatewayv1.SubscribeEventsRequest) error {
|
||||
switch {
|
||||
case req.GetProtocolVersion() == "":
|
||||
return newMalformedEnvelopeError("protocol_version must not be empty")
|
||||
case req.GetDeviceSessionId() == "":
|
||||
return newMalformedEnvelopeError("device_session_id must not be empty")
|
||||
case req.GetMessageType() == "":
|
||||
return newMalformedEnvelopeError("message_type must not be empty")
|
||||
case req.GetTimestampMs() <= 0:
|
||||
return newMalformedEnvelopeError("timestamp_ms must be greater than zero")
|
||||
case req.GetRequestId() == "":
|
||||
return newMalformedEnvelopeError("request_id must not be empty")
|
||||
case len(req.GetPayloadHash()) == 0:
|
||||
return newMalformedEnvelopeError("payload_hash must not be empty")
|
||||
case len(req.GetSignature()) == 0:
|
||||
return newMalformedEnvelopeError("signature must not be empty")
|
||||
default:
|
||||
return newMalformedEnvelopeError("request envelope is invalid")
|
||||
}
|
||||
}
|
||||
|
||||
// newMalformedEnvelopeError returns the stable malformed-envelope reject used
|
||||
// before the gateway performs any auth or routing work.
|
||||
func newMalformedEnvelopeError(message string) error {
|
||||
return status.Error(codes.InvalidArgument, message)
|
||||
}
|
||||
|
||||
// newUnsupportedProtocolVersionError returns the stable reject for a non-empty
|
||||
// but unsupported protocol_version literal.
|
||||
func newUnsupportedProtocolVersionError(version string) error {
|
||||
return status.Error(codes.FailedPrecondition, fmt.Sprintf("unsupported protocol_version %q", version))
|
||||
}
|
||||
|
||||
type parsedEnvelopeContextKey struct{}
|
||||
|
||||
type envelopeContextStream struct {
|
||||
grpc.ServerStreamingServer[gatewayv1.GatewayEvent]
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (s envelopeContextStream) Context() context.Context {
|
||||
if s.ctx == nil {
|
||||
return context.Background()
|
||||
}
|
||||
|
||||
return s.ctx
|
||||
}
|
||||
|
||||
var _ gatewayv1.EdgeGatewayServer = envelopeValidatingService{}
|
||||
@@ -0,0 +1,420 @@
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestParseExecuteCommandRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mutate func(*gatewayv1.ExecuteCommandRequest)
|
||||
wantCode codes.Code
|
||||
wantMessage string
|
||||
assertValid func(*testing.T, *gatewayv1.ExecuteCommandRequest, parsedEnvelope)
|
||||
}{
|
||||
{
|
||||
name: "nil request",
|
||||
wantCode: codes.InvalidArgument,
|
||||
wantMessage: "request envelope must not be nil",
|
||||
},
|
||||
{
|
||||
name: "empty protocol version",
|
||||
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
|
||||
req.ProtocolVersion = ""
|
||||
},
|
||||
wantCode: codes.InvalidArgument,
|
||||
wantMessage: "protocol_version must not be empty",
|
||||
},
|
||||
{
|
||||
name: "empty device session id",
|
||||
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
|
||||
req.DeviceSessionId = ""
|
||||
},
|
||||
wantCode: codes.InvalidArgument,
|
||||
wantMessage: "device_session_id must not be empty",
|
||||
},
|
||||
{
|
||||
name: "empty message type",
|
||||
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
|
||||
req.MessageType = ""
|
||||
},
|
||||
wantCode: codes.InvalidArgument,
|
||||
wantMessage: "message_type must not be empty",
|
||||
},
|
||||
{
|
||||
name: "zero timestamp",
|
||||
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
|
||||
req.TimestampMs = 0
|
||||
},
|
||||
wantCode: codes.InvalidArgument,
|
||||
wantMessage: "timestamp_ms must be greater than zero",
|
||||
},
|
||||
{
|
||||
name: "empty request id",
|
||||
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
|
||||
req.RequestId = ""
|
||||
},
|
||||
wantCode: codes.InvalidArgument,
|
||||
wantMessage: "request_id must not be empty",
|
||||
},
|
||||
{
|
||||
name: "empty payload bytes",
|
||||
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
|
||||
req.PayloadBytes = nil
|
||||
},
|
||||
wantCode: codes.InvalidArgument,
|
||||
wantMessage: "payload_bytes must not be empty",
|
||||
},
|
||||
{
|
||||
name: "empty payload hash",
|
||||
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
|
||||
req.PayloadHash = nil
|
||||
},
|
||||
wantCode: codes.InvalidArgument,
|
||||
wantMessage: "payload_hash must not be empty",
|
||||
},
|
||||
{
|
||||
name: "empty signature",
|
||||
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
|
||||
req.Signature = nil
|
||||
},
|
||||
wantCode: codes.InvalidArgument,
|
||||
wantMessage: "signature must not be empty",
|
||||
},
|
||||
{
|
||||
name: "unsupported protocol version",
|
||||
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
|
||||
req.ProtocolVersion = "v2"
|
||||
},
|
||||
wantCode: codes.FailedPrecondition,
|
||||
wantMessage: `unsupported protocol_version "v2"`,
|
||||
},
|
||||
{
|
||||
name: "valid request",
|
||||
wantCode: codes.OK,
|
||||
assertValid: func(t *testing.T, req *gatewayv1.ExecuteCommandRequest, envelope parsedEnvelope) {
|
||||
t.Helper()
|
||||
|
||||
assert.Equal(t, supportedProtocolVersion, envelope.ProtocolVersion)
|
||||
assert.Equal(t, req.GetDeviceSessionId(), envelope.DeviceSessionID)
|
||||
assert.Equal(t, req.GetMessageType(), envelope.MessageType)
|
||||
assert.Equal(t, req.GetTimestampMs(), envelope.TimestampMS)
|
||||
assert.Equal(t, req.GetRequestId(), envelope.RequestID)
|
||||
assert.Equal(t, req.GetTraceId(), envelope.TraceID)
|
||||
assert.Equal(t, req.GetPayloadBytes(), envelope.PayloadBytes)
|
||||
assert.Equal(t, req.GetPayloadHash(), envelope.PayloadHash)
|
||||
assert.Equal(t, req.GetSignature(), envelope.Signature)
|
||||
|
||||
originalPayloadBytes := append([]byte(nil), req.GetPayloadBytes()...)
|
||||
originalPayloadHash := append([]byte(nil), req.GetPayloadHash()...)
|
||||
originalSignature := append([]byte(nil), req.GetSignature()...)
|
||||
|
||||
envelope.PayloadBytes[0] = 'X'
|
||||
envelope.PayloadHash[0] = 'Y'
|
||||
envelope.Signature[0] = 'Z'
|
||||
|
||||
assert.Equal(t, originalPayloadBytes, req.GetPayloadBytes())
|
||||
assert.Equal(t, originalPayloadHash, req.GetPayloadHash())
|
||||
assert.Equal(t, originalSignature, req.GetSignature())
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var req *gatewayv1.ExecuteCommandRequest
|
||||
if tt.name != "nil request" {
|
||||
req = newValidExecuteCommandRequest()
|
||||
if tt.mutate != nil {
|
||||
tt.mutate(req)
|
||||
}
|
||||
}
|
||||
|
||||
envelope, err := parseExecuteCommandRequest(req)
|
||||
if tt.wantCode != codes.OK {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tt.wantCode, status.Code(err))
|
||||
assert.Equal(t, tt.wantMessage, status.Convert(err).Message())
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, tt.assertValid)
|
||||
tt.assertValid(t, req, envelope)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSubscribeEventsRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mutate func(*gatewayv1.SubscribeEventsRequest)
|
||||
wantCode codes.Code
|
||||
wantMessage string
|
||||
assertValid func(*testing.T, *gatewayv1.SubscribeEventsRequest, parsedEnvelope)
|
||||
}{
|
||||
{
|
||||
name: "nil request",
|
||||
wantCode: codes.InvalidArgument,
|
||||
wantMessage: "request envelope must not be nil",
|
||||
},
|
||||
{
|
||||
name: "empty protocol version",
|
||||
mutate: func(req *gatewayv1.SubscribeEventsRequest) {
|
||||
req.ProtocolVersion = ""
|
||||
},
|
||||
wantCode: codes.InvalidArgument,
|
||||
wantMessage: "protocol_version must not be empty",
|
||||
},
|
||||
{
|
||||
name: "empty device session id",
|
||||
mutate: func(req *gatewayv1.SubscribeEventsRequest) {
|
||||
req.DeviceSessionId = ""
|
||||
},
|
||||
wantCode: codes.InvalidArgument,
|
||||
wantMessage: "device_session_id must not be empty",
|
||||
},
|
||||
{
|
||||
name: "empty message type",
|
||||
mutate: func(req *gatewayv1.SubscribeEventsRequest) {
|
||||
req.MessageType = ""
|
||||
},
|
||||
wantCode: codes.InvalidArgument,
|
||||
wantMessage: "message_type must not be empty",
|
||||
},
|
||||
{
|
||||
name: "zero timestamp",
|
||||
mutate: func(req *gatewayv1.SubscribeEventsRequest) {
|
||||
req.TimestampMs = 0
|
||||
},
|
||||
wantCode: codes.InvalidArgument,
|
||||
wantMessage: "timestamp_ms must be greater than zero",
|
||||
},
|
||||
{
|
||||
name: "empty request id",
|
||||
mutate: func(req *gatewayv1.SubscribeEventsRequest) {
|
||||
req.RequestId = ""
|
||||
},
|
||||
wantCode: codes.InvalidArgument,
|
||||
wantMessage: "request_id must not be empty",
|
||||
},
|
||||
{
|
||||
name: "empty payload hash",
|
||||
mutate: func(req *gatewayv1.SubscribeEventsRequest) {
|
||||
req.PayloadHash = nil
|
||||
},
|
||||
wantCode: codes.InvalidArgument,
|
||||
wantMessage: "payload_hash must not be empty",
|
||||
},
|
||||
{
|
||||
name: "empty signature",
|
||||
mutate: func(req *gatewayv1.SubscribeEventsRequest) {
|
||||
req.Signature = nil
|
||||
},
|
||||
wantCode: codes.InvalidArgument,
|
||||
wantMessage: "signature must not be empty",
|
||||
},
|
||||
{
|
||||
name: "unsupported protocol version",
|
||||
mutate: func(req *gatewayv1.SubscribeEventsRequest) {
|
||||
req.ProtocolVersion = "v2"
|
||||
},
|
||||
wantCode: codes.FailedPrecondition,
|
||||
wantMessage: `unsupported protocol_version "v2"`,
|
||||
},
|
||||
{
|
||||
name: "valid request with empty payload bytes",
|
||||
wantCode: codes.OK,
|
||||
assertValid: func(t *testing.T, req *gatewayv1.SubscribeEventsRequest, envelope parsedEnvelope) {
|
||||
t.Helper()
|
||||
|
||||
assert.Empty(t, req.GetPayloadBytes())
|
||||
assert.Empty(t, envelope.PayloadBytes)
|
||||
assert.Equal(t, req.GetPayloadHash(), envelope.PayloadHash)
|
||||
assert.Equal(t, req.GetSignature(), envelope.Signature)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var req *gatewayv1.SubscribeEventsRequest
|
||||
if tt.name != "nil request" {
|
||||
req = newValidSubscribeEventsRequest()
|
||||
if tt.mutate != nil {
|
||||
tt.mutate(req)
|
||||
}
|
||||
}
|
||||
|
||||
envelope, err := parseSubscribeEventsRequest(req)
|
||||
if tt.wantCode != codes.OK {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tt.wantCode, status.Code(err))
|
||||
assert.Equal(t, tt.wantMessage, status.Convert(err).Message())
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, tt.assertValid)
|
||||
tt.assertValid(t, req, envelope)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvelopeValidatingServiceExecuteCommandRejectsInvalidRequestBeforeDelegate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
service := newEnvelopeValidatingService(delegate)
|
||||
|
||||
_, err := service.ExecuteCommand(context.Background(), &gatewayv1.ExecuteCommandRequest{})
|
||||
require.Error(t, err)
|
||||
|
||||
assert.Equal(t, codes.InvalidArgument, status.Code(err))
|
||||
assert.Zero(t, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestEnvelopeValidatingServiceSubscribeEventsRejectsInvalidRequestBeforeDelegate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
service := newEnvelopeValidatingService(delegate)
|
||||
|
||||
err := service.SubscribeEvents(&gatewayv1.SubscribeEventsRequest{}, stubGatewayEventStream{})
|
||||
require.Error(t, err)
|
||||
|
||||
assert.Equal(t, codes.InvalidArgument, status.Code(err))
|
||||
assert.Zero(t, delegate.subscribeCalls)
|
||||
}
|
||||
|
||||
func TestEnvelopeValidatingServiceExecuteCommandAttachesParsedEnvelope(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
want := newValidExecuteCommandRequest()
|
||||
delegate := &recordingEdgeGatewayService{
|
||||
executeCommandFunc: func(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
|
||||
envelope, ok := parsedEnvelopeFromContext(ctx)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, want.GetRequestId(), envelope.RequestID)
|
||||
assert.Equal(t, want.GetDeviceSessionId(), envelope.DeviceSessionID)
|
||||
assert.Equal(t, want.GetMessageType(), envelope.MessageType)
|
||||
assert.Equal(t, want.GetPayloadBytes(), envelope.PayloadBytes)
|
||||
return &gatewayv1.ExecuteCommandResponse{RequestId: req.GetRequestId()}, nil
|
||||
},
|
||||
}
|
||||
service := newEnvelopeValidatingService(delegate)
|
||||
|
||||
response, err := service.ExecuteCommand(context.Background(), want)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, want.GetRequestId(), response.GetRequestId())
|
||||
assert.Equal(t, 1, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestEnvelopeValidatingServiceSubscribeEventsAttachesParsedEnvelope(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
want := newValidSubscribeEventsRequest()
|
||||
delegate := &recordingEdgeGatewayService{
|
||||
subscribeEventsFunc: func(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
|
||||
envelope, ok := parsedEnvelopeFromContext(stream.Context())
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, want.GetRequestId(), envelope.RequestID)
|
||||
assert.Equal(t, want.GetDeviceSessionId(), envelope.DeviceSessionID)
|
||||
assert.Equal(t, want.GetMessageType(), envelope.MessageType)
|
||||
assert.Equal(t, want.GetPayloadHash(), envelope.PayloadHash)
|
||||
assert.Equal(t, want.GetSignature(), envelope.Signature)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
service := newEnvelopeValidatingService(delegate)
|
||||
|
||||
err := service.SubscribeEvents(want, stubGatewayEventStream{})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 1, delegate.subscribeCalls)
|
||||
}
|
||||
|
||||
type recordingEdgeGatewayService struct {
|
||||
gatewayv1.UnimplementedEdgeGatewayServer
|
||||
|
||||
executeCalls int
|
||||
subscribeCalls int
|
||||
executeCommandFunc func(context.Context, *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error)
|
||||
subscribeEventsFunc func(*gatewayv1.SubscribeEventsRequest, grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error
|
||||
}
|
||||
|
||||
func (s *recordingEdgeGatewayService) ExecuteCommand(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
|
||||
s.executeCalls++
|
||||
if s.executeCommandFunc != nil {
|
||||
return s.executeCommandFunc(ctx, req)
|
||||
}
|
||||
|
||||
return &gatewayv1.ExecuteCommandResponse{}, nil
|
||||
}
|
||||
|
||||
func (s *recordingEdgeGatewayService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
|
||||
s.subscribeCalls++
|
||||
if s.subscribeEventsFunc != nil {
|
||||
return s.subscribeEventsFunc(req, stream)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type stubGatewayEventStream struct {
|
||||
grpc.ServerStream
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (s stubGatewayEventStream) Send(*gatewayv1.GatewayEvent) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s stubGatewayEventStream) SetHeader(metadata.MD) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s stubGatewayEventStream) SendHeader(metadata.MD) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s stubGatewayEventStream) SetTrailer(metadata.MD) {}
|
||||
|
||||
func (s stubGatewayEventStream) Context() context.Context {
|
||||
if s.ctx == nil {
|
||||
return context.Background()
|
||||
}
|
||||
|
||||
return s.ctx
|
||||
}
|
||||
|
||||
func (s stubGatewayEventStream) SendMsg(any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s stubGatewayEventStream) RecvMsg(any) error {
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,95 @@
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/clock"
|
||||
"galaxy/gateway/internal/replay"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
const minimumReplayReservationTTL = time.Millisecond
|
||||
|
||||
// freshnessAndReplayService applies freshness and anti-replay checks after
|
||||
// client-signature verification and before later policy or routing steps run.
|
||||
type freshnessAndReplayService struct {
|
||||
gatewayv1.UnimplementedEdgeGatewayServer
|
||||
|
||||
delegate gatewayv1.EdgeGatewayServer
|
||||
clock clock.Clock
|
||||
replayStore replay.Store
|
||||
freshnessWindow time.Duration
|
||||
}
|
||||
|
||||
// ExecuteCommand verifies request freshness and replay protection before
|
||||
// delegating to the configured service implementation.
|
||||
func (s freshnessAndReplayService) ExecuteCommand(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
|
||||
if err := s.verifyFreshnessAndReplay(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.delegate.ExecuteCommand(ctx, req)
|
||||
}
|
||||
|
||||
// SubscribeEvents verifies request freshness and replay protection before
|
||||
// delegating to the configured service implementation.
|
||||
func (s freshnessAndReplayService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
|
||||
if err := s.verifyFreshnessAndReplay(stream.Context()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.delegate.SubscribeEvents(req, stream)
|
||||
}
|
||||
|
||||
// newFreshnessAndReplayService wraps delegate with the freshness and replay
|
||||
// gate.
|
||||
func newFreshnessAndReplayService(delegate gatewayv1.EdgeGatewayServer, clk clock.Clock, replayStore replay.Store, freshnessWindow time.Duration) gatewayv1.EdgeGatewayServer {
|
||||
return freshnessAndReplayService{
|
||||
delegate: delegate,
|
||||
clock: clk,
|
||||
replayStore: replayStore,
|
||||
freshnessWindow: freshnessWindow,
|
||||
}
|
||||
}
|
||||
|
||||
func (s freshnessAndReplayService) verifyFreshnessAndReplay(ctx context.Context) error {
|
||||
envelope, ok := parsedEnvelopeFromContext(ctx)
|
||||
if !ok {
|
||||
return status.Error(codes.Internal, "authenticated request context is incomplete")
|
||||
}
|
||||
|
||||
now := s.clock.Now().UTC()
|
||||
requestTime := time.UnixMilli(envelope.TimestampMS).UTC()
|
||||
if requestTime.Before(now.Add(-s.freshnessWindow)) || requestTime.After(now.Add(s.freshnessWindow)) {
|
||||
return status.Error(codes.FailedPrecondition, "request timestamp is outside the freshness window")
|
||||
}
|
||||
|
||||
ttl := requestTime.Add(s.freshnessWindow).Sub(now)
|
||||
if ttl < minimumReplayReservationTTL {
|
||||
ttl = minimumReplayReservationTTL
|
||||
}
|
||||
|
||||
err := s.replayStore.Reserve(ctx, envelope.DeviceSessionID, envelope.RequestID, ttl)
|
||||
switch {
|
||||
case err == nil:
|
||||
return nil
|
||||
case errors.Is(err, replay.ErrDuplicate):
|
||||
return status.Error(codes.FailedPrecondition, "request replay detected")
|
||||
default:
|
||||
return status.Error(codes.Unavailable, "replay store is unavailable")
|
||||
}
|
||||
}
|
||||
|
||||
type unavailableReplayStore struct{}
|
||||
|
||||
func (unavailableReplayStore) Reserve(context.Context, string, string, time.Duration) error {
|
||||
return errors.New("replay store is unavailable")
|
||||
}
|
||||
|
||||
var _ gatewayv1.EdgeGatewayServer = freshnessAndReplayService{}
|
||||
@@ -0,0 +1,509 @@
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/replay"
|
||||
"galaxy/gateway/internal/session"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestExecuteCommandRejectsStaleTimestamp(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
timestampMS int64
|
||||
}{
|
||||
{
|
||||
name: "past window",
|
||||
timestampMS: testCurrentTime.Add(-testFreshnessWindow - time.Millisecond).UnixMilli(),
|
||||
},
|
||||
{
|
||||
name: "future window",
|
||||
timestampMS: testCurrentTime.Add(testFreshnessWindow + time.Millisecond).UnixMilli(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
|
||||
ReplayStore: staticReplayStore{},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithTimestamp("device-session-123", "request-123", tt.timestampMS))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
|
||||
assert.Equal(t, "request timestamp is outside the freshness window", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.executeCalls)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscribeEventsRejectsStaleTimestamp(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
timestampMS int64
|
||||
}{
|
||||
{
|
||||
name: "past window",
|
||||
timestampMS: testCurrentTime.Add(-testFreshnessWindow - time.Millisecond).UnixMilli(),
|
||||
},
|
||||
{
|
||||
name: "future window",
|
||||
timestampMS: testCurrentTime.Add(testFreshnessWindow + time.Millisecond).UnixMilli(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
|
||||
ReplayStore: staticReplayStore{},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequestWithTimestamp("device-session-123", "request-123", tt.timestampMS))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
|
||||
assert.Equal(t, "request timestamp is outside the freshness window", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.subscribeCalls)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteCommandRejectsReplay(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
|
||||
ReplayStore: staticReplayStore{
|
||||
reserveFunc: replayDuplicateBySessionAndRequest(),
|
||||
},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
req := newValidExecuteCommandRequest()
|
||||
|
||||
_, err := client.ExecuteCommand(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), req)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
|
||||
assert.Equal(t, "request replay detected", status.Convert(err).Message())
|
||||
assert.Equal(t, 1, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestSubscribeEventsRejectsReplay(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
|
||||
ReplayStore: staticReplayStore{
|
||||
reserveFunc: replayDuplicateBySessionAndRequest(),
|
||||
},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
req := newValidSubscribeEventsRequest()
|
||||
|
||||
stream, err := client.SubscribeEvents(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
event := recvBootstrapEvent(t, stream)
|
||||
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-123", "trace-123", testCurrentTime.UnixMilli())
|
||||
_, err = stream.Recv()
|
||||
require.ErrorIs(t, err, io.EOF)
|
||||
|
||||
err = subscribeEventsError(t, context.Background(), client, req)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
|
||||
assert.Equal(t, "request replay detected", status.Convert(err).Message())
|
||||
assert.Equal(t, 1, delegate.subscribeCalls)
|
||||
}
|
||||
|
||||
func TestExecuteCommandAllowsSameRequestIDAcrossDistinctSessions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{
|
||||
executeCommandFunc: func(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
|
||||
return &gatewayv1.ExecuteCommandResponse{RequestId: req.GetRequestId()}, nil
|
||||
},
|
||||
}
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{
|
||||
lookupFunc: func(ctx context.Context, deviceSessionID string) (session.Record, error) {
|
||||
return newActiveSessionRecordWithSessionID(deviceSessionID), nil
|
||||
},
|
||||
},
|
||||
ReplayStore: staticReplayStore{
|
||||
reserveFunc: replayDuplicateBySessionAndRequest(),
|
||||
},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-123", "request-shared"))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-456", "request-shared"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 2, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestSubscribeEventsAllowsSameRequestIDAcrossDistinctSessions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{
|
||||
subscribeEventsFunc: func(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{
|
||||
lookupFunc: func(ctx context.Context, deviceSessionID string) (session.Record, error) {
|
||||
return newActiveSessionRecordWithSessionID(deviceSessionID), nil
|
||||
},
|
||||
},
|
||||
ReplayStore: staticReplayStore{
|
||||
reserveFunc: replayDuplicateBySessionAndRequest(),
|
||||
},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
stream, err := client.SubscribeEvents(context.Background(), newValidSubscribeEventsRequestWithSessionAndRequestID("device-session-123", "request-shared"))
|
||||
require.NoError(t, err)
|
||||
event := recvBootstrapEvent(t, stream)
|
||||
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-shared", "trace-123", testCurrentTime.UnixMilli())
|
||||
_, err = stream.Recv()
|
||||
require.ErrorIs(t, err, io.EOF)
|
||||
|
||||
stream, err = client.SubscribeEvents(context.Background(), newValidSubscribeEventsRequestWithSessionAndRequestID("device-session-456", "request-shared"))
|
||||
require.NoError(t, err)
|
||||
event = recvBootstrapEvent(t, stream)
|
||||
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-shared", "trace-123", testCurrentTime.UnixMilli())
|
||||
_, err = stream.Recv()
|
||||
require.ErrorIs(t, err, io.EOF)
|
||||
|
||||
assert.Equal(t, 2, delegate.subscribeCalls)
|
||||
}
|
||||
|
||||
func TestExecuteCommandRejectsReplayStoreUnavailable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
|
||||
ReplayStore: staticReplayStore{
|
||||
reserveFunc: func(context.Context, string, string, time.Duration) error {
|
||||
return errors.New("redis down")
|
||||
},
|
||||
},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unavailable, status.Code(err))
|
||||
assert.Equal(t, "replay store is unavailable", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestSubscribeEventsRejectsReplayStoreUnavailable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
|
||||
ReplayStore: staticReplayStore{
|
||||
reserveFunc: func(context.Context, string, string, time.Duration) error {
|
||||
return errors.New("redis down")
|
||||
},
|
||||
},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unavailable, status.Code(err))
|
||||
assert.Equal(t, "replay store is unavailable", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.subscribeCalls)
|
||||
}
|
||||
|
||||
func TestExecuteCommandFreshRequestReachesDelegateAndUsesDynamicReplayTTL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{
|
||||
executeCommandFunc: func(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
|
||||
return &gatewayv1.ExecuteCommandResponse{RequestId: req.GetRequestId()}, nil
|
||||
},
|
||||
}
|
||||
|
||||
var reservedDeviceSessionID string
|
||||
var reservedRequestID string
|
||||
var reservedTTL time.Duration
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
|
||||
ReplayStore: staticReplayStore{
|
||||
reserveFunc: func(ctx context.Context, deviceSessionID string, requestID string, ttl time.Duration) error {
|
||||
reservedDeviceSessionID = deviceSessionID
|
||||
reservedRequestID = requestID
|
||||
reservedTTL = ttl
|
||||
return nil
|
||||
},
|
||||
},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
response, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "request-123", response.GetRequestId())
|
||||
assert.Equal(t, "device-session-123", reservedDeviceSessionID)
|
||||
assert.Equal(t, "request-123", reservedRequestID)
|
||||
assert.Equal(t, testFreshnessWindow, reservedTTL)
|
||||
assert.Equal(t, 1, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestSubscribeEventsFreshRequestReachesDelegateAndUsesDynamicReplayTTL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{
|
||||
subscribeEventsFunc: func(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
var reservedTTL time.Duration
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
|
||||
ReplayStore: staticReplayStore{
|
||||
reserveFunc: func(ctx context.Context, deviceSessionID string, requestID string, ttl time.Duration) error {
|
||||
assert.Equal(t, "device-session-123", deviceSessionID)
|
||||
assert.Equal(t, "request-123", requestID)
|
||||
reservedTTL = ttl
|
||||
return nil
|
||||
},
|
||||
},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
stream, err := client.SubscribeEvents(context.Background(), newValidSubscribeEventsRequest())
|
||||
require.NoError(t, err)
|
||||
event := recvBootstrapEvent(t, stream)
|
||||
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-123", "trace-123", testCurrentTime.UnixMilli())
|
||||
_, err = stream.Recv()
|
||||
require.ErrorIs(t, err, io.EOF)
|
||||
assert.Equal(t, testFreshnessWindow, reservedTTL)
|
||||
assert.Equal(t, 1, delegate.subscribeCalls)
|
||||
}
|
||||
|
||||
func TestExecuteCommandFutureSkewUsesExtendedReplayTTL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{
|
||||
executeCommandFunc: func(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
|
||||
return &gatewayv1.ExecuteCommandResponse{RequestId: req.GetRequestId()}, nil
|
||||
},
|
||||
}
|
||||
|
||||
var reservedTTL time.Duration
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
|
||||
ReplayStore: staticReplayStore{
|
||||
reserveFunc: func(ctx context.Context, deviceSessionID string, requestID string, ttl time.Duration) error {
|
||||
reservedTTL = ttl
|
||||
return nil
|
||||
},
|
||||
},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(
|
||||
context.Background(),
|
||||
newValidExecuteCommandRequestWithTimestamp("device-session-123", "request-123", testCurrentTime.Add(2*time.Minute).UnixMilli()),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 7*time.Minute, reservedTTL)
|
||||
assert.Equal(t, 1, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestExecuteCommandBoundaryFreshnessUsesMinimumReplayTTL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{
|
||||
executeCommandFunc: func(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
|
||||
return &gatewayv1.ExecuteCommandResponse{RequestId: req.GetRequestId()}, nil
|
||||
},
|
||||
}
|
||||
|
||||
var reservedTTL time.Duration
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
|
||||
ReplayStore: staticReplayStore{
|
||||
reserveFunc: func(ctx context.Context, deviceSessionID string, requestID string, ttl time.Duration) error {
|
||||
reservedTTL = ttl
|
||||
return nil
|
||||
},
|
||||
},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(
|
||||
context.Background(),
|
||||
newValidExecuteCommandRequestWithTimestamp("device-session-123", "request-123", testCurrentTime.Add(-testFreshnessWindow).UnixMilli()),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, minimumReplayReservationTTL, reservedTTL)
|
||||
assert.Equal(t, 1, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func replayDuplicateBySessionAndRequest() func(context.Context, string, string, time.Duration) error {
|
||||
var (
|
||||
mu sync.Mutex
|
||||
seen = make(map[string]struct{})
|
||||
)
|
||||
|
||||
return func(ctx context.Context, deviceSessionID string, requestID string, ttl time.Duration) error {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
key := deviceSessionID + "\x00" + requestID
|
||||
if _, ok := seen[key]; ok {
|
||||
return replay.ErrDuplicate
|
||||
}
|
||||
|
||||
seen[key] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,147 @@
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"path"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/logging"
|
||||
"galaxy/gateway/internal/telemetry"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func observabilityUnaryInterceptor(logger *zap.Logger, metrics *telemetry.Runtime) grpc.UnaryServerInterceptor {
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
|
||||
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
|
||||
start := time.Now()
|
||||
resp, err := handler(ctx, req)
|
||||
|
||||
recordGRPCRequest(logger, metrics, ctx, info.FullMethod, req, resp, err, time.Since(start), "unary")
|
||||
return resp, err
|
||||
}
|
||||
}
|
||||
|
||||
func observabilityStreamInterceptor(logger *zap.Logger, metrics *telemetry.Runtime) grpc.StreamServerInterceptor {
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
|
||||
return func(srv any, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||
start := time.Now()
|
||||
wrapped := &observabilityServerStream{ServerStream: stream}
|
||||
err := handler(srv, wrapped)
|
||||
|
||||
recordGRPCRequest(logger, metrics, stream.Context(), info.FullMethod, wrapped.request, nil, err, time.Since(start), "stream")
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
type observabilityServerStream struct {
|
||||
grpc.ServerStream
|
||||
request any
|
||||
}
|
||||
|
||||
func (s *observabilityServerStream) RecvMsg(m any) error {
|
||||
err := s.ServerStream.RecvMsg(m)
|
||||
if err == nil && s.request == nil {
|
||||
s.request = m
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func recordGRPCRequest(logger *zap.Logger, metrics *telemetry.Runtime, ctx context.Context, fullMethod string, req any, resp any, err error, duration time.Duration, streamKind string) {
|
||||
rpcMethod := path.Base(fullMethod)
|
||||
messageType, requestID, traceID := grpcEnvelopeFields(req)
|
||||
resultCode := grpcResultCode(resp)
|
||||
grpcCode, grpcMessage, outcome := grpcOutcome(err)
|
||||
rejectReason := telemetry.RejectReason(outcome)
|
||||
|
||||
attrs := []attribute.KeyValue{
|
||||
attribute.String("rpc_method", rpcMethod),
|
||||
attribute.String("message_type", messageType),
|
||||
attribute.String("edge_outcome", string(outcome)),
|
||||
}
|
||||
if resultCode != "" {
|
||||
attrs = append(attrs, attribute.String("result_code", resultCode))
|
||||
}
|
||||
if rejectReason != "" {
|
||||
attrs = append(attrs, attribute.String("reject_reason", rejectReason))
|
||||
}
|
||||
metrics.RecordAuthenticatedGRPC(ctx, attrs, duration)
|
||||
|
||||
fields := []zap.Field{
|
||||
zap.String("component", "authenticated_grpc"),
|
||||
zap.String("transport", "grpc"),
|
||||
zap.String("stream_kind", streamKind),
|
||||
zap.String("rpc_method", rpcMethod),
|
||||
zap.String("message_type", messageType),
|
||||
zap.String("grpc_code", grpcCode.String()),
|
||||
zap.Float64("duration_ms", float64(duration.Microseconds())/1000),
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("trace_id", traceID),
|
||||
zap.String("peer_ip", peerIPFromContext(ctx)),
|
||||
zap.String("edge_outcome", string(outcome)),
|
||||
}
|
||||
if resultCode != "" {
|
||||
fields = append(fields, zap.String("result_code", resultCode))
|
||||
}
|
||||
if rejectReason != "" {
|
||||
fields = append(fields, zap.String("reject_reason", rejectReason))
|
||||
}
|
||||
if grpcMessage != "" {
|
||||
fields = append(fields, zap.String("grpc_message", grpcMessage))
|
||||
}
|
||||
fields = append(fields, logging.TraceFieldsFromContext(ctx)...)
|
||||
|
||||
switch outcome {
|
||||
case telemetry.EdgeOutcomeSuccess:
|
||||
logger.Info("authenticated gRPC request completed", fields...)
|
||||
case telemetry.EdgeOutcomeBackendUnavailable, telemetry.EdgeOutcomeDownstreamUnavailable, telemetry.EdgeOutcomeInternalError:
|
||||
logger.Error("authenticated gRPC request failed", fields...)
|
||||
default:
|
||||
logger.Warn("authenticated gRPC request rejected", fields...)
|
||||
}
|
||||
}
|
||||
|
||||
func grpcEnvelopeFields(req any) (messageType string, requestID string, traceID string) {
|
||||
switch typed := req.(type) {
|
||||
case *gatewayv1.ExecuteCommandRequest:
|
||||
return typed.GetMessageType(), typed.GetRequestId(), typed.GetTraceId()
|
||||
case *gatewayv1.SubscribeEventsRequest:
|
||||
return typed.GetMessageType(), typed.GetRequestId(), typed.GetTraceId()
|
||||
default:
|
||||
return "", "", ""
|
||||
}
|
||||
}
|
||||
|
||||
func grpcResultCode(resp any) string {
|
||||
typed, ok := resp.(*gatewayv1.ExecuteCommandResponse)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
return typed.GetResultCode()
|
||||
}
|
||||
|
||||
func grpcOutcome(err error) (codes.Code, string, telemetry.EdgeOutcome) {
|
||||
switch {
|
||||
case err == nil:
|
||||
return codes.OK, "", telemetry.EdgeOutcomeSuccess
|
||||
case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded):
|
||||
return codes.Canceled, err.Error(), telemetry.EdgeOutcomeSuccess
|
||||
default:
|
||||
grpcStatus := status.Convert(err)
|
||||
return grpcStatus.Code(), grpcStatus.Message(), telemetry.OutcomeFromGRPCStatus(grpcStatus.Code(), grpcStatus.Message())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"galaxy/gateway/internal/authn"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// payloadHashVerifyingService applies payload-hash verification after session
|
||||
// lookup and before any later auth or routing step runs.
|
||||
type payloadHashVerifyingService struct {
|
||||
gatewayv1.UnimplementedEdgeGatewayServer
|
||||
|
||||
delegate gatewayv1.EdgeGatewayServer
|
||||
}
|
||||
|
||||
// ExecuteCommand verifies req payload integrity before delegating to the
|
||||
// configured service implementation.
|
||||
func (s payloadHashVerifyingService) ExecuteCommand(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
|
||||
if err := verifyPayloadHash(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.delegate.ExecuteCommand(ctx, req)
|
||||
}
|
||||
|
||||
// SubscribeEvents verifies req payload integrity before delegating to the
|
||||
// configured service implementation.
|
||||
func (s payloadHashVerifyingService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
|
||||
if err := verifyPayloadHash(stream.Context()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.delegate.SubscribeEvents(req, stream)
|
||||
}
|
||||
|
||||
// newPayloadHashVerifyingService wraps delegate with the payload-hash
|
||||
// verification gate.
|
||||
func newPayloadHashVerifyingService(delegate gatewayv1.EdgeGatewayServer) gatewayv1.EdgeGatewayServer {
|
||||
return payloadHashVerifyingService{delegate: delegate}
|
||||
}
|
||||
|
||||
func verifyPayloadHash(ctx context.Context) error {
|
||||
envelope, ok := parsedEnvelopeFromContext(ctx)
|
||||
if !ok {
|
||||
return status.Error(codes.Internal, "authenticated request context is incomplete")
|
||||
}
|
||||
|
||||
err := authn.VerifyPayloadHash(envelope.PayloadBytes, envelope.PayloadHash)
|
||||
switch {
|
||||
case err == nil:
|
||||
return nil
|
||||
case errors.Is(err, authn.ErrInvalidPayloadHash), errors.Is(err, authn.ErrPayloadHashMismatch):
|
||||
return status.Error(codes.InvalidArgument, err.Error())
|
||||
default:
|
||||
return status.Error(codes.Internal, "payload hash verification failed")
|
||||
}
|
||||
}
|
||||
|
||||
var _ gatewayv1.EdgeGatewayServer = payloadHashVerifyingService{}
|
||||
@@ -0,0 +1,125 @@
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"testing"
|
||||
|
||||
"galaxy/gateway/internal/session"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestExecuteCommandRejectsPayloadHashWithInvalidLength(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
req := newValidExecuteCommandRequest()
|
||||
req.PayloadHash = []byte("short")
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), req)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.InvalidArgument, status.Code(err))
|
||||
assert.Equal(t, "payload_hash must be a 32-byte SHA-256 digest", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestExecuteCommandRejectsPayloadHashMismatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
req := newValidExecuteCommandRequest()
|
||||
sum := sha256.Sum256([]byte("other"))
|
||||
req.PayloadHash = sum[:]
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), req)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.InvalidArgument, status.Code(err))
|
||||
assert.Equal(t, "payload_hash does not match payload_bytes", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestSubscribeEventsRejectsPayloadHashWithInvalidLength(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
req := newValidSubscribeEventsRequest()
|
||||
req.PayloadHash = []byte("short")
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
err := subscribeEventsError(t, context.Background(), client, req)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.InvalidArgument, status.Code(err))
|
||||
assert.Equal(t, "payload_hash must be a 32-byte SHA-256 digest", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.subscribeCalls)
|
||||
}
|
||||
|
||||
func TestSubscribeEventsRejectsPayloadHashMismatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
req := newValidSubscribeEventsRequest()
|
||||
sum := sha256.Sum256([]byte("other"))
|
||||
req.PayloadHash = sum[:]
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
err := subscribeEventsError(t, context.Background(), client, req)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.InvalidArgument, status.Code(err))
|
||||
assert.Equal(t, "payload_hash does not match payload_bytes", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.subscribeCalls)
|
||||
}
|
||||
@@ -0,0 +1,172 @@
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
|
||||
"galaxy/gateway/internal/authn"
|
||||
"galaxy/gateway/internal/clock"
|
||||
"galaxy/gateway/internal/logging"
|
||||
"galaxy/gateway/internal/push"
|
||||
"galaxy/gateway/internal/telemetry"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// NewFanOutPushStreamService constructs the authenticated SubscribeEvents tail
|
||||
// service that registers active streams in hub and forwards client-facing
|
||||
// events after the bootstrap event has been sent.
|
||||
func NewFanOutPushStreamService(hub *push.Hub, responseSigner authn.ResponseSigner, clk clock.Clock, logger *zap.Logger) gatewayv1.EdgeGatewayServer {
|
||||
if responseSigner == nil {
|
||||
responseSigner = unavailableResponseSigner{}
|
||||
}
|
||||
if clk == nil {
|
||||
clk = clock.System{}
|
||||
}
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
|
||||
return fanOutPushStreamService{
|
||||
hub: hub,
|
||||
responseSigner: responseSigner,
|
||||
clock: clk,
|
||||
logger: logger.Named("push_stream"),
|
||||
}
|
||||
}
|
||||
|
||||
// fanOutPushStreamService owns the post-bootstrap authenticated push-stream
|
||||
// lifecycle backed by the in-memory push hub.
|
||||
type fanOutPushStreamService struct {
|
||||
gatewayv1.UnimplementedEdgeGatewayServer
|
||||
|
||||
hub *push.Hub
|
||||
responseSigner authn.ResponseSigner
|
||||
clock clock.Clock
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// SubscribeEvents registers the verified stream in the push hub and forwards
|
||||
// matching client-facing events until the stream ends.
|
||||
func (s fanOutPushStreamService) SubscribeEvents(_ *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
|
||||
binding, ok := authenticatedStreamBindingFromContext(stream.Context())
|
||||
if !ok {
|
||||
return status.Error(codes.Internal, "authenticated request context is incomplete")
|
||||
}
|
||||
if s.hub == nil {
|
||||
return status.Error(codes.Internal, "push hub is unavailable")
|
||||
}
|
||||
|
||||
subscription, err := s.hub.Register(push.StreamBinding{
|
||||
UserID: binding.UserID,
|
||||
DeviceSessionID: binding.DeviceSessionID,
|
||||
})
|
||||
if err != nil {
|
||||
return status.Error(codes.Internal, "push stream registration failed")
|
||||
}
|
||||
defer subscription.Close()
|
||||
|
||||
openFields := []zap.Field{
|
||||
zap.String("component", "authenticated_grpc"),
|
||||
zap.String("transport", "grpc"),
|
||||
zap.String("rpc_method", authenticatedRPCSubscribeEvents),
|
||||
zap.String("message_type", binding.MessageType),
|
||||
zap.String("request_id", binding.RequestID),
|
||||
zap.String("trace_id", binding.TraceID),
|
||||
zap.String("device_session_id", binding.DeviceSessionID),
|
||||
zap.String("user_id", binding.UserID),
|
||||
}
|
||||
openFields = append(openFields, logging.TraceFieldsFromContext(stream.Context())...)
|
||||
s.logger.Info("push stream opened", openFields...)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-stream.Context().Done():
|
||||
s.logger.Info("push stream closed", append(openFields, zap.String("edge_outcome", string(mapSubscriptionOutcome(stream.Context().Err()))))...)
|
||||
return stream.Context().Err()
|
||||
case <-subscription.Done():
|
||||
subscriptionErr := subscription.Err()
|
||||
s.logger.Warn("push stream closed", append(openFields,
|
||||
zap.String("edge_outcome", string(mapSubscriptionOutcome(subscriptionErr))),
|
||||
zap.String("reject_reason", string(mapSubscriptionOutcome(subscriptionErr))),
|
||||
)...)
|
||||
return mapSubscriptionError(subscriptionErr)
|
||||
case event := <-subscription.Events():
|
||||
signedEvent, err := s.buildGatewayEvent(event)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := stream.Send(signedEvent); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s fanOutPushStreamService) buildGatewayEvent(event push.Event) (*gatewayv1.GatewayEvent, error) {
|
||||
timestampMS := s.clock.Now().UTC().UnixMilli()
|
||||
payloadHash := sha256.Sum256(event.PayloadBytes)
|
||||
|
||||
signature, err := s.responseSigner.SignEvent(authn.EventSigningFields{
|
||||
EventType: event.EventType,
|
||||
EventID: event.EventID,
|
||||
TimestampMS: timestampMS,
|
||||
RequestID: event.RequestID,
|
||||
TraceID: event.TraceID,
|
||||
PayloadHash: payloadHash[:],
|
||||
})
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Unavailable, "response signer is unavailable")
|
||||
}
|
||||
|
||||
return &gatewayv1.GatewayEvent{
|
||||
EventType: event.EventType,
|
||||
EventId: event.EventID,
|
||||
TimestampMs: timestampMS,
|
||||
PayloadBytes: bytes.Clone(event.PayloadBytes),
|
||||
PayloadHash: bytes.Clone(payloadHash[:]),
|
||||
Signature: signature,
|
||||
RequestId: event.RequestID,
|
||||
TraceId: event.TraceID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func mapSubscriptionError(err error) error {
|
||||
switch {
|
||||
case err == nil:
|
||||
return nil
|
||||
case errors.Is(err, push.ErrSubscriptionRevoked):
|
||||
return status.Error(codes.FailedPrecondition, "device session is revoked")
|
||||
case errors.Is(err, push.ErrSubscriptionOverflow):
|
||||
return status.Error(codes.ResourceExhausted, "push stream overflowed")
|
||||
case errors.Is(err, push.ErrHubShuttingDown):
|
||||
return status.Error(codes.Unavailable, "gateway is shutting down")
|
||||
default:
|
||||
return status.Error(codes.Internal, "push stream closed unexpectedly")
|
||||
}
|
||||
}
|
||||
|
||||
func mapSubscriptionOutcome(err error) telemetry.EdgeOutcome {
|
||||
switch {
|
||||
case err == nil:
|
||||
return telemetry.EdgeOutcomeSuccess
|
||||
case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded):
|
||||
return telemetry.EdgeOutcomeSuccess
|
||||
case errors.Is(err, push.ErrSubscriptionRevoked):
|
||||
return telemetry.EdgeOutcomeRevokedSession
|
||||
case errors.Is(err, push.ErrSubscriptionOverflow):
|
||||
return telemetry.EdgeOutcomeRateLimited
|
||||
case errors.Is(err, push.ErrHubShuttingDown):
|
||||
return telemetry.EdgeOutcomeGatewayShuttingDown
|
||||
default:
|
||||
return telemetry.EdgeOutcomeInternalError
|
||||
}
|
||||
}
|
||||
|
||||
var _ gatewayv1.EdgeGatewayServer = fanOutPushStreamService{}
|
||||
@@ -0,0 +1,164 @@
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
|
||||
"galaxy/gateway/internal/authn"
|
||||
"galaxy/gateway/internal/clock"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
gatewayfbs "galaxy/schema/fbs/gateway"
|
||||
|
||||
flatbuffers "github.com/google/flatbuffers/go"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
const serverTimeEventType = "gateway.server_time"
|
||||
|
||||
// authenticatedStreamBinding captures the verified identity bound to one
|
||||
// authenticated SubscribeEvents stream after the full ingress pipeline
|
||||
// succeeds.
|
||||
type authenticatedStreamBinding struct {
|
||||
UserID string
|
||||
DeviceSessionID string
|
||||
MessageType string
|
||||
RequestID string
|
||||
TraceID string
|
||||
}
|
||||
|
||||
// authenticatedStreamBindingFromContext returns the verified stream binding
|
||||
// previously attached to ctx by the authenticated push-stream service.
|
||||
func authenticatedStreamBindingFromContext(ctx context.Context) (authenticatedStreamBinding, bool) {
|
||||
if ctx == nil {
|
||||
return authenticatedStreamBinding{}, false
|
||||
}
|
||||
|
||||
binding, ok := ctx.Value(authenticatedStreamBindingContextKey{}).(authenticatedStreamBinding)
|
||||
if !ok {
|
||||
return authenticatedStreamBinding{}, false
|
||||
}
|
||||
|
||||
return binding, true
|
||||
}
|
||||
|
||||
// authenticatedPushStreamService owns SubscribeEvents bootstrap behavior:
|
||||
// bind the authenticated stream, send the initial signed server-time event,
|
||||
// and then hand the stream lifecycle to the configured tail delegate.
|
||||
type authenticatedPushStreamService struct {
|
||||
gatewayv1.UnimplementedEdgeGatewayServer
|
||||
|
||||
tailDelegate gatewayv1.EdgeGatewayServer
|
||||
responseSigner authn.ResponseSigner
|
||||
clock clock.Clock
|
||||
}
|
||||
|
||||
// SubscribeEvents binds the verified stream identity, sends the initial signed
|
||||
// server-time event, and then delegates the remaining lifecycle.
|
||||
func (s authenticatedPushStreamService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
|
||||
envelope, ok := parsedEnvelopeFromContext(stream.Context())
|
||||
if !ok {
|
||||
return status.Error(codes.Internal, "authenticated request context is incomplete")
|
||||
}
|
||||
|
||||
record, ok := resolvedSessionFromContext(stream.Context())
|
||||
if !ok {
|
||||
return status.Error(codes.Internal, "authenticated request context is incomplete")
|
||||
}
|
||||
|
||||
binding := authenticatedStreamBinding{
|
||||
UserID: record.UserID,
|
||||
DeviceSessionID: record.DeviceSessionID,
|
||||
MessageType: envelope.MessageType,
|
||||
RequestID: envelope.RequestID,
|
||||
TraceID: envelope.TraceID,
|
||||
}
|
||||
boundStream := authenticatedStreamContextStream{
|
||||
ServerStreamingServer: stream,
|
||||
ctx: context.WithValue(
|
||||
stream.Context(),
|
||||
authenticatedStreamBindingContextKey{},
|
||||
binding,
|
||||
),
|
||||
}
|
||||
|
||||
serverTimeMS := s.clock.Now().UTC().UnixMilli()
|
||||
payloadBytes := buildServerTimeEventPayload(serverTimeMS)
|
||||
payloadHash := sha256.Sum256(payloadBytes)
|
||||
signature, err := s.responseSigner.SignEvent(authn.EventSigningFields{
|
||||
EventType: serverTimeEventType,
|
||||
EventID: envelope.RequestID,
|
||||
TimestampMS: serverTimeMS,
|
||||
RequestID: envelope.RequestID,
|
||||
TraceID: envelope.TraceID,
|
||||
PayloadHash: payloadHash[:],
|
||||
})
|
||||
if err != nil {
|
||||
return status.Error(codes.Unavailable, "response signer is unavailable")
|
||||
}
|
||||
|
||||
if err := boundStream.Send(&gatewayv1.GatewayEvent{
|
||||
EventType: serverTimeEventType,
|
||||
EventId: envelope.RequestID,
|
||||
TimestampMs: serverTimeMS,
|
||||
PayloadBytes: bytes.Clone(payloadBytes),
|
||||
PayloadHash: bytes.Clone(payloadHash[:]),
|
||||
Signature: signature,
|
||||
RequestId: envelope.RequestID,
|
||||
TraceId: envelope.TraceID,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.tailDelegate.SubscribeEvents(req, boundStream)
|
||||
}
|
||||
|
||||
func newAuthenticatedPushStreamService(tailDelegate gatewayv1.EdgeGatewayServer, responseSigner authn.ResponseSigner, clk clock.Clock) gatewayv1.EdgeGatewayServer {
|
||||
if tailDelegate == nil {
|
||||
tailDelegate = holdOpenSubscribeEventsService{}
|
||||
}
|
||||
|
||||
return authenticatedPushStreamService{
|
||||
tailDelegate: tailDelegate,
|
||||
responseSigner: responseSigner,
|
||||
clock: clk,
|
||||
}
|
||||
}
|
||||
|
||||
func buildServerTimeEventPayload(serverTimeMS int64) []byte {
|
||||
builder := flatbuffers.NewBuilder(32)
|
||||
gatewayfbs.ServerTimeEventStart(builder)
|
||||
gatewayfbs.ServerTimeEventAddServerTimeMs(builder, serverTimeMS)
|
||||
eventOffset := gatewayfbs.ServerTimeEventEnd(builder)
|
||||
gatewayfbs.FinishServerTimeEventBuffer(builder, eventOffset)
|
||||
|
||||
return bytes.Clone(builder.FinishedBytes())
|
||||
}
|
||||
|
||||
type authenticatedStreamBindingContextKey struct{}
|
||||
|
||||
type authenticatedStreamContextStream struct {
|
||||
grpc.ServerStreamingServer[gatewayv1.GatewayEvent]
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (s authenticatedStreamContextStream) Context() context.Context {
|
||||
if s.ctx == nil {
|
||||
return context.Background()
|
||||
}
|
||||
|
||||
return s.ctx
|
||||
}
|
||||
|
||||
type holdOpenSubscribeEventsService struct {
|
||||
gatewayv1.UnimplementedEdgeGatewayServer
|
||||
}
|
||||
|
||||
func (holdOpenSubscribeEventsService) SubscribeEvents(_ *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
|
||||
<-stream.Context().Done()
|
||||
return stream.Context().Err()
|
||||
}
|
||||
|
||||
var _ gatewayv1.EdgeGatewayServer = authenticatedPushStreamService{}
|
||||
@@ -0,0 +1,286 @@
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"galaxy/gateway/internal/config"
|
||||
"galaxy/gateway/internal/ratelimit"
|
||||
"galaxy/gateway/internal/session"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/peer"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
const (
|
||||
authenticatedGRPCBaseBucketKeyPrefix = "authenticated_grpc/"
|
||||
|
||||
authenticatedGRPCIPBucketKeySegment = authenticatedGRPCBaseBucketKeyPrefix + "ip="
|
||||
authenticatedGRPCSessionBucketKeySegment = authenticatedGRPCBaseBucketKeyPrefix + "session="
|
||||
authenticatedGRPCUserBucketKeySegment = authenticatedGRPCBaseBucketKeyPrefix + "user="
|
||||
authenticatedGRPCMessageClassBucketKeySegment = authenticatedGRPCBaseBucketKeyPrefix + "message_class="
|
||||
|
||||
unknownAuthenticatedPeerIP = "unknown"
|
||||
|
||||
authenticatedRPCExecuteCommand = "ExecuteCommand"
|
||||
authenticatedRPCSubscribeEvents = "SubscribeEvents"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrAuthenticatedPolicyDenied reports that the authenticated request was
|
||||
// rejected by later edge policy after transport authenticity succeeded.
|
||||
ErrAuthenticatedPolicyDenied = errors.New("authenticated request rejected by edge policy")
|
||||
|
||||
// ErrAuthenticatedPolicyUnavailable reports that authenticated policy could
|
||||
// not be evaluated because its backing dependency is unavailable.
|
||||
ErrAuthenticatedPolicyUnavailable = errors.New("authenticated request policy is unavailable")
|
||||
)
|
||||
|
||||
// AuthenticatedRequestLimiter applies authenticated gRPC rate-limit policy to
|
||||
// one concrete bucket key.
|
||||
type AuthenticatedRequestLimiter interface {
|
||||
// Reserve evaluates key under policy and reports whether the request may
|
||||
// proceed immediately.
|
||||
Reserve(key string, policy ratelimit.Policy) ratelimit.Decision
|
||||
}
|
||||
|
||||
// AuthenticatedRequest describes the authenticated request metadata exposed to
|
||||
// the edge-policy hook.
|
||||
type AuthenticatedRequest struct {
|
||||
// RPCMethod identifies the public gRPC method being processed.
|
||||
RPCMethod string
|
||||
|
||||
// PeerIP is the transport peer IP derived from the gRPC connection.
|
||||
PeerIP string
|
||||
|
||||
// MessageClass is the stable rate-limit and policy class. The gateway uses
|
||||
// the full message_type literal because the v1 transport does not yet define
|
||||
// a coarser authenticated class taxonomy.
|
||||
MessageClass string
|
||||
|
||||
// Envelope contains the verified transport envelope fields used by later
|
||||
// edge policy.
|
||||
Envelope AuthenticatedRequestEnvelope
|
||||
|
||||
// Session contains the authenticated identity resolved from SessionCache.
|
||||
Session session.Record
|
||||
}
|
||||
|
||||
// AuthenticatedRequestEnvelope describes the verified request envelope fields
|
||||
// exposed to the edge-policy hook.
|
||||
type AuthenticatedRequestEnvelope struct {
|
||||
// ProtocolVersion is the supported transport protocol version literal.
|
||||
ProtocolVersion string
|
||||
|
||||
// DeviceSessionID is the authenticated device-session identifier.
|
||||
DeviceSessionID string
|
||||
|
||||
// MessageType is the verified downstream routing key supplied by the client.
|
||||
MessageType string
|
||||
|
||||
// TimestampMS is the client timestamp that already passed freshness checks.
|
||||
TimestampMS int64
|
||||
|
||||
// RequestID is the authenticated transport request identifier.
|
||||
RequestID string
|
||||
|
||||
// TraceID is the optional client-supplied correlation identifier.
|
||||
TraceID string
|
||||
}
|
||||
|
||||
// AuthenticatedRequestPolicy evaluates later authenticated edge policy after
|
||||
// transport authenticity and rate-limit checks succeed.
|
||||
type AuthenticatedRequestPolicy interface {
|
||||
// Evaluate returns nil when the authenticated request may proceed. It should
|
||||
// wrap ErrAuthenticatedPolicyDenied for stable reject mapping and
|
||||
// ErrAuthenticatedPolicyUnavailable when its backing dependency is
|
||||
// temporarily unavailable.
|
||||
Evaluate(ctx context.Context, request AuthenticatedRequest) error
|
||||
}
|
||||
|
||||
type authenticatedRateLimitService struct {
|
||||
gatewayv1.UnimplementedEdgeGatewayServer
|
||||
|
||||
delegate gatewayv1.EdgeGatewayServer
|
||||
limiter AuthenticatedRequestLimiter
|
||||
policy AuthenticatedRequestPolicy
|
||||
cfg config.AuthenticatedGRPCAntiAbuseConfig
|
||||
}
|
||||
|
||||
// ExecuteCommand applies authenticated rate limits and edge policy before
|
||||
// delegating to the configured service implementation.
|
||||
func (s authenticatedRateLimitService) ExecuteCommand(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
|
||||
if err := s.applyRateLimitsAndPolicy(ctx, authenticatedRPCExecuteCommand); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.delegate.ExecuteCommand(ctx, req)
|
||||
}
|
||||
|
||||
// SubscribeEvents applies authenticated rate limits and edge policy before
|
||||
// delegating to the configured service implementation.
|
||||
func (s authenticatedRateLimitService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
|
||||
if err := s.applyRateLimitsAndPolicy(stream.Context(), authenticatedRPCSubscribeEvents); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.delegate.SubscribeEvents(req, stream)
|
||||
}
|
||||
|
||||
// newAuthenticatedRateLimitService wraps delegate with the authenticated
|
||||
// rate-limit and edge-policy gate.
|
||||
func newAuthenticatedRateLimitService(delegate gatewayv1.EdgeGatewayServer, limiter AuthenticatedRequestLimiter, policy AuthenticatedRequestPolicy, cfg config.AuthenticatedGRPCAntiAbuseConfig) gatewayv1.EdgeGatewayServer {
|
||||
return authenticatedRateLimitService{
|
||||
delegate: delegate,
|
||||
limiter: limiter,
|
||||
policy: policy,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func (s authenticatedRateLimitService) applyRateLimitsAndPolicy(ctx context.Context, rpcMethod string) error {
|
||||
request, err := authenticatedRequestFromContext(ctx, rpcMethod)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.applyRateLimits(request); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.applyPolicy(ctx, request); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s authenticatedRateLimitService) applyRateLimits(request AuthenticatedRequest) error {
|
||||
checks := []struct {
|
||||
key string
|
||||
policy config.AuthenticatedRateLimitConfig
|
||||
}{
|
||||
{
|
||||
key: authenticatedGRPCIPBucketKey(request.PeerIP),
|
||||
policy: s.cfg.IP,
|
||||
},
|
||||
{
|
||||
key: authenticatedGRPCSessionBucketKey(request.Envelope.DeviceSessionID),
|
||||
policy: s.cfg.Session,
|
||||
},
|
||||
{
|
||||
key: authenticatedGRPCUserBucketKey(request.Session.UserID),
|
||||
policy: s.cfg.User,
|
||||
},
|
||||
{
|
||||
key: authenticatedGRPCMessageClassBucketKey(request.MessageClass),
|
||||
policy: s.cfg.MessageClass,
|
||||
},
|
||||
}
|
||||
|
||||
for _, check := range checks {
|
||||
decision := s.limiter.Reserve(check.key, ratelimit.Policy{
|
||||
Requests: check.policy.Requests,
|
||||
Window: check.policy.Window,
|
||||
Burst: check.policy.Burst,
|
||||
})
|
||||
if !decision.Allowed {
|
||||
return status.Error(codes.ResourceExhausted, "authenticated request rate limit exceeded")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s authenticatedRateLimitService) applyPolicy(ctx context.Context, request AuthenticatedRequest) error {
|
||||
err := s.policy.Evaluate(ctx, request)
|
||||
switch {
|
||||
case err == nil:
|
||||
return nil
|
||||
case errors.Is(err, ErrAuthenticatedPolicyDenied):
|
||||
return status.Error(codes.PermissionDenied, "authenticated request rejected by edge policy")
|
||||
case errors.Is(err, ErrAuthenticatedPolicyUnavailable):
|
||||
return status.Error(codes.Unavailable, "authenticated request policy is unavailable")
|
||||
default:
|
||||
return status.Error(codes.Internal, "authenticated request policy evaluation failed")
|
||||
}
|
||||
}
|
||||
|
||||
func authenticatedRequestFromContext(ctx context.Context, rpcMethod string) (AuthenticatedRequest, error) {
|
||||
envelope, ok := parsedEnvelopeFromContext(ctx)
|
||||
if !ok {
|
||||
return AuthenticatedRequest{}, status.Error(codes.Internal, "authenticated request context is incomplete")
|
||||
}
|
||||
|
||||
record, ok := resolvedSessionFromContext(ctx)
|
||||
if !ok {
|
||||
return AuthenticatedRequest{}, status.Error(codes.Internal, "authenticated request context is incomplete")
|
||||
}
|
||||
|
||||
return AuthenticatedRequest{
|
||||
RPCMethod: rpcMethod,
|
||||
PeerIP: peerIPFromContext(ctx),
|
||||
MessageClass: authenticatedMessageClass(envelope.MessageType),
|
||||
Envelope: AuthenticatedRequestEnvelope{
|
||||
ProtocolVersion: envelope.ProtocolVersion,
|
||||
DeviceSessionID: envelope.DeviceSessionID,
|
||||
MessageType: envelope.MessageType,
|
||||
TimestampMS: envelope.TimestampMS,
|
||||
RequestID: envelope.RequestID,
|
||||
TraceID: envelope.TraceID,
|
||||
},
|
||||
Session: record,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func authenticatedGRPCIPBucketKey(peerIP string) string {
|
||||
return authenticatedGRPCIPBucketKeySegment + peerIP
|
||||
}
|
||||
|
||||
func authenticatedGRPCSessionBucketKey(deviceSessionID string) string {
|
||||
return authenticatedGRPCSessionBucketKeySegment + deviceSessionID
|
||||
}
|
||||
|
||||
func authenticatedGRPCUserBucketKey(userID string) string {
|
||||
return authenticatedGRPCUserBucketKeySegment + userID
|
||||
}
|
||||
|
||||
func authenticatedGRPCMessageClassBucketKey(messageClass string) string {
|
||||
return authenticatedGRPCMessageClassBucketKeySegment + messageClass
|
||||
}
|
||||
|
||||
func authenticatedMessageClass(messageType string) string {
|
||||
return messageType
|
||||
}
|
||||
|
||||
func peerIPFromContext(ctx context.Context) string {
|
||||
peerInfo, ok := peer.FromContext(ctx)
|
||||
if !ok || peerInfo.Addr == nil {
|
||||
return unknownAuthenticatedPeerIP
|
||||
}
|
||||
|
||||
value := strings.TrimSpace(peerInfo.Addr.String())
|
||||
if value == "" {
|
||||
return unknownAuthenticatedPeerIP
|
||||
}
|
||||
|
||||
host, _, err := net.SplitHostPort(value)
|
||||
if err == nil && host != "" {
|
||||
return host
|
||||
}
|
||||
|
||||
return value
|
||||
}
|
||||
|
||||
type noopAuthenticatedRequestPolicy struct{}
|
||||
|
||||
func (noopAuthenticatedRequestPolicy) Evaluate(context.Context, AuthenticatedRequest) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ gatewayv1.EdgeGatewayServer = authenticatedRateLimitService{}
|
||||
@@ -0,0 +1,497 @@
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/app"
|
||||
"galaxy/gateway/internal/config"
|
||||
"galaxy/gateway/internal/ratelimit"
|
||||
"galaxy/gateway/internal/restapi"
|
||||
"galaxy/gateway/internal/session"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestExecuteCommandRateLimitsByIP(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGatewayWithGRPCConfig(t, newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) {
|
||||
cfg.AntiAbuse.IP = config.AuthenticatedRateLimitConfig{
|
||||
Requests: 1,
|
||||
Window: time.Hour,
|
||||
Burst: 1,
|
||||
}
|
||||
}), ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: userMappedSessionCache(map[string]string{"device-session-1": "user-1", "device-session-2": "user-2"}),
|
||||
ReplayStore: staticReplayStore{},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-1", "request-1"))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-2", "request-2"))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.ResourceExhausted, status.Code(err))
|
||||
assert.Equal(t, "authenticated request rate limit exceeded", status.Convert(err).Message())
|
||||
assert.Equal(t, 1, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestExecuteCommandRateLimitsBySession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGatewayWithGRPCConfig(t, newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) {
|
||||
cfg.AntiAbuse.Session = config.AuthenticatedRateLimitConfig{
|
||||
Requests: 1,
|
||||
Window: time.Hour,
|
||||
Burst: 1,
|
||||
}
|
||||
}), ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: userMappedSessionCache(map[string]string{"device-session-1": "user-1", "device-session-2": "user-1"}),
|
||||
ReplayStore: staticReplayStore{},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-1", "request-1"))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-1", "request-2"))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.ResourceExhausted, status.Code(err))
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-2", "request-3"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 2, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestExecuteCommandRateLimitsByUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGatewayWithGRPCConfig(t, newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) {
|
||||
cfg.AntiAbuse.User = config.AuthenticatedRateLimitConfig{
|
||||
Requests: 1,
|
||||
Window: time.Hour,
|
||||
Burst: 1,
|
||||
}
|
||||
}), ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: userMappedSessionCache(map[string]string{
|
||||
"device-session-1": "user-shared",
|
||||
"device-session-2": "user-shared",
|
||||
"device-session-3": "user-other",
|
||||
}),
|
||||
ReplayStore: staticReplayStore{},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-1", "request-1"))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-2", "request-2"))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.ResourceExhausted, status.Code(err))
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-3", "request-3"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 2, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestExecuteCommandRateLimitsByMessageClass(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGatewayWithGRPCConfig(t, newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) {
|
||||
cfg.AntiAbuse.MessageClass = config.AuthenticatedRateLimitConfig{
|
||||
Requests: 1,
|
||||
Window: time.Hour,
|
||||
Burst: 1,
|
||||
}
|
||||
}), ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: userMappedSessionCache(map[string]string{
|
||||
"device-session-1": "user-1",
|
||||
"device-session-2": "user-2",
|
||||
}),
|
||||
ReplayStore: staticReplayStore{},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithMessageType("device-session-1", "request-1", "fleet.move"))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithMessageType("device-session-2", "request-2", "fleet.move"))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.ResourceExhausted, status.Code(err))
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithMessageType("device-session-2", "request-3", "fleet.rename"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 2, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestAuthenticatedPolicyHookReceivesVerifiedRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
policy := &recordingAuthenticatedRequestPolicy{}
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
|
||||
ReplayStore: staticReplayStore{},
|
||||
Policy: policy,
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, policy.requests, 1)
|
||||
assert.Equal(t, authenticatedRPCExecuteCommand, policy.requests[0].RPCMethod)
|
||||
assert.Equal(t, "127.0.0.1", policy.requests[0].PeerIP)
|
||||
assert.Equal(t, "fleet.move", policy.requests[0].MessageClass)
|
||||
assert.Equal(t, "device-session-123", policy.requests[0].Envelope.DeviceSessionID)
|
||||
assert.Equal(t, "request-123", policy.requests[0].Envelope.RequestID)
|
||||
assert.Equal(t, "trace-123", policy.requests[0].Envelope.TraceID)
|
||||
assert.Equal(t, "user-123", policy.requests[0].Session.UserID)
|
||||
assert.Equal(t, 1, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestExecuteCommandPolicyRejectMapsToPermissionDenied(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
|
||||
ReplayStore: staticReplayStore{},
|
||||
Policy: authenticatedRequestPolicyFunc(func(context.Context, AuthenticatedRequest) error {
|
||||
return fmt.Errorf("policy deny: %w", ErrAuthenticatedPolicyDenied)
|
||||
}),
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.PermissionDenied, status.Code(err))
|
||||
assert.Equal(t, "authenticated request rejected by edge policy", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestSubscribeEventsRateLimitRejectsStream(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGatewayWithGRPCConfig(t, newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) {
|
||||
cfg.AntiAbuse.IP = config.AuthenticatedRateLimitConfig{
|
||||
Requests: 1,
|
||||
Window: time.Hour,
|
||||
Burst: 1,
|
||||
}
|
||||
}), ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: userMappedSessionCache(map[string]string{"device-session-1": "user-1", "device-session-2": "user-2"}),
|
||||
ReplayStore: staticReplayStore{},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
stream, err := client.SubscribeEvents(context.Background(), newValidSubscribeEventsRequestWithSessionAndRequestID("device-session-1", "request-1"))
|
||||
require.NoError(t, err)
|
||||
event := recvBootstrapEvent(t, stream)
|
||||
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-1", "trace-123", testCurrentTime.UnixMilli())
|
||||
_, err = stream.Recv()
|
||||
require.ErrorIs(t, err, io.EOF)
|
||||
|
||||
err = subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequestWithSessionAndRequestID("device-session-2", "request-2"))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.ResourceExhausted, status.Code(err))
|
||||
assert.Equal(t, "authenticated request rate limit exceeded", status.Convert(err).Message())
|
||||
assert.Equal(t, 1, delegate.subscribeCalls)
|
||||
}
|
||||
|
||||
func TestAuthenticatedRateLimitsStayIsolatedFromPublicREST(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sharedLimiter := ratelimit.NewInMemory()
|
||||
|
||||
publicCfg := config.DefaultPublicHTTPConfig()
|
||||
publicCfg.Addr = unusedTCPAddr(t)
|
||||
publicCfg.AntiAbuse.PublicAuth.RateLimit = config.PublicRateLimitConfig{
|
||||
Requests: 1,
|
||||
Window: time.Hour,
|
||||
Burst: 1,
|
||||
}
|
||||
publicCfg.AntiAbuse.SendEmailCodeIdentity.RateLimit = config.PublicRateLimitConfig{
|
||||
Requests: 100,
|
||||
Window: time.Hour,
|
||||
Burst: 100,
|
||||
}
|
||||
|
||||
grpcCfg := newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) {
|
||||
cfg.Addr = unusedTCPAddr(t)
|
||||
cfg.AntiAbuse.IP = config.AuthenticatedRateLimitConfig{
|
||||
Requests: 1,
|
||||
Window: time.Hour,
|
||||
Burst: 1,
|
||||
}
|
||||
})
|
||||
|
||||
restServer := restapi.NewServer(publicCfg, restapi.ServerDependencies{
|
||||
AuthService: staticAuthServiceClient{},
|
||||
Limiter: publicLimiterAdapter{limiter: sharedLimiter},
|
||||
})
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
grpcServer := NewServer(grpcCfg, ServerDependencies{
|
||||
Service: delegate,
|
||||
Router: executeCommandAdapterRouter{service: delegate},
|
||||
ResponseSigner: newTestResponseSigner(),
|
||||
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
|
||||
ReplayStore: staticReplayStore{},
|
||||
Limiter: sharedLimiter,
|
||||
Clock: fixedClock{now: testCurrentTime},
|
||||
})
|
||||
|
||||
application := app.New(config.Config{ShutdownTimeout: time.Second}, restServer, grpcServer)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
resultCh := make(chan error, 1)
|
||||
go func() {
|
||||
resultCh <- application.Run(ctx)
|
||||
}()
|
||||
runGateway := runningGateway{cancel: cancel, resultCh: resultCh}
|
||||
defer runGateway.stop(t)
|
||||
|
||||
waitForHTTPHealthz(t, "http://"+publicCfg.Addr+"/healthz")
|
||||
addr := waitForListenAddr(t, grpcServer)
|
||||
|
||||
firstPublic := sendPublicAuthRequest(t, "http://"+publicCfg.Addr+"/api/v1/public/auth/send-email-code")
|
||||
secondPublic := sendPublicAuthRequest(t, "http://"+publicCfg.Addr+"/api/v1/public/auth/send-email-code")
|
||||
|
||||
assert.Equal(t, http.StatusOK, firstPublic.StatusCode)
|
||||
assert.Equal(t, http.StatusTooManyRequests, secondPublic.StatusCode)
|
||||
require.NoError(t, firstPublic.Body.Close())
|
||||
require.NoError(t, secondPublic.Body.Close())
|
||||
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func newAuthenticatedGRPCConfigForTest(mutate func(*config.AuthenticatedGRPCConfig)) config.AuthenticatedGRPCConfig {
|
||||
cfg := config.DefaultAuthenticatedGRPCConfig()
|
||||
cfg.Addr = "127.0.0.1:0"
|
||||
cfg.FreshnessWindow = testFreshnessWindow
|
||||
cfg.AntiAbuse.IP = config.AuthenticatedRateLimitConfig{
|
||||
Requests: 100,
|
||||
Window: time.Hour,
|
||||
Burst: 100,
|
||||
}
|
||||
cfg.AntiAbuse.Session = config.AuthenticatedRateLimitConfig{
|
||||
Requests: 100,
|
||||
Window: time.Hour,
|
||||
Burst: 100,
|
||||
}
|
||||
cfg.AntiAbuse.User = config.AuthenticatedRateLimitConfig{
|
||||
Requests: 100,
|
||||
Window: time.Hour,
|
||||
Burst: 100,
|
||||
}
|
||||
cfg.AntiAbuse.MessageClass = config.AuthenticatedRateLimitConfig{
|
||||
Requests: 100,
|
||||
Window: time.Hour,
|
||||
Burst: 100,
|
||||
}
|
||||
|
||||
if mutate != nil {
|
||||
mutate(&cfg)
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
func newValidExecuteCommandRequestWithMessageType(deviceSessionID string, requestID string, messageType string) *gatewayv1.ExecuteCommandRequest {
|
||||
req := newValidExecuteCommandRequestWithSessionAndRequestID(deviceSessionID, requestID)
|
||||
req.MessageType = messageType
|
||||
req.Signature = signRequest(
|
||||
req.GetProtocolVersion(),
|
||||
req.GetDeviceSessionId(),
|
||||
req.GetMessageType(),
|
||||
req.GetTimestampMs(),
|
||||
req.GetRequestId(),
|
||||
req.GetPayloadHash(),
|
||||
)
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
func userMappedSessionCache(users map[string]string) staticSessionCache {
|
||||
return staticSessionCache{
|
||||
lookupFunc: func(_ context.Context, deviceSessionID string) (session.Record, error) {
|
||||
userID, ok := users[deviceSessionID]
|
||||
if !ok {
|
||||
return session.Record{}, session.ErrNotFound
|
||||
}
|
||||
|
||||
record := newActiveSessionRecordWithSessionID(deviceSessionID)
|
||||
record.UserID = userID
|
||||
return record, nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type authenticatedRequestPolicyFunc func(context.Context, AuthenticatedRequest) error
|
||||
|
||||
func (f authenticatedRequestPolicyFunc) Evaluate(ctx context.Context, request AuthenticatedRequest) error {
|
||||
return f(ctx, request)
|
||||
}
|
||||
|
||||
type recordingAuthenticatedRequestPolicy struct {
|
||||
requests []AuthenticatedRequest
|
||||
}
|
||||
|
||||
func (p *recordingAuthenticatedRequestPolicy) Evaluate(_ context.Context, request AuthenticatedRequest) error {
|
||||
p.requests = append(p.requests, request)
|
||||
return nil
|
||||
}
|
||||
|
||||
type publicLimiterAdapter struct {
|
||||
limiter ratelimit.Limiter
|
||||
}
|
||||
|
||||
func (a publicLimiterAdapter) Reserve(key string, policy config.PublicRateLimitConfig) restapi.PublicRateLimitDecision {
|
||||
decision := a.limiter.Reserve(key, ratelimit.Policy{
|
||||
Requests: policy.Requests,
|
||||
Window: policy.Window,
|
||||
Burst: policy.Burst,
|
||||
})
|
||||
|
||||
return restapi.PublicRateLimitDecision{
|
||||
Allowed: decision.Allowed,
|
||||
RetryAfter: decision.RetryAfter,
|
||||
}
|
||||
}
|
||||
|
||||
type staticAuthServiceClient struct{}
|
||||
|
||||
func (staticAuthServiceClient) SendEmailCode(context.Context, restapi.SendEmailCodeInput) (restapi.SendEmailCodeResult, error) {
|
||||
return restapi.SendEmailCodeResult{ChallengeID: "challenge-123"}, nil
|
||||
}
|
||||
|
||||
func (staticAuthServiceClient) ConfirmEmailCode(context.Context, restapi.ConfirmEmailCodeInput) (restapi.ConfirmEmailCodeResult, error) {
|
||||
return restapi.ConfirmEmailCodeResult{DeviceSessionID: "device-session-123"}, nil
|
||||
}
|
||||
|
||||
func waitForHTTPHealthz(t *testing.T, url string) {
|
||||
t.Helper()
|
||||
|
||||
client := &http.Client{Timeout: 200 * time.Millisecond}
|
||||
require.Eventually(t, func() bool {
|
||||
response, err := client.Get(url)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
require.NoError(t, response.Body.Close())
|
||||
|
||||
return response.StatusCode == http.StatusOK
|
||||
}, 2*time.Second, 10*time.Millisecond, "public REST server did not become healthy: %s", url)
|
||||
}
|
||||
|
||||
func sendPublicAuthRequest(t *testing.T, url string) *http.Response {
|
||||
t.Helper()
|
||||
|
||||
request, err := http.NewRequest(http.MethodPost, url, strings.NewReader(`{"email":"pilot@example.com"}`))
|
||||
require.NoError(t, err)
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
response, err := (&http.Client{Timeout: time.Second}).Do(request)
|
||||
require.NoError(t, err)
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
func unusedTCPAddr(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
addr := listener.Addr().String()
|
||||
require.NoError(t, listener.Close())
|
||||
|
||||
return addr
|
||||
}
|
||||
@@ -0,0 +1,260 @@
|
||||
// Package grpcapi exposes the authenticated gRPC surface of the gateway.
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"galaxy/gateway/internal/authn"
|
||||
"galaxy/gateway/internal/clock"
|
||||
"galaxy/gateway/internal/config"
|
||||
"galaxy/gateway/internal/downstream"
|
||||
"galaxy/gateway/internal/push"
|
||||
"galaxy/gateway/internal/ratelimit"
|
||||
"galaxy/gateway/internal/replay"
|
||||
"galaxy/gateway/internal/session"
|
||||
"galaxy/gateway/internal/telemetry"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
// ServerDependencies describes the optional collaborators used by the
|
||||
// authenticated gRPC server. The zero value is valid and keeps the process
|
||||
// runnable with the built-in unimplemented service stub.
|
||||
type ServerDependencies struct {
|
||||
// Service optionally handles the post-bootstrap SubscribeEvents lifecycle
|
||||
// after the initial authenticated service event has been sent. When nil, the
|
||||
// gateway keeps authenticated SubscribeEvents streams open until the client
|
||||
// cancels them, the server shuts down, or a later stream send fails.
|
||||
Service gatewayv1.EdgeGatewayServer
|
||||
|
||||
// Router resolves the exact downstream unary client for the verified
|
||||
// message_type value. When nil, the authenticated unary surface uses an
|
||||
// empty exact-match router and returns UNIMPLEMENTED for unrouted commands.
|
||||
Router downstream.Router
|
||||
|
||||
// ResponseSigner signs authenticated unary responses after downstream
|
||||
// execution succeeds. When nil, the unary surface fails closed once it needs
|
||||
// to sign a routed response.
|
||||
ResponseSigner authn.ResponseSigner
|
||||
|
||||
// SessionCache resolves authenticated device sessions after the envelope
|
||||
// gate succeeds. When nil, the authenticated gRPC surface remains runnable
|
||||
// but valid envelopes fail closed as session-cache unavailable.
|
||||
SessionCache session.Cache
|
||||
|
||||
// Clock provides current server time for freshness checks. When nil, the
|
||||
// authenticated gRPC surface uses the system clock.
|
||||
Clock clock.Clock
|
||||
|
||||
// ReplayStore reserves authenticated request identifiers after signature
|
||||
// verification. When nil, valid requests fail closed as replay-store
|
||||
// unavailable.
|
||||
ReplayStore replay.Store
|
||||
|
||||
// Limiter applies authenticated rate limits after the request passes the
|
||||
// transport authenticity checks. When nil, the authenticated gRPC surface
|
||||
// uses a process-local in-memory limiter.
|
||||
Limiter AuthenticatedRequestLimiter
|
||||
|
||||
// Policy evaluates later authenticated edge policy after rate limits pass.
|
||||
// When nil, the authenticated gRPC surface applies a no-op allow policy.
|
||||
Policy AuthenticatedRequestPolicy
|
||||
|
||||
// Logger writes structured logs for authenticated gRPC traffic.
|
||||
Logger *zap.Logger
|
||||
|
||||
// Telemetry records low-cardinality gRPC metrics.
|
||||
Telemetry *telemetry.Runtime
|
||||
|
||||
// PushHub is the active authenticated push-stream hub. When present, the
|
||||
// server closes active streams before GracefulStop during shutdown.
|
||||
PushHub *push.Hub
|
||||
}
|
||||
|
||||
// Server owns the authenticated gRPC listener exposed by the gateway.
|
||||
type Server struct {
|
||||
cfg config.AuthenticatedGRPCConfig
|
||||
service gatewayv1.EdgeGatewayServer
|
||||
logger *zap.Logger
|
||||
pushHub *push.Hub
|
||||
metrics *telemetry.Runtime
|
||||
|
||||
stateMu sync.RWMutex
|
||||
server *grpc.Server
|
||||
listener net.Listener
|
||||
}
|
||||
|
||||
// NewServer constructs an authenticated gRPC server for the supplied listener
|
||||
// configuration and dependency bundle. Nil dependencies are replaced with safe
|
||||
// defaults so the gateway can expose the documented transport surface with the
|
||||
// full auth pipeline wired from built-in fallbacks.
|
||||
func NewServer(cfg config.AuthenticatedGRPCConfig, deps ServerDependencies) *Server {
|
||||
deps = normalizeServerDependencies(deps)
|
||||
|
||||
finalService := newCommandRoutingService(
|
||||
newAuthenticatedPushStreamService(deps.Service, deps.ResponseSigner, deps.Clock),
|
||||
deps.Router,
|
||||
deps.ResponseSigner,
|
||||
deps.Clock,
|
||||
cfg.DownstreamTimeout,
|
||||
)
|
||||
|
||||
return &Server{
|
||||
cfg: cfg,
|
||||
service: newEnvelopeValidatingService(
|
||||
newSessionLookupService(
|
||||
newPayloadHashVerifyingService(
|
||||
newSignatureVerifyingService(
|
||||
newFreshnessAndReplayService(
|
||||
newAuthenticatedRateLimitService(
|
||||
finalService,
|
||||
deps.Limiter,
|
||||
deps.Policy,
|
||||
cfg.AntiAbuse,
|
||||
),
|
||||
deps.Clock,
|
||||
deps.ReplayStore,
|
||||
cfg.FreshnessWindow,
|
||||
),
|
||||
),
|
||||
),
|
||||
deps.SessionCache,
|
||||
),
|
||||
),
|
||||
logger: deps.Logger.Named("authenticated_grpc"),
|
||||
pushHub: deps.PushHub,
|
||||
metrics: deps.Telemetry,
|
||||
}
|
||||
}
|
||||
|
||||
// Run binds the configured listener and serves the authenticated gRPC surface
|
||||
// until Shutdown closes the server.
|
||||
func (s *Server) Run(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
return errors.New("run authenticated gRPC server: nil context")
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", s.cfg.Addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("run authenticated gRPC server: listen on %q: %w", s.cfg.Addr, err)
|
||||
}
|
||||
|
||||
grpcServer := grpc.NewServer(
|
||||
grpc.ConnectionTimeout(s.cfg.ConnectionTimeout),
|
||||
grpc.StatsHandler(otelgrpc.NewServerHandler()),
|
||||
grpc.ChainUnaryInterceptor(observabilityUnaryInterceptor(s.logger, s.metrics)),
|
||||
grpc.ChainStreamInterceptor(observabilityStreamInterceptor(s.logger, s.metrics)),
|
||||
)
|
||||
gatewayv1.RegisterEdgeGatewayServer(grpcServer, s.service)
|
||||
|
||||
s.stateMu.Lock()
|
||||
s.server = grpcServer
|
||||
s.listener = listener
|
||||
s.stateMu.Unlock()
|
||||
|
||||
s.logger.Info("authenticated gRPC server started", zap.String("addr", listener.Addr().String()))
|
||||
|
||||
defer func() {
|
||||
s.stateMu.Lock()
|
||||
s.server = nil
|
||||
s.listener = nil
|
||||
s.stateMu.Unlock()
|
||||
}()
|
||||
|
||||
err = grpcServer.Serve(listener)
|
||||
switch {
|
||||
case err == nil:
|
||||
return nil
|
||||
case errors.Is(err, grpc.ErrServerStopped):
|
||||
s.logger.Info("authenticated gRPC server stopped")
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("run authenticated gRPC server: serve on %q: %w", s.cfg.Addr, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown gracefully stops the authenticated gRPC server within ctx. When the
|
||||
// graceful stop exceeds ctx, the server is force-stopped before returning the
|
||||
// timeout to the caller.
|
||||
func (s *Server) Shutdown(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
return errors.New("shutdown authenticated gRPC server: nil context")
|
||||
}
|
||||
|
||||
s.stateMu.RLock()
|
||||
server := s.server
|
||||
s.stateMu.RUnlock()
|
||||
|
||||
if server == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if s.pushHub != nil {
|
||||
s.pushHub.Shutdown()
|
||||
}
|
||||
|
||||
stopped := make(chan struct{})
|
||||
go func() {
|
||||
server.GracefulStop()
|
||||
close(stopped)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-stopped:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
server.Stop()
|
||||
<-stopped
|
||||
return fmt.Errorf("shutdown authenticated gRPC server: %w", ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) listenAddr() string {
|
||||
s.stateMu.RLock()
|
||||
defer s.stateMu.RUnlock()
|
||||
|
||||
if s.listener == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return s.listener.Addr().String()
|
||||
}
|
||||
|
||||
func normalizeServerDependencies(deps ServerDependencies) ServerDependencies {
|
||||
if deps.Router == nil {
|
||||
deps.Router = downstream.NewStaticRouter(nil)
|
||||
}
|
||||
if deps.ResponseSigner == nil {
|
||||
deps.ResponseSigner = unavailableResponseSigner{}
|
||||
}
|
||||
if deps.SessionCache == nil {
|
||||
deps.SessionCache = unavailableSessionCache{}
|
||||
}
|
||||
if deps.Clock == nil {
|
||||
deps.Clock = clock.System{}
|
||||
}
|
||||
if deps.ReplayStore == nil {
|
||||
deps.ReplayStore = unavailableReplayStore{}
|
||||
}
|
||||
if deps.Limiter == nil {
|
||||
deps.Limiter = ratelimit.NewInMemory()
|
||||
}
|
||||
if deps.Policy == nil {
|
||||
deps.Policy = noopAuthenticatedRequestPolicy{}
|
||||
}
|
||||
if deps.Logger == nil {
|
||||
deps.Logger = zap.NewNop()
|
||||
}
|
||||
|
||||
return deps
|
||||
}
|
||||
@@ -0,0 +1,332 @@
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/app"
|
||||
"galaxy/gateway/internal/config"
|
||||
"galaxy/gateway/internal/session"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestExecuteCommandRejectsMalformedEnvelope(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), &gatewayv1.ExecuteCommandRequest{})
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.InvalidArgument, status.Code(err))
|
||||
}
|
||||
|
||||
func TestSubscribeEventsRejectsMalformedEnvelope(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
err := subscribeEventsError(t, context.Background(), client, &gatewayv1.SubscribeEventsRequest{})
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.InvalidArgument, status.Code(err))
|
||||
}
|
||||
|
||||
func TestExecuteCommandRejectsUnsupportedProtocolVersion(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), &gatewayv1.ExecuteCommandRequest{
|
||||
ProtocolVersion: "v2",
|
||||
DeviceSessionId: "device-session-123",
|
||||
MessageType: "fleet.move",
|
||||
TimestampMs: 123456789,
|
||||
RequestId: "request-123",
|
||||
PayloadBytes: []byte("payload"),
|
||||
PayloadHash: []byte("hash"),
|
||||
Signature: []byte("signature"),
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
|
||||
assert.Equal(t, `unsupported protocol_version "v2"`, status.Convert(err).Message())
|
||||
}
|
||||
|
||||
func TestExecuteCommandValidEnvelopeStillReturnsUnimplemented(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
SessionCache: staticSessionCache{
|
||||
lookupFunc: func(context.Context, string) (session.Record, error) {
|
||||
return newActiveSessionRecord(), nil
|
||||
},
|
||||
},
|
||||
ReplayStore: staticReplayStore{},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unimplemented, status.Code(err))
|
||||
}
|
||||
|
||||
func TestExecuteCommandMissingReplayStoreFailsClosed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
SessionCache: staticSessionCache{
|
||||
lookupFunc: func(context.Context, string) (session.Record, error) {
|
||||
return newActiveSessionRecord(), nil
|
||||
},
|
||||
},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unavailable, status.Code(err))
|
||||
assert.Equal(t, "replay store is unavailable", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
func TestSubscribeEventsValidEnvelopeSendsBootstrapEventAndWaitsForCancellation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
SessionCache: staticSessionCache{
|
||||
lookupFunc: func(context.Context, string) (session.Record, error) {
|
||||
return newActiveSessionRecord(), nil
|
||||
},
|
||||
},
|
||||
ReplayStore: staticReplayStore{},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
stream, err := client.SubscribeEvents(ctx, newValidSubscribeEventsRequest())
|
||||
require.NoError(t, err)
|
||||
|
||||
event := recvBootstrapEvent(t, stream)
|
||||
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-123", "trace-123", testCurrentTime.UnixMilli())
|
||||
|
||||
recvResult := make(chan error, 1)
|
||||
go func() {
|
||||
_, recvErr := stream.Recv()
|
||||
recvResult <- recvErr
|
||||
}()
|
||||
|
||||
require.Never(t, func() bool {
|
||||
select {
|
||||
case <-recvResult:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, 100*time.Millisecond, 10*time.Millisecond, "stream closed before cancellation")
|
||||
|
||||
cancel()
|
||||
|
||||
var recvErr error
|
||||
require.Eventually(t, func() bool {
|
||||
select {
|
||||
case recvErr = <-recvResult:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, time.Second, 10*time.Millisecond, "stream did not stop after client cancellation")
|
||||
require.Error(t, recvErr)
|
||||
assert.Equal(t, codes.Canceled, status.Code(recvErr))
|
||||
}
|
||||
|
||||
func TestSubscribeEventsMissingReplayStoreFailsClosed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
SessionCache: staticSessionCache{
|
||||
lookupFunc: func(context.Context, string) (session.Record, error) {
|
||||
return newActiveSessionRecord(), nil
|
||||
},
|
||||
},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unavailable, status.Code(err))
|
||||
assert.Equal(t, "replay store is unavailable", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
func TestServerLifecycle(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{})
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
require.NoError(t, conn.Close())
|
||||
|
||||
runGateway.stop(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, err := grpc.DialContext(
|
||||
ctx,
|
||||
addr,
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithBlock(),
|
||||
)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
type runningGateway struct {
|
||||
cancel context.CancelFunc
|
||||
resultCh chan error
|
||||
}
|
||||
|
||||
func newTestGateway(t *testing.T, deps ServerDependencies) (*Server, runningGateway) {
|
||||
t.Helper()
|
||||
|
||||
grpcCfg := config.DefaultAuthenticatedGRPCConfig()
|
||||
grpcCfg.Addr = "127.0.0.1:0"
|
||||
grpcCfg.FreshnessWindow = testFreshnessWindow
|
||||
|
||||
return newTestGatewayWithGRPCConfig(t, grpcCfg, deps)
|
||||
}
|
||||
|
||||
func newTestGatewayWithGRPCConfig(t *testing.T, grpcCfg config.AuthenticatedGRPCConfig, deps ServerDependencies) (*Server, runningGateway) {
|
||||
t.Helper()
|
||||
|
||||
cfg := config.Config{
|
||||
ShutdownTimeout: time.Second,
|
||||
AuthenticatedGRPC: grpcCfg,
|
||||
}
|
||||
|
||||
if deps.Clock == nil {
|
||||
deps.Clock = fixedClock{now: testCurrentTime}
|
||||
}
|
||||
if deps.ResponseSigner == nil {
|
||||
deps.ResponseSigner = newTestResponseSigner()
|
||||
}
|
||||
if deps.Router == nil && deps.Service != nil {
|
||||
deps.Router = executeCommandAdapterRouter{service: deps.Service}
|
||||
}
|
||||
|
||||
server := NewServer(cfg.AuthenticatedGRPC, deps)
|
||||
application := app.New(cfg, server)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
resultCh := make(chan error, 1)
|
||||
go func() {
|
||||
resultCh <- application.Run(ctx)
|
||||
}()
|
||||
|
||||
return server, runningGateway{
|
||||
cancel: cancel,
|
||||
resultCh: resultCh,
|
||||
}
|
||||
}
|
||||
|
||||
func (g runningGateway) stop(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
g.cancel()
|
||||
|
||||
var err error
|
||||
require.Eventually(t, func() bool {
|
||||
select {
|
||||
case err = <-g.resultCh:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, 2*time.Second, 10*time.Millisecond, "gateway did not stop after cancellation")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func waitForListenAddr(t *testing.T, server *Server) string {
|
||||
t.Helper()
|
||||
|
||||
var addr string
|
||||
require.Eventually(t, func() bool {
|
||||
addr = server.listenAddr()
|
||||
return addr != ""
|
||||
}, time.Second, 10*time.Millisecond, "server did not start listening")
|
||||
return addr
|
||||
}
|
||||
|
||||
func dialGatewayClient(t *testing.T, addr string) *grpc.ClientConn {
|
||||
t.Helper()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn, err := grpc.DialContext(
|
||||
ctx,
|
||||
addr,
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithBlock(),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
return conn
|
||||
}
|
||||
@@ -0,0 +1,126 @@
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"galaxy/gateway/internal/session"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// resolvedSessionFromContext returns the session record previously attached to
|
||||
// ctx by the session-lookup gateway wrapper.
|
||||
func resolvedSessionFromContext(ctx context.Context) (session.Record, bool) {
|
||||
if ctx == nil {
|
||||
return session.Record{}, false
|
||||
}
|
||||
|
||||
record, ok := ctx.Value(resolvedSessionContextKey{}).(session.Record)
|
||||
if !ok {
|
||||
return session.Record{}, false
|
||||
}
|
||||
|
||||
return cloneSessionRecord(record), true
|
||||
}
|
||||
|
||||
// sessionLookupService resolves the authenticated session from SessionCache
|
||||
// after envelope parsing succeeds and before later auth steps run.
|
||||
type sessionLookupService struct {
|
||||
gatewayv1.UnimplementedEdgeGatewayServer
|
||||
|
||||
delegate gatewayv1.EdgeGatewayServer
|
||||
cache session.Cache
|
||||
}
|
||||
|
||||
// ExecuteCommand resolves the cached session for req and only then forwards it
|
||||
// to the configured delegate with the resolved session attached to ctx.
|
||||
func (s sessionLookupService) ExecuteCommand(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
|
||||
record, err := s.lookupSession(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.delegate.ExecuteCommand(context.WithValue(ctx, resolvedSessionContextKey{}, cloneSessionRecord(record)), req)
|
||||
}
|
||||
|
||||
// SubscribeEvents resolves the cached session for req and only then forwards it
|
||||
// to the configured delegate with the resolved session attached to the stream
|
||||
// context.
|
||||
func (s sessionLookupService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
|
||||
record, err := s.lookupSession(stream.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.delegate.SubscribeEvents(req, resolvedSessionContextStream{
|
||||
ServerStreamingServer: stream,
|
||||
ctx: context.WithValue(stream.Context(), resolvedSessionContextKey{}, cloneSessionRecord(record)),
|
||||
})
|
||||
}
|
||||
|
||||
// newSessionLookupService wraps delegate with the session-cache lookup gate.
|
||||
func newSessionLookupService(delegate gatewayv1.EdgeGatewayServer, cache session.Cache) gatewayv1.EdgeGatewayServer {
|
||||
return sessionLookupService{
|
||||
delegate: delegate,
|
||||
cache: cache,
|
||||
}
|
||||
}
|
||||
|
||||
func (s sessionLookupService) lookupSession(ctx context.Context) (session.Record, error) {
|
||||
envelope, ok := parsedEnvelopeFromContext(ctx)
|
||||
if !ok {
|
||||
return session.Record{}, status.Error(codes.Internal, "authenticated request context is incomplete")
|
||||
}
|
||||
|
||||
record, err := s.cache.Lookup(ctx, envelope.DeviceSessionID)
|
||||
switch {
|
||||
case err == nil:
|
||||
case errors.Is(err, session.ErrNotFound):
|
||||
return session.Record{}, status.Error(codes.Unauthenticated, "unknown device session")
|
||||
default:
|
||||
return session.Record{}, status.Error(codes.Unavailable, "session cache is unavailable")
|
||||
}
|
||||
|
||||
if record.Status == session.StatusRevoked {
|
||||
return session.Record{}, status.Error(codes.FailedPrecondition, "device session is revoked")
|
||||
}
|
||||
|
||||
return cloneSessionRecord(record), nil
|
||||
}
|
||||
|
||||
func cloneSessionRecord(record session.Record) session.Record {
|
||||
cloned := record
|
||||
if record.RevokedAtMS != nil {
|
||||
value := *record.RevokedAtMS
|
||||
cloned.RevokedAtMS = &value
|
||||
}
|
||||
|
||||
return cloned
|
||||
}
|
||||
|
||||
type resolvedSessionContextKey struct{}
|
||||
|
||||
type resolvedSessionContextStream struct {
|
||||
grpc.ServerStreamingServer[gatewayv1.GatewayEvent]
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (s resolvedSessionContextStream) Context() context.Context {
|
||||
if s.ctx == nil {
|
||||
return context.Background()
|
||||
}
|
||||
|
||||
return s.ctx
|
||||
}
|
||||
|
||||
type unavailableSessionCache struct{}
|
||||
|
||||
func (unavailableSessionCache) Lookup(context.Context, string) (session.Record, error) {
|
||||
return session.Record{}, errors.New("session cache is unavailable")
|
||||
}
|
||||
|
||||
var _ gatewayv1.EdgeGatewayServer = sessionLookupService{}
|
||||
@@ -0,0 +1,294 @@
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"galaxy/gateway/internal/session"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestExecuteCommandRejectsUnknownSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{
|
||||
lookupFunc: func(context.Context, string) (session.Record, error) {
|
||||
return session.Record{}, session.ErrNotFound
|
||||
},
|
||||
},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unauthenticated, status.Code(err))
|
||||
assert.Equal(t, "unknown device session", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestSubscribeEventsRejectsUnknownSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{
|
||||
lookupFunc: func(context.Context, string) (session.Record, error) {
|
||||
return session.Record{}, session.ErrNotFound
|
||||
},
|
||||
},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unauthenticated, status.Code(err))
|
||||
assert.Equal(t, "unknown device session", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.subscribeCalls)
|
||||
}
|
||||
|
||||
func TestExecuteCommandRejectsRevokedSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newRevokedSessionRecord(), nil }},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
|
||||
assert.Equal(t, "device session is revoked", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestSubscribeEventsRejectsRevokedSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newRevokedSessionRecord(), nil }},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
|
||||
assert.Equal(t, "device session is revoked", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.subscribeCalls)
|
||||
}
|
||||
|
||||
func TestExecuteCommandRejectsSessionCacheUnavailable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{
|
||||
lookupFunc: func(context.Context, string) (session.Record, error) {
|
||||
return session.Record{}, errors.New("redis down")
|
||||
},
|
||||
},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unavailable, status.Code(err))
|
||||
assert.Equal(t, "session cache is unavailable", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestSubscribeEventsRejectsSessionCacheUnavailable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{
|
||||
lookupFunc: func(context.Context, string) (session.Record, error) {
|
||||
return session.Record{}, errors.New("redis down")
|
||||
},
|
||||
},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unavailable, status.Code(err))
|
||||
assert.Equal(t, "session cache is unavailable", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.subscribeCalls)
|
||||
}
|
||||
|
||||
func TestExecuteCommandAttachesResolvedSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{
|
||||
executeCommandFunc: func(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
|
||||
record, ok := resolvedSessionFromContext(ctx)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, newActiveSessionRecord(), record)
|
||||
return &gatewayv1.ExecuteCommandResponse{RequestId: req.GetRequestId()}, nil
|
||||
},
|
||||
}
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
|
||||
ReplayStore: staticReplayStore{},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
response, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "request-123", response.GetRequestId())
|
||||
}
|
||||
|
||||
func TestSubscribeEventsAttachesResolvedSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{
|
||||
subscribeEventsFunc: func(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
|
||||
record, ok := resolvedSessionFromContext(stream.Context())
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, newActiveSessionRecord(), record)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
|
||||
ReplayStore: staticReplayStore{},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
stream, err := client.SubscribeEvents(context.Background(), newValidSubscribeEventsRequest())
|
||||
require.NoError(t, err)
|
||||
|
||||
event := recvBootstrapEvent(t, stream)
|
||||
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-123", "trace-123", testCurrentTime.UnixMilli())
|
||||
|
||||
_, err = stream.Recv()
|
||||
require.ErrorIs(t, err, io.EOF)
|
||||
}
|
||||
|
||||
func TestSubscribeEventsAttachesAuthenticatedStreamBinding(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{
|
||||
subscribeEventsFunc: func(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
|
||||
binding, ok := authenticatedStreamBindingFromContext(stream.Context())
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, authenticatedStreamBinding{
|
||||
UserID: "user-123",
|
||||
DeviceSessionID: "device-session-123",
|
||||
MessageType: "gateway.subscribe",
|
||||
RequestID: "request-123",
|
||||
TraceID: "trace-123",
|
||||
}, binding)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
|
||||
ReplayStore: staticReplayStore{},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
stream, err := client.SubscribeEvents(context.Background(), newValidSubscribeEventsRequest())
|
||||
require.NoError(t, err)
|
||||
|
||||
event := recvBootstrapEvent(t, stream)
|
||||
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-123", "trace-123", testCurrentTime.UnixMilli())
|
||||
|
||||
_, err = stream.Recv()
|
||||
require.ErrorIs(t, err, io.EOF)
|
||||
}
|
||||
|
||||
type staticSessionCache struct {
|
||||
lookupFunc func(context.Context, string) (session.Record, error)
|
||||
}
|
||||
|
||||
func (c staticSessionCache) Lookup(ctx context.Context, deviceSessionID string) (session.Record, error) {
|
||||
return c.lookupFunc(ctx, deviceSessionID)
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"galaxy/gateway/internal/authn"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// signatureVerifyingService applies client-signature verification after
|
||||
// payload integrity checks and before later auth or routing steps run.
|
||||
type signatureVerifyingService struct {
|
||||
gatewayv1.UnimplementedEdgeGatewayServer
|
||||
|
||||
delegate gatewayv1.EdgeGatewayServer
|
||||
}
|
||||
|
||||
// ExecuteCommand verifies req client signature before delegating to the
|
||||
// configured service implementation.
|
||||
func (s signatureVerifyingService) ExecuteCommand(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
|
||||
if err := verifyRequestSignature(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.delegate.ExecuteCommand(ctx, req)
|
||||
}
|
||||
|
||||
// SubscribeEvents verifies req client signature before delegating to the
|
||||
// configured service implementation.
|
||||
func (s signatureVerifyingService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
|
||||
if err := verifyRequestSignature(stream.Context()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.delegate.SubscribeEvents(req, stream)
|
||||
}
|
||||
|
||||
// newSignatureVerifyingService wraps delegate with the client-signature
|
||||
// verification gate.
|
||||
func newSignatureVerifyingService(delegate gatewayv1.EdgeGatewayServer) gatewayv1.EdgeGatewayServer {
|
||||
return signatureVerifyingService{delegate: delegate}
|
||||
}
|
||||
|
||||
func verifyRequestSignature(ctx context.Context) error {
|
||||
envelope, ok := parsedEnvelopeFromContext(ctx)
|
||||
if !ok {
|
||||
return status.Error(codes.Internal, "authenticated request context is incomplete")
|
||||
}
|
||||
|
||||
record, ok := resolvedSessionFromContext(ctx)
|
||||
if !ok {
|
||||
return status.Error(codes.Internal, "authenticated request context is incomplete")
|
||||
}
|
||||
|
||||
err := authn.VerifyRequestSignature(record.ClientPublicKey, envelope.Signature, authn.RequestSigningFields{
|
||||
ProtocolVersion: envelope.ProtocolVersion,
|
||||
DeviceSessionID: envelope.DeviceSessionID,
|
||||
MessageType: envelope.MessageType,
|
||||
TimestampMS: envelope.TimestampMS,
|
||||
RequestID: envelope.RequestID,
|
||||
PayloadHash: envelope.PayloadHash,
|
||||
})
|
||||
switch {
|
||||
case err == nil:
|
||||
return nil
|
||||
case errors.Is(err, authn.ErrInvalidClientPublicKey):
|
||||
return status.Error(codes.Unavailable, "session cache is unavailable")
|
||||
case errors.Is(err, authn.ErrInvalidRequestSignature):
|
||||
return status.Error(codes.Unauthenticated, "invalid request signature")
|
||||
default:
|
||||
return status.Error(codes.Internal, "request signature verification failed")
|
||||
}
|
||||
}
|
||||
|
||||
var _ gatewayv1.EdgeGatewayServer = signatureVerifyingService{}
|
||||
@@ -0,0 +1,188 @@
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"galaxy/gateway/internal/session"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestExecuteCommandRejectsInvalidSignature(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
req := newValidExecuteCommandRequest()
|
||||
req.Signature[0] ^= 0xff
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), req)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unauthenticated, status.Code(err))
|
||||
assert.Equal(t, "invalid request signature", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestExecuteCommandRejectsWrongKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{
|
||||
lookupFunc: func(context.Context, string) (session.Record, error) {
|
||||
record := newActiveSessionRecord()
|
||||
record.ClientPublicKey = alternateTestClientPublicKeyBase64()
|
||||
return record, nil
|
||||
},
|
||||
},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unauthenticated, status.Code(err))
|
||||
assert.Equal(t, "invalid request signature", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestExecuteCommandRejectsInvalidCachedPublicKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{
|
||||
lookupFunc: func(context.Context, string) (session.Record, error) {
|
||||
record := newActiveSessionRecord()
|
||||
record.ClientPublicKey = "%%%not-base64%%%"
|
||||
return record, nil
|
||||
},
|
||||
},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unavailable, status.Code(err))
|
||||
assert.Equal(t, "session cache is unavailable", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.executeCalls)
|
||||
}
|
||||
|
||||
func TestSubscribeEventsRejectsInvalidSignature(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
req := newValidSubscribeEventsRequest()
|
||||
req.Signature[0] ^= 0xff
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
err := subscribeEventsError(t, context.Background(), client, req)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unauthenticated, status.Code(err))
|
||||
assert.Equal(t, "invalid request signature", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.subscribeCalls)
|
||||
}
|
||||
|
||||
func TestSubscribeEventsRejectsWrongKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{
|
||||
lookupFunc: func(context.Context, string) (session.Record, error) {
|
||||
record := newActiveSessionRecord()
|
||||
record.ClientPublicKey = alternateTestClientPublicKeyBase64()
|
||||
return record, nil
|
||||
},
|
||||
},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unauthenticated, status.Code(err))
|
||||
assert.Equal(t, "invalid request signature", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.subscribeCalls)
|
||||
}
|
||||
|
||||
func TestSubscribeEventsRejectsInvalidCachedPublicKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
delegate := &recordingEdgeGatewayService{}
|
||||
server, runGateway := newTestGateway(t, ServerDependencies{
|
||||
Service: delegate,
|
||||
SessionCache: staticSessionCache{
|
||||
lookupFunc: func(context.Context, string) (session.Record, error) {
|
||||
record := newActiveSessionRecord()
|
||||
record.ClientPublicKey = "%%%not-base64%%%"
|
||||
return record, nil
|
||||
},
|
||||
},
|
||||
})
|
||||
defer runGateway.stop(t)
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest())
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.Unavailable, status.Code(err))
|
||||
assert.Equal(t, "session cache is unavailable", status.Convert(err).Message())
|
||||
assert.Zero(t, delegate.subscribeCalls)
|
||||
}
|
||||
@@ -0,0 +1,298 @@
|
||||
package grpcapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/authn"
|
||||
"galaxy/gateway/internal/downstream"
|
||||
"galaxy/gateway/internal/session"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
gatewayfbs "galaxy/schema/fbs/gateway"
|
||||
|
||||
flatbuffers "github.com/google/flatbuffers/go"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
testCurrentTime = time.Date(2026, time.April, 1, 12, 0, 0, 0, time.UTC)
|
||||
testFreshnessWindow = 5 * time.Minute
|
||||
)
|
||||
|
||||
func newValidExecuteCommandRequest() *gatewayv1.ExecuteCommandRequest {
|
||||
return newValidExecuteCommandRequestWithSessionAndRequestID("device-session-123", "request-123")
|
||||
}
|
||||
|
||||
func newValidExecuteCommandRequestWithSessionAndRequestID(deviceSessionID string, requestID string) *gatewayv1.ExecuteCommandRequest {
|
||||
return newValidExecuteCommandRequestWithTimestamp(deviceSessionID, requestID, testCurrentTime.UnixMilli())
|
||||
}
|
||||
|
||||
func newValidExecuteCommandRequestWithTimestamp(deviceSessionID string, requestID string, timestampMS int64) *gatewayv1.ExecuteCommandRequest {
|
||||
payloadBytes := []byte("payload")
|
||||
payloadHash := sha256.Sum256(payloadBytes)
|
||||
|
||||
req := &gatewayv1.ExecuteCommandRequest{
|
||||
ProtocolVersion: supportedProtocolVersion,
|
||||
DeviceSessionId: deviceSessionID,
|
||||
MessageType: "fleet.move",
|
||||
TimestampMs: timestampMS,
|
||||
RequestId: requestID,
|
||||
PayloadBytes: payloadBytes,
|
||||
PayloadHash: payloadHash[:],
|
||||
TraceId: "trace-123",
|
||||
}
|
||||
req.Signature = signRequest(req.GetProtocolVersion(), req.GetDeviceSessionId(), req.GetMessageType(), req.GetTimestampMs(), req.GetRequestId(), req.GetPayloadHash())
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
func newValidSubscribeEventsRequest() *gatewayv1.SubscribeEventsRequest {
|
||||
return newValidSubscribeEventsRequestWithSessionAndRequestID("device-session-123", "request-123")
|
||||
}
|
||||
|
||||
func newValidSubscribeEventsRequestWithSessionAndRequestID(deviceSessionID string, requestID string) *gatewayv1.SubscribeEventsRequest {
|
||||
return newValidSubscribeEventsRequestWithTimestamp(deviceSessionID, requestID, testCurrentTime.UnixMilli())
|
||||
}
|
||||
|
||||
func newValidSubscribeEventsRequestWithTimestamp(deviceSessionID string, requestID string, timestampMS int64) *gatewayv1.SubscribeEventsRequest {
|
||||
payloadHash := sha256.Sum256(nil)
|
||||
|
||||
req := &gatewayv1.SubscribeEventsRequest{
|
||||
ProtocolVersion: supportedProtocolVersion,
|
||||
DeviceSessionId: deviceSessionID,
|
||||
MessageType: "gateway.subscribe",
|
||||
TimestampMs: timestampMS,
|
||||
RequestId: requestID,
|
||||
PayloadHash: payloadHash[:],
|
||||
TraceId: "trace-123",
|
||||
}
|
||||
req.Signature = signRequest(req.GetProtocolVersion(), req.GetDeviceSessionId(), req.GetMessageType(), req.GetTimestampMs(), req.GetRequestId(), req.GetPayloadHash())
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
func newActiveSessionRecord() session.Record {
|
||||
return newActiveSessionRecordWithSessionID("device-session-123")
|
||||
}
|
||||
|
||||
func newActiveSessionRecordWithSessionID(deviceSessionID string) session.Record {
|
||||
return session.Record{
|
||||
DeviceSessionID: deviceSessionID,
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: testClientPublicKeyBase64(),
|
||||
Status: session.StatusActive,
|
||||
}
|
||||
}
|
||||
|
||||
func newRevokedSessionRecord() session.Record {
|
||||
revokedAtMS := int64(123456789)
|
||||
|
||||
return session.Record{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: testClientPublicKeyBase64(),
|
||||
Status: session.StatusRevoked,
|
||||
RevokedAtMS: &revokedAtMS,
|
||||
}
|
||||
}
|
||||
|
||||
func alternateTestClientPublicKeyBase64() string {
|
||||
return base64.StdEncoding.EncodeToString(newTestPrivateKey("alternate").Public().(ed25519.PublicKey))
|
||||
}
|
||||
|
||||
func testClientPublicKeyBase64() string {
|
||||
return base64.StdEncoding.EncodeToString(newTestPrivateKey("primary").Public().(ed25519.PublicKey))
|
||||
}
|
||||
|
||||
func signRequest(protocolVersion, deviceSessionID, messageType string, timestampMS int64, requestID string, payloadHash []byte) []byte {
|
||||
return ed25519.Sign(newTestPrivateKey("primary"), authn.BuildRequestSigningInput(authn.RequestSigningFields{
|
||||
ProtocolVersion: protocolVersion,
|
||||
DeviceSessionID: deviceSessionID,
|
||||
MessageType: messageType,
|
||||
TimestampMS: timestampMS,
|
||||
RequestID: requestID,
|
||||
PayloadHash: payloadHash,
|
||||
}))
|
||||
}
|
||||
|
||||
func newTestPrivateKey(label string) ed25519.PrivateKey {
|
||||
seed := sha256.Sum256([]byte("gateway-grpcapi-signature-test-" + label))
|
||||
return ed25519.NewKeyFromSeed(seed[:])
|
||||
}
|
||||
|
||||
func newTestEd25519ResponseSigner() *authn.Ed25519ResponseSigner {
|
||||
pemBytes := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "PRIVATE KEY",
|
||||
Bytes: mustMarshalPKCS8PrivateKey(newTestPrivateKey("response-signer")),
|
||||
})
|
||||
|
||||
signer, err := authn.ParseEd25519ResponseSignerPEM(pemBytes)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return signer
|
||||
}
|
||||
|
||||
func newTestResponseSigner() authn.ResponseSigner {
|
||||
return newTestEd25519ResponseSigner()
|
||||
}
|
||||
|
||||
func newTestResponseSignerPublicKey() ed25519.PublicKey {
|
||||
return newTestEd25519ResponseSigner().PublicKey()
|
||||
}
|
||||
|
||||
func mustMarshalPKCS8PrivateKey(privateKey ed25519.PrivateKey) []byte {
|
||||
encoded, err := x509.MarshalPKCS8PrivateKey(privateKey)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return encoded
|
||||
}
|
||||
|
||||
type fixedClock struct {
|
||||
now time.Time
|
||||
}
|
||||
|
||||
func (c fixedClock) Now() time.Time {
|
||||
return c.now
|
||||
}
|
||||
|
||||
func recvBootstrapEvent(t interface {
|
||||
require.TestingT
|
||||
Helper()
|
||||
}, stream grpc.ServerStreamingClient[gatewayv1.GatewayEvent]) *gatewayv1.GatewayEvent {
|
||||
t.Helper()
|
||||
|
||||
event, err := stream.Recv()
|
||||
require.NoError(t, err)
|
||||
|
||||
return event
|
||||
}
|
||||
|
||||
func subscribeEventsError(t interface {
|
||||
require.TestingT
|
||||
Helper()
|
||||
}, ctx context.Context, client gatewayv1.EdgeGatewayClient, req *gatewayv1.SubscribeEventsRequest) error {
|
||||
t.Helper()
|
||||
|
||||
stream, err := client.SubscribeEvents(ctx, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = stream.Recv()
|
||||
return err
|
||||
}
|
||||
|
||||
func assertServerTimeBootstrapEvent(t interface {
|
||||
require.TestingT
|
||||
Helper()
|
||||
}, event *gatewayv1.GatewayEvent, publicKey ed25519.PublicKey, wantRequestID string, wantTraceID string, wantTimestampMS int64) {
|
||||
t.Helper()
|
||||
|
||||
require.NotNil(t, event)
|
||||
assert.Equal(t, serverTimeEventType, event.GetEventType())
|
||||
assert.Equal(t, wantRequestID, event.GetEventId())
|
||||
assert.Equal(t, wantRequestID, event.GetRequestId())
|
||||
assert.Equal(t, wantTraceID, event.GetTraceId())
|
||||
assert.Equal(t, wantTimestampMS, event.GetTimestampMs())
|
||||
require.NoError(t, authn.VerifyPayloadHash(event.GetPayloadBytes(), event.GetPayloadHash()))
|
||||
require.NoError(t, authn.VerifyEventSignature(publicKey, event.GetSignature(), authn.EventSigningFields{
|
||||
EventType: event.GetEventType(),
|
||||
EventID: event.GetEventId(),
|
||||
TimestampMS: event.GetTimestampMs(),
|
||||
RequestID: event.GetRequestId(),
|
||||
TraceID: event.GetTraceId(),
|
||||
PayloadHash: event.GetPayloadHash(),
|
||||
}))
|
||||
|
||||
payload := gatewayfbs.GetRootAsServerTimeEvent(event.GetPayloadBytes(), flatbuffers.UOffsetT(0))
|
||||
assert.Equal(t, wantTimestampMS, payload.ServerTimeMs())
|
||||
}
|
||||
|
||||
type staticReplayStore struct {
|
||||
reserveFunc func(context.Context, string, string, time.Duration) error
|
||||
}
|
||||
|
||||
func (s staticReplayStore) Reserve(ctx context.Context, deviceSessionID string, requestID string, ttl time.Duration) error {
|
||||
if s.reserveFunc != nil {
|
||||
return s.reserveFunc(ctx, deviceSessionID, requestID, ttl)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type executeCommandAdapterRouter struct {
|
||||
service gatewayv1.EdgeGatewayServer
|
||||
}
|
||||
|
||||
func (r executeCommandAdapterRouter) Route(string) (downstream.Client, error) {
|
||||
return executeCommandAdapterClient{service: r.service}, nil
|
||||
}
|
||||
|
||||
type executeCommandAdapterClient struct {
|
||||
service gatewayv1.EdgeGatewayServer
|
||||
}
|
||||
|
||||
func (c executeCommandAdapterClient) ExecuteCommand(ctx context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
|
||||
response, err := c.service.ExecuteCommand(ctx, &gatewayv1.ExecuteCommandRequest{
|
||||
ProtocolVersion: command.ProtocolVersion,
|
||||
DeviceSessionId: command.DeviceSessionID,
|
||||
MessageType: command.MessageType,
|
||||
TimestampMs: command.TimestampMS,
|
||||
RequestId: command.RequestID,
|
||||
PayloadBytes: command.PayloadBytes,
|
||||
TraceId: command.TraceID,
|
||||
})
|
||||
if err != nil {
|
||||
return downstream.UnaryResult{}, err
|
||||
}
|
||||
|
||||
resultCode := response.GetResultCode()
|
||||
if resultCode == "" {
|
||||
resultCode = "ok"
|
||||
}
|
||||
|
||||
return downstream.UnaryResult{
|
||||
ResultCode: resultCode,
|
||||
PayloadBytes: response.GetPayloadBytes(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type recordingDownstreamClient struct {
|
||||
executeCalls int
|
||||
commands []downstream.AuthenticatedCommand
|
||||
executeFunc func(context.Context, downstream.AuthenticatedCommand) (downstream.UnaryResult, error)
|
||||
}
|
||||
|
||||
func (c *recordingDownstreamClient) ExecuteCommand(ctx context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
|
||||
c.executeCalls++
|
||||
c.commands = append(c.commands, downstream.AuthenticatedCommand{
|
||||
ProtocolVersion: command.ProtocolVersion,
|
||||
UserID: command.UserID,
|
||||
DeviceSessionID: command.DeviceSessionID,
|
||||
MessageType: command.MessageType,
|
||||
TimestampMS: command.TimestampMS,
|
||||
RequestID: command.RequestID,
|
||||
TraceID: command.TraceID,
|
||||
PayloadBytes: append([]byte(nil), command.PayloadBytes...),
|
||||
})
|
||||
if c.executeFunc != nil {
|
||||
return c.executeFunc(ctx, command)
|
||||
}
|
||||
|
||||
return downstream.UnaryResult{
|
||||
ResultCode: "ok",
|
||||
PayloadBytes: []byte("response"),
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,84 @@
|
||||
// Package logging configures the gateway structured logger and provides
|
||||
// context-aware helpers for attaching OpenTelemetry trace identifiers.
|
||||
package logging
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"galaxy/gateway/internal/config"
|
||||
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
// New constructs the process-wide JSON logger from cfg.
|
||||
func New(cfg config.LoggingConfig) (*zap.Logger, error) {
|
||||
level := zap.NewAtomicLevel()
|
||||
if err := level.UnmarshalText([]byte(strings.TrimSpace(cfg.Level))); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
zapCfg := zap.NewProductionConfig()
|
||||
zapCfg.Level = level
|
||||
zapCfg.Sampling = nil
|
||||
zapCfg.Encoding = "json"
|
||||
zapCfg.EncoderConfig.TimeKey = "timestamp"
|
||||
zapCfg.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
|
||||
zapCfg.OutputPaths = []string{"stdout"}
|
||||
zapCfg.ErrorOutputPaths = []string{"stderr"}
|
||||
|
||||
return zapCfg.Build()
|
||||
}
|
||||
|
||||
// TraceFieldsFromContext returns zap fields for the active OpenTelemetry span
|
||||
// when ctx carries a valid span context.
|
||||
func TraceFieldsFromContext(ctx context.Context) []zap.Field {
|
||||
if ctx == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
spanContext := trace.SpanContextFromContext(ctx)
|
||||
if !spanContext.IsValid() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return []zap.Field{
|
||||
zap.String("otel_trace_id", spanContext.TraceID().String()),
|
||||
zap.String("otel_span_id", spanContext.SpanID().String()),
|
||||
}
|
||||
}
|
||||
|
||||
// Sync flushes logger and ignores the benign stdout or stderr sync errors
|
||||
// commonly returned by containerized or redirected process outputs.
|
||||
func Sync(logger *zap.Logger) error {
|
||||
if logger == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := logger.Sync()
|
||||
if err == nil || isIgnorableSyncError(err) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func isIgnorableSyncError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
message := strings.ToLower(err.Error())
|
||||
switch {
|
||||
case strings.Contains(message, "invalid argument"):
|
||||
return true
|
||||
case strings.Contains(message, "bad file descriptor"):
|
||||
return true
|
||||
case strings.Contains(message, "inappropriate ioctl for device"):
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,385 @@
|
||||
// Package push provides the in-memory hub used to fan out internal
|
||||
// client-facing events to active authenticated push streams.
|
||||
package push
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const defaultSubscriptionQueueCapacity = 64
|
||||
|
||||
var (
|
||||
// ErrSubscriptionOverflow reports that one push stream stopped consuming
|
||||
// events quickly enough and its bounded queue overflowed.
|
||||
ErrSubscriptionOverflow = errors.New("push stream overflowed")
|
||||
|
||||
// ErrSubscriptionRevoked reports that the authenticated device session bound
|
||||
// to the push stream was revoked and the stream must terminate.
|
||||
ErrSubscriptionRevoked = errors.New("device session is revoked")
|
||||
|
||||
// ErrHubShuttingDown reports that the gateway is shutting down and all
|
||||
// active push streams must terminate promptly.
|
||||
ErrHubShuttingDown = errors.New("gateway is shutting down")
|
||||
)
|
||||
|
||||
// StreamBinding identifies one authenticated push stream tracked by Hub.
|
||||
type StreamBinding struct {
|
||||
// UserID is the verified authenticated user bound to the stream.
|
||||
UserID string
|
||||
|
||||
// DeviceSessionID is the verified authenticated device session bound to the
|
||||
// stream.
|
||||
DeviceSessionID string
|
||||
}
|
||||
|
||||
// Event is the internal client-facing event delivered from internal pub/sub to
|
||||
// active push streams.
|
||||
type Event struct {
|
||||
// UserID identifies the authenticated user that should receive the event.
|
||||
UserID string
|
||||
|
||||
// DeviceSessionID optionally narrows delivery to one device session.
|
||||
DeviceSessionID string
|
||||
|
||||
// EventType identifies the stable client-facing event category.
|
||||
EventType string
|
||||
|
||||
// EventID is the stable event correlation identifier.
|
||||
EventID string
|
||||
|
||||
// PayloadBytes carries the opaque event payload bytes.
|
||||
PayloadBytes []byte
|
||||
|
||||
// RequestID optionally correlates the event to an earlier client request.
|
||||
RequestID string
|
||||
|
||||
// TraceID optionally carries tracing correlation.
|
||||
TraceID string
|
||||
}
|
||||
|
||||
// Subscription represents one active push stream registered in Hub.
|
||||
type Subscription struct {
|
||||
hub *Hub
|
||||
id uint64
|
||||
binding StreamBinding
|
||||
events chan Event
|
||||
done chan struct{}
|
||||
|
||||
closeOnce sync.Once
|
||||
stateMu sync.RWMutex
|
||||
err error
|
||||
}
|
||||
|
||||
// Observer receives push stream lifecycle notifications suitable for metrics
|
||||
// bookkeeping.
|
||||
type Observer interface {
|
||||
// Registered reports one active push stream binding.
|
||||
Registered(binding StreamBinding)
|
||||
|
||||
// Unregistered reports that binding stopped with err. A nil err means the
|
||||
// stream ended without a hub-enforced terminal reason.
|
||||
Unregistered(binding StreamBinding, err error)
|
||||
}
|
||||
|
||||
// Events returns the ordered event queue for the subscription.
|
||||
func (s *Subscription) Events() <-chan Event {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return s.events
|
||||
}
|
||||
|
||||
// Done closes when the subscription has been removed from the hub.
|
||||
func (s *Subscription) Done() <-chan struct{} {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return s.done
|
||||
}
|
||||
|
||||
// Err returns the terminal subscription error, if any.
|
||||
func (s *Subscription) Err() error {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.stateMu.RLock()
|
||||
defer s.stateMu.RUnlock()
|
||||
|
||||
return s.err
|
||||
}
|
||||
|
||||
// Close unregisters the subscription from its hub.
|
||||
func (s *Subscription) Close() {
|
||||
if s == nil || s.hub == nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.hub.unregister(s.id, nil)
|
||||
}
|
||||
|
||||
func (s *Subscription) enqueue(event Event) bool {
|
||||
if s == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
cloned := cloneEvent(event)
|
||||
|
||||
select {
|
||||
case <-s.done:
|
||||
return true
|
||||
default:
|
||||
}
|
||||
|
||||
select {
|
||||
case s.events <- cloned:
|
||||
return true
|
||||
case <-s.done:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Subscription) closeWithError(err error) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.closeOnce.Do(func() {
|
||||
s.stateMu.Lock()
|
||||
s.err = err
|
||||
s.stateMu.Unlock()
|
||||
close(s.done)
|
||||
})
|
||||
}
|
||||
|
||||
// Hub tracks active authenticated push streams and fans out client-facing
|
||||
// events to the matching subscriptions.
|
||||
type Hub struct {
|
||||
mu sync.RWMutex
|
||||
nextID uint64
|
||||
queueCapacity int
|
||||
observer Observer
|
||||
byID map[uint64]*Subscription
|
||||
byUser map[string]map[uint64]*Subscription
|
||||
bySession map[string]map[uint64]*Subscription
|
||||
}
|
||||
|
||||
// NewHub constructs a push hub with one bounded in-memory queue per
|
||||
// subscription. Non-positive queueCapacity falls back to the package default.
|
||||
func NewHub(queueCapacity int) *Hub {
|
||||
return NewHubWithObserver(queueCapacity, nil)
|
||||
}
|
||||
|
||||
// NewHubWithObserver constructs a push hub that also reports stream lifecycle
|
||||
// changes to observer.
|
||||
func NewHubWithObserver(queueCapacity int, observer Observer) *Hub {
|
||||
if queueCapacity <= 0 {
|
||||
queueCapacity = defaultSubscriptionQueueCapacity
|
||||
}
|
||||
|
||||
return &Hub{
|
||||
queueCapacity: queueCapacity,
|
||||
observer: observer,
|
||||
byID: make(map[uint64]*Subscription),
|
||||
byUser: make(map[string]map[uint64]*Subscription),
|
||||
bySession: make(map[string]map[uint64]*Subscription),
|
||||
}
|
||||
}
|
||||
|
||||
// Register adds one authenticated push stream to the hub and returns its
|
||||
// subscription handle.
|
||||
func (h *Hub) Register(binding StreamBinding) (*Subscription, error) {
|
||||
if h == nil {
|
||||
return nil, errors.New("register push subscription: nil hub")
|
||||
}
|
||||
|
||||
userID := strings.TrimSpace(binding.UserID)
|
||||
if userID == "" {
|
||||
return nil, errors.New("register push subscription: user id must not be empty")
|
||||
}
|
||||
|
||||
deviceSessionID := strings.TrimSpace(binding.DeviceSessionID)
|
||||
if deviceSessionID == "" {
|
||||
return nil, errors.New("register push subscription: device session id must not be empty")
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
|
||||
h.nextID++
|
||||
subscription := &Subscription{
|
||||
hub: h,
|
||||
id: h.nextID,
|
||||
binding: StreamBinding{
|
||||
UserID: userID,
|
||||
DeviceSessionID: deviceSessionID,
|
||||
},
|
||||
events: make(chan Event, h.queueCapacity),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
h.byID[subscription.id] = subscription
|
||||
addIndexedSubscription(h.byUser, userID, subscription)
|
||||
addIndexedSubscription(h.bySession, deviceSessionID, subscription)
|
||||
h.mu.Unlock()
|
||||
|
||||
if h.observer != nil {
|
||||
h.observer.Registered(subscription.binding)
|
||||
}
|
||||
|
||||
return subscription, nil
|
||||
}
|
||||
|
||||
// Publish fans out event to the matching active subscriptions. When one
|
||||
// subscription queue overflows, only that subscription is closed.
|
||||
func (h *Hub) Publish(event Event) {
|
||||
if h == nil {
|
||||
return
|
||||
}
|
||||
|
||||
targets := h.targets(event)
|
||||
for _, target := range targets {
|
||||
if target.enqueue(event) {
|
||||
continue
|
||||
}
|
||||
|
||||
h.unregister(target.id, ErrSubscriptionOverflow)
|
||||
}
|
||||
}
|
||||
|
||||
// RevokeDeviceSession closes all active subscriptions bound to the exact
|
||||
// authenticated device session identifier.
|
||||
func (h *Hub) RevokeDeviceSession(deviceSessionID string) {
|
||||
if h == nil {
|
||||
return
|
||||
}
|
||||
|
||||
deviceSessionID = strings.TrimSpace(deviceSessionID)
|
||||
if deviceSessionID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
h.mu.RLock()
|
||||
targets := cloneSubscriptions(h.bySession[deviceSessionID])
|
||||
h.mu.RUnlock()
|
||||
|
||||
for _, target := range targets {
|
||||
h.unregister(target.id, ErrSubscriptionRevoked)
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown closes every active subscription because the gateway is shutting
|
||||
// down.
|
||||
func (h *Hub) Shutdown() {
|
||||
if h == nil {
|
||||
return
|
||||
}
|
||||
|
||||
h.mu.RLock()
|
||||
targets := cloneSubscriptions(h.byID)
|
||||
h.mu.RUnlock()
|
||||
|
||||
for _, target := range targets {
|
||||
h.unregister(target.id, ErrHubShuttingDown)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hub) targets(event Event) []*Subscription {
|
||||
userID := strings.TrimSpace(event.UserID)
|
||||
eventType := strings.TrimSpace(event.EventType)
|
||||
eventID := strings.TrimSpace(event.EventID)
|
||||
if h == nil || userID == "" || eventType == "" || eventID == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
deviceSessionID := strings.TrimSpace(event.DeviceSessionID)
|
||||
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
if deviceSessionID == "" {
|
||||
return cloneSubscriptions(h.byUser[userID])
|
||||
}
|
||||
|
||||
sessionMatches := cloneSubscriptions(h.bySession[deviceSessionID])
|
||||
filtered := sessionMatches[:0]
|
||||
for _, subscription := range sessionMatches {
|
||||
if subscription.binding.UserID == userID {
|
||||
filtered = append(filtered, subscription)
|
||||
}
|
||||
}
|
||||
|
||||
return filtered
|
||||
}
|
||||
|
||||
func (h *Hub) unregister(id uint64, err error) {
|
||||
if h == nil || id == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
subscription, ok := h.byID[id]
|
||||
if !ok {
|
||||
h.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
delete(h.byID, id)
|
||||
removeIndexedSubscription(h.byUser, subscription.binding.UserID, id)
|
||||
removeIndexedSubscription(h.bySession, subscription.binding.DeviceSessionID, id)
|
||||
h.mu.Unlock()
|
||||
|
||||
subscription.closeWithError(err)
|
||||
if h.observer != nil {
|
||||
h.observer.Unregistered(subscription.binding, err)
|
||||
}
|
||||
}
|
||||
|
||||
func addIndexedSubscription(index map[string]map[uint64]*Subscription, key string, subscription *Subscription) {
|
||||
if _, ok := index[key]; !ok {
|
||||
index[key] = make(map[uint64]*Subscription)
|
||||
}
|
||||
index[key][subscription.id] = subscription
|
||||
}
|
||||
|
||||
func removeIndexedSubscription(index map[string]map[uint64]*Subscription, key string, id uint64) {
|
||||
bucket, ok := index[key]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
delete(bucket, id)
|
||||
if len(bucket) == 0 {
|
||||
delete(index, key)
|
||||
}
|
||||
}
|
||||
|
||||
func cloneSubscriptions(bucket map[uint64]*Subscription) []*Subscription {
|
||||
if len(bucket) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
cloned := make([]*Subscription, 0, len(bucket))
|
||||
for _, subscription := range bucket {
|
||||
cloned = append(cloned, subscription)
|
||||
}
|
||||
|
||||
return cloned
|
||||
}
|
||||
|
||||
func cloneEvent(event Event) Event {
|
||||
return Event{
|
||||
UserID: event.UserID,
|
||||
DeviceSessionID: event.DeviceSessionID,
|
||||
EventType: event.EventType,
|
||||
EventID: event.EventID,
|
||||
PayloadBytes: bytes.Clone(event.PayloadBytes),
|
||||
RequestID: event.RequestID,
|
||||
TraceID: event.TraceID,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
package push_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/push"
|
||||
"galaxy/gateway/internal/telemetry"
|
||||
"galaxy/gateway/internal/testutil"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHubObserverClassifiesClosureReasons(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger, _ := testutil.NewObservedLogger(t)
|
||||
telemetryRuntime := testutil.NewTelemetryRuntime(t, logger)
|
||||
hub := push.NewHubWithObserver(1, telemetry.NewPushObserver(telemetryRuntime))
|
||||
|
||||
overflow, err := hub.Register(push.StreamBinding{
|
||||
UserID: "user-123",
|
||||
DeviceSessionID: "device-session-overflow",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
revoked, err := hub.Register(push.StreamBinding{
|
||||
UserID: "user-123",
|
||||
DeviceSessionID: "device-session-revoked",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
shutdown, err := hub.Register(push.StreamBinding{
|
||||
UserID: "user-123",
|
||||
DeviceSessionID: "device-session-shutdown",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
hub.Publish(push.Event{
|
||||
UserID: "user-123",
|
||||
DeviceSessionID: "device-session-overflow",
|
||||
EventType: "fleet.updated",
|
||||
EventID: "event-1",
|
||||
PayloadBytes: []byte("payload-1"),
|
||||
})
|
||||
hub.Publish(push.Event{
|
||||
UserID: "user-123",
|
||||
DeviceSessionID: "device-session-overflow",
|
||||
EventType: "fleet.updated",
|
||||
EventID: "event-2",
|
||||
PayloadBytes: []byte("payload-2"),
|
||||
})
|
||||
hub.RevokeDeviceSession("device-session-revoked")
|
||||
hub.Shutdown()
|
||||
|
||||
select {
|
||||
case <-overflow.Done():
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "overflow subscription did not close")
|
||||
}
|
||||
select {
|
||||
case <-revoked.Done():
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "revoked subscription did not close")
|
||||
}
|
||||
select {
|
||||
case <-shutdown.Done():
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "shutdown subscription did not close")
|
||||
}
|
||||
|
||||
metricsText := testutil.ScrapeMetrics(t, telemetryRuntime.Handler())
|
||||
assert.Contains(t, metricsText, `gateway_push_stream_closures_total`)
|
||||
assert.Contains(t, metricsText, `reason="overflow"`)
|
||||
assert.Contains(t, metricsText, `reason="revoked"`)
|
||||
assert.Contains(t, metricsText, `reason="shutdown"`)
|
||||
assert.Contains(t, metricsText, `gateway_push_active_streams`)
|
||||
}
|
||||
@@ -0,0 +1,270 @@
|
||||
package push
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHubDeliversSessionTargetedEvent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
hub := NewHub(4)
|
||||
target, err := hub.Register(StreamBinding{
|
||||
UserID: "user-123",
|
||||
DeviceSessionID: "device-session-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
otherSession, err := hub.Register(StreamBinding{
|
||||
UserID: "user-123",
|
||||
DeviceSessionID: "device-session-2",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
unrelatedUser, err := hub.Register(StreamBinding{
|
||||
UserID: "user-999",
|
||||
DeviceSessionID: "device-session-3",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
hub.Publish(Event{
|
||||
UserID: "user-123",
|
||||
DeviceSessionID: "device-session-1",
|
||||
EventType: "fleet.updated",
|
||||
EventID: "event-1",
|
||||
PayloadBytes: []byte("payload-1"),
|
||||
})
|
||||
|
||||
assertEvent(t, target.Events(), Event{
|
||||
UserID: "user-123",
|
||||
DeviceSessionID: "device-session-1",
|
||||
EventType: "fleet.updated",
|
||||
EventID: "event-1",
|
||||
PayloadBytes: []byte("payload-1"),
|
||||
})
|
||||
assertNoEvent(t, otherSession.Events())
|
||||
assertNoEvent(t, unrelatedUser.Events())
|
||||
}
|
||||
|
||||
func TestHubDeliversUserTargetedEventToAllUserSessions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
hub := NewHub(4)
|
||||
first, err := hub.Register(StreamBinding{
|
||||
UserID: "user-123",
|
||||
DeviceSessionID: "device-session-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
second, err := hub.Register(StreamBinding{
|
||||
UserID: "user-123",
|
||||
DeviceSessionID: "device-session-2",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
unrelated, err := hub.Register(StreamBinding{
|
||||
UserID: "user-999",
|
||||
DeviceSessionID: "device-session-3",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
hub.Publish(Event{
|
||||
UserID: "user-123",
|
||||
EventType: "fleet.updated",
|
||||
EventID: "event-1",
|
||||
PayloadBytes: []byte("payload-1"),
|
||||
RequestID: "request-1",
|
||||
TraceID: "trace-1",
|
||||
})
|
||||
|
||||
want := Event{
|
||||
UserID: "user-123",
|
||||
EventType: "fleet.updated",
|
||||
EventID: "event-1",
|
||||
PayloadBytes: []byte("payload-1"),
|
||||
RequestID: "request-1",
|
||||
TraceID: "trace-1",
|
||||
}
|
||||
assertEvent(t, first.Events(), want)
|
||||
assertEvent(t, second.Events(), want)
|
||||
assertNoEvent(t, unrelated.Events())
|
||||
}
|
||||
|
||||
func TestSubscriptionCloseUnregistersStream(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
hub := NewHub(4)
|
||||
subscription, err := hub.Register(StreamBinding{
|
||||
UserID: "user-123",
|
||||
DeviceSessionID: "device-session-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
subscription.Close()
|
||||
|
||||
select {
|
||||
case <-subscription.Done():
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "subscription did not close")
|
||||
}
|
||||
|
||||
hub.Publish(Event{
|
||||
UserID: "user-123",
|
||||
EventType: "fleet.updated",
|
||||
EventID: "event-1",
|
||||
PayloadBytes: []byte("payload-1"),
|
||||
})
|
||||
|
||||
assertNoEvent(t, subscription.Events())
|
||||
assert.NoError(t, subscription.Err())
|
||||
}
|
||||
|
||||
func TestHubOverflowClosesOnlySlowSubscription(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
hub := NewHub(1)
|
||||
slow, err := hub.Register(StreamBinding{
|
||||
UserID: "user-123",
|
||||
DeviceSessionID: "device-session-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
fast, err := hub.Register(StreamBinding{
|
||||
UserID: "user-123",
|
||||
DeviceSessionID: "device-session-2",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
hub.Publish(Event{
|
||||
UserID: "user-123",
|
||||
EventType: "fleet.updated",
|
||||
EventID: "event-1",
|
||||
PayloadBytes: []byte("payload-1"),
|
||||
})
|
||||
assertEvent(t, fast.Events(), Event{
|
||||
UserID: "user-123",
|
||||
EventType: "fleet.updated",
|
||||
EventID: "event-1",
|
||||
PayloadBytes: []byte("payload-1"),
|
||||
})
|
||||
|
||||
hub.Publish(Event{
|
||||
UserID: "user-123",
|
||||
EventType: "fleet.updated",
|
||||
EventID: "event-2",
|
||||
PayloadBytes: []byte("payload-2"),
|
||||
})
|
||||
|
||||
select {
|
||||
case <-slow.Done():
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "slow subscription did not close after overflow")
|
||||
}
|
||||
|
||||
assert.ErrorIs(t, slow.Err(), ErrSubscriptionOverflow)
|
||||
assertEvent(t, fast.Events(), Event{
|
||||
UserID: "user-123",
|
||||
EventType: "fleet.updated",
|
||||
EventID: "event-2",
|
||||
PayloadBytes: []byte("payload-2"),
|
||||
})
|
||||
}
|
||||
|
||||
func TestHubRevokeDeviceSessionClosesOnlyMatchingSubscriptions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
hub := NewHub(4)
|
||||
targetOne, err := hub.Register(StreamBinding{
|
||||
UserID: "user-123",
|
||||
DeviceSessionID: "device-session-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
targetTwo, err := hub.Register(StreamBinding{
|
||||
UserID: "user-456",
|
||||
DeviceSessionID: "device-session-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
otherSession, err := hub.Register(StreamBinding{
|
||||
UserID: "user-123",
|
||||
DeviceSessionID: "device-session-2",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
hub.RevokeDeviceSession("device-session-1")
|
||||
|
||||
select {
|
||||
case <-targetOne.Done():
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "first matching subscription did not close after revoke")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-targetTwo.Done():
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "second matching subscription did not close after revoke")
|
||||
}
|
||||
|
||||
assert.ErrorIs(t, targetOne.Err(), ErrSubscriptionRevoked)
|
||||
assert.ErrorIs(t, targetTwo.Err(), ErrSubscriptionRevoked)
|
||||
|
||||
select {
|
||||
case <-otherSession.Done():
|
||||
require.FailNow(t, "unrelated session subscription closed after revoke")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
|
||||
hub.Publish(Event{
|
||||
UserID: "user-123",
|
||||
DeviceSessionID: "device-session-2",
|
||||
EventType: "fleet.updated",
|
||||
EventID: "event-1",
|
||||
PayloadBytes: []byte("payload-1"),
|
||||
})
|
||||
|
||||
assertEvent(t, otherSession.Events(), Event{
|
||||
UserID: "user-123",
|
||||
DeviceSessionID: "device-session-2",
|
||||
EventType: "fleet.updated",
|
||||
EventID: "event-1",
|
||||
PayloadBytes: []byte("payload-1"),
|
||||
})
|
||||
}
|
||||
|
||||
func TestHubRevokeDeviceSessionIgnoresUnknownOrEmptySession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
hub := NewHub(4)
|
||||
subscription, err := hub.Register(StreamBinding{
|
||||
UserID: "user-123",
|
||||
DeviceSessionID: "device-session-1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
hub.RevokeDeviceSession("")
|
||||
hub.RevokeDeviceSession("missing-session")
|
||||
|
||||
select {
|
||||
case <-subscription.Done():
|
||||
require.FailNow(t, "subscription closed for empty or unknown session revoke")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func assertEvent(t *testing.T, eventCh <-chan Event, want Event) {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
case got := <-eventCh:
|
||||
assert.Equal(t, want, got)
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "event was not delivered")
|
||||
}
|
||||
}
|
||||
|
||||
func assertNoEvent(t *testing.T, eventCh <-chan Event) {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
case got := <-eventCh:
|
||||
require.FailNowf(t, "unexpected event delivered", "%+v", got)
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,136 @@
|
||||
// Package ratelimit provides small process-local rate-limit primitives used by
|
||||
// the gateway edge policy layers.
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// Policy describes one token-bucket budget enforced for a concrete key.
|
||||
type Policy struct {
|
||||
// Requests is the number of accepted requests replenished per Window.
|
||||
Requests int
|
||||
|
||||
// Window is the interval over which Requests are replenished.
|
||||
Window time.Duration
|
||||
|
||||
// Burst is the maximum number of immediately available tokens.
|
||||
Burst int
|
||||
}
|
||||
|
||||
// Decision describes the result of one limiter reservation attempt.
|
||||
type Decision struct {
|
||||
// Allowed reports whether the request may proceed immediately.
|
||||
Allowed bool
|
||||
|
||||
// RetryAfter is the minimum delay the caller should wait before retrying
|
||||
// when Allowed is false.
|
||||
RetryAfter time.Duration
|
||||
}
|
||||
|
||||
// Limiter applies a policy to one concrete key.
|
||||
type Limiter interface {
|
||||
// Reserve evaluates key under policy and reports whether the request may
|
||||
// proceed immediately.
|
||||
Reserve(key string, policy Policy) Decision
|
||||
}
|
||||
|
||||
// InMemory is a process-local Limiter backed by x/time/rate token buckets.
|
||||
type InMemory struct {
|
||||
now func() time.Time
|
||||
cleanupInterval time.Duration
|
||||
|
||||
mu sync.Mutex
|
||||
entries map[string]*entry
|
||||
nextCleanup time.Time
|
||||
}
|
||||
|
||||
type entry struct {
|
||||
limiter *rate.Limiter
|
||||
limit rate.Limit
|
||||
burst int
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
// NewInMemory constructs a process-local limiter suitable for one gateway
|
||||
// process instance.
|
||||
func NewInMemory() *InMemory {
|
||||
return &InMemory{
|
||||
now: time.Now,
|
||||
cleanupInterval: time.Minute,
|
||||
entries: make(map[string]*entry),
|
||||
}
|
||||
}
|
||||
|
||||
// Reserve evaluates key against policy and reports whether the request may
|
||||
// proceed immediately.
|
||||
func (l *InMemory) Reserve(key string, policy Policy) Decision {
|
||||
if policy.Requests <= 0 || policy.Window <= 0 || policy.Burst <= 0 {
|
||||
return Decision{}
|
||||
}
|
||||
|
||||
now := l.now()
|
||||
limit := rate.Limit(float64(policy.Requests) / policy.Window.Seconds())
|
||||
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
l.cleanupExpiredBucketsLocked(now)
|
||||
|
||||
current, ok := l.entries[key]
|
||||
if !ok || current.limit != limit || current.burst != policy.Burst {
|
||||
current = &entry{
|
||||
limiter: rate.NewLimiter(limit, policy.Burst),
|
||||
limit: limit,
|
||||
burst: policy.Burst,
|
||||
}
|
||||
l.entries[key] = current
|
||||
}
|
||||
|
||||
current.expiresAt = now.Add(entryTTL(policy.Window))
|
||||
|
||||
reservation := current.limiter.ReserveN(now, 1)
|
||||
if !reservation.OK() {
|
||||
return Decision{
|
||||
Allowed: false,
|
||||
RetryAfter: policy.Window,
|
||||
}
|
||||
}
|
||||
|
||||
retryAfter := reservation.DelayFrom(now)
|
||||
if retryAfter > 0 {
|
||||
return Decision{
|
||||
Allowed: false,
|
||||
RetryAfter: retryAfter,
|
||||
}
|
||||
}
|
||||
|
||||
return Decision{Allowed: true}
|
||||
}
|
||||
|
||||
func (l *InMemory) cleanupExpiredBucketsLocked(now time.Time) {
|
||||
if !l.nextCleanup.IsZero() && now.Before(l.nextCleanup) {
|
||||
return
|
||||
}
|
||||
|
||||
for key, current := range l.entries {
|
||||
if !current.expiresAt.After(now) {
|
||||
delete(l.entries, key)
|
||||
}
|
||||
}
|
||||
|
||||
l.nextCleanup = now.Add(l.cleanupInterval)
|
||||
}
|
||||
|
||||
func entryTTL(window time.Duration) time.Duration {
|
||||
if window < time.Minute {
|
||||
return time.Minute
|
||||
}
|
||||
|
||||
return 2 * window
|
||||
}
|
||||
|
||||
var _ Limiter = (*InMemory)(nil)
|
||||
@@ -0,0 +1,49 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestInMemoryReserve(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
limiter := NewInMemory()
|
||||
policy := Policy{
|
||||
Requests: 1,
|
||||
Window: time.Hour,
|
||||
Burst: 1,
|
||||
}
|
||||
|
||||
first := limiter.Reserve("bucket-1", policy)
|
||||
second := limiter.Reserve("bucket-1", policy)
|
||||
otherBucket := limiter.Reserve("bucket-2", policy)
|
||||
|
||||
assert.True(t, first.Allowed)
|
||||
assert.False(t, second.Allowed)
|
||||
assert.Positive(t, second.RetryAfter)
|
||||
assert.True(t, otherBucket.Allowed)
|
||||
}
|
||||
|
||||
func TestInMemoryReserveResetsOnPolicyChange(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
limiter := NewInMemory()
|
||||
|
||||
initialPolicy := Policy{
|
||||
Requests: 1,
|
||||
Window: time.Hour,
|
||||
Burst: 1,
|
||||
}
|
||||
updatedPolicy := Policy{
|
||||
Requests: 2,
|
||||
Window: time.Hour,
|
||||
Burst: 2,
|
||||
}
|
||||
|
||||
assert.True(t, limiter.Reserve("bucket-1", initialPolicy).Allowed)
|
||||
assert.False(t, limiter.Reserve("bucket-1", initialPolicy).Allowed)
|
||||
assert.True(t, limiter.Reserve("bucket-1", updatedPolicy).Allowed)
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
package replay
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/config"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// RedisStore implements Store with Redis SETNX reservations over a dedicated
|
||||
// key namespace.
|
||||
type RedisStore struct {
|
||||
client *redis.Client
|
||||
keyPrefix string
|
||||
reserveTimeout time.Duration
|
||||
}
|
||||
|
||||
// NewRedisStore constructs a Redis-backed replay store that reuses the
|
||||
// SessionCache Redis deployment settings and applies the replay-specific key
|
||||
// namespace and timeout controls from replayCfg.
|
||||
func NewRedisStore(sessionCfg config.SessionCacheRedisConfig, replayCfg config.ReplayRedisConfig) (*RedisStore, error) {
|
||||
if strings.TrimSpace(sessionCfg.Addr) == "" {
|
||||
return nil, errors.New("new redis replay store: redis addr must not be empty")
|
||||
}
|
||||
if sessionCfg.DB < 0 {
|
||||
return nil, errors.New("new redis replay store: redis db must not be negative")
|
||||
}
|
||||
if strings.TrimSpace(replayCfg.KeyPrefix) == "" {
|
||||
return nil, errors.New("new redis replay store: replay key prefix must not be empty")
|
||||
}
|
||||
if replayCfg.ReserveTimeout <= 0 {
|
||||
return nil, errors.New("new redis replay store: reserve timeout must be positive")
|
||||
}
|
||||
|
||||
options := &redis.Options{
|
||||
Addr: sessionCfg.Addr,
|
||||
Username: sessionCfg.Username,
|
||||
Password: sessionCfg.Password,
|
||||
DB: sessionCfg.DB,
|
||||
Protocol: 2,
|
||||
DisableIdentity: true,
|
||||
}
|
||||
if sessionCfg.TLSEnabled {
|
||||
options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
|
||||
}
|
||||
|
||||
return &RedisStore{
|
||||
client: redis.NewClient(options),
|
||||
keyPrefix: replayCfg.KeyPrefix,
|
||||
reserveTimeout: replayCfg.ReserveTimeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close releases the underlying Redis client resources.
|
||||
func (s *RedisStore) Close() error {
|
||||
if s == nil || s.client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return s.client.Close()
|
||||
}
|
||||
|
||||
// Ping verifies that the configured Redis backend is reachable within the
|
||||
// replay reserve timeout budget.
|
||||
func (s *RedisStore) Ping(ctx context.Context) error {
|
||||
if s == nil || s.client == nil {
|
||||
return errors.New("ping redis replay store: nil store")
|
||||
}
|
||||
if ctx == nil {
|
||||
return errors.New("ping redis replay store: nil context")
|
||||
}
|
||||
|
||||
pingCtx, cancel := context.WithTimeout(ctx, s.reserveTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := s.client.Ping(pingCtx).Err(); err != nil {
|
||||
return fmt.Errorf("ping redis replay store: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reserve records the authenticated deviceSessionID and requestID pair for
|
||||
// ttl. It rejects duplicates while the reservation remains active.
|
||||
func (s *RedisStore) Reserve(ctx context.Context, deviceSessionID string, requestID string, ttl time.Duration) error {
|
||||
if s == nil || s.client == nil {
|
||||
return errors.New("reserve replay request in redis: nil store")
|
||||
}
|
||||
if ctx == nil {
|
||||
return errors.New("reserve replay request in redis: nil context")
|
||||
}
|
||||
if strings.TrimSpace(deviceSessionID) == "" {
|
||||
return errors.New("reserve replay request in redis: empty device session id")
|
||||
}
|
||||
if strings.TrimSpace(requestID) == "" {
|
||||
return errors.New("reserve replay request in redis: empty request id")
|
||||
}
|
||||
if ttl <= 0 {
|
||||
return errors.New("reserve replay request in redis: ttl must be positive")
|
||||
}
|
||||
|
||||
reserveCtx, cancel := context.WithTimeout(ctx, s.reserveTimeout)
|
||||
defer cancel()
|
||||
|
||||
reserved, err := s.client.SetNX(reserveCtx, s.reservationKey(deviceSessionID, requestID), "1", ttl).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("reserve replay request in redis: %w", err)
|
||||
}
|
||||
if !reserved {
|
||||
return fmt.Errorf("reserve replay request in redis: %w", ErrDuplicate)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *RedisStore) reservationKey(deviceSessionID string, requestID string) string {
|
||||
return s.keyPrefix + encodeKeyComponent(deviceSessionID) + ":" + encodeKeyComponent(requestID)
|
||||
}
|
||||
|
||||
func encodeKeyComponent(value string) string {
|
||||
return base64.RawURLEncoding.EncodeToString([]byte(value))
|
||||
}
|
||||
|
||||
var _ Store = (*RedisStore)(nil)
|
||||
@@ -0,0 +1,254 @@
|
||||
package replay
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/config"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewRedisStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sessionCfg config.SessionCacheRedisConfig
|
||||
replayCfg config.ReplayRedisConfig
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
sessionCfg: config.SessionCacheRedisConfig{
|
||||
Addr: server.Addr(),
|
||||
DB: 2,
|
||||
},
|
||||
replayCfg: config.ReplayRedisConfig{
|
||||
KeyPrefix: "gateway:replay:",
|
||||
ReserveTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty redis addr",
|
||||
replayCfg: config.ReplayRedisConfig{
|
||||
KeyPrefix: "gateway:replay:",
|
||||
ReserveTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "redis addr must not be empty",
|
||||
},
|
||||
{
|
||||
name: "negative redis db",
|
||||
sessionCfg: config.SessionCacheRedisConfig{
|
||||
Addr: server.Addr(),
|
||||
DB: -1,
|
||||
},
|
||||
replayCfg: config.ReplayRedisConfig{
|
||||
KeyPrefix: "gateway:replay:",
|
||||
ReserveTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "redis db must not be negative",
|
||||
},
|
||||
{
|
||||
name: "empty replay key prefix",
|
||||
sessionCfg: config.SessionCacheRedisConfig{
|
||||
Addr: server.Addr(),
|
||||
},
|
||||
replayCfg: config.ReplayRedisConfig{
|
||||
ReserveTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "replay key prefix must not be empty",
|
||||
},
|
||||
{
|
||||
name: "non-positive reserve timeout",
|
||||
sessionCfg: config.SessionCacheRedisConfig{
|
||||
Addr: server.Addr(),
|
||||
},
|
||||
replayCfg: config.ReplayRedisConfig{
|
||||
KeyPrefix: "gateway:replay:",
|
||||
},
|
||||
wantErr: "reserve timeout must be positive",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store, err := NewRedisStore(tt.sessionCfg, tt.replayCfg)
|
||||
if tt.wantErr != "" {
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, store.Close())
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisStorePing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestRedisStore(t, server, config.SessionCacheRedisConfig{}, config.ReplayRedisConfig{})
|
||||
|
||||
require.NoError(t, store.Ping(context.Background()))
|
||||
}
|
||||
|
||||
func TestRedisStoreReserve(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sessionCfg config.SessionCacheRedisConfig
|
||||
replayCfg config.ReplayRedisConfig
|
||||
deviceSessionID string
|
||||
requestID string
|
||||
ttl time.Duration
|
||||
secondReserve func(*testing.T, Store)
|
||||
wantErrIs error
|
||||
wantErrText string
|
||||
}{
|
||||
{
|
||||
name: "first reservation succeeds",
|
||||
deviceSessionID: "device-session-123",
|
||||
requestID: "request-123",
|
||||
ttl: 5 * time.Second,
|
||||
},
|
||||
{
|
||||
name: "duplicate reservation is rejected",
|
||||
deviceSessionID: "device-session-123",
|
||||
requestID: "request-123",
|
||||
ttl: 5 * time.Second,
|
||||
secondReserve: func(t *testing.T, store Store) {
|
||||
t.Helper()
|
||||
err := store.Reserve(context.Background(), "device-session-123", "request-123", 5*time.Second)
|
||||
require.ErrorIs(t, err, ErrDuplicate)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "same request id in distinct sessions does not collide",
|
||||
deviceSessionID: "device-session-123",
|
||||
requestID: "request-123",
|
||||
ttl: 5 * time.Second,
|
||||
secondReserve: func(t *testing.T, store Store) {
|
||||
t.Helper()
|
||||
require.NoError(t, store.Reserve(context.Background(), "device-session-456", "request-123", 5*time.Second))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty device session id",
|
||||
requestID: "request-123",
|
||||
ttl: 5 * time.Second,
|
||||
wantErrText: "empty device session id",
|
||||
},
|
||||
{
|
||||
name: "empty request id",
|
||||
deviceSessionID: "device-session-123",
|
||||
ttl: 5 * time.Second,
|
||||
wantErrText: "empty request id",
|
||||
},
|
||||
{
|
||||
name: "non-positive ttl",
|
||||
deviceSessionID: "device-session-123",
|
||||
requestID: "request-123",
|
||||
wantErrText: "ttl must be positive",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := newTestRedisStore(t, server, tt.sessionCfg, tt.replayCfg)
|
||||
|
||||
err := store.Reserve(context.Background(), tt.deviceSessionID, tt.requestID, tt.ttl)
|
||||
if tt.wantErrIs != nil || tt.wantErrText != "" {
|
||||
require.Error(t, err)
|
||||
if tt.wantErrIs != nil {
|
||||
require.ErrorIs(t, err, tt.wantErrIs)
|
||||
}
|
||||
if tt.wantErrText != "" {
|
||||
require.ErrorContains(t, err, tt.wantErrText)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
if tt.secondReserve != nil {
|
||||
tt.secondReserve(t, store)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisStoreReserveReturnsBackendError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store, err := NewRedisStore(
|
||||
config.SessionCacheRedisConfig{Addr: unusedTCPAddr(t)},
|
||||
config.ReplayRedisConfig{
|
||||
KeyPrefix: "gateway:replay:",
|
||||
ReserveTimeout: 100 * time.Millisecond,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, store.Close())
|
||||
})
|
||||
|
||||
err = store.Reserve(context.Background(), "device-session-123", "request-123", 5*time.Second)
|
||||
require.Error(t, err)
|
||||
assert.False(t, errors.Is(err, ErrDuplicate))
|
||||
assert.ErrorContains(t, err, "reserve replay request in redis")
|
||||
}
|
||||
|
||||
func newTestRedisStore(t *testing.T, server *miniredis.Miniredis, sessionCfg config.SessionCacheRedisConfig, replayCfg config.ReplayRedisConfig) *RedisStore {
|
||||
t.Helper()
|
||||
|
||||
if sessionCfg.Addr == "" {
|
||||
sessionCfg.Addr = server.Addr()
|
||||
}
|
||||
if replayCfg.KeyPrefix == "" {
|
||||
replayCfg.KeyPrefix = "gateway:replay:"
|
||||
}
|
||||
if replayCfg.ReserveTimeout == 0 {
|
||||
replayCfg.ReserveTimeout = 250 * time.Millisecond
|
||||
}
|
||||
|
||||
store, err := NewRedisStore(sessionCfg, replayCfg)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, store.Close())
|
||||
})
|
||||
|
||||
return store
|
||||
}
|
||||
|
||||
func unusedTCPAddr(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
addr := listener.Addr().String()
|
||||
require.NoError(t, listener.Close())
|
||||
|
||||
return addr
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
// Package replay defines the authenticated replay-reservation contract used by
|
||||
// the gateway transport pipeline.
|
||||
package replay
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrDuplicate reports that the request identifier has already been
|
||||
// reserved for the same device session within the active replay window.
|
||||
ErrDuplicate = errors.New("replay reservation already exists")
|
||||
)
|
||||
|
||||
// Store reserves authenticated transport request identifiers for a bounded
|
||||
// replay window.
|
||||
type Store interface {
|
||||
// Reserve marks the deviceSessionID and requestID pair as seen for ttl.
|
||||
// Implementations must wrap ErrDuplicate when the same pair is reserved
|
||||
// again before ttl expires.
|
||||
Reserve(ctx context.Context, deviceSessionID string, requestID string, ttl time.Duration) error
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
package restapi
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/logging"
|
||||
"galaxy/gateway/internal/telemetry"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func withPublicObservability(logger *zap.Logger, metrics *telemetry.Runtime) gin.HandlerFunc {
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
|
||||
return func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
c.Next()
|
||||
|
||||
statusCode := c.Writer.Status()
|
||||
route := c.FullPath()
|
||||
if route == "" {
|
||||
route = c.Request.URL.Path
|
||||
}
|
||||
|
||||
class, ok := PublicRouteClassFromContext(c.Request.Context())
|
||||
if !ok {
|
||||
class = PublicRouteClassPublicMisc
|
||||
}
|
||||
|
||||
errorCode, _ := c.Get(publicErrorCodeContextKey)
|
||||
errorCodeValue, _ := errorCode.(string)
|
||||
|
||||
outcome := telemetry.OutcomeFromPublicErrorCode(statusCode, errorCodeValue)
|
||||
rejectReason := telemetry.RejectReason(outcome)
|
||||
duration := time.Since(start)
|
||||
|
||||
attrs := []attribute.KeyValue{
|
||||
attribute.String("route_class", string(class)),
|
||||
attribute.String("route", route),
|
||||
attribute.String("method", c.Request.Method),
|
||||
attribute.String("edge_outcome", string(outcome)),
|
||||
}
|
||||
if rejectReason != "" {
|
||||
attrs = append(attrs, attribute.String("reject_reason", rejectReason))
|
||||
}
|
||||
metrics.RecordPublicRequest(c.Request.Context(), attrs, duration)
|
||||
|
||||
fields := []zap.Field{
|
||||
zap.String("component", "public_http"),
|
||||
zap.String("transport", "http"),
|
||||
zap.String("route", route),
|
||||
zap.String("route_class", string(class)),
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.Int("status_code", statusCode),
|
||||
zap.Float64("duration_ms", float64(duration.Microseconds())/1000),
|
||||
zap.String("edge_outcome", string(outcome)),
|
||||
}
|
||||
if rejectReason != "" {
|
||||
fields = append(fields, zap.String("reject_reason", rejectReason))
|
||||
}
|
||||
fields = append(fields, logging.TraceFieldsFromContext(c.Request.Context())...)
|
||||
|
||||
switch outcome {
|
||||
case telemetry.EdgeOutcomeSuccess:
|
||||
logger.Info("public request completed", fields...)
|
||||
case telemetry.EdgeOutcomeBackendUnavailable, telemetry.EdgeOutcomeInternalError:
|
||||
logger.Error("public request failed", fields...)
|
||||
default:
|
||||
logger.Warn("public request rejected", fields...)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
package restapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/getkin/kin-openapi/openapi3"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPublicOpenAPISpecValidates(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, thisFile, _, ok := runtime.Caller(0)
|
||||
require.True(t, ok)
|
||||
|
||||
specPath := filepath.Join(filepath.Dir(thisFile), "..", "..", "openapi.yaml")
|
||||
ctx := context.Background()
|
||||
|
||||
loader := openapi3.NewLoader()
|
||||
doc, err := loader.LoadFromFile(specPath)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, doc)
|
||||
require.NotNil(t, doc.Info)
|
||||
require.Equal(t, "v1", doc.Info.Version)
|
||||
|
||||
require.NoError(t, doc.Validate(ctx))
|
||||
}
|
||||
@@ -0,0 +1,378 @@
|
||||
package restapi
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/config"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
const (
|
||||
errorCodeRequestTooLarge = "request_too_large"
|
||||
errorCodeRateLimited = "rate_limited"
|
||||
|
||||
publicRESTIPBucketKeySegment = "/ip="
|
||||
)
|
||||
|
||||
var errRequestBodyTooLarge = errors.New("request body exceeds the configured limit")
|
||||
|
||||
// PublicMalformedRequestReason identifies the stable malformed-request counter
|
||||
// dimension recorded by the public REST anti-abuse middleware.
|
||||
type PublicMalformedRequestReason string
|
||||
|
||||
const (
|
||||
// PublicMalformedRequestReasonEmptyBody records a missing request body.
|
||||
PublicMalformedRequestReasonEmptyBody PublicMalformedRequestReason = "empty_body"
|
||||
|
||||
// PublicMalformedRequestReasonMalformedJSON records syntactically malformed
|
||||
// JSON.
|
||||
PublicMalformedRequestReasonMalformedJSON PublicMalformedRequestReason = "malformed_json"
|
||||
|
||||
// PublicMalformedRequestReasonInvalidJSONValue records JSON values whose
|
||||
// types do not match the expected request schema.
|
||||
PublicMalformedRequestReasonInvalidJSONValue PublicMalformedRequestReason = "invalid_json_value"
|
||||
|
||||
// PublicMalformedRequestReasonUnknownField records JSON objects with fields
|
||||
// outside the documented schema.
|
||||
PublicMalformedRequestReasonUnknownField PublicMalformedRequestReason = "unknown_field"
|
||||
|
||||
// PublicMalformedRequestReasonMultipleJSONObjects records requests that
|
||||
// contain more than one JSON object.
|
||||
PublicMalformedRequestReasonMultipleJSONObjects PublicMalformedRequestReason = "multiple_json_objects"
|
||||
|
||||
// PublicMalformedRequestReasonOversizedBody records requests whose bodies
|
||||
// exceed the configured class limit.
|
||||
PublicMalformedRequestReasonOversizedBody PublicMalformedRequestReason = "oversized_body"
|
||||
)
|
||||
|
||||
// PublicRateLimitDecision describes the outcome returned by a public REST
|
||||
// limiter for one request bucket reservation attempt.
|
||||
type PublicRateLimitDecision struct {
|
||||
// Allowed reports whether the request may proceed immediately.
|
||||
Allowed bool
|
||||
|
||||
// RetryAfter is the minimum delay the client should wait before retrying
|
||||
// when Allowed is false.
|
||||
RetryAfter time.Duration
|
||||
}
|
||||
|
||||
// PublicRequestLimiter applies public REST rate-limit policy to a concrete
|
||||
// bucket key.
|
||||
type PublicRequestLimiter interface {
|
||||
// Reserve evaluates key under policy and returns whether the request may
|
||||
// proceed immediately.
|
||||
Reserve(key string, policy config.PublicRateLimitConfig) PublicRateLimitDecision
|
||||
}
|
||||
|
||||
// PublicRequestObserver captures low-cardinality public REST anti-abuse
|
||||
// telemetry.
|
||||
type PublicRequestObserver interface {
|
||||
// RecordMalformedRequest records one malformed request in class for reason.
|
||||
RecordMalformedRequest(class PublicRouteClass, reason PublicMalformedRequestReason)
|
||||
}
|
||||
|
||||
type noopPublicRequestObserver struct{}
|
||||
|
||||
func (noopPublicRequestObserver) RecordMalformedRequest(PublicRouteClass, PublicMalformedRequestReason) {
|
||||
}
|
||||
|
||||
type inMemoryPublicRequestLimiter struct {
|
||||
now func() time.Time
|
||||
cleanupInterval time.Duration
|
||||
|
||||
mu sync.Mutex
|
||||
entries map[string]*publicRateLimiterEntry
|
||||
nextCleanup time.Time
|
||||
}
|
||||
|
||||
type publicRateLimiterEntry struct {
|
||||
limiter *rate.Limiter
|
||||
limit rate.Limit
|
||||
burst int
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
func newInMemoryPublicRequestLimiter() *inMemoryPublicRequestLimiter {
|
||||
return &inMemoryPublicRequestLimiter{
|
||||
now: time.Now,
|
||||
cleanupInterval: time.Minute,
|
||||
entries: make(map[string]*publicRateLimiterEntry),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *inMemoryPublicRequestLimiter) Reserve(key string, policy config.PublicRateLimitConfig) PublicRateLimitDecision {
|
||||
now := l.now()
|
||||
limit := rate.Limit(float64(policy.Requests) / policy.Window.Seconds())
|
||||
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
l.cleanupExpiredBucketsLocked(now)
|
||||
|
||||
entry, ok := l.entries[key]
|
||||
if !ok || entry.limit != limit || entry.burst != policy.Burst {
|
||||
entry = &publicRateLimiterEntry{
|
||||
limiter: rate.NewLimiter(limit, policy.Burst),
|
||||
limit: limit,
|
||||
burst: policy.Burst,
|
||||
}
|
||||
l.entries[key] = entry
|
||||
}
|
||||
|
||||
entry.expiresAt = now.Add(publicRateLimiterEntryTTL(policy.Window))
|
||||
|
||||
reservation := entry.limiter.ReserveN(now, 1)
|
||||
if !reservation.OK() {
|
||||
return PublicRateLimitDecision{
|
||||
Allowed: false,
|
||||
RetryAfter: policy.Window,
|
||||
}
|
||||
}
|
||||
|
||||
retryAfter := reservation.DelayFrom(now)
|
||||
if retryAfter > 0 {
|
||||
return PublicRateLimitDecision{
|
||||
Allowed: false,
|
||||
RetryAfter: retryAfter,
|
||||
}
|
||||
}
|
||||
|
||||
return PublicRateLimitDecision{Allowed: true}
|
||||
}
|
||||
|
||||
func (l *inMemoryPublicRequestLimiter) cleanupExpiredBucketsLocked(now time.Time) {
|
||||
if !l.nextCleanup.IsZero() && now.Before(l.nextCleanup) {
|
||||
return
|
||||
}
|
||||
|
||||
for key, entry := range l.entries {
|
||||
if !entry.expiresAt.After(now) {
|
||||
delete(l.entries, key)
|
||||
}
|
||||
}
|
||||
|
||||
l.nextCleanup = now.Add(l.cleanupInterval)
|
||||
}
|
||||
|
||||
func publicRateLimiterEntryTTL(window time.Duration) time.Duration {
|
||||
if window < time.Minute {
|
||||
return time.Minute
|
||||
}
|
||||
|
||||
return 2 * window
|
||||
}
|
||||
|
||||
func withPublicAntiAbuse(policy config.PublicHTTPAntiAbuseConfig, limiter PublicRequestLimiter, observer PublicRequestObserver) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
class, ok := PublicRouteClassFromContext(c.Request.Context())
|
||||
if !ok {
|
||||
class = PublicRouteClassPublicMisc
|
||||
}
|
||||
|
||||
allowedMethods := allowedMethodsForRequestShape(c.Request)
|
||||
if len(allowedMethods) > 0 && !isAllowedMethod(c.Request.Method, allowedMethods) {
|
||||
c.Header("Allow", strings.Join(allowedMethods, ", "))
|
||||
abortWithError(c, http.StatusMethodNotAllowed, errorCodeMethodNotAllowed, "request method is not allowed for this route")
|
||||
return
|
||||
}
|
||||
|
||||
classPolicy := publicRoutePolicyForClass(policy, class)
|
||||
bodyBytes, err := bufferRequestBody(c.Request, classPolicy.MaxBodyBytes)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, errRequestBodyTooLarge):
|
||||
observer.RecordMalformedRequest(class, PublicMalformedRequestReasonOversizedBody)
|
||||
abortWithError(c, http.StatusRequestEntityTooLarge, errorCodeRequestTooLarge, "request body exceeds the configured limit")
|
||||
default:
|
||||
abortWithError(c, http.StatusInternalServerError, errorCodeInternalError, "internal server error")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
clientIP := clientIPFromRemoteAddr(c.Request.RemoteAddr)
|
||||
if decision := limiter.Reserve(publicRESTIPBucketKey(class, clientIP), classPolicy.RateLimit); !decision.Allowed {
|
||||
abortRateLimited(c, decision.RetryAfter)
|
||||
return
|
||||
}
|
||||
|
||||
identity, err := extractPublicAuthIdentity(c.Request.URL.Path, bodyBytes)
|
||||
switch {
|
||||
case err == nil:
|
||||
identityPolicy := publicAuthIdentityPolicyForPath(c.Request.URL.Path, policy)
|
||||
if decision := limiter.Reserve(publicAuthIdentityBucketKey(class, identity.kind, identity.value), identityPolicy.RateLimit); !decision.Allowed {
|
||||
abortRateLimited(c, decision.RetryAfter)
|
||||
return
|
||||
}
|
||||
case errors.Is(err, errPublicAuthIdentityNotApplicable):
|
||||
default:
|
||||
if reason, malformed := malformedRequestReasonFromError(err); malformed {
|
||||
observer.RecordMalformedRequest(class, reason)
|
||||
}
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func publicRoutePolicyForClass(policy config.PublicHTTPAntiAbuseConfig, class PublicRouteClass) config.PublicRoutePolicyConfig {
|
||||
switch class.Normalized() {
|
||||
case PublicRouteClassPublicAuth:
|
||||
return policy.PublicAuth
|
||||
case PublicRouteClassBrowserBootstrap:
|
||||
return policy.BrowserBootstrap
|
||||
case PublicRouteClassBrowserAsset:
|
||||
return policy.BrowserAsset
|
||||
default:
|
||||
return policy.PublicMisc
|
||||
}
|
||||
}
|
||||
|
||||
func publicAuthIdentityPolicyForPath(requestPath string, policy config.PublicHTTPAntiAbuseConfig) config.PublicAuthIdentityPolicyConfig {
|
||||
switch requestPath {
|
||||
case "/api/v1/public/auth/send-email-code":
|
||||
return policy.SendEmailCodeIdentity
|
||||
case "/api/v1/public/auth/confirm-email-code":
|
||||
return policy.ConfirmEmailCodeIdentity
|
||||
default:
|
||||
return config.PublicAuthIdentityPolicyConfig{}
|
||||
}
|
||||
}
|
||||
|
||||
func allowedMethodsForRequestShape(r *http.Request) []string {
|
||||
switch {
|
||||
case isPublicAuthPath(r.URL.Path):
|
||||
return []string{http.MethodPost}
|
||||
case isProbePath(r.URL.Path):
|
||||
return []string{http.MethodGet}
|
||||
case matchesBrowserAssetRequestShape(r):
|
||||
return []string{http.MethodGet, http.MethodHead}
|
||||
case matchesBrowserBootstrapRequestShape(r):
|
||||
return []string{http.MethodGet, http.MethodHead}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func isAllowedMethod(method string, allowedMethods []string) bool {
|
||||
for _, allowedMethod := range allowedMethods {
|
||||
if method == allowedMethod {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func isPublicAuthPath(requestPath string) bool {
|
||||
switch requestPath {
|
||||
case "/api/v1/public/auth/send-email-code", "/api/v1/public/auth/confirm-email-code":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isProbePath(requestPath string) bool {
|
||||
switch requestPath {
|
||||
case "/healthz", "/readyz":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func matchesBrowserBootstrapRequestShape(r *http.Request) bool {
|
||||
if r.URL.Path == "/" {
|
||||
return true
|
||||
}
|
||||
|
||||
return strings.Contains(strings.ToLower(r.Header.Get("Accept")), "text/html")
|
||||
}
|
||||
|
||||
func matchesBrowserAssetRequestShape(r *http.Request) bool {
|
||||
if strings.HasPrefix(r.URL.Path, "/assets/") {
|
||||
return true
|
||||
}
|
||||
|
||||
switch strings.ToLower(path.Ext(r.URL.Path)) {
|
||||
case ".js", ".mjs", ".css", ".map", ".png", ".jpg", ".jpeg", ".gif", ".svg", ".ico", ".woff", ".woff2", ".json", ".webmanifest":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func bufferRequestBody(r *http.Request, maxBodyBytes int64) ([]byte, error) {
|
||||
if r == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if r.Body == nil {
|
||||
r.Body = io.NopCloser(bytes.NewReader(nil))
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
bodyBytes, err := io.ReadAll(io.LimitReader(r.Body, maxBodyBytes+1))
|
||||
closeErr := r.Body.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if closeErr != nil {
|
||||
return nil, closeErr
|
||||
}
|
||||
if int64(len(bodyBytes)) > maxBodyBytes {
|
||||
return nil, errRequestBodyTooLarge
|
||||
}
|
||||
|
||||
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
|
||||
return bodyBytes, nil
|
||||
}
|
||||
|
||||
func abortRateLimited(c *gin.Context, retryAfter time.Duration) {
|
||||
c.Header("Retry-After", retryAfterHeaderValue(retryAfter))
|
||||
abortWithError(c, http.StatusTooManyRequests, errorCodeRateLimited, "request rate limit exceeded")
|
||||
}
|
||||
|
||||
func retryAfterHeaderValue(delay time.Duration) string {
|
||||
seconds := int64(math.Ceil(delay.Seconds()))
|
||||
if seconds < 1 {
|
||||
seconds = 1
|
||||
}
|
||||
|
||||
return strconv.FormatInt(seconds, 10)
|
||||
}
|
||||
|
||||
func clientIPFromRemoteAddr(remoteAddr string) string {
|
||||
host, _, err := net.SplitHostPort(strings.TrimSpace(remoteAddr))
|
||||
if err == nil {
|
||||
return host
|
||||
}
|
||||
|
||||
remoteAddr = strings.TrimSpace(remoteAddr)
|
||||
if remoteAddr == "" {
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
return remoteAddr
|
||||
}
|
||||
|
||||
func publicRESTIPBucketKey(class PublicRouteClass, clientIP string) string {
|
||||
return class.BaseBucketKey() + publicRESTIPBucketKeySegment + clientIP
|
||||
}
|
||||
|
||||
func publicAuthIdentityBucketKey(class PublicRouteClass, identityKind string, identityValue string) string {
|
||||
return class.BaseBucketKey() + "/" + identityKind + "=" + identityValue
|
||||
}
|
||||
@@ -0,0 +1,455 @@
|
||||
package restapi
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/config"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPublicAntiAbuseRejectsOversizedBodies(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
oversizedJSONBody := `{"email":"` + strings.Repeat("a", 8200) + `@example.com"}`
|
||||
oversizedConfirmJSONBody := `{"challenge_id":"` + strings.Repeat("c", 8300) + `","code":"123456","client_public_key":"key"}`
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
target string
|
||||
body string
|
||||
wantClass PublicRouteClass
|
||||
}{
|
||||
{
|
||||
name: "send email",
|
||||
method: http.MethodPost,
|
||||
target: "/api/v1/public/auth/send-email-code",
|
||||
body: oversizedJSONBody,
|
||||
wantClass: PublicRouteClassPublicAuth,
|
||||
},
|
||||
{
|
||||
name: "confirm email",
|
||||
method: http.MethodPost,
|
||||
target: "/api/v1/public/auth/confirm-email-code",
|
||||
body: oversizedConfirmJSONBody,
|
||||
wantClass: PublicRouteClassPublicAuth,
|
||||
},
|
||||
{
|
||||
name: "healthz body",
|
||||
method: http.MethodGet,
|
||||
target: "/healthz",
|
||||
body: `x`,
|
||||
wantClass: PublicRouteClassPublicMisc,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
observer := &recordingPublicRequestObserver{}
|
||||
authService := &recordingAuthServiceClient{
|
||||
sendEmailCodeResult: SendEmailCodeResult{ChallengeID: "challenge-123"},
|
||||
confirmEmailCodeResult: ConfirmEmailCodeResult{
|
||||
DeviceSessionID: "device-session-123",
|
||||
},
|
||||
}
|
||||
handler := newPublicHandlerWithConfig(config.DefaultPublicHTTPConfig(), ServerDependencies{
|
||||
AuthService: authService,
|
||||
Observer: observer,
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(tt.method, tt.target, strings.NewReader(tt.body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusRequestEntityTooLarge, recorder.Code)
|
||||
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
|
||||
assert.Equal(t, `{"error":{"code":"request_too_large","message":"request body exceeds the configured limit"}}`, recorder.Body.String())
|
||||
assert.Equal(t, 0, authService.sendEmailCodeCalls)
|
||||
assert.Equal(t, 0, authService.confirmEmailCodeCalls)
|
||||
assert.Equal(t, []malformedObservation{{
|
||||
class: tt.wantClass,
|
||||
reason: PublicMalformedRequestReasonOversizedBody,
|
||||
}}, observer.snapshot())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublicAntiAbuseRejectsInvalidMethodsForBrowserShapes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newPublicHandler(ServerDependencies{})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
target string
|
||||
accept string
|
||||
wantAllow string
|
||||
}{
|
||||
{
|
||||
name: "asset path",
|
||||
method: http.MethodPost,
|
||||
target: "/assets/app.js",
|
||||
wantAllow: "GET, HEAD",
|
||||
},
|
||||
{
|
||||
name: "bootstrap request",
|
||||
method: http.MethodPost,
|
||||
target: "/",
|
||||
accept: "text/html",
|
||||
wantAllow: "GET, HEAD",
|
||||
},
|
||||
{
|
||||
name: "head probe rejected",
|
||||
method: http.MethodHead,
|
||||
target: "/healthz",
|
||||
wantAllow: http.MethodGet,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := httptest.NewRequest(tt.method, tt.target, nil)
|
||||
if tt.accept != "" {
|
||||
req.Header.Set("Accept", tt.accept)
|
||||
}
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusMethodNotAllowed, recorder.Code)
|
||||
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
|
||||
assert.Equal(t, tt.wantAllow, recorder.Header().Get("Allow"))
|
||||
assert.Equal(t, `{"error":{"code":"method_not_allowed","message":"request method is not allowed for this route"}}`, recorder.Body.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublicAntiAbuseBrowserClassBucketsStayIsolatedFromPublicAuth(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
burstRequest *http.Request
|
||||
}{
|
||||
{
|
||||
name: "browser asset",
|
||||
burstRequest: httptest.NewRequest(http.MethodGet, "/assets/app.js", nil),
|
||||
},
|
||||
{
|
||||
name: "browser bootstrap",
|
||||
burstRequest: func() *http.Request {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Accept", "text/html")
|
||||
return req
|
||||
}(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := config.DefaultPublicHTTPConfig()
|
||||
cfg.AntiAbuse.BrowserAsset.RateLimit = config.PublicRateLimitConfig{
|
||||
Requests: 1,
|
||||
Window: time.Hour,
|
||||
Burst: 1,
|
||||
}
|
||||
cfg.AntiAbuse.BrowserBootstrap.RateLimit = config.PublicRateLimitConfig{
|
||||
Requests: 1,
|
||||
Window: time.Hour,
|
||||
Burst: 1,
|
||||
}
|
||||
cfg.AntiAbuse.PublicAuth.RateLimit = config.PublicRateLimitConfig{
|
||||
Requests: 100,
|
||||
Window: time.Hour,
|
||||
Burst: 100,
|
||||
}
|
||||
authService := &recordingAuthServiceClient{
|
||||
sendEmailCodeResult: SendEmailCodeResult{
|
||||
ChallengeID: "challenge-123",
|
||||
},
|
||||
}
|
||||
handler := newPublicHandlerWithConfig(cfg, ServerDependencies{AuthService: authService})
|
||||
|
||||
tt.burstRequest.RemoteAddr = "192.0.2.10:1234"
|
||||
|
||||
firstBurst := httptest.NewRecorder()
|
||||
handler.ServeHTTP(firstBurst, tt.burstRequest.Clone(tt.burstRequest.Context()))
|
||||
|
||||
secondBurst := httptest.NewRecorder()
|
||||
handler.ServeHTTP(secondBurst, tt.burstRequest.Clone(tt.burstRequest.Context()))
|
||||
|
||||
authReq := httptest.NewRequest(http.MethodPost, "/api/v1/public/auth/send-email-code", strings.NewReader(`{"email":"pilot@example.com"}`))
|
||||
authReq.Header.Set("Content-Type", "application/json")
|
||||
authReq.RemoteAddr = "192.0.2.10:1234"
|
||||
authResp := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(authResp, authReq)
|
||||
|
||||
assert.Equal(t, http.StatusNotFound, firstBurst.Code)
|
||||
assert.Equal(t, http.StatusTooManyRequests, secondBurst.Code)
|
||||
assert.Equal(t, http.StatusOK, authResp.Code)
|
||||
assert.Equal(t, `{"challenge_id":"challenge-123"}`, authResp.Body.String())
|
||||
assert.Equal(t, 1, authService.sendEmailCodeCalls)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublicAntiAbuseSendEmailIdentityThrottle(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := config.DefaultPublicHTTPConfig()
|
||||
cfg.AntiAbuse.PublicAuth.RateLimit = config.PublicRateLimitConfig{
|
||||
Requests: 100,
|
||||
Window: time.Hour,
|
||||
Burst: 100,
|
||||
}
|
||||
cfg.AntiAbuse.SendEmailCodeIdentity.RateLimit = config.PublicRateLimitConfig{
|
||||
Requests: 1,
|
||||
Window: time.Hour,
|
||||
Burst: 1,
|
||||
}
|
||||
|
||||
authService := &recordingAuthServiceClient{
|
||||
sendEmailCodeResult: SendEmailCodeResult{
|
||||
ChallengeID: "challenge-123",
|
||||
},
|
||||
}
|
||||
handler := newPublicHandlerWithConfig(cfg, ServerDependencies{AuthService: authService})
|
||||
|
||||
first := sendEmailCodeRequest(`{"email":"pilot@example.com"}`)
|
||||
second := sendEmailCodeRequest(`{"email":"pilot@example.com"}`)
|
||||
third := sendEmailCodeRequest(`{"email":"other@example.com"}`)
|
||||
|
||||
firstResp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(firstResp, first)
|
||||
|
||||
secondResp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(secondResp, second)
|
||||
|
||||
thirdResp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(thirdResp, third)
|
||||
|
||||
assert.Equal(t, http.StatusOK, firstResp.Code)
|
||||
assert.Equal(t, http.StatusTooManyRequests, secondResp.Code)
|
||||
assert.Equal(t, "3600", secondResp.Header().Get("Retry-After"))
|
||||
assert.Equal(t, http.StatusOK, thirdResp.Code)
|
||||
assert.Equal(t, 2, authService.sendEmailCodeCalls)
|
||||
thirdInput := authService.sendEmailCodeInput
|
||||
assert.Equal(t, "other@example.com", thirdInput.Email)
|
||||
}
|
||||
|
||||
func TestPublicAntiAbuseConfirmEmailIdentityThrottle(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := config.DefaultPublicHTTPConfig()
|
||||
cfg.AntiAbuse.PublicAuth.RateLimit = config.PublicRateLimitConfig{
|
||||
Requests: 100,
|
||||
Window: time.Hour,
|
||||
Burst: 100,
|
||||
}
|
||||
cfg.AntiAbuse.ConfirmEmailCodeIdentity.RateLimit = config.PublicRateLimitConfig{
|
||||
Requests: 1,
|
||||
Window: time.Hour,
|
||||
Burst: 1,
|
||||
}
|
||||
|
||||
authService := &recordingAuthServiceClient{
|
||||
confirmEmailCodeResult: ConfirmEmailCodeResult{
|
||||
DeviceSessionID: "device-session-123",
|
||||
},
|
||||
}
|
||||
handler := newPublicHandlerWithConfig(cfg, ServerDependencies{AuthService: authService})
|
||||
|
||||
first := confirmEmailCodeRequest(`{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`)
|
||||
second := confirmEmailCodeRequest(`{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`)
|
||||
third := confirmEmailCodeRequest(`{"challenge_id":"challenge-456","code":"123456","client_public_key":"public-key-material"}`)
|
||||
|
||||
firstResp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(firstResp, first)
|
||||
|
||||
secondResp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(secondResp, second)
|
||||
|
||||
thirdResp := httptest.NewRecorder()
|
||||
handler.ServeHTTP(thirdResp, third)
|
||||
|
||||
assert.Equal(t, http.StatusOK, firstResp.Code)
|
||||
assert.Equal(t, http.StatusTooManyRequests, secondResp.Code)
|
||||
assert.Equal(t, "3600", secondResp.Header().Get("Retry-After"))
|
||||
assert.Equal(t, http.StatusOK, thirdResp.Code)
|
||||
assert.Equal(t, 2, authService.confirmEmailCodeCalls)
|
||||
assert.Equal(t, "challenge-456", authService.confirmEmailCodeInput.ChallengeID)
|
||||
}
|
||||
|
||||
func TestPublicAntiAbuseMalformedTelemetry(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
wantReason PublicMalformedRequestReason
|
||||
wantRecords int
|
||||
}{
|
||||
{
|
||||
name: "empty body",
|
||||
body: ``,
|
||||
wantReason: PublicMalformedRequestReasonEmptyBody,
|
||||
wantRecords: 1,
|
||||
},
|
||||
{
|
||||
name: "malformed json",
|
||||
body: `{"email":`,
|
||||
wantReason: PublicMalformedRequestReasonMalformedJSON,
|
||||
wantRecords: 1,
|
||||
},
|
||||
{
|
||||
name: "invalid json value",
|
||||
body: `{"email":123}`,
|
||||
wantReason: PublicMalformedRequestReasonInvalidJSONValue,
|
||||
wantRecords: 1,
|
||||
},
|
||||
{
|
||||
name: "unknown field",
|
||||
body: `{"email":"pilot@example.com","extra":"x"}`,
|
||||
wantReason: PublicMalformedRequestReasonUnknownField,
|
||||
wantRecords: 1,
|
||||
},
|
||||
{
|
||||
name: "multiple objects",
|
||||
body: `{"email":"pilot@example.com"}{"email":"pilot@example.com"}`,
|
||||
wantReason: PublicMalformedRequestReasonMultipleJSONObjects,
|
||||
wantRecords: 1,
|
||||
},
|
||||
{
|
||||
name: "validation error does not count as malformed",
|
||||
body: `{"email":"not-an-email"}`,
|
||||
wantRecords: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
observer := &recordingPublicRequestObserver{}
|
||||
authService := &recordingAuthServiceClient{}
|
||||
handler := newPublicHandlerWithConfig(config.DefaultPublicHTTPConfig(), ServerDependencies{
|
||||
AuthService: authService,
|
||||
Observer: observer,
|
||||
})
|
||||
|
||||
req := sendEmailCodeRequest(tt.body)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||
assert.Equal(t, tt.wantRecords, len(observer.snapshot()))
|
||||
assert.Equal(t, 0, authService.sendEmailCodeCalls)
|
||||
if tt.wantRecords == 1 {
|
||||
assert.Equal(t, malformedObservation{
|
||||
class: PublicRouteClassPublicAuth,
|
||||
reason: tt.wantReason,
|
||||
}, observer.snapshot()[0])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInMemoryPublicRequestLimiterCleansExpiredBuckets(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Unix(1000, 0)
|
||||
limiter := newInMemoryPublicRequestLimiter()
|
||||
limiter.now = func() time.Time {
|
||||
return now
|
||||
}
|
||||
limiter.cleanupInterval = time.Second
|
||||
|
||||
policy := config.PublicRateLimitConfig{
|
||||
Requests: 1,
|
||||
Window: time.Minute,
|
||||
Burst: 1,
|
||||
}
|
||||
|
||||
firstDecision := limiter.Reserve("bucket-1", policy)
|
||||
secondDecision := limiter.Reserve("bucket-2", policy)
|
||||
require.True(t, firstDecision.Allowed)
|
||||
require.True(t, secondDecision.Allowed)
|
||||
require.Len(t, limiter.entries, 2)
|
||||
|
||||
now = now.Add(3 * time.Minute)
|
||||
|
||||
thirdDecision := limiter.Reserve("bucket-3", policy)
|
||||
require.True(t, thirdDecision.Allowed)
|
||||
assert.Len(t, limiter.entries, 1)
|
||||
_, exists := limiter.entries["bucket-3"]
|
||||
assert.True(t, exists)
|
||||
}
|
||||
|
||||
func sendEmailCodeRequest(body string) *http.Request {
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/public/auth/send-email-code", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.RemoteAddr = "192.0.2.10:1234"
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
func confirmEmailCodeRequest(body string) *http.Request {
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/public/auth/confirm-email-code", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.RemoteAddr = "192.0.2.10:1234"
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
type malformedObservation struct {
|
||||
class PublicRouteClass
|
||||
reason PublicMalformedRequestReason
|
||||
}
|
||||
|
||||
type recordingPublicRequestObserver struct {
|
||||
mu sync.Mutex
|
||||
observations []malformedObservation
|
||||
}
|
||||
|
||||
func (o *recordingPublicRequestObserver) RecordMalformedRequest(class PublicRouteClass, reason PublicMalformedRequestReason) {
|
||||
o.mu.Lock()
|
||||
defer o.mu.Unlock()
|
||||
|
||||
o.observations = append(o.observations, malformedObservation{
|
||||
class: class,
|
||||
reason: reason,
|
||||
})
|
||||
}
|
||||
|
||||
func (o *recordingPublicRequestObserver) snapshot() []malformedObservation {
|
||||
o.mu.Lock()
|
||||
defer o.mu.Unlock()
|
||||
|
||||
return append([]malformedObservation(nil), o.observations...)
|
||||
}
|
||||
@@ -0,0 +1,446 @@
|
||||
package restapi
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/mail"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var errPublicAuthIdentityNotApplicable = errors.New("public auth identity does not apply to this route")
|
||||
|
||||
type malformedJSONRequestError struct {
|
||||
message string
|
||||
reason PublicMalformedRequestReason
|
||||
}
|
||||
|
||||
func (e *malformedJSONRequestError) Error() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return e.message
|
||||
}
|
||||
|
||||
type publicAuthIdentity struct {
|
||||
kind string
|
||||
value string
|
||||
}
|
||||
|
||||
// AuthServiceClient defines the consumer-side contract used by public auth
|
||||
// REST handlers to delegate unauthenticated authentication commands to the
|
||||
// Auth / Session Service.
|
||||
type AuthServiceClient interface {
|
||||
// SendEmailCode starts a login challenge for input.Email and returns the
|
||||
// challenge identifier that the client must later confirm.
|
||||
SendEmailCode(ctx context.Context, input SendEmailCodeInput) (SendEmailCodeResult, error)
|
||||
|
||||
// ConfirmEmailCode completes a previously issued challenge, registers
|
||||
// input.ClientPublicKey for the new device session, and returns the created
|
||||
// device session identifier.
|
||||
ConfirmEmailCode(ctx context.Context, input ConfirmEmailCodeInput) (ConfirmEmailCodeResult, error)
|
||||
}
|
||||
|
||||
// SendEmailCodeInput describes the public REST and adapter payload used to
|
||||
// request a login code for a single e-mail address.
|
||||
type SendEmailCodeInput struct {
|
||||
// Email is the single client e-mail address that should receive the login
|
||||
// code challenge.
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
// SendEmailCodeResult describes the public REST and adapter payload returned
|
||||
// after the Auth / Session Service creates a login challenge.
|
||||
type SendEmailCodeResult struct {
|
||||
// ChallengeID identifies the issued challenge that must be confirmed by the
|
||||
// client in the next public auth step.
|
||||
ChallengeID string `json:"challenge_id"`
|
||||
}
|
||||
|
||||
// ConfirmEmailCodeInput describes the public REST and adapter payload used to
|
||||
// complete a previously issued login challenge.
|
||||
type ConfirmEmailCodeInput struct {
|
||||
// ChallengeID identifies the challenge previously returned by
|
||||
// SendEmailCode.
|
||||
ChallengeID string `json:"challenge_id"`
|
||||
|
||||
// Code is the verification code delivered to the client by the Auth /
|
||||
// Session Service.
|
||||
Code string `json:"code"`
|
||||
|
||||
// ClientPublicKey is the standard base64-encoded raw 32-byte Ed25519 public
|
||||
// key that should be registered for the created device session.
|
||||
ClientPublicKey string `json:"client_public_key"`
|
||||
}
|
||||
|
||||
// ConfirmEmailCodeResult describes the public REST and adapter payload
|
||||
// returned after the Auth / Session Service creates a device session.
|
||||
type ConfirmEmailCodeResult struct {
|
||||
// DeviceSessionID is the stable identifier of the created device session.
|
||||
DeviceSessionID string `json:"device_session_id"`
|
||||
}
|
||||
|
||||
// AuthServiceError allows an auth adapter to project a stable public REST
|
||||
// error without teaching the gateway transport layer about upstream business
|
||||
// rules.
|
||||
type AuthServiceError struct {
|
||||
// StatusCode is the HTTP status that the public REST handler should expose.
|
||||
StatusCode int
|
||||
|
||||
// Code is the stable edge-level error code written into the JSON envelope.
|
||||
Code string
|
||||
|
||||
// Message is the human-readable client-safe error description.
|
||||
Message string
|
||||
}
|
||||
|
||||
// Error returns a readable representation of the projected auth service error.
|
||||
func (e *AuthServiceError) Error() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch {
|
||||
case strings.TrimSpace(e.Code) == "" && strings.TrimSpace(e.Message) == "":
|
||||
return http.StatusText(e.normalizedStatusCode())
|
||||
case strings.TrimSpace(e.Code) == "":
|
||||
return e.Message
|
||||
case strings.TrimSpace(e.Message) == "":
|
||||
return e.Code
|
||||
default:
|
||||
return e.Code + ": " + e.Message
|
||||
}
|
||||
}
|
||||
|
||||
func (e *AuthServiceError) normalizedStatusCode() int {
|
||||
if e == nil || e.StatusCode < 400 || e.StatusCode > 599 {
|
||||
return http.StatusInternalServerError
|
||||
}
|
||||
|
||||
return e.StatusCode
|
||||
}
|
||||
|
||||
func (e *AuthServiceError) normalizedCode() string {
|
||||
if e == nil {
|
||||
return errorCodeInternalError
|
||||
}
|
||||
|
||||
code := strings.TrimSpace(e.Code)
|
||||
if code == "" {
|
||||
switch e.normalizedStatusCode() {
|
||||
case http.StatusServiceUnavailable:
|
||||
return errorCodeServiceUnavailable
|
||||
case http.StatusBadRequest:
|
||||
return errorCodeInvalidRequest
|
||||
default:
|
||||
return errorCodeInternalError
|
||||
}
|
||||
}
|
||||
|
||||
return code
|
||||
}
|
||||
|
||||
func (e *AuthServiceError) normalizedMessage() string {
|
||||
if e == nil {
|
||||
return "internal server error"
|
||||
}
|
||||
|
||||
message := strings.TrimSpace(e.Message)
|
||||
if message == "" {
|
||||
switch e.normalizedStatusCode() {
|
||||
case http.StatusServiceUnavailable:
|
||||
return "auth service is unavailable"
|
||||
case http.StatusBadRequest:
|
||||
return "request is invalid"
|
||||
default:
|
||||
return "internal server error"
|
||||
}
|
||||
}
|
||||
|
||||
return message
|
||||
}
|
||||
|
||||
// unavailableAuthServiceClient keeps the public auth surface mounted until a
|
||||
// concrete upstream adapter is wired into the gateway process.
|
||||
type unavailableAuthServiceClient struct{}
|
||||
|
||||
func (unavailableAuthServiceClient) SendEmailCode(context.Context, SendEmailCodeInput) (SendEmailCodeResult, error) {
|
||||
return SendEmailCodeResult{}, &AuthServiceError{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Code: errorCodeServiceUnavailable,
|
||||
Message: "auth service is unavailable",
|
||||
}
|
||||
}
|
||||
|
||||
func (unavailableAuthServiceClient) ConfirmEmailCode(context.Context, ConfirmEmailCodeInput) (ConfirmEmailCodeResult, error) {
|
||||
return ConfirmEmailCodeResult{}, &AuthServiceError{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Code: errorCodeServiceUnavailable,
|
||||
Message: "auth service is unavailable",
|
||||
}
|
||||
}
|
||||
|
||||
func handleSendEmailCode(authService AuthServiceClient, timeout time.Duration) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var input SendEmailCodeInput
|
||||
if err := decodeJSONRequest(c.Request, &input); err != nil {
|
||||
abortInvalidRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
if err := validateSendEmailCodeInput(&input); err != nil {
|
||||
abortInvalidRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
callCtx, cancel := context.WithTimeout(c.Request.Context(), timeout)
|
||||
defer cancel()
|
||||
|
||||
result, err := authService.SendEmailCode(callCtx, input)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
abortWithError(c, http.StatusServiceUnavailable, errorCodeServiceUnavailable, "auth service is unavailable")
|
||||
return
|
||||
}
|
||||
abortWithAuthServiceError(c, err)
|
||||
return
|
||||
}
|
||||
if err := validateSendEmailCodeResult(&result); err != nil {
|
||||
abortWithError(c, http.StatusInternalServerError, errorCodeInternalError, "internal server error")
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, result)
|
||||
}
|
||||
}
|
||||
|
||||
func handleConfirmEmailCode(authService AuthServiceClient, timeout time.Duration) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var input ConfirmEmailCodeInput
|
||||
if err := decodeJSONRequest(c.Request, &input); err != nil {
|
||||
abortInvalidRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
if err := validateConfirmEmailCodeInput(&input); err != nil {
|
||||
abortInvalidRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
callCtx, cancel := context.WithTimeout(c.Request.Context(), timeout)
|
||||
defer cancel()
|
||||
|
||||
result, err := authService.ConfirmEmailCode(callCtx, input)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
abortWithError(c, http.StatusServiceUnavailable, errorCodeServiceUnavailable, "auth service is unavailable")
|
||||
return
|
||||
}
|
||||
abortWithAuthServiceError(c, err)
|
||||
return
|
||||
}
|
||||
if err := validateConfirmEmailCodeResult(&result); err != nil {
|
||||
abortWithError(c, http.StatusInternalServerError, errorCodeInternalError, "internal server error")
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, result)
|
||||
}
|
||||
}
|
||||
|
||||
func abortInvalidRequest(c *gin.Context, message string) {
|
||||
abortWithError(c, http.StatusBadRequest, errorCodeInvalidRequest, message)
|
||||
}
|
||||
|
||||
func abortWithAuthServiceError(c *gin.Context, err error) {
|
||||
var authErr *AuthServiceError
|
||||
if errors.As(err, &authErr) {
|
||||
abortWithError(c, authErr.normalizedStatusCode(), authErr.normalizedCode(), authErr.normalizedMessage())
|
||||
return
|
||||
}
|
||||
|
||||
abortWithError(c, http.StatusInternalServerError, errorCodeInternalError, "internal server error")
|
||||
}
|
||||
|
||||
func decodeJSONRequest(r *http.Request, target any) error {
|
||||
if r == nil || r.Body == nil {
|
||||
return &malformedJSONRequestError{
|
||||
message: "request body must not be empty",
|
||||
reason: PublicMalformedRequestReasonEmptyBody,
|
||||
}
|
||||
}
|
||||
|
||||
return decodeJSONReader(r.Body, target)
|
||||
}
|
||||
|
||||
func decodeJSONBytes(bodyBytes []byte, target any) error {
|
||||
return decodeJSONReader(bytes.NewReader(bodyBytes), target)
|
||||
}
|
||||
|
||||
func decodeJSONReader(reader io.Reader, target any) error {
|
||||
decoder := json.NewDecoder(reader)
|
||||
decoder.DisallowUnknownFields()
|
||||
|
||||
if err := decoder.Decode(target); err != nil {
|
||||
return describeJSONDecodeError(err)
|
||||
}
|
||||
|
||||
if err := decoder.Decode(&struct{}{}); err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &malformedJSONRequestError{
|
||||
message: "request body must contain a single JSON object",
|
||||
reason: PublicMalformedRequestReasonMultipleJSONObjects,
|
||||
}
|
||||
}
|
||||
|
||||
return &malformedJSONRequestError{
|
||||
message: "request body must contain a single JSON object",
|
||||
reason: PublicMalformedRequestReasonMultipleJSONObjects,
|
||||
}
|
||||
}
|
||||
|
||||
func describeJSONDecodeError(err error) error {
|
||||
var syntaxErr *json.SyntaxError
|
||||
var typeErr *json.UnmarshalTypeError
|
||||
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
return &malformedJSONRequestError{
|
||||
message: "request body must not be empty",
|
||||
reason: PublicMalformedRequestReasonEmptyBody,
|
||||
}
|
||||
case errors.As(err, &syntaxErr):
|
||||
return &malformedJSONRequestError{
|
||||
message: "request body contains malformed JSON",
|
||||
reason: PublicMalformedRequestReasonMalformedJSON,
|
||||
}
|
||||
case errors.Is(err, io.ErrUnexpectedEOF):
|
||||
return &malformedJSONRequestError{
|
||||
message: "request body contains malformed JSON",
|
||||
reason: PublicMalformedRequestReasonMalformedJSON,
|
||||
}
|
||||
case errors.As(err, &typeErr):
|
||||
if strings.TrimSpace(typeErr.Field) != "" {
|
||||
return &malformedJSONRequestError{
|
||||
message: fmt.Sprintf("request body contains an invalid value for %q", typeErr.Field),
|
||||
reason: PublicMalformedRequestReasonInvalidJSONValue,
|
||||
}
|
||||
}
|
||||
return &malformedJSONRequestError{
|
||||
message: "request body contains an invalid JSON value",
|
||||
reason: PublicMalformedRequestReasonInvalidJSONValue,
|
||||
}
|
||||
case strings.HasPrefix(err.Error(), "json: unknown field "):
|
||||
return &malformedJSONRequestError{
|
||||
message: fmt.Sprintf("request body contains unknown field %s", strings.TrimPrefix(err.Error(), "json: unknown field ")),
|
||||
reason: PublicMalformedRequestReasonUnknownField,
|
||||
}
|
||||
default:
|
||||
return &malformedJSONRequestError{
|
||||
message: "request body contains invalid JSON",
|
||||
reason: PublicMalformedRequestReasonMalformedJSON,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func validateSendEmailCodeInput(input *SendEmailCodeInput) error {
|
||||
input.Email = strings.TrimSpace(input.Email)
|
||||
if input.Email == "" {
|
||||
return errors.New("email must not be empty")
|
||||
}
|
||||
|
||||
parsedAddress, err := mail.ParseAddress(input.Email)
|
||||
if err != nil || parsedAddress.Name != "" || parsedAddress.Address != input.Email {
|
||||
return errors.New("email must be a single valid email address")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateSendEmailCodeResult(result *SendEmailCodeResult) error {
|
||||
result.ChallengeID = strings.TrimSpace(result.ChallengeID)
|
||||
if result.ChallengeID == "" {
|
||||
return errors.New("auth service returned an empty challenge_id")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateConfirmEmailCodeInput(input *ConfirmEmailCodeInput) error {
|
||||
input.ChallengeID = strings.TrimSpace(input.ChallengeID)
|
||||
if input.ChallengeID == "" {
|
||||
return errors.New("challenge_id must not be empty")
|
||||
}
|
||||
|
||||
input.Code = strings.TrimSpace(input.Code)
|
||||
if input.Code == "" {
|
||||
return errors.New("code must not be empty")
|
||||
}
|
||||
|
||||
input.ClientPublicKey = strings.TrimSpace(input.ClientPublicKey)
|
||||
if input.ClientPublicKey == "" {
|
||||
return errors.New("client_public_key must not be empty")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateConfirmEmailCodeResult(result *ConfirmEmailCodeResult) error {
|
||||
result.DeviceSessionID = strings.TrimSpace(result.DeviceSessionID)
|
||||
if result.DeviceSessionID == "" {
|
||||
return errors.New("auth service returned an empty device_session_id")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func malformedRequestReasonFromError(err error) (PublicMalformedRequestReason, bool) {
|
||||
var malformedErr *malformedJSONRequestError
|
||||
if !errors.As(err, &malformedErr) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return malformedErr.reason, true
|
||||
}
|
||||
|
||||
func extractPublicAuthIdentity(requestPath string, bodyBytes []byte) (publicAuthIdentity, error) {
|
||||
switch requestPath {
|
||||
case "/api/v1/public/auth/send-email-code":
|
||||
var input SendEmailCodeInput
|
||||
if err := decodeJSONBytes(bodyBytes, &input); err != nil {
|
||||
return publicAuthIdentity{}, err
|
||||
}
|
||||
if err := validateSendEmailCodeInput(&input); err != nil {
|
||||
return publicAuthIdentity{}, err
|
||||
}
|
||||
|
||||
return publicAuthIdentity{
|
||||
kind: "email",
|
||||
value: input.Email,
|
||||
}, nil
|
||||
case "/api/v1/public/auth/confirm-email-code":
|
||||
var input ConfirmEmailCodeInput
|
||||
if err := decodeJSONBytes(bodyBytes, &input); err != nil {
|
||||
return publicAuthIdentity{}, err
|
||||
}
|
||||
if err := validateConfirmEmailCodeInput(&input); err != nil {
|
||||
return publicAuthIdentity{}, err
|
||||
}
|
||||
|
||||
return publicAuthIdentity{
|
||||
kind: "challenge",
|
||||
value: input.ChallengeID,
|
||||
}, nil
|
||||
default:
|
||||
return publicAuthIdentity{}, errPublicAuthIdentityNotApplicable
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,377 @@
|
||||
package restapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/config"
|
||||
"galaxy/gateway/internal/testutil"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSendEmailCodeHandlerSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
authService := &recordingAuthServiceClient{
|
||||
sendEmailCodeResult: SendEmailCodeResult{
|
||||
ChallengeID: "challenge-123",
|
||||
},
|
||||
}
|
||||
handler := newPublicHandler(ServerDependencies{AuthService: authService})
|
||||
|
||||
req := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/api/v1/public/auth/send-email-code",
|
||||
strings.NewReader(`{"email":" pilot@example.com "}`),
|
||||
)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
|
||||
assert.Equal(t, `{"challenge_id":"challenge-123"}`, recorder.Body.String())
|
||||
assert.Equal(t, 1, authService.sendEmailCodeCalls)
|
||||
assert.Equal(t, 0, authService.confirmEmailCodeCalls)
|
||||
assert.Equal(t, SendEmailCodeInput{Email: "pilot@example.com"}, authService.sendEmailCodeInput)
|
||||
assert.True(t, authService.sendEmailCodeRouteClassOK)
|
||||
assert.Equal(t, PublicRouteClassPublicAuth, authService.sendEmailCodeRouteClass)
|
||||
}
|
||||
|
||||
func TestConfirmEmailCodeHandlerSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
authService := &recordingAuthServiceClient{
|
||||
confirmEmailCodeResult: ConfirmEmailCodeResult{
|
||||
DeviceSessionID: "device-session-123",
|
||||
},
|
||||
}
|
||||
handler := newPublicHandler(ServerDependencies{AuthService: authService})
|
||||
|
||||
req := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/api/v1/public/auth/confirm-email-code",
|
||||
strings.NewReader(`{"challenge_id":" challenge-123 ","code":" 123456 ","client_public_key":" public-key-material "}`),
|
||||
)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
|
||||
assert.Equal(t, `{"device_session_id":"device-session-123"}`, recorder.Body.String())
|
||||
assert.Equal(t, 0, authService.sendEmailCodeCalls)
|
||||
assert.Equal(t, 1, authService.confirmEmailCodeCalls)
|
||||
assert.Equal(t, ConfirmEmailCodeInput{
|
||||
ChallengeID: "challenge-123",
|
||||
Code: "123456",
|
||||
ClientPublicKey: "public-key-material",
|
||||
}, authService.confirmEmailCodeInput)
|
||||
assert.True(t, authService.confirmEmailCodeRouteClassOK)
|
||||
assert.Equal(t, PublicRouteClassPublicAuth, authService.confirmEmailCodeRouteClass)
|
||||
}
|
||||
|
||||
func TestPublicAuthHandlersRejectInvalidRequests(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
target string
|
||||
body string
|
||||
wantStatus int
|
||||
wantBody string
|
||||
wantSendCalls int
|
||||
wantConfirmCalls int
|
||||
}{
|
||||
{
|
||||
name: "send email malformed json",
|
||||
target: "/api/v1/public/auth/send-email-code",
|
||||
body: `{"email":`,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: `{"error":{"code":"invalid_request","message":"request body contains malformed JSON"}}`,
|
||||
wantSendCalls: 0,
|
||||
wantConfirmCalls: 0,
|
||||
},
|
||||
{
|
||||
name: "send email validation error",
|
||||
target: "/api/v1/public/auth/send-email-code",
|
||||
body: `{"email":"not-an-email"}`,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: `{"error":{"code":"invalid_request","message":"email must be a single valid email address"}}`,
|
||||
wantSendCalls: 0,
|
||||
wantConfirmCalls: 0,
|
||||
},
|
||||
{
|
||||
name: "confirm email empty code",
|
||||
target: "/api/v1/public/auth/confirm-email-code",
|
||||
body: `{"challenge_id":"challenge-123","code":" ","client_public_key":"public-key-material"}`,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: `{"error":{"code":"invalid_request","message":"code must not be empty"}}`,
|
||||
wantSendCalls: 0,
|
||||
wantConfirmCalls: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
authService := &recordingAuthServiceClient{}
|
||||
handler := newPublicHandler(ServerDependencies{AuthService: authService})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, tt.target, strings.NewReader(tt.body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, tt.wantStatus, recorder.Code)
|
||||
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
|
||||
assert.Equal(t, tt.wantBody, recorder.Body.String())
|
||||
assert.Equal(t, tt.wantSendCalls, authService.sendEmailCodeCalls)
|
||||
assert.Equal(t, tt.wantConfirmCalls, authService.confirmEmailCodeCalls)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublicAuthHandlersMapAdapterErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
target string
|
||||
body string
|
||||
authClient *recordingAuthServiceClient
|
||||
wantStatus int
|
||||
wantBody string
|
||||
}{
|
||||
{
|
||||
name: "auth service projected bad request",
|
||||
target: "/api/v1/public/auth/confirm-email-code",
|
||||
body: `{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`,
|
||||
authClient: &recordingAuthServiceClient{
|
||||
confirmEmailCodeErr: &AuthServiceError{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Code: errorCodeInvalidRequest,
|
||||
Message: "confirmation code is invalid",
|
||||
},
|
||||
},
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: `{"error":{"code":"invalid_request","message":"confirmation code is invalid"}}`,
|
||||
},
|
||||
{
|
||||
name: "auth service projected custom too many requests",
|
||||
target: "/api/v1/public/auth/send-email-code",
|
||||
body: `{"email":"pilot@example.com"}`,
|
||||
authClient: &recordingAuthServiceClient{
|
||||
sendEmailCodeErr: &AuthServiceError{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Code: "upstream_rate_limited",
|
||||
Message: "too many attempts for this email",
|
||||
},
|
||||
},
|
||||
wantStatus: http.StatusTooManyRequests,
|
||||
wantBody: `{"error":{"code":"upstream_rate_limited","message":"too many attempts for this email"}}`,
|
||||
},
|
||||
{
|
||||
name: "auth service projected gateway normalizes blank gateway error fields",
|
||||
target: "/api/v1/public/auth/confirm-email-code",
|
||||
body: `{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`,
|
||||
authClient: &recordingAuthServiceClient{
|
||||
confirmEmailCodeErr: &AuthServiceError{
|
||||
StatusCode: http.StatusBadGateway,
|
||||
},
|
||||
},
|
||||
wantStatus: http.StatusBadGateway,
|
||||
wantBody: `{"error":{"code":"internal_error","message":"internal server error"}}`,
|
||||
},
|
||||
{
|
||||
name: "unexpected auth service error",
|
||||
target: "/api/v1/public/auth/send-email-code",
|
||||
body: `{"email":"pilot@example.com"}`,
|
||||
authClient: &recordingAuthServiceClient{
|
||||
sendEmailCodeErr: errors.New("boom"),
|
||||
},
|
||||
wantStatus: http.StatusInternalServerError,
|
||||
wantBody: `{"error":{"code":"internal_error","message":"internal server error"}}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newPublicHandler(ServerDependencies{AuthService: tt.authClient})
|
||||
req := httptest.NewRequest(http.MethodPost, tt.target, strings.NewReader(tt.body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, tt.wantStatus, recorder.Code)
|
||||
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
|
||||
assert.Equal(t, tt.wantBody, recorder.Body.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAuthServiceReturnsServiceUnavailable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newPublicHandler(ServerDependencies{})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
target string
|
||||
body string
|
||||
wantStatus int
|
||||
wantBody string
|
||||
}{
|
||||
{
|
||||
name: "send email code",
|
||||
method: http.MethodPost,
|
||||
target: "/api/v1/public/auth/send-email-code",
|
||||
body: `{"email":"pilot@example.com"}`,
|
||||
wantStatus: http.StatusServiceUnavailable,
|
||||
wantBody: `{"error":{"code":"service_unavailable","message":"auth service is unavailable"}}`,
|
||||
},
|
||||
{
|
||||
name: "confirm email code",
|
||||
method: http.MethodPost,
|
||||
target: "/api/v1/public/auth/confirm-email-code",
|
||||
body: `{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`,
|
||||
wantStatus: http.StatusServiceUnavailable,
|
||||
wantBody: `{"error":{"code":"service_unavailable","message":"auth service is unavailable"}}`,
|
||||
},
|
||||
{
|
||||
name: "healthz remains available",
|
||||
method: http.MethodGet,
|
||||
target: "/healthz",
|
||||
wantStatus: http.StatusOK,
|
||||
wantBody: `{"status":"ok"}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := httptest.NewRequest(tt.method, tt.target, strings.NewReader(tt.body))
|
||||
if tt.body != "" {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, tt.wantStatus, recorder.Code)
|
||||
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
|
||||
assert.Equal(t, tt.wantBody, recorder.Body.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublicAuthHandlerTimeoutMapsToServiceUnavailable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
authService := &recordingAuthServiceClient{
|
||||
sendEmailCodeErr: context.DeadlineExceeded,
|
||||
}
|
||||
cfg := config.DefaultPublicHTTPConfig()
|
||||
cfg.AuthUpstreamTimeout = 5 * time.Millisecond
|
||||
handler := newPublicHandlerWithConfig(cfg, ServerDependencies{AuthService: authService})
|
||||
|
||||
req := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/api/v1/public/auth/send-email-code",
|
||||
strings.NewReader(`{"email":"pilot@example.com"}`),
|
||||
)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code)
|
||||
assert.Equal(t, `{"error":{"code":"service_unavailable","message":"auth service is unavailable"}}`, recorder.Body.String())
|
||||
}
|
||||
|
||||
func TestPublicAuthLogsDoNotContainSensitiveFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger, buffer := testutil.NewObservedLogger(t)
|
||||
handler := newPublicHandler(ServerDependencies{
|
||||
Logger: logger,
|
||||
AuthService: &recordingAuthServiceClient{
|
||||
confirmEmailCodeResult: ConfirmEmailCodeResult{DeviceSessionID: "device-session-123"},
|
||||
},
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/api/v1/public/auth/confirm-email-code",
|
||||
strings.NewReader(`{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`),
|
||||
)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(recorder, req)
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
|
||||
logOutput := buffer.String()
|
||||
assert.NotContains(t, logOutput, "challenge-123")
|
||||
assert.NotContains(t, logOutput, "123456")
|
||||
assert.NotContains(t, logOutput, "public-key-material")
|
||||
assert.NotContains(t, logOutput, "pilot@example.com")
|
||||
}
|
||||
|
||||
// recordingAuthServiceClient captures handler inputs and route classification
|
||||
// so tests can assert the exact adapter delegation contract.
|
||||
type recordingAuthServiceClient struct {
|
||||
sendEmailCodeResult SendEmailCodeResult
|
||||
sendEmailCodeErr error
|
||||
sendEmailCodeInput SendEmailCodeInput
|
||||
sendEmailCodeRouteClass PublicRouteClass
|
||||
sendEmailCodeRouteClassOK bool
|
||||
sendEmailCodeCalls int
|
||||
|
||||
confirmEmailCodeResult ConfirmEmailCodeResult
|
||||
confirmEmailCodeErr error
|
||||
confirmEmailCodeInput ConfirmEmailCodeInput
|
||||
confirmEmailCodeRouteClass PublicRouteClass
|
||||
confirmEmailCodeRouteClassOK bool
|
||||
confirmEmailCodeCalls int
|
||||
}
|
||||
|
||||
func (c *recordingAuthServiceClient) SendEmailCode(ctx context.Context, input SendEmailCodeInput) (SendEmailCodeResult, error) {
|
||||
c.sendEmailCodeCalls++
|
||||
c.sendEmailCodeInput = input
|
||||
|
||||
c.sendEmailCodeRouteClass, c.sendEmailCodeRouteClassOK = PublicRouteClassFromContext(ctx)
|
||||
|
||||
return c.sendEmailCodeResult, c.sendEmailCodeErr
|
||||
}
|
||||
|
||||
func (c *recordingAuthServiceClient) ConfirmEmailCode(ctx context.Context, input ConfirmEmailCodeInput) (ConfirmEmailCodeResult, error) {
|
||||
c.confirmEmailCodeCalls++
|
||||
c.confirmEmailCodeInput = input
|
||||
|
||||
c.confirmEmailCodeRouteClass, c.confirmEmailCodeRouteClassOK = PublicRouteClassFromContext(ctx)
|
||||
|
||||
return c.confirmEmailCodeResult, c.confirmEmailCodeErr
|
||||
}
|
||||
@@ -0,0 +1,388 @@
|
||||
// Package restapi exposes the unauthenticated public REST surface of the
|
||||
// gateway.
|
||||
package restapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"galaxy/gateway/internal/config"
|
||||
"galaxy/gateway/internal/telemetry"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
jsonContentType = "application/json; charset=utf-8"
|
||||
|
||||
errorCodeInvalidRequest = "invalid_request"
|
||||
errorCodeNotFound = "not_found"
|
||||
errorCodeMethodNotAllowed = "method_not_allowed"
|
||||
errorCodeInternalError = "internal_error"
|
||||
errorCodeServiceUnavailable = "service_unavailable"
|
||||
|
||||
publicRESTBaseBucketKeyPrefix = "public_rest/class="
|
||||
)
|
||||
|
||||
// PublicRouteClass identifies the public traffic class assigned to an incoming
|
||||
// REST request before route handling and edge policy evaluation.
|
||||
type PublicRouteClass string
|
||||
|
||||
const (
|
||||
// PublicRouteClassPublicAuth identifies public authentication commands.
|
||||
PublicRouteClassPublicAuth PublicRouteClass = "public_auth"
|
||||
|
||||
// PublicRouteClassBrowserBootstrap identifies browser bootstrap traffic such
|
||||
// as the main document request.
|
||||
PublicRouteClassBrowserBootstrap PublicRouteClass = "browser_bootstrap"
|
||||
|
||||
// PublicRouteClassBrowserAsset identifies browser asset requests.
|
||||
PublicRouteClassBrowserAsset PublicRouteClass = "browser_asset"
|
||||
|
||||
// PublicRouteClassPublicMisc identifies public traffic that does not match a
|
||||
// more specific class.
|
||||
PublicRouteClassPublicMisc PublicRouteClass = "public_misc"
|
||||
)
|
||||
|
||||
var configureGinModeOnce sync.Once
|
||||
|
||||
// Normalized returns c when it belongs to the stable public route class set.
|
||||
// Unknown or empty values collapse to PublicRouteClassPublicMisc so edge policy
|
||||
// code can rely on a fixed anti-abuse namespace.
|
||||
func (c PublicRouteClass) Normalized() PublicRouteClass {
|
||||
switch c {
|
||||
case PublicRouteClassPublicAuth,
|
||||
PublicRouteClassBrowserBootstrap,
|
||||
PublicRouteClassBrowserAsset,
|
||||
PublicRouteClassPublicMisc:
|
||||
return c
|
||||
default:
|
||||
return PublicRouteClassPublicMisc
|
||||
}
|
||||
}
|
||||
|
||||
// BaseBucketKey returns the canonical base rate-limit namespace for c. The key
|
||||
// stays scoped only by the normalized public route class; callers may append
|
||||
// subject dimensions such as IP or identity without redefining the class
|
||||
// namespace.
|
||||
func (c PublicRouteClass) BaseBucketKey() string {
|
||||
return publicRESTBaseBucketKeyPrefix + string(c.Normalized())
|
||||
}
|
||||
|
||||
// PublicTrafficClassifier maps public REST requests to the public anti-abuse
|
||||
// class used by the gateway edge. The server normalizes classifier outputs to
|
||||
// the stable class set before storing them in request context.
|
||||
type PublicTrafficClassifier interface {
|
||||
Classify(*http.Request) PublicRouteClass
|
||||
}
|
||||
|
||||
// ServerDependencies describes the optional collaborators used by the public
|
||||
// REST server. The zero value is valid and keeps the process runnable with the
|
||||
// built-in defaults.
|
||||
type ServerDependencies struct {
|
||||
// Classifier assigns the public anti-abuse class before route handling.
|
||||
// When nil, the gateway default classifier is used.
|
||||
Classifier PublicTrafficClassifier
|
||||
|
||||
// AuthService delegates public auth commands to the Auth / Session Service.
|
||||
// When nil, public auth routes remain mounted and return a stable
|
||||
// service-unavailable response.
|
||||
AuthService AuthServiceClient
|
||||
|
||||
// Limiter applies the public REST rate-limit policy. When nil, a default
|
||||
// process-local in-memory limiter is used.
|
||||
Limiter PublicRequestLimiter
|
||||
|
||||
// Observer records malformed-request telemetry for the public REST layer.
|
||||
// When nil, a no-op observer is used.
|
||||
Observer PublicRequestObserver
|
||||
|
||||
// Logger writes structured transport logs for public REST traffic. When nil,
|
||||
// a no-op logger is used.
|
||||
Logger *zap.Logger
|
||||
|
||||
// Telemetry records low-cardinality edge metrics. When nil, metrics are
|
||||
// disabled.
|
||||
Telemetry *telemetry.Runtime
|
||||
}
|
||||
|
||||
// Server owns the public unauthenticated REST listener exposed by the gateway.
|
||||
type Server struct {
|
||||
cfg config.PublicHTTPConfig
|
||||
|
||||
handler http.Handler
|
||||
logger *zap.Logger
|
||||
|
||||
stateMu sync.RWMutex
|
||||
server *http.Server
|
||||
listener net.Listener
|
||||
}
|
||||
|
||||
// NewServer constructs a public REST server for the supplied listener
|
||||
// configuration and dependency bundle. Nil dependencies are replaced with safe
|
||||
// defaults so the gateway can still expose the documented public surface.
|
||||
func NewServer(cfg config.PublicHTTPConfig, deps ServerDependencies) *Server {
|
||||
deps = normalizeServerDependencies(deps)
|
||||
|
||||
return &Server{
|
||||
cfg: cfg,
|
||||
handler: newPublicHandlerWithConfig(cfg, deps),
|
||||
logger: deps.Logger.Named("public_http"),
|
||||
}
|
||||
}
|
||||
|
||||
// Run binds the configured listener and serves the public REST surface until
|
||||
// Shutdown closes the server.
|
||||
func (s *Server) Run(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
return errors.New("run public REST server: nil context")
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", s.cfg.Addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("run public REST server: listen on %q: %w", s.cfg.Addr, err)
|
||||
}
|
||||
|
||||
server := &http.Server{
|
||||
Handler: s.handler,
|
||||
ReadHeaderTimeout: s.cfg.ReadHeaderTimeout,
|
||||
ReadTimeout: s.cfg.ReadTimeout,
|
||||
IdleTimeout: s.cfg.IdleTimeout,
|
||||
}
|
||||
|
||||
s.stateMu.Lock()
|
||||
s.server = server
|
||||
s.listener = listener
|
||||
s.stateMu.Unlock()
|
||||
|
||||
s.logger.Info("public REST server started", zap.String("addr", listener.Addr().String()))
|
||||
|
||||
defer func() {
|
||||
s.stateMu.Lock()
|
||||
s.server = nil
|
||||
s.listener = nil
|
||||
s.stateMu.Unlock()
|
||||
}()
|
||||
|
||||
err = server.Serve(listener)
|
||||
switch {
|
||||
case err == nil:
|
||||
return nil
|
||||
case errors.Is(err, http.ErrServerClosed):
|
||||
s.logger.Info("public REST server stopped")
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("run public REST server: serve on %q: %w", s.cfg.Addr, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown gracefully stops the public REST server within ctx.
|
||||
func (s *Server) Shutdown(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
return errors.New("shutdown public REST server: nil context")
|
||||
}
|
||||
|
||||
s.stateMu.RLock()
|
||||
server := s.server
|
||||
s.stateMu.RUnlock()
|
||||
|
||||
if server == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := server.Shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
return fmt.Errorf("shutdown public REST server: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PublicRouteClassFromContext returns the previously classified normalized
|
||||
// public route class stored in ctx.
|
||||
func PublicRouteClassFromContext(ctx context.Context) (PublicRouteClass, bool) {
|
||||
if ctx == nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
class, ok := ctx.Value(publicRouteClassContextKey{}).(PublicRouteClass)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return class.Normalized(), true
|
||||
}
|
||||
|
||||
type publicRouteClassContextKey struct{}
|
||||
|
||||
type defaultPublicTrafficClassifier struct{}
|
||||
|
||||
// Classify maps the incoming request into a stable public route class that can
|
||||
// later drive anti-abuse policy and rate limiting.
|
||||
func (defaultPublicTrafficClassifier) Classify(r *http.Request) PublicRouteClass {
|
||||
switch {
|
||||
case isPublicAuthRequest(r):
|
||||
return PublicRouteClassPublicAuth
|
||||
case isBrowserBootstrapRequest(r):
|
||||
return PublicRouteClassBrowserBootstrap
|
||||
case isBrowserAssetRequest(r):
|
||||
return PublicRouteClassBrowserAsset
|
||||
default:
|
||||
return PublicRouteClassPublicMisc
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeServerDependencies(deps ServerDependencies) ServerDependencies {
|
||||
if deps.Classifier == nil {
|
||||
deps.Classifier = defaultPublicTrafficClassifier{}
|
||||
}
|
||||
if deps.AuthService == nil {
|
||||
deps.AuthService = unavailableAuthServiceClient{}
|
||||
}
|
||||
if deps.Limiter == nil {
|
||||
deps.Limiter = newInMemoryPublicRequestLimiter()
|
||||
}
|
||||
if deps.Observer == nil {
|
||||
deps.Observer = noopPublicRequestObserver{}
|
||||
}
|
||||
if deps.Logger == nil {
|
||||
deps.Logger = zap.NewNop()
|
||||
}
|
||||
|
||||
return deps
|
||||
}
|
||||
|
||||
func newPublicHandler(deps ServerDependencies) http.Handler {
|
||||
return newPublicHandlerWithConfig(config.DefaultPublicHTTPConfig(), deps)
|
||||
}
|
||||
|
||||
func newPublicHandlerWithConfig(cfg config.PublicHTTPConfig, deps ServerDependencies) http.Handler {
|
||||
configureGinModeOnce.Do(func() {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
})
|
||||
|
||||
deps = normalizeServerDependencies(deps)
|
||||
|
||||
router := gin.New()
|
||||
router.HandleMethodNotAllowed = true
|
||||
router.Use(gin.CustomRecovery(func(c *gin.Context, _ any) {
|
||||
abortWithError(c, http.StatusInternalServerError, errorCodeInternalError, "internal server error")
|
||||
}))
|
||||
router.Use(otelgin.Middleware("galaxy-edge-gateway-public"))
|
||||
router.Use(withPublicObservability(deps.Logger.Named("public_http"), deps.Telemetry))
|
||||
router.Use(withPublicRouteClass(deps.Classifier))
|
||||
router.Use(withPublicAntiAbuse(cfg.AntiAbuse, deps.Limiter, deps.Observer))
|
||||
|
||||
router.GET("/healthz", handleHealthz)
|
||||
router.GET("/readyz", handleReadyz)
|
||||
router.POST("/api/v1/public/auth/send-email-code", handleSendEmailCode(deps.AuthService, cfg.AuthUpstreamTimeout))
|
||||
router.POST("/api/v1/public/auth/confirm-email-code", handleConfirmEmailCode(deps.AuthService, cfg.AuthUpstreamTimeout))
|
||||
|
||||
router.NoMethod(func(c *gin.Context) {
|
||||
allowMethods := allowedMethodsForPath(c.Request.URL.Path)
|
||||
if allowMethods != "" {
|
||||
c.Header("Allow", allowMethods)
|
||||
}
|
||||
|
||||
abortWithError(c, http.StatusMethodNotAllowed, errorCodeMethodNotAllowed, "request method is not allowed for this route")
|
||||
})
|
||||
router.NoRoute(func(c *gin.Context) {
|
||||
abortWithError(c, http.StatusNotFound, errorCodeNotFound, "resource was not found")
|
||||
})
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
func handleHealthz(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, statusResponse{Status: "ok"})
|
||||
}
|
||||
|
||||
func handleReadyz(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, statusResponse{Status: "ready"})
|
||||
}
|
||||
|
||||
func withPublicRouteClass(classifier PublicTrafficClassifier) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
class := classifier.Classify(c.Request).Normalized()
|
||||
ctx := context.WithValue(c.Request.Context(), publicRouteClassContextKey{}, class)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func isPublicAuthRequest(r *http.Request) bool {
|
||||
return r.Method == http.MethodPost && isPublicAuthPath(r.URL.Path)
|
||||
}
|
||||
|
||||
func isBrowserBootstrapRequest(r *http.Request) bool {
|
||||
if r.Method == http.MethodGet && r.URL.Path == "/" {
|
||||
return true
|
||||
}
|
||||
|
||||
return matchesBrowserBootstrapRequestShape(r)
|
||||
}
|
||||
|
||||
func isBrowserAssetRequest(r *http.Request) bool {
|
||||
if r.Method != http.MethodGet && r.Method != http.MethodHead {
|
||||
return false
|
||||
}
|
||||
|
||||
return matchesBrowserAssetRequestShape(r)
|
||||
}
|
||||
|
||||
type statusResponse struct {
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
type errorResponse struct {
|
||||
Error errorBody `json:"error"`
|
||||
}
|
||||
|
||||
type errorBody struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func abortWithError(c *gin.Context, statusCode int, code string, message string) {
|
||||
if c != nil {
|
||||
c.Set(publicErrorCodeContextKey, code)
|
||||
}
|
||||
c.AbortWithStatusJSON(statusCode, errorResponse{
|
||||
Error: errorBody{
|
||||
Code: code,
|
||||
Message: message,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
const publicErrorCodeContextKey = "public_error_code"
|
||||
|
||||
func allowedMethodsForPath(requestPath string) string {
|
||||
switch requestPath {
|
||||
case "/healthz", "/readyz":
|
||||
return http.MethodGet
|
||||
case "/api/v1/public/auth/send-email-code", "/api/v1/public/auth/confirm-email-code":
|
||||
return http.MethodPost
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) listenAddr() string {
|
||||
s.stateMu.RLock()
|
||||
defer s.stateMu.RUnlock()
|
||||
|
||||
if s.listener == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return s.listener.Addr().String()
|
||||
}
|
||||
@@ -0,0 +1,459 @@
|
||||
package restapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/app"
|
||||
"galaxy/gateway/internal/config"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPublicHandlerHealthEndpoints(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := newPublicHandler(ServerDependencies{})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
target string
|
||||
wantStatus int
|
||||
wantType string
|
||||
wantBody string
|
||||
wantAllow string
|
||||
}{
|
||||
{
|
||||
name: "healthz",
|
||||
method: http.MethodGet,
|
||||
target: "/healthz",
|
||||
wantStatus: http.StatusOK,
|
||||
wantType: jsonContentType,
|
||||
wantBody: `{"status":"ok"}`,
|
||||
},
|
||||
{
|
||||
name: "readyz",
|
||||
method: http.MethodGet,
|
||||
target: "/readyz",
|
||||
wantStatus: http.StatusOK,
|
||||
wantType: jsonContentType,
|
||||
wantBody: `{"status":"ready"}`,
|
||||
},
|
||||
{
|
||||
name: "wrong method on known route",
|
||||
method: http.MethodPost,
|
||||
target: "/healthz",
|
||||
wantStatus: http.StatusMethodNotAllowed,
|
||||
wantType: jsonContentType,
|
||||
wantBody: `{"error":{"code":"method_not_allowed","message":"request method is not allowed for this route"}}`,
|
||||
wantAllow: http.MethodGet,
|
||||
},
|
||||
{
|
||||
name: "unknown route",
|
||||
method: http.MethodGet,
|
||||
target: "/unknown",
|
||||
wantStatus: http.StatusNotFound,
|
||||
wantType: jsonContentType,
|
||||
wantBody: `{"error":{"code":"not_found","message":"resource was not found"}}`,
|
||||
},
|
||||
{
|
||||
name: "wrong method on public auth route",
|
||||
method: http.MethodGet,
|
||||
target: "/api/v1/public/auth/send-email-code",
|
||||
wantStatus: http.StatusMethodNotAllowed,
|
||||
wantType: jsonContentType,
|
||||
wantBody: `{"error":{"code":"method_not_allowed","message":"request method is not allowed for this route"}}`,
|
||||
wantAllow: http.MethodPost,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := httptest.NewRequest(tt.method, tt.target, nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, tt.wantStatus, recorder.Code)
|
||||
assert.Equal(t, tt.wantType, recorder.Header().Get("Content-Type"))
|
||||
assert.Equal(t, tt.wantBody, recorder.Body.String())
|
||||
assert.Equal(t, tt.wantAllow, recorder.Header().Get("Allow"))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultPublicTrafficClassifier(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
classifier := defaultPublicTrafficClassifier{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
target string
|
||||
accept string
|
||||
wantClass PublicRouteClass
|
||||
}{
|
||||
{
|
||||
name: "public auth route",
|
||||
method: http.MethodPost,
|
||||
target: "/api/v1/public/auth/send-email-code",
|
||||
wantClass: PublicRouteClassPublicAuth,
|
||||
},
|
||||
{
|
||||
name: "public auth confirm route",
|
||||
method: http.MethodPost,
|
||||
target: "/api/v1/public/auth/confirm-email-code",
|
||||
wantClass: PublicRouteClassPublicAuth,
|
||||
},
|
||||
{
|
||||
name: "browser bootstrap route",
|
||||
method: http.MethodGet,
|
||||
target: "/",
|
||||
wantClass: PublicRouteClassBrowserBootstrap,
|
||||
},
|
||||
{
|
||||
name: "browser asset route",
|
||||
method: http.MethodGet,
|
||||
target: "/assets/app.js",
|
||||
wantClass: PublicRouteClassBrowserAsset,
|
||||
},
|
||||
{
|
||||
name: "browser asset head request",
|
||||
method: http.MethodHead,
|
||||
target: "/assets/app.js",
|
||||
wantClass: PublicRouteClassBrowserAsset,
|
||||
},
|
||||
{
|
||||
name: "browser asset extension request",
|
||||
method: http.MethodGet,
|
||||
target: "/manifest.webmanifest",
|
||||
wantClass: PublicRouteClassBrowserAsset,
|
||||
},
|
||||
{
|
||||
name: "public misc route",
|
||||
method: http.MethodPost,
|
||||
target: "/api/v1/public/unknown",
|
||||
wantClass: PublicRouteClassPublicMisc,
|
||||
},
|
||||
{
|
||||
name: "html accept bootstrap route",
|
||||
method: http.MethodGet,
|
||||
target: "/app",
|
||||
accept: "application/json, text/html;q=0.9",
|
||||
wantClass: PublicRouteClassBrowserBootstrap,
|
||||
},
|
||||
{
|
||||
name: "public auth wins over browser accept header",
|
||||
method: http.MethodPost,
|
||||
target: "/api/v1/public/auth/confirm-email-code",
|
||||
accept: "text/html",
|
||||
wantClass: PublicRouteClassPublicAuth,
|
||||
},
|
||||
{
|
||||
name: "probe with html accept is bootstrap",
|
||||
method: http.MethodGet,
|
||||
target: "/healthz",
|
||||
accept: "text/html",
|
||||
wantClass: PublicRouteClassBrowserBootstrap,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := httptest.NewRequest(tt.method, tt.target, nil)
|
||||
if tt.accept != "" {
|
||||
req.Header.Set("Accept", tt.accept)
|
||||
}
|
||||
|
||||
assert.Equal(t, tt.wantClass, classifier.Classify(req))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublicRouteClassNormalized(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input PublicRouteClass
|
||||
want PublicRouteClass
|
||||
}{
|
||||
{
|
||||
name: "public auth",
|
||||
input: PublicRouteClassPublicAuth,
|
||||
want: PublicRouteClassPublicAuth,
|
||||
},
|
||||
{
|
||||
name: "browser bootstrap",
|
||||
input: PublicRouteClassBrowserBootstrap,
|
||||
want: PublicRouteClassBrowserBootstrap,
|
||||
},
|
||||
{
|
||||
name: "browser asset",
|
||||
input: PublicRouteClassBrowserAsset,
|
||||
want: PublicRouteClassBrowserAsset,
|
||||
},
|
||||
{
|
||||
name: "public misc",
|
||||
input: PublicRouteClassPublicMisc,
|
||||
want: PublicRouteClassPublicMisc,
|
||||
},
|
||||
{
|
||||
name: "unknown collapses to misc",
|
||||
input: PublicRouteClass("unexpected"),
|
||||
want: PublicRouteClassPublicMisc,
|
||||
},
|
||||
{
|
||||
name: "empty collapses to misc",
|
||||
input: PublicRouteClass(""),
|
||||
want: PublicRouteClassPublicMisc,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert.Equal(t, tt.want, tt.input.Normalized())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublicRouteClassBaseBucketKeyIsolation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
class PublicRouteClass
|
||||
wantKey string
|
||||
}{
|
||||
{
|
||||
name: "public auth",
|
||||
class: PublicRouteClassPublicAuth,
|
||||
wantKey: "public_rest/class=public_auth",
|
||||
},
|
||||
{
|
||||
name: "browser bootstrap",
|
||||
class: PublicRouteClassBrowserBootstrap,
|
||||
wantKey: "public_rest/class=browser_bootstrap",
|
||||
},
|
||||
{
|
||||
name: "browser asset",
|
||||
class: PublicRouteClassBrowserAsset,
|
||||
wantKey: "public_rest/class=browser_asset",
|
||||
},
|
||||
{
|
||||
name: "public misc",
|
||||
class: PublicRouteClassPublicMisc,
|
||||
wantKey: "public_rest/class=public_misc",
|
||||
},
|
||||
{
|
||||
name: "unknown collapses to misc namespace",
|
||||
class: PublicRouteClass("unexpected"),
|
||||
wantKey: "public_rest/class=public_misc",
|
||||
},
|
||||
}
|
||||
|
||||
seenKeys := make(map[string]PublicRouteClass, len(tests))
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert.Equal(t, tt.wantKey, tt.class.BaseBucketKey())
|
||||
})
|
||||
|
||||
normalizedClass := tt.class.Normalized()
|
||||
if normalizedClass == PublicRouteClassPublicMisc && tt.class != PublicRouteClassPublicMisc {
|
||||
continue
|
||||
}
|
||||
|
||||
if previousClass, exists := seenKeys[tt.wantKey]; exists {
|
||||
require.FailNowf(t, "bucket key collision", "class %q collides with %q on key %q", tt.class, previousClass, tt.wantKey)
|
||||
}
|
||||
|
||||
seenKeys[tt.wantKey] = tt.class
|
||||
}
|
||||
|
||||
assert.NotEqual(t, PublicRouteClassPublicAuth.BaseBucketKey(), PublicRouteClassBrowserBootstrap.BaseBucketKey())
|
||||
assert.NotEqual(t, PublicRouteClassPublicAuth.BaseBucketKey(), PublicRouteClassBrowserAsset.BaseBucketKey())
|
||||
}
|
||||
|
||||
func TestWithPublicRouteClassStoresClassInContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
router := gin.New()
|
||||
router.Use(withPublicRouteClass(staticClassifier{class: PublicRouteClassBrowserAsset}))
|
||||
router.GET("/assets/app.js", func(c *gin.Context) {
|
||||
class, ok := PublicRouteClassFromContext(c.Request.Context())
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, PublicRouteClassBrowserAsset, class)
|
||||
|
||||
c.JSON(http.StatusOK, statusResponse{Status: "ok"})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/assets/app.js", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, `{"status":"ok"}`, recorder.Body.String())
|
||||
}
|
||||
|
||||
func TestWithPublicRouteClassNormalizesUnsupportedClassToPublicMisc(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
class PublicRouteClass
|
||||
}{
|
||||
{
|
||||
name: "unknown class",
|
||||
class: PublicRouteClass("unexpected"),
|
||||
},
|
||||
{
|
||||
name: "empty class",
|
||||
class: PublicRouteClass(""),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
router := gin.New()
|
||||
router.Use(withPublicRouteClass(staticClassifier{class: tt.class}))
|
||||
router.GET("/", func(c *gin.Context) {
|
||||
class, ok := PublicRouteClassFromContext(c.Request.Context())
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, PublicRouteClassPublicMisc, class)
|
||||
|
||||
c.JSON(http.StatusOK, statusResponse{Status: "ok"})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, `{"status":"ok"}`, recorder.Body.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerLifecycle(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := config.Config{
|
||||
ShutdownTimeout: time.Second,
|
||||
PublicHTTP: func() config.PublicHTTPConfig {
|
||||
publicHTTPCfg := config.DefaultPublicHTTPConfig()
|
||||
publicHTTPCfg.Addr = "127.0.0.1:0"
|
||||
publicHTTPCfg.AntiAbuse.PublicMisc.RateLimit = config.PublicRateLimitConfig{
|
||||
Requests: 1000,
|
||||
Window: time.Minute,
|
||||
Burst: 1000,
|
||||
}
|
||||
return publicHTTPCfg
|
||||
}(),
|
||||
}
|
||||
|
||||
server := NewServer(cfg.PublicHTTP, ServerDependencies{})
|
||||
application := app.New(cfg, server)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
resultCh := make(chan error, 1)
|
||||
go func() {
|
||||
resultCh <- application.Run(ctx)
|
||||
}()
|
||||
|
||||
addr := waitForListenAddr(t, server)
|
||||
waitForHealthResponse(t, addr)
|
||||
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case err := <-resultCh:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(2 * time.Second):
|
||||
require.FailNow(t, "Run() did not return after cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
type staticClassifier struct {
|
||||
class PublicRouteClass
|
||||
}
|
||||
|
||||
func (c staticClassifier) Classify(*http.Request) PublicRouteClass {
|
||||
return c.class
|
||||
}
|
||||
|
||||
func waitForListenAddr(t *testing.T, server *Server) string {
|
||||
t.Helper()
|
||||
|
||||
deadline := time.Now().Add(time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
if addr := server.listenAddr(); addr != "" {
|
||||
return addr
|
||||
}
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
require.FailNow(t, "server did not start listening")
|
||||
return ""
|
||||
}
|
||||
|
||||
func waitForHealthResponse(t *testing.T, addr string) {
|
||||
t.Helper()
|
||||
|
||||
client := &http.Client{Timeout: 100 * time.Millisecond}
|
||||
url := "http://" + addr + "/healthz"
|
||||
deadline := time.Now().Add(time.Second)
|
||||
|
||||
for time.Now().Before(deadline) {
|
||||
resp, err := client.Get(url)
|
||||
if err != nil {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
body, readErr := io.ReadAll(resp.Body)
|
||||
closeErr := resp.Body.Close()
|
||||
require.NoError(t, readErr)
|
||||
require.NoError(t, closeErr)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.Equal(t, `{"status":"ok"}`, strings.TrimSpace(string(body)))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
require.FailNowf(t, "health check timed out", "url=%s", url)
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// MemoryCache stores session record snapshots in process-local memory. It is
|
||||
// intended for the authenticated gateway hot path and deliberately keeps no
|
||||
// TTL or size-based eviction policy.
|
||||
type MemoryCache struct {
|
||||
mu sync.RWMutex
|
||||
records map[string]Record
|
||||
}
|
||||
|
||||
// NewMemoryCache constructs an empty process-local session snapshot store.
|
||||
func NewMemoryCache() *MemoryCache {
|
||||
return &MemoryCache{
|
||||
records: make(map[string]Record),
|
||||
}
|
||||
}
|
||||
|
||||
// Lookup resolves deviceSessionID from the process-local snapshot map.
|
||||
func (c *MemoryCache) Lookup(ctx context.Context, deviceSessionID string) (Record, error) {
|
||||
if c == nil {
|
||||
return Record{}, errors.New("lookup session from in-memory cache: nil cache")
|
||||
}
|
||||
if ctx == nil || fmt.Sprint(ctx) == "context.TODO" {
|
||||
return Record{}, errors.New("lookup session from in-memory cache: nil context")
|
||||
}
|
||||
if strings.TrimSpace(deviceSessionID) == "" {
|
||||
return Record{}, errors.New("lookup session from in-memory cache: empty device session id")
|
||||
}
|
||||
|
||||
c.mu.RLock()
|
||||
record, ok := c.records[deviceSessionID]
|
||||
c.mu.RUnlock()
|
||||
if !ok {
|
||||
return Record{}, fmt.Errorf("lookup session from in-memory cache: %w", ErrNotFound)
|
||||
}
|
||||
|
||||
return cloneRecord(record), nil
|
||||
}
|
||||
|
||||
// Upsert stores record in the process-local snapshot map after validating the
|
||||
// same session invariants expected from the Redis-backed cache.
|
||||
func (c *MemoryCache) Upsert(record Record) error {
|
||||
if c == nil {
|
||||
return errors.New("upsert session into in-memory cache: nil cache")
|
||||
}
|
||||
if err := validateRecord(record.DeviceSessionID, record); err != nil {
|
||||
return fmt.Errorf("upsert session into in-memory cache: %w", err)
|
||||
}
|
||||
|
||||
cloned := cloneRecord(record)
|
||||
|
||||
c.mu.Lock()
|
||||
c.records[record.DeviceSessionID] = cloned
|
||||
c.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes the local snapshot for deviceSessionID when one exists.
|
||||
func (c *MemoryCache) Delete(deviceSessionID string) {
|
||||
if c == nil || strings.TrimSpace(deviceSessionID) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
delete(c.records, deviceSessionID)
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func cloneRecord(record Record) Record {
|
||||
cloned := record
|
||||
if record.RevokedAtMS != nil {
|
||||
value := *record.RevokedAtMS
|
||||
cloned.RevokedAtMS = &value
|
||||
}
|
||||
|
||||
return cloned
|
||||
}
|
||||
|
||||
var _ SnapshotStore = (*MemoryCache)(nil)
|
||||
@@ -0,0 +1,68 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ReadThroughCache resolves authenticated sessions from a process-local
|
||||
// SnapshotStore first and falls back to another Cache only on a local miss.
|
||||
type ReadThroughCache struct {
|
||||
local SnapshotStore
|
||||
fallback Cache
|
||||
}
|
||||
|
||||
// NewReadThroughCache constructs a hot-path cache that seeds local snapshots
|
||||
// from fallback on demand.
|
||||
func NewReadThroughCache(local SnapshotStore, fallback Cache) (*ReadThroughCache, error) {
|
||||
if local == nil {
|
||||
return nil, errors.New("new read-through session cache: nil local cache")
|
||||
}
|
||||
if fallback == nil {
|
||||
return nil, errors.New("new read-through session cache: nil fallback cache")
|
||||
}
|
||||
|
||||
return &ReadThroughCache{
|
||||
local: local,
|
||||
fallback: fallback,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Lookup resolves deviceSessionID from local first, then performs one fallback
|
||||
// lookup on a local miss and seeds the local cache with the returned snapshot.
|
||||
func (c *ReadThroughCache) Lookup(ctx context.Context, deviceSessionID string) (Record, error) {
|
||||
if c == nil {
|
||||
return Record{}, errors.New("lookup session from read-through cache: nil cache")
|
||||
}
|
||||
|
||||
record, err := c.local.Lookup(ctx, deviceSessionID)
|
||||
switch {
|
||||
case err == nil:
|
||||
return record, nil
|
||||
case !errors.Is(err, ErrNotFound):
|
||||
return Record{}, fmt.Errorf("lookup session from read-through cache: %w", err)
|
||||
}
|
||||
|
||||
record, err = c.fallback.Lookup(ctx, deviceSessionID)
|
||||
if err != nil {
|
||||
return Record{}, err
|
||||
}
|
||||
|
||||
if err := c.local.Upsert(record); err != nil {
|
||||
return Record{}, fmt.Errorf("lookup session from read-through cache: seed local cache: %w", err)
|
||||
}
|
||||
|
||||
return cloneRecord(record), nil
|
||||
}
|
||||
|
||||
// Local returns the mutable process-local snapshot store used by c.
|
||||
func (c *ReadThroughCache) Local() SnapshotStore {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return c.local
|
||||
}
|
||||
|
||||
var _ Cache = (*ReadThroughCache)(nil)
|
||||
@@ -0,0 +1,176 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMemoryCacheLookupReturnsClonedRecord(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cache := NewMemoryCache()
|
||||
revokedAtMS := int64(123456789)
|
||||
|
||||
require.NoError(t, cache.Upsert(Record{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: "public-key-123",
|
||||
Status: StatusRevoked,
|
||||
RevokedAtMS: &revokedAtMS,
|
||||
}))
|
||||
|
||||
record, err := cache.Lookup(context.Background(), "device-session-123")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, record.RevokedAtMS)
|
||||
|
||||
*record.RevokedAtMS = 1
|
||||
|
||||
stored, err := cache.Lookup(context.Background(), "device-session-123")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, stored.RevokedAtMS)
|
||||
assert.Equal(t, revokedAtMS, *stored.RevokedAtMS)
|
||||
}
|
||||
|
||||
func TestReadThroughCacheLocalHitSkipsFallback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
local := NewMemoryCache()
|
||||
require.NoError(t, local.Upsert(Record{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: "public-key-123",
|
||||
Status: StatusActive,
|
||||
}))
|
||||
|
||||
fallback := &recordingCache{
|
||||
lookupFunc: func(context.Context, string) (Record, error) {
|
||||
return Record{}, errors.New("fallback should not be called")
|
||||
},
|
||||
}
|
||||
|
||||
cache, err := NewReadThroughCache(local, fallback)
|
||||
require.NoError(t, err)
|
||||
|
||||
record, err := cache.Lookup(context.Background(), "device-session-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, Record{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: "public-key-123",
|
||||
Status: StatusActive,
|
||||
}, record)
|
||||
assert.Equal(t, 0, fallback.lookupCalls)
|
||||
}
|
||||
|
||||
func TestReadThroughCacheFallbackSeedsLocalCache(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
local := NewMemoryCache()
|
||||
fallback := &recordingCache{
|
||||
lookupFunc: func(context.Context, string) (Record, error) {
|
||||
return Record{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: "public-key-123",
|
||||
Status: StatusActive,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
cache, err := NewReadThroughCache(local, fallback)
|
||||
require.NoError(t, err)
|
||||
|
||||
record, err := cache.Lookup(context.Background(), "device-session-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, fallback.lookupCalls)
|
||||
assert.Equal(t, "user-123", record.UserID)
|
||||
|
||||
record, err = cache.Lookup(context.Background(), "device-session-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, fallback.lookupCalls)
|
||||
assert.Equal(t, "user-123", record.UserID)
|
||||
}
|
||||
|
||||
func TestReadThroughCacheKeepsRevokedSnapshotLocal(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
revokedAtMS := int64(123456789)
|
||||
local := NewMemoryCache()
|
||||
fallback := &recordingCache{
|
||||
lookupFunc: func(context.Context, string) (Record, error) {
|
||||
return Record{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: "public-key-123",
|
||||
Status: StatusRevoked,
|
||||
RevokedAtMS: &revokedAtMS,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
cache, err := NewReadThroughCache(local, fallback)
|
||||
require.NoError(t, err)
|
||||
|
||||
record, err := cache.Lookup(context.Background(), "device-session-123")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, record.RevokedAtMS)
|
||||
assert.Equal(t, StatusRevoked, record.Status)
|
||||
assert.Equal(t, 1, fallback.lookupCalls)
|
||||
|
||||
record, err = cache.Lookup(context.Background(), "device-session-123")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, record.RevokedAtMS)
|
||||
assert.Equal(t, StatusRevoked, record.Status)
|
||||
assert.Equal(t, revokedAtMS, *record.RevokedAtMS)
|
||||
assert.Equal(t, 1, fallback.lookupCalls)
|
||||
}
|
||||
|
||||
func TestReadThroughCacheReturnsClonedFallbackRecord(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
revokedAtMS := int64(123456789)
|
||||
local := NewMemoryCache()
|
||||
fallback := &recordingCache{
|
||||
lookupFunc: func(context.Context, string) (Record, error) {
|
||||
return Record{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: "public-key-123",
|
||||
Status: StatusRevoked,
|
||||
RevokedAtMS: &revokedAtMS,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
cache, err := NewReadThroughCache(local, fallback)
|
||||
require.NoError(t, err)
|
||||
|
||||
record, err := cache.Lookup(context.Background(), "device-session-123")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, record.RevokedAtMS)
|
||||
|
||||
*record.RevokedAtMS = 1
|
||||
|
||||
stored, err := local.Lookup(context.Background(), "device-session-123")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, stored.RevokedAtMS)
|
||||
assert.Equal(t, revokedAtMS, *stored.RevokedAtMS)
|
||||
}
|
||||
|
||||
type recordingCache struct {
|
||||
lookupCalls int
|
||||
lookupFunc func(context.Context, string) (Record, error)
|
||||
}
|
||||
|
||||
func (c *recordingCache) Lookup(ctx context.Context, deviceSessionID string) (Record, error) {
|
||||
c.lookupCalls++
|
||||
if c.lookupFunc != nil {
|
||||
return c.lookupFunc(ctx, deviceSessionID)
|
||||
}
|
||||
|
||||
return Record{}, errors.New("lookup is not implemented")
|
||||
}
|
||||
@@ -0,0 +1,192 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/config"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// RedisCache implements Cache with Redis GET lookups over strict JSON session
|
||||
// records.
|
||||
type RedisCache struct {
|
||||
client *redis.Client
|
||||
keyPrefix string
|
||||
lookupTimeout time.Duration
|
||||
}
|
||||
|
||||
type redisRecord struct {
|
||||
DeviceSessionID string `json:"device_session_id"`
|
||||
UserID string `json:"user_id"`
|
||||
ClientPublicKey string `json:"client_public_key"`
|
||||
Status Status `json:"status"`
|
||||
RevokedAtMS *int64 `json:"revoked_at_ms,omitempty"`
|
||||
}
|
||||
|
||||
// NewRedisCache constructs a Redis-backed SessionCache from cfg. The returned
|
||||
// cache is read-only from the gateway perspective and does not write or mutate
|
||||
// Redis state.
|
||||
func NewRedisCache(cfg config.SessionCacheRedisConfig) (*RedisCache, error) {
|
||||
if strings.TrimSpace(cfg.Addr) == "" {
|
||||
return nil, errors.New("new redis session cache: redis addr must not be empty")
|
||||
}
|
||||
if cfg.DB < 0 {
|
||||
return nil, errors.New("new redis session cache: redis db must not be negative")
|
||||
}
|
||||
if cfg.LookupTimeout <= 0 {
|
||||
return nil, errors.New("new redis session cache: lookup timeout must be positive")
|
||||
}
|
||||
|
||||
options := &redis.Options{
|
||||
Addr: cfg.Addr,
|
||||
Username: cfg.Username,
|
||||
Password: cfg.Password,
|
||||
DB: cfg.DB,
|
||||
Protocol: 2,
|
||||
DisableIdentity: true,
|
||||
}
|
||||
if cfg.TLSEnabled {
|
||||
options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
|
||||
}
|
||||
|
||||
return &RedisCache{
|
||||
client: redis.NewClient(options),
|
||||
keyPrefix: cfg.KeyPrefix,
|
||||
lookupTimeout: cfg.LookupTimeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close releases the underlying Redis client resources.
|
||||
func (c *RedisCache) Close() error {
|
||||
if c == nil || c.client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return c.client.Close()
|
||||
}
|
||||
|
||||
// Ping verifies that the configured Redis backend is reachable within the
|
||||
// cache lookup timeout budget.
|
||||
func (c *RedisCache) Ping(ctx context.Context) error {
|
||||
if c == nil || c.client == nil {
|
||||
return errors.New("ping redis session cache: nil cache")
|
||||
}
|
||||
if ctx == nil {
|
||||
return errors.New("ping redis session cache: nil context")
|
||||
}
|
||||
|
||||
pingCtx, cancel := context.WithTimeout(ctx, c.lookupTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := c.client.Ping(pingCtx).Err(); err != nil {
|
||||
return fmt.Errorf("ping redis session cache: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Lookup resolves deviceSessionID from Redis, validates the cached JSON
|
||||
// payload strictly, and returns the decoded session record.
|
||||
func (c *RedisCache) Lookup(ctx context.Context, deviceSessionID string) (Record, error) {
|
||||
if c == nil || c.client == nil {
|
||||
return Record{}, errors.New("lookup session from redis: nil cache")
|
||||
}
|
||||
if ctx == nil || fmt.Sprint(ctx) == "context.TODO" {
|
||||
return Record{}, errors.New("lookup session from redis: nil context")
|
||||
}
|
||||
if strings.TrimSpace(deviceSessionID) == "" {
|
||||
return Record{}, errors.New("lookup session from redis: empty device session id")
|
||||
}
|
||||
|
||||
lookupCtx, cancel := context.WithTimeout(ctx, c.lookupTimeout)
|
||||
defer cancel()
|
||||
|
||||
payload, err := c.client.Get(lookupCtx, c.lookupKey(deviceSessionID)).Bytes()
|
||||
switch {
|
||||
case errors.Is(err, redis.Nil):
|
||||
return Record{}, fmt.Errorf("lookup session from redis: %w", ErrNotFound)
|
||||
case err != nil:
|
||||
return Record{}, fmt.Errorf("lookup session from redis: %w", err)
|
||||
}
|
||||
|
||||
record, err := decodeRedisRecord(deviceSessionID, payload)
|
||||
if err != nil {
|
||||
return Record{}, fmt.Errorf("lookup session from redis: %w", err)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func (c *RedisCache) lookupKey(deviceSessionID string) string {
|
||||
return c.keyPrefix + deviceSessionID
|
||||
}
|
||||
|
||||
func decodeRedisRecord(expectedDeviceSessionID string, payload []byte) (Record, error) {
|
||||
decoder := json.NewDecoder(bytes.NewReader(payload))
|
||||
decoder.DisallowUnknownFields()
|
||||
|
||||
var stored redisRecord
|
||||
if err := decoder.Decode(&stored); err != nil {
|
||||
return Record{}, fmt.Errorf("decode redis session record: %w", err)
|
||||
}
|
||||
if err := decoder.Decode(&struct{}{}); err != io.EOF {
|
||||
if err == nil {
|
||||
return Record{}, errors.New("decode redis session record: unexpected trailing JSON input")
|
||||
}
|
||||
return Record{}, fmt.Errorf("decode redis session record: %w", err)
|
||||
}
|
||||
|
||||
record := Record{
|
||||
DeviceSessionID: stored.DeviceSessionID,
|
||||
UserID: stored.UserID,
|
||||
ClientPublicKey: stored.ClientPublicKey,
|
||||
Status: stored.Status,
|
||||
RevokedAtMS: cloneOptionalInt64(stored.RevokedAtMS),
|
||||
}
|
||||
|
||||
if err := validateRecord(expectedDeviceSessionID, record); err != nil {
|
||||
return Record{}, err
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func validateRecord(expectedDeviceSessionID string, record Record) error {
|
||||
if record.DeviceSessionID == "" {
|
||||
return errors.New("session record device_session_id must not be empty")
|
||||
}
|
||||
if record.DeviceSessionID != expectedDeviceSessionID {
|
||||
return fmt.Errorf("session record device_session_id %q does not match requested %q", record.DeviceSessionID, expectedDeviceSessionID)
|
||||
}
|
||||
if record.UserID == "" {
|
||||
return errors.New("session record user_id must not be empty")
|
||||
}
|
||||
if record.ClientPublicKey == "" {
|
||||
return errors.New("session record client_public_key must not be empty")
|
||||
}
|
||||
if !record.Status.IsKnown() {
|
||||
return fmt.Errorf("session record status %q is unsupported", record.Status)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func cloneOptionalInt64(value *int64) *int64 {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cloned := *value
|
||||
return &cloned
|
||||
}
|
||||
|
||||
var _ Cache = (*RedisCache)(nil)
|
||||
@@ -0,0 +1,331 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/config"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewRedisCache(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg config.SessionCacheRedisConfig
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
cfg: config.SessionCacheRedisConfig{
|
||||
Addr: server.Addr(),
|
||||
DB: 2,
|
||||
KeyPrefix: "gateway:session:",
|
||||
LookupTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty addr",
|
||||
cfg: config.SessionCacheRedisConfig{
|
||||
LookupTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "redis addr must not be empty",
|
||||
},
|
||||
{
|
||||
name: "negative db",
|
||||
cfg: config.SessionCacheRedisConfig{
|
||||
Addr: server.Addr(),
|
||||
DB: -1,
|
||||
LookupTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
wantErr: "redis db must not be negative",
|
||||
},
|
||||
{
|
||||
name: "non-positive lookup timeout",
|
||||
cfg: config.SessionCacheRedisConfig{
|
||||
Addr: server.Addr(),
|
||||
},
|
||||
wantErr: "lookup timeout must be positive",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cache, err := NewRedisCache(tt.cfg)
|
||||
if tt.wantErr != "" {
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, cache.Close())
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisCachePing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
cache := newTestRedisCache(t, server, config.SessionCacheRedisConfig{})
|
||||
|
||||
require.NoError(t, cache.Ping(context.Background()))
|
||||
}
|
||||
|
||||
func TestRedisCacheLookup(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
revokedAtMS := int64(123456789)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg config.SessionCacheRedisConfig
|
||||
requestID string
|
||||
seed func(*testing.T, *miniredis.Miniredis, config.SessionCacheRedisConfig)
|
||||
want Record
|
||||
wantErrIs error
|
||||
wantErrText string
|
||||
assertErrText string
|
||||
}{
|
||||
{
|
||||
name: "active cache hit",
|
||||
requestID: "device-session-123",
|
||||
cfg: config.SessionCacheRedisConfig{
|
||||
KeyPrefix: "gateway:session:",
|
||||
},
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) {
|
||||
t.Helper()
|
||||
setRedisSessionRecord(t, server, cfg.KeyPrefix+"device-session-123", redisRecord{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: "public-key-123",
|
||||
Status: StatusActive,
|
||||
})
|
||||
},
|
||||
want: Record{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: "public-key-123",
|
||||
Status: StatusActive,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing session",
|
||||
requestID: "device-session-404",
|
||||
cfg: config.SessionCacheRedisConfig{
|
||||
KeyPrefix: "gateway:session:",
|
||||
},
|
||||
wantErrIs: ErrNotFound,
|
||||
assertErrText: "session cache record not found",
|
||||
},
|
||||
{
|
||||
name: "revoked session",
|
||||
requestID: "device-session-revoked",
|
||||
cfg: config.SessionCacheRedisConfig{
|
||||
KeyPrefix: "gateway:session:",
|
||||
},
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) {
|
||||
t.Helper()
|
||||
setRedisSessionRecord(t, server, cfg.KeyPrefix+"device-session-revoked", redisRecord{
|
||||
DeviceSessionID: "device-session-revoked",
|
||||
UserID: "user-777",
|
||||
ClientPublicKey: "public-key-777",
|
||||
Status: StatusRevoked,
|
||||
RevokedAtMS: &revokedAtMS,
|
||||
})
|
||||
},
|
||||
want: Record{
|
||||
DeviceSessionID: "device-session-revoked",
|
||||
UserID: "user-777",
|
||||
ClientPublicKey: "public-key-777",
|
||||
Status: StatusRevoked,
|
||||
RevokedAtMS: &revokedAtMS,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "malformed json",
|
||||
requestID: "device-session-bad-json",
|
||||
cfg: config.SessionCacheRedisConfig{
|
||||
KeyPrefix: "gateway:session:",
|
||||
},
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) {
|
||||
t.Helper()
|
||||
server.Set(cfg.KeyPrefix+"device-session-bad-json", "{")
|
||||
},
|
||||
wantErrText: "decode redis session record",
|
||||
},
|
||||
{
|
||||
name: "unknown status",
|
||||
requestID: "device-session-unknown-status",
|
||||
cfg: config.SessionCacheRedisConfig{
|
||||
KeyPrefix: "gateway:session:",
|
||||
},
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) {
|
||||
t.Helper()
|
||||
setRedisSessionRecord(t, server, cfg.KeyPrefix+"device-session-unknown-status", redisRecord{
|
||||
DeviceSessionID: "device-session-unknown-status",
|
||||
UserID: "user-1",
|
||||
ClientPublicKey: "public-key-1",
|
||||
Status: Status("paused"),
|
||||
})
|
||||
},
|
||||
wantErrText: `status "paused" is unsupported`,
|
||||
},
|
||||
{
|
||||
name: "missing required field",
|
||||
requestID: "device-session-missing-user",
|
||||
cfg: config.SessionCacheRedisConfig{
|
||||
KeyPrefix: "gateway:session:",
|
||||
},
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) {
|
||||
t.Helper()
|
||||
setRedisSessionRecord(t, server, cfg.KeyPrefix+"device-session-missing-user", redisRecord{
|
||||
DeviceSessionID: "device-session-missing-user",
|
||||
ClientPublicKey: "public-key-1",
|
||||
Status: StatusActive,
|
||||
})
|
||||
},
|
||||
wantErrText: "user_id must not be empty",
|
||||
},
|
||||
{
|
||||
name: "device session id mismatch",
|
||||
requestID: "device-session-requested",
|
||||
cfg: config.SessionCacheRedisConfig{
|
||||
KeyPrefix: "gateway:session:",
|
||||
},
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) {
|
||||
t.Helper()
|
||||
setRedisSessionRecord(t, server, cfg.KeyPrefix+"device-session-requested", redisRecord{
|
||||
DeviceSessionID: "device-session-other",
|
||||
UserID: "user-1",
|
||||
ClientPublicKey: "public-key-1",
|
||||
Status: StatusActive,
|
||||
})
|
||||
},
|
||||
wantErrText: `does not match requested "device-session-requested"`,
|
||||
},
|
||||
{
|
||||
name: "key prefix is honored",
|
||||
requestID: "device-session-prefixed",
|
||||
cfg: config.SessionCacheRedisConfig{
|
||||
KeyPrefix: "custom:session:",
|
||||
},
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) {
|
||||
t.Helper()
|
||||
setRedisSessionRecord(t, server, cfg.KeyPrefix+"device-session-prefixed", redisRecord{
|
||||
DeviceSessionID: "device-session-prefixed",
|
||||
UserID: "user-prefixed",
|
||||
ClientPublicKey: "public-key-prefixed",
|
||||
Status: StatusActive,
|
||||
})
|
||||
setRedisSessionRecord(t, server, "gateway:session:device-session-prefixed", redisRecord{
|
||||
DeviceSessionID: "device-session-prefixed",
|
||||
UserID: "wrong-user",
|
||||
ClientPublicKey: "wrong-key",
|
||||
Status: StatusRevoked,
|
||||
})
|
||||
},
|
||||
want: Record{
|
||||
DeviceSessionID: "device-session-prefixed",
|
||||
UserID: "user-prefixed",
|
||||
ClientPublicKey: "public-key-prefixed",
|
||||
Status: StatusActive,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
|
||||
cfg := tt.cfg
|
||||
cfg.Addr = server.Addr()
|
||||
cfg.DB = 0
|
||||
cfg.LookupTimeout = 250 * time.Millisecond
|
||||
|
||||
if tt.seed != nil {
|
||||
tt.seed(t, server, cfg)
|
||||
}
|
||||
|
||||
cache := newTestRedisCache(t, server, cfg)
|
||||
record, err := cache.Lookup(context.Background(), tt.requestID)
|
||||
if tt.wantErrIs != nil || tt.wantErrText != "" {
|
||||
require.Error(t, err)
|
||||
if tt.wantErrIs != nil {
|
||||
assert.ErrorIs(t, err, tt.wantErrIs)
|
||||
}
|
||||
if tt.wantErrText != "" {
|
||||
assert.ErrorContains(t, err, tt.wantErrText)
|
||||
}
|
||||
if tt.assertErrText != "" {
|
||||
assert.ErrorContains(t, err, tt.assertErrText)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.want, record)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newTestRedisCache(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) *RedisCache {
|
||||
t.Helper()
|
||||
|
||||
if cfg.Addr == "" {
|
||||
cfg.Addr = server.Addr()
|
||||
}
|
||||
if cfg.LookupTimeout == 0 {
|
||||
cfg.LookupTimeout = 250 * time.Millisecond
|
||||
}
|
||||
|
||||
cache, err := NewRedisCache(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, cache.Close())
|
||||
})
|
||||
|
||||
return cache
|
||||
}
|
||||
|
||||
func setRedisSessionRecord(t *testing.T, server *miniredis.Miniredis, key string, record redisRecord) {
|
||||
t.Helper()
|
||||
|
||||
payload, err := json.Marshal(record)
|
||||
require.NoError(t, err)
|
||||
|
||||
server.Set(key, string(payload))
|
||||
}
|
||||
|
||||
func TestRedisCacheLookupNilContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
cache := newTestRedisCache(t, server, config.SessionCacheRedisConfig{})
|
||||
|
||||
_, err := cache.Lookup(context.TODO(), "device-session-123")
|
||||
require.Error(t, err)
|
||||
assert.False(t, errors.Is(err, ErrNotFound))
|
||||
assert.ErrorContains(t, err, "nil context")
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
// Package session defines the authenticated session-cache contract used by the
|
||||
// gateway hot path.
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrNotFound reports that SessionCache does not currently contain the
|
||||
// requested device session identifier.
|
||||
ErrNotFound = errors.New("session cache record not found")
|
||||
)
|
||||
|
||||
// Cache resolves authenticated device-session state from the gateway hot-path
|
||||
// cache.
|
||||
type Cache interface {
|
||||
// Lookup returns the cached record for deviceSessionID. Implementations must
|
||||
// wrap ErrNotFound when the cache does not contain the requested record.
|
||||
Lookup(ctx context.Context, deviceSessionID string) (Record, error)
|
||||
}
|
||||
|
||||
// SnapshotStore stores mutable session record snapshots inside one gateway
|
||||
// process and exposes the same read contract as Cache for the hot path.
|
||||
type SnapshotStore interface {
|
||||
Cache
|
||||
|
||||
// Upsert stores record under record.DeviceSessionID, replacing any previous
|
||||
// snapshot for that session.
|
||||
Upsert(record Record) error
|
||||
|
||||
// Delete removes the local snapshot for deviceSessionID when it exists.
|
||||
Delete(deviceSessionID string)
|
||||
}
|
||||
|
||||
// Status identifies the cached lifecycle state of a device session.
|
||||
type Status string
|
||||
|
||||
const (
|
||||
// StatusActive reports that the cached device session may continue through
|
||||
// later authenticated gateway checks.
|
||||
StatusActive Status = "active"
|
||||
|
||||
// StatusRevoked reports that the cached device session has been revoked and
|
||||
// must be rejected before later auth steps run.
|
||||
StatusRevoked Status = "revoked"
|
||||
)
|
||||
|
||||
// Record is the minimum authenticated session state required by the gateway
|
||||
// before signature verification begins.
|
||||
type Record struct {
|
||||
// DeviceSessionID is the stable device-session identifier resolved from the
|
||||
// hot-path cache.
|
||||
DeviceSessionID string
|
||||
|
||||
// UserID is the authenticated user identity bound to DeviceSessionID.
|
||||
UserID string
|
||||
|
||||
// ClientPublicKey is the standard base64-encoded raw Ed25519 public key
|
||||
// material used for request-signature verification.
|
||||
ClientPublicKey string
|
||||
|
||||
// Status reports whether the cached session is active or revoked.
|
||||
Status Status
|
||||
|
||||
// RevokedAtMS optionally records when the device session was revoked.
|
||||
RevokedAtMS *int64
|
||||
}
|
||||
|
||||
// IsKnown reports whether s is one of the session states supported by the
|
||||
// gateway.
|
||||
func (s Status) IsKnown() bool {
|
||||
switch s {
|
||||
case StatusActive, StatusRevoked:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,102 @@
|
||||
// Package telemetry provides shared edge observability helpers used by the
|
||||
// gateway transports and internal event consumers.
|
||||
package telemetry
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
)
|
||||
|
||||
// EdgeOutcome is the stable low-cardinality outcome vocabulary shared by REST,
|
||||
// gRPC, push shutdown, and observability backends.
|
||||
type EdgeOutcome string
|
||||
|
||||
const (
|
||||
EdgeOutcomeSuccess EdgeOutcome = "success"
|
||||
EdgeOutcomeMalformedRequest EdgeOutcome = "malformed_request"
|
||||
EdgeOutcomeRequestTooLarge EdgeOutcome = "request_too_large"
|
||||
EdgeOutcomeUnsupportedProtocol EdgeOutcome = "unsupported_protocol"
|
||||
EdgeOutcomeUnknownSession EdgeOutcome = "unknown_session"
|
||||
EdgeOutcomeRevokedSession EdgeOutcome = "revoked_session"
|
||||
EdgeOutcomeInvalidSignature EdgeOutcome = "invalid_signature"
|
||||
EdgeOutcomeStaleRequest EdgeOutcome = "stale_request"
|
||||
EdgeOutcomeReplayDetected EdgeOutcome = "replay_detected"
|
||||
EdgeOutcomeRateLimited EdgeOutcome = "rate_limited"
|
||||
EdgeOutcomePolicyDenied EdgeOutcome = "policy_denied"
|
||||
EdgeOutcomeDownstreamUnavailable EdgeOutcome = "downstream_unavailable"
|
||||
EdgeOutcomeBackendUnavailable EdgeOutcome = "backend_unavailable"
|
||||
EdgeOutcomeInternalError EdgeOutcome = "internal_error"
|
||||
EdgeOutcomeGatewayShuttingDown EdgeOutcome = "gateway_shutting_down"
|
||||
)
|
||||
|
||||
// RejectReason returns the stable reject reason for outcome. Success does not
|
||||
// produce a reject reason.
|
||||
func RejectReason(outcome EdgeOutcome) string {
|
||||
if outcome == EdgeOutcomeSuccess {
|
||||
return ""
|
||||
}
|
||||
|
||||
return string(outcome)
|
||||
}
|
||||
|
||||
// OutcomeFromPublicErrorCode maps the stable public REST error envelope into
|
||||
// the shared edge-outcome vocabulary.
|
||||
func OutcomeFromPublicErrorCode(statusCode int, code string) EdgeOutcome {
|
||||
switch strings.TrimSpace(code) {
|
||||
case "":
|
||||
if statusCode < http.StatusBadRequest {
|
||||
return EdgeOutcomeSuccess
|
||||
}
|
||||
return EdgeOutcomeInternalError
|
||||
case "invalid_request", "method_not_allowed", "not_found":
|
||||
return EdgeOutcomeMalformedRequest
|
||||
case "request_too_large":
|
||||
return EdgeOutcomeRequestTooLarge
|
||||
case "rate_limited":
|
||||
return EdgeOutcomeRateLimited
|
||||
case "service_unavailable":
|
||||
return EdgeOutcomeBackendUnavailable
|
||||
default:
|
||||
if statusCode >= http.StatusInternalServerError {
|
||||
return EdgeOutcomeInternalError
|
||||
}
|
||||
return EdgeOutcomeMalformedRequest
|
||||
}
|
||||
}
|
||||
|
||||
// OutcomeFromGRPCStatus maps the stable authenticated gRPC reject contract
|
||||
// into the shared edge-outcome vocabulary.
|
||||
func OutcomeFromGRPCStatus(code codes.Code, message string) EdgeOutcome {
|
||||
switch {
|
||||
case code == codes.OK:
|
||||
return EdgeOutcomeSuccess
|
||||
case code == codes.InvalidArgument:
|
||||
return EdgeOutcomeMalformedRequest
|
||||
case code == codes.FailedPrecondition && strings.Contains(message, "unsupported protocol_version"):
|
||||
return EdgeOutcomeUnsupportedProtocol
|
||||
case code == codes.Unauthenticated && message == "unknown device session":
|
||||
return EdgeOutcomeUnknownSession
|
||||
case code == codes.FailedPrecondition && message == "device session is revoked":
|
||||
return EdgeOutcomeRevokedSession
|
||||
case code == codes.Unauthenticated && message == "invalid request signature":
|
||||
return EdgeOutcomeInvalidSignature
|
||||
case code == codes.FailedPrecondition && message == "request timestamp is outside the freshness window":
|
||||
return EdgeOutcomeStaleRequest
|
||||
case code == codes.FailedPrecondition && message == "request replay detected":
|
||||
return EdgeOutcomeReplayDetected
|
||||
case code == codes.ResourceExhausted && message == "authenticated request rate limit exceeded":
|
||||
return EdgeOutcomeRateLimited
|
||||
case code == codes.PermissionDenied && message == "authenticated request rejected by edge policy":
|
||||
return EdgeOutcomePolicyDenied
|
||||
case code == codes.Unavailable && message == "downstream service is unavailable":
|
||||
return EdgeOutcomeDownstreamUnavailable
|
||||
case code == codes.Unavailable && message == "gateway is shutting down":
|
||||
return EdgeOutcomeGatewayShuttingDown
|
||||
case code == codes.Unavailable:
|
||||
return EdgeOutcomeBackendUnavailable
|
||||
default:
|
||||
return EdgeOutcomeInternalError
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
package telemetry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"galaxy/gateway/internal/push"
|
||||
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
)
|
||||
|
||||
// PushObserver adapts Runtime to the push.Observer interface.
|
||||
type PushObserver struct {
|
||||
runtime *Runtime
|
||||
}
|
||||
|
||||
// NewPushObserver constructs a push stream observer backed by runtime.
|
||||
func NewPushObserver(runtime *Runtime) *PushObserver {
|
||||
return &PushObserver{runtime: runtime}
|
||||
}
|
||||
|
||||
// Registered records one active push stream.
|
||||
func (o *PushObserver) Registered(_ push.StreamBinding) {
|
||||
if o == nil || o.runtime == nil {
|
||||
return
|
||||
}
|
||||
|
||||
o.runtime.AddActivePushStream(context.Background(), 1)
|
||||
}
|
||||
|
||||
// Unregistered records one active-stream decrement and one closure reason for
|
||||
// hub-enforced shutdown, overflow, or revocation.
|
||||
func (o *PushObserver) Unregistered(_ push.StreamBinding, err error) {
|
||||
if o == nil || o.runtime == nil {
|
||||
return
|
||||
}
|
||||
|
||||
o.runtime.AddActivePushStream(context.Background(), -1)
|
||||
|
||||
switch {
|
||||
case errors.Is(err, push.ErrSubscriptionOverflow):
|
||||
o.runtime.RecordPushStreamClosure(context.Background(), attribute.String("reason", "overflow"))
|
||||
case errors.Is(err, push.ErrSubscriptionRevoked):
|
||||
o.runtime.RecordPushStreamClosure(context.Background(), attribute.String("reason", "revoked"))
|
||||
case errors.Is(err, push.ErrHubShuttingDown):
|
||||
o.runtime.RecordPushStreamClosure(context.Background(), attribute.String("reason", "shutdown"))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,254 @@
|
||||
package telemetry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
|
||||
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp"
|
||||
otelprom "go.opentelemetry.io/otel/exporters/prometheus"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
"go.opentelemetry.io/otel/propagation"
|
||||
sdkmetric "go.opentelemetry.io/otel/sdk/metric"
|
||||
"go.opentelemetry.io/otel/sdk/resource"
|
||||
sdktrace "go.opentelemetry.io/otel/sdk/trace"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const defaultServiceName = "galaxy-edge-gateway"
|
||||
|
||||
// Runtime owns the shared OpenTelemetry providers, the Prometheus metrics
|
||||
// handler, and the custom low-cardinality edge instruments.
|
||||
type Runtime struct {
|
||||
logger *zap.Logger
|
||||
|
||||
tracerProvider *sdktrace.TracerProvider
|
||||
meterProvider *sdkmetric.MeterProvider
|
||||
promHandler http.Handler
|
||||
|
||||
// Public REST instruments.
|
||||
publicRequests metric.Int64Counter
|
||||
publicDuration metric.Float64Histogram
|
||||
|
||||
// Authenticated gRPC instruments.
|
||||
grpcRequests metric.Int64Counter
|
||||
grpcDuration metric.Float64Histogram
|
||||
|
||||
// Push instruments.
|
||||
pushActiveStreams metric.Int64UpDownCounter
|
||||
pushStreamClosers metric.Int64Counter
|
||||
|
||||
// Internal event consumer instruments.
|
||||
internalEventDrops metric.Int64Counter
|
||||
}
|
||||
|
||||
// New constructs the gateway telemetry runtime, registers global providers,
|
||||
// and returns the Prometheus handler used by the admin listener.
|
||||
func New(ctx context.Context, logger *zap.Logger) (*Runtime, error) {
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
|
||||
serviceName := strings.TrimSpace(os.Getenv("OTEL_SERVICE_NAME"))
|
||||
if serviceName == "" {
|
||||
serviceName = defaultServiceName
|
||||
}
|
||||
|
||||
res, err := resource.New(
|
||||
ctx,
|
||||
resource.WithAttributes(attribute.String("service.name", serviceName)),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tracerProvider, err := newTracerProvider(ctx, res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
registry := prometheus.NewRegistry()
|
||||
exporter, err := otelprom.New(otelprom.WithRegisterer(registry))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
meterProvider := sdkmetric.NewMeterProvider(
|
||||
sdkmetric.WithResource(res),
|
||||
sdkmetric.WithReader(exporter),
|
||||
)
|
||||
|
||||
otel.SetTracerProvider(tracerProvider)
|
||||
otel.SetMeterProvider(meterProvider)
|
||||
otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator(
|
||||
propagation.TraceContext{},
|
||||
propagation.Baggage{},
|
||||
))
|
||||
|
||||
meter := meterProvider.Meter("galaxy/gateway")
|
||||
|
||||
publicRequests, err := meter.Int64Counter("gateway.public_http.requests")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
publicDuration, err := meter.Float64Histogram("gateway.public_http.duration", metric.WithUnit("ms"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
grpcRequests, err := meter.Int64Counter("gateway.authenticated_grpc.requests")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
grpcDuration, err := meter.Float64Histogram("gateway.authenticated_grpc.duration", metric.WithUnit("ms"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pushActiveStreams, err := meter.Int64UpDownCounter("gateway.push.active_streams")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pushStreamClosers, err := meter.Int64Counter("gateway.push.stream_closures")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
internalEventDrops, err := meter.Int64Counter("gateway.internal_event_drops")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Runtime{
|
||||
logger: logger,
|
||||
tracerProvider: tracerProvider,
|
||||
meterProvider: meterProvider,
|
||||
promHandler: promhttp.HandlerFor(registry, promhttp.HandlerOpts{}),
|
||||
publicRequests: publicRequests,
|
||||
publicDuration: publicDuration,
|
||||
grpcRequests: grpcRequests,
|
||||
grpcDuration: grpcDuration,
|
||||
pushActiveStreams: pushActiveStreams,
|
||||
pushStreamClosers: pushStreamClosers,
|
||||
internalEventDrops: internalEventDrops,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Handler returns the Prometheus handler that should be mounted on the admin
|
||||
// listener.
|
||||
func (r *Runtime) Handler() http.Handler {
|
||||
if r == nil || r.promHandler == nil {
|
||||
return http.NotFoundHandler()
|
||||
}
|
||||
|
||||
return r.promHandler
|
||||
}
|
||||
|
||||
// Shutdown flushes the configured telemetry providers.
|
||||
func (r *Runtime) Shutdown(ctx context.Context) error {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var shutdownErr error
|
||||
if r.meterProvider != nil {
|
||||
shutdownErr = errors.Join(shutdownErr, r.meterProvider.Shutdown(ctx))
|
||||
}
|
||||
if r.tracerProvider != nil {
|
||||
shutdownErr = errors.Join(shutdownErr, r.tracerProvider.Shutdown(ctx))
|
||||
}
|
||||
|
||||
return shutdownErr
|
||||
}
|
||||
|
||||
// RecordPublicRequest records one public REST request outcome.
|
||||
func (r *Runtime) RecordPublicRequest(ctx context.Context, attrs []attribute.KeyValue, duration time.Duration) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
|
||||
options := metric.WithAttributes(attrs...)
|
||||
r.publicRequests.Add(ctx, 1, options)
|
||||
r.publicDuration.Record(ctx, duration.Seconds()*1000, options)
|
||||
}
|
||||
|
||||
// RecordAuthenticatedGRPC records one authenticated gRPC request or stream
|
||||
// outcome.
|
||||
func (r *Runtime) RecordAuthenticatedGRPC(ctx context.Context, attrs []attribute.KeyValue, duration time.Duration) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
|
||||
options := metric.WithAttributes(attrs...)
|
||||
r.grpcRequests.Add(ctx, 1, options)
|
||||
r.grpcDuration.Record(ctx, duration.Seconds()*1000, options)
|
||||
}
|
||||
|
||||
// AddActivePushStream records one active-stream delta.
|
||||
func (r *Runtime) AddActivePushStream(ctx context.Context, delta int64, attrs ...attribute.KeyValue) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
|
||||
r.pushActiveStreams.Add(ctx, delta, metric.WithAttributes(attrs...))
|
||||
}
|
||||
|
||||
// RecordPushStreamClosure records one push-stream closure reason.
|
||||
func (r *Runtime) RecordPushStreamClosure(ctx context.Context, attrs ...attribute.KeyValue) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
|
||||
r.pushStreamClosers.Add(ctx, 1, metric.WithAttributes(attrs...))
|
||||
}
|
||||
|
||||
// RecordInternalEventDrop records one malformed or rejected internal event.
|
||||
func (r *Runtime) RecordInternalEventDrop(ctx context.Context, attrs ...attribute.KeyValue) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
|
||||
r.internalEventDrops.Add(ctx, 1, metric.WithAttributes(attrs...))
|
||||
}
|
||||
|
||||
func newTracerProvider(ctx context.Context, res *resource.Resource) (*sdktrace.TracerProvider, error) {
|
||||
exporterName := strings.TrimSpace(os.Getenv("OTEL_TRACES_EXPORTER"))
|
||||
if exporterName == "" || exporterName == "none" {
|
||||
return sdktrace.NewTracerProvider(sdktrace.WithResource(res)), nil
|
||||
}
|
||||
|
||||
if exporterName != "otlp" {
|
||||
return nil, errors.New("unsupported OTEL_TRACES_EXPORTER value")
|
||||
}
|
||||
|
||||
protocol := strings.TrimSpace(os.Getenv("OTEL_EXPORTER_OTLP_TRACES_PROTOCOL"))
|
||||
if protocol == "" {
|
||||
protocol = strings.TrimSpace(os.Getenv("OTEL_EXPORTER_OTLP_PROTOCOL"))
|
||||
}
|
||||
|
||||
var (
|
||||
exporter sdktrace.SpanExporter
|
||||
err error
|
||||
)
|
||||
switch protocol {
|
||||
case "", "http/protobuf":
|
||||
exporter, err = otlptracehttp.New(ctx)
|
||||
case "grpc":
|
||||
exporter, err = otlptracegrpc.New(ctx)
|
||||
default:
|
||||
return nil, errors.New("unsupported OTEL exporter protocol")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return sdktrace.NewTracerProvider(
|
||||
sdktrace.WithBatcher(exporter),
|
||||
sdktrace.WithResource(res),
|
||||
), nil
|
||||
}
|
||||
@@ -0,0 +1,94 @@
|
||||
package testutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/logging"
|
||||
"galaxy/gateway/internal/telemetry"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
// LogBuffer is a concurrency-safe in-memory buffer used by observability
|
||||
// tests.
|
||||
type LogBuffer struct {
|
||||
mu sync.Mutex
|
||||
buf bytes.Buffer
|
||||
}
|
||||
|
||||
// Write appends p to the buffer.
|
||||
func (b *LogBuffer) Write(p []byte) (int, error) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
return b.buf.Write(p)
|
||||
}
|
||||
|
||||
// String returns the current buffer contents.
|
||||
func (b *LogBuffer) String() string {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
return b.buf.String()
|
||||
}
|
||||
|
||||
// NewObservedLogger constructs a JSON zap logger that writes into an in-memory
|
||||
// buffer suitable for log assertions.
|
||||
func NewObservedLogger(t *testing.T) (*zap.Logger, *LogBuffer) {
|
||||
t.Helper()
|
||||
|
||||
buffer := &LogBuffer{}
|
||||
encoderConfig := zap.NewProductionEncoderConfig()
|
||||
encoderConfig.TimeKey = "timestamp"
|
||||
encoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
|
||||
|
||||
core := zapcore.NewCore(
|
||||
zapcore.NewJSONEncoder(encoderConfig),
|
||||
zapcore.Lock(zapcore.AddSync(buffer)),
|
||||
zap.DebugLevel,
|
||||
)
|
||||
|
||||
logger := zap.New(core)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, logging.Sync(logger))
|
||||
})
|
||||
|
||||
return logger, buffer
|
||||
}
|
||||
|
||||
// NewTelemetryRuntime constructs a telemetry runtime for tests and shuts it
|
||||
// down automatically.
|
||||
func NewTelemetryRuntime(t *testing.T, logger *zap.Logger) *telemetry.Runtime {
|
||||
t.Helper()
|
||||
|
||||
runtime, err := telemetry.New(context.Background(), logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
require.NoError(t, runtime.Shutdown(ctx))
|
||||
})
|
||||
|
||||
return runtime
|
||||
}
|
||||
|
||||
// ScrapeMetrics returns the Prometheus exposition produced by handler.
|
||||
func ScrapeMetrics(t *testing.T, handler http.Handler) string {
|
||||
t.Helper()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
handler.ServeHTTP(recorder, req)
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
|
||||
return recorder.Body.String()
|
||||
}
|
||||
Reference in New Issue
Block a user