feat: edge gateway service

This commit is contained in:
Ilia Denisov
2026-04-02 19:18:42 +02:00
committed by GitHub
parent 8cde99936c
commit 436c97a38b
95 changed files with 20504 additions and 57 deletions
+133
View File
@@ -0,0 +1,133 @@
// Package adminapi exposes the optional private admin HTTP listener used for
// operational endpoints such as Prometheus metrics.
package adminapi
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"sync"
"galaxy/gateway/internal/config"
"go.uber.org/zap"
)
// Server owns the optional admin HTTP listener exposed by the gateway.
type Server struct {
cfg config.AdminHTTPConfig
handler http.Handler
logger *zap.Logger
stateMu sync.RWMutex
server *http.Server
listener net.Listener
}
// NewServer constructs an admin HTTP server for cfg and handler.
func NewServer(cfg config.AdminHTTPConfig, handler http.Handler, logger *zap.Logger) *Server {
if handler == nil {
handler = http.NotFoundHandler()
}
if logger == nil {
logger = zap.NewNop()
}
return &Server{
cfg: cfg,
handler: handler,
logger: logger.Named("admin_http"),
}
}
// Enabled reports whether the admin listener should run.
func (s *Server) Enabled() bool {
return s != nil && s.cfg.Addr != ""
}
// Run binds the configured listener and serves the admin HTTP surface until
// Shutdown closes the server. A disabled admin server returns when ctx is
// canceled.
func (s *Server) Run(ctx context.Context) error {
if ctx == nil {
return errors.New("run admin HTTP server: nil context")
}
if err := ctx.Err(); err != nil {
return err
}
if !s.Enabled() {
<-ctx.Done()
return nil
}
listener, err := net.Listen("tcp", s.cfg.Addr)
if err != nil {
return fmt.Errorf("run admin HTTP server: listen on %q: %w", s.cfg.Addr, err)
}
server := &http.Server{
Handler: s.handler,
ReadHeaderTimeout: s.cfg.ReadHeaderTimeout,
ReadTimeout: s.cfg.ReadTimeout,
IdleTimeout: s.cfg.IdleTimeout,
}
s.stateMu.Lock()
s.server = server
s.listener = listener
s.stateMu.Unlock()
s.logger.Info("admin HTTP server started", zap.String("addr", listener.Addr().String()))
defer func() {
s.stateMu.Lock()
s.server = nil
s.listener = nil
s.stateMu.Unlock()
}()
err = server.Serve(listener)
switch {
case err == nil:
return nil
case errors.Is(err, http.ErrServerClosed):
s.logger.Info("admin HTTP server stopped")
return nil
default:
return fmt.Errorf("run admin HTTP server: serve on %q: %w", s.cfg.Addr, err)
}
}
// Shutdown gracefully stops the admin HTTP server within ctx.
func (s *Server) Shutdown(ctx context.Context) error {
if ctx == nil {
return errors.New("shutdown admin HTTP server: nil context")
}
s.stateMu.RLock()
server := s.server
s.stateMu.RUnlock()
if server == nil {
return nil
}
if err := server.Shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("shutdown admin HTTP server: %w", err)
}
return nil
}
func (s *Server) listenAddr() string {
s.stateMu.RLock()
defer s.stateMu.RUnlock()
if s.listener == nil {
return ""
}
return s.listener.Addr().String()
}
+102
View File
@@ -0,0 +1,102 @@
package adminapi
import (
"context"
"net"
"net/http"
"testing"
"time"
"galaxy/gateway/internal/app"
"galaxy/gateway/internal/config"
"galaxy/gateway/internal/restapi"
"galaxy/gateway/internal/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMetricsAreReachableOnlyOnAdminListener(t *testing.T) {
t.Parallel()
logger, _ := testutil.NewObservedLogger(t)
telemetryRuntime := testutil.NewTelemetryRuntime(t, logger)
publicAddr := unusedTCPAddr(t)
adminAddr := unusedTCPAddr(t)
publicCfg := config.DefaultPublicHTTPConfig()
publicCfg.Addr = publicAddr
adminCfg := config.DefaultAdminHTTPConfig()
adminCfg.Addr = adminAddr
restServer := restapi.NewServer(publicCfg, restapi.ServerDependencies{
Logger: logger,
Telemetry: telemetryRuntime,
})
adminServer := NewServer(adminCfg, telemetryRuntime.Handler(), logger)
application := app.New(
config.Config{
ShutdownTimeout: time.Second,
PublicHTTP: publicCfg,
AdminHTTP: adminCfg,
AuthenticatedGRPC: config.DefaultAuthenticatedGRPCConfig(),
},
restServer,
adminServer,
)
ctx, cancel := context.WithCancel(context.Background())
resultCh := make(chan error, 1)
go func() {
resultCh <- application.Run(ctx)
}()
defer func() {
cancel()
select {
case err := <-resultCh:
require.NoError(t, err)
case <-time.After(time.Second):
require.FailNow(t, "application did not stop")
}
}()
waitForHTTPStatus(t, "http://"+publicAddr+"/healthz", http.StatusOK)
waitForHTTPStatus(t, "http://"+adminAddr+"/metrics", http.StatusOK)
publicMetricsResp, err := http.Get("http://" + publicAddr + "/metrics")
require.NoError(t, err)
defer func() {
require.NoError(t, publicMetricsResp.Body.Close())
}()
assert.Equal(t, http.StatusNotFound, publicMetricsResp.StatusCode)
}
func waitForHTTPStatus(t *testing.T, rawURL string, wantStatus int) {
t.Helper()
require.Eventually(t, func() bool {
resp, err := http.Get(rawURL)
if err != nil {
return false
}
defer func() {
_ = resp.Body.Close()
}()
return resp.StatusCode == wantStatus
}, time.Second, 10*time.Millisecond)
}
func unusedTCPAddr(t *testing.T) string {
t.Helper()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
addr := listener.Addr().String()
require.NoError(t, listener.Close())
return addr
}
+178
View File
@@ -0,0 +1,178 @@
// Package app wires the gateway process lifecycle and coordinates component
// startup and graceful shutdown.
package app
import (
"context"
"errors"
"fmt"
"sync"
"galaxy/gateway/internal/config"
)
// Component is a long-lived gateway subsystem that participates in coordinated
// startup and graceful shutdown.
type Component interface {
// Run starts the component and blocks until it stops.
Run(context.Context) error
// Shutdown stops the component within the provided timeout-bounded context.
Shutdown(context.Context) error
}
// App owns the process-level lifecycle of the gateway and its registered
// components.
type App struct {
cfg config.Config
components []Component
}
// New constructs an App with a defensive copy of the supplied components.
func New(cfg config.Config, components ...Component) *App {
clonedComponents := append([]Component(nil), components...)
return &App{
cfg: cfg,
components: clonedComponents,
}
}
// Run starts all configured components, waits for cancellation or the first
// component failure, and then executes best-effort graceful shutdown for every
// component.
func (a *App) Run(ctx context.Context) error {
if ctx == nil {
return errors.New("run gateway app: nil context")
}
if err := a.validate(); err != nil {
return err
}
if len(a.components) == 0 {
<-ctx.Done()
return nil
}
runCtx, cancel := context.WithCancel(ctx)
defer cancel()
results := make(chan componentResult, len(a.components))
var runWG sync.WaitGroup
for idx, component := range a.components {
runWG.Add(1)
go func(index int, component Component) {
defer runWG.Done()
results <- componentResult{
index: index,
err: component.Run(runCtx),
}
}(idx, component)
}
var runErr error
select {
case <-ctx.Done():
case result := <-results:
runErr = classifyComponentResult(ctx, result)
}
cancel()
shutdownErr := a.shutdownComponents()
waitErr := a.waitForComponents(&runWG)
return errors.Join(runErr, shutdownErr, waitErr)
}
// componentResult captures the first observed exit from a running component.
type componentResult struct {
index int
err error
}
// validate confirms that the App has a safe shutdown budget and no nil
// components before goroutines are started.
func (a *App) validate() error {
if a.cfg.ShutdownTimeout <= 0 {
return fmt.Errorf("run gateway app: shutdown timeout must be positive, got %s", a.cfg.ShutdownTimeout)
}
for idx, component := range a.components {
if component == nil {
return fmt.Errorf("run gateway app: component %d is nil", idx)
}
}
return nil
}
// classifyComponentResult maps the first component exit into the error that
// should control the application result.
func classifyComponentResult(parentCtx context.Context, result componentResult) error {
switch {
case result.err == nil:
if parentCtx.Err() != nil {
return nil
}
return fmt.Errorf("run gateway app: component %d exited without error before shutdown", result.index)
case errors.Is(result.err, context.Canceled) && parentCtx.Err() != nil:
return nil
default:
return fmt.Errorf("run gateway app: component %d: %w", result.index, result.err)
}
}
// shutdownComponents calls Shutdown on every registered component using a fresh
// timeout-bounded context per component and joins any shutdown failures.
func (a *App) shutdownComponents() error {
var shutdownWG sync.WaitGroup
errs := make(chan error, len(a.components))
for idx, component := range a.components {
shutdownWG.Add(1)
go func(index int, component Component) {
defer shutdownWG.Done()
shutdownCtx, cancel := context.WithTimeout(context.Background(), a.cfg.ShutdownTimeout)
defer cancel()
if err := component.Shutdown(shutdownCtx); err != nil {
errs <- fmt.Errorf("shutdown gateway component %d: %w", index, err)
}
}(idx, component)
}
shutdownWG.Wait()
close(errs)
var joined error
for err := range errs {
joined = errors.Join(joined, err)
}
return joined
}
// waitForComponents waits for running components to return after shutdown and
// reports when they outlive the configured shutdown budget.
func (a *App) waitForComponents(runWG *sync.WaitGroup) error {
done := make(chan struct{})
go func() {
runWG.Wait()
close(done)
}()
waitCtx, cancel := context.WithTimeout(context.Background(), a.cfg.ShutdownTimeout)
defer cancel()
select {
case <-done:
return nil
case <-waitCtx.Done():
return fmt.Errorf("wait for gateway components: %w", waitCtx.Err())
}
}
+268
View File
@@ -0,0 +1,268 @@
package app
import (
"context"
"errors"
"sync"
"testing"
"time"
"galaxy/gateway/internal/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAppRunWaitsForCancellationWithoutComponents(t *testing.T) {
t.Parallel()
application := New(config.Config{ShutdownTimeout: 50 * time.Millisecond})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resultCh := make(chan error, 1)
go func() {
resultCh <- application.Run(ctx)
}()
select {
case err := <-resultCh:
require.FailNowf(t, "Run() returned early", "error=%v", err)
case <-time.After(50 * time.Millisecond):
}
cancel()
select {
case err := <-resultCh:
require.NoError(t, err)
case <-time.After(time.Second):
require.FailNow(t, "Run() did not return after cancellation")
}
}
func TestAppRunCancelsComponentsAndCallsShutdownOnce(t *testing.T) {
t.Parallel()
first := newLifecycleComponent()
second := newLifecycleComponent()
application := New(
config.Config{ShutdownTimeout: time.Second},
first,
second,
)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resultCh := make(chan error, 1)
go func() {
resultCh <- application.Run(ctx)
}()
first.waitStarted(t)
second.waitStarted(t)
cancel()
select {
case err := <-resultCh:
require.NoError(t, err)
case <-time.After(time.Second):
require.FailNow(t, "Run() did not return after cancellation")
}
first.waitRunExited(t)
second.waitRunExited(t)
assert.Equal(t, 1, first.shutdownCalls())
assert.Equal(t, 1, second.shutdownCalls())
}
func TestAppRunReturnsComponentErrorAndStillShutsDown(t *testing.T) {
t.Parallel()
runErr := errors.New("boom")
failing := newFailingComponent(runErr)
blocking := newLifecycleComponent()
application := New(
config.Config{ShutdownTimeout: time.Second},
failing,
blocking,
)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resultCh := make(chan error, 1)
go func() {
resultCh <- application.Run(ctx)
}()
failing.waitStarted(t)
blocking.waitStarted(t)
failing.releaseRun()
select {
case err := <-resultCh:
require.Error(t, err)
assert.ErrorIs(t, err, runErr)
case <-time.After(time.Second):
require.FailNow(t, "Run() did not return after component failure")
}
failing.waitRunExited(t)
blocking.waitRunExited(t)
assert.Equal(t, 1, failing.shutdownCalls())
assert.Equal(t, 1, blocking.shutdownCalls())
}
// lifecycleComponent blocks in Run until the application calls Shutdown.
type lifecycleComponent struct {
startedCh chan struct{}
runDoneCh chan struct{}
stopCh chan struct{}
shutdownMu sync.Mutex
shutdownCnt int
}
// newLifecycleComponent builds a component that exits Run only after Shutdown
// signals its stop channel.
func newLifecycleComponent() *lifecycleComponent {
return &lifecycleComponent{
startedCh: make(chan struct{}),
runDoneCh: make(chan struct{}),
stopCh: make(chan struct{}),
}
}
// Run marks the component as started, waits for cancellation, and then blocks
// until Shutdown releases the stop channel.
func (c *lifecycleComponent) Run(ctx context.Context) error {
close(c.startedCh)
defer close(c.runDoneCh)
<-ctx.Done()
<-c.stopCh
return nil
}
// Shutdown records the call and releases the run loop.
func (c *lifecycleComponent) Shutdown(context.Context) error {
c.shutdownMu.Lock()
defer c.shutdownMu.Unlock()
c.shutdownCnt++
if c.shutdownCnt == 1 {
close(c.stopCh)
}
return nil
}
// waitStarted blocks until Run has started or fails the test on timeout.
func (c *lifecycleComponent) waitStarted(t *testing.T) {
t.Helper()
select {
case <-c.startedCh:
case <-time.After(time.Second):
require.FailNow(t, "component did not start")
}
}
// waitRunExited blocks until Run exits or fails the test on timeout.
func (c *lifecycleComponent) waitRunExited(t *testing.T) {
t.Helper()
select {
case <-c.runDoneCh:
case <-time.After(time.Second):
require.FailNow(t, "component run did not exit")
}
}
// shutdownCalls returns the number of observed Shutdown invocations.
func (c *lifecycleComponent) shutdownCalls() int {
c.shutdownMu.Lock()
defer c.shutdownMu.Unlock()
return c.shutdownCnt
}
// failingComponent returns a predefined error once released by the test and
// still tracks shutdown calls.
type failingComponent struct {
startedCh chan struct{}
releaseCh chan struct{}
runDoneCh chan struct{}
shutdownMu sync.Mutex
shutdownCnt int
err error
}
// newFailingComponent builds a component whose Run returns err after release.
func newFailingComponent(err error) *failingComponent {
return &failingComponent{
startedCh: make(chan struct{}),
releaseCh: make(chan struct{}),
runDoneCh: make(chan struct{}),
err: err,
}
}
// Run waits until the test releases it and then returns the configured error.
func (c *failingComponent) Run(context.Context) error {
close(c.startedCh)
defer close(c.runDoneCh)
<-c.releaseCh
return c.err
}
// Shutdown records that the application attempted graceful shutdown.
func (c *failingComponent) Shutdown(context.Context) error {
c.shutdownMu.Lock()
defer c.shutdownMu.Unlock()
c.shutdownCnt++
return nil
}
// waitStarted blocks until Run has started or fails the test on timeout.
func (c *failingComponent) waitStarted(t *testing.T) {
t.Helper()
select {
case <-c.startedCh:
case <-time.After(time.Second):
require.FailNow(t, "failing component did not start")
}
}
// releaseRun allows Run to return its configured error.
func (c *failingComponent) releaseRun() {
close(c.releaseCh)
}
// waitRunExited blocks until Run exits or fails the test on timeout.
func (c *failingComponent) waitRunExited(t *testing.T) {
t.Helper()
select {
case <-c.runDoneCh:
case <-time.After(time.Second):
require.FailNow(t, "failing component run did not exit")
}
}
// shutdownCalls returns the number of observed Shutdown invocations.
func (c *failingComponent) shutdownCalls() int {
c.shutdownMu.Lock()
defer c.shutdownMu.Unlock()
return c.shutdownCnt
}
+80
View File
@@ -0,0 +1,80 @@
package authn
import (
"crypto/ed25519"
"encoding/binary"
"errors"
)
const (
// EventDomainMarkerV1 binds the v1 server event signature to the Galaxy
// gateway transport contract.
EventDomainMarkerV1 = "galaxy-event-v1"
)
var (
// ErrInvalidEventSignature reports that a gateway stream event signature is
// not a raw Ed25519 signature for the canonical event signing input.
ErrInvalidEventSignature = errors.New("invalid event signature")
)
// EventSigningFields contains the canonical v1 stream-event fields that are
// bound into the server signing input.
type EventSigningFields struct {
// EventType identifies the stable client-facing event category.
EventType string
// EventID is the stable event correlation identifier.
EventID string
// TimestampMS carries the server event timestamp in milliseconds.
TimestampMS int64
// RequestID optionally correlates the event to the opening client request.
RequestID string
// TraceID optionally carries the client-supplied tracing correlation value.
TraceID string
// PayloadHash is the raw SHA-256 digest of event payload bytes.
PayloadHash []byte
}
// BuildEventSigningInput returns the canonical byte sequence the v1 gateway
// stream-event signature covers. String and byte fields are length-prefixed
// with uvarint(len(field)) followed by raw bytes, while TimestampMS is
// appended as an 8-byte big-endian uint64.
func BuildEventSigningInput(fields EventSigningFields) []byte {
size := len(EventDomainMarkerV1) +
len(fields.EventType) +
len(fields.EventID) +
len(fields.RequestID) +
len(fields.TraceID) +
len(fields.PayloadHash) +
(6 * binary.MaxVarintLen64) +
8
buf := make([]byte, 0, size)
buf = appendLengthPrefixedString(buf, EventDomainMarkerV1)
buf = appendLengthPrefixedString(buf, fields.EventType)
buf = appendLengthPrefixedString(buf, fields.EventID)
buf = binary.BigEndian.AppendUint64(buf, uint64(fields.TimestampMS))
buf = appendLengthPrefixedString(buf, fields.RequestID)
buf = appendLengthPrefixedString(buf, fields.TraceID)
buf = appendLengthPrefixedBytes(buf, fields.PayloadHash)
return buf
}
// VerifyEventSignature verifies that signature authenticates fields under
// publicKey using the canonical v1 event signing input.
func VerifyEventSignature(publicKey ed25519.PublicKey, signature []byte, fields EventSigningFields) error {
if len(publicKey) != ed25519.PublicKeySize || len(signature) != ed25519.SignatureSize {
return ErrInvalidEventSignature
}
if !ed25519.Verify(publicKey, BuildEventSigningInput(fields), signature) {
return ErrInvalidEventSignature
}
return nil
}
+111
View File
@@ -0,0 +1,111 @@
package authn
import (
"bytes"
"crypto/ed25519"
"crypto/rand"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestBuildEventSigningInputChangesWhenSignedFieldChanges(t *testing.T) {
t.Parallel()
base := EventSigningFields{
EventType: "gateway.server_time",
EventID: "request-123",
TimestampMS: 123456789,
RequestID: "request-123",
TraceID: "trace-123",
PayloadHash: mustSHA256([]byte("payload")),
}
baseInput := BuildEventSigningInput(base)
tests := []struct {
name string
mutate func(EventSigningFields) EventSigningFields
}{
{
name: "event type",
mutate: func(fields EventSigningFields) EventSigningFields {
fields.EventType = "gateway.other"
return fields
},
},
{
name: "event id",
mutate: func(fields EventSigningFields) EventSigningFields {
fields.EventID = "request-456"
return fields
},
},
{
name: "timestamp",
mutate: func(fields EventSigningFields) EventSigningFields {
fields.TimestampMS++
return fields
},
},
{
name: "request id",
mutate: func(fields EventSigningFields) EventSigningFields {
fields.RequestID = "request-456"
return fields
},
},
{
name: "trace id",
mutate: func(fields EventSigningFields) EventSigningFields {
fields.TraceID = "trace-456"
return fields
},
},
{
name: "payload hash",
mutate: func(fields EventSigningFields) EventSigningFields {
fields.PayloadHash = mustSHA256([]byte("other"))
return fields
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
mutated := BuildEventSigningInput(tt.mutate(base))
assert.False(t, bytes.Equal(baseInput, mutated))
})
}
}
func TestSignAndVerifyEventSignature(t *testing.T) {
t.Parallel()
_, privateKey, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
signer, err := NewEd25519ResponseSigner(privateKey)
require.NoError(t, err)
fields := EventSigningFields{
EventType: "gateway.server_time",
EventID: "request-123",
TimestampMS: 123456789,
RequestID: "request-123",
TraceID: "trace-123",
PayloadHash: mustSHA256([]byte("payload")),
}
signature, err := signer.SignEvent(fields)
require.NoError(t, err)
require.NoError(t, VerifyEventSignature(signer.PublicKey(), signature, fields))
fields.TraceID = "changed"
require.ErrorIs(t, VerifyEventSignature(signer.PublicKey(), signature, fields), ErrInvalidEventSignature)
}
+101
View File
@@ -0,0 +1,101 @@
// Package authn defines authenticated transport helpers shared by the gateway
// edge verification pipeline.
package authn
import (
"bytes"
"crypto/sha256"
"encoding/binary"
"errors"
)
const (
// RequestDomainMarkerV1 binds the v1 client request signature to the Galaxy
// gateway transport contract.
RequestDomainMarkerV1 = "galaxy-request-v1"
)
var (
// ErrInvalidPayloadHash reports that payloadHash is not a raw SHA-256 digest.
ErrInvalidPayloadHash = errors.New("payload_hash must be a 32-byte SHA-256 digest")
// ErrPayloadHashMismatch reports that payloadHash does not match payloadBytes.
ErrPayloadHashMismatch = errors.New("payload_hash does not match payload_bytes")
)
// RequestSigningFields contains the canonical v1 request fields that are bound
// into the client signing input after the gateway validates and normalizes the
// request envelope.
type RequestSigningFields struct {
// ProtocolVersion identifies the transport envelope version.
ProtocolVersion string
// DeviceSessionID identifies the authenticated device session bound to the
// request.
DeviceSessionID string
// MessageType is the stable downstream routing key.
MessageType string
// TimestampMS carries the client request timestamp in milliseconds.
TimestampMS int64
// RequestID is the transport correlation and anti-replay identifier.
RequestID string
// PayloadHash is the raw SHA-256 digest of payload bytes.
PayloadHash []byte
}
// BuildRequestSigningInput returns the canonical byte sequence the v1 client
// request signature covers. String and byte fields are length-prefixed with
// uvarint(len(field)) followed by raw bytes, while TimestampMS is appended as
// an 8-byte big-endian uint64. The caller is expected to pass fields that have
// already passed earlier envelope validation.
func BuildRequestSigningInput(fields RequestSigningFields) []byte {
size := len(RequestDomainMarkerV1) +
len(fields.ProtocolVersion) +
len(fields.DeviceSessionID) +
len(fields.MessageType) +
len(fields.RequestID) +
len(fields.PayloadHash) +
(6 * binary.MaxVarintLen64) +
8
buf := make([]byte, 0, size)
buf = appendLengthPrefixedString(buf, RequestDomainMarkerV1)
buf = appendLengthPrefixedString(buf, fields.ProtocolVersion)
buf = appendLengthPrefixedString(buf, fields.DeviceSessionID)
buf = appendLengthPrefixedString(buf, fields.MessageType)
buf = binary.BigEndian.AppendUint64(buf, uint64(fields.TimestampMS))
buf = appendLengthPrefixedString(buf, fields.RequestID)
buf = appendLengthPrefixedBytes(buf, fields.PayloadHash)
return buf
}
// VerifyPayloadHash checks that payloadHash is the raw SHA-256 digest of
// payloadBytes. Empty payloadBytes are valid and must use sha256.Sum256(nil).
func VerifyPayloadHash(payloadBytes, payloadHash []byte) error {
if len(payloadHash) != sha256.Size {
return ErrInvalidPayloadHash
}
sum := sha256.Sum256(payloadBytes)
if !bytes.Equal(sum[:], payloadHash) {
return ErrPayloadHashMismatch
}
return nil
}
func appendLengthPrefixedString(dst []byte, value string) []byte {
return appendLengthPrefixedBytes(dst, []byte(value))
}
func appendLengthPrefixedBytes(dst []byte, value []byte) []byte {
dst = binary.AppendUvarint(dst, uint64(len(value)))
dst = append(dst, value...)
return dst
}
+163
View File
@@ -0,0 +1,163 @@
package authn
import (
"bytes"
"crypto/sha256"
"encoding/hex"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestVerifyPayloadHash(t *testing.T) {
t.Parallel()
payloadSum := sha256.Sum256([]byte("payload"))
emptySum := sha256.Sum256(nil)
otherSum := sha256.Sum256([]byte("other"))
tests := []struct {
name string
payload []byte
payloadHash []byte
wantErr error
}{
{
name: "matches non-empty payload",
payload: []byte("payload"),
payloadHash: payloadSum[:],
},
{
name: "matches empty payload",
payload: nil,
payloadHash: emptySum[:],
},
{
name: "rejects digest with invalid length",
payload: []byte("payload"),
payloadHash: []byte("short"),
wantErr: ErrInvalidPayloadHash,
},
{
name: "rejects digest mismatch",
payload: []byte("payload"),
payloadHash: otherSum[:],
wantErr: ErrPayloadHashMismatch,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
err := VerifyPayloadHash(tt.payload, tt.payloadHash)
if tt.wantErr == nil {
require.NoError(t, err)
return
}
require.ErrorIs(t, err, tt.wantErr)
})
}
}
func TestBuildRequestSigningInput(t *testing.T) {
t.Parallel()
fields := RequestSigningFields{
ProtocolVersion: "v1",
DeviceSessionID: "device-session-123",
MessageType: "fleet.move",
TimestampMS: 123456789,
RequestID: "request-123",
PayloadHash: mustSHA256([]byte("payload")),
}
got := BuildRequestSigningInput(fields)
want, err := hex.DecodeString("1167616c6178792d726571756573742d7631027631126465766963652d73657373696f6e2d3132330a666c6565742e6d6f766500000000075bcd150b726571756573742d31323320239f59ed55e737c77147cf55ad0c1b030b6d7ee748a7426952f9b852d5a935e5")
require.NoError(t, err)
assert.Equal(t, want, got)
}
func TestBuildRequestSigningInputChangesWhenSignedFieldChanges(t *testing.T) {
t.Parallel()
base := RequestSigningFields{
ProtocolVersion: "v1",
DeviceSessionID: "device-session-123",
MessageType: "fleet.move",
TimestampMS: 123456789,
RequestID: "request-123",
PayloadHash: mustSHA256([]byte("payload")),
}
baseInput := BuildRequestSigningInput(base)
tests := []struct {
name string
mutate func(RequestSigningFields) RequestSigningFields
}{
{
name: "protocol version",
mutate: func(fields RequestSigningFields) RequestSigningFields {
fields.ProtocolVersion = "v2"
return fields
},
},
{
name: "device session id",
mutate: func(fields RequestSigningFields) RequestSigningFields {
fields.DeviceSessionID = "device-session-456"
return fields
},
},
{
name: "message type",
mutate: func(fields RequestSigningFields) RequestSigningFields {
fields.MessageType = "fleet.attack"
return fields
},
},
{
name: "timestamp",
mutate: func(fields RequestSigningFields) RequestSigningFields {
fields.TimestampMS++
return fields
},
},
{
name: "request id",
mutate: func(fields RequestSigningFields) RequestSigningFields {
fields.RequestID = "request-456"
return fields
},
},
{
name: "payload hash",
mutate: func(fields RequestSigningFields) RequestSigningFields {
fields.PayloadHash = mustSHA256([]byte("other"))
return fields
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
mutated := BuildRequestSigningInput(tt.mutate(base))
assert.False(t, bytes.Equal(baseInput, mutated))
})
}
}
func mustSHA256(payload []byte) []byte {
sum := sha256.Sum256(payload)
return sum[:]
}
+189
View File
@@ -0,0 +1,189 @@
package authn
import (
"bytes"
"crypto/ed25519"
"crypto/x509"
"encoding/binary"
"encoding/pem"
"errors"
"fmt"
"os"
)
const (
// ResponseDomainMarkerV1 binds the v1 server response signature to the
// Galaxy gateway transport contract.
ResponseDomainMarkerV1 = "galaxy-response-v1"
)
var (
// ErrInvalidResponsePrivateKeyPEM reports that the configured response
// signer private key is not a strict PKCS#8 PEM-encoded private key.
ErrInvalidResponsePrivateKeyPEM = errors.New("response signer private key is not a valid PKCS#8 PEM block")
// ErrInvalidResponsePrivateKey reports that the configured response signer
// private key is not an Ed25519 private key.
ErrInvalidResponsePrivateKey = errors.New("response signer private key must be an Ed25519 PKCS#8 private key")
// ErrInvalidResponseSignature reports that a server response signature is
// not a raw Ed25519 signature for the canonical response signing input.
ErrInvalidResponseSignature = errors.New("invalid response signature")
)
// ResponseSigningFields contains the canonical v1 response fields that are
// bound into the server signing input.
type ResponseSigningFields struct {
// ProtocolVersion identifies the transport envelope version.
ProtocolVersion string
// RequestID is the transport correlation identifier copied from the
// authenticated request.
RequestID string
// TimestampMS carries the server response timestamp in milliseconds.
TimestampMS int64
// ResultCode is the opaque downstream result code returned to the client.
ResultCode string
// PayloadHash is the raw SHA-256 digest of response payload bytes.
PayloadHash []byte
}
// ResponseSigner signs authenticated unary responses and client-facing stream
// events with one server-side key.
type ResponseSigner interface {
// SignResponse returns the raw Ed25519 signature for the canonical response
// signing input built from fields.
SignResponse(fields ResponseSigningFields) ([]byte, error)
// SignEvent returns the raw Ed25519 signature for the canonical event
// signing input built from fields.
SignEvent(fields EventSigningFields) ([]byte, error)
}
// Ed25519ResponseSigner signs authenticated responses with one Ed25519 private
// key loaded during process startup.
type Ed25519ResponseSigner struct {
privateKey ed25519.PrivateKey
}
// NewEd25519ResponseSigner validates privateKey and constructs a signer using
// a defensive key copy.
func NewEd25519ResponseSigner(privateKey ed25519.PrivateKey) (*Ed25519ResponseSigner, error) {
if len(privateKey) != ed25519.PrivateKeySize {
return nil, ErrInvalidResponsePrivateKey
}
return &Ed25519ResponseSigner{
privateKey: bytes.Clone(privateKey),
}, nil
}
// LoadEd25519ResponseSignerFromPEMFile loads a strict PKCS#8 PEM-encoded
// Ed25519 private key from path and constructs a signer.
func LoadEd25519ResponseSignerFromPEMFile(path string) (*Ed25519ResponseSigner, error) {
pemBytes, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("read response signer private key PEM: %w", err)
}
signer, err := ParseEd25519ResponseSignerPEM(pemBytes)
if err != nil {
return nil, err
}
return signer, nil
}
// ParseEd25519ResponseSignerPEM parses one strict PKCS#8 PEM-encoded Ed25519
// private key and constructs a signer from it.
func ParseEd25519ResponseSignerPEM(pemBytes []byte) (*Ed25519ResponseSigner, error) {
block, rest := pem.Decode(pemBytes)
if block == nil || block.Type != "PRIVATE KEY" || len(bytes.TrimSpace(rest)) > 0 {
return nil, ErrInvalidResponsePrivateKeyPEM
}
parsedKey, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return nil, ErrInvalidResponsePrivateKeyPEM
}
privateKey, ok := parsedKey.(ed25519.PrivateKey)
if !ok {
return nil, ErrInvalidResponsePrivateKey
}
return NewEd25519ResponseSigner(privateKey)
}
// PublicKey returns the Ed25519 public key that corresponds to the configured
// response signer private key.
func (s *Ed25519ResponseSigner) PublicKey() ed25519.PublicKey {
if s == nil {
return nil
}
publicKey, _ := s.privateKey.Public().(ed25519.PublicKey)
return bytes.Clone(publicKey)
}
// SignResponse signs the canonical v1 response signing input built from
// fields.
func (s *Ed25519ResponseSigner) SignResponse(fields ResponseSigningFields) ([]byte, error) {
if s == nil || len(s.privateKey) != ed25519.PrivateKeySize {
return nil, ErrInvalidResponsePrivateKey
}
signature := ed25519.Sign(s.privateKey, BuildResponseSigningInput(fields))
return bytes.Clone(signature), nil
}
// SignEvent signs the canonical v1 stream-event signing input built from
// fields.
func (s *Ed25519ResponseSigner) SignEvent(fields EventSigningFields) ([]byte, error) {
if s == nil || len(s.privateKey) != ed25519.PrivateKeySize {
return nil, ErrInvalidResponsePrivateKey
}
signature := ed25519.Sign(s.privateKey, BuildEventSigningInput(fields))
return bytes.Clone(signature), nil
}
// BuildResponseSigningInput returns the canonical byte sequence the v1 server
// response signature covers. String and byte fields are length-prefixed with
// uvarint(len(field)) followed by raw bytes, while TimestampMS is appended as
// an 8-byte big-endian uint64.
func BuildResponseSigningInput(fields ResponseSigningFields) []byte {
size := len(ResponseDomainMarkerV1) +
len(fields.ProtocolVersion) +
len(fields.RequestID) +
len(fields.ResultCode) +
len(fields.PayloadHash) +
(5 * binary.MaxVarintLen64) +
8
buf := make([]byte, 0, size)
buf = appendLengthPrefixedString(buf, ResponseDomainMarkerV1)
buf = appendLengthPrefixedString(buf, fields.ProtocolVersion)
buf = appendLengthPrefixedString(buf, fields.RequestID)
buf = binary.BigEndian.AppendUint64(buf, uint64(fields.TimestampMS))
buf = appendLengthPrefixedString(buf, fields.ResultCode)
buf = appendLengthPrefixedBytes(buf, fields.PayloadHash)
return buf
}
// VerifyResponseSignature verifies that signature authenticates fields under
// publicKey using the canonical v1 response signing input.
func VerifyResponseSignature(publicKey ed25519.PublicKey, signature []byte, fields ResponseSigningFields) error {
if len(publicKey) != ed25519.PublicKeySize || len(signature) != ed25519.SignatureSize {
return ErrInvalidResponseSignature
}
if !ed25519.Verify(publicKey, BuildResponseSigningInput(fields), signature) {
return ErrInvalidResponseSignature
}
return nil
}
+146
View File
@@ -0,0 +1,146 @@
package authn
import (
"bytes"
"crypto/ed25519"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestBuildResponseSigningInputChangesWhenSignedFieldChanges(t *testing.T) {
t.Parallel()
base := ResponseSigningFields{
ProtocolVersion: "v1",
RequestID: "request-123",
TimestampMS: 123456789,
ResultCode: "ok",
PayloadHash: mustSHA256([]byte("payload")),
}
baseInput := BuildResponseSigningInput(base)
tests := []struct {
name string
mutate func(ResponseSigningFields) ResponseSigningFields
}{
{
name: "protocol version",
mutate: func(fields ResponseSigningFields) ResponseSigningFields {
fields.ProtocolVersion = "v2"
return fields
},
},
{
name: "request id",
mutate: func(fields ResponseSigningFields) ResponseSigningFields {
fields.RequestID = "request-456"
return fields
},
},
{
name: "timestamp",
mutate: func(fields ResponseSigningFields) ResponseSigningFields {
fields.TimestampMS++
return fields
},
},
{
name: "result code",
mutate: func(fields ResponseSigningFields) ResponseSigningFields {
fields.ResultCode = "denied"
return fields
},
},
{
name: "payload hash",
mutate: func(fields ResponseSigningFields) ResponseSigningFields {
fields.PayloadHash = mustSHA256([]byte("other"))
return fields
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
mutated := BuildResponseSigningInput(tt.mutate(base))
assert.False(t, bytes.Equal(baseInput, mutated))
})
}
}
func TestParseEd25519ResponseSignerPEMRejectsMalformedPEM(t *testing.T) {
t.Parallel()
_, err := ParseEd25519ResponseSignerPEM([]byte("not-pem"))
require.ErrorIs(t, err, ErrInvalidResponsePrivateKeyPEM)
}
func TestParseEd25519ResponseSignerPEMRejectsNonPKCS8PEM(t *testing.T) {
t.Parallel()
_, privateKey, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pemBytes, err := x509.MarshalPKCS8PrivateKey(privateKey)
require.NoError(t, err)
block := pem.Block{
Type: "ED25519 PRIVATE KEY",
Bytes: pemBytes,
}
_, err = ParseEd25519ResponseSignerPEM(pem.EncodeToMemory(&block))
require.ErrorIs(t, err, ErrInvalidResponsePrivateKeyPEM)
}
func TestParseEd25519ResponseSignerPEMRejectsNonEd25519Key(t *testing.T) {
t.Parallel()
privateKey, err := rsa.GenerateKey(rand.Reader, 1024)
require.NoError(t, err)
pemBytes, err := x509.MarshalPKCS8PrivateKey(privateKey)
require.NoError(t, err)
_, err = ParseEd25519ResponseSignerPEM(pem.EncodeToMemory(&pem.Block{
Type: "PRIVATE KEY",
Bytes: pemBytes,
}))
require.ErrorIs(t, err, ErrInvalidResponsePrivateKey)
}
func TestSignAndVerifyResponseSignature(t *testing.T) {
t.Parallel()
_, privateKey, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
signer, err := NewEd25519ResponseSigner(privateKey)
require.NoError(t, err)
fields := ResponseSigningFields{
ProtocolVersion: "v1",
RequestID: "request-123",
TimestampMS: 123456789,
ResultCode: "ok",
PayloadHash: mustSHA256([]byte("payload")),
}
signature, err := signer.SignResponse(fields)
require.NoError(t, err)
require.NoError(t, VerifyResponseSignature(signer.PublicKey(), signature, fields))
fields.ResultCode = "changed"
require.ErrorIs(t, VerifyResponseSignature(signer.PublicKey(), signature, fields), ErrInvalidResponseSignature)
}
+47
View File
@@ -0,0 +1,47 @@
package authn
import (
"crypto/ed25519"
"encoding/base64"
"errors"
)
var (
// ErrInvalidClientPublicKey reports that cached client public key material
// is not a base64-encoded raw Ed25519 public key.
ErrInvalidClientPublicKey = errors.New("client_public_key is not a valid base64-encoded Ed25519 public key")
// ErrInvalidRequestSignature reports that a request signature is not a raw
// Ed25519 signature for the canonical request signing input.
ErrInvalidRequestSignature = errors.New("invalid request signature")
)
// VerifyRequestSignature validates the base64-encoded raw Ed25519 public key
// from session cache, builds the canonical v1 signing input from fields, and
// verifies that signature authenticates the request.
func VerifyRequestSignature(clientPublicKey string, signature []byte, fields RequestSigningFields) error {
publicKey, err := decodeClientPublicKey(clientPublicKey)
if err != nil {
return err
}
if len(signature) != ed25519.SignatureSize {
return ErrInvalidRequestSignature
}
if !ed25519.Verify(publicKey, BuildRequestSigningInput(fields), signature) {
return ErrInvalidRequestSignature
}
return nil
}
func decodeClientPublicKey(value string) (ed25519.PublicKey, error) {
decoded, err := base64.StdEncoding.Strict().DecodeString(value)
if err != nil {
return nil, ErrInvalidClientPublicKey
}
if len(decoded) != ed25519.PublicKeySize {
return nil, ErrInvalidClientPublicKey
}
return ed25519.PublicKey(decoded), nil
}
+137
View File
@@ -0,0 +1,137 @@
package authn
import (
"crypto/ed25519"
"crypto/sha256"
"encoding/base64"
"testing"
"github.com/stretchr/testify/require"
)
func TestVerifyRequestSignature(t *testing.T) {
t.Parallel()
clientPrivateKey := newTestPrivateKey("primary")
clientPublicKey := clientPrivateKey.Public().(ed25519.PublicKey)
otherPrivateKey := newTestPrivateKey("other")
fields := RequestSigningFields{
ProtocolVersion: "v1",
DeviceSessionID: "device-session-123",
MessageType: "fleet.move",
TimestampMS: 123456789,
RequestID: "request-123",
PayloadHash: mustSHA256([]byte("payload")),
}
signature := ed25519.Sign(clientPrivateKey, BuildRequestSigningInput(fields))
tests := []struct {
name string
clientPublicKey string
signature []byte
fields RequestSigningFields
wantErr error
}{
{
name: "valid signature",
clientPublicKey: base64.StdEncoding.EncodeToString(clientPublicKey),
signature: signature,
fields: fields,
},
{
name: "message type change rejects signature",
clientPublicKey: base64.StdEncoding.EncodeToString(clientPublicKey),
signature: signature,
fields: func() RequestSigningFields {
mutated := fields
mutated.MessageType = "fleet.attack"
return mutated
}(),
wantErr: ErrInvalidRequestSignature,
},
{
name: "request id change rejects signature",
clientPublicKey: base64.StdEncoding.EncodeToString(clientPublicKey),
signature: signature,
fields: func() RequestSigningFields {
mutated := fields
mutated.RequestID = "request-456"
return mutated
}(),
wantErr: ErrInvalidRequestSignature,
},
{
name: "payload hash change rejects signature",
clientPublicKey: base64.StdEncoding.EncodeToString(clientPublicKey),
signature: signature,
fields: func() RequestSigningFields {
mutated := fields
mutated.PayloadHash = mustSHA256([]byte("other"))
return mutated
}(),
wantErr: ErrInvalidRequestSignature,
},
{
name: "wrong key rejects signature",
clientPublicKey: base64.StdEncoding.EncodeToString(otherPrivateKey.Public().(ed25519.PublicKey)),
signature: signature,
fields: fields,
wantErr: ErrInvalidRequestSignature,
},
{
name: "bit flipped signature rejects",
clientPublicKey: base64.StdEncoding.EncodeToString(clientPublicKey),
signature: func() []byte {
corrupted := append([]byte(nil), signature...)
corrupted[0] ^= 0xff
return corrupted
}(),
fields: fields,
wantErr: ErrInvalidRequestSignature,
},
{
name: "invalid signature length rejects",
clientPublicKey: base64.StdEncoding.EncodeToString(clientPublicKey),
signature: signature[:len(signature)-1],
fields: fields,
wantErr: ErrInvalidRequestSignature,
},
{
name: "invalid base64 public key rejects",
clientPublicKey: "%%%not-base64%%%",
signature: signature,
fields: fields,
wantErr: ErrInvalidClientPublicKey,
},
{
name: "invalid public key length rejects",
clientPublicKey: base64.StdEncoding.EncodeToString([]byte("short")),
signature: signature,
fields: fields,
wantErr: ErrInvalidClientPublicKey,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
err := VerifyRequestSignature(tt.clientPublicKey, tt.signature, tt.fields)
if tt.wantErr == nil {
require.NoError(t, err)
return
}
require.ErrorIs(t, err, tt.wantErr)
})
}
}
func newTestPrivateKey(label string) ed25519.PrivateKey {
seed := sha256.Sum256([]byte("gateway-authn-signature-test-" + label))
return ed25519.NewKeyFromSeed(seed[:])
}
+20
View File
@@ -0,0 +1,20 @@
// Package clock provides the gateway time source abstraction used by
// authenticated transport checks.
package clock
import "time"
// Clock returns current server time for freshness checks and time-dependent
// transport behavior.
type Clock interface {
// Now returns the current server time.
Now() time.Time
}
// System returns the current process time using the local system clock.
type System struct{}
// Now returns the current UTC time from the system clock.
func (System) Now() time.Time {
return time.Now().UTC()
}
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+108
View File
@@ -0,0 +1,108 @@
// Package downstream defines the verified internal command contract used by the
// gateway after the authenticated edge pipeline succeeds.
package downstream
import (
"context"
"errors"
)
var (
// ErrRouteNotFound reports that Router does not have an exact-match handler
// for the supplied authenticated message type.
ErrRouteNotFound = errors.New("downstream route not found")
// ErrDownstreamUnavailable reports that the resolved downstream dependency is
// temporarily unavailable.
ErrDownstreamUnavailable = errors.New("downstream service is unavailable")
)
// AuthenticatedCommand is the minimum verified unary command context the
// gateway may forward to downstream business services.
type AuthenticatedCommand struct {
// ProtocolVersion is the authenticated transport protocol version accepted
// by the gateway.
ProtocolVersion string
// UserID is the authenticated user identity resolved from SessionCache.
UserID string
// DeviceSessionID is the authenticated device session that originated the
// command.
DeviceSessionID string
// MessageType is the stable exact-match downstream routing key.
MessageType string
// TimestampMS is the client-supplied request timestamp that already passed
// freshness verification.
TimestampMS int64
// RequestID is the transport correlation and anti-replay identifier.
RequestID string
// TraceID is the optional client-supplied correlation identifier.
TraceID string
// PayloadBytes carries the verified opaque business payload bytes.
PayloadBytes []byte
}
// UnaryResult is the minimum downstream unary result the gateway needs in
// order to build a signed authenticated client response.
type UnaryResult struct {
// ResultCode is the stable opaque downstream result code returned to the
// client without business reinterpretation by the gateway.
ResultCode string
// PayloadBytes carries the opaque downstream response payload bytes.
PayloadBytes []byte
}
// Client executes a verified authenticated unary command against one concrete
// downstream service or adapter.
type Client interface {
// ExecuteCommand executes command and returns the downstream unary result.
ExecuteCommand(ctx context.Context, command AuthenticatedCommand) (UnaryResult, error)
}
// Router resolves the downstream unary client for one exact authenticated
// message_type value.
type Router interface {
// Route returns the downstream client for messageType. Implementations must
// wrap ErrRouteNotFound when the route table does not contain messageType.
Route(messageType string) (Client, error)
}
// StaticRouter resolves exact message_type literals from an immutable route
// map supplied at construction time.
type StaticRouter struct {
routes map[string]Client
}
// NewStaticRouter constructs a StaticRouter with a defensive copy of routes.
func NewStaticRouter(routes map[string]Client) *StaticRouter {
clonedRoutes := make(map[string]Client, len(routes))
for messageType, client := range routes {
if client == nil {
continue
}
clonedRoutes[messageType] = client
}
return &StaticRouter{routes: clonedRoutes}
}
// Route returns the exact-match client for messageType.
func (r *StaticRouter) Route(messageType string) (Client, error) {
if r == nil {
return nil, ErrRouteNotFound
}
client, ok := r.routes[messageType]
if !ok || client == nil {
return nil, ErrRouteNotFound
}
return client, nil
}
@@ -0,0 +1,39 @@
package downstream
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestStaticRouterRoutesExactMessageType(t *testing.T) {
t.Parallel()
want := &stubClient{}
router := NewStaticRouter(map[string]Client{
"fleet.move": want,
})
got, err := router.Route("fleet.move")
require.NoError(t, err)
assert.Same(t, want, got)
}
func TestStaticRouterRejectsUnknownMessageType(t *testing.T) {
t.Parallel()
router := NewStaticRouter(map[string]Client{
"fleet.move": &stubClient{},
})
_, err := router.Route("fleet.rename")
require.ErrorIs(t, err, ErrRouteNotFound)
}
type stubClient struct{}
func (*stubClient) ExecuteCommand(context.Context, AuthenticatedCommand) (UnaryResult, error) {
return UnaryResult{}, nil
}
@@ -0,0 +1,341 @@
package events
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"strings"
"sync"
"time"
"galaxy/gateway/internal/config"
"galaxy/gateway/internal/push"
"galaxy/gateway/internal/telemetry"
"github.com/redis/go-redis/v9"
"go.opentelemetry.io/otel/attribute"
"go.uber.org/zap"
)
const clientEventReadCount int64 = 128
// ClientEventPublisher accepts decoded client-facing events from the internal
// event subscriber.
type ClientEventPublisher interface {
// Publish fans out event to the currently active push streams.
Publish(event push.Event)
}
// RedisClientEventSubscriber consumes client-facing events from one Redis
// Stream and forwards them to the configured publisher.
type RedisClientEventSubscriber struct {
client *redis.Client
stream string
pingTimeout time.Duration
readBlockTimeout time.Duration
publisher ClientEventPublisher
logger *zap.Logger
metrics *telemetry.Runtime
closeOnce sync.Once
startedOnce sync.Once
started chan struct{}
}
// NewRedisClientEventSubscriber constructs a Redis Stream subscriber that
// reuses the SessionCache Redis connection settings and forwards decoded
// client-facing events to publisher.
func NewRedisClientEventSubscriber(sessionCfg config.SessionCacheRedisConfig, eventsCfg config.ClientEventsRedisConfig, publisher ClientEventPublisher) (*RedisClientEventSubscriber, error) {
return NewRedisClientEventSubscriberWithObservability(sessionCfg, eventsCfg, publisher, nil, nil)
}
// NewRedisClientEventSubscriberWithObservability constructs a Redis Stream
// subscriber that also records malformed or dropped internal events.
func NewRedisClientEventSubscriberWithObservability(sessionCfg config.SessionCacheRedisConfig, eventsCfg config.ClientEventsRedisConfig, publisher ClientEventPublisher, logger *zap.Logger, metrics *telemetry.Runtime) (*RedisClientEventSubscriber, error) {
if strings.TrimSpace(sessionCfg.Addr) == "" {
return nil, errors.New("new redis client event subscriber: redis addr must not be empty")
}
if sessionCfg.DB < 0 {
return nil, errors.New("new redis client event subscriber: redis db must not be negative")
}
if sessionCfg.LookupTimeout <= 0 {
return nil, errors.New("new redis client event subscriber: lookup timeout must be positive")
}
if strings.TrimSpace(eventsCfg.Stream) == "" {
return nil, errors.New("new redis client event subscriber: stream must not be empty")
}
if eventsCfg.ReadBlockTimeout <= 0 {
return nil, errors.New("new redis client event subscriber: read block timeout must be positive")
}
if publisher == nil {
return nil, errors.New("new redis client event subscriber: nil publisher")
}
options := &redis.Options{
Addr: sessionCfg.Addr,
Username: sessionCfg.Username,
Password: sessionCfg.Password,
DB: sessionCfg.DB,
Protocol: 2,
DisableIdentity: true,
}
if sessionCfg.TLSEnabled {
options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
}
if logger == nil {
logger = zap.NewNop()
}
return &RedisClientEventSubscriber{
client: redis.NewClient(options),
stream: eventsCfg.Stream,
pingTimeout: sessionCfg.LookupTimeout,
readBlockTimeout: eventsCfg.ReadBlockTimeout,
publisher: publisher,
logger: logger.Named("client_event_subscriber"),
metrics: metrics,
started: make(chan struct{}),
}, nil
}
// Ping verifies that the Redis backend used for client-facing event fan-out is
// reachable within the configured timeout budget.
func (s *RedisClientEventSubscriber) Ping(ctx context.Context) error {
if s == nil || s.client == nil {
return errors.New("ping redis client event subscriber: nil subscriber")
}
if ctx == nil {
return errors.New("ping redis client event subscriber: nil context")
}
pingCtx, cancel := context.WithTimeout(ctx, s.pingTimeout)
defer cancel()
if err := s.client.Ping(pingCtx).Err(); err != nil {
return fmt.Errorf("ping redis client event subscriber: %w", err)
}
return nil
}
// Run consumes client-facing events until ctx is canceled or Redis returns an
// unexpected error.
func (s *RedisClientEventSubscriber) Run(ctx context.Context) error {
if s == nil || s.client == nil {
return errors.New("run redis client event subscriber: nil subscriber")
}
if ctx == nil {
return errors.New("run redis client event subscriber: nil context")
}
if err := ctx.Err(); err != nil {
return err
}
lastID, err := s.resolveStartID(ctx)
if err != nil {
return err
}
s.signalStarted()
for {
streams, err := s.client.XRead(ctx, &redis.XReadArgs{
Streams: []string{s.stream, lastID},
Count: clientEventReadCount,
Block: s.readBlockTimeout,
}).Result()
switch {
case err == nil:
for _, stream := range streams {
for _, message := range stream.Messages {
s.publishMessage(message)
lastID = message.ID
}
}
continue
case errors.Is(err, redis.Nil):
continue
case ctx.Err() != nil && (errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) || errors.Is(err, redis.ErrClosed)):
return ctx.Err()
case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded), errors.Is(err, redis.ErrClosed):
return fmt.Errorf("run redis client event subscriber: %w", err)
default:
return fmt.Errorf("run redis client event subscriber: %w", err)
}
}
}
func (s *RedisClientEventSubscriber) resolveStartID(ctx context.Context) (string, error) {
messages, err := s.client.XRevRangeN(ctx, s.stream, "+", "-", 1).Result()
switch {
case err == nil:
case errors.Is(err, redis.Nil):
return "0-0", nil
default:
return "", fmt.Errorf("run redis client event subscriber: resolve stream tail: %w", err)
}
if len(messages) == 0 {
return "0-0", nil
}
return messages[0].ID, nil
}
// Shutdown closes the Redis client so a blocking stream read can terminate
// promptly during gateway shutdown.
func (s *RedisClientEventSubscriber) Shutdown(ctx context.Context) error {
if ctx == nil {
return errors.New("shutdown redis client event subscriber: nil context")
}
return s.Close()
}
// Close releases the underlying Redis client resources.
func (s *RedisClientEventSubscriber) Close() error {
if s == nil || s.client == nil {
return nil
}
var err error
s.closeOnce.Do(func() {
err = s.client.Close()
})
return err
}
func (s *RedisClientEventSubscriber) signalStarted() {
s.startedOnce.Do(func() {
close(s.started)
})
}
func (s *RedisClientEventSubscriber) publishMessage(message redis.XMessage) {
event, err := decodeClientEvent(message.Values)
if err != nil {
s.logger.Warn("dropped malformed client event",
zap.String("stream", s.stream),
zap.String("message_id", message.ID),
zap.Error(err),
)
s.metrics.RecordInternalEventDrop(context.Background(),
attribute.String("component", "client_event_subscriber"),
attribute.String("reason", "malformed_event"),
)
return
}
s.publisher.Publish(event)
}
func decodeClientEvent(values map[string]any) (push.Event, error) {
requiredKeys := map[string]struct{}{
"user_id": {},
"event_type": {},
"event_id": {},
"payload_bytes": {},
}
optionalKeys := map[string]struct{}{
"device_session_id": {},
"request_id": {},
"trace_id": {},
}
for key := range values {
if _, ok := requiredKeys[key]; ok {
continue
}
if _, ok := optionalKeys[key]; ok {
continue
}
return push.Event{}, fmt.Errorf("decode client event: unsupported field %q", key)
}
userID, err := requiredStringField(values, "user_id")
if err != nil {
return push.Event{}, err
}
eventType, err := requiredStringField(values, "event_type")
if err != nil {
return push.Event{}, err
}
eventID, err := requiredStringField(values, "event_id")
if err != nil {
return push.Event{}, err
}
payloadBytes, err := requiredBytesField(values, "payload_bytes")
if err != nil {
return push.Event{}, err
}
event := push.Event{
UserID: userID,
EventType: eventType,
EventID: eventID,
PayloadBytes: payloadBytes,
}
if deviceSessionID, ok, err := optionalStringField(values, "device_session_id"); err != nil {
return push.Event{}, err
} else if ok {
event.DeviceSessionID = strings.TrimSpace(deviceSessionID)
}
if requestID, ok, err := optionalStringField(values, "request_id"); err != nil {
return push.Event{}, err
} else if ok {
event.RequestID = requestID
}
if traceID, ok, err := optionalStringField(values, "trace_id"); err != nil {
return push.Event{}, err
} else if ok {
event.TraceID = traceID
}
return event, nil
}
func requiredBytesField(values map[string]any, field string) ([]byte, error) {
value, ok := values[field]
if !ok {
return nil, fmt.Errorf("decode client event: missing %s", field)
}
byteValue, err := coerceBytes(value)
if err != nil {
return nil, fmt.Errorf("decode client event: %s: %w", field, err)
}
return byteValue, nil
}
func optionalStringField(values map[string]any, field string) (string, bool, error) {
value, ok := values[field]
if !ok {
return "", false, nil
}
stringValue, err := coerceString(value)
if err != nil {
return "", false, fmt.Errorf("decode client event: %s: %w", field, err)
}
return stringValue, true, nil
}
func coerceBytes(value any) ([]byte, error) {
switch typed := value.(type) {
case string:
return []byte(typed), nil
case []byte:
return bytes.Clone(typed), nil
default:
return nil, fmt.Errorf("unsupported type %T", value)
}
}
@@ -0,0 +1,294 @@
package events
import (
"context"
"strings"
"sync"
"testing"
"time"
"galaxy/gateway/internal/config"
"galaxy/gateway/internal/push"
"galaxy/gateway/internal/testutil"
"github.com/alicebob/miniredis/v2"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRedisClientEventSubscriberPublishesValidEvent(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := &recordingClientEventPublisher{}
subscriber := newTestRedisClientEventSubscriber(t, server, publisher)
running := runTestClientEventSubscriber(t, subscriber)
defer running.stop(t)
addClientEvent(t, server, "gateway:client_events", map[string]any{
"user_id": "user-123",
"device_session_id": "device-session-123",
"event_type": "fleet.updated",
"event_id": "event-123",
"payload_bytes": []byte("payload-123"),
"request_id": "request-123",
"trace_id": "trace-123",
})
require.Eventually(t, func() bool {
return len(publisher.events()) == 1
}, time.Second, 10*time.Millisecond)
assert.Equal(t, []push.Event{{
UserID: "user-123",
DeviceSessionID: "device-session-123",
EventType: "fleet.updated",
EventID: "event-123",
PayloadBytes: []byte("payload-123"),
RequestID: "request-123",
TraceID: "trace-123",
}}, publisher.events())
}
func TestRedisClientEventSubscriberSkipsMalformedEventAndContinues(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := &recordingClientEventPublisher{}
subscriber := newTestRedisClientEventSubscriber(t, server, publisher)
running := runTestClientEventSubscriber(t, subscriber)
defer running.stop(t)
addClientEvent(t, server, "gateway:client_events", map[string]any{
"user_id": "user-123",
"event_type": "fleet.updated",
"event_id": "event-bad",
"payload_bytes": []byte("payload-bad"),
"unexpected": "boom",
})
addClientEvent(t, server, "gateway:client_events", map[string]any{
"user_id": "user-123",
"event_type": "fleet.updated",
"event_id": "event-good",
"payload_bytes": []byte("payload-good"),
})
require.Eventually(t, func() bool {
events := publisher.events()
return len(events) == 1 && events[0].EventID == "event-good"
}, time.Second, 10*time.Millisecond)
}
func TestRedisClientEventSubscriberStartsFromCurrentTail(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := &recordingClientEventPublisher{}
addClientEvent(t, server, "gateway:client_events", map[string]any{
"user_id": "user-123",
"event_type": "fleet.updated",
"event_id": "event-old",
"payload_bytes": []byte("payload-old"),
})
subscriber := newTestRedisClientEventSubscriber(t, server, publisher)
running := runTestClientEventSubscriber(t, subscriber)
defer running.stop(t)
assert.Never(t, func() bool {
return len(publisher.events()) > 0
}, 100*time.Millisecond, 10*time.Millisecond)
addClientEvent(t, server, "gateway:client_events", map[string]any{
"user_id": "user-123",
"event_type": "fleet.updated",
"event_id": "event-new",
"payload_bytes": []byte("payload-new"),
})
require.Eventually(t, func() bool {
events := publisher.events()
return len(events) == 1 && events[0].EventID == "event-new"
}, time.Second, 10*time.Millisecond)
}
func TestRedisClientEventSubscriberShutdownInterruptsBlockingRead(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := &recordingClientEventPublisher{}
subscriber := newTestRedisClientEventSubscriber(t, server, publisher)
ctx, cancel := context.WithCancel(context.Background())
resultCh := make(chan error, 1)
go func() {
resultCh <- subscriber.Run(ctx)
}()
select {
case <-subscriber.started:
case <-time.After(time.Second):
require.FailNow(t, "subscriber did not start")
}
cancel()
require.NoError(t, subscriber.Shutdown(context.Background()))
select {
case err := <-resultCh:
require.ErrorIs(t, err, context.Canceled)
case <-time.After(time.Second):
require.FailNow(t, "subscriber did not stop after shutdown")
}
}
func TestRedisClientEventSubscriberLogsAndCountsMalformedEvents(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := &recordingClientEventPublisher{}
logger, logBuffer := testutil.NewObservedLogger(t)
telemetryRuntime := testutil.NewTelemetryRuntime(t, logger)
subscriber, err := NewRedisClientEventSubscriberWithObservability(
config.SessionCacheRedisConfig{
Addr: server.Addr(),
LookupTimeout: 250 * time.Millisecond,
},
config.ClientEventsRedisConfig{
Stream: "gateway:client_events",
ReadBlockTimeout: 25 * time.Millisecond,
},
publisher,
logger,
telemetryRuntime,
)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, subscriber.Close())
})
running := runTestClientEventSubscriber(t, subscriber)
defer running.stop(t)
addClientEvent(t, server, "gateway:client_events", map[string]any{
"user_id": "user-123",
"event_type": "fleet.updated",
"event_id": "event-bad",
"payload_bytes": []byte("payload-bad"),
"unexpected": "boom",
})
require.Eventually(t, func() bool {
return strings.Contains(logBuffer.String(), "dropped malformed client event")
}, time.Second, 10*time.Millisecond)
metricsText := testutil.ScrapeMetrics(t, telemetryRuntime.Handler())
assert.Contains(t, metricsText, `gateway_internal_event_drops_total`)
assert.Contains(t, metricsText, `component="client_event_subscriber"`)
assert.Contains(t, metricsText, `reason="malformed_event"`)
}
func newTestRedisClientEventSubscriber(t *testing.T, server *miniredis.Miniredis, publisher ClientEventPublisher) *RedisClientEventSubscriber {
t.Helper()
subscriber, err := NewRedisClientEventSubscriber(
config.SessionCacheRedisConfig{
Addr: server.Addr(),
LookupTimeout: 250 * time.Millisecond,
},
config.ClientEventsRedisConfig{
Stream: "gateway:client_events",
ReadBlockTimeout: 25 * time.Millisecond,
},
publisher,
)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, subscriber.Close())
})
return subscriber
}
func addClientEvent(t *testing.T, server *miniredis.Miniredis, stream string, values map[string]any) {
t.Helper()
client := redis.NewClient(&redis.Options{
Addr: server.Addr(),
Protocol: 2,
DisableIdentity: true,
})
defer func() {
assert.NoError(t, client.Close())
}()
err := client.XAdd(context.Background(), &redis.XAddArgs{
Stream: stream,
Values: values,
}).Err()
require.NoError(t, err)
}
type runningClientEventSubscriber struct {
cancel context.CancelFunc
resultCh chan error
}
func runTestClientEventSubscriber(t *testing.T, subscriber *RedisClientEventSubscriber) runningClientEventSubscriber {
t.Helper()
ctx, cancel := context.WithCancel(context.Background())
resultCh := make(chan error, 1)
go func() {
resultCh <- subscriber.Run(ctx)
}()
select {
case <-subscriber.started:
case <-time.After(time.Second):
require.FailNow(t, "subscriber did not start")
}
return runningClientEventSubscriber{
cancel: cancel,
resultCh: resultCh,
}
}
func (r runningClientEventSubscriber) stop(t *testing.T) {
t.Helper()
r.cancel()
select {
case err := <-r.resultCh:
require.ErrorIs(t, err, context.Canceled)
case <-time.After(time.Second):
require.FailNow(t, "subscriber did not stop")
}
}
type recordingClientEventPublisher struct {
mu sync.Mutex
records []push.Event
}
func (p *recordingClientEventPublisher) Publish(event push.Event) {
p.mu.Lock()
defer p.mu.Unlock()
p.records = append(p.records, event)
}
func (p *recordingClientEventPublisher) events() []push.Event {
p.mu.Lock()
defer p.mu.Unlock()
cloned := make([]push.Event, len(p.records))
copy(cloned, p.records)
return cloned
}
@@ -0,0 +1,385 @@
package events
import (
"context"
"crypto/ed25519"
"crypto/sha256"
"encoding/base64"
"errors"
"net"
"sync"
"testing"
"time"
"galaxy/gateway/internal/app"
"galaxy/gateway/internal/authn"
"galaxy/gateway/internal/clock"
"galaxy/gateway/internal/config"
"galaxy/gateway/internal/downstream"
"galaxy/gateway/internal/grpcapi"
"galaxy/gateway/internal/replay"
"galaxy/gateway/internal/session"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/status"
)
var testNow = time.Date(2026, time.April, 1, 12, 0, 0, 0, time.UTC)
func TestAuthenticatedGatewayWarmsLocalSessionCache(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
local := session.NewMemoryCache()
fallback := &countingSessionCache{
records: map[string]session.Record{
"device-session-123": newActiveSessionRecord("user-123"),
},
}
readThrough, err := session.NewReadThroughCache(local, fallback)
require.NoError(t, err)
subscriber := newTestRedisSessionSubscriber(t, server, local)
downstreamClient := &recordingDownstreamClient{}
addr, running := runAuthenticatedGateway(t, readThrough, subscriber, downstreamClient)
defer running.stop(t)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-1"))
require.NoError(t, err)
assert.Equal(t, 1, fallback.lookupCalls())
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-2"))
require.NoError(t, err)
assert.Equal(t, 1, fallback.lookupCalls())
assert.Len(t, downstreamClient.commands(), 2)
}
func TestAuthenticatedGatewayUsesSessionUpdateEventWithoutFallbackLookup(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
local := session.NewMemoryCache()
fallback := &countingSessionCache{
records: map[string]session.Record{
"device-session-123": newActiveSessionRecord("user-123"),
},
}
readThrough, err := session.NewReadThroughCache(local, fallback)
require.NoError(t, err)
subscriber := newTestRedisSessionSubscriber(t, server, local)
downstreamClient := &recordingDownstreamClient{}
addr, running := runAuthenticatedGateway(t, readThrough, subscriber, downstreamClient)
defer running.stop(t)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-1"))
require.NoError(t, err)
assert.Equal(t, 1, fallback.lookupCalls())
addSessionEvent(t, server, "gateway:session_events", map[string]string{
"device_session_id": "device-session-123",
"user_id": "user-456",
"client_public_key": testClientPublicKeyBase64(),
"status": string(session.StatusActive),
})
require.Eventually(t, func() bool {
record, lookupErr := local.Lookup(context.Background(), "device-session-123")
return lookupErr == nil && record.UserID == "user-456"
}, time.Second, 10*time.Millisecond)
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-2"))
require.NoError(t, err)
assert.Equal(t, 1, fallback.lookupCalls())
commands := downstreamClient.commands()
require.Len(t, commands, 2)
assert.Equal(t, "user-456", commands[1].UserID)
}
func TestAuthenticatedGatewayRejectsRevokedSessionAfterEventWithoutFallbackLookup(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
local := session.NewMemoryCache()
fallback := &countingSessionCache{
records: map[string]session.Record{
"device-session-123": newActiveSessionRecord("user-123"),
},
}
readThrough, err := session.NewReadThroughCache(local, fallback)
require.NoError(t, err)
subscriber := newTestRedisSessionSubscriber(t, server, local)
downstreamClient := &recordingDownstreamClient{}
addr, running := runAuthenticatedGateway(t, readThrough, subscriber, downstreamClient)
defer running.stop(t)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-1"))
require.NoError(t, err)
assert.Equal(t, 1, fallback.lookupCalls())
addSessionEvent(t, server, "gateway:session_events", map[string]string{
"device_session_id": "device-session-123",
"user_id": "user-123",
"client_public_key": testClientPublicKeyBase64(),
"status": string(session.StatusRevoked),
"revoked_at_ms": "123456789",
})
require.Eventually(t, func() bool {
record, lookupErr := local.Lookup(context.Background(), "device-session-123")
return lookupErr == nil && record.Status == session.StatusRevoked
}, time.Second, 10*time.Millisecond)
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-2"))
require.Error(t, err)
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
assert.Equal(t, "device session is revoked", status.Convert(err).Message())
assert.Equal(t, 1, fallback.lookupCalls())
}
type runningAuthenticatedGateway struct {
cancel context.CancelFunc
resultCh chan error
}
func runAuthenticatedGateway(t *testing.T, sessionCache session.Cache, subscriber *RedisSessionSubscriber, downstreamClient downstream.Client) (string, runningAuthenticatedGateway) {
t.Helper()
addr := unusedTCPAddr(t)
grpcCfg := config.DefaultAuthenticatedGRPCConfig()
grpcCfg.Addr = addr
grpcCfg.FreshnessWindow = 5 * time.Minute
router := downstream.NewStaticRouter(map[string]downstream.Client{
"fleet.move": downstreamClient,
})
gateway := grpcapi.NewServer(grpcCfg, grpcapi.ServerDependencies{
Router: router,
ResponseSigner: newTestResponseSigner(t),
SessionCache: sessionCache,
ReplayStore: staticReplayStore{},
Clock: fixedClock{now: testNow},
})
application := app.New(
config.Config{
ShutdownTimeout: time.Second,
AuthenticatedGRPC: grpcCfg,
},
gateway,
subscriber,
)
ctx, cancel := context.WithCancel(context.Background())
resultCh := make(chan error, 1)
go func() {
resultCh <- application.Run(ctx)
}()
select {
case <-subscriber.started:
case <-time.After(time.Second):
require.FailNow(t, "session subscriber did not start")
}
return addr, runningAuthenticatedGateway{
cancel: cancel,
resultCh: resultCh,
}
}
func (g runningAuthenticatedGateway) stop(t *testing.T) {
t.Helper()
g.cancel()
select {
case err := <-g.resultCh:
require.NoError(t, err)
case <-time.After(2 * time.Second):
require.FailNow(t, "gateway did not stop after cancellation")
}
}
func dialGatewayClient(t *testing.T, addr string) *grpc.ClientConn {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
conn, err := grpc.DialContext(
ctx,
addr,
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithBlock(),
)
require.NoError(t, err)
return conn
}
func unusedTCPAddr(t *testing.T) string {
t.Helper()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
addr := listener.Addr().String()
require.NoError(t, listener.Close())
return addr
}
func newExecuteCommandRequest(requestID string) *gatewayv1.ExecuteCommandRequest {
payloadBytes := []byte("payload")
payloadHash := sha256.Sum256(payloadBytes)
req := &gatewayv1.ExecuteCommandRequest{
ProtocolVersion: "v1",
DeviceSessionId: "device-session-123",
MessageType: "fleet.move",
TimestampMs: testNow.UnixMilli(),
RequestId: requestID,
PayloadBytes: payloadBytes,
PayloadHash: payloadHash[:],
TraceId: "trace-123",
}
req.Signature = ed25519.Sign(testClientPrivateKey(), authn.BuildRequestSigningInput(authn.RequestSigningFields{
ProtocolVersion: req.GetProtocolVersion(),
DeviceSessionID: req.GetDeviceSessionId(),
MessageType: req.GetMessageType(),
TimestampMS: req.GetTimestampMs(),
RequestID: req.GetRequestId(),
PayloadHash: req.GetPayloadHash(),
}))
return req
}
func newActiveSessionRecord(userID string) session.Record {
return session.Record{
DeviceSessionID: "device-session-123",
UserID: userID,
ClientPublicKey: testClientPublicKeyBase64(),
Status: session.StatusActive,
}
}
func testClientPrivateKey() ed25519.PrivateKey {
seed := sha256.Sum256([]byte("gateway-events-grpc-test-client"))
return ed25519.NewKeyFromSeed(seed[:])
}
func testClientPublicKeyBase64() string {
return base64.StdEncoding.EncodeToString(testClientPrivateKey().Public().(ed25519.PublicKey))
}
func newTestResponseSigner(t *testing.T) authn.ResponseSigner {
t.Helper()
seed := sha256.Sum256([]byte("gateway-events-grpc-test-response"))
signer, err := authn.NewEd25519ResponseSigner(ed25519.NewKeyFromSeed(seed[:]))
require.NoError(t, err)
return signer
}
type fixedClock struct {
now time.Time
}
func (c fixedClock) Now() time.Time {
return c.now
}
var _ clock.Clock = fixedClock{}
type staticReplayStore struct{}
func (staticReplayStore) Reserve(context.Context, string, string, time.Duration) error {
return nil
}
var _ replay.Store = staticReplayStore{}
type countingSessionCache struct {
mu sync.Mutex
records map[string]session.Record
lookupCount int
}
func (c *countingSessionCache) Lookup(context.Context, string) (session.Record, error) {
c.mu.Lock()
defer c.mu.Unlock()
c.lookupCount++
record, ok := c.records["device-session-123"]
if !ok {
return session.Record{}, errors.New("lookup session from counting cache: session cache record not found")
}
return record, nil
}
func (c *countingSessionCache) lookupCalls() int {
c.mu.Lock()
defer c.mu.Unlock()
return c.lookupCount
}
type recordingDownstreamClient struct {
mu sync.Mutex
captured []downstream.AuthenticatedCommand
}
func (c *recordingDownstreamClient) ExecuteCommand(_ context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
c.mu.Lock()
c.captured = append(c.captured, command)
c.mu.Unlock()
return downstream.UnaryResult{
ResultCode: "ok",
PayloadBytes: []byte("response"),
}, nil
}
func (c *recordingDownstreamClient) commands() []downstream.AuthenticatedCommand {
c.mu.Lock()
defer c.mu.Unlock()
cloned := make([]downstream.AuthenticatedCommand, len(c.captured))
copy(cloned, c.captured)
return cloned
}
@@ -0,0 +1,416 @@
package events
import (
"context"
"crypto/ed25519"
"crypto/sha256"
"encoding/base64"
"testing"
"time"
"galaxy/gateway/internal/app"
"galaxy/gateway/internal/authn"
"galaxy/gateway/internal/config"
"galaxy/gateway/internal/grpcapi"
"galaxy/gateway/internal/push"
"galaxy/gateway/internal/session"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
func TestSubscribeEventsFanOutsUserTargetedEventToAllUserSessions(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
sessionCache := session.NewMemoryCache()
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-1", "user-123")))
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-2", "user-123")))
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-3", "user-999")))
pushHub := push.NewHub(4)
clientSubscriber := newTestRedisClientEventSubscriber(t, server, pushHub)
addr, running := runPushGateway(t, sessionCache, pushHub, clientSubscriber)
defer running.stop(t)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
targetOneCtx, cancelTargetOne := context.WithCancel(context.Background())
defer cancelTargetOne()
targetOne, err := client.SubscribeEvents(targetOneCtx, newPushSubscribeEventsRequest("device-session-1", "request-1"))
require.NoError(t, err)
assertPushBootstrapEvent(t, recvPushEvent(t, targetOne), "request-1", "trace-device-session-1")
targetTwoCtx, cancelTargetTwo := context.WithCancel(context.Background())
defer cancelTargetTwo()
targetTwo, err := client.SubscribeEvents(targetTwoCtx, newPushSubscribeEventsRequest("device-session-2", "request-2"))
require.NoError(t, err)
assertPushBootstrapEvent(t, recvPushEvent(t, targetTwo), "request-2", "trace-device-session-2")
unrelatedCtx, cancelUnrelated := context.WithCancel(context.Background())
defer cancelUnrelated()
unrelated, err := client.SubscribeEvents(unrelatedCtx, newPushSubscribeEventsRequest("device-session-3", "request-3"))
require.NoError(t, err)
assertPushBootstrapEvent(t, recvPushEvent(t, unrelated), "request-3", "trace-device-session-3")
addClientEvent(t, server, "gateway:client_events", map[string]any{
"user_id": "user-123",
"event_type": "fleet.updated",
"event_id": "event-123",
"payload_bytes": []byte("payload-123"),
"request_id": "request-123",
"trace_id": "trace-123",
})
assertSignedPushEvent(t, recvPushEvent(t, targetOne), push.Event{
UserID: "user-123",
EventType: "fleet.updated",
EventID: "event-123",
PayloadBytes: []byte("payload-123"),
RequestID: "request-123",
TraceID: "trace-123",
})
assertSignedPushEvent(t, recvPushEvent(t, targetTwo), push.Event{
UserID: "user-123",
EventType: "fleet.updated",
EventID: "event-123",
PayloadBytes: []byte("payload-123"),
RequestID: "request-123",
TraceID: "trace-123",
})
assertNoPushEvent(t, unrelated, cancelUnrelated)
}
func TestSubscribeEventsFanOutsSessionTargetedEventOnlyToMatchingSession(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
sessionCache := session.NewMemoryCache()
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-1", "user-123")))
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-2", "user-123")))
pushHub := push.NewHub(4)
clientSubscriber := newTestRedisClientEventSubscriber(t, server, pushHub)
addr, running := runPushGateway(t, sessionCache, pushHub, clientSubscriber)
defer running.stop(t)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
otherCtx, cancelOther := context.WithCancel(context.Background())
defer cancelOther()
otherStream, err := client.SubscribeEvents(otherCtx, newPushSubscribeEventsRequest("device-session-1", "request-1"))
require.NoError(t, err)
assertPushBootstrapEvent(t, recvPushEvent(t, otherStream), "request-1", "trace-device-session-1")
targetCtx, cancelTarget := context.WithCancel(context.Background())
defer cancelTarget()
targetStream, err := client.SubscribeEvents(targetCtx, newPushSubscribeEventsRequest("device-session-2", "request-2"))
require.NoError(t, err)
assertPushBootstrapEvent(t, recvPushEvent(t, targetStream), "request-2", "trace-device-session-2")
addClientEvent(t, server, "gateway:client_events", map[string]any{
"user_id": "user-123",
"device_session_id": "device-session-2",
"event_type": "fleet.updated",
"event_id": "event-456",
"payload_bytes": []byte("payload-456"),
})
assertSignedPushEvent(t, recvPushEvent(t, targetStream), push.Event{
UserID: "user-123",
DeviceSessionID: "device-session-2",
EventType: "fleet.updated",
EventID: "event-456",
PayloadBytes: []byte("payload-456"),
})
assertNoPushEvent(t, otherStream, cancelOther)
}
func TestSubscribeEventsClosesRevokedSessionStreamAndRejectsReopen(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
sessionCache := session.NewMemoryCache()
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-1", "user-123")))
pushHub := push.NewHub(4)
clientSubscriber := newTestRedisClientEventSubscriber(t, server, pushHub)
sessionSubscriber := newTestRedisSessionSubscriberWithRevocationHandler(t, server, sessionCache, pushHub)
addr, running := runPushGateway(t, sessionCache, pushHub, clientSubscriber, sessionSubscriber)
defer running.stop(t)
select {
case <-sessionSubscriber.started:
case <-time.After(time.Second):
require.FailNow(t, "session subscriber did not start")
}
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
streamCtx, cancelStream := context.WithCancel(context.Background())
defer cancelStream()
stream, err := client.SubscribeEvents(streamCtx, newPushSubscribeEventsRequest("device-session-1", "request-1"))
require.NoError(t, err)
assertPushBootstrapEvent(t, recvPushEvent(t, stream), "request-1", "trace-device-session-1")
addSessionEvent(t, server, "gateway:session_events", map[string]string{
"device_session_id": "device-session-1",
"user_id": "user-123",
"client_public_key": pushClientPublicKeyBase64(),
"status": string(session.StatusRevoked),
"revoked_at_ms": "123456789",
})
require.Eventually(t, func() bool {
record, lookupErr := sessionCache.Lookup(context.Background(), "device-session-1")
return lookupErr == nil && record.Status == session.StatusRevoked
}, time.Second, 10*time.Millisecond)
recvErrCh := make(chan error, 1)
go func() {
_, recvErr := stream.Recv()
recvErrCh <- recvErr
}()
select {
case recvErr := <-recvErrCh:
require.Error(t, recvErr)
assert.Equal(t, codes.FailedPrecondition, status.Code(recvErr))
assert.Equal(t, "device session is revoked", status.Convert(recvErr).Message())
case <-time.After(time.Second):
require.FailNow(t, "stream did not close after revoke")
}
reopened, err := client.SubscribeEvents(context.Background(), newPushSubscribeEventsRequest("device-session-1", "request-2"))
if err == nil {
_, err = reopened.Recv()
}
require.Error(t, err)
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
assert.Equal(t, "device session is revoked", status.Convert(err).Message())
}
func TestSubscribeEventsClosesActiveStreamWhenGatewayShutsDown(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
sessionCache := session.NewMemoryCache()
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-1", "user-123")))
pushHub := push.NewHub(4)
clientSubscriber := newTestRedisClientEventSubscriber(t, server, pushHub)
addr, running := runPushGateway(t, sessionCache, pushHub, clientSubscriber)
defer running.stop(t)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
stream, err := client.SubscribeEvents(context.Background(), newPushSubscribeEventsRequest("device-session-1", "request-1"))
require.NoError(t, err)
assertPushBootstrapEvent(t, recvPushEvent(t, stream), "request-1", "trace-device-session-1")
recvErrCh := make(chan error, 1)
go func() {
_, recvErr := stream.Recv()
recvErrCh <- recvErr
}()
running.cancel()
select {
case recvErr := <-recvErrCh:
require.Error(t, recvErr)
assert.Equal(t, codes.Unavailable, status.Code(recvErr))
assert.Equal(t, "gateway is shutting down", status.Convert(recvErr).Message())
case <-time.After(time.Second):
require.FailNow(t, "stream did not close after gateway shutdown")
}
}
func runPushGateway(t *testing.T, sessionCache session.Cache, pushHub *push.Hub, clientSubscriber *RedisClientEventSubscriber, extraComponents ...app.Component) (string, runningAuthenticatedGateway) {
t.Helper()
addr := unusedTCPAddr(t)
grpcCfg := config.DefaultAuthenticatedGRPCConfig()
grpcCfg.Addr = addr
grpcCfg.FreshnessWindow = 5 * time.Minute
responseSigner := newTestResponseSigner(t)
gateway := grpcapi.NewServer(grpcCfg, grpcapi.ServerDependencies{
Service: grpcapi.NewFanOutPushStreamService(pushHub, responseSigner, fixedClock{now: testNow}, zap.NewNop()),
ResponseSigner: responseSigner,
SessionCache: sessionCache,
ReplayStore: staticReplayStore{},
Clock: fixedClock{now: testNow},
PushHub: pushHub,
})
components := []app.Component{gateway, clientSubscriber}
components = append(components, extraComponents...)
application := app.New(
config.Config{
ShutdownTimeout: time.Second,
AuthenticatedGRPC: grpcCfg,
},
components...,
)
ctx, cancel := context.WithCancel(context.Background())
resultCh := make(chan error, 1)
go func() {
resultCh <- application.Run(ctx)
}()
select {
case <-clientSubscriber.started:
case <-time.After(time.Second):
require.FailNow(t, "client event subscriber did not start")
}
return addr, runningAuthenticatedGateway{
cancel: cancel,
resultCh: resultCh,
}
}
func newPushActiveSessionRecord(deviceSessionID string, userID string) session.Record {
return session.Record{
DeviceSessionID: deviceSessionID,
UserID: userID,
ClientPublicKey: pushClientPublicKeyBase64(),
Status: session.StatusActive,
}
}
func newPushSubscribeEventsRequest(deviceSessionID string, requestID string) *gatewayv1.SubscribeEventsRequest {
payloadHash := sha256.Sum256(nil)
traceID := "trace-" + deviceSessionID
req := &gatewayv1.SubscribeEventsRequest{
ProtocolVersion: "v1",
DeviceSessionId: deviceSessionID,
MessageType: "gateway.subscribe",
TimestampMs: testNow.UnixMilli(),
RequestId: requestID,
PayloadHash: payloadHash[:],
TraceId: traceID,
}
req.Signature = ed25519.Sign(pushClientPrivateKey(), authn.BuildRequestSigningInput(authn.RequestSigningFields{
ProtocolVersion: req.GetProtocolVersion(),
DeviceSessionID: req.GetDeviceSessionId(),
MessageType: req.GetMessageType(),
TimestampMS: req.GetTimestampMs(),
RequestID: req.GetRequestId(),
PayloadHash: req.GetPayloadHash(),
}))
return req
}
func recvPushEvent(t *testing.T, stream grpc.ServerStreamingClient[gatewayv1.GatewayEvent]) *gatewayv1.GatewayEvent {
t.Helper()
event, err := stream.Recv()
require.NoError(t, err)
return event
}
func assertPushBootstrapEvent(t *testing.T, event *gatewayv1.GatewayEvent, wantRequestID string, wantTraceID string) {
t.Helper()
require.NotNil(t, event)
assert.Equal(t, "gateway.server_time", event.GetEventType())
assert.Equal(t, wantRequestID, event.GetEventId())
assert.Equal(t, wantRequestID, event.GetRequestId())
assert.Equal(t, wantTraceID, event.GetTraceId())
require.NoError(t, authn.VerifyPayloadHash(event.GetPayloadBytes(), event.GetPayloadHash()))
require.NoError(t, authn.VerifyEventSignature(pushResponseSignerPublicKey(), event.GetSignature(), authn.EventSigningFields{
EventType: event.GetEventType(),
EventID: event.GetEventId(),
TimestampMS: event.GetTimestampMs(),
RequestID: event.GetRequestId(),
TraceID: event.GetTraceId(),
PayloadHash: event.GetPayloadHash(),
}))
}
func assertSignedPushEvent(t *testing.T, event *gatewayv1.GatewayEvent, want push.Event) {
t.Helper()
require.NotNil(t, event)
assert.Equal(t, want.EventType, event.GetEventType())
assert.Equal(t, want.EventID, event.GetEventId())
assert.Equal(t, want.RequestID, event.GetRequestId())
assert.Equal(t, want.TraceID, event.GetTraceId())
assert.Equal(t, want.PayloadBytes, event.GetPayloadBytes())
require.NoError(t, authn.VerifyPayloadHash(event.GetPayloadBytes(), event.GetPayloadHash()))
require.NoError(t, authn.VerifyEventSignature(pushResponseSignerPublicKey(), event.GetSignature(), authn.EventSigningFields{
EventType: event.GetEventType(),
EventID: event.GetEventId(),
TimestampMS: event.GetTimestampMs(),
RequestID: event.GetRequestId(),
TraceID: event.GetTraceId(),
PayloadHash: event.GetPayloadHash(),
}))
}
func assertNoPushEvent(t *testing.T, stream grpc.ServerStreamingClient[gatewayv1.GatewayEvent], cancel context.CancelFunc) {
t.Helper()
recvCh := make(chan *gatewayv1.GatewayEvent, 1)
errCh := make(chan error, 1)
go func() {
event, err := stream.Recv()
if err != nil {
errCh <- err
return
}
recvCh <- event
}()
select {
case event := <-recvCh:
require.FailNowf(t, "unexpected push event delivered", "%+v", event)
case <-time.After(100 * time.Millisecond):
cancel()
case err := <-errCh:
require.FailNowf(t, "stream closed unexpectedly", "%v", err)
}
}
func pushClientPrivateKey() ed25519.PrivateKey {
seed := sha256.Sum256([]byte("gateway-push-grpc-test-client"))
return ed25519.NewKeyFromSeed(seed[:])
}
func pushClientPublicKeyBase64() string {
return base64.StdEncoding.EncodeToString(pushClientPrivateKey().Public().(ed25519.PublicKey))
}
func pushResponseSignerPublicKey() ed25519.PublicKey {
seed := sha256.Sum256([]byte("gateway-events-grpc-test-response"))
return ed25519.NewKeyFromSeed(seed[:]).Public().(ed25519.PublicKey)
}
+389
View File
@@ -0,0 +1,389 @@
// Package events subscribes to internal session lifecycle streams used to keep
// the gateway hot-path session cache synchronized without per-request upstream
// lookups.
package events
import (
"context"
"crypto/tls"
"errors"
"fmt"
"strconv"
"strings"
"sync"
"time"
"galaxy/gateway/internal/config"
"galaxy/gateway/internal/session"
"galaxy/gateway/internal/telemetry"
"github.com/redis/go-redis/v9"
"go.opentelemetry.io/otel/attribute"
"go.uber.org/zap"
)
const sessionEventReadCount int64 = 128
// SessionRevocationHandler reacts to a successfully applied revoked session
// snapshot and may tear down active resources bound to that session.
type SessionRevocationHandler interface {
// RevokeDeviceSession tears down active resources bound to deviceSessionID.
RevokeDeviceSession(deviceSessionID string)
}
// RedisSessionSubscriber consumes full session snapshots from one Redis Stream
// and applies them to a process-local session snapshot store.
type RedisSessionSubscriber struct {
client *redis.Client
stream string
pingTimeout time.Duration
readBlockTimeout time.Duration
store session.SnapshotStore
revocationHandler SessionRevocationHandler
logger *zap.Logger
metrics *telemetry.Runtime
closeOnce sync.Once
startedOnce sync.Once
started chan struct{}
}
// NewRedisSessionSubscriber constructs a Redis Stream subscriber that reuses
// the SessionCache Redis connection settings and applies updates to store.
func NewRedisSessionSubscriber(sessionCfg config.SessionCacheRedisConfig, eventsCfg config.SessionEventsRedisConfig, store session.SnapshotStore) (*RedisSessionSubscriber, error) {
return NewRedisSessionSubscriberWithObservability(sessionCfg, eventsCfg, store, nil, nil, nil)
}
// NewRedisSessionSubscriberWithRevocationHandler constructs a Redis Stream
// subscriber that reuses the SessionCache Redis connection settings, applies
// updates to store, and optionally tears down active resources for revoked
// sessions.
func NewRedisSessionSubscriberWithRevocationHandler(sessionCfg config.SessionCacheRedisConfig, eventsCfg config.SessionEventsRedisConfig, store session.SnapshotStore, revocationHandler SessionRevocationHandler) (*RedisSessionSubscriber, error) {
return NewRedisSessionSubscriberWithObservability(sessionCfg, eventsCfg, store, revocationHandler, nil, nil)
}
// NewRedisSessionSubscriberWithObservability constructs a Redis Stream
// subscriber that also logs and counts malformed internal session events.
func NewRedisSessionSubscriberWithObservability(sessionCfg config.SessionCacheRedisConfig, eventsCfg config.SessionEventsRedisConfig, store session.SnapshotStore, revocationHandler SessionRevocationHandler, logger *zap.Logger, metrics *telemetry.Runtime) (*RedisSessionSubscriber, error) {
if strings.TrimSpace(sessionCfg.Addr) == "" {
return nil, errors.New("new redis session subscriber: redis addr must not be empty")
}
if sessionCfg.DB < 0 {
return nil, errors.New("new redis session subscriber: redis db must not be negative")
}
if sessionCfg.LookupTimeout <= 0 {
return nil, errors.New("new redis session subscriber: lookup timeout must be positive")
}
if strings.TrimSpace(eventsCfg.Stream) == "" {
return nil, errors.New("new redis session subscriber: stream must not be empty")
}
if eventsCfg.ReadBlockTimeout <= 0 {
return nil, errors.New("new redis session subscriber: read block timeout must be positive")
}
if store == nil {
return nil, errors.New("new redis session subscriber: nil session snapshot store")
}
options := &redis.Options{
Addr: sessionCfg.Addr,
Username: sessionCfg.Username,
Password: sessionCfg.Password,
DB: sessionCfg.DB,
Protocol: 2,
DisableIdentity: true,
}
if sessionCfg.TLSEnabled {
options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
}
if logger == nil {
logger = zap.NewNop()
}
return &RedisSessionSubscriber{
client: redis.NewClient(options),
stream: eventsCfg.Stream,
pingTimeout: sessionCfg.LookupTimeout,
readBlockTimeout: eventsCfg.ReadBlockTimeout,
store: store,
revocationHandler: revocationHandler,
logger: logger.Named("session_subscriber"),
metrics: metrics,
started: make(chan struct{}),
}, nil
}
// Ping verifies that the Redis backend used for session lifecycle events is
// reachable within the configured timeout budget.
func (s *RedisSessionSubscriber) Ping(ctx context.Context) error {
if s == nil || s.client == nil {
return errors.New("ping redis session subscriber: nil subscriber")
}
if ctx == nil {
return errors.New("ping redis session subscriber: nil context")
}
pingCtx, cancel := context.WithTimeout(ctx, s.pingTimeout)
defer cancel()
if err := s.client.Ping(pingCtx).Err(); err != nil {
return fmt.Errorf("ping redis session subscriber: %w", err)
}
return nil
}
// Run consumes session lifecycle events until ctx is canceled or Redis returns
// an unexpected error.
func (s *RedisSessionSubscriber) Run(ctx context.Context) error {
if s == nil || s.client == nil {
return errors.New("run redis session subscriber: nil subscriber")
}
if ctx == nil {
return errors.New("run redis session subscriber: nil context")
}
if err := ctx.Err(); err != nil {
return err
}
lastID, err := s.resolveStartID(ctx)
if err != nil {
return err
}
s.signalStarted()
for {
streams, err := s.client.XRead(ctx, &redis.XReadArgs{
Streams: []string{s.stream, lastID},
Count: sessionEventReadCount,
Block: s.readBlockTimeout,
}).Result()
switch {
case err == nil:
for _, stream := range streams {
for _, message := range stream.Messages {
s.applyMessage(message)
lastID = message.ID
}
}
continue
case errors.Is(err, redis.Nil):
continue
case ctx.Err() != nil && (errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) || errors.Is(err, redis.ErrClosed)):
return ctx.Err()
case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded), errors.Is(err, redis.ErrClosed):
return fmt.Errorf("run redis session subscriber: %w", err)
default:
return fmt.Errorf("run redis session subscriber: %w", err)
}
}
}
func (s *RedisSessionSubscriber) resolveStartID(ctx context.Context) (string, error) {
messages, err := s.client.XRevRangeN(ctx, s.stream, "+", "-", 1).Result()
switch {
case err == nil:
case errors.Is(err, redis.Nil):
return "0-0", nil
default:
return "", fmt.Errorf("run redis session subscriber: resolve stream tail: %w", err)
}
if len(messages) == 0 {
return "0-0", nil
}
return messages[0].ID, nil
}
// Shutdown closes the Redis client so a blocking stream read can terminate
// promptly during gateway shutdown.
func (s *RedisSessionSubscriber) Shutdown(ctx context.Context) error {
if ctx == nil {
return errors.New("shutdown redis session subscriber: nil context")
}
return s.Close()
}
// Close releases the underlying Redis client resources.
func (s *RedisSessionSubscriber) Close() error {
if s == nil || s.client == nil {
return nil
}
var err error
s.closeOnce.Do(func() {
err = s.client.Close()
})
return err
}
func (s *RedisSessionSubscriber) signalStarted() {
s.startedOnce.Do(func() {
close(s.started)
})
}
func (s *RedisSessionSubscriber) applyMessage(message redis.XMessage) {
record, err := decodeSessionRecordSnapshot(message.Values)
if err != nil {
s.logger.Warn("dropped malformed session event",
zap.String("stream", s.stream),
zap.String("message_id", message.ID),
zap.Error(err),
)
s.metrics.RecordInternalEventDrop(context.Background(),
attribute.String("component", "session_subscriber"),
attribute.String("reason", "malformed_event"),
)
if deviceSessionID, ok := extractDeviceSessionID(message.Values); ok {
s.store.Delete(deviceSessionID)
}
return
}
if err := s.store.Upsert(record); err != nil {
s.logger.Warn("dropped session snapshot after store failure",
zap.String("stream", s.stream),
zap.String("message_id", message.ID),
zap.String("device_session_id", record.DeviceSessionID),
zap.Error(err),
)
s.metrics.RecordInternalEventDrop(context.Background(),
attribute.String("component", "session_subscriber"),
attribute.String("reason", "store_failure"),
)
s.store.Delete(record.DeviceSessionID)
return
}
if record.Status == session.StatusRevoked && s.revocationHandler != nil {
s.revocationHandler.RevokeDeviceSession(record.DeviceSessionID)
}
}
func decodeSessionRecordSnapshot(values map[string]any) (session.Record, error) {
requiredKeys := map[string]struct{}{
"device_session_id": {},
"user_id": {},
"client_public_key": {},
"status": {},
}
optionalKeys := map[string]struct{}{
"revoked_at_ms": {},
}
for key := range values {
if _, ok := requiredKeys[key]; ok {
continue
}
if _, ok := optionalKeys[key]; ok {
continue
}
return session.Record{}, fmt.Errorf("decode session event: unsupported field %q", key)
}
deviceSessionID, err := requiredStringField(values, "device_session_id")
if err != nil {
return session.Record{}, err
}
userID, err := requiredStringField(values, "user_id")
if err != nil {
return session.Record{}, err
}
clientPublicKey, err := requiredStringField(values, "client_public_key")
if err != nil {
return session.Record{}, err
}
statusValue, err := requiredStringField(values, "status")
if err != nil {
return session.Record{}, err
}
record := session.Record{
DeviceSessionID: deviceSessionID,
UserID: userID,
ClientPublicKey: clientPublicKey,
Status: session.Status(statusValue),
}
if rawRevokedAtMS, ok := values["revoked_at_ms"]; ok {
revokedAtMS, err := parseInt64Field(rawRevokedAtMS, "revoked_at_ms")
if err != nil {
return session.Record{}, err
}
record.RevokedAtMS = &revokedAtMS
}
return record, nil
}
func extractDeviceSessionID(values map[string]any) (string, bool) {
value, ok := values["device_session_id"]
if !ok {
return "", false
}
deviceSessionID, err := coerceString(value)
if err != nil {
return "", false
}
if strings.TrimSpace(deviceSessionID) == "" {
return "", false
}
return deviceSessionID, true
}
func requiredStringField(values map[string]any, field string) (string, error) {
value, ok := values[field]
if !ok {
return "", fmt.Errorf("decode session event: missing %s", field)
}
stringValue, err := coerceString(value)
if err != nil {
return "", fmt.Errorf("decode session event: %s: %w", field, err)
}
if strings.TrimSpace(stringValue) == "" {
return "", fmt.Errorf("decode session event: %s must not be empty", field)
}
return stringValue, nil
}
func parseInt64Field(value any, field string) (int64, error) {
stringValue, err := coerceString(value)
if err != nil {
return 0, fmt.Errorf("decode session event: %s: %w", field, err)
}
parsed, err := strconv.ParseInt(strings.TrimSpace(stringValue), 10, 64)
if err != nil {
return 0, fmt.Errorf("decode session event: %s: %w", field, err)
}
return parsed, nil
}
func coerceString(value any) (string, error) {
switch typed := value.(type) {
case string:
return typed, nil
case []byte:
return string(typed), nil
case fmt.Stringer:
return typed.String(), nil
case int:
return strconv.Itoa(typed), nil
case int64:
return strconv.FormatInt(typed, 10), nil
case uint64:
return strconv.FormatUint(typed, 10), nil
default:
return "", fmt.Errorf("unsupported value type %T", value)
}
}
+366
View File
@@ -0,0 +1,366 @@
package events
import (
"context"
"sync"
"testing"
"time"
"galaxy/gateway/internal/config"
"galaxy/gateway/internal/session"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRedisSessionSubscriberAppliesActiveSnapshot(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := session.NewMemoryCache()
subscriber := newTestRedisSessionSubscriber(t, server, store)
running := runTestSubscriber(t, subscriber)
defer running.stop(t)
addSessionEvent(t, server, "gateway:session_events", map[string]string{
"device_session_id": "device-session-123",
"user_id": "user-123",
"client_public_key": "public-key-123",
"status": string(session.StatusActive),
})
require.Eventually(t, func() bool {
record, err := store.Lookup(context.Background(), "device-session-123")
if err != nil {
return false
}
return record.UserID == "user-123" && record.Status == session.StatusActive
}, time.Second, 10*time.Millisecond)
}
func TestRedisSessionSubscriberAppliesRevokedSnapshot(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := session.NewMemoryCache()
require.NoError(t, store.Upsert(session.Record{
DeviceSessionID: "device-session-123",
UserID: "user-123",
ClientPublicKey: "public-key-123",
Status: session.StatusActive,
}))
subscriber := newTestRedisSessionSubscriber(t, server, store)
running := runTestSubscriber(t, subscriber)
defer running.stop(t)
addSessionEvent(t, server, "gateway:session_events", map[string]string{
"device_session_id": "device-session-123",
"user_id": "user-123",
"client_public_key": "public-key-123",
"status": string(session.StatusRevoked),
"revoked_at_ms": "123456789",
})
require.Eventually(t, func() bool {
record, err := store.Lookup(context.Background(), "device-session-123")
if err != nil || record.RevokedAtMS == nil {
return false
}
return record.Status == session.StatusRevoked && *record.RevokedAtMS == 123456789
}, time.Second, 10*time.Millisecond)
}
func TestRedisSessionSubscriberRevokedSnapshotTriggersRevocationHandler(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := session.NewMemoryCache()
handler := &recordingSessionRevocationHandler{}
subscriber := newTestRedisSessionSubscriberWithRevocationHandler(t, server, store, handler)
running := runTestSubscriber(t, subscriber)
defer running.stop(t)
addSessionEvent(t, server, "gateway:session_events", map[string]string{
"device_session_id": "device-session-123",
"user_id": "user-123",
"client_public_key": "public-key-123",
"status": string(session.StatusRevoked),
"revoked_at_ms": "123456789",
})
require.Eventually(t, func() bool {
record, err := store.Lookup(context.Background(), "device-session-123")
if err != nil || record.Status != session.StatusRevoked {
return false
}
return assert.ObjectsAreEqual([]string{"device-session-123"}, handler.revocations())
}, time.Second, 10*time.Millisecond)
}
func TestRedisSessionSubscriberActiveSnapshotDoesNotTriggerRevocationHandler(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := session.NewMemoryCache()
handler := &recordingSessionRevocationHandler{}
subscriber := newTestRedisSessionSubscriberWithRevocationHandler(t, server, store, handler)
running := runTestSubscriber(t, subscriber)
defer running.stop(t)
addSessionEvent(t, server, "gateway:session_events", map[string]string{
"device_session_id": "device-session-123",
"user_id": "user-123",
"client_public_key": "public-key-123",
"status": string(session.StatusActive),
})
assert.Never(t, func() bool {
return len(handler.revocations()) != 0
}, 100*time.Millisecond, 10*time.Millisecond)
}
func TestRedisSessionSubscriberStoreFailureDoesNotTriggerRevocationHandler(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
handler := &recordingSessionRevocationHandler{}
subscriber := newTestRedisSessionSubscriberWithRevocationHandler(t, server, failingSnapshotStore{}, handler)
running := runTestSubscriber(t, subscriber)
defer running.stop(t)
addSessionEvent(t, server, "gateway:session_events", map[string]string{
"device_session_id": "device-session-123",
"user_id": "user-123",
"client_public_key": "public-key-123",
"status": string(session.StatusRevoked),
"revoked_at_ms": "123456789",
})
assert.Never(t, func() bool {
return len(handler.revocations()) != 0
}, 100*time.Millisecond, 10*time.Millisecond)
}
func TestRedisSessionSubscriberLaterEventWins(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := session.NewMemoryCache()
subscriber := newTestRedisSessionSubscriber(t, server, store)
running := runTestSubscriber(t, subscriber)
defer running.stop(t)
addSessionEvent(t, server, "gateway:session_events", map[string]string{
"device_session_id": "device-session-123",
"user_id": "user-123",
"client_public_key": "public-key-123",
"status": string(session.StatusActive),
})
addSessionEvent(t, server, "gateway:session_events", map[string]string{
"device_session_id": "device-session-123",
"user_id": "user-456",
"client_public_key": "public-key-456",
"status": string(session.StatusActive),
})
require.Eventually(t, func() bool {
record, err := store.Lookup(context.Background(), "device-session-123")
if err != nil {
return false
}
return record.UserID == "user-456" && record.ClientPublicKey == "public-key-456"
}, time.Second, 10*time.Millisecond)
}
func TestRedisSessionSubscriberMalformedEventEvictsAndContinues(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := session.NewMemoryCache()
require.NoError(t, store.Upsert(session.Record{
DeviceSessionID: "device-session-123",
UserID: "user-123",
ClientPublicKey: "public-key-123",
Status: session.StatusActive,
}))
subscriber := newTestRedisSessionSubscriber(t, server, store)
running := runTestSubscriber(t, subscriber)
defer running.stop(t)
addSessionEvent(t, server, "gateway:session_events", map[string]string{
"device_session_id": "device-session-123",
"user_id": "user-123",
"client_public_key": "public-key-123",
"status": "paused",
})
require.Eventually(t, func() bool {
_, err := store.Lookup(context.Background(), "device-session-123")
return err != nil
}, time.Second, 10*time.Millisecond)
addSessionEvent(t, server, "gateway:session_events", map[string]string{
"device_session_id": "device-session-123",
"user_id": "user-456",
"client_public_key": "public-key-456",
"status": string(session.StatusActive),
})
require.Eventually(t, func() bool {
record, err := store.Lookup(context.Background(), "device-session-123")
if err != nil {
return false
}
return record.UserID == "user-456" && record.Status == session.StatusActive
}, time.Second, 10*time.Millisecond)
}
func TestRedisSessionSubscriberShutdownInterruptsBlockingRead(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := session.NewMemoryCache()
subscriber := newTestRedisSessionSubscriber(t, server, store)
ctx, cancel := context.WithCancel(context.Background())
resultCh := make(chan error, 1)
go func() {
resultCh <- subscriber.Run(ctx)
}()
select {
case <-subscriber.started:
case <-time.After(time.Second):
require.FailNow(t, "subscriber did not start")
}
cancel()
require.NoError(t, subscriber.Shutdown(context.Background()))
select {
case err := <-resultCh:
require.ErrorIs(t, err, context.Canceled)
case <-time.After(time.Second):
require.FailNow(t, "subscriber did not stop after shutdown")
}
}
func newTestRedisSessionSubscriber(t *testing.T, server *miniredis.Miniredis, store session.SnapshotStore) *RedisSessionSubscriber {
t.Helper()
return newTestRedisSessionSubscriberWithRevocationHandler(t, server, store, nil)
}
func newTestRedisSessionSubscriberWithRevocationHandler(t *testing.T, server *miniredis.Miniredis, store session.SnapshotStore, revocationHandler SessionRevocationHandler) *RedisSessionSubscriber {
t.Helper()
subscriber, err := NewRedisSessionSubscriberWithRevocationHandler(
config.SessionCacheRedisConfig{
Addr: server.Addr(),
LookupTimeout: 250 * time.Millisecond,
},
config.SessionEventsRedisConfig{
Stream: "gateway:session_events",
ReadBlockTimeout: 25 * time.Millisecond,
},
store,
revocationHandler,
)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, subscriber.Close())
})
return subscriber
}
type recordingSessionRevocationHandler struct {
mu sync.Mutex
revokedIDs []string
}
func (h *recordingSessionRevocationHandler) RevokeDeviceSession(deviceSessionID string) {
h.mu.Lock()
h.revokedIDs = append(h.revokedIDs, deviceSessionID)
h.mu.Unlock()
}
func (h *recordingSessionRevocationHandler) revocations() []string {
h.mu.Lock()
defer h.mu.Unlock()
return append([]string(nil), h.revokedIDs...)
}
type failingSnapshotStore struct{}
func (failingSnapshotStore) Lookup(context.Context, string) (session.Record, error) {
return session.Record{}, session.ErrNotFound
}
func (failingSnapshotStore) Upsert(session.Record) error {
return context.DeadlineExceeded
}
func (failingSnapshotStore) Delete(string) {}
func addSessionEvent(t *testing.T, server *miniredis.Miniredis, stream string, fields map[string]string) {
t.Helper()
values := make([]string, 0, len(fields)*2)
for key, value := range fields {
values = append(values, key, value)
}
_, err := server.XAdd(stream, "*", values)
require.NoError(t, err)
}
type runningSubscriber struct {
cancel context.CancelFunc
resultCh chan error
stopOnce bool
}
func runTestSubscriber(t *testing.T, subscriber *RedisSessionSubscriber) runningSubscriber {
t.Helper()
ctx, cancel := context.WithCancel(context.Background())
resultCh := make(chan error, 1)
go func() {
resultCh <- subscriber.Run(ctx)
}()
select {
case <-subscriber.started:
case <-time.After(time.Second):
require.FailNow(t, "subscriber did not start")
}
return runningSubscriber{
cancel: cancel,
resultCh: resultCh,
}
}
func (r runningSubscriber) stop(t *testing.T) {
t.Helper()
r.cancel()
select {
case err := <-r.resultCh:
require.ErrorIs(t, err, context.Canceled)
case <-time.After(time.Second):
require.FailNow(t, "subscriber did not stop")
}
}
+145
View File
@@ -0,0 +1,145 @@
package grpcapi
import (
"bytes"
"context"
"crypto/sha256"
"errors"
"strings"
"time"
"galaxy/gateway/internal/authn"
"galaxy/gateway/internal/clock"
"galaxy/gateway/internal/downstream"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// commandRoutingService translates the verified authenticated request context
// into an internal downstream command and signs successful unary responses.
type commandRoutingService struct {
gatewayv1.UnimplementedEdgeGatewayServer
subscribeDelegate gatewayv1.EdgeGatewayServer
router downstream.Router
responseSigner authn.ResponseSigner
clock clock.Clock
downstreamTimeout time.Duration
}
// ExecuteCommand builds a verified downstream command, routes it by exact
// message_type, executes it, and signs the resulting unary response.
func (s commandRoutingService) ExecuteCommand(ctx context.Context, _ *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
command, err := authenticatedCommandFromContext(ctx)
if err != nil {
return nil, err
}
client, err := s.router.Route(command.MessageType)
switch {
case err == nil:
case errors.Is(err, downstream.ErrRouteNotFound):
return nil, status.Error(codes.Unimplemented, "message_type is not routed")
case errors.Is(err, downstream.ErrDownstreamUnavailable):
return nil, status.Error(codes.Unavailable, "downstream service is unavailable")
default:
return nil, status.Error(codes.Internal, "downstream route resolution failed")
}
downstreamCtx, cancel := context.WithTimeout(ctx, s.downstreamTimeout)
defer cancel()
result, err := client.ExecuteCommand(downstreamCtx, command)
switch {
case err == nil:
case errors.Is(err, downstream.ErrDownstreamUnavailable),
errors.Is(err, context.DeadlineExceeded),
errors.Is(err, context.Canceled):
return nil, status.Error(codes.Unavailable, "downstream service is unavailable")
default:
return nil, status.Error(codes.Internal, "downstream execution failed")
}
if strings.TrimSpace(result.ResultCode) == "" {
return nil, status.Error(codes.Internal, "downstream response is invalid")
}
responseTimestampMS := s.clock.Now().UTC().UnixMilli()
payloadHash := sha256.Sum256(result.PayloadBytes)
signature, err := s.responseSigner.SignResponse(authn.ResponseSigningFields{
ProtocolVersion: command.ProtocolVersion,
RequestID: command.RequestID,
TimestampMS: responseTimestampMS,
ResultCode: result.ResultCode,
PayloadHash: payloadHash[:],
})
if err != nil {
return nil, status.Error(codes.Unavailable, "response signer is unavailable")
}
return &gatewayv1.ExecuteCommandResponse{
ProtocolVersion: command.ProtocolVersion,
RequestId: command.RequestID,
TimestampMs: responseTimestampMS,
ResultCode: result.ResultCode,
PayloadBytes: bytes.Clone(result.PayloadBytes),
PayloadHash: bytes.Clone(payloadHash[:]),
Signature: signature,
}, nil
}
// SubscribeEvents delegates to the authenticated streaming service
// implementation selected during server construction.
func (s commandRoutingService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
return s.subscribeDelegate.SubscribeEvents(req, stream)
}
// newCommandRoutingService constructs the final authenticated service that
// owns verified unary routing while preserving the delegated streaming path.
func newCommandRoutingService(subscribeDelegate gatewayv1.EdgeGatewayServer, router downstream.Router, responseSigner authn.ResponseSigner, clk clock.Clock, downstreamTimeout time.Duration) gatewayv1.EdgeGatewayServer {
return commandRoutingService{
subscribeDelegate: subscribeDelegate,
router: router,
responseSigner: responseSigner,
clock: clk,
downstreamTimeout: downstreamTimeout,
}
}
func authenticatedCommandFromContext(ctx context.Context) (downstream.AuthenticatedCommand, error) {
envelope, ok := parsedEnvelopeFromContext(ctx)
if !ok {
return downstream.AuthenticatedCommand{}, status.Error(codes.Internal, "authenticated request context is incomplete")
}
record, ok := resolvedSessionFromContext(ctx)
if !ok {
return downstream.AuthenticatedCommand{}, status.Error(codes.Internal, "authenticated request context is incomplete")
}
return downstream.AuthenticatedCommand{
ProtocolVersion: envelope.ProtocolVersion,
UserID: record.UserID,
DeviceSessionID: record.DeviceSessionID,
MessageType: envelope.MessageType,
TimestampMS: envelope.TimestampMS,
RequestID: envelope.RequestID,
TraceID: envelope.TraceID,
PayloadBytes: bytes.Clone(envelope.PayloadBytes),
}, nil
}
type unavailableResponseSigner struct{}
func (unavailableResponseSigner) SignResponse(authn.ResponseSigningFields) ([]byte, error) {
return nil, errors.New("response signer is unavailable")
}
func (unavailableResponseSigner) SignEvent(authn.EventSigningFields) ([]byte, error) {
return nil, errors.New("response signer is unavailable")
}
var _ gatewayv1.EdgeGatewayServer = commandRoutingService{}
@@ -0,0 +1,296 @@
package grpcapi
import (
"context"
"crypto/sha256"
"fmt"
"testing"
"time"
"galaxy/gateway/internal/authn"
"galaxy/gateway/internal/downstream"
"galaxy/gateway/internal/testutil"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
func TestExecuteCommandRoutesVerifiedCommandAndSignsResponse(t *testing.T) {
t.Parallel()
signer := newTestEd25519ResponseSigner()
moveClient := &recordingDownstreamClient{
executeFunc: func(_ context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
assert.Equal(t, downstream.AuthenticatedCommand{
ProtocolVersion: "v1",
UserID: "user-123",
DeviceSessionID: "device-session-123",
MessageType: "fleet.move",
TimestampMS: testCurrentTime.UnixMilli(),
RequestID: "request-123",
TraceID: "trace-123",
PayloadBytes: []byte("payload"),
}, command)
return downstream.UnaryResult{
ResultCode: "accepted",
PayloadBytes: []byte("downstream-response"),
}, nil
},
}
renameClient := &recordingDownstreamClient{}
server, runGateway := newTestGateway(t, ServerDependencies{
Router: downstream.NewStaticRouter(map[string]downstream.Client{
"fleet.move": moveClient,
"fleet.rename": renameClient,
}),
ResponseSigner: signer,
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
ReplayStore: staticReplayStore{},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
response, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
require.NoError(t, err)
assert.Equal(t, "v1", response.GetProtocolVersion())
assert.Equal(t, "request-123", response.GetRequestId())
assert.Equal(t, testCurrentTime.UnixMilli(), response.GetTimestampMs())
assert.Equal(t, "accepted", response.GetResultCode())
assert.Equal(t, []byte("downstream-response"), response.GetPayloadBytes())
assert.Equal(t, 1, moveClient.executeCalls)
assert.Zero(t, renameClient.executeCalls)
wantHash := sha256.Sum256([]byte("downstream-response"))
assert.Equal(t, wantHash[:], response.GetPayloadHash())
require.NoError(t, authn.VerifyPayloadHash(response.GetPayloadBytes(), response.GetPayloadHash()))
require.NoError(t, authn.VerifyResponseSignature(signer.PublicKey(), response.GetSignature(), authn.ResponseSigningFields{
ProtocolVersion: response.GetProtocolVersion(),
RequestID: response.GetRequestId(),
TimestampMS: response.GetTimestampMs(),
ResultCode: response.GetResultCode(),
PayloadHash: response.GetPayloadHash(),
}))
}
func TestExecuteCommandRouteMissReturnsUnimplemented(t *testing.T) {
t.Parallel()
server, runGateway := newTestGateway(t, ServerDependencies{
Router: downstream.NewStaticRouter(nil),
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
ReplayStore: staticReplayStore{},
ResponseSigner: newTestResponseSigner(),
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
require.Error(t, err)
assert.Equal(t, codes.Unimplemented, status.Code(err))
assert.Equal(t, "message_type is not routed", status.Convert(err).Message())
}
func TestExecuteCommandMapsDownstreamUnavailableToUnavailable(t *testing.T) {
t.Parallel()
failingClient := &recordingDownstreamClient{
executeFunc: func(context.Context, downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
return downstream.UnaryResult{}, fmt.Errorf("rpc transport failed: %w", downstream.ErrDownstreamUnavailable)
},
}
server, runGateway := newTestGateway(t, ServerDependencies{
Router: downstream.NewStaticRouter(map[string]downstream.Client{
"fleet.move": failingClient,
}),
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
ReplayStore: staticReplayStore{},
ResponseSigner: newTestResponseSigner(),
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
require.Error(t, err)
assert.Equal(t, codes.Unavailable, status.Code(err))
assert.Equal(t, "downstream service is unavailable", status.Convert(err).Message())
assert.Equal(t, 1, failingClient.executeCalls)
}
func TestExecuteCommandPropagatesOTelSpanContextToDownstream(t *testing.T) {
t.Parallel()
logger := zap.NewNop()
telemetryRuntime := testutil.NewTelemetryRuntime(t, logger)
var (
seenSpanContext trace.SpanContext
seenCommand downstream.AuthenticatedCommand
)
server, runGateway := newTestGateway(t, ServerDependencies{
Router: downstream.NewStaticRouter(map[string]downstream.Client{
"fleet.move": &recordingDownstreamClient{
executeFunc: func(ctx context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
seenSpanContext = trace.SpanContextFromContext(ctx)
seenCommand = command
return downstream.UnaryResult{
ResultCode: "accepted",
PayloadBytes: []byte("downstream-response"),
}, nil
},
},
}),
ResponseSigner: newTestResponseSigner(),
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
ReplayStore: staticReplayStore{},
Logger: logger,
Telemetry: telemetryRuntime,
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
require.NoError(t, err)
assert.True(t, seenSpanContext.IsValid())
assert.Equal(t, "trace-123", seenCommand.TraceID)
}
func TestExecuteCommandDrainsInFlightUnaryDuringShutdown(t *testing.T) {
t.Parallel()
started := make(chan struct{})
release := make(chan struct{})
server, runGateway := newTestGateway(t, ServerDependencies{
Router: downstream.NewStaticRouter(map[string]downstream.Client{
"fleet.move": &recordingDownstreamClient{
executeFunc: func(_ context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
close(started)
<-release
return downstream.UnaryResult{
ResultCode: "accepted",
PayloadBytes: []byte("downstream-response"),
}, nil
},
},
}),
ResponseSigner: newTestResponseSigner(),
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
ReplayStore: staticReplayStore{},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
resultCh := make(chan error, 1)
go func() {
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
resultCh <- err
}()
require.Eventually(t, func() bool {
select {
case <-started:
return true
default:
return false
}
}, time.Second, 10*time.Millisecond, "downstream execution did not start")
runGateway.cancel()
require.Never(t, func() bool {
select {
case <-resultCh:
return true
default:
return false
}
}, 100*time.Millisecond, 10*time.Millisecond, "unary request returned before downstream release")
close(release)
var err error
require.Eventually(t, func() bool {
select {
case err = <-resultCh:
return true
default:
return false
}
}, time.Second, 10*time.Millisecond, "unary request did not drain before shutdown timeout")
require.NoError(t, err)
}
func TestExecuteCommandLogsDoNotContainSensitiveTransportMaterial(t *testing.T) {
t.Parallel()
logger, logBuffer := testutil.NewObservedLogger(t)
server, runGateway := newTestGateway(t, ServerDependencies{
Router: downstream.NewStaticRouter(map[string]downstream.Client{
"fleet.move": &recordingDownstreamClient{},
}),
ResponseSigner: newTestResponseSigner(),
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
ReplayStore: staticReplayStore{},
Logger: logger,
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
require.NoError(t, err)
logOutput := logBuffer.String()
assert.NotContains(t, logOutput, "payload_hash")
assert.NotContains(t, logOutput, "signature")
assert.NotContains(t, logOutput, `"payload"`)
}
+214
View File
@@ -0,0 +1,214 @@
package grpcapi
import (
"bytes"
"context"
"fmt"
"galaxy/gateway/proto/galaxy/gateway/v1"
"buf.build/go/protovalidate"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
const supportedProtocolVersion = "v1"
// parsedEnvelope captures the authenticated transport fields extracted from a
// request envelope after validation succeeds. Later wrappers may enrich this
// structure without changing the raw gRPC request types.
type parsedEnvelope struct {
ProtocolVersion string
DeviceSessionID string
MessageType string
TimestampMS int64
RequestID string
TraceID string
PayloadBytes []byte
PayloadHash []byte
Signature []byte
}
// parsedEnvelopeFromContext returns the parsed envelope previously attached to
// ctx by the envelope-validating gRPC service wrapper.
func parsedEnvelopeFromContext(ctx context.Context) (parsedEnvelope, bool) {
if ctx == nil {
return parsedEnvelope{}, false
}
envelope, ok := ctx.Value(parsedEnvelopeContextKey{}).(parsedEnvelope)
if !ok {
return parsedEnvelope{}, false
}
return envelope, true
}
// envelopeValidatingService applies envelope parsing and the protocol gate
// before delegating to the configured service implementation.
type envelopeValidatingService struct {
gatewayv1.UnimplementedEdgeGatewayServer
delegate gatewayv1.EdgeGatewayServer
}
// ExecuteCommand validates req and only then forwards it to the configured
// delegate with the parsed envelope attached to ctx.
func (s envelopeValidatingService) ExecuteCommand(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
envelope, err := parseExecuteCommandRequest(req)
if err != nil {
return nil, err
}
return s.delegate.ExecuteCommand(context.WithValue(ctx, parsedEnvelopeContextKey{}, envelope), req)
}
// SubscribeEvents validates req and only then forwards it to the configured
// delegate with the parsed envelope attached to the stream context.
func (s envelopeValidatingService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
envelope, err := parseSubscribeEventsRequest(req)
if err != nil {
return err
}
return s.delegate.SubscribeEvents(req, envelopeContextStream{
ServerStreamingServer: stream,
ctx: context.WithValue(stream.Context(), parsedEnvelopeContextKey{}, envelope),
})
}
// parseExecuteCommandRequest validates req according to the request-envelope
// rules and returns a cloned parsed envelope suitable for later auth steps.
func parseExecuteCommandRequest(req *gatewayv1.ExecuteCommandRequest) (parsedEnvelope, error) {
if req == nil {
return parsedEnvelope{}, newMalformedEnvelopeError("request envelope must not be nil")
}
if err := protovalidate.Validate(req); err != nil {
return parsedEnvelope{}, canonicalExecuteCommandValidationError(req)
}
if req.GetProtocolVersion() != supportedProtocolVersion {
return parsedEnvelope{}, newUnsupportedProtocolVersionError(req.GetProtocolVersion())
}
return parsedEnvelope{
ProtocolVersion: req.GetProtocolVersion(),
DeviceSessionID: req.GetDeviceSessionId(),
MessageType: req.GetMessageType(),
TimestampMS: req.GetTimestampMs(),
RequestID: req.GetRequestId(),
TraceID: req.GetTraceId(),
PayloadBytes: bytes.Clone(req.GetPayloadBytes()),
PayloadHash: bytes.Clone(req.GetPayloadHash()),
Signature: bytes.Clone(req.GetSignature()),
}, nil
}
// parseSubscribeEventsRequest validates req according to the request-envelope
// rules and returns a cloned parsed envelope suitable for later auth steps.
func parseSubscribeEventsRequest(req *gatewayv1.SubscribeEventsRequest) (parsedEnvelope, error) {
if req == nil {
return parsedEnvelope{}, newMalformedEnvelopeError("request envelope must not be nil")
}
if err := protovalidate.Validate(req); err != nil {
return parsedEnvelope{}, canonicalSubscribeEventsValidationError(req)
}
if req.GetProtocolVersion() != supportedProtocolVersion {
return parsedEnvelope{}, newUnsupportedProtocolVersionError(req.GetProtocolVersion())
}
return parsedEnvelope{
ProtocolVersion: req.GetProtocolVersion(),
DeviceSessionID: req.GetDeviceSessionId(),
MessageType: req.GetMessageType(),
TimestampMS: req.GetTimestampMs(),
RequestID: req.GetRequestId(),
TraceID: req.GetTraceId(),
PayloadBytes: bytes.Clone(req.GetPayloadBytes()),
PayloadHash: bytes.Clone(req.GetPayloadHash()),
Signature: bytes.Clone(req.GetSignature()),
}, nil
}
// newEnvelopeValidatingService wraps delegate with the envelope-validation
// gate.
func newEnvelopeValidatingService(delegate gatewayv1.EdgeGatewayServer) gatewayv1.EdgeGatewayServer {
return envelopeValidatingService{delegate: delegate}
}
// canonicalExecuteCommandValidationError maps any ExecuteCommand validation
// failure into the stable canonical error chosen by field order.
func canonicalExecuteCommandValidationError(req *gatewayv1.ExecuteCommandRequest) error {
switch {
case req.GetProtocolVersion() == "":
return newMalformedEnvelopeError("protocol_version must not be empty")
case req.GetDeviceSessionId() == "":
return newMalformedEnvelopeError("device_session_id must not be empty")
case req.GetMessageType() == "":
return newMalformedEnvelopeError("message_type must not be empty")
case req.GetTimestampMs() <= 0:
return newMalformedEnvelopeError("timestamp_ms must be greater than zero")
case req.GetRequestId() == "":
return newMalformedEnvelopeError("request_id must not be empty")
case len(req.GetPayloadBytes()) == 0:
return newMalformedEnvelopeError("payload_bytes must not be empty")
case len(req.GetPayloadHash()) == 0:
return newMalformedEnvelopeError("payload_hash must not be empty")
case len(req.GetSignature()) == 0:
return newMalformedEnvelopeError("signature must not be empty")
default:
return newMalformedEnvelopeError("request envelope is invalid")
}
}
// canonicalSubscribeEventsValidationError maps any SubscribeEvents validation
// failure into the stable canonical error chosen by field order.
func canonicalSubscribeEventsValidationError(req *gatewayv1.SubscribeEventsRequest) error {
switch {
case req.GetProtocolVersion() == "":
return newMalformedEnvelopeError("protocol_version must not be empty")
case req.GetDeviceSessionId() == "":
return newMalformedEnvelopeError("device_session_id must not be empty")
case req.GetMessageType() == "":
return newMalformedEnvelopeError("message_type must not be empty")
case req.GetTimestampMs() <= 0:
return newMalformedEnvelopeError("timestamp_ms must be greater than zero")
case req.GetRequestId() == "":
return newMalformedEnvelopeError("request_id must not be empty")
case len(req.GetPayloadHash()) == 0:
return newMalformedEnvelopeError("payload_hash must not be empty")
case len(req.GetSignature()) == 0:
return newMalformedEnvelopeError("signature must not be empty")
default:
return newMalformedEnvelopeError("request envelope is invalid")
}
}
// newMalformedEnvelopeError returns the stable malformed-envelope reject used
// before the gateway performs any auth or routing work.
func newMalformedEnvelopeError(message string) error {
return status.Error(codes.InvalidArgument, message)
}
// newUnsupportedProtocolVersionError returns the stable reject for a non-empty
// but unsupported protocol_version literal.
func newUnsupportedProtocolVersionError(version string) error {
return status.Error(codes.FailedPrecondition, fmt.Sprintf("unsupported protocol_version %q", version))
}
type parsedEnvelopeContextKey struct{}
type envelopeContextStream struct {
grpc.ServerStreamingServer[gatewayv1.GatewayEvent]
ctx context.Context
}
func (s envelopeContextStream) Context() context.Context {
if s.ctx == nil {
return context.Background()
}
return s.ctx
}
var _ gatewayv1.EdgeGatewayServer = envelopeValidatingService{}
+420
View File
@@ -0,0 +1,420 @@
package grpcapi
import (
"context"
"testing"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
func TestParseExecuteCommandRequest(t *testing.T) {
t.Parallel()
tests := []struct {
name string
mutate func(*gatewayv1.ExecuteCommandRequest)
wantCode codes.Code
wantMessage string
assertValid func(*testing.T, *gatewayv1.ExecuteCommandRequest, parsedEnvelope)
}{
{
name: "nil request",
wantCode: codes.InvalidArgument,
wantMessage: "request envelope must not be nil",
},
{
name: "empty protocol version",
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
req.ProtocolVersion = ""
},
wantCode: codes.InvalidArgument,
wantMessage: "protocol_version must not be empty",
},
{
name: "empty device session id",
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
req.DeviceSessionId = ""
},
wantCode: codes.InvalidArgument,
wantMessage: "device_session_id must not be empty",
},
{
name: "empty message type",
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
req.MessageType = ""
},
wantCode: codes.InvalidArgument,
wantMessage: "message_type must not be empty",
},
{
name: "zero timestamp",
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
req.TimestampMs = 0
},
wantCode: codes.InvalidArgument,
wantMessage: "timestamp_ms must be greater than zero",
},
{
name: "empty request id",
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
req.RequestId = ""
},
wantCode: codes.InvalidArgument,
wantMessage: "request_id must not be empty",
},
{
name: "empty payload bytes",
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
req.PayloadBytes = nil
},
wantCode: codes.InvalidArgument,
wantMessage: "payload_bytes must not be empty",
},
{
name: "empty payload hash",
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
req.PayloadHash = nil
},
wantCode: codes.InvalidArgument,
wantMessage: "payload_hash must not be empty",
},
{
name: "empty signature",
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
req.Signature = nil
},
wantCode: codes.InvalidArgument,
wantMessage: "signature must not be empty",
},
{
name: "unsupported protocol version",
mutate: func(req *gatewayv1.ExecuteCommandRequest) {
req.ProtocolVersion = "v2"
},
wantCode: codes.FailedPrecondition,
wantMessage: `unsupported protocol_version "v2"`,
},
{
name: "valid request",
wantCode: codes.OK,
assertValid: func(t *testing.T, req *gatewayv1.ExecuteCommandRequest, envelope parsedEnvelope) {
t.Helper()
assert.Equal(t, supportedProtocolVersion, envelope.ProtocolVersion)
assert.Equal(t, req.GetDeviceSessionId(), envelope.DeviceSessionID)
assert.Equal(t, req.GetMessageType(), envelope.MessageType)
assert.Equal(t, req.GetTimestampMs(), envelope.TimestampMS)
assert.Equal(t, req.GetRequestId(), envelope.RequestID)
assert.Equal(t, req.GetTraceId(), envelope.TraceID)
assert.Equal(t, req.GetPayloadBytes(), envelope.PayloadBytes)
assert.Equal(t, req.GetPayloadHash(), envelope.PayloadHash)
assert.Equal(t, req.GetSignature(), envelope.Signature)
originalPayloadBytes := append([]byte(nil), req.GetPayloadBytes()...)
originalPayloadHash := append([]byte(nil), req.GetPayloadHash()...)
originalSignature := append([]byte(nil), req.GetSignature()...)
envelope.PayloadBytes[0] = 'X'
envelope.PayloadHash[0] = 'Y'
envelope.Signature[0] = 'Z'
assert.Equal(t, originalPayloadBytes, req.GetPayloadBytes())
assert.Equal(t, originalPayloadHash, req.GetPayloadHash())
assert.Equal(t, originalSignature, req.GetSignature())
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var req *gatewayv1.ExecuteCommandRequest
if tt.name != "nil request" {
req = newValidExecuteCommandRequest()
if tt.mutate != nil {
tt.mutate(req)
}
}
envelope, err := parseExecuteCommandRequest(req)
if tt.wantCode != codes.OK {
require.Error(t, err)
assert.Equal(t, tt.wantCode, status.Code(err))
assert.Equal(t, tt.wantMessage, status.Convert(err).Message())
return
}
require.NoError(t, err)
require.NotNil(t, tt.assertValid)
tt.assertValid(t, req, envelope)
})
}
}
func TestParseSubscribeEventsRequest(t *testing.T) {
t.Parallel()
tests := []struct {
name string
mutate func(*gatewayv1.SubscribeEventsRequest)
wantCode codes.Code
wantMessage string
assertValid func(*testing.T, *gatewayv1.SubscribeEventsRequest, parsedEnvelope)
}{
{
name: "nil request",
wantCode: codes.InvalidArgument,
wantMessage: "request envelope must not be nil",
},
{
name: "empty protocol version",
mutate: func(req *gatewayv1.SubscribeEventsRequest) {
req.ProtocolVersion = ""
},
wantCode: codes.InvalidArgument,
wantMessage: "protocol_version must not be empty",
},
{
name: "empty device session id",
mutate: func(req *gatewayv1.SubscribeEventsRequest) {
req.DeviceSessionId = ""
},
wantCode: codes.InvalidArgument,
wantMessage: "device_session_id must not be empty",
},
{
name: "empty message type",
mutate: func(req *gatewayv1.SubscribeEventsRequest) {
req.MessageType = ""
},
wantCode: codes.InvalidArgument,
wantMessage: "message_type must not be empty",
},
{
name: "zero timestamp",
mutate: func(req *gatewayv1.SubscribeEventsRequest) {
req.TimestampMs = 0
},
wantCode: codes.InvalidArgument,
wantMessage: "timestamp_ms must be greater than zero",
},
{
name: "empty request id",
mutate: func(req *gatewayv1.SubscribeEventsRequest) {
req.RequestId = ""
},
wantCode: codes.InvalidArgument,
wantMessage: "request_id must not be empty",
},
{
name: "empty payload hash",
mutate: func(req *gatewayv1.SubscribeEventsRequest) {
req.PayloadHash = nil
},
wantCode: codes.InvalidArgument,
wantMessage: "payload_hash must not be empty",
},
{
name: "empty signature",
mutate: func(req *gatewayv1.SubscribeEventsRequest) {
req.Signature = nil
},
wantCode: codes.InvalidArgument,
wantMessage: "signature must not be empty",
},
{
name: "unsupported protocol version",
mutate: func(req *gatewayv1.SubscribeEventsRequest) {
req.ProtocolVersion = "v2"
},
wantCode: codes.FailedPrecondition,
wantMessage: `unsupported protocol_version "v2"`,
},
{
name: "valid request with empty payload bytes",
wantCode: codes.OK,
assertValid: func(t *testing.T, req *gatewayv1.SubscribeEventsRequest, envelope parsedEnvelope) {
t.Helper()
assert.Empty(t, req.GetPayloadBytes())
assert.Empty(t, envelope.PayloadBytes)
assert.Equal(t, req.GetPayloadHash(), envelope.PayloadHash)
assert.Equal(t, req.GetSignature(), envelope.Signature)
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var req *gatewayv1.SubscribeEventsRequest
if tt.name != "nil request" {
req = newValidSubscribeEventsRequest()
if tt.mutate != nil {
tt.mutate(req)
}
}
envelope, err := parseSubscribeEventsRequest(req)
if tt.wantCode != codes.OK {
require.Error(t, err)
assert.Equal(t, tt.wantCode, status.Code(err))
assert.Equal(t, tt.wantMessage, status.Convert(err).Message())
return
}
require.NoError(t, err)
require.NotNil(t, tt.assertValid)
tt.assertValid(t, req, envelope)
})
}
}
func TestEnvelopeValidatingServiceExecuteCommandRejectsInvalidRequestBeforeDelegate(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
service := newEnvelopeValidatingService(delegate)
_, err := service.ExecuteCommand(context.Background(), &gatewayv1.ExecuteCommandRequest{})
require.Error(t, err)
assert.Equal(t, codes.InvalidArgument, status.Code(err))
assert.Zero(t, delegate.executeCalls)
}
func TestEnvelopeValidatingServiceSubscribeEventsRejectsInvalidRequestBeforeDelegate(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
service := newEnvelopeValidatingService(delegate)
err := service.SubscribeEvents(&gatewayv1.SubscribeEventsRequest{}, stubGatewayEventStream{})
require.Error(t, err)
assert.Equal(t, codes.InvalidArgument, status.Code(err))
assert.Zero(t, delegate.subscribeCalls)
}
func TestEnvelopeValidatingServiceExecuteCommandAttachesParsedEnvelope(t *testing.T) {
t.Parallel()
want := newValidExecuteCommandRequest()
delegate := &recordingEdgeGatewayService{
executeCommandFunc: func(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
envelope, ok := parsedEnvelopeFromContext(ctx)
require.True(t, ok)
assert.Equal(t, want.GetRequestId(), envelope.RequestID)
assert.Equal(t, want.GetDeviceSessionId(), envelope.DeviceSessionID)
assert.Equal(t, want.GetMessageType(), envelope.MessageType)
assert.Equal(t, want.GetPayloadBytes(), envelope.PayloadBytes)
return &gatewayv1.ExecuteCommandResponse{RequestId: req.GetRequestId()}, nil
},
}
service := newEnvelopeValidatingService(delegate)
response, err := service.ExecuteCommand(context.Background(), want)
require.NoError(t, err)
assert.Equal(t, want.GetRequestId(), response.GetRequestId())
assert.Equal(t, 1, delegate.executeCalls)
}
func TestEnvelopeValidatingServiceSubscribeEventsAttachesParsedEnvelope(t *testing.T) {
t.Parallel()
want := newValidSubscribeEventsRequest()
delegate := &recordingEdgeGatewayService{
subscribeEventsFunc: func(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
envelope, ok := parsedEnvelopeFromContext(stream.Context())
require.True(t, ok)
assert.Equal(t, want.GetRequestId(), envelope.RequestID)
assert.Equal(t, want.GetDeviceSessionId(), envelope.DeviceSessionID)
assert.Equal(t, want.GetMessageType(), envelope.MessageType)
assert.Equal(t, want.GetPayloadHash(), envelope.PayloadHash)
assert.Equal(t, want.GetSignature(), envelope.Signature)
return nil
},
}
service := newEnvelopeValidatingService(delegate)
err := service.SubscribeEvents(want, stubGatewayEventStream{})
require.NoError(t, err)
assert.Equal(t, 1, delegate.subscribeCalls)
}
type recordingEdgeGatewayService struct {
gatewayv1.UnimplementedEdgeGatewayServer
executeCalls int
subscribeCalls int
executeCommandFunc func(context.Context, *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error)
subscribeEventsFunc func(*gatewayv1.SubscribeEventsRequest, grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error
}
func (s *recordingEdgeGatewayService) ExecuteCommand(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
s.executeCalls++
if s.executeCommandFunc != nil {
return s.executeCommandFunc(ctx, req)
}
return &gatewayv1.ExecuteCommandResponse{}, nil
}
func (s *recordingEdgeGatewayService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
s.subscribeCalls++
if s.subscribeEventsFunc != nil {
return s.subscribeEventsFunc(req, stream)
}
return nil
}
type stubGatewayEventStream struct {
grpc.ServerStream
ctx context.Context
}
func (s stubGatewayEventStream) Send(*gatewayv1.GatewayEvent) error {
return nil
}
func (s stubGatewayEventStream) SetHeader(metadata.MD) error {
return nil
}
func (s stubGatewayEventStream) SendHeader(metadata.MD) error {
return nil
}
func (s stubGatewayEventStream) SetTrailer(metadata.MD) {}
func (s stubGatewayEventStream) Context() context.Context {
if s.ctx == nil {
return context.Background()
}
return s.ctx
}
func (s stubGatewayEventStream) SendMsg(any) error {
return nil
}
func (s stubGatewayEventStream) RecvMsg(any) error {
return nil
}
@@ -0,0 +1,95 @@
package grpcapi
import (
"context"
"errors"
"time"
"galaxy/gateway/internal/clock"
"galaxy/gateway/internal/replay"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
const minimumReplayReservationTTL = time.Millisecond
// freshnessAndReplayService applies freshness and anti-replay checks after
// client-signature verification and before later policy or routing steps run.
type freshnessAndReplayService struct {
gatewayv1.UnimplementedEdgeGatewayServer
delegate gatewayv1.EdgeGatewayServer
clock clock.Clock
replayStore replay.Store
freshnessWindow time.Duration
}
// ExecuteCommand verifies request freshness and replay protection before
// delegating to the configured service implementation.
func (s freshnessAndReplayService) ExecuteCommand(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
if err := s.verifyFreshnessAndReplay(ctx); err != nil {
return nil, err
}
return s.delegate.ExecuteCommand(ctx, req)
}
// SubscribeEvents verifies request freshness and replay protection before
// delegating to the configured service implementation.
func (s freshnessAndReplayService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
if err := s.verifyFreshnessAndReplay(stream.Context()); err != nil {
return err
}
return s.delegate.SubscribeEvents(req, stream)
}
// newFreshnessAndReplayService wraps delegate with the freshness and replay
// gate.
func newFreshnessAndReplayService(delegate gatewayv1.EdgeGatewayServer, clk clock.Clock, replayStore replay.Store, freshnessWindow time.Duration) gatewayv1.EdgeGatewayServer {
return freshnessAndReplayService{
delegate: delegate,
clock: clk,
replayStore: replayStore,
freshnessWindow: freshnessWindow,
}
}
func (s freshnessAndReplayService) verifyFreshnessAndReplay(ctx context.Context) error {
envelope, ok := parsedEnvelopeFromContext(ctx)
if !ok {
return status.Error(codes.Internal, "authenticated request context is incomplete")
}
now := s.clock.Now().UTC()
requestTime := time.UnixMilli(envelope.TimestampMS).UTC()
if requestTime.Before(now.Add(-s.freshnessWindow)) || requestTime.After(now.Add(s.freshnessWindow)) {
return status.Error(codes.FailedPrecondition, "request timestamp is outside the freshness window")
}
ttl := requestTime.Add(s.freshnessWindow).Sub(now)
if ttl < minimumReplayReservationTTL {
ttl = minimumReplayReservationTTL
}
err := s.replayStore.Reserve(ctx, envelope.DeviceSessionID, envelope.RequestID, ttl)
switch {
case err == nil:
return nil
case errors.Is(err, replay.ErrDuplicate):
return status.Error(codes.FailedPrecondition, "request replay detected")
default:
return status.Error(codes.Unavailable, "replay store is unavailable")
}
}
type unavailableReplayStore struct{}
func (unavailableReplayStore) Reserve(context.Context, string, string, time.Duration) error {
return errors.New("replay store is unavailable")
}
var _ gatewayv1.EdgeGatewayServer = freshnessAndReplayService{}
@@ -0,0 +1,509 @@
package grpcapi
import (
"context"
"errors"
"io"
"sync"
"testing"
"time"
"galaxy/gateway/internal/replay"
"galaxy/gateway/internal/session"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
func TestExecuteCommandRejectsStaleTimestamp(t *testing.T) {
t.Parallel()
tests := []struct {
name string
timestampMS int64
}{
{
name: "past window",
timestampMS: testCurrentTime.Add(-testFreshnessWindow - time.Millisecond).UnixMilli(),
},
{
name: "future window",
timestampMS: testCurrentTime.Add(testFreshnessWindow + time.Millisecond).UnixMilli(),
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
ReplayStore: staticReplayStore{},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithTimestamp("device-session-123", "request-123", tt.timestampMS))
require.Error(t, err)
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
assert.Equal(t, "request timestamp is outside the freshness window", status.Convert(err).Message())
assert.Zero(t, delegate.executeCalls)
})
}
}
func TestSubscribeEventsRejectsStaleTimestamp(t *testing.T) {
t.Parallel()
tests := []struct {
name string
timestampMS int64
}{
{
name: "past window",
timestampMS: testCurrentTime.Add(-testFreshnessWindow - time.Millisecond).UnixMilli(),
},
{
name: "future window",
timestampMS: testCurrentTime.Add(testFreshnessWindow + time.Millisecond).UnixMilli(),
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
ReplayStore: staticReplayStore{},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequestWithTimestamp("device-session-123", "request-123", tt.timestampMS))
require.Error(t, err)
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
assert.Equal(t, "request timestamp is outside the freshness window", status.Convert(err).Message())
assert.Zero(t, delegate.subscribeCalls)
})
}
}
func TestExecuteCommandRejectsReplay(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
ReplayStore: staticReplayStore{
reserveFunc: replayDuplicateBySessionAndRequest(),
},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
req := newValidExecuteCommandRequest()
_, err := client.ExecuteCommand(context.Background(), req)
require.NoError(t, err)
_, err = client.ExecuteCommand(context.Background(), req)
require.Error(t, err)
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
assert.Equal(t, "request replay detected", status.Convert(err).Message())
assert.Equal(t, 1, delegate.executeCalls)
}
func TestSubscribeEventsRejectsReplay(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
ReplayStore: staticReplayStore{
reserveFunc: replayDuplicateBySessionAndRequest(),
},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
req := newValidSubscribeEventsRequest()
stream, err := client.SubscribeEvents(context.Background(), req)
require.NoError(t, err)
event := recvBootstrapEvent(t, stream)
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-123", "trace-123", testCurrentTime.UnixMilli())
_, err = stream.Recv()
require.ErrorIs(t, err, io.EOF)
err = subscribeEventsError(t, context.Background(), client, req)
require.Error(t, err)
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
assert.Equal(t, "request replay detected", status.Convert(err).Message())
assert.Equal(t, 1, delegate.subscribeCalls)
}
func TestExecuteCommandAllowsSameRequestIDAcrossDistinctSessions(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{
executeCommandFunc: func(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
return &gatewayv1.ExecuteCommandResponse{RequestId: req.GetRequestId()}, nil
},
}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{
lookupFunc: func(ctx context.Context, deviceSessionID string) (session.Record, error) {
return newActiveSessionRecordWithSessionID(deviceSessionID), nil
},
},
ReplayStore: staticReplayStore{
reserveFunc: replayDuplicateBySessionAndRequest(),
},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-123", "request-shared"))
require.NoError(t, err)
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-456", "request-shared"))
require.NoError(t, err)
assert.Equal(t, 2, delegate.executeCalls)
}
func TestSubscribeEventsAllowsSameRequestIDAcrossDistinctSessions(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{
subscribeEventsFunc: func(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
return nil
},
}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{
lookupFunc: func(ctx context.Context, deviceSessionID string) (session.Record, error) {
return newActiveSessionRecordWithSessionID(deviceSessionID), nil
},
},
ReplayStore: staticReplayStore{
reserveFunc: replayDuplicateBySessionAndRequest(),
},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
stream, err := client.SubscribeEvents(context.Background(), newValidSubscribeEventsRequestWithSessionAndRequestID("device-session-123", "request-shared"))
require.NoError(t, err)
event := recvBootstrapEvent(t, stream)
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-shared", "trace-123", testCurrentTime.UnixMilli())
_, err = stream.Recv()
require.ErrorIs(t, err, io.EOF)
stream, err = client.SubscribeEvents(context.Background(), newValidSubscribeEventsRequestWithSessionAndRequestID("device-session-456", "request-shared"))
require.NoError(t, err)
event = recvBootstrapEvent(t, stream)
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-shared", "trace-123", testCurrentTime.UnixMilli())
_, err = stream.Recv()
require.ErrorIs(t, err, io.EOF)
assert.Equal(t, 2, delegate.subscribeCalls)
}
func TestExecuteCommandRejectsReplayStoreUnavailable(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
ReplayStore: staticReplayStore{
reserveFunc: func(context.Context, string, string, time.Duration) error {
return errors.New("redis down")
},
},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
require.Error(t, err)
assert.Equal(t, codes.Unavailable, status.Code(err))
assert.Equal(t, "replay store is unavailable", status.Convert(err).Message())
assert.Zero(t, delegate.executeCalls)
}
func TestSubscribeEventsRejectsReplayStoreUnavailable(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
ReplayStore: staticReplayStore{
reserveFunc: func(context.Context, string, string, time.Duration) error {
return errors.New("redis down")
},
},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest())
require.Error(t, err)
assert.Equal(t, codes.Unavailable, status.Code(err))
assert.Equal(t, "replay store is unavailable", status.Convert(err).Message())
assert.Zero(t, delegate.subscribeCalls)
}
func TestExecuteCommandFreshRequestReachesDelegateAndUsesDynamicReplayTTL(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{
executeCommandFunc: func(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
return &gatewayv1.ExecuteCommandResponse{RequestId: req.GetRequestId()}, nil
},
}
var reservedDeviceSessionID string
var reservedRequestID string
var reservedTTL time.Duration
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
ReplayStore: staticReplayStore{
reserveFunc: func(ctx context.Context, deviceSessionID string, requestID string, ttl time.Duration) error {
reservedDeviceSessionID = deviceSessionID
reservedRequestID = requestID
reservedTTL = ttl
return nil
},
},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
response, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
require.NoError(t, err)
assert.Equal(t, "request-123", response.GetRequestId())
assert.Equal(t, "device-session-123", reservedDeviceSessionID)
assert.Equal(t, "request-123", reservedRequestID)
assert.Equal(t, testFreshnessWindow, reservedTTL)
assert.Equal(t, 1, delegate.executeCalls)
}
func TestSubscribeEventsFreshRequestReachesDelegateAndUsesDynamicReplayTTL(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{
subscribeEventsFunc: func(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
return nil
},
}
var reservedTTL time.Duration
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
ReplayStore: staticReplayStore{
reserveFunc: func(ctx context.Context, deviceSessionID string, requestID string, ttl time.Duration) error {
assert.Equal(t, "device-session-123", deviceSessionID)
assert.Equal(t, "request-123", requestID)
reservedTTL = ttl
return nil
},
},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
stream, err := client.SubscribeEvents(context.Background(), newValidSubscribeEventsRequest())
require.NoError(t, err)
event := recvBootstrapEvent(t, stream)
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-123", "trace-123", testCurrentTime.UnixMilli())
_, err = stream.Recv()
require.ErrorIs(t, err, io.EOF)
assert.Equal(t, testFreshnessWindow, reservedTTL)
assert.Equal(t, 1, delegate.subscribeCalls)
}
func TestExecuteCommandFutureSkewUsesExtendedReplayTTL(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{
executeCommandFunc: func(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
return &gatewayv1.ExecuteCommandResponse{RequestId: req.GetRequestId()}, nil
},
}
var reservedTTL time.Duration
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
ReplayStore: staticReplayStore{
reserveFunc: func(ctx context.Context, deviceSessionID string, requestID string, ttl time.Duration) error {
reservedTTL = ttl
return nil
},
},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(
context.Background(),
newValidExecuteCommandRequestWithTimestamp("device-session-123", "request-123", testCurrentTime.Add(2*time.Minute).UnixMilli()),
)
require.NoError(t, err)
assert.Equal(t, 7*time.Minute, reservedTTL)
assert.Equal(t, 1, delegate.executeCalls)
}
func TestExecuteCommandBoundaryFreshnessUsesMinimumReplayTTL(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{
executeCommandFunc: func(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
return &gatewayv1.ExecuteCommandResponse{RequestId: req.GetRequestId()}, nil
},
}
var reservedTTL time.Duration
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
ReplayStore: staticReplayStore{
reserveFunc: func(ctx context.Context, deviceSessionID string, requestID string, ttl time.Duration) error {
reservedTTL = ttl
return nil
},
},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(
context.Background(),
newValidExecuteCommandRequestWithTimestamp("device-session-123", "request-123", testCurrentTime.Add(-testFreshnessWindow).UnixMilli()),
)
require.NoError(t, err)
assert.Equal(t, minimumReplayReservationTTL, reservedTTL)
assert.Equal(t, 1, delegate.executeCalls)
}
func replayDuplicateBySessionAndRequest() func(context.Context, string, string, time.Duration) error {
var (
mu sync.Mutex
seen = make(map[string]struct{})
)
return func(ctx context.Context, deviceSessionID string, requestID string, ttl time.Duration) error {
mu.Lock()
defer mu.Unlock()
key := deviceSessionID + "\x00" + requestID
if _, ok := seen[key]; ok {
return replay.ErrDuplicate
}
seen[key] = struct{}{}
return nil
}
}
+147
View File
@@ -0,0 +1,147 @@
package grpcapi
import (
"context"
"errors"
"path"
"time"
"galaxy/gateway/internal/logging"
"galaxy/gateway/internal/telemetry"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"go.opentelemetry.io/otel/attribute"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
func observabilityUnaryInterceptor(logger *zap.Logger, metrics *telemetry.Runtime) grpc.UnaryServerInterceptor {
if logger == nil {
logger = zap.NewNop()
}
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
start := time.Now()
resp, err := handler(ctx, req)
recordGRPCRequest(logger, metrics, ctx, info.FullMethod, req, resp, err, time.Since(start), "unary")
return resp, err
}
}
func observabilityStreamInterceptor(logger *zap.Logger, metrics *telemetry.Runtime) grpc.StreamServerInterceptor {
if logger == nil {
logger = zap.NewNop()
}
return func(srv any, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
start := time.Now()
wrapped := &observabilityServerStream{ServerStream: stream}
err := handler(srv, wrapped)
recordGRPCRequest(logger, metrics, stream.Context(), info.FullMethod, wrapped.request, nil, err, time.Since(start), "stream")
return err
}
}
type observabilityServerStream struct {
grpc.ServerStream
request any
}
func (s *observabilityServerStream) RecvMsg(m any) error {
err := s.ServerStream.RecvMsg(m)
if err == nil && s.request == nil {
s.request = m
}
return err
}
func recordGRPCRequest(logger *zap.Logger, metrics *telemetry.Runtime, ctx context.Context, fullMethod string, req any, resp any, err error, duration time.Duration, streamKind string) {
rpcMethod := path.Base(fullMethod)
messageType, requestID, traceID := grpcEnvelopeFields(req)
resultCode := grpcResultCode(resp)
grpcCode, grpcMessage, outcome := grpcOutcome(err)
rejectReason := telemetry.RejectReason(outcome)
attrs := []attribute.KeyValue{
attribute.String("rpc_method", rpcMethod),
attribute.String("message_type", messageType),
attribute.String("edge_outcome", string(outcome)),
}
if resultCode != "" {
attrs = append(attrs, attribute.String("result_code", resultCode))
}
if rejectReason != "" {
attrs = append(attrs, attribute.String("reject_reason", rejectReason))
}
metrics.RecordAuthenticatedGRPC(ctx, attrs, duration)
fields := []zap.Field{
zap.String("component", "authenticated_grpc"),
zap.String("transport", "grpc"),
zap.String("stream_kind", streamKind),
zap.String("rpc_method", rpcMethod),
zap.String("message_type", messageType),
zap.String("grpc_code", grpcCode.String()),
zap.Float64("duration_ms", float64(duration.Microseconds())/1000),
zap.String("request_id", requestID),
zap.String("trace_id", traceID),
zap.String("peer_ip", peerIPFromContext(ctx)),
zap.String("edge_outcome", string(outcome)),
}
if resultCode != "" {
fields = append(fields, zap.String("result_code", resultCode))
}
if rejectReason != "" {
fields = append(fields, zap.String("reject_reason", rejectReason))
}
if grpcMessage != "" {
fields = append(fields, zap.String("grpc_message", grpcMessage))
}
fields = append(fields, logging.TraceFieldsFromContext(ctx)...)
switch outcome {
case telemetry.EdgeOutcomeSuccess:
logger.Info("authenticated gRPC request completed", fields...)
case telemetry.EdgeOutcomeBackendUnavailable, telemetry.EdgeOutcomeDownstreamUnavailable, telemetry.EdgeOutcomeInternalError:
logger.Error("authenticated gRPC request failed", fields...)
default:
logger.Warn("authenticated gRPC request rejected", fields...)
}
}
func grpcEnvelopeFields(req any) (messageType string, requestID string, traceID string) {
switch typed := req.(type) {
case *gatewayv1.ExecuteCommandRequest:
return typed.GetMessageType(), typed.GetRequestId(), typed.GetTraceId()
case *gatewayv1.SubscribeEventsRequest:
return typed.GetMessageType(), typed.GetRequestId(), typed.GetTraceId()
default:
return "", "", ""
}
}
func grpcResultCode(resp any) string {
typed, ok := resp.(*gatewayv1.ExecuteCommandResponse)
if !ok {
return ""
}
return typed.GetResultCode()
}
func grpcOutcome(err error) (codes.Code, string, telemetry.EdgeOutcome) {
switch {
case err == nil:
return codes.OK, "", telemetry.EdgeOutcomeSuccess
case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded):
return codes.Canceled, err.Error(), telemetry.EdgeOutcomeSuccess
default:
grpcStatus := status.Convert(err)
return grpcStatus.Code(), grpcStatus.Message(), telemetry.OutcomeFromGRPCStatus(grpcStatus.Code(), grpcStatus.Message())
}
}
+66
View File
@@ -0,0 +1,66 @@
package grpcapi
import (
"context"
"errors"
"galaxy/gateway/internal/authn"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// payloadHashVerifyingService applies payload-hash verification after session
// lookup and before any later auth or routing step runs.
type payloadHashVerifyingService struct {
gatewayv1.UnimplementedEdgeGatewayServer
delegate gatewayv1.EdgeGatewayServer
}
// ExecuteCommand verifies req payload integrity before delegating to the
// configured service implementation.
func (s payloadHashVerifyingService) ExecuteCommand(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
if err := verifyPayloadHash(ctx); err != nil {
return nil, err
}
return s.delegate.ExecuteCommand(ctx, req)
}
// SubscribeEvents verifies req payload integrity before delegating to the
// configured service implementation.
func (s payloadHashVerifyingService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
if err := verifyPayloadHash(stream.Context()); err != nil {
return err
}
return s.delegate.SubscribeEvents(req, stream)
}
// newPayloadHashVerifyingService wraps delegate with the payload-hash
// verification gate.
func newPayloadHashVerifyingService(delegate gatewayv1.EdgeGatewayServer) gatewayv1.EdgeGatewayServer {
return payloadHashVerifyingService{delegate: delegate}
}
func verifyPayloadHash(ctx context.Context) error {
envelope, ok := parsedEnvelopeFromContext(ctx)
if !ok {
return status.Error(codes.Internal, "authenticated request context is incomplete")
}
err := authn.VerifyPayloadHash(envelope.PayloadBytes, envelope.PayloadHash)
switch {
case err == nil:
return nil
case errors.Is(err, authn.ErrInvalidPayloadHash), errors.Is(err, authn.ErrPayloadHashMismatch):
return status.Error(codes.InvalidArgument, err.Error())
default:
return status.Error(codes.Internal, "payload hash verification failed")
}
}
var _ gatewayv1.EdgeGatewayServer = payloadHashVerifyingService{}
@@ -0,0 +1,125 @@
package grpcapi
import (
"context"
"crypto/sha256"
"testing"
"galaxy/gateway/internal/session"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
func TestExecuteCommandRejectsPayloadHashWithInvalidLength(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
req := newValidExecuteCommandRequest()
req.PayloadHash = []byte("short")
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), req)
require.Error(t, err)
assert.Equal(t, codes.InvalidArgument, status.Code(err))
assert.Equal(t, "payload_hash must be a 32-byte SHA-256 digest", status.Convert(err).Message())
assert.Zero(t, delegate.executeCalls)
}
func TestExecuteCommandRejectsPayloadHashMismatch(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
req := newValidExecuteCommandRequest()
sum := sha256.Sum256([]byte("other"))
req.PayloadHash = sum[:]
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), req)
require.Error(t, err)
assert.Equal(t, codes.InvalidArgument, status.Code(err))
assert.Equal(t, "payload_hash does not match payload_bytes", status.Convert(err).Message())
assert.Zero(t, delegate.executeCalls)
}
func TestSubscribeEventsRejectsPayloadHashWithInvalidLength(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
req := newValidSubscribeEventsRequest()
req.PayloadHash = []byte("short")
client := gatewayv1.NewEdgeGatewayClient(conn)
err := subscribeEventsError(t, context.Background(), client, req)
require.Error(t, err)
assert.Equal(t, codes.InvalidArgument, status.Code(err))
assert.Equal(t, "payload_hash must be a 32-byte SHA-256 digest", status.Convert(err).Message())
assert.Zero(t, delegate.subscribeCalls)
}
func TestSubscribeEventsRejectsPayloadHashMismatch(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
req := newValidSubscribeEventsRequest()
sum := sha256.Sum256([]byte("other"))
req.PayloadHash = sum[:]
client := gatewayv1.NewEdgeGatewayClient(conn)
err := subscribeEventsError(t, context.Background(), client, req)
require.Error(t, err)
assert.Equal(t, codes.InvalidArgument, status.Code(err))
assert.Equal(t, "payload_hash does not match payload_bytes", status.Convert(err).Message())
assert.Zero(t, delegate.subscribeCalls)
}
+172
View File
@@ -0,0 +1,172 @@
package grpcapi
import (
"bytes"
"context"
"crypto/sha256"
"errors"
"galaxy/gateway/internal/authn"
"galaxy/gateway/internal/clock"
"galaxy/gateway/internal/logging"
"galaxy/gateway/internal/push"
"galaxy/gateway/internal/telemetry"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// NewFanOutPushStreamService constructs the authenticated SubscribeEvents tail
// service that registers active streams in hub and forwards client-facing
// events after the bootstrap event has been sent.
func NewFanOutPushStreamService(hub *push.Hub, responseSigner authn.ResponseSigner, clk clock.Clock, logger *zap.Logger) gatewayv1.EdgeGatewayServer {
if responseSigner == nil {
responseSigner = unavailableResponseSigner{}
}
if clk == nil {
clk = clock.System{}
}
if logger == nil {
logger = zap.NewNop()
}
return fanOutPushStreamService{
hub: hub,
responseSigner: responseSigner,
clock: clk,
logger: logger.Named("push_stream"),
}
}
// fanOutPushStreamService owns the post-bootstrap authenticated push-stream
// lifecycle backed by the in-memory push hub.
type fanOutPushStreamService struct {
gatewayv1.UnimplementedEdgeGatewayServer
hub *push.Hub
responseSigner authn.ResponseSigner
clock clock.Clock
logger *zap.Logger
}
// SubscribeEvents registers the verified stream in the push hub and forwards
// matching client-facing events until the stream ends.
func (s fanOutPushStreamService) SubscribeEvents(_ *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
binding, ok := authenticatedStreamBindingFromContext(stream.Context())
if !ok {
return status.Error(codes.Internal, "authenticated request context is incomplete")
}
if s.hub == nil {
return status.Error(codes.Internal, "push hub is unavailable")
}
subscription, err := s.hub.Register(push.StreamBinding{
UserID: binding.UserID,
DeviceSessionID: binding.DeviceSessionID,
})
if err != nil {
return status.Error(codes.Internal, "push stream registration failed")
}
defer subscription.Close()
openFields := []zap.Field{
zap.String("component", "authenticated_grpc"),
zap.String("transport", "grpc"),
zap.String("rpc_method", authenticatedRPCSubscribeEvents),
zap.String("message_type", binding.MessageType),
zap.String("request_id", binding.RequestID),
zap.String("trace_id", binding.TraceID),
zap.String("device_session_id", binding.DeviceSessionID),
zap.String("user_id", binding.UserID),
}
openFields = append(openFields, logging.TraceFieldsFromContext(stream.Context())...)
s.logger.Info("push stream opened", openFields...)
for {
select {
case <-stream.Context().Done():
s.logger.Info("push stream closed", append(openFields, zap.String("edge_outcome", string(mapSubscriptionOutcome(stream.Context().Err()))))...)
return stream.Context().Err()
case <-subscription.Done():
subscriptionErr := subscription.Err()
s.logger.Warn("push stream closed", append(openFields,
zap.String("edge_outcome", string(mapSubscriptionOutcome(subscriptionErr))),
zap.String("reject_reason", string(mapSubscriptionOutcome(subscriptionErr))),
)...)
return mapSubscriptionError(subscriptionErr)
case event := <-subscription.Events():
signedEvent, err := s.buildGatewayEvent(event)
if err != nil {
return err
}
if err := stream.Send(signedEvent); err != nil {
return err
}
}
}
}
func (s fanOutPushStreamService) buildGatewayEvent(event push.Event) (*gatewayv1.GatewayEvent, error) {
timestampMS := s.clock.Now().UTC().UnixMilli()
payloadHash := sha256.Sum256(event.PayloadBytes)
signature, err := s.responseSigner.SignEvent(authn.EventSigningFields{
EventType: event.EventType,
EventID: event.EventID,
TimestampMS: timestampMS,
RequestID: event.RequestID,
TraceID: event.TraceID,
PayloadHash: payloadHash[:],
})
if err != nil {
return nil, status.Error(codes.Unavailable, "response signer is unavailable")
}
return &gatewayv1.GatewayEvent{
EventType: event.EventType,
EventId: event.EventID,
TimestampMs: timestampMS,
PayloadBytes: bytes.Clone(event.PayloadBytes),
PayloadHash: bytes.Clone(payloadHash[:]),
Signature: signature,
RequestId: event.RequestID,
TraceId: event.TraceID,
}, nil
}
func mapSubscriptionError(err error) error {
switch {
case err == nil:
return nil
case errors.Is(err, push.ErrSubscriptionRevoked):
return status.Error(codes.FailedPrecondition, "device session is revoked")
case errors.Is(err, push.ErrSubscriptionOverflow):
return status.Error(codes.ResourceExhausted, "push stream overflowed")
case errors.Is(err, push.ErrHubShuttingDown):
return status.Error(codes.Unavailable, "gateway is shutting down")
default:
return status.Error(codes.Internal, "push stream closed unexpectedly")
}
}
func mapSubscriptionOutcome(err error) telemetry.EdgeOutcome {
switch {
case err == nil:
return telemetry.EdgeOutcomeSuccess
case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded):
return telemetry.EdgeOutcomeSuccess
case errors.Is(err, push.ErrSubscriptionRevoked):
return telemetry.EdgeOutcomeRevokedSession
case errors.Is(err, push.ErrSubscriptionOverflow):
return telemetry.EdgeOutcomeRateLimited
case errors.Is(err, push.ErrHubShuttingDown):
return telemetry.EdgeOutcomeGatewayShuttingDown
default:
return telemetry.EdgeOutcomeInternalError
}
}
var _ gatewayv1.EdgeGatewayServer = fanOutPushStreamService{}
+164
View File
@@ -0,0 +1,164 @@
package grpcapi
import (
"bytes"
"context"
"crypto/sha256"
"galaxy/gateway/internal/authn"
"galaxy/gateway/internal/clock"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
gatewayfbs "galaxy/schema/fbs/gateway"
flatbuffers "github.com/google/flatbuffers/go"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
const serverTimeEventType = "gateway.server_time"
// authenticatedStreamBinding captures the verified identity bound to one
// authenticated SubscribeEvents stream after the full ingress pipeline
// succeeds.
type authenticatedStreamBinding struct {
UserID string
DeviceSessionID string
MessageType string
RequestID string
TraceID string
}
// authenticatedStreamBindingFromContext returns the verified stream binding
// previously attached to ctx by the authenticated push-stream service.
func authenticatedStreamBindingFromContext(ctx context.Context) (authenticatedStreamBinding, bool) {
if ctx == nil {
return authenticatedStreamBinding{}, false
}
binding, ok := ctx.Value(authenticatedStreamBindingContextKey{}).(authenticatedStreamBinding)
if !ok {
return authenticatedStreamBinding{}, false
}
return binding, true
}
// authenticatedPushStreamService owns SubscribeEvents bootstrap behavior:
// bind the authenticated stream, send the initial signed server-time event,
// and then hand the stream lifecycle to the configured tail delegate.
type authenticatedPushStreamService struct {
gatewayv1.UnimplementedEdgeGatewayServer
tailDelegate gatewayv1.EdgeGatewayServer
responseSigner authn.ResponseSigner
clock clock.Clock
}
// SubscribeEvents binds the verified stream identity, sends the initial signed
// server-time event, and then delegates the remaining lifecycle.
func (s authenticatedPushStreamService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
envelope, ok := parsedEnvelopeFromContext(stream.Context())
if !ok {
return status.Error(codes.Internal, "authenticated request context is incomplete")
}
record, ok := resolvedSessionFromContext(stream.Context())
if !ok {
return status.Error(codes.Internal, "authenticated request context is incomplete")
}
binding := authenticatedStreamBinding{
UserID: record.UserID,
DeviceSessionID: record.DeviceSessionID,
MessageType: envelope.MessageType,
RequestID: envelope.RequestID,
TraceID: envelope.TraceID,
}
boundStream := authenticatedStreamContextStream{
ServerStreamingServer: stream,
ctx: context.WithValue(
stream.Context(),
authenticatedStreamBindingContextKey{},
binding,
),
}
serverTimeMS := s.clock.Now().UTC().UnixMilli()
payloadBytes := buildServerTimeEventPayload(serverTimeMS)
payloadHash := sha256.Sum256(payloadBytes)
signature, err := s.responseSigner.SignEvent(authn.EventSigningFields{
EventType: serverTimeEventType,
EventID: envelope.RequestID,
TimestampMS: serverTimeMS,
RequestID: envelope.RequestID,
TraceID: envelope.TraceID,
PayloadHash: payloadHash[:],
})
if err != nil {
return status.Error(codes.Unavailable, "response signer is unavailable")
}
if err := boundStream.Send(&gatewayv1.GatewayEvent{
EventType: serverTimeEventType,
EventId: envelope.RequestID,
TimestampMs: serverTimeMS,
PayloadBytes: bytes.Clone(payloadBytes),
PayloadHash: bytes.Clone(payloadHash[:]),
Signature: signature,
RequestId: envelope.RequestID,
TraceId: envelope.TraceID,
}); err != nil {
return err
}
return s.tailDelegate.SubscribeEvents(req, boundStream)
}
func newAuthenticatedPushStreamService(tailDelegate gatewayv1.EdgeGatewayServer, responseSigner authn.ResponseSigner, clk clock.Clock) gatewayv1.EdgeGatewayServer {
if tailDelegate == nil {
tailDelegate = holdOpenSubscribeEventsService{}
}
return authenticatedPushStreamService{
tailDelegate: tailDelegate,
responseSigner: responseSigner,
clock: clk,
}
}
func buildServerTimeEventPayload(serverTimeMS int64) []byte {
builder := flatbuffers.NewBuilder(32)
gatewayfbs.ServerTimeEventStart(builder)
gatewayfbs.ServerTimeEventAddServerTimeMs(builder, serverTimeMS)
eventOffset := gatewayfbs.ServerTimeEventEnd(builder)
gatewayfbs.FinishServerTimeEventBuffer(builder, eventOffset)
return bytes.Clone(builder.FinishedBytes())
}
type authenticatedStreamBindingContextKey struct{}
type authenticatedStreamContextStream struct {
grpc.ServerStreamingServer[gatewayv1.GatewayEvent]
ctx context.Context
}
func (s authenticatedStreamContextStream) Context() context.Context {
if s.ctx == nil {
return context.Background()
}
return s.ctx
}
type holdOpenSubscribeEventsService struct {
gatewayv1.UnimplementedEdgeGatewayServer
}
func (holdOpenSubscribeEventsService) SubscribeEvents(_ *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
<-stream.Context().Done()
return stream.Context().Err()
}
var _ gatewayv1.EdgeGatewayServer = authenticatedPushStreamService{}
+286
View File
@@ -0,0 +1,286 @@
package grpcapi
import (
"context"
"errors"
"net"
"strings"
"galaxy/gateway/internal/config"
"galaxy/gateway/internal/ratelimit"
"galaxy/gateway/internal/session"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
)
const (
authenticatedGRPCBaseBucketKeyPrefix = "authenticated_grpc/"
authenticatedGRPCIPBucketKeySegment = authenticatedGRPCBaseBucketKeyPrefix + "ip="
authenticatedGRPCSessionBucketKeySegment = authenticatedGRPCBaseBucketKeyPrefix + "session="
authenticatedGRPCUserBucketKeySegment = authenticatedGRPCBaseBucketKeyPrefix + "user="
authenticatedGRPCMessageClassBucketKeySegment = authenticatedGRPCBaseBucketKeyPrefix + "message_class="
unknownAuthenticatedPeerIP = "unknown"
authenticatedRPCExecuteCommand = "ExecuteCommand"
authenticatedRPCSubscribeEvents = "SubscribeEvents"
)
var (
// ErrAuthenticatedPolicyDenied reports that the authenticated request was
// rejected by later edge policy after transport authenticity succeeded.
ErrAuthenticatedPolicyDenied = errors.New("authenticated request rejected by edge policy")
// ErrAuthenticatedPolicyUnavailable reports that authenticated policy could
// not be evaluated because its backing dependency is unavailable.
ErrAuthenticatedPolicyUnavailable = errors.New("authenticated request policy is unavailable")
)
// AuthenticatedRequestLimiter applies authenticated gRPC rate-limit policy to
// one concrete bucket key.
type AuthenticatedRequestLimiter interface {
// Reserve evaluates key under policy and reports whether the request may
// proceed immediately.
Reserve(key string, policy ratelimit.Policy) ratelimit.Decision
}
// AuthenticatedRequest describes the authenticated request metadata exposed to
// the edge-policy hook.
type AuthenticatedRequest struct {
// RPCMethod identifies the public gRPC method being processed.
RPCMethod string
// PeerIP is the transport peer IP derived from the gRPC connection.
PeerIP string
// MessageClass is the stable rate-limit and policy class. The gateway uses
// the full message_type literal because the v1 transport does not yet define
// a coarser authenticated class taxonomy.
MessageClass string
// Envelope contains the verified transport envelope fields used by later
// edge policy.
Envelope AuthenticatedRequestEnvelope
// Session contains the authenticated identity resolved from SessionCache.
Session session.Record
}
// AuthenticatedRequestEnvelope describes the verified request envelope fields
// exposed to the edge-policy hook.
type AuthenticatedRequestEnvelope struct {
// ProtocolVersion is the supported transport protocol version literal.
ProtocolVersion string
// DeviceSessionID is the authenticated device-session identifier.
DeviceSessionID string
// MessageType is the verified downstream routing key supplied by the client.
MessageType string
// TimestampMS is the client timestamp that already passed freshness checks.
TimestampMS int64
// RequestID is the authenticated transport request identifier.
RequestID string
// TraceID is the optional client-supplied correlation identifier.
TraceID string
}
// AuthenticatedRequestPolicy evaluates later authenticated edge policy after
// transport authenticity and rate-limit checks succeed.
type AuthenticatedRequestPolicy interface {
// Evaluate returns nil when the authenticated request may proceed. It should
// wrap ErrAuthenticatedPolicyDenied for stable reject mapping and
// ErrAuthenticatedPolicyUnavailable when its backing dependency is
// temporarily unavailable.
Evaluate(ctx context.Context, request AuthenticatedRequest) error
}
type authenticatedRateLimitService struct {
gatewayv1.UnimplementedEdgeGatewayServer
delegate gatewayv1.EdgeGatewayServer
limiter AuthenticatedRequestLimiter
policy AuthenticatedRequestPolicy
cfg config.AuthenticatedGRPCAntiAbuseConfig
}
// ExecuteCommand applies authenticated rate limits and edge policy before
// delegating to the configured service implementation.
func (s authenticatedRateLimitService) ExecuteCommand(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
if err := s.applyRateLimitsAndPolicy(ctx, authenticatedRPCExecuteCommand); err != nil {
return nil, err
}
return s.delegate.ExecuteCommand(ctx, req)
}
// SubscribeEvents applies authenticated rate limits and edge policy before
// delegating to the configured service implementation.
func (s authenticatedRateLimitService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
if err := s.applyRateLimitsAndPolicy(stream.Context(), authenticatedRPCSubscribeEvents); err != nil {
return err
}
return s.delegate.SubscribeEvents(req, stream)
}
// newAuthenticatedRateLimitService wraps delegate with the authenticated
// rate-limit and edge-policy gate.
func newAuthenticatedRateLimitService(delegate gatewayv1.EdgeGatewayServer, limiter AuthenticatedRequestLimiter, policy AuthenticatedRequestPolicy, cfg config.AuthenticatedGRPCAntiAbuseConfig) gatewayv1.EdgeGatewayServer {
return authenticatedRateLimitService{
delegate: delegate,
limiter: limiter,
policy: policy,
cfg: cfg,
}
}
func (s authenticatedRateLimitService) applyRateLimitsAndPolicy(ctx context.Context, rpcMethod string) error {
request, err := authenticatedRequestFromContext(ctx, rpcMethod)
if err != nil {
return err
}
if err := s.applyRateLimits(request); err != nil {
return err
}
if err := s.applyPolicy(ctx, request); err != nil {
return err
}
return nil
}
func (s authenticatedRateLimitService) applyRateLimits(request AuthenticatedRequest) error {
checks := []struct {
key string
policy config.AuthenticatedRateLimitConfig
}{
{
key: authenticatedGRPCIPBucketKey(request.PeerIP),
policy: s.cfg.IP,
},
{
key: authenticatedGRPCSessionBucketKey(request.Envelope.DeviceSessionID),
policy: s.cfg.Session,
},
{
key: authenticatedGRPCUserBucketKey(request.Session.UserID),
policy: s.cfg.User,
},
{
key: authenticatedGRPCMessageClassBucketKey(request.MessageClass),
policy: s.cfg.MessageClass,
},
}
for _, check := range checks {
decision := s.limiter.Reserve(check.key, ratelimit.Policy{
Requests: check.policy.Requests,
Window: check.policy.Window,
Burst: check.policy.Burst,
})
if !decision.Allowed {
return status.Error(codes.ResourceExhausted, "authenticated request rate limit exceeded")
}
}
return nil
}
func (s authenticatedRateLimitService) applyPolicy(ctx context.Context, request AuthenticatedRequest) error {
err := s.policy.Evaluate(ctx, request)
switch {
case err == nil:
return nil
case errors.Is(err, ErrAuthenticatedPolicyDenied):
return status.Error(codes.PermissionDenied, "authenticated request rejected by edge policy")
case errors.Is(err, ErrAuthenticatedPolicyUnavailable):
return status.Error(codes.Unavailable, "authenticated request policy is unavailable")
default:
return status.Error(codes.Internal, "authenticated request policy evaluation failed")
}
}
func authenticatedRequestFromContext(ctx context.Context, rpcMethod string) (AuthenticatedRequest, error) {
envelope, ok := parsedEnvelopeFromContext(ctx)
if !ok {
return AuthenticatedRequest{}, status.Error(codes.Internal, "authenticated request context is incomplete")
}
record, ok := resolvedSessionFromContext(ctx)
if !ok {
return AuthenticatedRequest{}, status.Error(codes.Internal, "authenticated request context is incomplete")
}
return AuthenticatedRequest{
RPCMethod: rpcMethod,
PeerIP: peerIPFromContext(ctx),
MessageClass: authenticatedMessageClass(envelope.MessageType),
Envelope: AuthenticatedRequestEnvelope{
ProtocolVersion: envelope.ProtocolVersion,
DeviceSessionID: envelope.DeviceSessionID,
MessageType: envelope.MessageType,
TimestampMS: envelope.TimestampMS,
RequestID: envelope.RequestID,
TraceID: envelope.TraceID,
},
Session: record,
}, nil
}
func authenticatedGRPCIPBucketKey(peerIP string) string {
return authenticatedGRPCIPBucketKeySegment + peerIP
}
func authenticatedGRPCSessionBucketKey(deviceSessionID string) string {
return authenticatedGRPCSessionBucketKeySegment + deviceSessionID
}
func authenticatedGRPCUserBucketKey(userID string) string {
return authenticatedGRPCUserBucketKeySegment + userID
}
func authenticatedGRPCMessageClassBucketKey(messageClass string) string {
return authenticatedGRPCMessageClassBucketKeySegment + messageClass
}
func authenticatedMessageClass(messageType string) string {
return messageType
}
func peerIPFromContext(ctx context.Context) string {
peerInfo, ok := peer.FromContext(ctx)
if !ok || peerInfo.Addr == nil {
return unknownAuthenticatedPeerIP
}
value := strings.TrimSpace(peerInfo.Addr.String())
if value == "" {
return unknownAuthenticatedPeerIP
}
host, _, err := net.SplitHostPort(value)
if err == nil && host != "" {
return host
}
return value
}
type noopAuthenticatedRequestPolicy struct{}
func (noopAuthenticatedRequestPolicy) Evaluate(context.Context, AuthenticatedRequest) error {
return nil
}
var _ gatewayv1.EdgeGatewayServer = authenticatedRateLimitService{}
@@ -0,0 +1,497 @@
package grpcapi
import (
"context"
"fmt"
"io"
"net"
"net/http"
"strings"
"testing"
"time"
"galaxy/gateway/internal/app"
"galaxy/gateway/internal/config"
"galaxy/gateway/internal/ratelimit"
"galaxy/gateway/internal/restapi"
"galaxy/gateway/internal/session"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
func TestExecuteCommandRateLimitsByIP(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGatewayWithGRPCConfig(t, newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) {
cfg.AntiAbuse.IP = config.AuthenticatedRateLimitConfig{
Requests: 1,
Window: time.Hour,
Burst: 1,
}
}), ServerDependencies{
Service: delegate,
SessionCache: userMappedSessionCache(map[string]string{"device-session-1": "user-1", "device-session-2": "user-2"}),
ReplayStore: staticReplayStore{},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-1", "request-1"))
require.NoError(t, err)
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-2", "request-2"))
require.Error(t, err)
assert.Equal(t, codes.ResourceExhausted, status.Code(err))
assert.Equal(t, "authenticated request rate limit exceeded", status.Convert(err).Message())
assert.Equal(t, 1, delegate.executeCalls)
}
func TestExecuteCommandRateLimitsBySession(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGatewayWithGRPCConfig(t, newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) {
cfg.AntiAbuse.Session = config.AuthenticatedRateLimitConfig{
Requests: 1,
Window: time.Hour,
Burst: 1,
}
}), ServerDependencies{
Service: delegate,
SessionCache: userMappedSessionCache(map[string]string{"device-session-1": "user-1", "device-session-2": "user-1"}),
ReplayStore: staticReplayStore{},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-1", "request-1"))
require.NoError(t, err)
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-1", "request-2"))
require.Error(t, err)
assert.Equal(t, codes.ResourceExhausted, status.Code(err))
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-2", "request-3"))
require.NoError(t, err)
assert.Equal(t, 2, delegate.executeCalls)
}
func TestExecuteCommandRateLimitsByUser(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGatewayWithGRPCConfig(t, newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) {
cfg.AntiAbuse.User = config.AuthenticatedRateLimitConfig{
Requests: 1,
Window: time.Hour,
Burst: 1,
}
}), ServerDependencies{
Service: delegate,
SessionCache: userMappedSessionCache(map[string]string{
"device-session-1": "user-shared",
"device-session-2": "user-shared",
"device-session-3": "user-other",
}),
ReplayStore: staticReplayStore{},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-1", "request-1"))
require.NoError(t, err)
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-2", "request-2"))
require.Error(t, err)
assert.Equal(t, codes.ResourceExhausted, status.Code(err))
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithSessionAndRequestID("device-session-3", "request-3"))
require.NoError(t, err)
assert.Equal(t, 2, delegate.executeCalls)
}
func TestExecuteCommandRateLimitsByMessageClass(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGatewayWithGRPCConfig(t, newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) {
cfg.AntiAbuse.MessageClass = config.AuthenticatedRateLimitConfig{
Requests: 1,
Window: time.Hour,
Burst: 1,
}
}), ServerDependencies{
Service: delegate,
SessionCache: userMappedSessionCache(map[string]string{
"device-session-1": "user-1",
"device-session-2": "user-2",
}),
ReplayStore: staticReplayStore{},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithMessageType("device-session-1", "request-1", "fleet.move"))
require.NoError(t, err)
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithMessageType("device-session-2", "request-2", "fleet.move"))
require.Error(t, err)
assert.Equal(t, codes.ResourceExhausted, status.Code(err))
_, err = client.ExecuteCommand(context.Background(), newValidExecuteCommandRequestWithMessageType("device-session-2", "request-3", "fleet.rename"))
require.NoError(t, err)
assert.Equal(t, 2, delegate.executeCalls)
}
func TestAuthenticatedPolicyHookReceivesVerifiedRequest(t *testing.T) {
t.Parallel()
policy := &recordingAuthenticatedRequestPolicy{}
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
ReplayStore: staticReplayStore{},
Policy: policy,
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
require.NoError(t, err)
require.Len(t, policy.requests, 1)
assert.Equal(t, authenticatedRPCExecuteCommand, policy.requests[0].RPCMethod)
assert.Equal(t, "127.0.0.1", policy.requests[0].PeerIP)
assert.Equal(t, "fleet.move", policy.requests[0].MessageClass)
assert.Equal(t, "device-session-123", policy.requests[0].Envelope.DeviceSessionID)
assert.Equal(t, "request-123", policy.requests[0].Envelope.RequestID)
assert.Equal(t, "trace-123", policy.requests[0].Envelope.TraceID)
assert.Equal(t, "user-123", policy.requests[0].Session.UserID)
assert.Equal(t, 1, delegate.executeCalls)
}
func TestExecuteCommandPolicyRejectMapsToPermissionDenied(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
ReplayStore: staticReplayStore{},
Policy: authenticatedRequestPolicyFunc(func(context.Context, AuthenticatedRequest) error {
return fmt.Errorf("policy deny: %w", ErrAuthenticatedPolicyDenied)
}),
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
require.Error(t, err)
assert.Equal(t, codes.PermissionDenied, status.Code(err))
assert.Equal(t, "authenticated request rejected by edge policy", status.Convert(err).Message())
assert.Zero(t, delegate.executeCalls)
}
func TestSubscribeEventsRateLimitRejectsStream(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGatewayWithGRPCConfig(t, newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) {
cfg.AntiAbuse.IP = config.AuthenticatedRateLimitConfig{
Requests: 1,
Window: time.Hour,
Burst: 1,
}
}), ServerDependencies{
Service: delegate,
SessionCache: userMappedSessionCache(map[string]string{"device-session-1": "user-1", "device-session-2": "user-2"}),
ReplayStore: staticReplayStore{},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
stream, err := client.SubscribeEvents(context.Background(), newValidSubscribeEventsRequestWithSessionAndRequestID("device-session-1", "request-1"))
require.NoError(t, err)
event := recvBootstrapEvent(t, stream)
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-1", "trace-123", testCurrentTime.UnixMilli())
_, err = stream.Recv()
require.ErrorIs(t, err, io.EOF)
err = subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequestWithSessionAndRequestID("device-session-2", "request-2"))
require.Error(t, err)
assert.Equal(t, codes.ResourceExhausted, status.Code(err))
assert.Equal(t, "authenticated request rate limit exceeded", status.Convert(err).Message())
assert.Equal(t, 1, delegate.subscribeCalls)
}
func TestAuthenticatedRateLimitsStayIsolatedFromPublicREST(t *testing.T) {
t.Parallel()
sharedLimiter := ratelimit.NewInMemory()
publicCfg := config.DefaultPublicHTTPConfig()
publicCfg.Addr = unusedTCPAddr(t)
publicCfg.AntiAbuse.PublicAuth.RateLimit = config.PublicRateLimitConfig{
Requests: 1,
Window: time.Hour,
Burst: 1,
}
publicCfg.AntiAbuse.SendEmailCodeIdentity.RateLimit = config.PublicRateLimitConfig{
Requests: 100,
Window: time.Hour,
Burst: 100,
}
grpcCfg := newAuthenticatedGRPCConfigForTest(func(cfg *config.AuthenticatedGRPCConfig) {
cfg.Addr = unusedTCPAddr(t)
cfg.AntiAbuse.IP = config.AuthenticatedRateLimitConfig{
Requests: 1,
Window: time.Hour,
Burst: 1,
}
})
restServer := restapi.NewServer(publicCfg, restapi.ServerDependencies{
AuthService: staticAuthServiceClient{},
Limiter: publicLimiterAdapter{limiter: sharedLimiter},
})
delegate := &recordingEdgeGatewayService{}
grpcServer := NewServer(grpcCfg, ServerDependencies{
Service: delegate,
Router: executeCommandAdapterRouter{service: delegate},
ResponseSigner: newTestResponseSigner(),
SessionCache: userMappedSessionCache(map[string]string{"device-session-123": "user-123"}),
ReplayStore: staticReplayStore{},
Limiter: sharedLimiter,
Clock: fixedClock{now: testCurrentTime},
})
application := app.New(config.Config{ShutdownTimeout: time.Second}, restServer, grpcServer)
ctx, cancel := context.WithCancel(context.Background())
resultCh := make(chan error, 1)
go func() {
resultCh <- application.Run(ctx)
}()
runGateway := runningGateway{cancel: cancel, resultCh: resultCh}
defer runGateway.stop(t)
waitForHTTPHealthz(t, "http://"+publicCfg.Addr+"/healthz")
addr := waitForListenAddr(t, grpcServer)
firstPublic := sendPublicAuthRequest(t, "http://"+publicCfg.Addr+"/api/v1/public/auth/send-email-code")
secondPublic := sendPublicAuthRequest(t, "http://"+publicCfg.Addr+"/api/v1/public/auth/send-email-code")
assert.Equal(t, http.StatusOK, firstPublic.StatusCode)
assert.Equal(t, http.StatusTooManyRequests, secondPublic.StatusCode)
require.NoError(t, firstPublic.Body.Close())
require.NoError(t, secondPublic.Body.Close())
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
require.NoError(t, err)
}
func newAuthenticatedGRPCConfigForTest(mutate func(*config.AuthenticatedGRPCConfig)) config.AuthenticatedGRPCConfig {
cfg := config.DefaultAuthenticatedGRPCConfig()
cfg.Addr = "127.0.0.1:0"
cfg.FreshnessWindow = testFreshnessWindow
cfg.AntiAbuse.IP = config.AuthenticatedRateLimitConfig{
Requests: 100,
Window: time.Hour,
Burst: 100,
}
cfg.AntiAbuse.Session = config.AuthenticatedRateLimitConfig{
Requests: 100,
Window: time.Hour,
Burst: 100,
}
cfg.AntiAbuse.User = config.AuthenticatedRateLimitConfig{
Requests: 100,
Window: time.Hour,
Burst: 100,
}
cfg.AntiAbuse.MessageClass = config.AuthenticatedRateLimitConfig{
Requests: 100,
Window: time.Hour,
Burst: 100,
}
if mutate != nil {
mutate(&cfg)
}
return cfg
}
func newValidExecuteCommandRequestWithMessageType(deviceSessionID string, requestID string, messageType string) *gatewayv1.ExecuteCommandRequest {
req := newValidExecuteCommandRequestWithSessionAndRequestID(deviceSessionID, requestID)
req.MessageType = messageType
req.Signature = signRequest(
req.GetProtocolVersion(),
req.GetDeviceSessionId(),
req.GetMessageType(),
req.GetTimestampMs(),
req.GetRequestId(),
req.GetPayloadHash(),
)
return req
}
func userMappedSessionCache(users map[string]string) staticSessionCache {
return staticSessionCache{
lookupFunc: func(_ context.Context, deviceSessionID string) (session.Record, error) {
userID, ok := users[deviceSessionID]
if !ok {
return session.Record{}, session.ErrNotFound
}
record := newActiveSessionRecordWithSessionID(deviceSessionID)
record.UserID = userID
return record, nil
},
}
}
type authenticatedRequestPolicyFunc func(context.Context, AuthenticatedRequest) error
func (f authenticatedRequestPolicyFunc) Evaluate(ctx context.Context, request AuthenticatedRequest) error {
return f(ctx, request)
}
type recordingAuthenticatedRequestPolicy struct {
requests []AuthenticatedRequest
}
func (p *recordingAuthenticatedRequestPolicy) Evaluate(_ context.Context, request AuthenticatedRequest) error {
p.requests = append(p.requests, request)
return nil
}
type publicLimiterAdapter struct {
limiter ratelimit.Limiter
}
func (a publicLimiterAdapter) Reserve(key string, policy config.PublicRateLimitConfig) restapi.PublicRateLimitDecision {
decision := a.limiter.Reserve(key, ratelimit.Policy{
Requests: policy.Requests,
Window: policy.Window,
Burst: policy.Burst,
})
return restapi.PublicRateLimitDecision{
Allowed: decision.Allowed,
RetryAfter: decision.RetryAfter,
}
}
type staticAuthServiceClient struct{}
func (staticAuthServiceClient) SendEmailCode(context.Context, restapi.SendEmailCodeInput) (restapi.SendEmailCodeResult, error) {
return restapi.SendEmailCodeResult{ChallengeID: "challenge-123"}, nil
}
func (staticAuthServiceClient) ConfirmEmailCode(context.Context, restapi.ConfirmEmailCodeInput) (restapi.ConfirmEmailCodeResult, error) {
return restapi.ConfirmEmailCodeResult{DeviceSessionID: "device-session-123"}, nil
}
func waitForHTTPHealthz(t *testing.T, url string) {
t.Helper()
client := &http.Client{Timeout: 200 * time.Millisecond}
require.Eventually(t, func() bool {
response, err := client.Get(url)
if err != nil {
return false
}
require.NoError(t, response.Body.Close())
return response.StatusCode == http.StatusOK
}, 2*time.Second, 10*time.Millisecond, "public REST server did not become healthy: %s", url)
}
func sendPublicAuthRequest(t *testing.T, url string) *http.Response {
t.Helper()
request, err := http.NewRequest(http.MethodPost, url, strings.NewReader(`{"email":"pilot@example.com"}`))
require.NoError(t, err)
request.Header.Set("Content-Type", "application/json")
response, err := (&http.Client{Timeout: time.Second}).Do(request)
require.NoError(t, err)
return response
}
func unusedTCPAddr(t *testing.T) string {
t.Helper()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
addr := listener.Addr().String()
require.NoError(t, listener.Close())
return addr
}
+260
View File
@@ -0,0 +1,260 @@
// Package grpcapi exposes the authenticated gRPC surface of the gateway.
package grpcapi
import (
"context"
"errors"
"fmt"
"net"
"sync"
"galaxy/gateway/internal/authn"
"galaxy/gateway/internal/clock"
"galaxy/gateway/internal/config"
"galaxy/gateway/internal/downstream"
"galaxy/gateway/internal/push"
"galaxy/gateway/internal/ratelimit"
"galaxy/gateway/internal/replay"
"galaxy/gateway/internal/session"
"galaxy/gateway/internal/telemetry"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"go.uber.org/zap"
"google.golang.org/grpc"
)
// ServerDependencies describes the optional collaborators used by the
// authenticated gRPC server. The zero value is valid and keeps the process
// runnable with the built-in unimplemented service stub.
type ServerDependencies struct {
// Service optionally handles the post-bootstrap SubscribeEvents lifecycle
// after the initial authenticated service event has been sent. When nil, the
// gateway keeps authenticated SubscribeEvents streams open until the client
// cancels them, the server shuts down, or a later stream send fails.
Service gatewayv1.EdgeGatewayServer
// Router resolves the exact downstream unary client for the verified
// message_type value. When nil, the authenticated unary surface uses an
// empty exact-match router and returns UNIMPLEMENTED for unrouted commands.
Router downstream.Router
// ResponseSigner signs authenticated unary responses after downstream
// execution succeeds. When nil, the unary surface fails closed once it needs
// to sign a routed response.
ResponseSigner authn.ResponseSigner
// SessionCache resolves authenticated device sessions after the envelope
// gate succeeds. When nil, the authenticated gRPC surface remains runnable
// but valid envelopes fail closed as session-cache unavailable.
SessionCache session.Cache
// Clock provides current server time for freshness checks. When nil, the
// authenticated gRPC surface uses the system clock.
Clock clock.Clock
// ReplayStore reserves authenticated request identifiers after signature
// verification. When nil, valid requests fail closed as replay-store
// unavailable.
ReplayStore replay.Store
// Limiter applies authenticated rate limits after the request passes the
// transport authenticity checks. When nil, the authenticated gRPC surface
// uses a process-local in-memory limiter.
Limiter AuthenticatedRequestLimiter
// Policy evaluates later authenticated edge policy after rate limits pass.
// When nil, the authenticated gRPC surface applies a no-op allow policy.
Policy AuthenticatedRequestPolicy
// Logger writes structured logs for authenticated gRPC traffic.
Logger *zap.Logger
// Telemetry records low-cardinality gRPC metrics.
Telemetry *telemetry.Runtime
// PushHub is the active authenticated push-stream hub. When present, the
// server closes active streams before GracefulStop during shutdown.
PushHub *push.Hub
}
// Server owns the authenticated gRPC listener exposed by the gateway.
type Server struct {
cfg config.AuthenticatedGRPCConfig
service gatewayv1.EdgeGatewayServer
logger *zap.Logger
pushHub *push.Hub
metrics *telemetry.Runtime
stateMu sync.RWMutex
server *grpc.Server
listener net.Listener
}
// NewServer constructs an authenticated gRPC server for the supplied listener
// configuration and dependency bundle. Nil dependencies are replaced with safe
// defaults so the gateway can expose the documented transport surface with the
// full auth pipeline wired from built-in fallbacks.
func NewServer(cfg config.AuthenticatedGRPCConfig, deps ServerDependencies) *Server {
deps = normalizeServerDependencies(deps)
finalService := newCommandRoutingService(
newAuthenticatedPushStreamService(deps.Service, deps.ResponseSigner, deps.Clock),
deps.Router,
deps.ResponseSigner,
deps.Clock,
cfg.DownstreamTimeout,
)
return &Server{
cfg: cfg,
service: newEnvelopeValidatingService(
newSessionLookupService(
newPayloadHashVerifyingService(
newSignatureVerifyingService(
newFreshnessAndReplayService(
newAuthenticatedRateLimitService(
finalService,
deps.Limiter,
deps.Policy,
cfg.AntiAbuse,
),
deps.Clock,
deps.ReplayStore,
cfg.FreshnessWindow,
),
),
),
deps.SessionCache,
),
),
logger: deps.Logger.Named("authenticated_grpc"),
pushHub: deps.PushHub,
metrics: deps.Telemetry,
}
}
// Run binds the configured listener and serves the authenticated gRPC surface
// until Shutdown closes the server.
func (s *Server) Run(ctx context.Context) error {
if ctx == nil {
return errors.New("run authenticated gRPC server: nil context")
}
if err := ctx.Err(); err != nil {
return err
}
listener, err := net.Listen("tcp", s.cfg.Addr)
if err != nil {
return fmt.Errorf("run authenticated gRPC server: listen on %q: %w", s.cfg.Addr, err)
}
grpcServer := grpc.NewServer(
grpc.ConnectionTimeout(s.cfg.ConnectionTimeout),
grpc.StatsHandler(otelgrpc.NewServerHandler()),
grpc.ChainUnaryInterceptor(observabilityUnaryInterceptor(s.logger, s.metrics)),
grpc.ChainStreamInterceptor(observabilityStreamInterceptor(s.logger, s.metrics)),
)
gatewayv1.RegisterEdgeGatewayServer(grpcServer, s.service)
s.stateMu.Lock()
s.server = grpcServer
s.listener = listener
s.stateMu.Unlock()
s.logger.Info("authenticated gRPC server started", zap.String("addr", listener.Addr().String()))
defer func() {
s.stateMu.Lock()
s.server = nil
s.listener = nil
s.stateMu.Unlock()
}()
err = grpcServer.Serve(listener)
switch {
case err == nil:
return nil
case errors.Is(err, grpc.ErrServerStopped):
s.logger.Info("authenticated gRPC server stopped")
return nil
default:
return fmt.Errorf("run authenticated gRPC server: serve on %q: %w", s.cfg.Addr, err)
}
}
// Shutdown gracefully stops the authenticated gRPC server within ctx. When the
// graceful stop exceeds ctx, the server is force-stopped before returning the
// timeout to the caller.
func (s *Server) Shutdown(ctx context.Context) error {
if ctx == nil {
return errors.New("shutdown authenticated gRPC server: nil context")
}
s.stateMu.RLock()
server := s.server
s.stateMu.RUnlock()
if server == nil {
return nil
}
if s.pushHub != nil {
s.pushHub.Shutdown()
}
stopped := make(chan struct{})
go func() {
server.GracefulStop()
close(stopped)
}()
select {
case <-stopped:
return nil
case <-ctx.Done():
server.Stop()
<-stopped
return fmt.Errorf("shutdown authenticated gRPC server: %w", ctx.Err())
}
}
func (s *Server) listenAddr() string {
s.stateMu.RLock()
defer s.stateMu.RUnlock()
if s.listener == nil {
return ""
}
return s.listener.Addr().String()
}
func normalizeServerDependencies(deps ServerDependencies) ServerDependencies {
if deps.Router == nil {
deps.Router = downstream.NewStaticRouter(nil)
}
if deps.ResponseSigner == nil {
deps.ResponseSigner = unavailableResponseSigner{}
}
if deps.SessionCache == nil {
deps.SessionCache = unavailableSessionCache{}
}
if deps.Clock == nil {
deps.Clock = clock.System{}
}
if deps.ReplayStore == nil {
deps.ReplayStore = unavailableReplayStore{}
}
if deps.Limiter == nil {
deps.Limiter = ratelimit.NewInMemory()
}
if deps.Policy == nil {
deps.Policy = noopAuthenticatedRequestPolicy{}
}
if deps.Logger == nil {
deps.Logger = zap.NewNop()
}
return deps
}
+332
View File
@@ -0,0 +1,332 @@
package grpcapi
import (
"context"
"testing"
"time"
"galaxy/gateway/internal/app"
"galaxy/gateway/internal/config"
"galaxy/gateway/internal/session"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/status"
)
func TestExecuteCommandRejectsMalformedEnvelope(t *testing.T) {
t.Parallel()
server, runGateway := newTestGateway(t, ServerDependencies{})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), &gatewayv1.ExecuteCommandRequest{})
require.Error(t, err)
assert.Equal(t, codes.InvalidArgument, status.Code(err))
}
func TestSubscribeEventsRejectsMalformedEnvelope(t *testing.T) {
t.Parallel()
server, runGateway := newTestGateway(t, ServerDependencies{})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
err := subscribeEventsError(t, context.Background(), client, &gatewayv1.SubscribeEventsRequest{})
require.Error(t, err)
assert.Equal(t, codes.InvalidArgument, status.Code(err))
}
func TestExecuteCommandRejectsUnsupportedProtocolVersion(t *testing.T) {
t.Parallel()
server, runGateway := newTestGateway(t, ServerDependencies{})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), &gatewayv1.ExecuteCommandRequest{
ProtocolVersion: "v2",
DeviceSessionId: "device-session-123",
MessageType: "fleet.move",
TimestampMs: 123456789,
RequestId: "request-123",
PayloadBytes: []byte("payload"),
PayloadHash: []byte("hash"),
Signature: []byte("signature"),
})
require.Error(t, err)
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
assert.Equal(t, `unsupported protocol_version "v2"`, status.Convert(err).Message())
}
func TestExecuteCommandValidEnvelopeStillReturnsUnimplemented(t *testing.T) {
t.Parallel()
server, runGateway := newTestGateway(t, ServerDependencies{
SessionCache: staticSessionCache{
lookupFunc: func(context.Context, string) (session.Record, error) {
return newActiveSessionRecord(), nil
},
},
ReplayStore: staticReplayStore{},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
require.Error(t, err)
assert.Equal(t, codes.Unimplemented, status.Code(err))
}
func TestExecuteCommandMissingReplayStoreFailsClosed(t *testing.T) {
t.Parallel()
server, runGateway := newTestGateway(t, ServerDependencies{
SessionCache: staticSessionCache{
lookupFunc: func(context.Context, string) (session.Record, error) {
return newActiveSessionRecord(), nil
},
},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
require.Error(t, err)
assert.Equal(t, codes.Unavailable, status.Code(err))
assert.Equal(t, "replay store is unavailable", status.Convert(err).Message())
}
func TestSubscribeEventsValidEnvelopeSendsBootstrapEventAndWaitsForCancellation(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
server, runGateway := newTestGateway(t, ServerDependencies{
SessionCache: staticSessionCache{
lookupFunc: func(context.Context, string) (session.Record, error) {
return newActiveSessionRecord(), nil
},
},
ReplayStore: staticReplayStore{},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
stream, err := client.SubscribeEvents(ctx, newValidSubscribeEventsRequest())
require.NoError(t, err)
event := recvBootstrapEvent(t, stream)
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-123", "trace-123", testCurrentTime.UnixMilli())
recvResult := make(chan error, 1)
go func() {
_, recvErr := stream.Recv()
recvResult <- recvErr
}()
require.Never(t, func() bool {
select {
case <-recvResult:
return true
default:
return false
}
}, 100*time.Millisecond, 10*time.Millisecond, "stream closed before cancellation")
cancel()
var recvErr error
require.Eventually(t, func() bool {
select {
case recvErr = <-recvResult:
return true
default:
return false
}
}, time.Second, 10*time.Millisecond, "stream did not stop after client cancellation")
require.Error(t, recvErr)
assert.Equal(t, codes.Canceled, status.Code(recvErr))
}
func TestSubscribeEventsMissingReplayStoreFailsClosed(t *testing.T) {
t.Parallel()
server, runGateway := newTestGateway(t, ServerDependencies{
SessionCache: staticSessionCache{
lookupFunc: func(context.Context, string) (session.Record, error) {
return newActiveSessionRecord(), nil
},
},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest())
require.Error(t, err)
assert.Equal(t, codes.Unavailable, status.Code(err))
assert.Equal(t, "replay store is unavailable", status.Convert(err).Message())
}
func TestServerLifecycle(t *testing.T) {
t.Parallel()
server, runGateway := newTestGateway(t, ServerDependencies{})
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
require.NoError(t, conn.Close())
runGateway.stop(t)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_, err := grpc.DialContext(
ctx,
addr,
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithBlock(),
)
require.Error(t, err)
}
type runningGateway struct {
cancel context.CancelFunc
resultCh chan error
}
func newTestGateway(t *testing.T, deps ServerDependencies) (*Server, runningGateway) {
t.Helper()
grpcCfg := config.DefaultAuthenticatedGRPCConfig()
grpcCfg.Addr = "127.0.0.1:0"
grpcCfg.FreshnessWindow = testFreshnessWindow
return newTestGatewayWithGRPCConfig(t, grpcCfg, deps)
}
func newTestGatewayWithGRPCConfig(t *testing.T, grpcCfg config.AuthenticatedGRPCConfig, deps ServerDependencies) (*Server, runningGateway) {
t.Helper()
cfg := config.Config{
ShutdownTimeout: time.Second,
AuthenticatedGRPC: grpcCfg,
}
if deps.Clock == nil {
deps.Clock = fixedClock{now: testCurrentTime}
}
if deps.ResponseSigner == nil {
deps.ResponseSigner = newTestResponseSigner()
}
if deps.Router == nil && deps.Service != nil {
deps.Router = executeCommandAdapterRouter{service: deps.Service}
}
server := NewServer(cfg.AuthenticatedGRPC, deps)
application := app.New(cfg, server)
ctx, cancel := context.WithCancel(context.Background())
resultCh := make(chan error, 1)
go func() {
resultCh <- application.Run(ctx)
}()
return server, runningGateway{
cancel: cancel,
resultCh: resultCh,
}
}
func (g runningGateway) stop(t *testing.T) {
t.Helper()
g.cancel()
var err error
require.Eventually(t, func() bool {
select {
case err = <-g.resultCh:
return true
default:
return false
}
}, 2*time.Second, 10*time.Millisecond, "gateway did not stop after cancellation")
require.NoError(t, err)
}
func waitForListenAddr(t *testing.T, server *Server) string {
t.Helper()
var addr string
require.Eventually(t, func() bool {
addr = server.listenAddr()
return addr != ""
}, time.Second, 10*time.Millisecond, "server did not start listening")
return addr
}
func dialGatewayClient(t *testing.T, addr string) *grpc.ClientConn {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
conn, err := grpc.DialContext(
ctx,
addr,
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithBlock(),
)
require.NoError(t, err)
return conn
}
+126
View File
@@ -0,0 +1,126 @@
package grpcapi
import (
"context"
"errors"
"galaxy/gateway/internal/session"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// resolvedSessionFromContext returns the session record previously attached to
// ctx by the session-lookup gateway wrapper.
func resolvedSessionFromContext(ctx context.Context) (session.Record, bool) {
if ctx == nil {
return session.Record{}, false
}
record, ok := ctx.Value(resolvedSessionContextKey{}).(session.Record)
if !ok {
return session.Record{}, false
}
return cloneSessionRecord(record), true
}
// sessionLookupService resolves the authenticated session from SessionCache
// after envelope parsing succeeds and before later auth steps run.
type sessionLookupService struct {
gatewayv1.UnimplementedEdgeGatewayServer
delegate gatewayv1.EdgeGatewayServer
cache session.Cache
}
// ExecuteCommand resolves the cached session for req and only then forwards it
// to the configured delegate with the resolved session attached to ctx.
func (s sessionLookupService) ExecuteCommand(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
record, err := s.lookupSession(ctx)
if err != nil {
return nil, err
}
return s.delegate.ExecuteCommand(context.WithValue(ctx, resolvedSessionContextKey{}, cloneSessionRecord(record)), req)
}
// SubscribeEvents resolves the cached session for req and only then forwards it
// to the configured delegate with the resolved session attached to the stream
// context.
func (s sessionLookupService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
record, err := s.lookupSession(stream.Context())
if err != nil {
return err
}
return s.delegate.SubscribeEvents(req, resolvedSessionContextStream{
ServerStreamingServer: stream,
ctx: context.WithValue(stream.Context(), resolvedSessionContextKey{}, cloneSessionRecord(record)),
})
}
// newSessionLookupService wraps delegate with the session-cache lookup gate.
func newSessionLookupService(delegate gatewayv1.EdgeGatewayServer, cache session.Cache) gatewayv1.EdgeGatewayServer {
return sessionLookupService{
delegate: delegate,
cache: cache,
}
}
func (s sessionLookupService) lookupSession(ctx context.Context) (session.Record, error) {
envelope, ok := parsedEnvelopeFromContext(ctx)
if !ok {
return session.Record{}, status.Error(codes.Internal, "authenticated request context is incomplete")
}
record, err := s.cache.Lookup(ctx, envelope.DeviceSessionID)
switch {
case err == nil:
case errors.Is(err, session.ErrNotFound):
return session.Record{}, status.Error(codes.Unauthenticated, "unknown device session")
default:
return session.Record{}, status.Error(codes.Unavailable, "session cache is unavailable")
}
if record.Status == session.StatusRevoked {
return session.Record{}, status.Error(codes.FailedPrecondition, "device session is revoked")
}
return cloneSessionRecord(record), nil
}
func cloneSessionRecord(record session.Record) session.Record {
cloned := record
if record.RevokedAtMS != nil {
value := *record.RevokedAtMS
cloned.RevokedAtMS = &value
}
return cloned
}
type resolvedSessionContextKey struct{}
type resolvedSessionContextStream struct {
grpc.ServerStreamingServer[gatewayv1.GatewayEvent]
ctx context.Context
}
func (s resolvedSessionContextStream) Context() context.Context {
if s.ctx == nil {
return context.Background()
}
return s.ctx
}
type unavailableSessionCache struct{}
func (unavailableSessionCache) Lookup(context.Context, string) (session.Record, error) {
return session.Record{}, errors.New("session cache is unavailable")
}
var _ gatewayv1.EdgeGatewayServer = sessionLookupService{}
@@ -0,0 +1,294 @@
package grpcapi
import (
"context"
"errors"
"io"
"testing"
"galaxy/gateway/internal/session"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
func TestExecuteCommandRejectsUnknownSession(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{
lookupFunc: func(context.Context, string) (session.Record, error) {
return session.Record{}, session.ErrNotFound
},
},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
require.Error(t, err)
assert.Equal(t, codes.Unauthenticated, status.Code(err))
assert.Equal(t, "unknown device session", status.Convert(err).Message())
assert.Zero(t, delegate.executeCalls)
}
func TestSubscribeEventsRejectsUnknownSession(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{
lookupFunc: func(context.Context, string) (session.Record, error) {
return session.Record{}, session.ErrNotFound
},
},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest())
require.Error(t, err)
assert.Equal(t, codes.Unauthenticated, status.Code(err))
assert.Equal(t, "unknown device session", status.Convert(err).Message())
assert.Zero(t, delegate.subscribeCalls)
}
func TestExecuteCommandRejectsRevokedSession(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newRevokedSessionRecord(), nil }},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
require.Error(t, err)
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
assert.Equal(t, "device session is revoked", status.Convert(err).Message())
assert.Zero(t, delegate.executeCalls)
}
func TestSubscribeEventsRejectsRevokedSession(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newRevokedSessionRecord(), nil }},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest())
require.Error(t, err)
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
assert.Equal(t, "device session is revoked", status.Convert(err).Message())
assert.Zero(t, delegate.subscribeCalls)
}
func TestExecuteCommandRejectsSessionCacheUnavailable(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{
lookupFunc: func(context.Context, string) (session.Record, error) {
return session.Record{}, errors.New("redis down")
},
},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
require.Error(t, err)
assert.Equal(t, codes.Unavailable, status.Code(err))
assert.Equal(t, "session cache is unavailable", status.Convert(err).Message())
assert.Zero(t, delegate.executeCalls)
}
func TestSubscribeEventsRejectsSessionCacheUnavailable(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{
lookupFunc: func(context.Context, string) (session.Record, error) {
return session.Record{}, errors.New("redis down")
},
},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest())
require.Error(t, err)
assert.Equal(t, codes.Unavailable, status.Code(err))
assert.Equal(t, "session cache is unavailable", status.Convert(err).Message())
assert.Zero(t, delegate.subscribeCalls)
}
func TestExecuteCommandAttachesResolvedSession(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{
executeCommandFunc: func(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
record, ok := resolvedSessionFromContext(ctx)
require.True(t, ok)
assert.Equal(t, newActiveSessionRecord(), record)
return &gatewayv1.ExecuteCommandResponse{RequestId: req.GetRequestId()}, nil
},
}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
ReplayStore: staticReplayStore{},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
response, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
require.NoError(t, err)
assert.Equal(t, "request-123", response.GetRequestId())
}
func TestSubscribeEventsAttachesResolvedSession(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{
subscribeEventsFunc: func(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
record, ok := resolvedSessionFromContext(stream.Context())
require.True(t, ok)
assert.Equal(t, newActiveSessionRecord(), record)
return nil
},
}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
ReplayStore: staticReplayStore{},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
stream, err := client.SubscribeEvents(context.Background(), newValidSubscribeEventsRequest())
require.NoError(t, err)
event := recvBootstrapEvent(t, stream)
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-123", "trace-123", testCurrentTime.UnixMilli())
_, err = stream.Recv()
require.ErrorIs(t, err, io.EOF)
}
func TestSubscribeEventsAttachesAuthenticatedStreamBinding(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{
subscribeEventsFunc: func(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
binding, ok := authenticatedStreamBindingFromContext(stream.Context())
require.True(t, ok)
assert.Equal(t, authenticatedStreamBinding{
UserID: "user-123",
DeviceSessionID: "device-session-123",
MessageType: "gateway.subscribe",
RequestID: "request-123",
TraceID: "trace-123",
}, binding)
return nil
},
}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
ReplayStore: staticReplayStore{},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
stream, err := client.SubscribeEvents(context.Background(), newValidSubscribeEventsRequest())
require.NoError(t, err)
event := recvBootstrapEvent(t, stream)
assertServerTimeBootstrapEvent(t, event, newTestResponseSignerPublicKey(), "request-123", "trace-123", testCurrentTime.UnixMilli())
_, err = stream.Recv()
require.ErrorIs(t, err, io.EOF)
}
type staticSessionCache struct {
lookupFunc func(context.Context, string) (session.Record, error)
}
func (c staticSessionCache) Lookup(ctx context.Context, deviceSessionID string) (session.Record, error) {
return c.lookupFunc(ctx, deviceSessionID)
}
+80
View File
@@ -0,0 +1,80 @@
package grpcapi
import (
"context"
"errors"
"galaxy/gateway/internal/authn"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// signatureVerifyingService applies client-signature verification after
// payload integrity checks and before later auth or routing steps run.
type signatureVerifyingService struct {
gatewayv1.UnimplementedEdgeGatewayServer
delegate gatewayv1.EdgeGatewayServer
}
// ExecuteCommand verifies req client signature before delegating to the
// configured service implementation.
func (s signatureVerifyingService) ExecuteCommand(ctx context.Context, req *gatewayv1.ExecuteCommandRequest) (*gatewayv1.ExecuteCommandResponse, error) {
if err := verifyRequestSignature(ctx); err != nil {
return nil, err
}
return s.delegate.ExecuteCommand(ctx, req)
}
// SubscribeEvents verifies req client signature before delegating to the
// configured service implementation.
func (s signatureVerifyingService) SubscribeEvents(req *gatewayv1.SubscribeEventsRequest, stream grpc.ServerStreamingServer[gatewayv1.GatewayEvent]) error {
if err := verifyRequestSignature(stream.Context()); err != nil {
return err
}
return s.delegate.SubscribeEvents(req, stream)
}
// newSignatureVerifyingService wraps delegate with the client-signature
// verification gate.
func newSignatureVerifyingService(delegate gatewayv1.EdgeGatewayServer) gatewayv1.EdgeGatewayServer {
return signatureVerifyingService{delegate: delegate}
}
func verifyRequestSignature(ctx context.Context) error {
envelope, ok := parsedEnvelopeFromContext(ctx)
if !ok {
return status.Error(codes.Internal, "authenticated request context is incomplete")
}
record, ok := resolvedSessionFromContext(ctx)
if !ok {
return status.Error(codes.Internal, "authenticated request context is incomplete")
}
err := authn.VerifyRequestSignature(record.ClientPublicKey, envelope.Signature, authn.RequestSigningFields{
ProtocolVersion: envelope.ProtocolVersion,
DeviceSessionID: envelope.DeviceSessionID,
MessageType: envelope.MessageType,
TimestampMS: envelope.TimestampMS,
RequestID: envelope.RequestID,
PayloadHash: envelope.PayloadHash,
})
switch {
case err == nil:
return nil
case errors.Is(err, authn.ErrInvalidClientPublicKey):
return status.Error(codes.Unavailable, "session cache is unavailable")
case errors.Is(err, authn.ErrInvalidRequestSignature):
return status.Error(codes.Unauthenticated, "invalid request signature")
default:
return status.Error(codes.Internal, "request signature verification failed")
}
}
var _ gatewayv1.EdgeGatewayServer = signatureVerifyingService{}
@@ -0,0 +1,188 @@
package grpcapi
import (
"context"
"testing"
"galaxy/gateway/internal/session"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
func TestExecuteCommandRejectsInvalidSignature(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
req := newValidExecuteCommandRequest()
req.Signature[0] ^= 0xff
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), req)
require.Error(t, err)
assert.Equal(t, codes.Unauthenticated, status.Code(err))
assert.Equal(t, "invalid request signature", status.Convert(err).Message())
assert.Zero(t, delegate.executeCalls)
}
func TestExecuteCommandRejectsWrongKey(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{
lookupFunc: func(context.Context, string) (session.Record, error) {
record := newActiveSessionRecord()
record.ClientPublicKey = alternateTestClientPublicKeyBase64()
return record, nil
},
},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
require.Error(t, err)
assert.Equal(t, codes.Unauthenticated, status.Code(err))
assert.Equal(t, "invalid request signature", status.Convert(err).Message())
assert.Zero(t, delegate.executeCalls)
}
func TestExecuteCommandRejectsInvalidCachedPublicKey(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{
lookupFunc: func(context.Context, string) (session.Record, error) {
record := newActiveSessionRecord()
record.ClientPublicKey = "%%%not-base64%%%"
return record, nil
},
},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err := client.ExecuteCommand(context.Background(), newValidExecuteCommandRequest())
require.Error(t, err)
assert.Equal(t, codes.Unavailable, status.Code(err))
assert.Equal(t, "session cache is unavailable", status.Convert(err).Message())
assert.Zero(t, delegate.executeCalls)
}
func TestSubscribeEventsRejectsInvalidSignature(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{lookupFunc: func(context.Context, string) (session.Record, error) { return newActiveSessionRecord(), nil }},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
req := newValidSubscribeEventsRequest()
req.Signature[0] ^= 0xff
client := gatewayv1.NewEdgeGatewayClient(conn)
err := subscribeEventsError(t, context.Background(), client, req)
require.Error(t, err)
assert.Equal(t, codes.Unauthenticated, status.Code(err))
assert.Equal(t, "invalid request signature", status.Convert(err).Message())
assert.Zero(t, delegate.subscribeCalls)
}
func TestSubscribeEventsRejectsWrongKey(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{
lookupFunc: func(context.Context, string) (session.Record, error) {
record := newActiveSessionRecord()
record.ClientPublicKey = alternateTestClientPublicKeyBase64()
return record, nil
},
},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest())
require.Error(t, err)
assert.Equal(t, codes.Unauthenticated, status.Code(err))
assert.Equal(t, "invalid request signature", status.Convert(err).Message())
assert.Zero(t, delegate.subscribeCalls)
}
func TestSubscribeEventsRejectsInvalidCachedPublicKey(t *testing.T) {
t.Parallel()
delegate := &recordingEdgeGatewayService{}
server, runGateway := newTestGateway(t, ServerDependencies{
Service: delegate,
SessionCache: staticSessionCache{
lookupFunc: func(context.Context, string) (session.Record, error) {
record := newActiveSessionRecord()
record.ClientPublicKey = "%%%not-base64%%%"
return record, nil
},
},
})
defer runGateway.stop(t)
addr := waitForListenAddr(t, server)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
err := subscribeEventsError(t, context.Background(), client, newValidSubscribeEventsRequest())
require.Error(t, err)
assert.Equal(t, codes.Unavailable, status.Code(err))
assert.Equal(t, "session cache is unavailable", status.Convert(err).Message())
assert.Zero(t, delegate.subscribeCalls)
}
@@ -0,0 +1,298 @@
package grpcapi
import (
"context"
"crypto/ed25519"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"time"
"galaxy/gateway/internal/authn"
"galaxy/gateway/internal/downstream"
"galaxy/gateway/internal/session"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
gatewayfbs "galaxy/schema/fbs/gateway"
flatbuffers "github.com/google/flatbuffers/go"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
)
var (
testCurrentTime = time.Date(2026, time.April, 1, 12, 0, 0, 0, time.UTC)
testFreshnessWindow = 5 * time.Minute
)
func newValidExecuteCommandRequest() *gatewayv1.ExecuteCommandRequest {
return newValidExecuteCommandRequestWithSessionAndRequestID("device-session-123", "request-123")
}
func newValidExecuteCommandRequestWithSessionAndRequestID(deviceSessionID string, requestID string) *gatewayv1.ExecuteCommandRequest {
return newValidExecuteCommandRequestWithTimestamp(deviceSessionID, requestID, testCurrentTime.UnixMilli())
}
func newValidExecuteCommandRequestWithTimestamp(deviceSessionID string, requestID string, timestampMS int64) *gatewayv1.ExecuteCommandRequest {
payloadBytes := []byte("payload")
payloadHash := sha256.Sum256(payloadBytes)
req := &gatewayv1.ExecuteCommandRequest{
ProtocolVersion: supportedProtocolVersion,
DeviceSessionId: deviceSessionID,
MessageType: "fleet.move",
TimestampMs: timestampMS,
RequestId: requestID,
PayloadBytes: payloadBytes,
PayloadHash: payloadHash[:],
TraceId: "trace-123",
}
req.Signature = signRequest(req.GetProtocolVersion(), req.GetDeviceSessionId(), req.GetMessageType(), req.GetTimestampMs(), req.GetRequestId(), req.GetPayloadHash())
return req
}
func newValidSubscribeEventsRequest() *gatewayv1.SubscribeEventsRequest {
return newValidSubscribeEventsRequestWithSessionAndRequestID("device-session-123", "request-123")
}
func newValidSubscribeEventsRequestWithSessionAndRequestID(deviceSessionID string, requestID string) *gatewayv1.SubscribeEventsRequest {
return newValidSubscribeEventsRequestWithTimestamp(deviceSessionID, requestID, testCurrentTime.UnixMilli())
}
func newValidSubscribeEventsRequestWithTimestamp(deviceSessionID string, requestID string, timestampMS int64) *gatewayv1.SubscribeEventsRequest {
payloadHash := sha256.Sum256(nil)
req := &gatewayv1.SubscribeEventsRequest{
ProtocolVersion: supportedProtocolVersion,
DeviceSessionId: deviceSessionID,
MessageType: "gateway.subscribe",
TimestampMs: timestampMS,
RequestId: requestID,
PayloadHash: payloadHash[:],
TraceId: "trace-123",
}
req.Signature = signRequest(req.GetProtocolVersion(), req.GetDeviceSessionId(), req.GetMessageType(), req.GetTimestampMs(), req.GetRequestId(), req.GetPayloadHash())
return req
}
func newActiveSessionRecord() session.Record {
return newActiveSessionRecordWithSessionID("device-session-123")
}
func newActiveSessionRecordWithSessionID(deviceSessionID string) session.Record {
return session.Record{
DeviceSessionID: deviceSessionID,
UserID: "user-123",
ClientPublicKey: testClientPublicKeyBase64(),
Status: session.StatusActive,
}
}
func newRevokedSessionRecord() session.Record {
revokedAtMS := int64(123456789)
return session.Record{
DeviceSessionID: "device-session-123",
UserID: "user-123",
ClientPublicKey: testClientPublicKeyBase64(),
Status: session.StatusRevoked,
RevokedAtMS: &revokedAtMS,
}
}
func alternateTestClientPublicKeyBase64() string {
return base64.StdEncoding.EncodeToString(newTestPrivateKey("alternate").Public().(ed25519.PublicKey))
}
func testClientPublicKeyBase64() string {
return base64.StdEncoding.EncodeToString(newTestPrivateKey("primary").Public().(ed25519.PublicKey))
}
func signRequest(protocolVersion, deviceSessionID, messageType string, timestampMS int64, requestID string, payloadHash []byte) []byte {
return ed25519.Sign(newTestPrivateKey("primary"), authn.BuildRequestSigningInput(authn.RequestSigningFields{
ProtocolVersion: protocolVersion,
DeviceSessionID: deviceSessionID,
MessageType: messageType,
TimestampMS: timestampMS,
RequestID: requestID,
PayloadHash: payloadHash,
}))
}
func newTestPrivateKey(label string) ed25519.PrivateKey {
seed := sha256.Sum256([]byte("gateway-grpcapi-signature-test-" + label))
return ed25519.NewKeyFromSeed(seed[:])
}
func newTestEd25519ResponseSigner() *authn.Ed25519ResponseSigner {
pemBytes := pem.EncodeToMemory(&pem.Block{
Type: "PRIVATE KEY",
Bytes: mustMarshalPKCS8PrivateKey(newTestPrivateKey("response-signer")),
})
signer, err := authn.ParseEd25519ResponseSignerPEM(pemBytes)
if err != nil {
panic(err)
}
return signer
}
func newTestResponseSigner() authn.ResponseSigner {
return newTestEd25519ResponseSigner()
}
func newTestResponseSignerPublicKey() ed25519.PublicKey {
return newTestEd25519ResponseSigner().PublicKey()
}
func mustMarshalPKCS8PrivateKey(privateKey ed25519.PrivateKey) []byte {
encoded, err := x509.MarshalPKCS8PrivateKey(privateKey)
if err != nil {
panic(err)
}
return encoded
}
type fixedClock struct {
now time.Time
}
func (c fixedClock) Now() time.Time {
return c.now
}
func recvBootstrapEvent(t interface {
require.TestingT
Helper()
}, stream grpc.ServerStreamingClient[gatewayv1.GatewayEvent]) *gatewayv1.GatewayEvent {
t.Helper()
event, err := stream.Recv()
require.NoError(t, err)
return event
}
func subscribeEventsError(t interface {
require.TestingT
Helper()
}, ctx context.Context, client gatewayv1.EdgeGatewayClient, req *gatewayv1.SubscribeEventsRequest) error {
t.Helper()
stream, err := client.SubscribeEvents(ctx, req)
if err != nil {
return err
}
_, err = stream.Recv()
return err
}
func assertServerTimeBootstrapEvent(t interface {
require.TestingT
Helper()
}, event *gatewayv1.GatewayEvent, publicKey ed25519.PublicKey, wantRequestID string, wantTraceID string, wantTimestampMS int64) {
t.Helper()
require.NotNil(t, event)
assert.Equal(t, serverTimeEventType, event.GetEventType())
assert.Equal(t, wantRequestID, event.GetEventId())
assert.Equal(t, wantRequestID, event.GetRequestId())
assert.Equal(t, wantTraceID, event.GetTraceId())
assert.Equal(t, wantTimestampMS, event.GetTimestampMs())
require.NoError(t, authn.VerifyPayloadHash(event.GetPayloadBytes(), event.GetPayloadHash()))
require.NoError(t, authn.VerifyEventSignature(publicKey, event.GetSignature(), authn.EventSigningFields{
EventType: event.GetEventType(),
EventID: event.GetEventId(),
TimestampMS: event.GetTimestampMs(),
RequestID: event.GetRequestId(),
TraceID: event.GetTraceId(),
PayloadHash: event.GetPayloadHash(),
}))
payload := gatewayfbs.GetRootAsServerTimeEvent(event.GetPayloadBytes(), flatbuffers.UOffsetT(0))
assert.Equal(t, wantTimestampMS, payload.ServerTimeMs())
}
type staticReplayStore struct {
reserveFunc func(context.Context, string, string, time.Duration) error
}
func (s staticReplayStore) Reserve(ctx context.Context, deviceSessionID string, requestID string, ttl time.Duration) error {
if s.reserveFunc != nil {
return s.reserveFunc(ctx, deviceSessionID, requestID, ttl)
}
return nil
}
type executeCommandAdapterRouter struct {
service gatewayv1.EdgeGatewayServer
}
func (r executeCommandAdapterRouter) Route(string) (downstream.Client, error) {
return executeCommandAdapterClient{service: r.service}, nil
}
type executeCommandAdapterClient struct {
service gatewayv1.EdgeGatewayServer
}
func (c executeCommandAdapterClient) ExecuteCommand(ctx context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
response, err := c.service.ExecuteCommand(ctx, &gatewayv1.ExecuteCommandRequest{
ProtocolVersion: command.ProtocolVersion,
DeviceSessionId: command.DeviceSessionID,
MessageType: command.MessageType,
TimestampMs: command.TimestampMS,
RequestId: command.RequestID,
PayloadBytes: command.PayloadBytes,
TraceId: command.TraceID,
})
if err != nil {
return downstream.UnaryResult{}, err
}
resultCode := response.GetResultCode()
if resultCode == "" {
resultCode = "ok"
}
return downstream.UnaryResult{
ResultCode: resultCode,
PayloadBytes: response.GetPayloadBytes(),
}, nil
}
type recordingDownstreamClient struct {
executeCalls int
commands []downstream.AuthenticatedCommand
executeFunc func(context.Context, downstream.AuthenticatedCommand) (downstream.UnaryResult, error)
}
func (c *recordingDownstreamClient) ExecuteCommand(ctx context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
c.executeCalls++
c.commands = append(c.commands, downstream.AuthenticatedCommand{
ProtocolVersion: command.ProtocolVersion,
UserID: command.UserID,
DeviceSessionID: command.DeviceSessionID,
MessageType: command.MessageType,
TimestampMS: command.TimestampMS,
RequestID: command.RequestID,
TraceID: command.TraceID,
PayloadBytes: append([]byte(nil), command.PayloadBytes...),
})
if c.executeFunc != nil {
return c.executeFunc(ctx, command)
}
return downstream.UnaryResult{
ResultCode: "ok",
PayloadBytes: []byte("response"),
}, nil
}
+84
View File
@@ -0,0 +1,84 @@
// Package logging configures the gateway structured logger and provides
// context-aware helpers for attaching OpenTelemetry trace identifiers.
package logging
import (
"context"
"strings"
"galaxy/gateway/internal/config"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
// New constructs the process-wide JSON logger from cfg.
func New(cfg config.LoggingConfig) (*zap.Logger, error) {
level := zap.NewAtomicLevel()
if err := level.UnmarshalText([]byte(strings.TrimSpace(cfg.Level))); err != nil {
return nil, err
}
zapCfg := zap.NewProductionConfig()
zapCfg.Level = level
zapCfg.Sampling = nil
zapCfg.Encoding = "json"
zapCfg.EncoderConfig.TimeKey = "timestamp"
zapCfg.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
zapCfg.OutputPaths = []string{"stdout"}
zapCfg.ErrorOutputPaths = []string{"stderr"}
return zapCfg.Build()
}
// TraceFieldsFromContext returns zap fields for the active OpenTelemetry span
// when ctx carries a valid span context.
func TraceFieldsFromContext(ctx context.Context) []zap.Field {
if ctx == nil {
return nil
}
spanContext := trace.SpanContextFromContext(ctx)
if !spanContext.IsValid() {
return nil
}
return []zap.Field{
zap.String("otel_trace_id", spanContext.TraceID().String()),
zap.String("otel_span_id", spanContext.SpanID().String()),
}
}
// Sync flushes logger and ignores the benign stdout or stderr sync errors
// commonly returned by containerized or redirected process outputs.
func Sync(logger *zap.Logger) error {
if logger == nil {
return nil
}
err := logger.Sync()
if err == nil || isIgnorableSyncError(err) {
return nil
}
return err
}
func isIgnorableSyncError(err error) bool {
if err == nil {
return false
}
message := strings.ToLower(err.Error())
switch {
case strings.Contains(message, "invalid argument"):
return true
case strings.Contains(message, "bad file descriptor"):
return true
case strings.Contains(message, "inappropriate ioctl for device"):
return true
default:
return false
}
}
+385
View File
@@ -0,0 +1,385 @@
// Package push provides the in-memory hub used to fan out internal
// client-facing events to active authenticated push streams.
package push
import (
"bytes"
"errors"
"strings"
"sync"
)
const defaultSubscriptionQueueCapacity = 64
var (
// ErrSubscriptionOverflow reports that one push stream stopped consuming
// events quickly enough and its bounded queue overflowed.
ErrSubscriptionOverflow = errors.New("push stream overflowed")
// ErrSubscriptionRevoked reports that the authenticated device session bound
// to the push stream was revoked and the stream must terminate.
ErrSubscriptionRevoked = errors.New("device session is revoked")
// ErrHubShuttingDown reports that the gateway is shutting down and all
// active push streams must terminate promptly.
ErrHubShuttingDown = errors.New("gateway is shutting down")
)
// StreamBinding identifies one authenticated push stream tracked by Hub.
type StreamBinding struct {
// UserID is the verified authenticated user bound to the stream.
UserID string
// DeviceSessionID is the verified authenticated device session bound to the
// stream.
DeviceSessionID string
}
// Event is the internal client-facing event delivered from internal pub/sub to
// active push streams.
type Event struct {
// UserID identifies the authenticated user that should receive the event.
UserID string
// DeviceSessionID optionally narrows delivery to one device session.
DeviceSessionID string
// EventType identifies the stable client-facing event category.
EventType string
// EventID is the stable event correlation identifier.
EventID string
// PayloadBytes carries the opaque event payload bytes.
PayloadBytes []byte
// RequestID optionally correlates the event to an earlier client request.
RequestID string
// TraceID optionally carries tracing correlation.
TraceID string
}
// Subscription represents one active push stream registered in Hub.
type Subscription struct {
hub *Hub
id uint64
binding StreamBinding
events chan Event
done chan struct{}
closeOnce sync.Once
stateMu sync.RWMutex
err error
}
// Observer receives push stream lifecycle notifications suitable for metrics
// bookkeeping.
type Observer interface {
// Registered reports one active push stream binding.
Registered(binding StreamBinding)
// Unregistered reports that binding stopped with err. A nil err means the
// stream ended without a hub-enforced terminal reason.
Unregistered(binding StreamBinding, err error)
}
// Events returns the ordered event queue for the subscription.
func (s *Subscription) Events() <-chan Event {
if s == nil {
return nil
}
return s.events
}
// Done closes when the subscription has been removed from the hub.
func (s *Subscription) Done() <-chan struct{} {
if s == nil {
return nil
}
return s.done
}
// Err returns the terminal subscription error, if any.
func (s *Subscription) Err() error {
if s == nil {
return nil
}
s.stateMu.RLock()
defer s.stateMu.RUnlock()
return s.err
}
// Close unregisters the subscription from its hub.
func (s *Subscription) Close() {
if s == nil || s.hub == nil {
return
}
s.hub.unregister(s.id, nil)
}
func (s *Subscription) enqueue(event Event) bool {
if s == nil {
return true
}
cloned := cloneEvent(event)
select {
case <-s.done:
return true
default:
}
select {
case s.events <- cloned:
return true
case <-s.done:
return true
default:
return false
}
}
func (s *Subscription) closeWithError(err error) {
if s == nil {
return
}
s.closeOnce.Do(func() {
s.stateMu.Lock()
s.err = err
s.stateMu.Unlock()
close(s.done)
})
}
// Hub tracks active authenticated push streams and fans out client-facing
// events to the matching subscriptions.
type Hub struct {
mu sync.RWMutex
nextID uint64
queueCapacity int
observer Observer
byID map[uint64]*Subscription
byUser map[string]map[uint64]*Subscription
bySession map[string]map[uint64]*Subscription
}
// NewHub constructs a push hub with one bounded in-memory queue per
// subscription. Non-positive queueCapacity falls back to the package default.
func NewHub(queueCapacity int) *Hub {
return NewHubWithObserver(queueCapacity, nil)
}
// NewHubWithObserver constructs a push hub that also reports stream lifecycle
// changes to observer.
func NewHubWithObserver(queueCapacity int, observer Observer) *Hub {
if queueCapacity <= 0 {
queueCapacity = defaultSubscriptionQueueCapacity
}
return &Hub{
queueCapacity: queueCapacity,
observer: observer,
byID: make(map[uint64]*Subscription),
byUser: make(map[string]map[uint64]*Subscription),
bySession: make(map[string]map[uint64]*Subscription),
}
}
// Register adds one authenticated push stream to the hub and returns its
// subscription handle.
func (h *Hub) Register(binding StreamBinding) (*Subscription, error) {
if h == nil {
return nil, errors.New("register push subscription: nil hub")
}
userID := strings.TrimSpace(binding.UserID)
if userID == "" {
return nil, errors.New("register push subscription: user id must not be empty")
}
deviceSessionID := strings.TrimSpace(binding.DeviceSessionID)
if deviceSessionID == "" {
return nil, errors.New("register push subscription: device session id must not be empty")
}
h.mu.Lock()
h.nextID++
subscription := &Subscription{
hub: h,
id: h.nextID,
binding: StreamBinding{
UserID: userID,
DeviceSessionID: deviceSessionID,
},
events: make(chan Event, h.queueCapacity),
done: make(chan struct{}),
}
h.byID[subscription.id] = subscription
addIndexedSubscription(h.byUser, userID, subscription)
addIndexedSubscription(h.bySession, deviceSessionID, subscription)
h.mu.Unlock()
if h.observer != nil {
h.observer.Registered(subscription.binding)
}
return subscription, nil
}
// Publish fans out event to the matching active subscriptions. When one
// subscription queue overflows, only that subscription is closed.
func (h *Hub) Publish(event Event) {
if h == nil {
return
}
targets := h.targets(event)
for _, target := range targets {
if target.enqueue(event) {
continue
}
h.unregister(target.id, ErrSubscriptionOverflow)
}
}
// RevokeDeviceSession closes all active subscriptions bound to the exact
// authenticated device session identifier.
func (h *Hub) RevokeDeviceSession(deviceSessionID string) {
if h == nil {
return
}
deviceSessionID = strings.TrimSpace(deviceSessionID)
if deviceSessionID == "" {
return
}
h.mu.RLock()
targets := cloneSubscriptions(h.bySession[deviceSessionID])
h.mu.RUnlock()
for _, target := range targets {
h.unregister(target.id, ErrSubscriptionRevoked)
}
}
// Shutdown closes every active subscription because the gateway is shutting
// down.
func (h *Hub) Shutdown() {
if h == nil {
return
}
h.mu.RLock()
targets := cloneSubscriptions(h.byID)
h.mu.RUnlock()
for _, target := range targets {
h.unregister(target.id, ErrHubShuttingDown)
}
}
func (h *Hub) targets(event Event) []*Subscription {
userID := strings.TrimSpace(event.UserID)
eventType := strings.TrimSpace(event.EventType)
eventID := strings.TrimSpace(event.EventID)
if h == nil || userID == "" || eventType == "" || eventID == "" {
return nil
}
deviceSessionID := strings.TrimSpace(event.DeviceSessionID)
h.mu.RLock()
defer h.mu.RUnlock()
if deviceSessionID == "" {
return cloneSubscriptions(h.byUser[userID])
}
sessionMatches := cloneSubscriptions(h.bySession[deviceSessionID])
filtered := sessionMatches[:0]
for _, subscription := range sessionMatches {
if subscription.binding.UserID == userID {
filtered = append(filtered, subscription)
}
}
return filtered
}
func (h *Hub) unregister(id uint64, err error) {
if h == nil || id == 0 {
return
}
h.mu.Lock()
subscription, ok := h.byID[id]
if !ok {
h.mu.Unlock()
return
}
delete(h.byID, id)
removeIndexedSubscription(h.byUser, subscription.binding.UserID, id)
removeIndexedSubscription(h.bySession, subscription.binding.DeviceSessionID, id)
h.mu.Unlock()
subscription.closeWithError(err)
if h.observer != nil {
h.observer.Unregistered(subscription.binding, err)
}
}
func addIndexedSubscription(index map[string]map[uint64]*Subscription, key string, subscription *Subscription) {
if _, ok := index[key]; !ok {
index[key] = make(map[uint64]*Subscription)
}
index[key][subscription.id] = subscription
}
func removeIndexedSubscription(index map[string]map[uint64]*Subscription, key string, id uint64) {
bucket, ok := index[key]
if !ok {
return
}
delete(bucket, id)
if len(bucket) == 0 {
delete(index, key)
}
}
func cloneSubscriptions(bucket map[uint64]*Subscription) []*Subscription {
if len(bucket) == 0 {
return nil
}
cloned := make([]*Subscription, 0, len(bucket))
for _, subscription := range bucket {
cloned = append(cloned, subscription)
}
return cloned
}
func cloneEvent(event Event) Event {
return Event{
UserID: event.UserID,
DeviceSessionID: event.DeviceSessionID,
EventType: event.EventType,
EventID: event.EventID,
PayloadBytes: bytes.Clone(event.PayloadBytes),
RequestID: event.RequestID,
TraceID: event.TraceID,
}
}
@@ -0,0 +1,77 @@
package push_test
import (
"testing"
"time"
"galaxy/gateway/internal/push"
"galaxy/gateway/internal/telemetry"
"galaxy/gateway/internal/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestHubObserverClassifiesClosureReasons(t *testing.T) {
t.Parallel()
logger, _ := testutil.NewObservedLogger(t)
telemetryRuntime := testutil.NewTelemetryRuntime(t, logger)
hub := push.NewHubWithObserver(1, telemetry.NewPushObserver(telemetryRuntime))
overflow, err := hub.Register(push.StreamBinding{
UserID: "user-123",
DeviceSessionID: "device-session-overflow",
})
require.NoError(t, err)
revoked, err := hub.Register(push.StreamBinding{
UserID: "user-123",
DeviceSessionID: "device-session-revoked",
})
require.NoError(t, err)
shutdown, err := hub.Register(push.StreamBinding{
UserID: "user-123",
DeviceSessionID: "device-session-shutdown",
})
require.NoError(t, err)
hub.Publish(push.Event{
UserID: "user-123",
DeviceSessionID: "device-session-overflow",
EventType: "fleet.updated",
EventID: "event-1",
PayloadBytes: []byte("payload-1"),
})
hub.Publish(push.Event{
UserID: "user-123",
DeviceSessionID: "device-session-overflow",
EventType: "fleet.updated",
EventID: "event-2",
PayloadBytes: []byte("payload-2"),
})
hub.RevokeDeviceSession("device-session-revoked")
hub.Shutdown()
select {
case <-overflow.Done():
case <-time.After(time.Second):
require.FailNow(t, "overflow subscription did not close")
}
select {
case <-revoked.Done():
case <-time.After(time.Second):
require.FailNow(t, "revoked subscription did not close")
}
select {
case <-shutdown.Done():
case <-time.After(time.Second):
require.FailNow(t, "shutdown subscription did not close")
}
metricsText := testutil.ScrapeMetrics(t, telemetryRuntime.Handler())
assert.Contains(t, metricsText, `gateway_push_stream_closures_total`)
assert.Contains(t, metricsText, `reason="overflow"`)
assert.Contains(t, metricsText, `reason="revoked"`)
assert.Contains(t, metricsText, `reason="shutdown"`)
assert.Contains(t, metricsText, `gateway_push_active_streams`)
}
+270
View File
@@ -0,0 +1,270 @@
package push
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestHubDeliversSessionTargetedEvent(t *testing.T) {
t.Parallel()
hub := NewHub(4)
target, err := hub.Register(StreamBinding{
UserID: "user-123",
DeviceSessionID: "device-session-1",
})
require.NoError(t, err)
otherSession, err := hub.Register(StreamBinding{
UserID: "user-123",
DeviceSessionID: "device-session-2",
})
require.NoError(t, err)
unrelatedUser, err := hub.Register(StreamBinding{
UserID: "user-999",
DeviceSessionID: "device-session-3",
})
require.NoError(t, err)
hub.Publish(Event{
UserID: "user-123",
DeviceSessionID: "device-session-1",
EventType: "fleet.updated",
EventID: "event-1",
PayloadBytes: []byte("payload-1"),
})
assertEvent(t, target.Events(), Event{
UserID: "user-123",
DeviceSessionID: "device-session-1",
EventType: "fleet.updated",
EventID: "event-1",
PayloadBytes: []byte("payload-1"),
})
assertNoEvent(t, otherSession.Events())
assertNoEvent(t, unrelatedUser.Events())
}
func TestHubDeliversUserTargetedEventToAllUserSessions(t *testing.T) {
t.Parallel()
hub := NewHub(4)
first, err := hub.Register(StreamBinding{
UserID: "user-123",
DeviceSessionID: "device-session-1",
})
require.NoError(t, err)
second, err := hub.Register(StreamBinding{
UserID: "user-123",
DeviceSessionID: "device-session-2",
})
require.NoError(t, err)
unrelated, err := hub.Register(StreamBinding{
UserID: "user-999",
DeviceSessionID: "device-session-3",
})
require.NoError(t, err)
hub.Publish(Event{
UserID: "user-123",
EventType: "fleet.updated",
EventID: "event-1",
PayloadBytes: []byte("payload-1"),
RequestID: "request-1",
TraceID: "trace-1",
})
want := Event{
UserID: "user-123",
EventType: "fleet.updated",
EventID: "event-1",
PayloadBytes: []byte("payload-1"),
RequestID: "request-1",
TraceID: "trace-1",
}
assertEvent(t, first.Events(), want)
assertEvent(t, second.Events(), want)
assertNoEvent(t, unrelated.Events())
}
func TestSubscriptionCloseUnregistersStream(t *testing.T) {
t.Parallel()
hub := NewHub(4)
subscription, err := hub.Register(StreamBinding{
UserID: "user-123",
DeviceSessionID: "device-session-1",
})
require.NoError(t, err)
subscription.Close()
select {
case <-subscription.Done():
case <-time.After(time.Second):
require.FailNow(t, "subscription did not close")
}
hub.Publish(Event{
UserID: "user-123",
EventType: "fleet.updated",
EventID: "event-1",
PayloadBytes: []byte("payload-1"),
})
assertNoEvent(t, subscription.Events())
assert.NoError(t, subscription.Err())
}
func TestHubOverflowClosesOnlySlowSubscription(t *testing.T) {
t.Parallel()
hub := NewHub(1)
slow, err := hub.Register(StreamBinding{
UserID: "user-123",
DeviceSessionID: "device-session-1",
})
require.NoError(t, err)
fast, err := hub.Register(StreamBinding{
UserID: "user-123",
DeviceSessionID: "device-session-2",
})
require.NoError(t, err)
hub.Publish(Event{
UserID: "user-123",
EventType: "fleet.updated",
EventID: "event-1",
PayloadBytes: []byte("payload-1"),
})
assertEvent(t, fast.Events(), Event{
UserID: "user-123",
EventType: "fleet.updated",
EventID: "event-1",
PayloadBytes: []byte("payload-1"),
})
hub.Publish(Event{
UserID: "user-123",
EventType: "fleet.updated",
EventID: "event-2",
PayloadBytes: []byte("payload-2"),
})
select {
case <-slow.Done():
case <-time.After(time.Second):
require.FailNow(t, "slow subscription did not close after overflow")
}
assert.ErrorIs(t, slow.Err(), ErrSubscriptionOverflow)
assertEvent(t, fast.Events(), Event{
UserID: "user-123",
EventType: "fleet.updated",
EventID: "event-2",
PayloadBytes: []byte("payload-2"),
})
}
func TestHubRevokeDeviceSessionClosesOnlyMatchingSubscriptions(t *testing.T) {
t.Parallel()
hub := NewHub(4)
targetOne, err := hub.Register(StreamBinding{
UserID: "user-123",
DeviceSessionID: "device-session-1",
})
require.NoError(t, err)
targetTwo, err := hub.Register(StreamBinding{
UserID: "user-456",
DeviceSessionID: "device-session-1",
})
require.NoError(t, err)
otherSession, err := hub.Register(StreamBinding{
UserID: "user-123",
DeviceSessionID: "device-session-2",
})
require.NoError(t, err)
hub.RevokeDeviceSession("device-session-1")
select {
case <-targetOne.Done():
case <-time.After(time.Second):
require.FailNow(t, "first matching subscription did not close after revoke")
}
select {
case <-targetTwo.Done():
case <-time.After(time.Second):
require.FailNow(t, "second matching subscription did not close after revoke")
}
assert.ErrorIs(t, targetOne.Err(), ErrSubscriptionRevoked)
assert.ErrorIs(t, targetTwo.Err(), ErrSubscriptionRevoked)
select {
case <-otherSession.Done():
require.FailNow(t, "unrelated session subscription closed after revoke")
case <-time.After(50 * time.Millisecond):
}
hub.Publish(Event{
UserID: "user-123",
DeviceSessionID: "device-session-2",
EventType: "fleet.updated",
EventID: "event-1",
PayloadBytes: []byte("payload-1"),
})
assertEvent(t, otherSession.Events(), Event{
UserID: "user-123",
DeviceSessionID: "device-session-2",
EventType: "fleet.updated",
EventID: "event-1",
PayloadBytes: []byte("payload-1"),
})
}
func TestHubRevokeDeviceSessionIgnoresUnknownOrEmptySession(t *testing.T) {
t.Parallel()
hub := NewHub(4)
subscription, err := hub.Register(StreamBinding{
UserID: "user-123",
DeviceSessionID: "device-session-1",
})
require.NoError(t, err)
hub.RevokeDeviceSession("")
hub.RevokeDeviceSession("missing-session")
select {
case <-subscription.Done():
require.FailNow(t, "subscription closed for empty or unknown session revoke")
case <-time.After(50 * time.Millisecond):
}
}
func assertEvent(t *testing.T, eventCh <-chan Event, want Event) {
t.Helper()
select {
case got := <-eventCh:
assert.Equal(t, want, got)
case <-time.After(time.Second):
require.FailNow(t, "event was not delivered")
}
}
func assertNoEvent(t *testing.T, eventCh <-chan Event) {
t.Helper()
select {
case got := <-eventCh:
require.FailNowf(t, "unexpected event delivered", "%+v", got)
case <-time.After(50 * time.Millisecond):
}
}
+136
View File
@@ -0,0 +1,136 @@
// Package ratelimit provides small process-local rate-limit primitives used by
// the gateway edge policy layers.
package ratelimit
import (
"sync"
"time"
"golang.org/x/time/rate"
)
// Policy describes one token-bucket budget enforced for a concrete key.
type Policy struct {
// Requests is the number of accepted requests replenished per Window.
Requests int
// Window is the interval over which Requests are replenished.
Window time.Duration
// Burst is the maximum number of immediately available tokens.
Burst int
}
// Decision describes the result of one limiter reservation attempt.
type Decision struct {
// Allowed reports whether the request may proceed immediately.
Allowed bool
// RetryAfter is the minimum delay the caller should wait before retrying
// when Allowed is false.
RetryAfter time.Duration
}
// Limiter applies a policy to one concrete key.
type Limiter interface {
// Reserve evaluates key under policy and reports whether the request may
// proceed immediately.
Reserve(key string, policy Policy) Decision
}
// InMemory is a process-local Limiter backed by x/time/rate token buckets.
type InMemory struct {
now func() time.Time
cleanupInterval time.Duration
mu sync.Mutex
entries map[string]*entry
nextCleanup time.Time
}
type entry struct {
limiter *rate.Limiter
limit rate.Limit
burst int
expiresAt time.Time
}
// NewInMemory constructs a process-local limiter suitable for one gateway
// process instance.
func NewInMemory() *InMemory {
return &InMemory{
now: time.Now,
cleanupInterval: time.Minute,
entries: make(map[string]*entry),
}
}
// Reserve evaluates key against policy and reports whether the request may
// proceed immediately.
func (l *InMemory) Reserve(key string, policy Policy) Decision {
if policy.Requests <= 0 || policy.Window <= 0 || policy.Burst <= 0 {
return Decision{}
}
now := l.now()
limit := rate.Limit(float64(policy.Requests) / policy.Window.Seconds())
l.mu.Lock()
defer l.mu.Unlock()
l.cleanupExpiredBucketsLocked(now)
current, ok := l.entries[key]
if !ok || current.limit != limit || current.burst != policy.Burst {
current = &entry{
limiter: rate.NewLimiter(limit, policy.Burst),
limit: limit,
burst: policy.Burst,
}
l.entries[key] = current
}
current.expiresAt = now.Add(entryTTL(policy.Window))
reservation := current.limiter.ReserveN(now, 1)
if !reservation.OK() {
return Decision{
Allowed: false,
RetryAfter: policy.Window,
}
}
retryAfter := reservation.DelayFrom(now)
if retryAfter > 0 {
return Decision{
Allowed: false,
RetryAfter: retryAfter,
}
}
return Decision{Allowed: true}
}
func (l *InMemory) cleanupExpiredBucketsLocked(now time.Time) {
if !l.nextCleanup.IsZero() && now.Before(l.nextCleanup) {
return
}
for key, current := range l.entries {
if !current.expiresAt.After(now) {
delete(l.entries, key)
}
}
l.nextCleanup = now.Add(l.cleanupInterval)
}
func entryTTL(window time.Duration) time.Duration {
if window < time.Minute {
return time.Minute
}
return 2 * window
}
var _ Limiter = (*InMemory)(nil)
@@ -0,0 +1,49 @@
package ratelimit
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestInMemoryReserve(t *testing.T) {
t.Parallel()
limiter := NewInMemory()
policy := Policy{
Requests: 1,
Window: time.Hour,
Burst: 1,
}
first := limiter.Reserve("bucket-1", policy)
second := limiter.Reserve("bucket-1", policy)
otherBucket := limiter.Reserve("bucket-2", policy)
assert.True(t, first.Allowed)
assert.False(t, second.Allowed)
assert.Positive(t, second.RetryAfter)
assert.True(t, otherBucket.Allowed)
}
func TestInMemoryReserveResetsOnPolicyChange(t *testing.T) {
t.Parallel()
limiter := NewInMemory()
initialPolicy := Policy{
Requests: 1,
Window: time.Hour,
Burst: 1,
}
updatedPolicy := Policy{
Requests: 2,
Window: time.Hour,
Burst: 2,
}
assert.True(t, limiter.Reserve("bucket-1", initialPolicy).Allowed)
assert.False(t, limiter.Reserve("bucket-1", initialPolicy).Allowed)
assert.True(t, limiter.Reserve("bucket-1", updatedPolicy).Allowed)
}
+131
View File
@@ -0,0 +1,131 @@
package replay
import (
"context"
"crypto/tls"
"encoding/base64"
"errors"
"fmt"
"strings"
"time"
"galaxy/gateway/internal/config"
"github.com/redis/go-redis/v9"
)
// RedisStore implements Store with Redis SETNX reservations over a dedicated
// key namespace.
type RedisStore struct {
client *redis.Client
keyPrefix string
reserveTimeout time.Duration
}
// NewRedisStore constructs a Redis-backed replay store that reuses the
// SessionCache Redis deployment settings and applies the replay-specific key
// namespace and timeout controls from replayCfg.
func NewRedisStore(sessionCfg config.SessionCacheRedisConfig, replayCfg config.ReplayRedisConfig) (*RedisStore, error) {
if strings.TrimSpace(sessionCfg.Addr) == "" {
return nil, errors.New("new redis replay store: redis addr must not be empty")
}
if sessionCfg.DB < 0 {
return nil, errors.New("new redis replay store: redis db must not be negative")
}
if strings.TrimSpace(replayCfg.KeyPrefix) == "" {
return nil, errors.New("new redis replay store: replay key prefix must not be empty")
}
if replayCfg.ReserveTimeout <= 0 {
return nil, errors.New("new redis replay store: reserve timeout must be positive")
}
options := &redis.Options{
Addr: sessionCfg.Addr,
Username: sessionCfg.Username,
Password: sessionCfg.Password,
DB: sessionCfg.DB,
Protocol: 2,
DisableIdentity: true,
}
if sessionCfg.TLSEnabled {
options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
}
return &RedisStore{
client: redis.NewClient(options),
keyPrefix: replayCfg.KeyPrefix,
reserveTimeout: replayCfg.ReserveTimeout,
}, nil
}
// Close releases the underlying Redis client resources.
func (s *RedisStore) Close() error {
if s == nil || s.client == nil {
return nil
}
return s.client.Close()
}
// Ping verifies that the configured Redis backend is reachable within the
// replay reserve timeout budget.
func (s *RedisStore) Ping(ctx context.Context) error {
if s == nil || s.client == nil {
return errors.New("ping redis replay store: nil store")
}
if ctx == nil {
return errors.New("ping redis replay store: nil context")
}
pingCtx, cancel := context.WithTimeout(ctx, s.reserveTimeout)
defer cancel()
if err := s.client.Ping(pingCtx).Err(); err != nil {
return fmt.Errorf("ping redis replay store: %w", err)
}
return nil
}
// Reserve records the authenticated deviceSessionID and requestID pair for
// ttl. It rejects duplicates while the reservation remains active.
func (s *RedisStore) Reserve(ctx context.Context, deviceSessionID string, requestID string, ttl time.Duration) error {
if s == nil || s.client == nil {
return errors.New("reserve replay request in redis: nil store")
}
if ctx == nil {
return errors.New("reserve replay request in redis: nil context")
}
if strings.TrimSpace(deviceSessionID) == "" {
return errors.New("reserve replay request in redis: empty device session id")
}
if strings.TrimSpace(requestID) == "" {
return errors.New("reserve replay request in redis: empty request id")
}
if ttl <= 0 {
return errors.New("reserve replay request in redis: ttl must be positive")
}
reserveCtx, cancel := context.WithTimeout(ctx, s.reserveTimeout)
defer cancel()
reserved, err := s.client.SetNX(reserveCtx, s.reservationKey(deviceSessionID, requestID), "1", ttl).Result()
if err != nil {
return fmt.Errorf("reserve replay request in redis: %w", err)
}
if !reserved {
return fmt.Errorf("reserve replay request in redis: %w", ErrDuplicate)
}
return nil
}
func (s *RedisStore) reservationKey(deviceSessionID string, requestID string) string {
return s.keyPrefix + encodeKeyComponent(deviceSessionID) + ":" + encodeKeyComponent(requestID)
}
func encodeKeyComponent(value string) string {
return base64.RawURLEncoding.EncodeToString([]byte(value))
}
var _ Store = (*RedisStore)(nil)
+254
View File
@@ -0,0 +1,254 @@
package replay
import (
"context"
"errors"
"net"
"testing"
"time"
"galaxy/gateway/internal/config"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewRedisStore(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
tests := []struct {
name string
sessionCfg config.SessionCacheRedisConfig
replayCfg config.ReplayRedisConfig
wantErr string
}{
{
name: "valid config",
sessionCfg: config.SessionCacheRedisConfig{
Addr: server.Addr(),
DB: 2,
},
replayCfg: config.ReplayRedisConfig{
KeyPrefix: "gateway:replay:",
ReserveTimeout: 250 * time.Millisecond,
},
},
{
name: "empty redis addr",
replayCfg: config.ReplayRedisConfig{
KeyPrefix: "gateway:replay:",
ReserveTimeout: 250 * time.Millisecond,
},
wantErr: "redis addr must not be empty",
},
{
name: "negative redis db",
sessionCfg: config.SessionCacheRedisConfig{
Addr: server.Addr(),
DB: -1,
},
replayCfg: config.ReplayRedisConfig{
KeyPrefix: "gateway:replay:",
ReserveTimeout: 250 * time.Millisecond,
},
wantErr: "redis db must not be negative",
},
{
name: "empty replay key prefix",
sessionCfg: config.SessionCacheRedisConfig{
Addr: server.Addr(),
},
replayCfg: config.ReplayRedisConfig{
ReserveTimeout: 250 * time.Millisecond,
},
wantErr: "replay key prefix must not be empty",
},
{
name: "non-positive reserve timeout",
sessionCfg: config.SessionCacheRedisConfig{
Addr: server.Addr(),
},
replayCfg: config.ReplayRedisConfig{
KeyPrefix: "gateway:replay:",
},
wantErr: "reserve timeout must be positive",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
store, err := NewRedisStore(tt.sessionCfg, tt.replayCfg)
if tt.wantErr != "" {
require.Error(t, err)
require.ErrorContains(t, err, tt.wantErr)
return
}
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, store.Close())
})
})
}
}
func TestRedisStorePing(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestRedisStore(t, server, config.SessionCacheRedisConfig{}, config.ReplayRedisConfig{})
require.NoError(t, store.Ping(context.Background()))
}
func TestRedisStoreReserve(t *testing.T) {
t.Parallel()
tests := []struct {
name string
sessionCfg config.SessionCacheRedisConfig
replayCfg config.ReplayRedisConfig
deviceSessionID string
requestID string
ttl time.Duration
secondReserve func(*testing.T, Store)
wantErrIs error
wantErrText string
}{
{
name: "first reservation succeeds",
deviceSessionID: "device-session-123",
requestID: "request-123",
ttl: 5 * time.Second,
},
{
name: "duplicate reservation is rejected",
deviceSessionID: "device-session-123",
requestID: "request-123",
ttl: 5 * time.Second,
secondReserve: func(t *testing.T, store Store) {
t.Helper()
err := store.Reserve(context.Background(), "device-session-123", "request-123", 5*time.Second)
require.ErrorIs(t, err, ErrDuplicate)
},
},
{
name: "same request id in distinct sessions does not collide",
deviceSessionID: "device-session-123",
requestID: "request-123",
ttl: 5 * time.Second,
secondReserve: func(t *testing.T, store Store) {
t.Helper()
require.NoError(t, store.Reserve(context.Background(), "device-session-456", "request-123", 5*time.Second))
},
},
{
name: "empty device session id",
requestID: "request-123",
ttl: 5 * time.Second,
wantErrText: "empty device session id",
},
{
name: "empty request id",
deviceSessionID: "device-session-123",
ttl: 5 * time.Second,
wantErrText: "empty request id",
},
{
name: "non-positive ttl",
deviceSessionID: "device-session-123",
requestID: "request-123",
wantErrText: "ttl must be positive",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := newTestRedisStore(t, server, tt.sessionCfg, tt.replayCfg)
err := store.Reserve(context.Background(), tt.deviceSessionID, tt.requestID, tt.ttl)
if tt.wantErrIs != nil || tt.wantErrText != "" {
require.Error(t, err)
if tt.wantErrIs != nil {
require.ErrorIs(t, err, tt.wantErrIs)
}
if tt.wantErrText != "" {
require.ErrorContains(t, err, tt.wantErrText)
}
return
}
require.NoError(t, err)
if tt.secondReserve != nil {
tt.secondReserve(t, store)
}
})
}
}
func TestRedisStoreReserveReturnsBackendError(t *testing.T) {
t.Parallel()
store, err := NewRedisStore(
config.SessionCacheRedisConfig{Addr: unusedTCPAddr(t)},
config.ReplayRedisConfig{
KeyPrefix: "gateway:replay:",
ReserveTimeout: 100 * time.Millisecond,
},
)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, store.Close())
})
err = store.Reserve(context.Background(), "device-session-123", "request-123", 5*time.Second)
require.Error(t, err)
assert.False(t, errors.Is(err, ErrDuplicate))
assert.ErrorContains(t, err, "reserve replay request in redis")
}
func newTestRedisStore(t *testing.T, server *miniredis.Miniredis, sessionCfg config.SessionCacheRedisConfig, replayCfg config.ReplayRedisConfig) *RedisStore {
t.Helper()
if sessionCfg.Addr == "" {
sessionCfg.Addr = server.Addr()
}
if replayCfg.KeyPrefix == "" {
replayCfg.KeyPrefix = "gateway:replay:"
}
if replayCfg.ReserveTimeout == 0 {
replayCfg.ReserveTimeout = 250 * time.Millisecond
}
store, err := NewRedisStore(sessionCfg, replayCfg)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, store.Close())
})
return store
}
func unusedTCPAddr(t *testing.T) string {
t.Helper()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
addr := listener.Addr().String()
require.NoError(t, listener.Close())
return addr
}
+24
View File
@@ -0,0 +1,24 @@
// Package replay defines the authenticated replay-reservation contract used by
// the gateway transport pipeline.
package replay
import (
"context"
"errors"
"time"
)
var (
// ErrDuplicate reports that the request identifier has already been
// reserved for the same device session within the active replay window.
ErrDuplicate = errors.New("replay reservation already exists")
)
// Store reserves authenticated transport request identifiers for a bounded
// replay window.
type Store interface {
// Reserve marks the deviceSessionID and requestID pair as seen for ttl.
// Implementations must wrap ErrDuplicate when the same pair is reserved
// again before ttl expires.
Reserve(ctx context.Context, deviceSessionID string, requestID string, ttl time.Duration) error
}
+76
View File
@@ -0,0 +1,76 @@
package restapi
import (
"time"
"galaxy/gateway/internal/logging"
"galaxy/gateway/internal/telemetry"
"github.com/gin-gonic/gin"
"go.opentelemetry.io/otel/attribute"
"go.uber.org/zap"
)
func withPublicObservability(logger *zap.Logger, metrics *telemetry.Runtime) gin.HandlerFunc {
if logger == nil {
logger = zap.NewNop()
}
return func(c *gin.Context) {
start := time.Now()
c.Next()
statusCode := c.Writer.Status()
route := c.FullPath()
if route == "" {
route = c.Request.URL.Path
}
class, ok := PublicRouteClassFromContext(c.Request.Context())
if !ok {
class = PublicRouteClassPublicMisc
}
errorCode, _ := c.Get(publicErrorCodeContextKey)
errorCodeValue, _ := errorCode.(string)
outcome := telemetry.OutcomeFromPublicErrorCode(statusCode, errorCodeValue)
rejectReason := telemetry.RejectReason(outcome)
duration := time.Since(start)
attrs := []attribute.KeyValue{
attribute.String("route_class", string(class)),
attribute.String("route", route),
attribute.String("method", c.Request.Method),
attribute.String("edge_outcome", string(outcome)),
}
if rejectReason != "" {
attrs = append(attrs, attribute.String("reject_reason", rejectReason))
}
metrics.RecordPublicRequest(c.Request.Context(), attrs, duration)
fields := []zap.Field{
zap.String("component", "public_http"),
zap.String("transport", "http"),
zap.String("route", route),
zap.String("route_class", string(class)),
zap.String("method", c.Request.Method),
zap.Int("status_code", statusCode),
zap.Float64("duration_ms", float64(duration.Microseconds())/1000),
zap.String("edge_outcome", string(outcome)),
}
if rejectReason != "" {
fields = append(fields, zap.String("reject_reason", rejectReason))
}
fields = append(fields, logging.TraceFieldsFromContext(c.Request.Context())...)
switch outcome {
case telemetry.EdgeOutcomeSuccess:
logger.Info("public request completed", fields...)
case telemetry.EdgeOutcomeBackendUnavailable, telemetry.EdgeOutcomeInternalError:
logger.Error("public request failed", fields...)
default:
logger.Warn("public request rejected", fields...)
}
}
}
+30
View File
@@ -0,0 +1,30 @@
package restapi
import (
"context"
"path/filepath"
"runtime"
"testing"
"github.com/getkin/kin-openapi/openapi3"
"github.com/stretchr/testify/require"
)
func TestPublicOpenAPISpecValidates(t *testing.T) {
t.Parallel()
_, thisFile, _, ok := runtime.Caller(0)
require.True(t, ok)
specPath := filepath.Join(filepath.Dir(thisFile), "..", "..", "openapi.yaml")
ctx := context.Background()
loader := openapi3.NewLoader()
doc, err := loader.LoadFromFile(specPath)
require.NoError(t, err)
require.NotNil(t, doc)
require.NotNil(t, doc.Info)
require.Equal(t, "v1", doc.Info.Version)
require.NoError(t, doc.Validate(ctx))
}
@@ -0,0 +1,378 @@
package restapi
import (
"bytes"
"errors"
"io"
"math"
"net"
"net/http"
"path"
"strconv"
"strings"
"sync"
"time"
"galaxy/gateway/internal/config"
"github.com/gin-gonic/gin"
"golang.org/x/time/rate"
)
const (
errorCodeRequestTooLarge = "request_too_large"
errorCodeRateLimited = "rate_limited"
publicRESTIPBucketKeySegment = "/ip="
)
var errRequestBodyTooLarge = errors.New("request body exceeds the configured limit")
// PublicMalformedRequestReason identifies the stable malformed-request counter
// dimension recorded by the public REST anti-abuse middleware.
type PublicMalformedRequestReason string
const (
// PublicMalformedRequestReasonEmptyBody records a missing request body.
PublicMalformedRequestReasonEmptyBody PublicMalformedRequestReason = "empty_body"
// PublicMalformedRequestReasonMalformedJSON records syntactically malformed
// JSON.
PublicMalformedRequestReasonMalformedJSON PublicMalformedRequestReason = "malformed_json"
// PublicMalformedRequestReasonInvalidJSONValue records JSON values whose
// types do not match the expected request schema.
PublicMalformedRequestReasonInvalidJSONValue PublicMalformedRequestReason = "invalid_json_value"
// PublicMalformedRequestReasonUnknownField records JSON objects with fields
// outside the documented schema.
PublicMalformedRequestReasonUnknownField PublicMalformedRequestReason = "unknown_field"
// PublicMalformedRequestReasonMultipleJSONObjects records requests that
// contain more than one JSON object.
PublicMalformedRequestReasonMultipleJSONObjects PublicMalformedRequestReason = "multiple_json_objects"
// PublicMalformedRequestReasonOversizedBody records requests whose bodies
// exceed the configured class limit.
PublicMalformedRequestReasonOversizedBody PublicMalformedRequestReason = "oversized_body"
)
// PublicRateLimitDecision describes the outcome returned by a public REST
// limiter for one request bucket reservation attempt.
type PublicRateLimitDecision struct {
// Allowed reports whether the request may proceed immediately.
Allowed bool
// RetryAfter is the minimum delay the client should wait before retrying
// when Allowed is false.
RetryAfter time.Duration
}
// PublicRequestLimiter applies public REST rate-limit policy to a concrete
// bucket key.
type PublicRequestLimiter interface {
// Reserve evaluates key under policy and returns whether the request may
// proceed immediately.
Reserve(key string, policy config.PublicRateLimitConfig) PublicRateLimitDecision
}
// PublicRequestObserver captures low-cardinality public REST anti-abuse
// telemetry.
type PublicRequestObserver interface {
// RecordMalformedRequest records one malformed request in class for reason.
RecordMalformedRequest(class PublicRouteClass, reason PublicMalformedRequestReason)
}
type noopPublicRequestObserver struct{}
func (noopPublicRequestObserver) RecordMalformedRequest(PublicRouteClass, PublicMalformedRequestReason) {
}
type inMemoryPublicRequestLimiter struct {
now func() time.Time
cleanupInterval time.Duration
mu sync.Mutex
entries map[string]*publicRateLimiterEntry
nextCleanup time.Time
}
type publicRateLimiterEntry struct {
limiter *rate.Limiter
limit rate.Limit
burst int
expiresAt time.Time
}
func newInMemoryPublicRequestLimiter() *inMemoryPublicRequestLimiter {
return &inMemoryPublicRequestLimiter{
now: time.Now,
cleanupInterval: time.Minute,
entries: make(map[string]*publicRateLimiterEntry),
}
}
func (l *inMemoryPublicRequestLimiter) Reserve(key string, policy config.PublicRateLimitConfig) PublicRateLimitDecision {
now := l.now()
limit := rate.Limit(float64(policy.Requests) / policy.Window.Seconds())
l.mu.Lock()
defer l.mu.Unlock()
l.cleanupExpiredBucketsLocked(now)
entry, ok := l.entries[key]
if !ok || entry.limit != limit || entry.burst != policy.Burst {
entry = &publicRateLimiterEntry{
limiter: rate.NewLimiter(limit, policy.Burst),
limit: limit,
burst: policy.Burst,
}
l.entries[key] = entry
}
entry.expiresAt = now.Add(publicRateLimiterEntryTTL(policy.Window))
reservation := entry.limiter.ReserveN(now, 1)
if !reservation.OK() {
return PublicRateLimitDecision{
Allowed: false,
RetryAfter: policy.Window,
}
}
retryAfter := reservation.DelayFrom(now)
if retryAfter > 0 {
return PublicRateLimitDecision{
Allowed: false,
RetryAfter: retryAfter,
}
}
return PublicRateLimitDecision{Allowed: true}
}
func (l *inMemoryPublicRequestLimiter) cleanupExpiredBucketsLocked(now time.Time) {
if !l.nextCleanup.IsZero() && now.Before(l.nextCleanup) {
return
}
for key, entry := range l.entries {
if !entry.expiresAt.After(now) {
delete(l.entries, key)
}
}
l.nextCleanup = now.Add(l.cleanupInterval)
}
func publicRateLimiterEntryTTL(window time.Duration) time.Duration {
if window < time.Minute {
return time.Minute
}
return 2 * window
}
func withPublicAntiAbuse(policy config.PublicHTTPAntiAbuseConfig, limiter PublicRequestLimiter, observer PublicRequestObserver) gin.HandlerFunc {
return func(c *gin.Context) {
class, ok := PublicRouteClassFromContext(c.Request.Context())
if !ok {
class = PublicRouteClassPublicMisc
}
allowedMethods := allowedMethodsForRequestShape(c.Request)
if len(allowedMethods) > 0 && !isAllowedMethod(c.Request.Method, allowedMethods) {
c.Header("Allow", strings.Join(allowedMethods, ", "))
abortWithError(c, http.StatusMethodNotAllowed, errorCodeMethodNotAllowed, "request method is not allowed for this route")
return
}
classPolicy := publicRoutePolicyForClass(policy, class)
bodyBytes, err := bufferRequestBody(c.Request, classPolicy.MaxBodyBytes)
if err != nil {
switch {
case errors.Is(err, errRequestBodyTooLarge):
observer.RecordMalformedRequest(class, PublicMalformedRequestReasonOversizedBody)
abortWithError(c, http.StatusRequestEntityTooLarge, errorCodeRequestTooLarge, "request body exceeds the configured limit")
default:
abortWithError(c, http.StatusInternalServerError, errorCodeInternalError, "internal server error")
}
return
}
clientIP := clientIPFromRemoteAddr(c.Request.RemoteAddr)
if decision := limiter.Reserve(publicRESTIPBucketKey(class, clientIP), classPolicy.RateLimit); !decision.Allowed {
abortRateLimited(c, decision.RetryAfter)
return
}
identity, err := extractPublicAuthIdentity(c.Request.URL.Path, bodyBytes)
switch {
case err == nil:
identityPolicy := publicAuthIdentityPolicyForPath(c.Request.URL.Path, policy)
if decision := limiter.Reserve(publicAuthIdentityBucketKey(class, identity.kind, identity.value), identityPolicy.RateLimit); !decision.Allowed {
abortRateLimited(c, decision.RetryAfter)
return
}
case errors.Is(err, errPublicAuthIdentityNotApplicable):
default:
if reason, malformed := malformedRequestReasonFromError(err); malformed {
observer.RecordMalformedRequest(class, reason)
}
}
c.Next()
}
}
func publicRoutePolicyForClass(policy config.PublicHTTPAntiAbuseConfig, class PublicRouteClass) config.PublicRoutePolicyConfig {
switch class.Normalized() {
case PublicRouteClassPublicAuth:
return policy.PublicAuth
case PublicRouteClassBrowserBootstrap:
return policy.BrowserBootstrap
case PublicRouteClassBrowserAsset:
return policy.BrowserAsset
default:
return policy.PublicMisc
}
}
func publicAuthIdentityPolicyForPath(requestPath string, policy config.PublicHTTPAntiAbuseConfig) config.PublicAuthIdentityPolicyConfig {
switch requestPath {
case "/api/v1/public/auth/send-email-code":
return policy.SendEmailCodeIdentity
case "/api/v1/public/auth/confirm-email-code":
return policy.ConfirmEmailCodeIdentity
default:
return config.PublicAuthIdentityPolicyConfig{}
}
}
func allowedMethodsForRequestShape(r *http.Request) []string {
switch {
case isPublicAuthPath(r.URL.Path):
return []string{http.MethodPost}
case isProbePath(r.URL.Path):
return []string{http.MethodGet}
case matchesBrowserAssetRequestShape(r):
return []string{http.MethodGet, http.MethodHead}
case matchesBrowserBootstrapRequestShape(r):
return []string{http.MethodGet, http.MethodHead}
default:
return nil
}
}
func isAllowedMethod(method string, allowedMethods []string) bool {
for _, allowedMethod := range allowedMethods {
if method == allowedMethod {
return true
}
}
return false
}
func isPublicAuthPath(requestPath string) bool {
switch requestPath {
case "/api/v1/public/auth/send-email-code", "/api/v1/public/auth/confirm-email-code":
return true
default:
return false
}
}
func isProbePath(requestPath string) bool {
switch requestPath {
case "/healthz", "/readyz":
return true
default:
return false
}
}
func matchesBrowserBootstrapRequestShape(r *http.Request) bool {
if r.URL.Path == "/" {
return true
}
return strings.Contains(strings.ToLower(r.Header.Get("Accept")), "text/html")
}
func matchesBrowserAssetRequestShape(r *http.Request) bool {
if strings.HasPrefix(r.URL.Path, "/assets/") {
return true
}
switch strings.ToLower(path.Ext(r.URL.Path)) {
case ".js", ".mjs", ".css", ".map", ".png", ".jpg", ".jpeg", ".gif", ".svg", ".ico", ".woff", ".woff2", ".json", ".webmanifest":
return true
default:
return false
}
}
func bufferRequestBody(r *http.Request, maxBodyBytes int64) ([]byte, error) {
if r == nil {
return nil, nil
}
if r.Body == nil {
r.Body = io.NopCloser(bytes.NewReader(nil))
return nil, nil
}
bodyBytes, err := io.ReadAll(io.LimitReader(r.Body, maxBodyBytes+1))
closeErr := r.Body.Close()
if err != nil {
return nil, err
}
if closeErr != nil {
return nil, closeErr
}
if int64(len(bodyBytes)) > maxBodyBytes {
return nil, errRequestBodyTooLarge
}
r.Body = io.NopCloser(bytes.NewReader(bodyBytes))
return bodyBytes, nil
}
func abortRateLimited(c *gin.Context, retryAfter time.Duration) {
c.Header("Retry-After", retryAfterHeaderValue(retryAfter))
abortWithError(c, http.StatusTooManyRequests, errorCodeRateLimited, "request rate limit exceeded")
}
func retryAfterHeaderValue(delay time.Duration) string {
seconds := int64(math.Ceil(delay.Seconds()))
if seconds < 1 {
seconds = 1
}
return strconv.FormatInt(seconds, 10)
}
func clientIPFromRemoteAddr(remoteAddr string) string {
host, _, err := net.SplitHostPort(strings.TrimSpace(remoteAddr))
if err == nil {
return host
}
remoteAddr = strings.TrimSpace(remoteAddr)
if remoteAddr == "" {
return "unknown"
}
return remoteAddr
}
func publicRESTIPBucketKey(class PublicRouteClass, clientIP string) string {
return class.BaseBucketKey() + publicRESTIPBucketKeySegment + clientIP
}
func publicAuthIdentityBucketKey(class PublicRouteClass, identityKind string, identityValue string) string {
return class.BaseBucketKey() + "/" + identityKind + "=" + identityValue
}
@@ -0,0 +1,455 @@
package restapi
import (
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
"galaxy/gateway/internal/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPublicAntiAbuseRejectsOversizedBodies(t *testing.T) {
t.Parallel()
oversizedJSONBody := `{"email":"` + strings.Repeat("a", 8200) + `@example.com"}`
oversizedConfirmJSONBody := `{"challenge_id":"` + strings.Repeat("c", 8300) + `","code":"123456","client_public_key":"key"}`
tests := []struct {
name string
method string
target string
body string
wantClass PublicRouteClass
}{
{
name: "send email",
method: http.MethodPost,
target: "/api/v1/public/auth/send-email-code",
body: oversizedJSONBody,
wantClass: PublicRouteClassPublicAuth,
},
{
name: "confirm email",
method: http.MethodPost,
target: "/api/v1/public/auth/confirm-email-code",
body: oversizedConfirmJSONBody,
wantClass: PublicRouteClassPublicAuth,
},
{
name: "healthz body",
method: http.MethodGet,
target: "/healthz",
body: `x`,
wantClass: PublicRouteClassPublicMisc,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
observer := &recordingPublicRequestObserver{}
authService := &recordingAuthServiceClient{
sendEmailCodeResult: SendEmailCodeResult{ChallengeID: "challenge-123"},
confirmEmailCodeResult: ConfirmEmailCodeResult{
DeviceSessionID: "device-session-123",
},
}
handler := newPublicHandlerWithConfig(config.DefaultPublicHTTPConfig(), ServerDependencies{
AuthService: authService,
Observer: observer,
})
req := httptest.NewRequest(tt.method, tt.target, strings.NewReader(tt.body))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusRequestEntityTooLarge, recorder.Code)
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
assert.Equal(t, `{"error":{"code":"request_too_large","message":"request body exceeds the configured limit"}}`, recorder.Body.String())
assert.Equal(t, 0, authService.sendEmailCodeCalls)
assert.Equal(t, 0, authService.confirmEmailCodeCalls)
assert.Equal(t, []malformedObservation{{
class: tt.wantClass,
reason: PublicMalformedRequestReasonOversizedBody,
}}, observer.snapshot())
})
}
}
func TestPublicAntiAbuseRejectsInvalidMethodsForBrowserShapes(t *testing.T) {
t.Parallel()
handler := newPublicHandler(ServerDependencies{})
tests := []struct {
name string
method string
target string
accept string
wantAllow string
}{
{
name: "asset path",
method: http.MethodPost,
target: "/assets/app.js",
wantAllow: "GET, HEAD",
},
{
name: "bootstrap request",
method: http.MethodPost,
target: "/",
accept: "text/html",
wantAllow: "GET, HEAD",
},
{
name: "head probe rejected",
method: http.MethodHead,
target: "/healthz",
wantAllow: http.MethodGet,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest(tt.method, tt.target, nil)
if tt.accept != "" {
req.Header.Set("Accept", tt.accept)
}
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusMethodNotAllowed, recorder.Code)
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
assert.Equal(t, tt.wantAllow, recorder.Header().Get("Allow"))
assert.Equal(t, `{"error":{"code":"method_not_allowed","message":"request method is not allowed for this route"}}`, recorder.Body.String())
})
}
}
func TestPublicAntiAbuseBrowserClassBucketsStayIsolatedFromPublicAuth(t *testing.T) {
t.Parallel()
tests := []struct {
name string
burstRequest *http.Request
}{
{
name: "browser asset",
burstRequest: httptest.NewRequest(http.MethodGet, "/assets/app.js", nil),
},
{
name: "browser bootstrap",
burstRequest: func() *http.Request {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Accept", "text/html")
return req
}(),
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
cfg := config.DefaultPublicHTTPConfig()
cfg.AntiAbuse.BrowserAsset.RateLimit = config.PublicRateLimitConfig{
Requests: 1,
Window: time.Hour,
Burst: 1,
}
cfg.AntiAbuse.BrowserBootstrap.RateLimit = config.PublicRateLimitConfig{
Requests: 1,
Window: time.Hour,
Burst: 1,
}
cfg.AntiAbuse.PublicAuth.RateLimit = config.PublicRateLimitConfig{
Requests: 100,
Window: time.Hour,
Burst: 100,
}
authService := &recordingAuthServiceClient{
sendEmailCodeResult: SendEmailCodeResult{
ChallengeID: "challenge-123",
},
}
handler := newPublicHandlerWithConfig(cfg, ServerDependencies{AuthService: authService})
tt.burstRequest.RemoteAddr = "192.0.2.10:1234"
firstBurst := httptest.NewRecorder()
handler.ServeHTTP(firstBurst, tt.burstRequest.Clone(tt.burstRequest.Context()))
secondBurst := httptest.NewRecorder()
handler.ServeHTTP(secondBurst, tt.burstRequest.Clone(tt.burstRequest.Context()))
authReq := httptest.NewRequest(http.MethodPost, "/api/v1/public/auth/send-email-code", strings.NewReader(`{"email":"pilot@example.com"}`))
authReq.Header.Set("Content-Type", "application/json")
authReq.RemoteAddr = "192.0.2.10:1234"
authResp := httptest.NewRecorder()
handler.ServeHTTP(authResp, authReq)
assert.Equal(t, http.StatusNotFound, firstBurst.Code)
assert.Equal(t, http.StatusTooManyRequests, secondBurst.Code)
assert.Equal(t, http.StatusOK, authResp.Code)
assert.Equal(t, `{"challenge_id":"challenge-123"}`, authResp.Body.String())
assert.Equal(t, 1, authService.sendEmailCodeCalls)
})
}
}
func TestPublicAntiAbuseSendEmailIdentityThrottle(t *testing.T) {
t.Parallel()
cfg := config.DefaultPublicHTTPConfig()
cfg.AntiAbuse.PublicAuth.RateLimit = config.PublicRateLimitConfig{
Requests: 100,
Window: time.Hour,
Burst: 100,
}
cfg.AntiAbuse.SendEmailCodeIdentity.RateLimit = config.PublicRateLimitConfig{
Requests: 1,
Window: time.Hour,
Burst: 1,
}
authService := &recordingAuthServiceClient{
sendEmailCodeResult: SendEmailCodeResult{
ChallengeID: "challenge-123",
},
}
handler := newPublicHandlerWithConfig(cfg, ServerDependencies{AuthService: authService})
first := sendEmailCodeRequest(`{"email":"pilot@example.com"}`)
second := sendEmailCodeRequest(`{"email":"pilot@example.com"}`)
third := sendEmailCodeRequest(`{"email":"other@example.com"}`)
firstResp := httptest.NewRecorder()
handler.ServeHTTP(firstResp, first)
secondResp := httptest.NewRecorder()
handler.ServeHTTP(secondResp, second)
thirdResp := httptest.NewRecorder()
handler.ServeHTTP(thirdResp, third)
assert.Equal(t, http.StatusOK, firstResp.Code)
assert.Equal(t, http.StatusTooManyRequests, secondResp.Code)
assert.Equal(t, "3600", secondResp.Header().Get("Retry-After"))
assert.Equal(t, http.StatusOK, thirdResp.Code)
assert.Equal(t, 2, authService.sendEmailCodeCalls)
thirdInput := authService.sendEmailCodeInput
assert.Equal(t, "other@example.com", thirdInput.Email)
}
func TestPublicAntiAbuseConfirmEmailIdentityThrottle(t *testing.T) {
t.Parallel()
cfg := config.DefaultPublicHTTPConfig()
cfg.AntiAbuse.PublicAuth.RateLimit = config.PublicRateLimitConfig{
Requests: 100,
Window: time.Hour,
Burst: 100,
}
cfg.AntiAbuse.ConfirmEmailCodeIdentity.RateLimit = config.PublicRateLimitConfig{
Requests: 1,
Window: time.Hour,
Burst: 1,
}
authService := &recordingAuthServiceClient{
confirmEmailCodeResult: ConfirmEmailCodeResult{
DeviceSessionID: "device-session-123",
},
}
handler := newPublicHandlerWithConfig(cfg, ServerDependencies{AuthService: authService})
first := confirmEmailCodeRequest(`{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`)
second := confirmEmailCodeRequest(`{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`)
third := confirmEmailCodeRequest(`{"challenge_id":"challenge-456","code":"123456","client_public_key":"public-key-material"}`)
firstResp := httptest.NewRecorder()
handler.ServeHTTP(firstResp, first)
secondResp := httptest.NewRecorder()
handler.ServeHTTP(secondResp, second)
thirdResp := httptest.NewRecorder()
handler.ServeHTTP(thirdResp, third)
assert.Equal(t, http.StatusOK, firstResp.Code)
assert.Equal(t, http.StatusTooManyRequests, secondResp.Code)
assert.Equal(t, "3600", secondResp.Header().Get("Retry-After"))
assert.Equal(t, http.StatusOK, thirdResp.Code)
assert.Equal(t, 2, authService.confirmEmailCodeCalls)
assert.Equal(t, "challenge-456", authService.confirmEmailCodeInput.ChallengeID)
}
func TestPublicAntiAbuseMalformedTelemetry(t *testing.T) {
t.Parallel()
tests := []struct {
name string
body string
wantReason PublicMalformedRequestReason
wantRecords int
}{
{
name: "empty body",
body: ``,
wantReason: PublicMalformedRequestReasonEmptyBody,
wantRecords: 1,
},
{
name: "malformed json",
body: `{"email":`,
wantReason: PublicMalformedRequestReasonMalformedJSON,
wantRecords: 1,
},
{
name: "invalid json value",
body: `{"email":123}`,
wantReason: PublicMalformedRequestReasonInvalidJSONValue,
wantRecords: 1,
},
{
name: "unknown field",
body: `{"email":"pilot@example.com","extra":"x"}`,
wantReason: PublicMalformedRequestReasonUnknownField,
wantRecords: 1,
},
{
name: "multiple objects",
body: `{"email":"pilot@example.com"}{"email":"pilot@example.com"}`,
wantReason: PublicMalformedRequestReasonMultipleJSONObjects,
wantRecords: 1,
},
{
name: "validation error does not count as malformed",
body: `{"email":"not-an-email"}`,
wantRecords: 0,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
observer := &recordingPublicRequestObserver{}
authService := &recordingAuthServiceClient{}
handler := newPublicHandlerWithConfig(config.DefaultPublicHTTPConfig(), ServerDependencies{
AuthService: authService,
Observer: observer,
})
req := sendEmailCodeRequest(tt.body)
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusBadRequest, recorder.Code)
assert.Equal(t, tt.wantRecords, len(observer.snapshot()))
assert.Equal(t, 0, authService.sendEmailCodeCalls)
if tt.wantRecords == 1 {
assert.Equal(t, malformedObservation{
class: PublicRouteClassPublicAuth,
reason: tt.wantReason,
}, observer.snapshot()[0])
}
})
}
}
func TestInMemoryPublicRequestLimiterCleansExpiredBuckets(t *testing.T) {
t.Parallel()
now := time.Unix(1000, 0)
limiter := newInMemoryPublicRequestLimiter()
limiter.now = func() time.Time {
return now
}
limiter.cleanupInterval = time.Second
policy := config.PublicRateLimitConfig{
Requests: 1,
Window: time.Minute,
Burst: 1,
}
firstDecision := limiter.Reserve("bucket-1", policy)
secondDecision := limiter.Reserve("bucket-2", policy)
require.True(t, firstDecision.Allowed)
require.True(t, secondDecision.Allowed)
require.Len(t, limiter.entries, 2)
now = now.Add(3 * time.Minute)
thirdDecision := limiter.Reserve("bucket-3", policy)
require.True(t, thirdDecision.Allowed)
assert.Len(t, limiter.entries, 1)
_, exists := limiter.entries["bucket-3"]
assert.True(t, exists)
}
func sendEmailCodeRequest(body string) *http.Request {
req := httptest.NewRequest(http.MethodPost, "/api/v1/public/auth/send-email-code", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.RemoteAddr = "192.0.2.10:1234"
return req
}
func confirmEmailCodeRequest(body string) *http.Request {
req := httptest.NewRequest(http.MethodPost, "/api/v1/public/auth/confirm-email-code", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.RemoteAddr = "192.0.2.10:1234"
return req
}
type malformedObservation struct {
class PublicRouteClass
reason PublicMalformedRequestReason
}
type recordingPublicRequestObserver struct {
mu sync.Mutex
observations []malformedObservation
}
func (o *recordingPublicRequestObserver) RecordMalformedRequest(class PublicRouteClass, reason PublicMalformedRequestReason) {
o.mu.Lock()
defer o.mu.Unlock()
o.observations = append(o.observations, malformedObservation{
class: class,
reason: reason,
})
}
func (o *recordingPublicRequestObserver) snapshot() []malformedObservation {
o.mu.Lock()
defer o.mu.Unlock()
return append([]malformedObservation(nil), o.observations...)
}
+446
View File
@@ -0,0 +1,446 @@
package restapi
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/mail"
"strings"
"time"
"github.com/gin-gonic/gin"
)
var errPublicAuthIdentityNotApplicable = errors.New("public auth identity does not apply to this route")
type malformedJSONRequestError struct {
message string
reason PublicMalformedRequestReason
}
func (e *malformedJSONRequestError) Error() string {
if e == nil {
return ""
}
return e.message
}
type publicAuthIdentity struct {
kind string
value string
}
// AuthServiceClient defines the consumer-side contract used by public auth
// REST handlers to delegate unauthenticated authentication commands to the
// Auth / Session Service.
type AuthServiceClient interface {
// SendEmailCode starts a login challenge for input.Email and returns the
// challenge identifier that the client must later confirm.
SendEmailCode(ctx context.Context, input SendEmailCodeInput) (SendEmailCodeResult, error)
// ConfirmEmailCode completes a previously issued challenge, registers
// input.ClientPublicKey for the new device session, and returns the created
// device session identifier.
ConfirmEmailCode(ctx context.Context, input ConfirmEmailCodeInput) (ConfirmEmailCodeResult, error)
}
// SendEmailCodeInput describes the public REST and adapter payload used to
// request a login code for a single e-mail address.
type SendEmailCodeInput struct {
// Email is the single client e-mail address that should receive the login
// code challenge.
Email string `json:"email"`
}
// SendEmailCodeResult describes the public REST and adapter payload returned
// after the Auth / Session Service creates a login challenge.
type SendEmailCodeResult struct {
// ChallengeID identifies the issued challenge that must be confirmed by the
// client in the next public auth step.
ChallengeID string `json:"challenge_id"`
}
// ConfirmEmailCodeInput describes the public REST and adapter payload used to
// complete a previously issued login challenge.
type ConfirmEmailCodeInput struct {
// ChallengeID identifies the challenge previously returned by
// SendEmailCode.
ChallengeID string `json:"challenge_id"`
// Code is the verification code delivered to the client by the Auth /
// Session Service.
Code string `json:"code"`
// ClientPublicKey is the standard base64-encoded raw 32-byte Ed25519 public
// key that should be registered for the created device session.
ClientPublicKey string `json:"client_public_key"`
}
// ConfirmEmailCodeResult describes the public REST and adapter payload
// returned after the Auth / Session Service creates a device session.
type ConfirmEmailCodeResult struct {
// DeviceSessionID is the stable identifier of the created device session.
DeviceSessionID string `json:"device_session_id"`
}
// AuthServiceError allows an auth adapter to project a stable public REST
// error without teaching the gateway transport layer about upstream business
// rules.
type AuthServiceError struct {
// StatusCode is the HTTP status that the public REST handler should expose.
StatusCode int
// Code is the stable edge-level error code written into the JSON envelope.
Code string
// Message is the human-readable client-safe error description.
Message string
}
// Error returns a readable representation of the projected auth service error.
func (e *AuthServiceError) Error() string {
if e == nil {
return ""
}
switch {
case strings.TrimSpace(e.Code) == "" && strings.TrimSpace(e.Message) == "":
return http.StatusText(e.normalizedStatusCode())
case strings.TrimSpace(e.Code) == "":
return e.Message
case strings.TrimSpace(e.Message) == "":
return e.Code
default:
return e.Code + ": " + e.Message
}
}
func (e *AuthServiceError) normalizedStatusCode() int {
if e == nil || e.StatusCode < 400 || e.StatusCode > 599 {
return http.StatusInternalServerError
}
return e.StatusCode
}
func (e *AuthServiceError) normalizedCode() string {
if e == nil {
return errorCodeInternalError
}
code := strings.TrimSpace(e.Code)
if code == "" {
switch e.normalizedStatusCode() {
case http.StatusServiceUnavailable:
return errorCodeServiceUnavailable
case http.StatusBadRequest:
return errorCodeInvalidRequest
default:
return errorCodeInternalError
}
}
return code
}
func (e *AuthServiceError) normalizedMessage() string {
if e == nil {
return "internal server error"
}
message := strings.TrimSpace(e.Message)
if message == "" {
switch e.normalizedStatusCode() {
case http.StatusServiceUnavailable:
return "auth service is unavailable"
case http.StatusBadRequest:
return "request is invalid"
default:
return "internal server error"
}
}
return message
}
// unavailableAuthServiceClient keeps the public auth surface mounted until a
// concrete upstream adapter is wired into the gateway process.
type unavailableAuthServiceClient struct{}
func (unavailableAuthServiceClient) SendEmailCode(context.Context, SendEmailCodeInput) (SendEmailCodeResult, error) {
return SendEmailCodeResult{}, &AuthServiceError{
StatusCode: http.StatusServiceUnavailable,
Code: errorCodeServiceUnavailable,
Message: "auth service is unavailable",
}
}
func (unavailableAuthServiceClient) ConfirmEmailCode(context.Context, ConfirmEmailCodeInput) (ConfirmEmailCodeResult, error) {
return ConfirmEmailCodeResult{}, &AuthServiceError{
StatusCode: http.StatusServiceUnavailable,
Code: errorCodeServiceUnavailable,
Message: "auth service is unavailable",
}
}
func handleSendEmailCode(authService AuthServiceClient, timeout time.Duration) gin.HandlerFunc {
return func(c *gin.Context) {
var input SendEmailCodeInput
if err := decodeJSONRequest(c.Request, &input); err != nil {
abortInvalidRequest(c, err.Error())
return
}
if err := validateSendEmailCodeInput(&input); err != nil {
abortInvalidRequest(c, err.Error())
return
}
callCtx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel()
result, err := authService.SendEmailCode(callCtx, input)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
abortWithError(c, http.StatusServiceUnavailable, errorCodeServiceUnavailable, "auth service is unavailable")
return
}
abortWithAuthServiceError(c, err)
return
}
if err := validateSendEmailCodeResult(&result); err != nil {
abortWithError(c, http.StatusInternalServerError, errorCodeInternalError, "internal server error")
return
}
c.JSON(http.StatusOK, result)
}
}
func handleConfirmEmailCode(authService AuthServiceClient, timeout time.Duration) gin.HandlerFunc {
return func(c *gin.Context) {
var input ConfirmEmailCodeInput
if err := decodeJSONRequest(c.Request, &input); err != nil {
abortInvalidRequest(c, err.Error())
return
}
if err := validateConfirmEmailCodeInput(&input); err != nil {
abortInvalidRequest(c, err.Error())
return
}
callCtx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel()
result, err := authService.ConfirmEmailCode(callCtx, input)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
abortWithError(c, http.StatusServiceUnavailable, errorCodeServiceUnavailable, "auth service is unavailable")
return
}
abortWithAuthServiceError(c, err)
return
}
if err := validateConfirmEmailCodeResult(&result); err != nil {
abortWithError(c, http.StatusInternalServerError, errorCodeInternalError, "internal server error")
return
}
c.JSON(http.StatusOK, result)
}
}
func abortInvalidRequest(c *gin.Context, message string) {
abortWithError(c, http.StatusBadRequest, errorCodeInvalidRequest, message)
}
func abortWithAuthServiceError(c *gin.Context, err error) {
var authErr *AuthServiceError
if errors.As(err, &authErr) {
abortWithError(c, authErr.normalizedStatusCode(), authErr.normalizedCode(), authErr.normalizedMessage())
return
}
abortWithError(c, http.StatusInternalServerError, errorCodeInternalError, "internal server error")
}
func decodeJSONRequest(r *http.Request, target any) error {
if r == nil || r.Body == nil {
return &malformedJSONRequestError{
message: "request body must not be empty",
reason: PublicMalformedRequestReasonEmptyBody,
}
}
return decodeJSONReader(r.Body, target)
}
func decodeJSONBytes(bodyBytes []byte, target any) error {
return decodeJSONReader(bytes.NewReader(bodyBytes), target)
}
func decodeJSONReader(reader io.Reader, target any) error {
decoder := json.NewDecoder(reader)
decoder.DisallowUnknownFields()
if err := decoder.Decode(target); err != nil {
return describeJSONDecodeError(err)
}
if err := decoder.Decode(&struct{}{}); err != nil {
if errors.Is(err, io.EOF) {
return nil
}
return &malformedJSONRequestError{
message: "request body must contain a single JSON object",
reason: PublicMalformedRequestReasonMultipleJSONObjects,
}
}
return &malformedJSONRequestError{
message: "request body must contain a single JSON object",
reason: PublicMalformedRequestReasonMultipleJSONObjects,
}
}
func describeJSONDecodeError(err error) error {
var syntaxErr *json.SyntaxError
var typeErr *json.UnmarshalTypeError
switch {
case errors.Is(err, io.EOF):
return &malformedJSONRequestError{
message: "request body must not be empty",
reason: PublicMalformedRequestReasonEmptyBody,
}
case errors.As(err, &syntaxErr):
return &malformedJSONRequestError{
message: "request body contains malformed JSON",
reason: PublicMalformedRequestReasonMalformedJSON,
}
case errors.Is(err, io.ErrUnexpectedEOF):
return &malformedJSONRequestError{
message: "request body contains malformed JSON",
reason: PublicMalformedRequestReasonMalformedJSON,
}
case errors.As(err, &typeErr):
if strings.TrimSpace(typeErr.Field) != "" {
return &malformedJSONRequestError{
message: fmt.Sprintf("request body contains an invalid value for %q", typeErr.Field),
reason: PublicMalformedRequestReasonInvalidJSONValue,
}
}
return &malformedJSONRequestError{
message: "request body contains an invalid JSON value",
reason: PublicMalformedRequestReasonInvalidJSONValue,
}
case strings.HasPrefix(err.Error(), "json: unknown field "):
return &malformedJSONRequestError{
message: fmt.Sprintf("request body contains unknown field %s", strings.TrimPrefix(err.Error(), "json: unknown field ")),
reason: PublicMalformedRequestReasonUnknownField,
}
default:
return &malformedJSONRequestError{
message: "request body contains invalid JSON",
reason: PublicMalformedRequestReasonMalformedJSON,
}
}
}
func validateSendEmailCodeInput(input *SendEmailCodeInput) error {
input.Email = strings.TrimSpace(input.Email)
if input.Email == "" {
return errors.New("email must not be empty")
}
parsedAddress, err := mail.ParseAddress(input.Email)
if err != nil || parsedAddress.Name != "" || parsedAddress.Address != input.Email {
return errors.New("email must be a single valid email address")
}
return nil
}
func validateSendEmailCodeResult(result *SendEmailCodeResult) error {
result.ChallengeID = strings.TrimSpace(result.ChallengeID)
if result.ChallengeID == "" {
return errors.New("auth service returned an empty challenge_id")
}
return nil
}
func validateConfirmEmailCodeInput(input *ConfirmEmailCodeInput) error {
input.ChallengeID = strings.TrimSpace(input.ChallengeID)
if input.ChallengeID == "" {
return errors.New("challenge_id must not be empty")
}
input.Code = strings.TrimSpace(input.Code)
if input.Code == "" {
return errors.New("code must not be empty")
}
input.ClientPublicKey = strings.TrimSpace(input.ClientPublicKey)
if input.ClientPublicKey == "" {
return errors.New("client_public_key must not be empty")
}
return nil
}
func validateConfirmEmailCodeResult(result *ConfirmEmailCodeResult) error {
result.DeviceSessionID = strings.TrimSpace(result.DeviceSessionID)
if result.DeviceSessionID == "" {
return errors.New("auth service returned an empty device_session_id")
}
return nil
}
func malformedRequestReasonFromError(err error) (PublicMalformedRequestReason, bool) {
var malformedErr *malformedJSONRequestError
if !errors.As(err, &malformedErr) {
return "", false
}
return malformedErr.reason, true
}
func extractPublicAuthIdentity(requestPath string, bodyBytes []byte) (publicAuthIdentity, error) {
switch requestPath {
case "/api/v1/public/auth/send-email-code":
var input SendEmailCodeInput
if err := decodeJSONBytes(bodyBytes, &input); err != nil {
return publicAuthIdentity{}, err
}
if err := validateSendEmailCodeInput(&input); err != nil {
return publicAuthIdentity{}, err
}
return publicAuthIdentity{
kind: "email",
value: input.Email,
}, nil
case "/api/v1/public/auth/confirm-email-code":
var input ConfirmEmailCodeInput
if err := decodeJSONBytes(bodyBytes, &input); err != nil {
return publicAuthIdentity{}, err
}
if err := validateConfirmEmailCodeInput(&input); err != nil {
return publicAuthIdentity{}, err
}
return publicAuthIdentity{
kind: "challenge",
value: input.ChallengeID,
}, nil
default:
return publicAuthIdentity{}, errPublicAuthIdentityNotApplicable
}
}
@@ -0,0 +1,377 @@
package restapi
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"galaxy/gateway/internal/config"
"galaxy/gateway/internal/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSendEmailCodeHandlerSuccess(t *testing.T) {
t.Parallel()
authService := &recordingAuthServiceClient{
sendEmailCodeResult: SendEmailCodeResult{
ChallengeID: "challenge-123",
},
}
handler := newPublicHandler(ServerDependencies{AuthService: authService})
req := httptest.NewRequest(
http.MethodPost,
"/api/v1/public/auth/send-email-code",
strings.NewReader(`{"email":" pilot@example.com "}`),
)
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
assert.Equal(t, `{"challenge_id":"challenge-123"}`, recorder.Body.String())
assert.Equal(t, 1, authService.sendEmailCodeCalls)
assert.Equal(t, 0, authService.confirmEmailCodeCalls)
assert.Equal(t, SendEmailCodeInput{Email: "pilot@example.com"}, authService.sendEmailCodeInput)
assert.True(t, authService.sendEmailCodeRouteClassOK)
assert.Equal(t, PublicRouteClassPublicAuth, authService.sendEmailCodeRouteClass)
}
func TestConfirmEmailCodeHandlerSuccess(t *testing.T) {
t.Parallel()
authService := &recordingAuthServiceClient{
confirmEmailCodeResult: ConfirmEmailCodeResult{
DeviceSessionID: "device-session-123",
},
}
handler := newPublicHandler(ServerDependencies{AuthService: authService})
req := httptest.NewRequest(
http.MethodPost,
"/api/v1/public/auth/confirm-email-code",
strings.NewReader(`{"challenge_id":" challenge-123 ","code":" 123456 ","client_public_key":" public-key-material "}`),
)
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
assert.Equal(t, `{"device_session_id":"device-session-123"}`, recorder.Body.String())
assert.Equal(t, 0, authService.sendEmailCodeCalls)
assert.Equal(t, 1, authService.confirmEmailCodeCalls)
assert.Equal(t, ConfirmEmailCodeInput{
ChallengeID: "challenge-123",
Code: "123456",
ClientPublicKey: "public-key-material",
}, authService.confirmEmailCodeInput)
assert.True(t, authService.confirmEmailCodeRouteClassOK)
assert.Equal(t, PublicRouteClassPublicAuth, authService.confirmEmailCodeRouteClass)
}
func TestPublicAuthHandlersRejectInvalidRequests(t *testing.T) {
t.Parallel()
tests := []struct {
name string
target string
body string
wantStatus int
wantBody string
wantSendCalls int
wantConfirmCalls int
}{
{
name: "send email malformed json",
target: "/api/v1/public/auth/send-email-code",
body: `{"email":`,
wantStatus: http.StatusBadRequest,
wantBody: `{"error":{"code":"invalid_request","message":"request body contains malformed JSON"}}`,
wantSendCalls: 0,
wantConfirmCalls: 0,
},
{
name: "send email validation error",
target: "/api/v1/public/auth/send-email-code",
body: `{"email":"not-an-email"}`,
wantStatus: http.StatusBadRequest,
wantBody: `{"error":{"code":"invalid_request","message":"email must be a single valid email address"}}`,
wantSendCalls: 0,
wantConfirmCalls: 0,
},
{
name: "confirm email empty code",
target: "/api/v1/public/auth/confirm-email-code",
body: `{"challenge_id":"challenge-123","code":" ","client_public_key":"public-key-material"}`,
wantStatus: http.StatusBadRequest,
wantBody: `{"error":{"code":"invalid_request","message":"code must not be empty"}}`,
wantSendCalls: 0,
wantConfirmCalls: 0,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
authService := &recordingAuthServiceClient{}
handler := newPublicHandler(ServerDependencies{AuthService: authService})
req := httptest.NewRequest(http.MethodPost, tt.target, strings.NewReader(tt.body))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
assert.Equal(t, tt.wantStatus, recorder.Code)
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
assert.Equal(t, tt.wantBody, recorder.Body.String())
assert.Equal(t, tt.wantSendCalls, authService.sendEmailCodeCalls)
assert.Equal(t, tt.wantConfirmCalls, authService.confirmEmailCodeCalls)
})
}
}
func TestPublicAuthHandlersMapAdapterErrors(t *testing.T) {
t.Parallel()
tests := []struct {
name string
target string
body string
authClient *recordingAuthServiceClient
wantStatus int
wantBody string
}{
{
name: "auth service projected bad request",
target: "/api/v1/public/auth/confirm-email-code",
body: `{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`,
authClient: &recordingAuthServiceClient{
confirmEmailCodeErr: &AuthServiceError{
StatusCode: http.StatusBadRequest,
Code: errorCodeInvalidRequest,
Message: "confirmation code is invalid",
},
},
wantStatus: http.StatusBadRequest,
wantBody: `{"error":{"code":"invalid_request","message":"confirmation code is invalid"}}`,
},
{
name: "auth service projected custom too many requests",
target: "/api/v1/public/auth/send-email-code",
body: `{"email":"pilot@example.com"}`,
authClient: &recordingAuthServiceClient{
sendEmailCodeErr: &AuthServiceError{
StatusCode: http.StatusTooManyRequests,
Code: "upstream_rate_limited",
Message: "too many attempts for this email",
},
},
wantStatus: http.StatusTooManyRequests,
wantBody: `{"error":{"code":"upstream_rate_limited","message":"too many attempts for this email"}}`,
},
{
name: "auth service projected gateway normalizes blank gateway error fields",
target: "/api/v1/public/auth/confirm-email-code",
body: `{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`,
authClient: &recordingAuthServiceClient{
confirmEmailCodeErr: &AuthServiceError{
StatusCode: http.StatusBadGateway,
},
},
wantStatus: http.StatusBadGateway,
wantBody: `{"error":{"code":"internal_error","message":"internal server error"}}`,
},
{
name: "unexpected auth service error",
target: "/api/v1/public/auth/send-email-code",
body: `{"email":"pilot@example.com"}`,
authClient: &recordingAuthServiceClient{
sendEmailCodeErr: errors.New("boom"),
},
wantStatus: http.StatusInternalServerError,
wantBody: `{"error":{"code":"internal_error","message":"internal server error"}}`,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
handler := newPublicHandler(ServerDependencies{AuthService: tt.authClient})
req := httptest.NewRequest(http.MethodPost, tt.target, strings.NewReader(tt.body))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
assert.Equal(t, tt.wantStatus, recorder.Code)
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
assert.Equal(t, tt.wantBody, recorder.Body.String())
})
}
}
func TestDefaultAuthServiceReturnsServiceUnavailable(t *testing.T) {
t.Parallel()
handler := newPublicHandler(ServerDependencies{})
tests := []struct {
name string
method string
target string
body string
wantStatus int
wantBody string
}{
{
name: "send email code",
method: http.MethodPost,
target: "/api/v1/public/auth/send-email-code",
body: `{"email":"pilot@example.com"}`,
wantStatus: http.StatusServiceUnavailable,
wantBody: `{"error":{"code":"service_unavailable","message":"auth service is unavailable"}}`,
},
{
name: "confirm email code",
method: http.MethodPost,
target: "/api/v1/public/auth/confirm-email-code",
body: `{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`,
wantStatus: http.StatusServiceUnavailable,
wantBody: `{"error":{"code":"service_unavailable","message":"auth service is unavailable"}}`,
},
{
name: "healthz remains available",
method: http.MethodGet,
target: "/healthz",
wantStatus: http.StatusOK,
wantBody: `{"status":"ok"}`,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest(tt.method, tt.target, strings.NewReader(tt.body))
if tt.body != "" {
req.Header.Set("Content-Type", "application/json")
}
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
assert.Equal(t, tt.wantStatus, recorder.Code)
assert.Equal(t, jsonContentType, recorder.Header().Get("Content-Type"))
assert.Equal(t, tt.wantBody, recorder.Body.String())
})
}
}
func TestPublicAuthHandlerTimeoutMapsToServiceUnavailable(t *testing.T) {
t.Parallel()
authService := &recordingAuthServiceClient{
sendEmailCodeErr: context.DeadlineExceeded,
}
cfg := config.DefaultPublicHTTPConfig()
cfg.AuthUpstreamTimeout = 5 * time.Millisecond
handler := newPublicHandlerWithConfig(cfg, ServerDependencies{AuthService: authService})
req := httptest.NewRequest(
http.MethodPost,
"/api/v1/public/auth/send-email-code",
strings.NewReader(`{"email":"pilot@example.com"}`),
)
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code)
assert.Equal(t, `{"error":{"code":"service_unavailable","message":"auth service is unavailable"}}`, recorder.Body.String())
}
func TestPublicAuthLogsDoNotContainSensitiveFields(t *testing.T) {
t.Parallel()
logger, buffer := testutil.NewObservedLogger(t)
handler := newPublicHandler(ServerDependencies{
Logger: logger,
AuthService: &recordingAuthServiceClient{
confirmEmailCodeResult: ConfirmEmailCodeResult{DeviceSessionID: "device-session-123"},
},
})
req := httptest.NewRequest(
http.MethodPost,
"/api/v1/public/auth/confirm-email-code",
strings.NewReader(`{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key-material"}`),
)
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
require.Equal(t, http.StatusOK, recorder.Code)
logOutput := buffer.String()
assert.NotContains(t, logOutput, "challenge-123")
assert.NotContains(t, logOutput, "123456")
assert.NotContains(t, logOutput, "public-key-material")
assert.NotContains(t, logOutput, "pilot@example.com")
}
// recordingAuthServiceClient captures handler inputs and route classification
// so tests can assert the exact adapter delegation contract.
type recordingAuthServiceClient struct {
sendEmailCodeResult SendEmailCodeResult
sendEmailCodeErr error
sendEmailCodeInput SendEmailCodeInput
sendEmailCodeRouteClass PublicRouteClass
sendEmailCodeRouteClassOK bool
sendEmailCodeCalls int
confirmEmailCodeResult ConfirmEmailCodeResult
confirmEmailCodeErr error
confirmEmailCodeInput ConfirmEmailCodeInput
confirmEmailCodeRouteClass PublicRouteClass
confirmEmailCodeRouteClassOK bool
confirmEmailCodeCalls int
}
func (c *recordingAuthServiceClient) SendEmailCode(ctx context.Context, input SendEmailCodeInput) (SendEmailCodeResult, error) {
c.sendEmailCodeCalls++
c.sendEmailCodeInput = input
c.sendEmailCodeRouteClass, c.sendEmailCodeRouteClassOK = PublicRouteClassFromContext(ctx)
return c.sendEmailCodeResult, c.sendEmailCodeErr
}
func (c *recordingAuthServiceClient) ConfirmEmailCode(ctx context.Context, input ConfirmEmailCodeInput) (ConfirmEmailCodeResult, error) {
c.confirmEmailCodeCalls++
c.confirmEmailCodeInput = input
c.confirmEmailCodeRouteClass, c.confirmEmailCodeRouteClassOK = PublicRouteClassFromContext(ctx)
return c.confirmEmailCodeResult, c.confirmEmailCodeErr
}
+388
View File
@@ -0,0 +1,388 @@
// Package restapi exposes the unauthenticated public REST surface of the
// gateway.
package restapi
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"sync"
"galaxy/gateway/internal/config"
"galaxy/gateway/internal/telemetry"
"github.com/gin-gonic/gin"
"go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin"
"go.uber.org/zap"
)
const (
jsonContentType = "application/json; charset=utf-8"
errorCodeInvalidRequest = "invalid_request"
errorCodeNotFound = "not_found"
errorCodeMethodNotAllowed = "method_not_allowed"
errorCodeInternalError = "internal_error"
errorCodeServiceUnavailable = "service_unavailable"
publicRESTBaseBucketKeyPrefix = "public_rest/class="
)
// PublicRouteClass identifies the public traffic class assigned to an incoming
// REST request before route handling and edge policy evaluation.
type PublicRouteClass string
const (
// PublicRouteClassPublicAuth identifies public authentication commands.
PublicRouteClassPublicAuth PublicRouteClass = "public_auth"
// PublicRouteClassBrowserBootstrap identifies browser bootstrap traffic such
// as the main document request.
PublicRouteClassBrowserBootstrap PublicRouteClass = "browser_bootstrap"
// PublicRouteClassBrowserAsset identifies browser asset requests.
PublicRouteClassBrowserAsset PublicRouteClass = "browser_asset"
// PublicRouteClassPublicMisc identifies public traffic that does not match a
// more specific class.
PublicRouteClassPublicMisc PublicRouteClass = "public_misc"
)
var configureGinModeOnce sync.Once
// Normalized returns c when it belongs to the stable public route class set.
// Unknown or empty values collapse to PublicRouteClassPublicMisc so edge policy
// code can rely on a fixed anti-abuse namespace.
func (c PublicRouteClass) Normalized() PublicRouteClass {
switch c {
case PublicRouteClassPublicAuth,
PublicRouteClassBrowserBootstrap,
PublicRouteClassBrowserAsset,
PublicRouteClassPublicMisc:
return c
default:
return PublicRouteClassPublicMisc
}
}
// BaseBucketKey returns the canonical base rate-limit namespace for c. The key
// stays scoped only by the normalized public route class; callers may append
// subject dimensions such as IP or identity without redefining the class
// namespace.
func (c PublicRouteClass) BaseBucketKey() string {
return publicRESTBaseBucketKeyPrefix + string(c.Normalized())
}
// PublicTrafficClassifier maps public REST requests to the public anti-abuse
// class used by the gateway edge. The server normalizes classifier outputs to
// the stable class set before storing them in request context.
type PublicTrafficClassifier interface {
Classify(*http.Request) PublicRouteClass
}
// ServerDependencies describes the optional collaborators used by the public
// REST server. The zero value is valid and keeps the process runnable with the
// built-in defaults.
type ServerDependencies struct {
// Classifier assigns the public anti-abuse class before route handling.
// When nil, the gateway default classifier is used.
Classifier PublicTrafficClassifier
// AuthService delegates public auth commands to the Auth / Session Service.
// When nil, public auth routes remain mounted and return a stable
// service-unavailable response.
AuthService AuthServiceClient
// Limiter applies the public REST rate-limit policy. When nil, a default
// process-local in-memory limiter is used.
Limiter PublicRequestLimiter
// Observer records malformed-request telemetry for the public REST layer.
// When nil, a no-op observer is used.
Observer PublicRequestObserver
// Logger writes structured transport logs for public REST traffic. When nil,
// a no-op logger is used.
Logger *zap.Logger
// Telemetry records low-cardinality edge metrics. When nil, metrics are
// disabled.
Telemetry *telemetry.Runtime
}
// Server owns the public unauthenticated REST listener exposed by the gateway.
type Server struct {
cfg config.PublicHTTPConfig
handler http.Handler
logger *zap.Logger
stateMu sync.RWMutex
server *http.Server
listener net.Listener
}
// NewServer constructs a public REST server for the supplied listener
// configuration and dependency bundle. Nil dependencies are replaced with safe
// defaults so the gateway can still expose the documented public surface.
func NewServer(cfg config.PublicHTTPConfig, deps ServerDependencies) *Server {
deps = normalizeServerDependencies(deps)
return &Server{
cfg: cfg,
handler: newPublicHandlerWithConfig(cfg, deps),
logger: deps.Logger.Named("public_http"),
}
}
// Run binds the configured listener and serves the public REST surface until
// Shutdown closes the server.
func (s *Server) Run(ctx context.Context) error {
if ctx == nil {
return errors.New("run public REST server: nil context")
}
if err := ctx.Err(); err != nil {
return err
}
listener, err := net.Listen("tcp", s.cfg.Addr)
if err != nil {
return fmt.Errorf("run public REST server: listen on %q: %w", s.cfg.Addr, err)
}
server := &http.Server{
Handler: s.handler,
ReadHeaderTimeout: s.cfg.ReadHeaderTimeout,
ReadTimeout: s.cfg.ReadTimeout,
IdleTimeout: s.cfg.IdleTimeout,
}
s.stateMu.Lock()
s.server = server
s.listener = listener
s.stateMu.Unlock()
s.logger.Info("public REST server started", zap.String("addr", listener.Addr().String()))
defer func() {
s.stateMu.Lock()
s.server = nil
s.listener = nil
s.stateMu.Unlock()
}()
err = server.Serve(listener)
switch {
case err == nil:
return nil
case errors.Is(err, http.ErrServerClosed):
s.logger.Info("public REST server stopped")
return nil
default:
return fmt.Errorf("run public REST server: serve on %q: %w", s.cfg.Addr, err)
}
}
// Shutdown gracefully stops the public REST server within ctx.
func (s *Server) Shutdown(ctx context.Context) error {
if ctx == nil {
return errors.New("shutdown public REST server: nil context")
}
s.stateMu.RLock()
server := s.server
s.stateMu.RUnlock()
if server == nil {
return nil
}
if err := server.Shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("shutdown public REST server: %w", err)
}
return nil
}
// PublicRouteClassFromContext returns the previously classified normalized
// public route class stored in ctx.
func PublicRouteClassFromContext(ctx context.Context) (PublicRouteClass, bool) {
if ctx == nil {
return "", false
}
class, ok := ctx.Value(publicRouteClassContextKey{}).(PublicRouteClass)
if !ok {
return "", false
}
return class.Normalized(), true
}
type publicRouteClassContextKey struct{}
type defaultPublicTrafficClassifier struct{}
// Classify maps the incoming request into a stable public route class that can
// later drive anti-abuse policy and rate limiting.
func (defaultPublicTrafficClassifier) Classify(r *http.Request) PublicRouteClass {
switch {
case isPublicAuthRequest(r):
return PublicRouteClassPublicAuth
case isBrowserBootstrapRequest(r):
return PublicRouteClassBrowserBootstrap
case isBrowserAssetRequest(r):
return PublicRouteClassBrowserAsset
default:
return PublicRouteClassPublicMisc
}
}
func normalizeServerDependencies(deps ServerDependencies) ServerDependencies {
if deps.Classifier == nil {
deps.Classifier = defaultPublicTrafficClassifier{}
}
if deps.AuthService == nil {
deps.AuthService = unavailableAuthServiceClient{}
}
if deps.Limiter == nil {
deps.Limiter = newInMemoryPublicRequestLimiter()
}
if deps.Observer == nil {
deps.Observer = noopPublicRequestObserver{}
}
if deps.Logger == nil {
deps.Logger = zap.NewNop()
}
return deps
}
func newPublicHandler(deps ServerDependencies) http.Handler {
return newPublicHandlerWithConfig(config.DefaultPublicHTTPConfig(), deps)
}
func newPublicHandlerWithConfig(cfg config.PublicHTTPConfig, deps ServerDependencies) http.Handler {
configureGinModeOnce.Do(func() {
gin.SetMode(gin.ReleaseMode)
})
deps = normalizeServerDependencies(deps)
router := gin.New()
router.HandleMethodNotAllowed = true
router.Use(gin.CustomRecovery(func(c *gin.Context, _ any) {
abortWithError(c, http.StatusInternalServerError, errorCodeInternalError, "internal server error")
}))
router.Use(otelgin.Middleware("galaxy-edge-gateway-public"))
router.Use(withPublicObservability(deps.Logger.Named("public_http"), deps.Telemetry))
router.Use(withPublicRouteClass(deps.Classifier))
router.Use(withPublicAntiAbuse(cfg.AntiAbuse, deps.Limiter, deps.Observer))
router.GET("/healthz", handleHealthz)
router.GET("/readyz", handleReadyz)
router.POST("/api/v1/public/auth/send-email-code", handleSendEmailCode(deps.AuthService, cfg.AuthUpstreamTimeout))
router.POST("/api/v1/public/auth/confirm-email-code", handleConfirmEmailCode(deps.AuthService, cfg.AuthUpstreamTimeout))
router.NoMethod(func(c *gin.Context) {
allowMethods := allowedMethodsForPath(c.Request.URL.Path)
if allowMethods != "" {
c.Header("Allow", allowMethods)
}
abortWithError(c, http.StatusMethodNotAllowed, errorCodeMethodNotAllowed, "request method is not allowed for this route")
})
router.NoRoute(func(c *gin.Context) {
abortWithError(c, http.StatusNotFound, errorCodeNotFound, "resource was not found")
})
return router
}
func handleHealthz(c *gin.Context) {
c.JSON(http.StatusOK, statusResponse{Status: "ok"})
}
func handleReadyz(c *gin.Context) {
c.JSON(http.StatusOK, statusResponse{Status: "ready"})
}
func withPublicRouteClass(classifier PublicTrafficClassifier) gin.HandlerFunc {
return func(c *gin.Context) {
class := classifier.Classify(c.Request).Normalized()
ctx := context.WithValue(c.Request.Context(), publicRouteClassContextKey{}, class)
c.Request = c.Request.WithContext(ctx)
c.Next()
}
}
func isPublicAuthRequest(r *http.Request) bool {
return r.Method == http.MethodPost && isPublicAuthPath(r.URL.Path)
}
func isBrowserBootstrapRequest(r *http.Request) bool {
if r.Method == http.MethodGet && r.URL.Path == "/" {
return true
}
return matchesBrowserBootstrapRequestShape(r)
}
func isBrowserAssetRequest(r *http.Request) bool {
if r.Method != http.MethodGet && r.Method != http.MethodHead {
return false
}
return matchesBrowserAssetRequestShape(r)
}
type statusResponse struct {
Status string `json:"status"`
}
type errorResponse struct {
Error errorBody `json:"error"`
}
type errorBody struct {
Code string `json:"code"`
Message string `json:"message"`
}
func abortWithError(c *gin.Context, statusCode int, code string, message string) {
if c != nil {
c.Set(publicErrorCodeContextKey, code)
}
c.AbortWithStatusJSON(statusCode, errorResponse{
Error: errorBody{
Code: code,
Message: message,
},
})
}
const publicErrorCodeContextKey = "public_error_code"
func allowedMethodsForPath(requestPath string) string {
switch requestPath {
case "/healthz", "/readyz":
return http.MethodGet
case "/api/v1/public/auth/send-email-code", "/api/v1/public/auth/confirm-email-code":
return http.MethodPost
default:
return ""
}
}
func (s *Server) listenAddr() string {
s.stateMu.RLock()
defer s.stateMu.RUnlock()
if s.listener == nil {
return ""
}
return s.listener.Addr().String()
}
+459
View File
@@ -0,0 +1,459 @@
package restapi
import (
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"galaxy/gateway/internal/app"
"galaxy/gateway/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPublicHandlerHealthEndpoints(t *testing.T) {
t.Parallel()
handler := newPublicHandler(ServerDependencies{})
tests := []struct {
name string
method string
target string
wantStatus int
wantType string
wantBody string
wantAllow string
}{
{
name: "healthz",
method: http.MethodGet,
target: "/healthz",
wantStatus: http.StatusOK,
wantType: jsonContentType,
wantBody: `{"status":"ok"}`,
},
{
name: "readyz",
method: http.MethodGet,
target: "/readyz",
wantStatus: http.StatusOK,
wantType: jsonContentType,
wantBody: `{"status":"ready"}`,
},
{
name: "wrong method on known route",
method: http.MethodPost,
target: "/healthz",
wantStatus: http.StatusMethodNotAllowed,
wantType: jsonContentType,
wantBody: `{"error":{"code":"method_not_allowed","message":"request method is not allowed for this route"}}`,
wantAllow: http.MethodGet,
},
{
name: "unknown route",
method: http.MethodGet,
target: "/unknown",
wantStatus: http.StatusNotFound,
wantType: jsonContentType,
wantBody: `{"error":{"code":"not_found","message":"resource was not found"}}`,
},
{
name: "wrong method on public auth route",
method: http.MethodGet,
target: "/api/v1/public/auth/send-email-code",
wantStatus: http.StatusMethodNotAllowed,
wantType: jsonContentType,
wantBody: `{"error":{"code":"method_not_allowed","message":"request method is not allowed for this route"}}`,
wantAllow: http.MethodPost,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest(tt.method, tt.target, nil)
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
assert.Equal(t, tt.wantStatus, recorder.Code)
assert.Equal(t, tt.wantType, recorder.Header().Get("Content-Type"))
assert.Equal(t, tt.wantBody, recorder.Body.String())
assert.Equal(t, tt.wantAllow, recorder.Header().Get("Allow"))
})
}
}
func TestDefaultPublicTrafficClassifier(t *testing.T) {
t.Parallel()
classifier := defaultPublicTrafficClassifier{}
tests := []struct {
name string
method string
target string
accept string
wantClass PublicRouteClass
}{
{
name: "public auth route",
method: http.MethodPost,
target: "/api/v1/public/auth/send-email-code",
wantClass: PublicRouteClassPublicAuth,
},
{
name: "public auth confirm route",
method: http.MethodPost,
target: "/api/v1/public/auth/confirm-email-code",
wantClass: PublicRouteClassPublicAuth,
},
{
name: "browser bootstrap route",
method: http.MethodGet,
target: "/",
wantClass: PublicRouteClassBrowserBootstrap,
},
{
name: "browser asset route",
method: http.MethodGet,
target: "/assets/app.js",
wantClass: PublicRouteClassBrowserAsset,
},
{
name: "browser asset head request",
method: http.MethodHead,
target: "/assets/app.js",
wantClass: PublicRouteClassBrowserAsset,
},
{
name: "browser asset extension request",
method: http.MethodGet,
target: "/manifest.webmanifest",
wantClass: PublicRouteClassBrowserAsset,
},
{
name: "public misc route",
method: http.MethodPost,
target: "/api/v1/public/unknown",
wantClass: PublicRouteClassPublicMisc,
},
{
name: "html accept bootstrap route",
method: http.MethodGet,
target: "/app",
accept: "application/json, text/html;q=0.9",
wantClass: PublicRouteClassBrowserBootstrap,
},
{
name: "public auth wins over browser accept header",
method: http.MethodPost,
target: "/api/v1/public/auth/confirm-email-code",
accept: "text/html",
wantClass: PublicRouteClassPublicAuth,
},
{
name: "probe with html accept is bootstrap",
method: http.MethodGet,
target: "/healthz",
accept: "text/html",
wantClass: PublicRouteClassBrowserBootstrap,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest(tt.method, tt.target, nil)
if tt.accept != "" {
req.Header.Set("Accept", tt.accept)
}
assert.Equal(t, tt.wantClass, classifier.Classify(req))
})
}
}
func TestPublicRouteClassNormalized(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input PublicRouteClass
want PublicRouteClass
}{
{
name: "public auth",
input: PublicRouteClassPublicAuth,
want: PublicRouteClassPublicAuth,
},
{
name: "browser bootstrap",
input: PublicRouteClassBrowserBootstrap,
want: PublicRouteClassBrowserBootstrap,
},
{
name: "browser asset",
input: PublicRouteClassBrowserAsset,
want: PublicRouteClassBrowserAsset,
},
{
name: "public misc",
input: PublicRouteClassPublicMisc,
want: PublicRouteClassPublicMisc,
},
{
name: "unknown collapses to misc",
input: PublicRouteClass("unexpected"),
want: PublicRouteClassPublicMisc,
},
{
name: "empty collapses to misc",
input: PublicRouteClass(""),
want: PublicRouteClassPublicMisc,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
assert.Equal(t, tt.want, tt.input.Normalized())
})
}
}
func TestPublicRouteClassBaseBucketKeyIsolation(t *testing.T) {
t.Parallel()
tests := []struct {
name string
class PublicRouteClass
wantKey string
}{
{
name: "public auth",
class: PublicRouteClassPublicAuth,
wantKey: "public_rest/class=public_auth",
},
{
name: "browser bootstrap",
class: PublicRouteClassBrowserBootstrap,
wantKey: "public_rest/class=browser_bootstrap",
},
{
name: "browser asset",
class: PublicRouteClassBrowserAsset,
wantKey: "public_rest/class=browser_asset",
},
{
name: "public misc",
class: PublicRouteClassPublicMisc,
wantKey: "public_rest/class=public_misc",
},
{
name: "unknown collapses to misc namespace",
class: PublicRouteClass("unexpected"),
wantKey: "public_rest/class=public_misc",
},
}
seenKeys := make(map[string]PublicRouteClass, len(tests))
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
assert.Equal(t, tt.wantKey, tt.class.BaseBucketKey())
})
normalizedClass := tt.class.Normalized()
if normalizedClass == PublicRouteClassPublicMisc && tt.class != PublicRouteClassPublicMisc {
continue
}
if previousClass, exists := seenKeys[tt.wantKey]; exists {
require.FailNowf(t, "bucket key collision", "class %q collides with %q on key %q", tt.class, previousClass, tt.wantKey)
}
seenKeys[tt.wantKey] = tt.class
}
assert.NotEqual(t, PublicRouteClassPublicAuth.BaseBucketKey(), PublicRouteClassBrowserBootstrap.BaseBucketKey())
assert.NotEqual(t, PublicRouteClassPublicAuth.BaseBucketKey(), PublicRouteClassBrowserAsset.BaseBucketKey())
}
func TestWithPublicRouteClassStoresClassInContext(t *testing.T) {
t.Parallel()
router := gin.New()
router.Use(withPublicRouteClass(staticClassifier{class: PublicRouteClassBrowserAsset}))
router.GET("/assets/app.js", func(c *gin.Context) {
class, ok := PublicRouteClassFromContext(c.Request.Context())
require.True(t, ok)
assert.Equal(t, PublicRouteClassBrowserAsset, class)
c.JSON(http.StatusOK, statusResponse{Status: "ok"})
})
req := httptest.NewRequest(http.MethodGet, "/assets/app.js", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, `{"status":"ok"}`, recorder.Body.String())
}
func TestWithPublicRouteClassNormalizesUnsupportedClassToPublicMisc(t *testing.T) {
t.Parallel()
tests := []struct {
name string
class PublicRouteClass
}{
{
name: "unknown class",
class: PublicRouteClass("unexpected"),
},
{
name: "empty class",
class: PublicRouteClass(""),
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
router := gin.New()
router.Use(withPublicRouteClass(staticClassifier{class: tt.class}))
router.GET("/", func(c *gin.Context) {
class, ok := PublicRouteClassFromContext(c.Request.Context())
require.True(t, ok)
assert.Equal(t, PublicRouteClassPublicMisc, class)
c.JSON(http.StatusOK, statusResponse{Status: "ok"})
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, `{"status":"ok"}`, recorder.Body.String())
})
}
}
func TestServerLifecycle(t *testing.T) {
t.Parallel()
cfg := config.Config{
ShutdownTimeout: time.Second,
PublicHTTP: func() config.PublicHTTPConfig {
publicHTTPCfg := config.DefaultPublicHTTPConfig()
publicHTTPCfg.Addr = "127.0.0.1:0"
publicHTTPCfg.AntiAbuse.PublicMisc.RateLimit = config.PublicRateLimitConfig{
Requests: 1000,
Window: time.Minute,
Burst: 1000,
}
return publicHTTPCfg
}(),
}
server := NewServer(cfg.PublicHTTP, ServerDependencies{})
application := app.New(cfg, server)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resultCh := make(chan error, 1)
go func() {
resultCh <- application.Run(ctx)
}()
addr := waitForListenAddr(t, server)
waitForHealthResponse(t, addr)
cancel()
select {
case err := <-resultCh:
require.NoError(t, err)
case <-time.After(2 * time.Second):
require.FailNow(t, "Run() did not return after cancellation")
}
}
type staticClassifier struct {
class PublicRouteClass
}
func (c staticClassifier) Classify(*http.Request) PublicRouteClass {
return c.class
}
func waitForListenAddr(t *testing.T, server *Server) string {
t.Helper()
deadline := time.Now().Add(time.Second)
for time.Now().Before(deadline) {
if addr := server.listenAddr(); addr != "" {
return addr
}
time.Sleep(10 * time.Millisecond)
}
require.FailNow(t, "server did not start listening")
return ""
}
func waitForHealthResponse(t *testing.T, addr string) {
t.Helper()
client := &http.Client{Timeout: 100 * time.Millisecond}
url := "http://" + addr + "/healthz"
deadline := time.Now().Add(time.Second)
for time.Now().Before(deadline) {
resp, err := client.Get(url)
if err != nil {
time.Sleep(10 * time.Millisecond)
continue
}
body, readErr := io.ReadAll(resp.Body)
closeErr := resp.Body.Close()
require.NoError(t, readErr)
require.NoError(t, closeErr)
require.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, `{"status":"ok"}`, strings.TrimSpace(string(body)))
return
}
require.FailNowf(t, "health check timed out", "url=%s", url)
}
+88
View File
@@ -0,0 +1,88 @@
package session
import (
"context"
"errors"
"fmt"
"strings"
"sync"
)
// MemoryCache stores session record snapshots in process-local memory. It is
// intended for the authenticated gateway hot path and deliberately keeps no
// TTL or size-based eviction policy.
type MemoryCache struct {
mu sync.RWMutex
records map[string]Record
}
// NewMemoryCache constructs an empty process-local session snapshot store.
func NewMemoryCache() *MemoryCache {
return &MemoryCache{
records: make(map[string]Record),
}
}
// Lookup resolves deviceSessionID from the process-local snapshot map.
func (c *MemoryCache) Lookup(ctx context.Context, deviceSessionID string) (Record, error) {
if c == nil {
return Record{}, errors.New("lookup session from in-memory cache: nil cache")
}
if ctx == nil || fmt.Sprint(ctx) == "context.TODO" {
return Record{}, errors.New("lookup session from in-memory cache: nil context")
}
if strings.TrimSpace(deviceSessionID) == "" {
return Record{}, errors.New("lookup session from in-memory cache: empty device session id")
}
c.mu.RLock()
record, ok := c.records[deviceSessionID]
c.mu.RUnlock()
if !ok {
return Record{}, fmt.Errorf("lookup session from in-memory cache: %w", ErrNotFound)
}
return cloneRecord(record), nil
}
// Upsert stores record in the process-local snapshot map after validating the
// same session invariants expected from the Redis-backed cache.
func (c *MemoryCache) Upsert(record Record) error {
if c == nil {
return errors.New("upsert session into in-memory cache: nil cache")
}
if err := validateRecord(record.DeviceSessionID, record); err != nil {
return fmt.Errorf("upsert session into in-memory cache: %w", err)
}
cloned := cloneRecord(record)
c.mu.Lock()
c.records[record.DeviceSessionID] = cloned
c.mu.Unlock()
return nil
}
// Delete removes the local snapshot for deviceSessionID when one exists.
func (c *MemoryCache) Delete(deviceSessionID string) {
if c == nil || strings.TrimSpace(deviceSessionID) == "" {
return
}
c.mu.Lock()
delete(c.records, deviceSessionID)
c.mu.Unlock()
}
func cloneRecord(record Record) Record {
cloned := record
if record.RevokedAtMS != nil {
value := *record.RevokedAtMS
cloned.RevokedAtMS = &value
}
return cloned
}
var _ SnapshotStore = (*MemoryCache)(nil)
+68
View File
@@ -0,0 +1,68 @@
package session
import (
"context"
"errors"
"fmt"
)
// ReadThroughCache resolves authenticated sessions from a process-local
// SnapshotStore first and falls back to another Cache only on a local miss.
type ReadThroughCache struct {
local SnapshotStore
fallback Cache
}
// NewReadThroughCache constructs a hot-path cache that seeds local snapshots
// from fallback on demand.
func NewReadThroughCache(local SnapshotStore, fallback Cache) (*ReadThroughCache, error) {
if local == nil {
return nil, errors.New("new read-through session cache: nil local cache")
}
if fallback == nil {
return nil, errors.New("new read-through session cache: nil fallback cache")
}
return &ReadThroughCache{
local: local,
fallback: fallback,
}, nil
}
// Lookup resolves deviceSessionID from local first, then performs one fallback
// lookup on a local miss and seeds the local cache with the returned snapshot.
func (c *ReadThroughCache) Lookup(ctx context.Context, deviceSessionID string) (Record, error) {
if c == nil {
return Record{}, errors.New("lookup session from read-through cache: nil cache")
}
record, err := c.local.Lookup(ctx, deviceSessionID)
switch {
case err == nil:
return record, nil
case !errors.Is(err, ErrNotFound):
return Record{}, fmt.Errorf("lookup session from read-through cache: %w", err)
}
record, err = c.fallback.Lookup(ctx, deviceSessionID)
if err != nil {
return Record{}, err
}
if err := c.local.Upsert(record); err != nil {
return Record{}, fmt.Errorf("lookup session from read-through cache: seed local cache: %w", err)
}
return cloneRecord(record), nil
}
// Local returns the mutable process-local snapshot store used by c.
func (c *ReadThroughCache) Local() SnapshotStore {
if c == nil {
return nil
}
return c.local
}
var _ Cache = (*ReadThroughCache)(nil)
@@ -0,0 +1,176 @@
package session
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMemoryCacheLookupReturnsClonedRecord(t *testing.T) {
t.Parallel()
cache := NewMemoryCache()
revokedAtMS := int64(123456789)
require.NoError(t, cache.Upsert(Record{
DeviceSessionID: "device-session-123",
UserID: "user-123",
ClientPublicKey: "public-key-123",
Status: StatusRevoked,
RevokedAtMS: &revokedAtMS,
}))
record, err := cache.Lookup(context.Background(), "device-session-123")
require.NoError(t, err)
require.NotNil(t, record.RevokedAtMS)
*record.RevokedAtMS = 1
stored, err := cache.Lookup(context.Background(), "device-session-123")
require.NoError(t, err)
require.NotNil(t, stored.RevokedAtMS)
assert.Equal(t, revokedAtMS, *stored.RevokedAtMS)
}
func TestReadThroughCacheLocalHitSkipsFallback(t *testing.T) {
t.Parallel()
local := NewMemoryCache()
require.NoError(t, local.Upsert(Record{
DeviceSessionID: "device-session-123",
UserID: "user-123",
ClientPublicKey: "public-key-123",
Status: StatusActive,
}))
fallback := &recordingCache{
lookupFunc: func(context.Context, string) (Record, error) {
return Record{}, errors.New("fallback should not be called")
},
}
cache, err := NewReadThroughCache(local, fallback)
require.NoError(t, err)
record, err := cache.Lookup(context.Background(), "device-session-123")
require.NoError(t, err)
assert.Equal(t, Record{
DeviceSessionID: "device-session-123",
UserID: "user-123",
ClientPublicKey: "public-key-123",
Status: StatusActive,
}, record)
assert.Equal(t, 0, fallback.lookupCalls)
}
func TestReadThroughCacheFallbackSeedsLocalCache(t *testing.T) {
t.Parallel()
local := NewMemoryCache()
fallback := &recordingCache{
lookupFunc: func(context.Context, string) (Record, error) {
return Record{
DeviceSessionID: "device-session-123",
UserID: "user-123",
ClientPublicKey: "public-key-123",
Status: StatusActive,
}, nil
},
}
cache, err := NewReadThroughCache(local, fallback)
require.NoError(t, err)
record, err := cache.Lookup(context.Background(), "device-session-123")
require.NoError(t, err)
assert.Equal(t, 1, fallback.lookupCalls)
assert.Equal(t, "user-123", record.UserID)
record, err = cache.Lookup(context.Background(), "device-session-123")
require.NoError(t, err)
assert.Equal(t, 1, fallback.lookupCalls)
assert.Equal(t, "user-123", record.UserID)
}
func TestReadThroughCacheKeepsRevokedSnapshotLocal(t *testing.T) {
t.Parallel()
revokedAtMS := int64(123456789)
local := NewMemoryCache()
fallback := &recordingCache{
lookupFunc: func(context.Context, string) (Record, error) {
return Record{
DeviceSessionID: "device-session-123",
UserID: "user-123",
ClientPublicKey: "public-key-123",
Status: StatusRevoked,
RevokedAtMS: &revokedAtMS,
}, nil
},
}
cache, err := NewReadThroughCache(local, fallback)
require.NoError(t, err)
record, err := cache.Lookup(context.Background(), "device-session-123")
require.NoError(t, err)
require.NotNil(t, record.RevokedAtMS)
assert.Equal(t, StatusRevoked, record.Status)
assert.Equal(t, 1, fallback.lookupCalls)
record, err = cache.Lookup(context.Background(), "device-session-123")
require.NoError(t, err)
require.NotNil(t, record.RevokedAtMS)
assert.Equal(t, StatusRevoked, record.Status)
assert.Equal(t, revokedAtMS, *record.RevokedAtMS)
assert.Equal(t, 1, fallback.lookupCalls)
}
func TestReadThroughCacheReturnsClonedFallbackRecord(t *testing.T) {
t.Parallel()
revokedAtMS := int64(123456789)
local := NewMemoryCache()
fallback := &recordingCache{
lookupFunc: func(context.Context, string) (Record, error) {
return Record{
DeviceSessionID: "device-session-123",
UserID: "user-123",
ClientPublicKey: "public-key-123",
Status: StatusRevoked,
RevokedAtMS: &revokedAtMS,
}, nil
},
}
cache, err := NewReadThroughCache(local, fallback)
require.NoError(t, err)
record, err := cache.Lookup(context.Background(), "device-session-123")
require.NoError(t, err)
require.NotNil(t, record.RevokedAtMS)
*record.RevokedAtMS = 1
stored, err := local.Lookup(context.Background(), "device-session-123")
require.NoError(t, err)
require.NotNil(t, stored.RevokedAtMS)
assert.Equal(t, revokedAtMS, *stored.RevokedAtMS)
}
type recordingCache struct {
lookupCalls int
lookupFunc func(context.Context, string) (Record, error)
}
func (c *recordingCache) Lookup(ctx context.Context, deviceSessionID string) (Record, error) {
c.lookupCalls++
if c.lookupFunc != nil {
return c.lookupFunc(ctx, deviceSessionID)
}
return Record{}, errors.New("lookup is not implemented")
}
+192
View File
@@ -0,0 +1,192 @@
package session
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io"
"strings"
"time"
"galaxy/gateway/internal/config"
"github.com/redis/go-redis/v9"
)
// RedisCache implements Cache with Redis GET lookups over strict JSON session
// records.
type RedisCache struct {
client *redis.Client
keyPrefix string
lookupTimeout time.Duration
}
type redisRecord struct {
DeviceSessionID string `json:"device_session_id"`
UserID string `json:"user_id"`
ClientPublicKey string `json:"client_public_key"`
Status Status `json:"status"`
RevokedAtMS *int64 `json:"revoked_at_ms,omitempty"`
}
// NewRedisCache constructs a Redis-backed SessionCache from cfg. The returned
// cache is read-only from the gateway perspective and does not write or mutate
// Redis state.
func NewRedisCache(cfg config.SessionCacheRedisConfig) (*RedisCache, error) {
if strings.TrimSpace(cfg.Addr) == "" {
return nil, errors.New("new redis session cache: redis addr must not be empty")
}
if cfg.DB < 0 {
return nil, errors.New("new redis session cache: redis db must not be negative")
}
if cfg.LookupTimeout <= 0 {
return nil, errors.New("new redis session cache: lookup timeout must be positive")
}
options := &redis.Options{
Addr: cfg.Addr,
Username: cfg.Username,
Password: cfg.Password,
DB: cfg.DB,
Protocol: 2,
DisableIdentity: true,
}
if cfg.TLSEnabled {
options.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
}
return &RedisCache{
client: redis.NewClient(options),
keyPrefix: cfg.KeyPrefix,
lookupTimeout: cfg.LookupTimeout,
}, nil
}
// Close releases the underlying Redis client resources.
func (c *RedisCache) Close() error {
if c == nil || c.client == nil {
return nil
}
return c.client.Close()
}
// Ping verifies that the configured Redis backend is reachable within the
// cache lookup timeout budget.
func (c *RedisCache) Ping(ctx context.Context) error {
if c == nil || c.client == nil {
return errors.New("ping redis session cache: nil cache")
}
if ctx == nil {
return errors.New("ping redis session cache: nil context")
}
pingCtx, cancel := context.WithTimeout(ctx, c.lookupTimeout)
defer cancel()
if err := c.client.Ping(pingCtx).Err(); err != nil {
return fmt.Errorf("ping redis session cache: %w", err)
}
return nil
}
// Lookup resolves deviceSessionID from Redis, validates the cached JSON
// payload strictly, and returns the decoded session record.
func (c *RedisCache) Lookup(ctx context.Context, deviceSessionID string) (Record, error) {
if c == nil || c.client == nil {
return Record{}, errors.New("lookup session from redis: nil cache")
}
if ctx == nil || fmt.Sprint(ctx) == "context.TODO" {
return Record{}, errors.New("lookup session from redis: nil context")
}
if strings.TrimSpace(deviceSessionID) == "" {
return Record{}, errors.New("lookup session from redis: empty device session id")
}
lookupCtx, cancel := context.WithTimeout(ctx, c.lookupTimeout)
defer cancel()
payload, err := c.client.Get(lookupCtx, c.lookupKey(deviceSessionID)).Bytes()
switch {
case errors.Is(err, redis.Nil):
return Record{}, fmt.Errorf("lookup session from redis: %w", ErrNotFound)
case err != nil:
return Record{}, fmt.Errorf("lookup session from redis: %w", err)
}
record, err := decodeRedisRecord(deviceSessionID, payload)
if err != nil {
return Record{}, fmt.Errorf("lookup session from redis: %w", err)
}
return record, nil
}
func (c *RedisCache) lookupKey(deviceSessionID string) string {
return c.keyPrefix + deviceSessionID
}
func decodeRedisRecord(expectedDeviceSessionID string, payload []byte) (Record, error) {
decoder := json.NewDecoder(bytes.NewReader(payload))
decoder.DisallowUnknownFields()
var stored redisRecord
if err := decoder.Decode(&stored); err != nil {
return Record{}, fmt.Errorf("decode redis session record: %w", err)
}
if err := decoder.Decode(&struct{}{}); err != io.EOF {
if err == nil {
return Record{}, errors.New("decode redis session record: unexpected trailing JSON input")
}
return Record{}, fmt.Errorf("decode redis session record: %w", err)
}
record := Record{
DeviceSessionID: stored.DeviceSessionID,
UserID: stored.UserID,
ClientPublicKey: stored.ClientPublicKey,
Status: stored.Status,
RevokedAtMS: cloneOptionalInt64(stored.RevokedAtMS),
}
if err := validateRecord(expectedDeviceSessionID, record); err != nil {
return Record{}, err
}
return record, nil
}
func validateRecord(expectedDeviceSessionID string, record Record) error {
if record.DeviceSessionID == "" {
return errors.New("session record device_session_id must not be empty")
}
if record.DeviceSessionID != expectedDeviceSessionID {
return fmt.Errorf("session record device_session_id %q does not match requested %q", record.DeviceSessionID, expectedDeviceSessionID)
}
if record.UserID == "" {
return errors.New("session record user_id must not be empty")
}
if record.ClientPublicKey == "" {
return errors.New("session record client_public_key must not be empty")
}
if !record.Status.IsKnown() {
return fmt.Errorf("session record status %q is unsupported", record.Status)
}
return nil
}
func cloneOptionalInt64(value *int64) *int64 {
if value == nil {
return nil
}
cloned := *value
return &cloned
}
var _ Cache = (*RedisCache)(nil)
+331
View File
@@ -0,0 +1,331 @@
package session
import (
"context"
"encoding/json"
"errors"
"testing"
"time"
"galaxy/gateway/internal/config"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewRedisCache(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
tests := []struct {
name string
cfg config.SessionCacheRedisConfig
wantErr string
}{
{
name: "valid config",
cfg: config.SessionCacheRedisConfig{
Addr: server.Addr(),
DB: 2,
KeyPrefix: "gateway:session:",
LookupTimeout: 250 * time.Millisecond,
},
},
{
name: "empty addr",
cfg: config.SessionCacheRedisConfig{
LookupTimeout: 250 * time.Millisecond,
},
wantErr: "redis addr must not be empty",
},
{
name: "negative db",
cfg: config.SessionCacheRedisConfig{
Addr: server.Addr(),
DB: -1,
LookupTimeout: 250 * time.Millisecond,
},
wantErr: "redis db must not be negative",
},
{
name: "non-positive lookup timeout",
cfg: config.SessionCacheRedisConfig{
Addr: server.Addr(),
},
wantErr: "lookup timeout must be positive",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
cache, err := NewRedisCache(tt.cfg)
if tt.wantErr != "" {
require.Error(t, err)
require.ErrorContains(t, err, tt.wantErr)
return
}
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, cache.Close())
})
})
}
}
func TestRedisCachePing(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
cache := newTestRedisCache(t, server, config.SessionCacheRedisConfig{})
require.NoError(t, cache.Ping(context.Background()))
}
func TestRedisCacheLookup(t *testing.T) {
t.Parallel()
revokedAtMS := int64(123456789)
tests := []struct {
name string
cfg config.SessionCacheRedisConfig
requestID string
seed func(*testing.T, *miniredis.Miniredis, config.SessionCacheRedisConfig)
want Record
wantErrIs error
wantErrText string
assertErrText string
}{
{
name: "active cache hit",
requestID: "device-session-123",
cfg: config.SessionCacheRedisConfig{
KeyPrefix: "gateway:session:",
},
seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) {
t.Helper()
setRedisSessionRecord(t, server, cfg.KeyPrefix+"device-session-123", redisRecord{
DeviceSessionID: "device-session-123",
UserID: "user-123",
ClientPublicKey: "public-key-123",
Status: StatusActive,
})
},
want: Record{
DeviceSessionID: "device-session-123",
UserID: "user-123",
ClientPublicKey: "public-key-123",
Status: StatusActive,
},
},
{
name: "missing session",
requestID: "device-session-404",
cfg: config.SessionCacheRedisConfig{
KeyPrefix: "gateway:session:",
},
wantErrIs: ErrNotFound,
assertErrText: "session cache record not found",
},
{
name: "revoked session",
requestID: "device-session-revoked",
cfg: config.SessionCacheRedisConfig{
KeyPrefix: "gateway:session:",
},
seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) {
t.Helper()
setRedisSessionRecord(t, server, cfg.KeyPrefix+"device-session-revoked", redisRecord{
DeviceSessionID: "device-session-revoked",
UserID: "user-777",
ClientPublicKey: "public-key-777",
Status: StatusRevoked,
RevokedAtMS: &revokedAtMS,
})
},
want: Record{
DeviceSessionID: "device-session-revoked",
UserID: "user-777",
ClientPublicKey: "public-key-777",
Status: StatusRevoked,
RevokedAtMS: &revokedAtMS,
},
},
{
name: "malformed json",
requestID: "device-session-bad-json",
cfg: config.SessionCacheRedisConfig{
KeyPrefix: "gateway:session:",
},
seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) {
t.Helper()
server.Set(cfg.KeyPrefix+"device-session-bad-json", "{")
},
wantErrText: "decode redis session record",
},
{
name: "unknown status",
requestID: "device-session-unknown-status",
cfg: config.SessionCacheRedisConfig{
KeyPrefix: "gateway:session:",
},
seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) {
t.Helper()
setRedisSessionRecord(t, server, cfg.KeyPrefix+"device-session-unknown-status", redisRecord{
DeviceSessionID: "device-session-unknown-status",
UserID: "user-1",
ClientPublicKey: "public-key-1",
Status: Status("paused"),
})
},
wantErrText: `status "paused" is unsupported`,
},
{
name: "missing required field",
requestID: "device-session-missing-user",
cfg: config.SessionCacheRedisConfig{
KeyPrefix: "gateway:session:",
},
seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) {
t.Helper()
setRedisSessionRecord(t, server, cfg.KeyPrefix+"device-session-missing-user", redisRecord{
DeviceSessionID: "device-session-missing-user",
ClientPublicKey: "public-key-1",
Status: StatusActive,
})
},
wantErrText: "user_id must not be empty",
},
{
name: "device session id mismatch",
requestID: "device-session-requested",
cfg: config.SessionCacheRedisConfig{
KeyPrefix: "gateway:session:",
},
seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) {
t.Helper()
setRedisSessionRecord(t, server, cfg.KeyPrefix+"device-session-requested", redisRecord{
DeviceSessionID: "device-session-other",
UserID: "user-1",
ClientPublicKey: "public-key-1",
Status: StatusActive,
})
},
wantErrText: `does not match requested "device-session-requested"`,
},
{
name: "key prefix is honored",
requestID: "device-session-prefixed",
cfg: config.SessionCacheRedisConfig{
KeyPrefix: "custom:session:",
},
seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) {
t.Helper()
setRedisSessionRecord(t, server, cfg.KeyPrefix+"device-session-prefixed", redisRecord{
DeviceSessionID: "device-session-prefixed",
UserID: "user-prefixed",
ClientPublicKey: "public-key-prefixed",
Status: StatusActive,
})
setRedisSessionRecord(t, server, "gateway:session:device-session-prefixed", redisRecord{
DeviceSessionID: "device-session-prefixed",
UserID: "wrong-user",
ClientPublicKey: "wrong-key",
Status: StatusRevoked,
})
},
want: Record{
DeviceSessionID: "device-session-prefixed",
UserID: "user-prefixed",
ClientPublicKey: "public-key-prefixed",
Status: StatusActive,
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
cfg := tt.cfg
cfg.Addr = server.Addr()
cfg.DB = 0
cfg.LookupTimeout = 250 * time.Millisecond
if tt.seed != nil {
tt.seed(t, server, cfg)
}
cache := newTestRedisCache(t, server, cfg)
record, err := cache.Lookup(context.Background(), tt.requestID)
if tt.wantErrIs != nil || tt.wantErrText != "" {
require.Error(t, err)
if tt.wantErrIs != nil {
assert.ErrorIs(t, err, tt.wantErrIs)
}
if tt.wantErrText != "" {
assert.ErrorContains(t, err, tt.wantErrText)
}
if tt.assertErrText != "" {
assert.ErrorContains(t, err, tt.assertErrText)
}
return
}
require.NoError(t, err)
assert.Equal(t, tt.want, record)
})
}
}
func newTestRedisCache(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) *RedisCache {
t.Helper()
if cfg.Addr == "" {
cfg.Addr = server.Addr()
}
if cfg.LookupTimeout == 0 {
cfg.LookupTimeout = 250 * time.Millisecond
}
cache, err := NewRedisCache(cfg)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, cache.Close())
})
return cache
}
func setRedisSessionRecord(t *testing.T, server *miniredis.Miniredis, key string, record redisRecord) {
t.Helper()
payload, err := json.Marshal(record)
require.NoError(t, err)
server.Set(key, string(payload))
}
func TestRedisCacheLookupNilContext(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
cache := newTestRedisCache(t, server, config.SessionCacheRedisConfig{})
_, err := cache.Lookup(context.TODO(), "device-session-123")
require.Error(t, err)
assert.False(t, errors.Is(err, ErrNotFound))
assert.ErrorContains(t, err, "nil context")
}
+80
View File
@@ -0,0 +1,80 @@
// Package session defines the authenticated session-cache contract used by the
// gateway hot path.
package session
import (
"context"
"errors"
)
var (
// ErrNotFound reports that SessionCache does not currently contain the
// requested device session identifier.
ErrNotFound = errors.New("session cache record not found")
)
// Cache resolves authenticated device-session state from the gateway hot-path
// cache.
type Cache interface {
// Lookup returns the cached record for deviceSessionID. Implementations must
// wrap ErrNotFound when the cache does not contain the requested record.
Lookup(ctx context.Context, deviceSessionID string) (Record, error)
}
// SnapshotStore stores mutable session record snapshots inside one gateway
// process and exposes the same read contract as Cache for the hot path.
type SnapshotStore interface {
Cache
// Upsert stores record under record.DeviceSessionID, replacing any previous
// snapshot for that session.
Upsert(record Record) error
// Delete removes the local snapshot for deviceSessionID when it exists.
Delete(deviceSessionID string)
}
// Status identifies the cached lifecycle state of a device session.
type Status string
const (
// StatusActive reports that the cached device session may continue through
// later authenticated gateway checks.
StatusActive Status = "active"
// StatusRevoked reports that the cached device session has been revoked and
// must be rejected before later auth steps run.
StatusRevoked Status = "revoked"
)
// Record is the minimum authenticated session state required by the gateway
// before signature verification begins.
type Record struct {
// DeviceSessionID is the stable device-session identifier resolved from the
// hot-path cache.
DeviceSessionID string
// UserID is the authenticated user identity bound to DeviceSessionID.
UserID string
// ClientPublicKey is the standard base64-encoded raw Ed25519 public key
// material used for request-signature verification.
ClientPublicKey string
// Status reports whether the cached session is active or revoked.
Status Status
// RevokedAtMS optionally records when the device session was revoked.
RevokedAtMS *int64
}
// IsKnown reports whether s is one of the session states supported by the
// gateway.
func (s Status) IsKnown() bool {
switch s {
case StatusActive, StatusRevoked:
return true
default:
return false
}
}
+102
View File
@@ -0,0 +1,102 @@
// Package telemetry provides shared edge observability helpers used by the
// gateway transports and internal event consumers.
package telemetry
import (
"net/http"
"strings"
"google.golang.org/grpc/codes"
)
// EdgeOutcome is the stable low-cardinality outcome vocabulary shared by REST,
// gRPC, push shutdown, and observability backends.
type EdgeOutcome string
const (
EdgeOutcomeSuccess EdgeOutcome = "success"
EdgeOutcomeMalformedRequest EdgeOutcome = "malformed_request"
EdgeOutcomeRequestTooLarge EdgeOutcome = "request_too_large"
EdgeOutcomeUnsupportedProtocol EdgeOutcome = "unsupported_protocol"
EdgeOutcomeUnknownSession EdgeOutcome = "unknown_session"
EdgeOutcomeRevokedSession EdgeOutcome = "revoked_session"
EdgeOutcomeInvalidSignature EdgeOutcome = "invalid_signature"
EdgeOutcomeStaleRequest EdgeOutcome = "stale_request"
EdgeOutcomeReplayDetected EdgeOutcome = "replay_detected"
EdgeOutcomeRateLimited EdgeOutcome = "rate_limited"
EdgeOutcomePolicyDenied EdgeOutcome = "policy_denied"
EdgeOutcomeDownstreamUnavailable EdgeOutcome = "downstream_unavailable"
EdgeOutcomeBackendUnavailable EdgeOutcome = "backend_unavailable"
EdgeOutcomeInternalError EdgeOutcome = "internal_error"
EdgeOutcomeGatewayShuttingDown EdgeOutcome = "gateway_shutting_down"
)
// RejectReason returns the stable reject reason for outcome. Success does not
// produce a reject reason.
func RejectReason(outcome EdgeOutcome) string {
if outcome == EdgeOutcomeSuccess {
return ""
}
return string(outcome)
}
// OutcomeFromPublicErrorCode maps the stable public REST error envelope into
// the shared edge-outcome vocabulary.
func OutcomeFromPublicErrorCode(statusCode int, code string) EdgeOutcome {
switch strings.TrimSpace(code) {
case "":
if statusCode < http.StatusBadRequest {
return EdgeOutcomeSuccess
}
return EdgeOutcomeInternalError
case "invalid_request", "method_not_allowed", "not_found":
return EdgeOutcomeMalformedRequest
case "request_too_large":
return EdgeOutcomeRequestTooLarge
case "rate_limited":
return EdgeOutcomeRateLimited
case "service_unavailable":
return EdgeOutcomeBackendUnavailable
default:
if statusCode >= http.StatusInternalServerError {
return EdgeOutcomeInternalError
}
return EdgeOutcomeMalformedRequest
}
}
// OutcomeFromGRPCStatus maps the stable authenticated gRPC reject contract
// into the shared edge-outcome vocabulary.
func OutcomeFromGRPCStatus(code codes.Code, message string) EdgeOutcome {
switch {
case code == codes.OK:
return EdgeOutcomeSuccess
case code == codes.InvalidArgument:
return EdgeOutcomeMalformedRequest
case code == codes.FailedPrecondition && strings.Contains(message, "unsupported protocol_version"):
return EdgeOutcomeUnsupportedProtocol
case code == codes.Unauthenticated && message == "unknown device session":
return EdgeOutcomeUnknownSession
case code == codes.FailedPrecondition && message == "device session is revoked":
return EdgeOutcomeRevokedSession
case code == codes.Unauthenticated && message == "invalid request signature":
return EdgeOutcomeInvalidSignature
case code == codes.FailedPrecondition && message == "request timestamp is outside the freshness window":
return EdgeOutcomeStaleRequest
case code == codes.FailedPrecondition && message == "request replay detected":
return EdgeOutcomeReplayDetected
case code == codes.ResourceExhausted && message == "authenticated request rate limit exceeded":
return EdgeOutcomeRateLimited
case code == codes.PermissionDenied && message == "authenticated request rejected by edge policy":
return EdgeOutcomePolicyDenied
case code == codes.Unavailable && message == "downstream service is unavailable":
return EdgeOutcomeDownstreamUnavailable
case code == codes.Unavailable && message == "gateway is shutting down":
return EdgeOutcomeGatewayShuttingDown
case code == codes.Unavailable:
return EdgeOutcomeBackendUnavailable
default:
return EdgeOutcomeInternalError
}
}
+48
View File
@@ -0,0 +1,48 @@
package telemetry
import (
"context"
"errors"
"galaxy/gateway/internal/push"
"go.opentelemetry.io/otel/attribute"
)
// PushObserver adapts Runtime to the push.Observer interface.
type PushObserver struct {
runtime *Runtime
}
// NewPushObserver constructs a push stream observer backed by runtime.
func NewPushObserver(runtime *Runtime) *PushObserver {
return &PushObserver{runtime: runtime}
}
// Registered records one active push stream.
func (o *PushObserver) Registered(_ push.StreamBinding) {
if o == nil || o.runtime == nil {
return
}
o.runtime.AddActivePushStream(context.Background(), 1)
}
// Unregistered records one active-stream decrement and one closure reason for
// hub-enforced shutdown, overflow, or revocation.
func (o *PushObserver) Unregistered(_ push.StreamBinding, err error) {
if o == nil || o.runtime == nil {
return
}
o.runtime.AddActivePushStream(context.Background(), -1)
switch {
case errors.Is(err, push.ErrSubscriptionOverflow):
o.runtime.RecordPushStreamClosure(context.Background(), attribute.String("reason", "overflow"))
case errors.Is(err, push.ErrSubscriptionRevoked):
o.runtime.RecordPushStreamClosure(context.Background(), attribute.String("reason", "revoked"))
case errors.Is(err, push.ErrHubShuttingDown):
o.runtime.RecordPushStreamClosure(context.Background(), attribute.String("reason", "shutdown"))
}
}
+254
View File
@@ -0,0 +1,254 @@
package telemetry
import (
"context"
"errors"
"net/http"
"os"
"strings"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp"
otelprom "go.opentelemetry.io/otel/exporters/prometheus"
"go.opentelemetry.io/otel/metric"
"go.opentelemetry.io/otel/propagation"
sdkmetric "go.opentelemetry.io/otel/sdk/metric"
"go.opentelemetry.io/otel/sdk/resource"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
"go.uber.org/zap"
)
const defaultServiceName = "galaxy-edge-gateway"
// Runtime owns the shared OpenTelemetry providers, the Prometheus metrics
// handler, and the custom low-cardinality edge instruments.
type Runtime struct {
logger *zap.Logger
tracerProvider *sdktrace.TracerProvider
meterProvider *sdkmetric.MeterProvider
promHandler http.Handler
// Public REST instruments.
publicRequests metric.Int64Counter
publicDuration metric.Float64Histogram
// Authenticated gRPC instruments.
grpcRequests metric.Int64Counter
grpcDuration metric.Float64Histogram
// Push instruments.
pushActiveStreams metric.Int64UpDownCounter
pushStreamClosers metric.Int64Counter
// Internal event consumer instruments.
internalEventDrops metric.Int64Counter
}
// New constructs the gateway telemetry runtime, registers global providers,
// and returns the Prometheus handler used by the admin listener.
func New(ctx context.Context, logger *zap.Logger) (*Runtime, error) {
if logger == nil {
logger = zap.NewNop()
}
serviceName := strings.TrimSpace(os.Getenv("OTEL_SERVICE_NAME"))
if serviceName == "" {
serviceName = defaultServiceName
}
res, err := resource.New(
ctx,
resource.WithAttributes(attribute.String("service.name", serviceName)),
)
if err != nil {
return nil, err
}
tracerProvider, err := newTracerProvider(ctx, res)
if err != nil {
return nil, err
}
registry := prometheus.NewRegistry()
exporter, err := otelprom.New(otelprom.WithRegisterer(registry))
if err != nil {
return nil, err
}
meterProvider := sdkmetric.NewMeterProvider(
sdkmetric.WithResource(res),
sdkmetric.WithReader(exporter),
)
otel.SetTracerProvider(tracerProvider)
otel.SetMeterProvider(meterProvider)
otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator(
propagation.TraceContext{},
propagation.Baggage{},
))
meter := meterProvider.Meter("galaxy/gateway")
publicRequests, err := meter.Int64Counter("gateway.public_http.requests")
if err != nil {
return nil, err
}
publicDuration, err := meter.Float64Histogram("gateway.public_http.duration", metric.WithUnit("ms"))
if err != nil {
return nil, err
}
grpcRequests, err := meter.Int64Counter("gateway.authenticated_grpc.requests")
if err != nil {
return nil, err
}
grpcDuration, err := meter.Float64Histogram("gateway.authenticated_grpc.duration", metric.WithUnit("ms"))
if err != nil {
return nil, err
}
pushActiveStreams, err := meter.Int64UpDownCounter("gateway.push.active_streams")
if err != nil {
return nil, err
}
pushStreamClosers, err := meter.Int64Counter("gateway.push.stream_closures")
if err != nil {
return nil, err
}
internalEventDrops, err := meter.Int64Counter("gateway.internal_event_drops")
if err != nil {
return nil, err
}
return &Runtime{
logger: logger,
tracerProvider: tracerProvider,
meterProvider: meterProvider,
promHandler: promhttp.HandlerFor(registry, promhttp.HandlerOpts{}),
publicRequests: publicRequests,
publicDuration: publicDuration,
grpcRequests: grpcRequests,
grpcDuration: grpcDuration,
pushActiveStreams: pushActiveStreams,
pushStreamClosers: pushStreamClosers,
internalEventDrops: internalEventDrops,
}, nil
}
// Handler returns the Prometheus handler that should be mounted on the admin
// listener.
func (r *Runtime) Handler() http.Handler {
if r == nil || r.promHandler == nil {
return http.NotFoundHandler()
}
return r.promHandler
}
// Shutdown flushes the configured telemetry providers.
func (r *Runtime) Shutdown(ctx context.Context) error {
if r == nil {
return nil
}
var shutdownErr error
if r.meterProvider != nil {
shutdownErr = errors.Join(shutdownErr, r.meterProvider.Shutdown(ctx))
}
if r.tracerProvider != nil {
shutdownErr = errors.Join(shutdownErr, r.tracerProvider.Shutdown(ctx))
}
return shutdownErr
}
// RecordPublicRequest records one public REST request outcome.
func (r *Runtime) RecordPublicRequest(ctx context.Context, attrs []attribute.KeyValue, duration time.Duration) {
if r == nil {
return
}
options := metric.WithAttributes(attrs...)
r.publicRequests.Add(ctx, 1, options)
r.publicDuration.Record(ctx, duration.Seconds()*1000, options)
}
// RecordAuthenticatedGRPC records one authenticated gRPC request or stream
// outcome.
func (r *Runtime) RecordAuthenticatedGRPC(ctx context.Context, attrs []attribute.KeyValue, duration time.Duration) {
if r == nil {
return
}
options := metric.WithAttributes(attrs...)
r.grpcRequests.Add(ctx, 1, options)
r.grpcDuration.Record(ctx, duration.Seconds()*1000, options)
}
// AddActivePushStream records one active-stream delta.
func (r *Runtime) AddActivePushStream(ctx context.Context, delta int64, attrs ...attribute.KeyValue) {
if r == nil {
return
}
r.pushActiveStreams.Add(ctx, delta, metric.WithAttributes(attrs...))
}
// RecordPushStreamClosure records one push-stream closure reason.
func (r *Runtime) RecordPushStreamClosure(ctx context.Context, attrs ...attribute.KeyValue) {
if r == nil {
return
}
r.pushStreamClosers.Add(ctx, 1, metric.WithAttributes(attrs...))
}
// RecordInternalEventDrop records one malformed or rejected internal event.
func (r *Runtime) RecordInternalEventDrop(ctx context.Context, attrs ...attribute.KeyValue) {
if r == nil {
return
}
r.internalEventDrops.Add(ctx, 1, metric.WithAttributes(attrs...))
}
func newTracerProvider(ctx context.Context, res *resource.Resource) (*sdktrace.TracerProvider, error) {
exporterName := strings.TrimSpace(os.Getenv("OTEL_TRACES_EXPORTER"))
if exporterName == "" || exporterName == "none" {
return sdktrace.NewTracerProvider(sdktrace.WithResource(res)), nil
}
if exporterName != "otlp" {
return nil, errors.New("unsupported OTEL_TRACES_EXPORTER value")
}
protocol := strings.TrimSpace(os.Getenv("OTEL_EXPORTER_OTLP_TRACES_PROTOCOL"))
if protocol == "" {
protocol = strings.TrimSpace(os.Getenv("OTEL_EXPORTER_OTLP_PROTOCOL"))
}
var (
exporter sdktrace.SpanExporter
err error
)
switch protocol {
case "", "http/protobuf":
exporter, err = otlptracehttp.New(ctx)
case "grpc":
exporter, err = otlptracegrpc.New(ctx)
default:
return nil, errors.New("unsupported OTEL exporter protocol")
}
if err != nil {
return nil, err
}
return sdktrace.NewTracerProvider(
sdktrace.WithBatcher(exporter),
sdktrace.WithResource(res),
), nil
}
@@ -0,0 +1,94 @@
package testutil
import (
"bytes"
"context"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"galaxy/gateway/internal/logging"
"galaxy/gateway/internal/telemetry"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
// LogBuffer is a concurrency-safe in-memory buffer used by observability
// tests.
type LogBuffer struct {
mu sync.Mutex
buf bytes.Buffer
}
// Write appends p to the buffer.
func (b *LogBuffer) Write(p []byte) (int, error) {
b.mu.Lock()
defer b.mu.Unlock()
return b.buf.Write(p)
}
// String returns the current buffer contents.
func (b *LogBuffer) String() string {
b.mu.Lock()
defer b.mu.Unlock()
return b.buf.String()
}
// NewObservedLogger constructs a JSON zap logger that writes into an in-memory
// buffer suitable for log assertions.
func NewObservedLogger(t *testing.T) (*zap.Logger, *LogBuffer) {
t.Helper()
buffer := &LogBuffer{}
encoderConfig := zap.NewProductionEncoderConfig()
encoderConfig.TimeKey = "timestamp"
encoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
core := zapcore.NewCore(
zapcore.NewJSONEncoder(encoderConfig),
zapcore.Lock(zapcore.AddSync(buffer)),
zap.DebugLevel,
)
logger := zap.New(core)
t.Cleanup(func() {
require.NoError(t, logging.Sync(logger))
})
return logger, buffer
}
// NewTelemetryRuntime constructs a telemetry runtime for tests and shuts it
// down automatically.
func NewTelemetryRuntime(t *testing.T, logger *zap.Logger) *telemetry.Runtime {
t.Helper()
runtime, err := telemetry.New(context.Background(), logger)
require.NoError(t, err)
t.Cleanup(func() {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
require.NoError(t, runtime.Shutdown(ctx))
})
return runtime
}
// ScrapeMetrics returns the Prometheus exposition produced by handler.
func ScrapeMetrics(t *testing.T, handler http.Handler) string {
t.Helper()
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
require.Equal(t, http.StatusOK, recorder.Code)
return recorder.Body.String()
}