feat: backend service

This commit is contained in:
Ilia Denisov
2026-05-06 10:14:55 +03:00
committed by GitHub
parent 3e2622757e
commit f446c6a2ac
1486 changed files with 49720 additions and 266401 deletions
-80
View File
@@ -1,80 +0,0 @@
package authn
import (
"crypto/ed25519"
"encoding/binary"
"errors"
)
const (
// EventDomainMarkerV1 binds the v1 server event signature to the Galaxy
// gateway transport contract.
EventDomainMarkerV1 = "galaxy-event-v1"
)
var (
// ErrInvalidEventSignature reports that a gateway stream event signature is
// not a raw Ed25519 signature for the canonical event signing input.
ErrInvalidEventSignature = errors.New("invalid event signature")
)
// EventSigningFields contains the canonical v1 stream-event fields that are
// bound into the server signing input.
type EventSigningFields struct {
// EventType identifies the stable client-facing event category.
EventType string
// EventID is the stable event correlation identifier.
EventID string
// TimestampMS carries the server event timestamp in milliseconds.
TimestampMS int64
// RequestID optionally correlates the event to the opening client request.
RequestID string
// TraceID optionally carries the client-supplied tracing correlation value.
TraceID string
// PayloadHash is the raw SHA-256 digest of event payload bytes.
PayloadHash []byte
}
// BuildEventSigningInput returns the canonical byte sequence the v1 gateway
// stream-event signature covers. String and byte fields are length-prefixed
// with uvarint(len(field)) followed by raw bytes, while TimestampMS is
// appended as an 8-byte big-endian uint64.
func BuildEventSigningInput(fields EventSigningFields) []byte {
size := len(EventDomainMarkerV1) +
len(fields.EventType) +
len(fields.EventID) +
len(fields.RequestID) +
len(fields.TraceID) +
len(fields.PayloadHash) +
(6 * binary.MaxVarintLen64) +
8
buf := make([]byte, 0, size)
buf = appendLengthPrefixedString(buf, EventDomainMarkerV1)
buf = appendLengthPrefixedString(buf, fields.EventType)
buf = appendLengthPrefixedString(buf, fields.EventID)
buf = binary.BigEndian.AppendUint64(buf, uint64(fields.TimestampMS))
buf = appendLengthPrefixedString(buf, fields.RequestID)
buf = appendLengthPrefixedString(buf, fields.TraceID)
buf = appendLengthPrefixedBytes(buf, fields.PayloadHash)
return buf
}
// VerifyEventSignature verifies that signature authenticates fields under
// publicKey using the canonical v1 event signing input.
func VerifyEventSignature(publicKey ed25519.PublicKey, signature []byte, fields EventSigningFields) error {
if len(publicKey) != ed25519.PublicKeySize || len(signature) != ed25519.SignatureSize {
return ErrInvalidEventSignature
}
if !ed25519.Verify(publicKey, BuildEventSigningInput(fields), signature) {
return ErrInvalidEventSignature
}
return nil
}
-111
View File
@@ -1,111 +0,0 @@
package authn
import (
"bytes"
"crypto/ed25519"
"crypto/rand"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestBuildEventSigningInputChangesWhenSignedFieldChanges(t *testing.T) {
t.Parallel()
base := EventSigningFields{
EventType: "gateway.server_time",
EventID: "request-123",
TimestampMS: 123456789,
RequestID: "request-123",
TraceID: "trace-123",
PayloadHash: mustSHA256([]byte("payload")),
}
baseInput := BuildEventSigningInput(base)
tests := []struct {
name string
mutate func(EventSigningFields) EventSigningFields
}{
{
name: "event type",
mutate: func(fields EventSigningFields) EventSigningFields {
fields.EventType = "gateway.other"
return fields
},
},
{
name: "event id",
mutate: func(fields EventSigningFields) EventSigningFields {
fields.EventID = "request-456"
return fields
},
},
{
name: "timestamp",
mutate: func(fields EventSigningFields) EventSigningFields {
fields.TimestampMS++
return fields
},
},
{
name: "request id",
mutate: func(fields EventSigningFields) EventSigningFields {
fields.RequestID = "request-456"
return fields
},
},
{
name: "trace id",
mutate: func(fields EventSigningFields) EventSigningFields {
fields.TraceID = "trace-456"
return fields
},
},
{
name: "payload hash",
mutate: func(fields EventSigningFields) EventSigningFields {
fields.PayloadHash = mustSHA256([]byte("other"))
return fields
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
mutated := BuildEventSigningInput(tt.mutate(base))
assert.False(t, bytes.Equal(baseInput, mutated))
})
}
}
func TestSignAndVerifyEventSignature(t *testing.T) {
t.Parallel()
_, privateKey, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
signer, err := NewEd25519ResponseSigner(privateKey)
require.NoError(t, err)
fields := EventSigningFields{
EventType: "gateway.server_time",
EventID: "request-123",
TimestampMS: 123456789,
RequestID: "request-123",
TraceID: "trace-123",
PayloadHash: mustSHA256([]byte("payload")),
}
signature, err := signer.SignEvent(fields)
require.NoError(t, err)
require.NoError(t, VerifyEventSignature(signer.PublicKey(), signature, fields))
fields.TraceID = "changed"
require.ErrorIs(t, VerifyEventSignature(signer.PublicKey(), signature, fields), ErrInvalidEventSignature)
}
-101
View File
@@ -1,101 +0,0 @@
// Package authn defines authenticated transport helpers shared by the gateway
// edge verification pipeline.
package authn
import (
"bytes"
"crypto/sha256"
"encoding/binary"
"errors"
)
const (
// RequestDomainMarkerV1 binds the v1 client request signature to the Galaxy
// gateway transport contract.
RequestDomainMarkerV1 = "galaxy-request-v1"
)
var (
// ErrInvalidPayloadHash reports that payloadHash is not a raw SHA-256 digest.
ErrInvalidPayloadHash = errors.New("payload_hash must be a 32-byte SHA-256 digest")
// ErrPayloadHashMismatch reports that payloadHash does not match payloadBytes.
ErrPayloadHashMismatch = errors.New("payload_hash does not match payload_bytes")
)
// RequestSigningFields contains the canonical v1 request fields that are bound
// into the client signing input after the gateway validates and normalizes the
// request envelope.
type RequestSigningFields struct {
// ProtocolVersion identifies the transport envelope version.
ProtocolVersion string
// DeviceSessionID identifies the authenticated device session bound to the
// request.
DeviceSessionID string
// MessageType is the stable downstream routing key.
MessageType string
// TimestampMS carries the client request timestamp in milliseconds.
TimestampMS int64
// RequestID is the transport correlation and anti-replay identifier.
RequestID string
// PayloadHash is the raw SHA-256 digest of payload bytes.
PayloadHash []byte
}
// BuildRequestSigningInput returns the canonical byte sequence the v1 client
// request signature covers. String and byte fields are length-prefixed with
// uvarint(len(field)) followed by raw bytes, while TimestampMS is appended as
// an 8-byte big-endian uint64. The caller is expected to pass fields that have
// already passed earlier envelope validation.
func BuildRequestSigningInput(fields RequestSigningFields) []byte {
size := len(RequestDomainMarkerV1) +
len(fields.ProtocolVersion) +
len(fields.DeviceSessionID) +
len(fields.MessageType) +
len(fields.RequestID) +
len(fields.PayloadHash) +
(6 * binary.MaxVarintLen64) +
8
buf := make([]byte, 0, size)
buf = appendLengthPrefixedString(buf, RequestDomainMarkerV1)
buf = appendLengthPrefixedString(buf, fields.ProtocolVersion)
buf = appendLengthPrefixedString(buf, fields.DeviceSessionID)
buf = appendLengthPrefixedString(buf, fields.MessageType)
buf = binary.BigEndian.AppendUint64(buf, uint64(fields.TimestampMS))
buf = appendLengthPrefixedString(buf, fields.RequestID)
buf = appendLengthPrefixedBytes(buf, fields.PayloadHash)
return buf
}
// VerifyPayloadHash checks that payloadHash is the raw SHA-256 digest of
// payloadBytes. Empty payloadBytes are valid and must use sha256.Sum256(nil).
func VerifyPayloadHash(payloadBytes, payloadHash []byte) error {
if len(payloadHash) != sha256.Size {
return ErrInvalidPayloadHash
}
sum := sha256.Sum256(payloadBytes)
if !bytes.Equal(sum[:], payloadHash) {
return ErrPayloadHashMismatch
}
return nil
}
func appendLengthPrefixedString(dst []byte, value string) []byte {
return appendLengthPrefixedBytes(dst, []byte(value))
}
func appendLengthPrefixedBytes(dst []byte, value []byte) []byte {
dst = binary.AppendUvarint(dst, uint64(len(value)))
dst = append(dst, value...)
return dst
}
-163
View File
@@ -1,163 +0,0 @@
package authn
import (
"bytes"
"crypto/sha256"
"encoding/hex"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestVerifyPayloadHash(t *testing.T) {
t.Parallel()
payloadSum := sha256.Sum256([]byte("payload"))
emptySum := sha256.Sum256(nil)
otherSum := sha256.Sum256([]byte("other"))
tests := []struct {
name string
payload []byte
payloadHash []byte
wantErr error
}{
{
name: "matches non-empty payload",
payload: []byte("payload"),
payloadHash: payloadSum[:],
},
{
name: "matches empty payload",
payload: nil,
payloadHash: emptySum[:],
},
{
name: "rejects digest with invalid length",
payload: []byte("payload"),
payloadHash: []byte("short"),
wantErr: ErrInvalidPayloadHash,
},
{
name: "rejects digest mismatch",
payload: []byte("payload"),
payloadHash: otherSum[:],
wantErr: ErrPayloadHashMismatch,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
err := VerifyPayloadHash(tt.payload, tt.payloadHash)
if tt.wantErr == nil {
require.NoError(t, err)
return
}
require.ErrorIs(t, err, tt.wantErr)
})
}
}
func TestBuildRequestSigningInput(t *testing.T) {
t.Parallel()
fields := RequestSigningFields{
ProtocolVersion: "v1",
DeviceSessionID: "device-session-123",
MessageType: "fleet.move",
TimestampMS: 123456789,
RequestID: "request-123",
PayloadHash: mustSHA256([]byte("payload")),
}
got := BuildRequestSigningInput(fields)
want, err := hex.DecodeString("1167616c6178792d726571756573742d7631027631126465766963652d73657373696f6e2d3132330a666c6565742e6d6f766500000000075bcd150b726571756573742d31323320239f59ed55e737c77147cf55ad0c1b030b6d7ee748a7426952f9b852d5a935e5")
require.NoError(t, err)
assert.Equal(t, want, got)
}
func TestBuildRequestSigningInputChangesWhenSignedFieldChanges(t *testing.T) {
t.Parallel()
base := RequestSigningFields{
ProtocolVersion: "v1",
DeviceSessionID: "device-session-123",
MessageType: "fleet.move",
TimestampMS: 123456789,
RequestID: "request-123",
PayloadHash: mustSHA256([]byte("payload")),
}
baseInput := BuildRequestSigningInput(base)
tests := []struct {
name string
mutate func(RequestSigningFields) RequestSigningFields
}{
{
name: "protocol version",
mutate: func(fields RequestSigningFields) RequestSigningFields {
fields.ProtocolVersion = "v2"
return fields
},
},
{
name: "device session id",
mutate: func(fields RequestSigningFields) RequestSigningFields {
fields.DeviceSessionID = "device-session-456"
return fields
},
},
{
name: "message type",
mutate: func(fields RequestSigningFields) RequestSigningFields {
fields.MessageType = "fleet.attack"
return fields
},
},
{
name: "timestamp",
mutate: func(fields RequestSigningFields) RequestSigningFields {
fields.TimestampMS++
return fields
},
},
{
name: "request id",
mutate: func(fields RequestSigningFields) RequestSigningFields {
fields.RequestID = "request-456"
return fields
},
},
{
name: "payload hash",
mutate: func(fields RequestSigningFields) RequestSigningFields {
fields.PayloadHash = mustSHA256([]byte("other"))
return fields
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
mutated := BuildRequestSigningInput(tt.mutate(base))
assert.False(t, bytes.Equal(baseInput, mutated))
})
}
}
func mustSHA256(payload []byte) []byte {
sum := sha256.Sum256(payload)
return sum[:]
}
-189
View File
@@ -1,189 +0,0 @@
package authn
import (
"bytes"
"crypto/ed25519"
"crypto/x509"
"encoding/binary"
"encoding/pem"
"errors"
"fmt"
"os"
)
const (
// ResponseDomainMarkerV1 binds the v1 server response signature to the
// Galaxy gateway transport contract.
ResponseDomainMarkerV1 = "galaxy-response-v1"
)
var (
// ErrInvalidResponsePrivateKeyPEM reports that the configured response
// signer private key is not a strict PKCS#8 PEM-encoded private key.
ErrInvalidResponsePrivateKeyPEM = errors.New("response signer private key is not a valid PKCS#8 PEM block")
// ErrInvalidResponsePrivateKey reports that the configured response signer
// private key is not an Ed25519 private key.
ErrInvalidResponsePrivateKey = errors.New("response signer private key must be an Ed25519 PKCS#8 private key")
// ErrInvalidResponseSignature reports that a server response signature is
// not a raw Ed25519 signature for the canonical response signing input.
ErrInvalidResponseSignature = errors.New("invalid response signature")
)
// ResponseSigningFields contains the canonical v1 response fields that are
// bound into the server signing input.
type ResponseSigningFields struct {
// ProtocolVersion identifies the transport envelope version.
ProtocolVersion string
// RequestID is the transport correlation identifier copied from the
// authenticated request.
RequestID string
// TimestampMS carries the server response timestamp in milliseconds.
TimestampMS int64
// ResultCode is the opaque downstream result code returned to the client.
ResultCode string
// PayloadHash is the raw SHA-256 digest of response payload bytes.
PayloadHash []byte
}
// ResponseSigner signs authenticated unary responses and client-facing stream
// events with one server-side key.
type ResponseSigner interface {
// SignResponse returns the raw Ed25519 signature for the canonical response
// signing input built from fields.
SignResponse(fields ResponseSigningFields) ([]byte, error)
// SignEvent returns the raw Ed25519 signature for the canonical event
// signing input built from fields.
SignEvent(fields EventSigningFields) ([]byte, error)
}
// Ed25519ResponseSigner signs authenticated responses with one Ed25519 private
// key loaded during process startup.
type Ed25519ResponseSigner struct {
privateKey ed25519.PrivateKey
}
// NewEd25519ResponseSigner validates privateKey and constructs a signer using
// a defensive key copy.
func NewEd25519ResponseSigner(privateKey ed25519.PrivateKey) (*Ed25519ResponseSigner, error) {
if len(privateKey) != ed25519.PrivateKeySize {
return nil, ErrInvalidResponsePrivateKey
}
return &Ed25519ResponseSigner{
privateKey: bytes.Clone(privateKey),
}, nil
}
// LoadEd25519ResponseSignerFromPEMFile loads a strict PKCS#8 PEM-encoded
// Ed25519 private key from path and constructs a signer.
func LoadEd25519ResponseSignerFromPEMFile(path string) (*Ed25519ResponseSigner, error) {
pemBytes, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("read response signer private key PEM: %w", err)
}
signer, err := ParseEd25519ResponseSignerPEM(pemBytes)
if err != nil {
return nil, err
}
return signer, nil
}
// ParseEd25519ResponseSignerPEM parses one strict PKCS#8 PEM-encoded Ed25519
// private key and constructs a signer from it.
func ParseEd25519ResponseSignerPEM(pemBytes []byte) (*Ed25519ResponseSigner, error) {
block, rest := pem.Decode(pemBytes)
if block == nil || block.Type != "PRIVATE KEY" || len(bytes.TrimSpace(rest)) > 0 {
return nil, ErrInvalidResponsePrivateKeyPEM
}
parsedKey, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return nil, ErrInvalidResponsePrivateKeyPEM
}
privateKey, ok := parsedKey.(ed25519.PrivateKey)
if !ok {
return nil, ErrInvalidResponsePrivateKey
}
return NewEd25519ResponseSigner(privateKey)
}
// PublicKey returns the Ed25519 public key that corresponds to the configured
// response signer private key.
func (s *Ed25519ResponseSigner) PublicKey() ed25519.PublicKey {
if s == nil {
return nil
}
publicKey, _ := s.privateKey.Public().(ed25519.PublicKey)
return bytes.Clone(publicKey)
}
// SignResponse signs the canonical v1 response signing input built from
// fields.
func (s *Ed25519ResponseSigner) SignResponse(fields ResponseSigningFields) ([]byte, error) {
if s == nil || len(s.privateKey) != ed25519.PrivateKeySize {
return nil, ErrInvalidResponsePrivateKey
}
signature := ed25519.Sign(s.privateKey, BuildResponseSigningInput(fields))
return bytes.Clone(signature), nil
}
// SignEvent signs the canonical v1 stream-event signing input built from
// fields.
func (s *Ed25519ResponseSigner) SignEvent(fields EventSigningFields) ([]byte, error) {
if s == nil || len(s.privateKey) != ed25519.PrivateKeySize {
return nil, ErrInvalidResponsePrivateKey
}
signature := ed25519.Sign(s.privateKey, BuildEventSigningInput(fields))
return bytes.Clone(signature), nil
}
// BuildResponseSigningInput returns the canonical byte sequence the v1 server
// response signature covers. String and byte fields are length-prefixed with
// uvarint(len(field)) followed by raw bytes, while TimestampMS is appended as
// an 8-byte big-endian uint64.
func BuildResponseSigningInput(fields ResponseSigningFields) []byte {
size := len(ResponseDomainMarkerV1) +
len(fields.ProtocolVersion) +
len(fields.RequestID) +
len(fields.ResultCode) +
len(fields.PayloadHash) +
(5 * binary.MaxVarintLen64) +
8
buf := make([]byte, 0, size)
buf = appendLengthPrefixedString(buf, ResponseDomainMarkerV1)
buf = appendLengthPrefixedString(buf, fields.ProtocolVersion)
buf = appendLengthPrefixedString(buf, fields.RequestID)
buf = binary.BigEndian.AppendUint64(buf, uint64(fields.TimestampMS))
buf = appendLengthPrefixedString(buf, fields.ResultCode)
buf = appendLengthPrefixedBytes(buf, fields.PayloadHash)
return buf
}
// VerifyResponseSignature verifies that signature authenticates fields under
// publicKey using the canonical v1 response signing input.
func VerifyResponseSignature(publicKey ed25519.PublicKey, signature []byte, fields ResponseSigningFields) error {
if len(publicKey) != ed25519.PublicKeySize || len(signature) != ed25519.SignatureSize {
return ErrInvalidResponseSignature
}
if !ed25519.Verify(publicKey, BuildResponseSigningInput(fields), signature) {
return ErrInvalidResponseSignature
}
return nil
}
-146
View File
@@ -1,146 +0,0 @@
package authn
import (
"bytes"
"crypto/ed25519"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestBuildResponseSigningInputChangesWhenSignedFieldChanges(t *testing.T) {
t.Parallel()
base := ResponseSigningFields{
ProtocolVersion: "v1",
RequestID: "request-123",
TimestampMS: 123456789,
ResultCode: "ok",
PayloadHash: mustSHA256([]byte("payload")),
}
baseInput := BuildResponseSigningInput(base)
tests := []struct {
name string
mutate func(ResponseSigningFields) ResponseSigningFields
}{
{
name: "protocol version",
mutate: func(fields ResponseSigningFields) ResponseSigningFields {
fields.ProtocolVersion = "v2"
return fields
},
},
{
name: "request id",
mutate: func(fields ResponseSigningFields) ResponseSigningFields {
fields.RequestID = "request-456"
return fields
},
},
{
name: "timestamp",
mutate: func(fields ResponseSigningFields) ResponseSigningFields {
fields.TimestampMS++
return fields
},
},
{
name: "result code",
mutate: func(fields ResponseSigningFields) ResponseSigningFields {
fields.ResultCode = "denied"
return fields
},
},
{
name: "payload hash",
mutate: func(fields ResponseSigningFields) ResponseSigningFields {
fields.PayloadHash = mustSHA256([]byte("other"))
return fields
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
mutated := BuildResponseSigningInput(tt.mutate(base))
assert.False(t, bytes.Equal(baseInput, mutated))
})
}
}
func TestParseEd25519ResponseSignerPEMRejectsMalformedPEM(t *testing.T) {
t.Parallel()
_, err := ParseEd25519ResponseSignerPEM([]byte("not-pem"))
require.ErrorIs(t, err, ErrInvalidResponsePrivateKeyPEM)
}
func TestParseEd25519ResponseSignerPEMRejectsNonPKCS8PEM(t *testing.T) {
t.Parallel()
_, privateKey, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pemBytes, err := x509.MarshalPKCS8PrivateKey(privateKey)
require.NoError(t, err)
block := pem.Block{
Type: "ED25519 PRIVATE KEY",
Bytes: pemBytes,
}
_, err = ParseEd25519ResponseSignerPEM(pem.EncodeToMemory(&block))
require.ErrorIs(t, err, ErrInvalidResponsePrivateKeyPEM)
}
func TestParseEd25519ResponseSignerPEMRejectsNonEd25519Key(t *testing.T) {
t.Parallel()
privateKey, err := rsa.GenerateKey(rand.Reader, 1024)
require.NoError(t, err)
pemBytes, err := x509.MarshalPKCS8PrivateKey(privateKey)
require.NoError(t, err)
_, err = ParseEd25519ResponseSignerPEM(pem.EncodeToMemory(&pem.Block{
Type: "PRIVATE KEY",
Bytes: pemBytes,
}))
require.ErrorIs(t, err, ErrInvalidResponsePrivateKey)
}
func TestSignAndVerifyResponseSignature(t *testing.T) {
t.Parallel()
_, privateKey, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
signer, err := NewEd25519ResponseSigner(privateKey)
require.NoError(t, err)
fields := ResponseSigningFields{
ProtocolVersion: "v1",
RequestID: "request-123",
TimestampMS: 123456789,
ResultCode: "ok",
PayloadHash: mustSHA256([]byte("payload")),
}
signature, err := signer.SignResponse(fields)
require.NoError(t, err)
require.NoError(t, VerifyResponseSignature(signer.PublicKey(), signature, fields))
fields.ResultCode = "changed"
require.ErrorIs(t, VerifyResponseSignature(signer.PublicKey(), signature, fields), ErrInvalidResponseSignature)
}
-47
View File
@@ -1,47 +0,0 @@
package authn
import (
"crypto/ed25519"
"encoding/base64"
"errors"
)
var (
// ErrInvalidClientPublicKey reports that cached client public key material
// is not a base64-encoded raw Ed25519 public key.
ErrInvalidClientPublicKey = errors.New("client_public_key is not a valid base64-encoded Ed25519 public key")
// ErrInvalidRequestSignature reports that a request signature is not a raw
// Ed25519 signature for the canonical request signing input.
ErrInvalidRequestSignature = errors.New("invalid request signature")
)
// VerifyRequestSignature validates the base64-encoded raw Ed25519 public key
// from session cache, builds the canonical v1 signing input from fields, and
// verifies that signature authenticates the request.
func VerifyRequestSignature(clientPublicKey string, signature []byte, fields RequestSigningFields) error {
publicKey, err := decodeClientPublicKey(clientPublicKey)
if err != nil {
return err
}
if len(signature) != ed25519.SignatureSize {
return ErrInvalidRequestSignature
}
if !ed25519.Verify(publicKey, BuildRequestSigningInput(fields), signature) {
return ErrInvalidRequestSignature
}
return nil
}
func decodeClientPublicKey(value string) (ed25519.PublicKey, error) {
decoded, err := base64.StdEncoding.Strict().DecodeString(value)
if err != nil {
return nil, ErrInvalidClientPublicKey
}
if len(decoded) != ed25519.PublicKeySize {
return nil, ErrInvalidClientPublicKey
}
return ed25519.PublicKey(decoded), nil
}
-137
View File
@@ -1,137 +0,0 @@
package authn
import (
"crypto/ed25519"
"crypto/sha256"
"encoding/base64"
"testing"
"github.com/stretchr/testify/require"
)
func TestVerifyRequestSignature(t *testing.T) {
t.Parallel()
clientPrivateKey := newTestPrivateKey("primary")
clientPublicKey := clientPrivateKey.Public().(ed25519.PublicKey)
otherPrivateKey := newTestPrivateKey("other")
fields := RequestSigningFields{
ProtocolVersion: "v1",
DeviceSessionID: "device-session-123",
MessageType: "fleet.move",
TimestampMS: 123456789,
RequestID: "request-123",
PayloadHash: mustSHA256([]byte("payload")),
}
signature := ed25519.Sign(clientPrivateKey, BuildRequestSigningInput(fields))
tests := []struct {
name string
clientPublicKey string
signature []byte
fields RequestSigningFields
wantErr error
}{
{
name: "valid signature",
clientPublicKey: base64.StdEncoding.EncodeToString(clientPublicKey),
signature: signature,
fields: fields,
},
{
name: "message type change rejects signature",
clientPublicKey: base64.StdEncoding.EncodeToString(clientPublicKey),
signature: signature,
fields: func() RequestSigningFields {
mutated := fields
mutated.MessageType = "fleet.attack"
return mutated
}(),
wantErr: ErrInvalidRequestSignature,
},
{
name: "request id change rejects signature",
clientPublicKey: base64.StdEncoding.EncodeToString(clientPublicKey),
signature: signature,
fields: func() RequestSigningFields {
mutated := fields
mutated.RequestID = "request-456"
return mutated
}(),
wantErr: ErrInvalidRequestSignature,
},
{
name: "payload hash change rejects signature",
clientPublicKey: base64.StdEncoding.EncodeToString(clientPublicKey),
signature: signature,
fields: func() RequestSigningFields {
mutated := fields
mutated.PayloadHash = mustSHA256([]byte("other"))
return mutated
}(),
wantErr: ErrInvalidRequestSignature,
},
{
name: "wrong key rejects signature",
clientPublicKey: base64.StdEncoding.EncodeToString(otherPrivateKey.Public().(ed25519.PublicKey)),
signature: signature,
fields: fields,
wantErr: ErrInvalidRequestSignature,
},
{
name: "bit flipped signature rejects",
clientPublicKey: base64.StdEncoding.EncodeToString(clientPublicKey),
signature: func() []byte {
corrupted := append([]byte(nil), signature...)
corrupted[0] ^= 0xff
return corrupted
}(),
fields: fields,
wantErr: ErrInvalidRequestSignature,
},
{
name: "invalid signature length rejects",
clientPublicKey: base64.StdEncoding.EncodeToString(clientPublicKey),
signature: signature[:len(signature)-1],
fields: fields,
wantErr: ErrInvalidRequestSignature,
},
{
name: "invalid base64 public key rejects",
clientPublicKey: "%%%not-base64%%%",
signature: signature,
fields: fields,
wantErr: ErrInvalidClientPublicKey,
},
{
name: "invalid public key length rejects",
clientPublicKey: base64.StdEncoding.EncodeToString([]byte("short")),
signature: signature,
fields: fields,
wantErr: ErrInvalidClientPublicKey,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
err := VerifyRequestSignature(tt.clientPublicKey, tt.signature, tt.fields)
if tt.wantErr == nil {
require.NoError(t, err)
return
}
require.ErrorIs(t, err, tt.wantErr)
})
}
}
func newTestPrivateKey(label string) ed25519.PrivateKey {
seed := sha256.Sum256([]byte("gateway-authn-signature-test-" + label))
return ed25519.NewKeyFromSeed(seed[:])
}
+138
View File
@@ -0,0 +1,138 @@
package backendclient
import (
"errors"
"fmt"
"net/url"
"strings"
"time"
)
// Config describes the backend endpoint and gateway client identity used
// to construct a Client. All fields are required when the gateway is
// expected to talk to a real backend; the empty value yields an
// always-unavailable client.
type Config struct {
// HTTPBaseURL is the absolute base URL of the backend HTTP listener
// (`/api/v1/{public,user,internal}/*`). Required.
HTTPBaseURL string
// GRPCPushURL is the dial target of the backend `Push.SubscribePush`
// listener (`host:port`). Required.
GRPCPushURL string
// GatewayClientID is the durable identifier this gateway instance
// presents to backend in `GatewaySubscribeRequest.gateway_client_id`.
// Required.
GatewayClientID string
// HTTPTimeout bounds individual REST calls. Must be positive.
HTTPTimeout time.Duration
// PushReconnectBaseBackoff is the starting delay between reconnect
// attempts of `Push.SubscribePush`. Must be positive.
PushReconnectBaseBackoff time.Duration
// PushReconnectMaxBackoff is the upper bound for exponential
// reconnect delays. Must be greater than or equal to
// PushReconnectBaseBackoff.
PushReconnectMaxBackoff time.Duration
}
// Validate reports a formatted error when cfg is missing required
// values. The empty value is invalid; callers that intentionally omit
// the backend may bypass this check by skipping NewClient entirely.
func (cfg Config) Validate() error {
trimmed := strings.TrimSpace(cfg.HTTPBaseURL)
if trimmed == "" {
return errors.New("backendclient: HTTPBaseURL must not be empty")
}
parsed, err := url.Parse(strings.TrimRight(trimmed, "/"))
if err != nil {
return fmt.Errorf("backendclient: parse HTTPBaseURL: %w", err)
}
if parsed.Scheme == "" || parsed.Host == "" {
return errors.New("backendclient: HTTPBaseURL must be absolute")
}
if strings.TrimSpace(cfg.GRPCPushURL) == "" {
return errors.New("backendclient: GRPCPushURL must not be empty")
}
if strings.TrimSpace(cfg.GatewayClientID) == "" {
return errors.New("backendclient: GatewayClientID must not be empty")
}
if cfg.HTTPTimeout <= 0 {
return errors.New("backendclient: HTTPTimeout must be positive")
}
if cfg.PushReconnectBaseBackoff <= 0 {
return errors.New("backendclient: PushReconnectBaseBackoff must be positive")
}
if cfg.PushReconnectMaxBackoff < cfg.PushReconnectBaseBackoff {
return errors.New("backendclient: PushReconnectMaxBackoff must be >= PushReconnectBaseBackoff")
}
return nil
}
// Client aggregates the REST and gRPC adapters that talk to backend.
// One value is shared across the gateway process; all methods are safe
// for concurrent use.
type Client struct {
rest *RESTClient
push *PushClient
}
// NewClient constructs a Client that targets the configured backend.
// REST adapter is always built. The gRPC push adapter is built lazily
// when StartPush is called so unit tests can construct a Client with a
// stubbed push transport.
func NewClient(cfg Config) (*Client, error) {
if err := cfg.Validate(); err != nil {
return nil, err
}
rest, err := NewRESTClient(cfg)
if err != nil {
return nil, err
}
push, err := NewPushClient(cfg)
if err != nil {
return nil, err
}
return &Client{rest: rest, push: push}, nil
}
// REST returns the REST adapter. The returned value is nil when the
// Client was constructed without a backend; callers must guard.
func (c *Client) REST() *RESTClient {
if c == nil {
return nil
}
return c.rest
}
// Push returns the gRPC push adapter. The returned value is nil when
// the Client was constructed without a backend.
func (c *Client) Push() *PushClient {
if c == nil {
return nil
}
return c.push
}
// Close releases idle HTTP connections and closes the gRPC push
// connection. Safe to call multiple times.
func (c *Client) Close() error {
if c == nil {
return nil
}
var firstErr error
if c.rest != nil {
if err := c.rest.Close(); err != nil {
firstErr = err
}
}
if c.push != nil {
if err := c.push.Close(); err != nil && firstErr == nil {
firstErr = err
}
}
return firstErr
}
+18
View File
@@ -0,0 +1,18 @@
// Package backendclient is the gateway-side adapter to the consolidated
// `backend` service. It bundles every gateway → backend conversation:
//
// - public REST (`/api/v1/public/auth/*`) used by the public auth
// surface,
// - internal REST (`/api/v1/internal/sessions/*`,
// `/api/v1/internal/users/*/account-internal`) used by the
// authenticated request pipeline,
// - authenticated user REST (`/api/v1/user/*`) used by the gRPC
// downstream router after envelope verification,
// - gRPC `Push.SubscribePush` used to receive `client_event` and
// `session_invalidation` frames from backend.
//
// One env-driven Config describes the backend endpoint and the gateway
// client identity. A single Client value is wired by `cmd/gateway` and
// shared by all consumers (rest API public auth handler, gRPC session
// cache, downstream user/lobby routes, and the push subscriber).
package backendclient
@@ -0,0 +1,197 @@
package backendclient
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"galaxy/gateway/internal/downstream"
lobbymodel "galaxy/model/lobby"
"galaxy/transcoder"
)
const (
lobbyResultCodeOK = "ok"
defaultLobbyErrorCodeInvalid = "invalid_request"
defaultLobbyErrorCodeNoSubj = "subject_not_found"
defaultLobbyErrorCodeForbid = "forbidden"
defaultLobbyErrorCodeConfl = "conflict"
defaultLobbyErrorCodeIntErr = "internal_error"
)
var stableLobbyErrorMessages = map[string]string{
defaultLobbyErrorCodeInvalid: "request is invalid",
defaultLobbyErrorCodeNoSubj: "subject not found",
defaultLobbyErrorCodeForbid: "operation is forbidden for the calling user",
defaultLobbyErrorCodeConfl: "request conflicts with current state",
defaultLobbyErrorCodeIntErr: "internal server error",
}
// ExecuteLobbyCommand routes one authenticated lobby command into
// backend's `/api/v1/user/lobby/*` endpoints.
func (c *RESTClient) ExecuteLobbyCommand(ctx context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
if c == nil || c.httpClient == nil {
return downstream.UnaryResult{}, errors.New("backendclient: execute lobby command: nil client")
}
if ctx == nil {
return downstream.UnaryResult{}, errors.New("backendclient: execute lobby command: nil context")
}
if err := ctx.Err(); err != nil {
return downstream.UnaryResult{}, err
}
if strings.TrimSpace(command.UserID) == "" {
return downstream.UnaryResult{}, errors.New("backendclient: execute lobby command: user_id must not be empty")
}
switch command.MessageType {
case lobbymodel.MessageTypeMyGamesList:
if _, err := transcoder.PayloadToMyGamesListRequest(command.PayloadBytes); err != nil {
return downstream.UnaryResult{}, fmt.Errorf("backendclient: execute lobby command %q: %w", command.MessageType, err)
}
return c.executeLobbyMyGames(ctx, command.UserID)
case lobbymodel.MessageTypeOpenEnrollment:
req, err := transcoder.PayloadToOpenEnrollmentRequest(command.PayloadBytes)
if err != nil {
return downstream.UnaryResult{}, fmt.Errorf("backendclient: execute lobby command %q: %w", command.MessageType, err)
}
return c.executeLobbyOpenEnrollment(ctx, command.UserID, req)
default:
return downstream.UnaryResult{}, fmt.Errorf("backendclient: execute lobby command: unsupported message type %q", command.MessageType)
}
}
func (c *RESTClient) executeLobbyMyGames(ctx context.Context, userID string) (downstream.UnaryResult, error) {
body, status, err := c.do(ctx, http.MethodGet, c.baseURL+"/api/v1/user/lobby/my/games", userID, nil)
if err != nil {
return downstream.UnaryResult{}, fmt.Errorf("execute lobby.my.games.list: %w", err)
}
if status == http.StatusOK {
var response lobbymodel.MyGamesListResponse
if err := decodeStrictJSON(body, &response); err != nil {
return downstream.UnaryResult{}, fmt.Errorf("decode success response: %w", err)
}
payloadBytes, err := transcoder.MyGamesListResponseToPayload(&response)
if err != nil {
return downstream.UnaryResult{}, fmt.Errorf("encode success response payload: %w", err)
}
return downstream.UnaryResult{
ResultCode: lobbyResultCodeOK,
PayloadBytes: payloadBytes,
}, nil
}
return projectLobbyErrorResponse(status, body)
}
func (c *RESTClient) executeLobbyOpenEnrollment(ctx context.Context, userID string, req *lobbymodel.OpenEnrollmentRequest) (downstream.UnaryResult, error) {
if req == nil || strings.TrimSpace(req.GameID) == "" {
return downstream.UnaryResult{}, errors.New("execute lobby.game.open-enrollment: game_id must not be empty")
}
target := c.baseURL + "/api/v1/user/lobby/games/" + url.PathEscape(req.GameID) + "/open-enrollment"
body, status, err := c.do(ctx, http.MethodPost, target, userID, struct{}{})
if err != nil {
return downstream.UnaryResult{}, fmt.Errorf("execute lobby.game.open-enrollment: %w", err)
}
if status == http.StatusOK {
// Backend returns the full LobbyGameDetail; gateway projects the
// minimal {game_id, status} pair onto the existing wire shape.
var detail struct {
GameID string `json:"game_id"`
Status string `json:"status"`
}
if err := json.NewDecoder(bytes.NewReader(body)).Decode(&detail); err != nil {
return downstream.UnaryResult{}, fmt.Errorf("decode success response: %w", err)
}
payloadBytes, err := transcoder.OpenEnrollmentResponseToPayload(&lobbymodel.OpenEnrollmentResponse{
GameID: detail.GameID,
Status: detail.Status,
})
if err != nil {
return downstream.UnaryResult{}, fmt.Errorf("encode success response payload: %w", err)
}
return downstream.UnaryResult{
ResultCode: lobbyResultCodeOK,
PayloadBytes: payloadBytes,
}, nil
}
return projectLobbyErrorResponse(status, body)
}
func projectLobbyErrorResponse(statusCode int, payload []byte) (downstream.UnaryResult, error) {
switch {
case statusCode == http.StatusServiceUnavailable:
return downstream.UnaryResult{}, downstream.ErrDownstreamUnavailable
case statusCode >= 400 && statusCode <= 599:
errResp, err := decodeLobbyError(statusCode, payload)
if err != nil {
return downstream.UnaryResult{}, fmt.Errorf("decode error response: %w", err)
}
payloadBytes, err := transcoder.LobbyErrorResponseToPayload(errResp)
if err != nil {
return downstream.UnaryResult{}, fmt.Errorf("encode error response payload: %w", err)
}
return downstream.UnaryResult{
ResultCode: errResp.Error.Code,
PayloadBytes: payloadBytes,
}, nil
default:
return downstream.UnaryResult{}, fmt.Errorf("unexpected HTTP status %d", statusCode)
}
}
func decodeLobbyError(statusCode int, payload []byte) (*lobbymodel.ErrorResponse, error) {
var response lobbymodel.ErrorResponse
decoder := json.NewDecoder(bytes.NewReader(payload))
decoder.DisallowUnknownFields()
if err := decoder.Decode(&response); err != nil {
return nil, err
}
if err := decoder.Decode(&struct{}{}); err != io.EOF {
if err == nil {
return nil, errors.New("unexpected trailing JSON input")
}
return nil, err
}
response.Error.Code = normalizeLobbyErrorCode(statusCode, response.Error.Code)
response.Error.Message = normalizeLobbyErrorMessage(response.Error.Code, response.Error.Message)
if strings.TrimSpace(response.Error.Code) == "" {
return nil, errors.New("missing error code")
}
if strings.TrimSpace(response.Error.Message) == "" {
return nil, errors.New("missing error message")
}
return &response, nil
}
func normalizeLobbyErrorCode(statusCode int, code string) string {
if trimmed := strings.TrimSpace(code); trimmed != "" {
return trimmed
}
switch statusCode {
case http.StatusBadRequest:
return defaultLobbyErrorCodeInvalid
case http.StatusForbidden:
return defaultLobbyErrorCodeForbid
case http.StatusNotFound:
return defaultLobbyErrorCodeNoSubj
case http.StatusConflict:
return defaultLobbyErrorCodeConfl
default:
return defaultLobbyErrorCodeIntErr
}
}
func normalizeLobbyErrorMessage(code, message string) string {
if trimmed := strings.TrimSpace(message); trimmed != "" {
return trimmed
}
if stable, ok := stableLobbyErrorMessages[code]; ok {
return stable
}
return stableLobbyErrorMessages[defaultLobbyErrorCodeIntErr]
}
@@ -0,0 +1,148 @@
package backendclient
import (
"context"
"errors"
"fmt"
"net/http"
"strings"
)
// SendEmailCodeInput is the public REST and adapter payload used to
// request a login code for a single e-mail address.
type SendEmailCodeInput struct {
Email string `json:"email"`
PreferredLanguage string `json:"-"`
}
// SendEmailCodeResult is the public REST and adapter payload returned
// after backend creates a login challenge.
type SendEmailCodeResult struct {
ChallengeID string `json:"challenge_id"`
}
// ConfirmEmailCodeInput is the public REST and adapter payload used to
// complete a previously issued login challenge.
type ConfirmEmailCodeInput struct {
ChallengeID string `json:"challenge_id"`
Code string `json:"code"`
ClientPublicKey string `json:"client_public_key"`
TimeZone string `json:"time_zone"`
}
// ConfirmEmailCodeResult is the public REST and adapter payload
// returned after backend creates a device session.
type ConfirmEmailCodeResult struct {
DeviceSessionID string `json:"device_session_id"`
}
// AuthError lets a public REST handler project a stable error envelope
// without re-deriving backend semantics. StatusCode is the HTTP status
// the gateway should return; Code and Message form the JSON envelope.
type AuthError struct {
StatusCode int
Code string
Message string
}
// Error returns a readable representation of the projected auth error.
func (e *AuthError) Error() string {
if e == nil {
return ""
}
return fmt.Sprintf("backendclient auth error: status=%d code=%s message=%s", e.StatusCode, e.Code, e.Message)
}
// SendEmailCode delegates the public send-email-code route to backend.
func (c *RESTClient) SendEmailCode(ctx context.Context, input SendEmailCodeInput) (SendEmailCodeResult, error) {
if strings.TrimSpace(input.Email) == "" {
return SendEmailCodeResult{}, errors.New("backendclient: send email code: email must not be empty")
}
body, status, err := c.doWithHeaders(ctx, http.MethodPost, c.baseURL+"/api/v1/public/auth/send-email-code", "", input, map[string]string{
"Accept-Language": resolvePreferredLanguage(input.PreferredLanguage),
})
if err != nil {
return SendEmailCodeResult{}, fmt.Errorf("backendclient: send email code: %w", err)
}
switch {
case status == http.StatusOK:
var result SendEmailCodeResult
if err := decodeStrictJSON(body, &result); err != nil {
return SendEmailCodeResult{}, fmt.Errorf("backendclient: send email code: decode success response: %w", err)
}
if strings.TrimSpace(result.ChallengeID) == "" {
return SendEmailCodeResult{}, errors.New("backendclient: send email code: challenge_id must not be empty")
}
return result, nil
case status >= 400 && status <= 599:
authErr, decodeErr := decodeAuthError(status, body)
if decodeErr != nil {
return SendEmailCodeResult{}, fmt.Errorf("backendclient: send email code: %w", decodeErr)
}
return SendEmailCodeResult{}, authErr
default:
return SendEmailCodeResult{}, fmt.Errorf("backendclient: send email code: unexpected HTTP status %d", status)
}
}
// ConfirmEmailCode delegates the public confirm-email-code route to
// backend.
func (c *RESTClient) ConfirmEmailCode(ctx context.Context, input ConfirmEmailCodeInput) (ConfirmEmailCodeResult, error) {
if strings.TrimSpace(input.ChallengeID) == "" {
return ConfirmEmailCodeResult{}, errors.New("backendclient: confirm email code: challenge_id must not be empty")
}
body, status, err := c.doWithHeaders(ctx, http.MethodPost, c.baseURL+"/api/v1/public/auth/confirm-email-code", "", input, nil)
if err != nil {
return ConfirmEmailCodeResult{}, fmt.Errorf("backendclient: confirm email code: %w", err)
}
switch {
case status == http.StatusOK:
var result ConfirmEmailCodeResult
if err := decodeStrictJSON(body, &result); err != nil {
return ConfirmEmailCodeResult{}, fmt.Errorf("backendclient: confirm email code: decode success response: %w", err)
}
if strings.TrimSpace(result.DeviceSessionID) == "" {
return ConfirmEmailCodeResult{}, errors.New("backendclient: confirm email code: device_session_id must not be empty")
}
return result, nil
case status >= 400 && status <= 599:
authErr, decodeErr := decodeAuthError(status, body)
if decodeErr != nil {
return ConfirmEmailCodeResult{}, fmt.Errorf("backendclient: confirm email code: %w", decodeErr)
}
return ConfirmEmailCodeResult{}, authErr
default:
return ConfirmEmailCodeResult{}, fmt.Errorf("backendclient: confirm email code: unexpected HTTP status %d", status)
}
}
// resolvePreferredLanguage returns a non-empty Accept-Language value or
// the empty string when input is unset; downstream HTTP request helpers
// drop the header on empty values.
func resolvePreferredLanguage(preferred string) string {
return strings.TrimSpace(preferred)
}
type authErrorEnvelope struct {
Error *authErrorBody `json:"error"`
}
type authErrorBody struct {
Code string `json:"code"`
Message string `json:"message"`
}
func decodeAuthError(statusCode int, payload []byte) (*AuthError, error) {
var envelope authErrorEnvelope
if err := decodeStrictJSON(payload, &envelope); err != nil {
return nil, fmt.Errorf("decode error response: %w", err)
}
if envelope.Error == nil {
return nil, errors.New("decode error response: missing error object")
}
return &AuthError{
StatusCode: statusCode,
Code: envelope.Error.Code,
Message: envelope.Error.Message,
}, nil
}
@@ -0,0 +1,266 @@
// PushClient — gateway-side gRPC consumer of `Push.SubscribePush`.
//
// One PushClient is wired for the gateway lifecycle. Run keeps the
// subscription open, reconnects on every transport error with
// exponential backoff (capped at PushReconnectMaxBackoff), and forwards
// every received PushEvent to the configured EventHandler. The cursor
// of the last successfully handled event is remembered in process
// memory only (see `backend/README.md` and `backend/docs/` D2). On reconnect
// it is replayed back to backend so any events still in the freshness-
// window ring are received exactly once.
package backendclient
import (
"context"
"errors"
"fmt"
"io"
"math/rand/v2"
"sync"
"time"
pushv1 "galaxy/backend/proto/push/v1"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/status"
)
// EventHandler receives every PushEvent successfully drained from the
// backend stream. Implementations must be concurrency-safe and must not
// block; PushClient owns the calling goroutine and waits for Handle to
// return before reading the next event.
type EventHandler interface {
Handle(context.Context, *pushv1.PushEvent)
}
// EventHandlerFunc adapts a plain function to the EventHandler
// contract.
type EventHandlerFunc func(context.Context, *pushv1.PushEvent)
// Handle implements EventHandler.
func (f EventHandlerFunc) Handle(ctx context.Context, ev *pushv1.PushEvent) { f(ctx, ev) }
// PushClient is the gRPC adapter that owns the long-lived
// SubscribePush stream.
type PushClient struct {
cfg Config
dialOpts []grpc.DialOption
clock func() time.Time
sleep func(context.Context, time.Duration) error
logger *zap.Logger
handler EventHandler
mu sync.Mutex
cursor string
connMu sync.Mutex
conn *grpc.ClientConn
}
// NewPushClient constructs a PushClient. The default dial uses
// transport credentials INSECURE; deployments behind TLS must wrap the
// returned client with an alternative DialOption set via
// WithDialOptions before calling Run.
func NewPushClient(cfg Config) (*PushClient, error) {
if err := cfg.Validate(); err != nil {
return nil, err
}
return &PushClient{
cfg: cfg,
dialOpts: []grpc.DialOption{
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithStatsHandler(otelgrpc.NewClientHandler()),
},
clock: time.Now,
sleep: defaultSleep,
logger: zap.NewNop(),
}, nil
}
// WithDialOptions overrides the default dial options used when opening
// the gRPC connection. Tests typically pass `grpc.WithContextDialer` so
// `grpc.NewClient` connects to a `bufconn` listener.
func (c *PushClient) WithDialOptions(opts ...grpc.DialOption) *PushClient {
if c == nil {
return nil
}
c.dialOpts = append([]grpc.DialOption(nil), opts...)
return c
}
// WithLogger replaces the structured logger.
func (c *PushClient) WithLogger(logger *zap.Logger) *PushClient {
if c == nil {
return nil
}
if logger == nil {
logger = zap.NewNop()
}
c.logger = logger.Named("push_client")
return c
}
// WithHandler installs the EventHandler. Run returns an error if no
// handler has been installed.
func (c *PushClient) WithHandler(handler EventHandler) *PushClient {
if c == nil {
return nil
}
c.handler = handler
return c
}
// Cursor returns the cursor of the last event delivered to the handler.
// Useful for tests and operator inspection. Returns the empty string
// before any event has been processed.
func (c *PushClient) Cursor() string {
if c == nil {
return ""
}
c.mu.Lock()
defer c.mu.Unlock()
return c.cursor
}
// Run opens the SubscribePush stream and forwards events until ctx is
// cancelled. Network errors are retried with exponential backoff up to
// PushReconnectMaxBackoff; ctx cancellation is the only terminal exit.
func (c *PushClient) Run(ctx context.Context) error {
if c == nil {
return errors.New("backendclient.PushClient.Run: nil client")
}
if ctx == nil {
return errors.New("backendclient.PushClient.Run: nil context")
}
if c.handler == nil {
return errors.New("backendclient.PushClient.Run: handler is required")
}
conn, err := grpc.NewClient(c.cfg.GRPCPushURL, c.dialOpts...)
if err != nil {
return fmt.Errorf("backendclient.PushClient.Run: dial backend push: %w", err)
}
c.connMu.Lock()
c.conn = conn
c.connMu.Unlock()
defer func() {
c.connMu.Lock()
_ = c.conn.Close()
c.conn = nil
c.connMu.Unlock()
}()
pushAPI := pushv1.NewPushClient(conn)
backoff := c.cfg.PushReconnectBaseBackoff
for {
if err := ctx.Err(); err != nil {
return err
}
err := c.runOnce(ctx, pushAPI)
switch {
case err == nil, errors.Is(err, context.Canceled):
return ctx.Err()
case status.Code(err) == codes.Aborted:
c.logger.Info("backend replaced push subscription; reconnecting")
case errors.Is(err, io.EOF):
c.logger.Info("backend push stream closed; reconnecting")
default:
c.logger.Warn("backend push stream error; reconnecting",
zap.Error(err),
zap.Duration("backoff", backoff),
)
}
if err := c.sleep(ctx, jitter(backoff)); err != nil {
return err
}
backoff = nextBackoff(backoff, c.cfg.PushReconnectMaxBackoff)
}
}
// Shutdown is a no-op kept for `app.Component` compatibility. The
// SubscribePush call exits when its parent context is cancelled.
func (c *PushClient) Shutdown(_ context.Context) error { return nil }
// Close closes the underlying gRPC connection if it is open. Idempotent.
func (c *PushClient) Close() error {
if c == nil {
return nil
}
c.connMu.Lock()
defer c.connMu.Unlock()
if c.conn == nil {
return nil
}
err := c.conn.Close()
c.conn = nil
return err
}
func (c *PushClient) runOnce(ctx context.Context, pushAPI pushv1.PushClient) error {
stream, err := pushAPI.SubscribePush(ctx, &pushv1.GatewaySubscribeRequest{
GatewayClientId: c.cfg.GatewayClientID,
Cursor: c.Cursor(),
})
if err != nil {
return fmt.Errorf("subscribe push: %w", err)
}
for {
ev, err := stream.Recv()
if err != nil {
return err
}
c.handler.Handle(ctx, ev)
if cursor := ev.GetCursor(); cursor != "" {
c.setCursor(cursor)
}
}
}
func (c *PushClient) setCursor(cursor string) {
c.mu.Lock()
c.cursor = cursor
c.mu.Unlock()
}
func nextBackoff(current, max time.Duration) time.Duration {
doubled := current * 2
if doubled > max {
return max
}
if doubled <= 0 {
return max
}
return doubled
}
// jitter returns d with ±20% multiplicative noise so multiple gateway
// instances do not retry in lockstep after a backend restart.
func jitter(d time.Duration) time.Duration {
if d <= 0 {
return d
}
noise := 1 + (rand.Float64()-0.5)*0.4
return time.Duration(float64(d) * noise)
}
func defaultSleep(ctx context.Context, d time.Duration) error {
if d <= 0 {
return nil
}
timer := time.NewTimer(d)
defer timer.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
return nil
}
}
@@ -0,0 +1,132 @@
package backendclient_test
import (
"context"
"net"
"sync"
"testing"
"time"
backendpush "galaxy/backend/push"
pushv1 "galaxy/backend/proto/push/v1"
"galaxy/gateway/internal/backendclient"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/test/bufconn"
)
// bufconnPushService starts an in-process backend push.Service backed by
// a *grpc.Server on a bufconn listener and returns the dial option that
// gateway PushClient should use to connect to it.
type bufconnPushService struct {
Service *backendpush.Service
dial func(context.Context, string) (net.Conn, error)
stop func()
}
func newBufconnPushService(t *testing.T) *bufconnPushService {
t.Helper()
service, err := backendpush.NewService(backendpush.ServiceConfig{
FreshnessWindow: time.Minute,
RingCapacity: 16,
PerConnBuffer: 8,
}, nil, nil)
require.NoError(t, err)
listener := bufconn.Listen(1 << 16)
server := grpc.NewServer()
pushv1.RegisterPushServer(server, service)
go func() {
_ = server.Serve(listener)
}()
stop := func() {
service.Close()
server.Stop()
_ = listener.Close()
}
t.Cleanup(stop)
return &bufconnPushService{
Service: service,
dial: func(_ context.Context, _ string) (net.Conn, error) { return listener.Dial() },
stop: stop,
}
}
func TestPushClientDeliversClientEventsAndAdvancesCursor(t *testing.T) {
t.Parallel()
svc := newBufconnPushService(t)
type received struct {
event *pushv1.PushEvent
cursor string
}
out := make(chan received, 4)
cfg := backendclient.Config{
HTTPBaseURL: "http://example.invalid",
GRPCPushURL: "passthrough://bufconn",
GatewayClientID: "gw-1",
HTTPTimeout: time.Second,
PushReconnectBaseBackoff: 10 * time.Millisecond,
PushReconnectMaxBackoff: 100 * time.Millisecond,
}
client, err := backendclient.NewPushClient(cfg)
require.NoError(t, err)
client.WithDialOptions(
grpc.WithContextDialer(svc.dial),
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
client.WithHandler(backendclient.EventHandlerFunc(func(_ context.Context, ev *pushv1.PushEvent) {
out <- received{event: ev, cursor: ev.GetCursor()}
}))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var (
runErr error
wg sync.WaitGroup
)
wg.Add(1)
go func() {
defer wg.Done()
runErr = client.Run(ctx)
}()
// Wait for backend service to register the subscription.
require.Eventually(t, func() bool { return svc.Service.SubscriberCount() == 1 }, time.Second, 10*time.Millisecond)
userID := uuid.New()
require.NoError(t, svc.Service.PublishClientEvent(context.Background(), userID, nil, "lobby.invite.received", map[string]any{"x": 1.0}, "evt-1", "req-1", "trace-1"))
select {
case got := <-out:
ce := got.event.GetClientEvent()
require.NotNil(t, ce)
assert.Equal(t, userID.String(), ce.GetUserId())
assert.Equal(t, "lobby.invite.received", ce.GetKind())
assert.Equal(t, "evt-1", ce.GetEventId())
assert.Equal(t, "req-1", ce.GetRequestId())
assert.Equal(t, "trace-1", ce.GetTraceId())
assert.NotEmpty(t, got.cursor)
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for client event")
}
require.Eventually(t, func() bool { return client.Cursor() != "" }, time.Second, 10*time.Millisecond)
cancel()
wg.Wait()
if runErr != nil && runErr != context.Canceled {
t.Fatalf("unexpected run error: %v", runErr)
}
}
+256
View File
@@ -0,0 +1,256 @@
package backendclient
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"galaxy/gateway/internal/session"
)
// HeaderUserID is the trusted gateway → backend identity header.
const HeaderUserID = "X-User-Id"
// errSessionNotFound is the public error returned by LookupSession when
// backend reports HTTP 404 for a device session id. It wraps
// session.ErrNotFound so callers can keep using the existing typed
// equality check at the gateway hot path.
func errSessionNotFound() error {
return fmt.Errorf("backendclient: lookup session: %w", session.ErrNotFound)
}
// RESTClient owns the gateway's HTTP conversation with backend.
//
// All methods are safe for concurrent use.
type RESTClient struct {
baseURL string
httpClient *http.Client
}
// NewRESTClient constructs a RESTClient targeting the backend HTTP
// listener configured in cfg.
func NewRESTClient(cfg Config) (*RESTClient, error) {
transport, ok := http.DefaultTransport.(*http.Transport)
if !ok {
return nil, errors.New("backendclient: default HTTP transport is not *http.Transport")
}
parsed, err := url.Parse(strings.TrimRight(strings.TrimSpace(cfg.HTTPBaseURL), "/"))
if err != nil {
return nil, fmt.Errorf("backendclient: parse HTTPBaseURL: %w", err)
}
if parsed.Scheme == "" || parsed.Host == "" {
return nil, errors.New("backendclient: HTTPBaseURL must be absolute")
}
return &RESTClient{
baseURL: parsed.String(),
httpClient: &http.Client{
Transport: transport.Clone(),
Timeout: cfg.HTTPTimeout,
},
}, nil
}
// Close releases idle HTTP connections owned by the client transport.
func (c *RESTClient) Close() error {
if c == nil || c.httpClient == nil {
return nil
}
type idleCloser interface {
CloseIdleConnections()
}
if transport, ok := c.httpClient.Transport.(idleCloser); ok {
transport.CloseIdleConnections()
}
return nil
}
// LookupSession resolves deviceSessionID against
// `GET /api/v1/internal/sessions/{device_session_id}`.
// Returns session.ErrNotFound (wrapped) when backend reports 404.
func (c *RESTClient) LookupSession(ctx context.Context, deviceSessionID string) (session.Record, error) {
if c == nil || c.httpClient == nil {
return session.Record{}, errors.New("backendclient: nil REST client")
}
if strings.TrimSpace(deviceSessionID) == "" {
return session.Record{}, errors.New("backendclient: lookup session: device_session_id must not be empty")
}
target := c.baseURL + "/api/v1/internal/sessions/" + url.PathEscape(deviceSessionID)
body, status, err := c.do(ctx, http.MethodGet, target, "", nil)
if err != nil {
return session.Record{}, fmt.Errorf("backendclient: lookup session: %w", err)
}
switch {
case status == http.StatusOK:
return decodeDeviceSession(deviceSessionID, body)
case status == http.StatusNotFound:
return session.Record{}, errSessionNotFound()
default:
return session.Record{}, fmt.Errorf("backendclient: lookup session: unexpected HTTP status %d", status)
}
}
// RevokeSession asks backend to revoke a single device session by id.
func (c *RESTClient) RevokeSession(ctx context.Context, deviceSessionID string) error {
if strings.TrimSpace(deviceSessionID) == "" {
return errors.New("backendclient: revoke session: device_session_id must not be empty")
}
target := c.baseURL + "/api/v1/internal/sessions/" + url.PathEscape(deviceSessionID) + "/revoke"
_, status, err := c.do(ctx, http.MethodPost, target, "", nil)
if err != nil {
return fmt.Errorf("backendclient: revoke session: %w", err)
}
if status == http.StatusOK || status == http.StatusNoContent {
return nil
}
if status == http.StatusNotFound {
return errSessionNotFound()
}
return fmt.Errorf("backendclient: revoke session: unexpected HTTP status %d", status)
}
// RevokeAllSessionsForUser asks backend to revoke every active device
// session belonging to userID.
func (c *RESTClient) RevokeAllSessionsForUser(ctx context.Context, userID string) error {
if strings.TrimSpace(userID) == "" {
return errors.New("backendclient: revoke-all sessions: user_id must not be empty")
}
target := c.baseURL + "/api/v1/internal/sessions/users/" + url.PathEscape(userID) + "/revoke-all"
_, status, err := c.do(ctx, http.MethodPost, target, "", nil)
if err != nil {
return fmt.Errorf("backendclient: revoke-all sessions: %w", err)
}
if status == http.StatusOK || status == http.StatusNoContent {
return nil
}
if status == http.StatusNotFound {
return errSessionNotFound()
}
return fmt.Errorf("backendclient: revoke-all sessions: unexpected HTTP status %d", status)
}
// do executes a JSON request and reads the response body. userID, when
// non-empty, is sent as the X-User-Id header (required for `/api/v1/user/*`).
func (c *RESTClient) do(ctx context.Context, method, target, userID string, body any) ([]byte, int, error) {
return c.doWithHeaders(ctx, method, target, userID, body, nil)
}
// doWithHeaders is the shared transport entry point. extraHeaders are
// applied verbatim after Content-Type/X-User-Id; an empty value drops
// the header so callers can pass optional language tags etc.
func (c *RESTClient) doWithHeaders(ctx context.Context, method, target, userID string, body any, extraHeaders map[string]string) ([]byte, int, error) {
if c == nil || c.httpClient == nil {
return nil, 0, errors.New("nil REST client")
}
if ctx == nil {
return nil, 0, errors.New("nil context")
}
var reader io.Reader
if body != nil {
buf, err := json.Marshal(body)
if err != nil {
return nil, 0, fmt.Errorf("marshal request body: %w", err)
}
reader = bytes.NewReader(buf)
}
req, err := http.NewRequestWithContext(ctx, method, target, reader)
if err != nil {
return nil, 0, fmt.Errorf("build request: %w", err)
}
if body != nil {
req.Header.Set("Content-Type", "application/json")
}
if userID != "" {
req.Header.Set(HeaderUserID, userID)
}
for key, value := range extraHeaders {
if strings.TrimSpace(value) == "" {
continue
}
req.Header.Set(key, value)
}
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, 0, err
}
defer resp.Body.Close()
payload, err := io.ReadAll(resp.Body)
if err != nil {
return nil, resp.StatusCode, fmt.Errorf("read response body: %w", err)
}
return payload, resp.StatusCode, nil
}
// deviceSessionWire mirrors backend openapi `DeviceSession`.
type deviceSessionWire struct {
DeviceSessionID string `json:"device_session_id"`
UserID string `json:"user_id"`
Status string `json:"status"`
ClientPublicKey string `json:"client_public_key,omitempty"`
CreatedAt time.Time `json:"created_at"`
RevokedAt *time.Time `json:"revoked_at,omitempty"`
LastSeenAt *time.Time `json:"last_seen_at,omitempty"`
}
func decodeDeviceSession(expectedDeviceSessionID string, payload []byte) (session.Record, error) {
var wire deviceSessionWire
if err := decodeStrictJSON(payload, &wire); err != nil {
return session.Record{}, fmt.Errorf("decode device session: %w", err)
}
if strings.TrimSpace(wire.DeviceSessionID) == "" {
return session.Record{}, errors.New("decode device session: device_session_id must not be empty")
}
if wire.DeviceSessionID != expectedDeviceSessionID {
return session.Record{}, fmt.Errorf("decode device session: device_session_id %q does not match requested %q", wire.DeviceSessionID, expectedDeviceSessionID)
}
if strings.TrimSpace(wire.UserID) == "" {
return session.Record{}, errors.New("decode device session: user_id must not be empty")
}
status := session.Status(strings.TrimSpace(wire.Status))
if !status.IsKnown() {
return session.Record{}, fmt.Errorf("decode device session: status %q is unsupported", wire.Status)
}
if status == session.StatusActive && strings.TrimSpace(wire.ClientPublicKey) == "" {
return session.Record{}, errors.New("decode device session: active record missing client_public_key")
}
record := session.Record{
DeviceSessionID: wire.DeviceSessionID,
UserID: wire.UserID,
ClientPublicKey: wire.ClientPublicKey,
Status: status,
}
if wire.RevokedAt != nil {
ms := wire.RevokedAt.UnixMilli()
record.RevokedAtMS = &ms
}
return record, nil
}
func decodeStrictJSON(payload []byte, target any) error {
decoder := json.NewDecoder(bytes.NewReader(payload))
decoder.DisallowUnknownFields()
if err := decoder.Decode(target); err != nil {
return err
}
if err := decoder.Decode(&struct{}{}); err != io.EOF {
if err == nil {
return errors.New("unexpected trailing JSON input")
}
return err
}
return nil
}
+190
View File
@@ -0,0 +1,190 @@
package backendclient_test
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"galaxy/gateway/internal/backendclient"
"galaxy/gateway/internal/session"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func newRESTClient(t *testing.T, server *httptest.Server) *backendclient.RESTClient {
t.Helper()
cfg := backendclient.Config{
HTTPBaseURL: server.URL,
GRPCPushURL: "passthrough://test",
GatewayClientID: "test-gateway",
HTTPTimeout: time.Second,
PushReconnectBaseBackoff: 10 * time.Millisecond,
PushReconnectMaxBackoff: 100 * time.Millisecond,
}
client, err := backendclient.NewRESTClient(cfg)
require.NoError(t, err)
t.Cleanup(func() { _ = client.Close() })
return client
}
func TestRESTClientLookupSessionReturnsActiveRecord(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
require.Equal(t, "/api/v1/internal/sessions/device-1", r.URL.Path)
writeJSON(t, w, http.StatusOK, map[string]any{
"device_session_id": "device-1",
"user_id": "user-1",
"status": "active",
"client_public_key": "pk-1",
"created_at": "2026-04-01T00:00:00Z",
})
}))
t.Cleanup(server.Close)
client := newRESTClient(t, server)
rec, err := client.LookupSession(context.Background(), "device-1")
require.NoError(t, err)
assert.Equal(t, session.Record{
DeviceSessionID: "device-1",
UserID: "user-1",
ClientPublicKey: "pk-1",
Status: session.StatusActive,
}, rec)
}
func TestRESTClientLookupSessionReturnsRevokedRecord(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
writeJSON(t, w, http.StatusOK, map[string]any{
"device_session_id": "device-2",
"user_id": "user-2",
"status": "revoked",
"client_public_key": "pk-2",
"created_at": "2026-04-01T00:00:00Z",
"revoked_at": "2026-04-01T00:01:00Z",
})
}))
t.Cleanup(server.Close)
client := newRESTClient(t, server)
rec, err := client.LookupSession(context.Background(), "device-2")
require.NoError(t, err)
assert.Equal(t, session.StatusRevoked, rec.Status)
require.NotNil(t, rec.RevokedAtMS)
assert.Equal(t, time.Date(2026, 4, 1, 0, 1, 0, 0, time.UTC).UnixMilli(), *rec.RevokedAtMS)
}
func TestRESTClientLookupSessionMapsNotFound(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
writeJSON(t, w, http.StatusNotFound, map[string]any{"error": map[string]any{"code": "subject_not_found", "message": "missing"}})
}))
t.Cleanup(server.Close)
client := newRESTClient(t, server)
_, err := client.LookupSession(context.Background(), "missing")
require.Error(t, err)
assert.True(t, errors.Is(err, session.ErrNotFound))
}
func TestRESTClientLookupSessionRejectsMismatchedID(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
writeJSON(t, w, http.StatusOK, map[string]any{
"device_session_id": "other",
"user_id": "user-1",
"status": "active",
"client_public_key": "pk-1",
"created_at": "2026-04-01T00:00:00Z",
})
}))
t.Cleanup(server.Close)
client := newRESTClient(t, server)
_, err := client.LookupSession(context.Background(), "device-1")
require.Error(t, err)
assert.Contains(t, err.Error(), "does not match requested")
}
func TestRESTClientSendEmailCodeForwardsAcceptLanguage(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodPost, r.Method)
require.Equal(t, "/api/v1/public/auth/send-email-code", r.URL.Path)
require.Equal(t, "ru-RU", r.Header.Get("Accept-Language"))
writeJSON(t, w, http.StatusOK, map[string]any{"challenge_id": "challenge-1"})
}))
t.Cleanup(server.Close)
client := newRESTClient(t, server)
out, err := client.SendEmailCode(context.Background(), backendclient.SendEmailCodeInput{
Email: "user@example.com",
PreferredLanguage: "ru-RU",
})
require.NoError(t, err)
assert.Equal(t, "challenge-1", out.ChallengeID)
}
func TestRESTClientSendEmailCodeProjectsAuthError(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
writeJSON(t, w, http.StatusBadRequest, map[string]any{
"error": map[string]any{"code": "invalid_request", "message": "bad email"},
})
}))
t.Cleanup(server.Close)
client := newRESTClient(t, server)
_, err := client.SendEmailCode(context.Background(), backendclient.SendEmailCodeInput{Email: "user@example.com"})
require.Error(t, err)
var authErr *backendclient.AuthError
require.ErrorAs(t, err, &authErr)
assert.Equal(t, http.StatusBadRequest, authErr.StatusCode)
assert.Equal(t, "invalid_request", authErr.Code)
assert.Equal(t, "bad email", authErr.Message)
}
func TestRESTClientConfirmEmailCodeReturnsDeviceSession(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/api/v1/public/auth/confirm-email-code", r.URL.Path)
var body backendclient.ConfirmEmailCodeInput
require.NoError(t, json.NewDecoder(r.Body).Decode(&body))
assert.Equal(t, "challenge-1", body.ChallengeID)
writeJSON(t, w, http.StatusOK, map[string]any{"device_session_id": "device-1"})
}))
t.Cleanup(server.Close)
client := newRESTClient(t, server)
out, err := client.ConfirmEmailCode(context.Background(), backendclient.ConfirmEmailCodeInput{
ChallengeID: "challenge-1",
Code: "12345",
})
require.NoError(t, err)
assert.Equal(t, "device-1", out.DeviceSessionID)
}
func writeJSON(t *testing.T, w http.ResponseWriter, status int, body any) {
t.Helper()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
require.NoError(t, json.NewEncoder(w).Encode(body))
}
// guard ensures package keeps testify dependency.
var _ = strings.TrimSpace
+67
View File
@@ -0,0 +1,67 @@
package backendclient
import (
"context"
"galaxy/gateway/internal/downstream"
lobbymodel "galaxy/model/lobby"
usermodel "galaxy/model/user"
)
// UserRoutes returns the authenticated `user.*` downstream routes
// served by backend. When client is nil every route resolves to a
// dependency-unavailable client so the static router still recognises
// the message types.
func UserRoutes(client *RESTClient) map[string]downstream.Client {
target := downstream.Client(unavailableClient{})
if client != nil {
target = userCommandClient{rest: client}
}
return map[string]downstream.Client{
usermodel.MessageTypeGetMyAccount: target,
usermodel.MessageTypeUpdateMyProfile: target,
usermodel.MessageTypeUpdateMySettings: target,
}
}
// LobbyRoutes returns the authenticated `lobby.*` downstream routes
// served by backend. When client is nil every route resolves to a
// dependency-unavailable client.
func LobbyRoutes(client *RESTClient) map[string]downstream.Client {
target := downstream.Client(unavailableClient{})
if client != nil {
target = lobbyCommandClient{rest: client}
}
return map[string]downstream.Client{
lobbymodel.MessageTypeMyGamesList: target,
lobbymodel.MessageTypeOpenEnrollment: target,
}
}
type unavailableClient struct{}
func (unavailableClient) ExecuteCommand(context.Context, downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
return downstream.UnaryResult{}, downstream.ErrDownstreamUnavailable
}
type userCommandClient struct {
rest *RESTClient
}
func (c userCommandClient) ExecuteCommand(ctx context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
return c.rest.ExecuteUserCommand(ctx, command)
}
type lobbyCommandClient struct {
rest *RESTClient
}
func (c lobbyCommandClient) ExecuteCommand(ctx context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
return c.rest.ExecuteLobbyCommand(ctx, command)
}
var (
_ downstream.Client = unavailableClient{}
_ downstream.Client = userCommandClient{}
_ downstream.Client = lobbyCommandClient{}
)
@@ -0,0 +1,166 @@
package backendclient
import (
"context"
"errors"
"fmt"
"net/http"
"strings"
"galaxy/gateway/internal/downstream"
usermodel "galaxy/model/user"
"galaxy/transcoder"
)
const (
userCommandResultCodeOK = "ok"
defaultUserErrorCode = "internal_error"
)
var stableUserErrorMessages = map[string]string{
"invalid_request": "request is invalid",
"subject_not_found": "subject not found",
"conflict": "request conflicts with current state",
defaultUserErrorCode: "internal server error",
}
// ExecuteUserCommand routes one authenticated user-surface command into
// backend's `/api/v1/user/*` endpoints. The function is registered for
// the message types listed in `galaxy/model/user`.
func (c *RESTClient) ExecuteUserCommand(ctx context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
if c == nil || c.httpClient == nil {
return downstream.UnaryResult{}, errors.New("backendclient: execute user command: nil client")
}
if ctx == nil {
return downstream.UnaryResult{}, errors.New("backendclient: execute user command: nil context")
}
if err := ctx.Err(); err != nil {
return downstream.UnaryResult{}, err
}
if strings.TrimSpace(command.UserID) == "" {
return downstream.UnaryResult{}, errors.New("backendclient: execute user command: user_id must not be empty")
}
switch command.MessageType {
case usermodel.MessageTypeGetMyAccount:
if _, err := transcoder.PayloadToGetMyAccountRequest(command.PayloadBytes); err != nil {
return downstream.UnaryResult{}, fmt.Errorf("backendclient: execute user command %q: %w", command.MessageType, err)
}
return c.executeUserAccountGet(ctx, command.UserID)
case usermodel.MessageTypeUpdateMyProfile:
req, err := transcoder.PayloadToUpdateMyProfileRequest(command.PayloadBytes)
if err != nil {
return downstream.UnaryResult{}, fmt.Errorf("backendclient: execute user command %q: %w", command.MessageType, err)
}
return c.executeUserAccountUpdateProfile(ctx, command.UserID, req)
case usermodel.MessageTypeUpdateMySettings:
req, err := transcoder.PayloadToUpdateMySettingsRequest(command.PayloadBytes)
if err != nil {
return downstream.UnaryResult{}, fmt.Errorf("backendclient: execute user command %q: %w", command.MessageType, err)
}
return c.executeUserAccountUpdateSettings(ctx, command.UserID, req)
default:
return downstream.UnaryResult{}, fmt.Errorf("backendclient: execute user command: unsupported message type %q", command.MessageType)
}
}
func (c *RESTClient) executeUserAccountGet(ctx context.Context, userID string) (downstream.UnaryResult, error) {
body, status, err := c.do(ctx, http.MethodGet, c.baseURL+"/api/v1/user/account", userID, nil)
if err != nil {
return downstream.UnaryResult{}, fmt.Errorf("execute user.account.get: %w", err)
}
return projectUserResponse(status, body)
}
func (c *RESTClient) executeUserAccountUpdateProfile(ctx context.Context, userID string, req *usermodel.UpdateMyProfileRequest) (downstream.UnaryResult, error) {
body, status, err := c.do(ctx, http.MethodPatch, c.baseURL+"/api/v1/user/account/profile", userID, req)
if err != nil {
return downstream.UnaryResult{}, fmt.Errorf("execute user.profile.update: %w", err)
}
return projectUserResponse(status, body)
}
func (c *RESTClient) executeUserAccountUpdateSettings(ctx context.Context, userID string, req *usermodel.UpdateMySettingsRequest) (downstream.UnaryResult, error) {
body, status, err := c.do(ctx, http.MethodPatch, c.baseURL+"/api/v1/user/account/settings", userID, req)
if err != nil {
return downstream.UnaryResult{}, fmt.Errorf("execute user.settings.update: %w", err)
}
return projectUserResponse(status, body)
}
func projectUserResponse(statusCode int, payload []byte) (downstream.UnaryResult, error) {
switch {
case statusCode == http.StatusOK:
var response usermodel.AccountResponse
if err := decodeStrictJSON(payload, &response); err != nil {
return downstream.UnaryResult{}, fmt.Errorf("decode success response: %w", err)
}
payloadBytes, err := transcoder.AccountResponseToPayload(&response)
if err != nil {
return downstream.UnaryResult{}, fmt.Errorf("encode success response payload: %w", err)
}
return downstream.UnaryResult{
ResultCode: userCommandResultCodeOK,
PayloadBytes: payloadBytes,
}, nil
case statusCode == http.StatusServiceUnavailable:
return downstream.UnaryResult{}, downstream.ErrDownstreamUnavailable
case statusCode >= 400 && statusCode <= 599:
errResp, err := decodeUserError(statusCode, payload)
if err != nil {
return downstream.UnaryResult{}, fmt.Errorf("decode error response: %w", err)
}
payloadBytes, err := transcoder.ErrorResponseToPayload(errResp)
if err != nil {
return downstream.UnaryResult{}, fmt.Errorf("encode error response payload: %w", err)
}
return downstream.UnaryResult{
ResultCode: errResp.Error.Code,
PayloadBytes: payloadBytes,
}, nil
default:
return downstream.UnaryResult{}, fmt.Errorf("unexpected HTTP status %d", statusCode)
}
}
func decodeUserError(statusCode int, payload []byte) (*usermodel.ErrorResponse, error) {
var response usermodel.ErrorResponse
if err := decodeStrictJSON(payload, &response); err != nil {
return nil, err
}
response.Error.Code = normalizeUserErrorCode(statusCode, response.Error.Code)
response.Error.Message = normalizeUserErrorMessage(response.Error.Code, response.Error.Message)
if strings.TrimSpace(response.Error.Code) == "" {
return nil, errors.New("missing error code")
}
if strings.TrimSpace(response.Error.Message) == "" {
return nil, errors.New("missing error message")
}
return &response, nil
}
func normalizeUserErrorCode(statusCode int, code string) string {
if trimmed := strings.TrimSpace(code); trimmed != "" {
return trimmed
}
switch statusCode {
case http.StatusBadRequest:
return "invalid_request"
case http.StatusNotFound:
return "subject_not_found"
case http.StatusConflict:
return "conflict"
default:
return defaultUserErrorCode
}
}
func normalizeUserErrorMessage(code, message string) string {
if trimmed := strings.TrimSpace(message); trimmed != "" {
return trimmed
}
if stable, ok := stableUserErrorMessages[code]; ok {
return stable
}
return stableUserErrorMessages[defaultUserErrorCode]
}
+176 -312
View File
@@ -44,20 +44,34 @@ const (
// configures the timeout budget used for public auth upstream calls.
publicAuthUpstreamTimeoutEnvVar = "GATEWAY_PUBLIC_AUTH_UPSTREAM_TIMEOUT"
// authServiceBaseURLEnvVar names the environment variable that configures
// the optional Auth / Session Service public HTTP base URL used by gateway
// public-auth delegation.
authServiceBaseURLEnvVar = "GATEWAY_AUTH_SERVICE_BASE_URL"
// backendHTTPURLEnvVar names the environment variable that configures
// the absolute base URL of the consolidated backend HTTP listener used
// for public auth, internal session lookup, and authenticated user /
// lobby commands.
backendHTTPURLEnvVar = "GATEWAY_BACKEND_HTTP_URL"
// userServiceBaseURLEnvVar names the environment variable that configures
// the optional User Service internal HTTP base URL used by authenticated
// gateway self-service delegation.
userServiceBaseURLEnvVar = "GATEWAY_USER_SERVICE_BASE_URL"
// backendGRPCPushURLEnvVar names the environment variable that
// configures the dial target of backend's gRPC `Push.SubscribePush`
// listener.
backendGRPCPushURLEnvVar = "GATEWAY_BACKEND_GRPC_PUSH_URL"
// lobbyServiceBaseURLEnvVar names the environment variable that configures
// the optional Game Lobby public HTTP base URL used by authenticated
// gateway platform-command delegation.
lobbyServiceBaseURLEnvVar = "GATEWAY_LOBBY_SERVICE_BASE_URL"
// backendGatewayClientIDEnvVar names the environment variable that
// configures the durable identifier this gateway instance presents to
// backend in `GatewaySubscribeRequest.gateway_client_id`.
backendGatewayClientIDEnvVar = "GATEWAY_BACKEND_GATEWAY_CLIENT_ID"
// backendHTTPTimeoutEnvVar names the environment variable that
// configures the per-call timeout applied to backend HTTP requests.
backendHTTPTimeoutEnvVar = "GATEWAY_BACKEND_HTTP_TIMEOUT"
// backendPushReconnectBaseBackoffEnvVar names the environment variable
// that configures the starting delay between reconnect attempts of the
// gRPC SubscribePush stream.
backendPushReconnectBaseBackoffEnvVar = "GATEWAY_BACKEND_PUSH_RECONNECT_BASE_BACKOFF"
// backendPushReconnectMaxBackoffEnvVar names the environment variable
// that configures the upper bound for exponential reconnect delays.
backendPushReconnectMaxBackoffEnvVar = "GATEWAY_BACKEND_PUSH_RECONNECT_MAX_BACKOFF"
// adminHTTPAddrEnvVar names the environment variable that configures the
// private admin HTTP listener address. When it is empty, the admin listener
@@ -152,14 +166,6 @@ const (
// rate-limit burst.
authenticatedGRPCMessageClassRateLimitBurstEnvVar = "GATEWAY_AUTHENTICATED_GRPC_ANTI_ABUSE_MESSAGE_CLASS_RATE_LIMIT_BURST"
// sessionCacheRedisKeyPrefixEnvVar names the environment variable that
// configures the Redis key prefix used for SessionCache records.
sessionCacheRedisKeyPrefixEnvVar = "GATEWAY_SESSION_CACHE_REDIS_KEY_PREFIX"
// sessionCacheRedisLookupTimeoutEnvVar names the environment variable that
// configures the timeout used for SessionCache Redis lookups.
sessionCacheRedisLookupTimeoutEnvVar = "GATEWAY_SESSION_CACHE_REDIS_LOOKUP_TIMEOUT"
// replayRedisKeyPrefixEnvVar names the environment variable that configures
// the Redis key prefix used for authenticated replay reservations.
replayRedisKeyPrefixEnvVar = "GATEWAY_REPLAY_REDIS_KEY_PREFIX"
@@ -169,24 +175,6 @@ const (
// startup connectivity checks.
replayRedisReserveTimeoutEnvVar = "GATEWAY_REPLAY_REDIS_RESERVE_TIMEOUT"
// sessionEventsRedisStreamEnvVar names the environment variable that
// configures the Redis Stream key consumed for session lifecycle updates.
sessionEventsRedisStreamEnvVar = "GATEWAY_SESSION_EVENTS_REDIS_STREAM"
// sessionEventsRedisReadBlockTimeoutEnvVar names the environment variable
// that configures the blocking read timeout used by the session event
// subscriber.
sessionEventsRedisReadBlockTimeoutEnvVar = "GATEWAY_SESSION_EVENTS_REDIS_READ_BLOCK_TIMEOUT"
// clientEventsRedisStreamEnvVar names the environment variable that
// configures the Redis Stream key consumed for client-facing push events.
clientEventsRedisStreamEnvVar = "GATEWAY_CLIENT_EVENTS_REDIS_STREAM"
// clientEventsRedisReadBlockTimeoutEnvVar names the environment variable
// that configures the blocking read timeout used by the client-event
// subscriber.
clientEventsRedisReadBlockTimeoutEnvVar = "GATEWAY_CLIENT_EVENTS_REDIS_READ_BLOCK_TIMEOUT"
// responseSignerPrivateKeyPEMPathEnvVar names the environment variable that
// configures the path to the PKCS#8 PEM-encoded Ed25519 private key used to
// sign authenticated unary responses and stream events.
@@ -293,13 +281,13 @@ const (
defaultPublicHTTPAddr = ":8080"
defaultPublicHTTPReadHeaderTimeout = 2 * time.Second
defaultPublicHTTPReadTimeout = 10 * time.Second
defaultPublicHTTPIdleTimeout = time.Minute
defaultPublicAuthUpstreamTimeout = 3 * time.Second
defaultPublicHTTPReadTimeout = 10 * time.Second
defaultPublicHTTPIdleTimeout = time.Minute
defaultPublicAuthUpstreamTimeout = 3 * time.Second
defaultAdminHTTPReadHeaderTimeout = 2 * time.Second
defaultAdminHTTPReadTimeout = 10 * time.Second
defaultAdminHTTPIdleTimeout = time.Minute
defaultAdminHTTPReadTimeout = 10 * time.Second
defaultAdminHTTPIdleTimeout = time.Minute
// defaultAuthenticatedGRPCAddr is applied when
// authenticatedGRPCAddrEnvVar is absent.
@@ -307,48 +295,46 @@ const (
defaultAuthenticatedGRPCConnectionTimeout = 5 * time.Second
defaultAuthenticatedGRPCDownstreamTimeout = 5 * time.Second
defaultAuthenticatedGRPCFreshnessWindow = 5 * time.Minute
defaultAuthenticatedGRPCFreshnessWindow = 5 * time.Minute
defaultAuthenticatedGRPCIPRateLimitRequests = 120
defaultAuthenticatedGRPCIPRateLimitBurst = 40
defaultAuthenticatedGRPCIPRateLimitBurst = 40
defaultAuthenticatedGRPCSessionRateLimitRequests = 60
defaultAuthenticatedGRPCSessionRateLimitBurst = 20
defaultAuthenticatedGRPCSessionRateLimitBurst = 20
defaultAuthenticatedGRPCUserRateLimitRequests = 120
defaultAuthenticatedGRPCUserRateLimitBurst = 40
defaultAuthenticatedGRPCUserRateLimitBurst = 40
defaultAuthenticatedGRPCMessageClassRateLimitRequests = 60
defaultAuthenticatedGRPCMessageClassRateLimitBurst = 20
defaultAuthenticatedGRPCMessageClassRateLimitBurst = 20
defaultSessionCacheRedisKeyPrefix = "gateway:session:"
defaultSessionCacheRedisLookupTimeout = 250 * time.Millisecond
defaultReplayRedisKeyPrefix = "gateway:replay:"
defaultReplayRedisKeyPrefix = "gateway:replay:"
defaultReplayRedisReserveTimeout = 250 * time.Millisecond
defaultSessionEventsRedisReadBlockTimeout = time.Second
defaultClientEventsRedisReadBlockTimeout = time.Second
defaultBackendHTTPTimeout = 5 * time.Second
defaultBackendPushReconnectBaseBackoff = 250 * time.Millisecond
defaultBackendPushReconnectMaxBackoff = 30 * time.Second
defaultPublicAuthMaxBodyBytes = int64(8192)
defaultPublicAuthRateLimitRequests = 30
defaultPublicAuthRateLimitBurst = 10
defaultPublicAuthRateLimitBurst = 10
defaultBrowserBootstrapRateLimitRequests = 60
defaultBrowserBootstrapRateLimitBurst = 20
defaultBrowserBootstrapRateLimitBurst = 20
defaultBrowserAssetRateLimitRequests = 300
defaultBrowserAssetRateLimitBurst = 80
defaultBrowserAssetRateLimitBurst = 80
defaultPublicMiscRateLimitRequests = 30
defaultPublicMiscRateLimitBurst = 10
defaultPublicMiscRateLimitBurst = 10
defaultSendEmailCodeIdentityRateLimitRequests = 3
defaultSendEmailCodeIdentityRateLimitBurst = 1
defaultSendEmailCodeIdentityRateLimitBurst = 1
defaultConfirmEmailCodeIdentityRateLimitRequests = 6
defaultConfirmEmailCodeIdentityRateLimitBurst = 2
defaultConfirmEmailCodeIdentityRateLimitBurst = 2
)
var (
@@ -462,31 +448,35 @@ type PublicHTTPConfig struct {
AntiAbuse PublicHTTPAntiAbuseConfig
}
// AuthServiceConfig describes the optional public-auth upstream used by the
// gateway runtime.
type AuthServiceConfig struct {
// BaseURL is the absolute base URL of the Auth / Session Service public
// HTTP API. When BaseURL is empty, the gateway keeps using its built-in
// unavailable public-auth adapter.
BaseURL string
}
// BackendConfig describes the consolidated backend service the gateway
// talks to. Every authenticated and public HTTP request is forwarded to
// `HTTPBaseURL`; the gRPC `Push.SubscribePush` stream is opened against
// `GRPCPushURL`.
type BackendConfig struct {
// HTTPBaseURL is the absolute base URL of the backend HTTP listener
// (`/api/v1/{public,user,internal}/*`). Required.
HTTPBaseURL string
// UserServiceConfig describes the optional authenticated self-service upstream
// used by the gateway runtime.
type UserServiceConfig struct {
// BaseURL is the absolute base URL of the User Service internal HTTP API.
// When BaseURL is empty, the gateway keeps using its built-in unavailable
// downstream adapter for the reserved `user.*` routes.
BaseURL string
}
// GRPCPushURL is the dial target of the backend `Push.SubscribePush`
// listener (`host:port`). Required.
GRPCPushURL string
// LobbyServiceConfig describes the optional authenticated platform-command
// upstream used by the gateway runtime.
type LobbyServiceConfig struct {
// BaseURL is the absolute base URL of the Game Lobby public HTTP API.
// When BaseURL is empty, the gateway keeps using its built-in unavailable
// downstream adapter for the reserved `lobby.*` routes.
BaseURL string
// GatewayClientID is the durable identifier this gateway instance
// presents to backend in `GatewaySubscribeRequest.gateway_client_id`.
// Required.
GatewayClientID string
// HTTPTimeout bounds individual REST calls. Must be positive.
HTTPTimeout time.Duration
// PushReconnectBaseBackoff is the starting delay between reconnect
// attempts of `Push.SubscribePush`. Must be positive.
PushReconnectBaseBackoff time.Duration
// PushReconnectMaxBackoff is the upper bound for exponential
// reconnect delays. Must be greater than or equal to
// PushReconnectBaseBackoff.
PushReconnectMaxBackoff time.Duration
}
// AdminHTTPConfig describes the private operational HTTP listener used for
@@ -531,18 +521,6 @@ type AuthenticatedGRPCConfig struct {
AntiAbuse AuthenticatedGRPCAntiAbuseConfig
}
// SessionCacheRedisConfig describes the namespace and timeout used for
// authenticated SessionCache lookups. Connection topology is shared with the
// other Redis-backed gateway components and lives on Config.Redis (see
// `pkg/redisconn`).
type SessionCacheRedisConfig struct {
// KeyPrefix is prepended to every SessionCache Redis key.
KeyPrefix string
// LookupTimeout bounds individual SessionCache Redis operations.
LookupTimeout time.Duration
}
// ReplayRedisConfig describes the Redis namespace and timeout used for
// authenticated replay reservations.
type ReplayRedisConfig struct {
@@ -553,29 +531,6 @@ type ReplayRedisConfig struct {
ReserveTimeout time.Duration
}
// SessionEventsRedisConfig describes the Redis Stream consumed by the gateway
// to keep the process-local session cache synchronized with session lifecycle
// updates.
type SessionEventsRedisConfig struct {
// Stream is the Redis Stream key carrying full session snapshot events.
Stream string
// ReadBlockTimeout bounds one blocking XREAD call so shutdown remains
// responsive even when the stream is idle.
ReadBlockTimeout time.Duration
}
// ClientEventsRedisConfig describes the Redis Stream consumed by the gateway
// to deliver client-facing events to active push streams.
type ClientEventsRedisConfig struct {
// Stream is the Redis Stream key carrying client-facing event entries.
Stream string
// ReadBlockTimeout bounds one blocking XREAD call so shutdown remains
// responsive even when the stream is idle.
ReadBlockTimeout time.Duration
}
// ResponseSignerConfig describes the private-key material used to sign
// authenticated unary responses and stream events.
type ResponseSignerConfig struct {
@@ -603,17 +558,10 @@ type Config struct {
// PublicHTTP configures the public unauthenticated REST listener.
PublicHTTP PublicHTTPConfig
// AuthService configures the optional public-auth delegation to the Auth /
// Session Service.
AuthService AuthServiceConfig
// UserService configures the optional authenticated self-service
// delegation to User Service.
UserService UserServiceConfig
// LobbyService configures the optional authenticated platform-command
// delegation to Game Lobby.
LobbyService LobbyServiceConfig
// Backend configures the consolidated backend the gateway forwards
// every public auth and authenticated user/lobby request to and the
// gRPC `Push.SubscribePush` stream consumed for inbound events.
Backend BackendConfig
// AdminHTTP configures the optional private admin listener used for metrics
// exposure.
@@ -622,25 +570,16 @@ type Config struct {
// AuthenticatedGRPC configures the authenticated gRPC listener.
AuthenticatedGRPC AuthenticatedGRPCConfig
// Redis carries the master/replica/password connection topology shared by
// every gateway Redis component, sourced from the GATEWAY_REDIS_*
// environment variables managed by `pkg/redisconn`.
// Redis carries the master/replica/password connection topology used
// by the anti-replay reservation store, sourced from the
// GATEWAY_REDIS_* environment variables managed by `pkg/redisconn`.
// The implementation dropped session cache projection and the two Redis
// Streams; Redis is now used only for replay reservations.
Redis redisconn.Config
// SessionCacheRedis configures the Redis-backed authenticated SessionCache.
SessionCacheRedis SessionCacheRedisConfig
// ReplayRedis configures the Redis-backed authenticated ReplayStore.
ReplayRedis ReplayRedisConfig
// SessionEventsRedis configures the Redis Stream consumed for session cache
// updates and revocations.
SessionEventsRedis SessionEventsRedisConfig
// ClientEventsRedis configures the Redis Stream consumed for client-facing
// push delivery.
ClientEventsRedis ClientEventsRedisConfig
// ResponseSigner configures the authenticated response and event signer
// loaded during startup.
ResponseSigner ResponseSignerConfig
@@ -650,53 +589,53 @@ type Config struct {
// for the public REST surface.
func DefaultPublicHTTPConfig() PublicHTTPConfig {
return PublicHTTPConfig{
Addr: defaultPublicHTTPAddr,
ReadHeaderTimeout: defaultPublicHTTPReadHeaderTimeout,
ReadTimeout: defaultPublicHTTPReadTimeout,
IdleTimeout: defaultPublicHTTPIdleTimeout,
Addr: defaultPublicHTTPAddr,
ReadHeaderTimeout: defaultPublicHTTPReadHeaderTimeout,
ReadTimeout: defaultPublicHTTPReadTimeout,
IdleTimeout: defaultPublicHTTPIdleTimeout,
AuthUpstreamTimeout: defaultPublicAuthUpstreamTimeout,
AntiAbuse: PublicHTTPAntiAbuseConfig{
PublicAuth: PublicRoutePolicyConfig{
MaxBodyBytes: defaultPublicAuthMaxBodyBytes,
RateLimit: PublicRateLimitConfig{
Requests: defaultPublicAuthRateLimitRequests,
Window: defaultClassRateLimitWindow,
Burst: defaultPublicAuthRateLimitBurst,
Window: defaultClassRateLimitWindow,
Burst: defaultPublicAuthRateLimitBurst,
},
},
BrowserBootstrap: PublicRoutePolicyConfig{
RateLimit: PublicRateLimitConfig{
Requests: defaultBrowserBootstrapRateLimitRequests,
Window: defaultClassRateLimitWindow,
Burst: defaultBrowserBootstrapRateLimitBurst,
Window: defaultClassRateLimitWindow,
Burst: defaultBrowserBootstrapRateLimitBurst,
},
},
BrowserAsset: PublicRoutePolicyConfig{
RateLimit: PublicRateLimitConfig{
Requests: defaultBrowserAssetRateLimitRequests,
Window: defaultClassRateLimitWindow,
Burst: defaultBrowserAssetRateLimitBurst,
Window: defaultClassRateLimitWindow,
Burst: defaultBrowserAssetRateLimitBurst,
},
},
PublicMisc: PublicRoutePolicyConfig{
RateLimit: PublicRateLimitConfig{
Requests: defaultPublicMiscRateLimitRequests,
Window: defaultClassRateLimitWindow,
Burst: defaultPublicMiscRateLimitBurst,
Window: defaultClassRateLimitWindow,
Burst: defaultPublicMiscRateLimitBurst,
},
},
SendEmailCodeIdentity: PublicAuthIdentityPolicyConfig{
RateLimit: PublicRateLimitConfig{
Requests: defaultSendEmailCodeIdentityRateLimitRequests,
Window: defaultIdentityRateLimitWindow,
Burst: defaultSendEmailCodeIdentityRateLimitBurst,
Window: defaultIdentityRateLimitWindow,
Burst: defaultSendEmailCodeIdentityRateLimitBurst,
},
},
ConfirmEmailCodeIdentity: PublicAuthIdentityPolicyConfig{
RateLimit: PublicRateLimitConfig{
Requests: defaultConfirmEmailCodeIdentityRateLimitRequests,
Window: defaultIdentityRateLimitWindow,
Burst: defaultConfirmEmailCodeIdentityRateLimitBurst,
Window: defaultIdentityRateLimitWindow,
Burst: defaultConfirmEmailCodeIdentityRateLimitBurst,
},
},
},
@@ -708,8 +647,8 @@ func DefaultPublicHTTPConfig() PublicHTTPConfig {
func DefaultAdminHTTPConfig() AdminHTTPConfig {
return AdminHTTPConfig{
ReadHeaderTimeout: defaultAdminHTTPReadHeaderTimeout,
ReadTimeout: defaultAdminHTTPReadTimeout,
IdleTimeout: defaultAdminHTTPIdleTimeout,
ReadTimeout: defaultAdminHTTPReadTimeout,
IdleTimeout: defaultAdminHTTPIdleTimeout,
}
}
@@ -717,30 +656,30 @@ func DefaultAdminHTTPConfig() AdminHTTPConfig {
// anti-abuse settings for the authenticated gRPC surface.
func DefaultAuthenticatedGRPCConfig() AuthenticatedGRPCConfig {
return AuthenticatedGRPCConfig{
Addr: defaultAuthenticatedGRPCAddr,
Addr: defaultAuthenticatedGRPCAddr,
ConnectionTimeout: defaultAuthenticatedGRPCConnectionTimeout,
DownstreamTimeout: defaultAuthenticatedGRPCDownstreamTimeout,
FreshnessWindow: defaultAuthenticatedGRPCFreshnessWindow,
FreshnessWindow: defaultAuthenticatedGRPCFreshnessWindow,
AntiAbuse: AuthenticatedGRPCAntiAbuseConfig{
IP: AuthenticatedRateLimitConfig{
Requests: defaultAuthenticatedGRPCIPRateLimitRequests,
Window: defaultClassRateLimitWindow,
Burst: defaultAuthenticatedGRPCIPRateLimitBurst,
Window: defaultClassRateLimitWindow,
Burst: defaultAuthenticatedGRPCIPRateLimitBurst,
},
Session: AuthenticatedRateLimitConfig{
Requests: defaultAuthenticatedGRPCSessionRateLimitRequests,
Window: defaultClassRateLimitWindow,
Burst: defaultAuthenticatedGRPCSessionRateLimitBurst,
Window: defaultClassRateLimitWindow,
Burst: defaultAuthenticatedGRPCSessionRateLimitBurst,
},
User: AuthenticatedRateLimitConfig{
Requests: defaultAuthenticatedGRPCUserRateLimitRequests,
Window: defaultClassRateLimitWindow,
Burst: defaultAuthenticatedGRPCUserRateLimitBurst,
Window: defaultClassRateLimitWindow,
Burst: defaultAuthenticatedGRPCUserRateLimitBurst,
},
MessageClass: AuthenticatedRateLimitConfig{
Requests: defaultAuthenticatedGRPCMessageClassRateLimitRequests,
Window: defaultClassRateLimitWindow,
Burst: defaultAuthenticatedGRPCMessageClassRateLimitBurst,
Window: defaultClassRateLimitWindow,
Burst: defaultAuthenticatedGRPCMessageClassRateLimitBurst,
},
},
}
@@ -751,39 +690,23 @@ func DefaultLoggingConfig() LoggingConfig {
return LoggingConfig{Level: defaultLogLevel}
}
// DefaultSessionCacheRedisConfig returns the default optional namespace and
// timeout settings for the Redis-backed authenticated SessionCache.
func DefaultSessionCacheRedisConfig() SessionCacheRedisConfig {
return SessionCacheRedisConfig{
KeyPrefix: defaultSessionCacheRedisKeyPrefix,
LookupTimeout: defaultSessionCacheRedisLookupTimeout,
}
}
// DefaultReplayRedisConfig returns the default Redis key namespace and timeout
// used for authenticated replay reservations.
func DefaultReplayRedisConfig() ReplayRedisConfig {
return ReplayRedisConfig{
KeyPrefix: defaultReplayRedisKeyPrefix,
KeyPrefix: defaultReplayRedisKeyPrefix,
ReserveTimeout: defaultReplayRedisReserveTimeout,
}
}
// DefaultSessionEventsRedisConfig returns the default optional settings for the
// session lifecycle event subscriber. Stream remains empty and must be
// supplied explicitly.
func DefaultSessionEventsRedisConfig() SessionEventsRedisConfig {
return SessionEventsRedisConfig{
ReadBlockTimeout: defaultSessionEventsRedisReadBlockTimeout,
}
}
// DefaultClientEventsRedisConfig returns the default optional settings for the
// client-facing event subscriber. Stream remains empty and must be supplied
// explicitly.
func DefaultClientEventsRedisConfig() ClientEventsRedisConfig {
return ClientEventsRedisConfig{
ReadBlockTimeout: defaultClientEventsRedisReadBlockTimeout,
// DefaultBackendConfig returns the default backend settings used for the
// gateway → backend HTTP and gRPC conversation. URL fields stay empty and
// must be supplied explicitly via env vars.
func DefaultBackendConfig() BackendConfig {
return BackendConfig{
HTTPTimeout: defaultBackendHTTPTimeout,
PushReconnectBaseBackoff: defaultBackendPushReconnectBaseBackoff,
PushReconnectMaxBackoff: defaultBackendPushReconnectMaxBackoff,
}
}
@@ -793,44 +716,19 @@ func DefaultResponseSignerConfig() ResponseSignerConfig {
return ResponseSignerConfig{}
}
// DefaultAuthServiceConfig returns the default public-auth upstream settings.
// The zero value keeps the built-in unavailable adapter active.
func DefaultAuthServiceConfig() AuthServiceConfig {
return AuthServiceConfig{}
}
// DefaultUserServiceConfig returns the default authenticated self-service
// upstream settings. The zero value keeps the built-in unavailable adapter
// active for reserved `user.*` routes.
func DefaultUserServiceConfig() UserServiceConfig {
return UserServiceConfig{}
}
// DefaultLobbyServiceConfig returns the default authenticated platform-command
// upstream settings. The zero value keeps the built-in unavailable adapter
// active for reserved `lobby.*` routes.
func DefaultLobbyServiceConfig() LobbyServiceConfig {
return LobbyServiceConfig{}
}
// LoadFromEnv loads Config from the process environment, applies defaults for
// omitted settings, and validates the resulting values.
func LoadFromEnv() (Config, error) {
cfg := Config{
ShutdownTimeout: defaultShutdownTimeout,
Logging: DefaultLoggingConfig(),
PublicHTTP: DefaultPublicHTTPConfig(),
AuthService: DefaultAuthServiceConfig(),
UserService: DefaultUserServiceConfig(),
LobbyService: DefaultLobbyServiceConfig(),
AdminHTTP: DefaultAdminHTTPConfig(),
AuthenticatedGRPC: DefaultAuthenticatedGRPCConfig(),
Redis: redisconn.DefaultConfig(),
SessionCacheRedis: DefaultSessionCacheRedisConfig(),
ReplayRedis: DefaultReplayRedisConfig(),
SessionEventsRedis: DefaultSessionEventsRedisConfig(),
ClientEventsRedis: DefaultClientEventsRedisConfig(),
ResponseSigner: DefaultResponseSignerConfig(),
ShutdownTimeout: defaultShutdownTimeout,
Logging: DefaultLoggingConfig(),
PublicHTTP: DefaultPublicHTTPConfig(),
Backend: DefaultBackendConfig(),
AdminHTTP: DefaultAdminHTTPConfig(),
AuthenticatedGRPC: DefaultAuthenticatedGRPCConfig(),
Redis: redisconn.DefaultConfig(),
ReplayRedis: DefaultReplayRedisConfig(),
ResponseSigner: DefaultResponseSignerConfig(),
}
rawShutdownTimeout, ok := os.LookupEnv(shutdownTimeoutEnvVar)
@@ -876,20 +774,30 @@ func LoadFromEnv() (Config, error) {
}
cfg.PublicHTTP.AuthUpstreamTimeout = publicAuthUpstreamTimeout
rawAuthServiceBaseURL, ok := os.LookupEnv(authServiceBaseURLEnvVar)
if ok {
cfg.AuthService.BaseURL = rawAuthServiceBaseURL
if v, ok := os.LookupEnv(backendHTTPURLEnvVar); ok {
cfg.Backend.HTTPBaseURL = v
}
rawUserServiceBaseURL, ok := os.LookupEnv(userServiceBaseURLEnvVar)
if ok {
cfg.UserService.BaseURL = rawUserServiceBaseURL
if v, ok := os.LookupEnv(backendGRPCPushURLEnvVar); ok {
cfg.Backend.GRPCPushURL = v
}
rawLobbyServiceBaseURL, ok := os.LookupEnv(lobbyServiceBaseURLEnvVar)
if ok {
cfg.LobbyService.BaseURL = rawLobbyServiceBaseURL
if v, ok := os.LookupEnv(backendGatewayClientIDEnvVar); ok {
cfg.Backend.GatewayClientID = v
}
backendHTTPTimeout, err := loadDurationEnvWithDefault(backendHTTPTimeoutEnvVar, cfg.Backend.HTTPTimeout)
if err != nil {
return Config{}, err
}
cfg.Backend.HTTPTimeout = backendHTTPTimeout
backendPushReconnectBaseBackoff, err := loadDurationEnvWithDefault(backendPushReconnectBaseBackoffEnvVar, cfg.Backend.PushReconnectBaseBackoff)
if err != nil {
return Config{}, err
}
cfg.Backend.PushReconnectBaseBackoff = backendPushReconnectBaseBackoff
backendPushReconnectMaxBackoff, err := loadDurationEnvWithDefault(backendPushReconnectMaxBackoffEnvVar, cfg.Backend.PushReconnectMaxBackoff)
if err != nil {
return Config{}, err
}
cfg.Backend.PushReconnectMaxBackoff = backendPushReconnectMaxBackoff
rawAdminHTTPAddr, ok := os.LookupEnv(adminHTTPAddrEnvVar)
if ok {
@@ -987,17 +895,6 @@ func LoadFromEnv() (Config, error) {
}
cfg.Redis = redisConn
rawSessionCacheRedisKeyPrefix, ok := os.LookupEnv(sessionCacheRedisKeyPrefixEnvVar)
if ok {
cfg.SessionCacheRedis.KeyPrefix = rawSessionCacheRedisKeyPrefix
}
sessionCacheRedisLookupTimeout, err := loadDurationEnvWithDefault(sessionCacheRedisLookupTimeoutEnvVar, cfg.SessionCacheRedis.LookupTimeout)
if err != nil {
return Config{}, err
}
cfg.SessionCacheRedis.LookupTimeout = sessionCacheRedisLookupTimeout
rawReplayRedisKeyPrefix, ok := os.LookupEnv(replayRedisKeyPrefixEnvVar)
if ok {
cfg.ReplayRedis.KeyPrefix = rawReplayRedisKeyPrefix
@@ -1009,28 +906,6 @@ func LoadFromEnv() (Config, error) {
}
cfg.ReplayRedis.ReserveTimeout = replayRedisReserveTimeout
rawSessionEventsRedisStream, ok := os.LookupEnv(sessionEventsRedisStreamEnvVar)
if ok {
cfg.SessionEventsRedis.Stream = rawSessionEventsRedisStream
}
sessionEventsRedisReadBlockTimeout, err := loadDurationEnvWithDefault(sessionEventsRedisReadBlockTimeoutEnvVar, cfg.SessionEventsRedis.ReadBlockTimeout)
if err != nil {
return Config{}, err
}
cfg.SessionEventsRedis.ReadBlockTimeout = sessionEventsRedisReadBlockTimeout
rawClientEventsRedisStream, ok := os.LookupEnv(clientEventsRedisStreamEnvVar)
if ok {
cfg.ClientEventsRedis.Stream = rawClientEventsRedisStream
}
clientEventsRedisReadBlockTimeout, err := loadDurationEnvWithDefault(clientEventsRedisReadBlockTimeoutEnvVar, cfg.ClientEventsRedis.ReadBlockTimeout)
if err != nil {
return Config{}, err
}
cfg.ClientEventsRedis.ReadBlockTimeout = clientEventsRedisReadBlockTimeout
rawSignerKeyPath, ok := os.LookupEnv(responseSignerPrivateKeyPEMPathEnvVar)
if ok {
cfg.ResponseSigner.PrivateKeyPEMPath = rawSignerKeyPath
@@ -1127,27 +1002,34 @@ func LoadFromEnv() (Config, error) {
if cfg.PublicHTTP.AuthUpstreamTimeout <= 0 {
return Config{}, fmt.Errorf("load gateway config: %s must be positive", publicAuthUpstreamTimeoutEnvVar)
}
cfg.AuthService.BaseURL = strings.TrimSpace(cfg.AuthService.BaseURL)
if cfg.AuthService.BaseURL != "" {
parsedAuthServiceBaseURL, err := url.Parse(cfg.AuthService.BaseURL)
if err != nil {
return Config{}, fmt.Errorf("load gateway config: parse %s: %w", authServiceBaseURLEnvVar, err)
}
if parsedAuthServiceBaseURL.Scheme == "" || parsedAuthServiceBaseURL.Host == "" {
return Config{}, fmt.Errorf("load gateway config: %s must be an absolute URL", authServiceBaseURLEnvVar)
}
cfg.AuthService.BaseURL = strings.TrimRight(parsedAuthServiceBaseURL.String(), "/")
cfg.Backend.HTTPBaseURL = strings.TrimSpace(cfg.Backend.HTTPBaseURL)
if cfg.Backend.HTTPBaseURL == "" {
return Config{}, fmt.Errorf("load gateway config: %s must not be empty", backendHTTPURLEnvVar)
}
cfg.UserService.BaseURL = strings.TrimSpace(cfg.UserService.BaseURL)
if cfg.UserService.BaseURL != "" {
parsedUserServiceBaseURL, err := url.Parse(cfg.UserService.BaseURL)
if err != nil {
return Config{}, fmt.Errorf("load gateway config: parse %s: %w", userServiceBaseURLEnvVar, err)
}
if parsedUserServiceBaseURL.Scheme == "" || parsedUserServiceBaseURL.Host == "" {
return Config{}, fmt.Errorf("load gateway config: %s must be an absolute URL", userServiceBaseURLEnvVar)
}
cfg.UserService.BaseURL = strings.TrimRight(parsedUserServiceBaseURL.String(), "/")
parsedBackendHTTP, err := url.Parse(strings.TrimRight(cfg.Backend.HTTPBaseURL, "/"))
if err != nil {
return Config{}, fmt.Errorf("load gateway config: parse %s: %w", backendHTTPURLEnvVar, err)
}
if parsedBackendHTTP.Scheme == "" || parsedBackendHTTP.Host == "" {
return Config{}, fmt.Errorf("load gateway config: %s must be an absolute URL", backendHTTPURLEnvVar)
}
cfg.Backend.HTTPBaseURL = parsedBackendHTTP.String()
cfg.Backend.GRPCPushURL = strings.TrimSpace(cfg.Backend.GRPCPushURL)
if cfg.Backend.GRPCPushURL == "" {
return Config{}, fmt.Errorf("load gateway config: %s must not be empty", backendGRPCPushURLEnvVar)
}
cfg.Backend.GatewayClientID = strings.TrimSpace(cfg.Backend.GatewayClientID)
if cfg.Backend.GatewayClientID == "" {
return Config{}, fmt.Errorf("load gateway config: %s must not be empty", backendGatewayClientIDEnvVar)
}
if cfg.Backend.HTTPTimeout <= 0 {
return Config{}, fmt.Errorf("load gateway config: %s must be positive", backendHTTPTimeoutEnvVar)
}
if cfg.Backend.PushReconnectBaseBackoff <= 0 {
return Config{}, fmt.Errorf("load gateway config: %s must be positive", backendPushReconnectBaseBackoffEnvVar)
}
if cfg.Backend.PushReconnectMaxBackoff < cfg.Backend.PushReconnectBaseBackoff {
return Config{}, fmt.Errorf("load gateway config: %s must be >= %s", backendPushReconnectMaxBackoffEnvVar, backendPushReconnectBaseBackoffEnvVar)
}
if addr := strings.TrimSpace(cfg.AdminHTTP.Addr); addr != "" {
cfg.AdminHTTP.Addr = addr
@@ -1208,30 +1090,12 @@ func LoadFromEnv() (Config, error) {
if err := cfg.Redis.Validate(); err != nil {
return Config{}, fmt.Errorf("load gateway config: redis: %w", err)
}
if strings.TrimSpace(cfg.SessionCacheRedis.KeyPrefix) == "" {
return Config{}, fmt.Errorf("load gateway config: %s must not be empty", sessionCacheRedisKeyPrefixEnvVar)
}
if cfg.SessionCacheRedis.LookupTimeout <= 0 {
return Config{}, fmt.Errorf("load gateway config: %s must be positive", sessionCacheRedisLookupTimeoutEnvVar)
}
if strings.TrimSpace(cfg.ReplayRedis.KeyPrefix) == "" {
return Config{}, fmt.Errorf("load gateway config: %s must not be empty", replayRedisKeyPrefixEnvVar)
}
if cfg.ReplayRedis.ReserveTimeout <= 0 {
return Config{}, fmt.Errorf("load gateway config: %s must be positive", replayRedisReserveTimeoutEnvVar)
}
if strings.TrimSpace(cfg.SessionEventsRedis.Stream) == "" {
return Config{}, fmt.Errorf("load gateway config: %s must not be empty", sessionEventsRedisStreamEnvVar)
}
if cfg.SessionEventsRedis.ReadBlockTimeout <= 0 {
return Config{}, fmt.Errorf("load gateway config: %s must be positive", sessionEventsRedisReadBlockTimeoutEnvVar)
}
if strings.TrimSpace(cfg.ClientEventsRedis.Stream) == "" {
return Config{}, fmt.Errorf("load gateway config: %s must not be empty", clientEventsRedisStreamEnvVar)
}
if cfg.ClientEventsRedis.ReadBlockTimeout <= 0 {
return Config{}, fmt.Errorf("load gateway config: %s must be positive", clientEventsRedisReadBlockTimeoutEnvVar)
}
if strings.TrimSpace(cfg.ResponseSigner.PrivateKeyPEMPath) == "" {
return Config{}, fmt.Errorf("load gateway config: %s must not be empty", responseSignerPrivateKeyPEMPathEnvVar)
}
File diff suppressed because it is too large Load Diff
@@ -1,329 +0,0 @@
// Package lobbyservice implements the authenticated Gateway -> Game Lobby
// downstream adapter. It forwards verified authenticated commands as
// trusted-internal HTTP requests against Game Lobby's public REST surface,
// transporting the calling user identity through the `X-User-Id` header.
package lobbyservice
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"galaxy/gateway/internal/downstream"
lobbymodel "galaxy/model/lobby"
"galaxy/transcoder"
)
const (
myGamesListPath = "/api/v1/lobby/my/games"
openEnrollmentPathFormat = "/api/v1/lobby/games/%s/open-enrollment"
resultCodeOK = "ok"
defaultErrorCodeBadRequest = "invalid_request"
defaultErrorCodeNotFound = "subject_not_found"
defaultErrorCodeForbidden = "forbidden"
defaultErrorCodeConflict = "conflict"
defaultErrorCodeInternalError = "internal_error"
headerCallingUserID = "X-User-Id"
)
var stableErrorMessages = map[string]string{
defaultErrorCodeBadRequest: "request is invalid",
defaultErrorCodeNotFound: "subject not found",
defaultErrorCodeForbidden: "operation is forbidden for the calling user",
defaultErrorCodeConflict: "request conflicts with current state",
defaultErrorCodeInternalError: "internal server error",
}
// HTTPClient implements downstream.Client against the trusted Game Lobby
// public REST API while preserving FlatBuffers at the external authenticated
// gateway boundary.
type HTTPClient struct {
baseURL string
httpClient *http.Client
}
// NewHTTPClient constructs one Game Lobby downstream client backed by the
// public REST API at baseURL.
func NewHTTPClient(baseURL string) (*HTTPClient, error) {
transport, ok := http.DefaultTransport.(*http.Transport)
if !ok {
return nil, errors.New("new lobby service HTTP client: default transport is not *http.Transport")
}
return newHTTPClient(baseURL, &http.Client{
Transport: transport.Clone(),
})
}
func newHTTPClient(baseURL string, httpClient *http.Client) (*HTTPClient, error) {
if httpClient == nil {
return nil, errors.New("new lobby service HTTP client: http client must not be nil")
}
trimmedBaseURL := strings.TrimSpace(baseURL)
if trimmedBaseURL == "" {
return nil, errors.New("new lobby service HTTP client: base URL must not be empty")
}
parsedBaseURL, err := url.Parse(strings.TrimRight(trimmedBaseURL, "/"))
if err != nil {
return nil, fmt.Errorf("new lobby service HTTP client: parse base URL: %w", err)
}
if parsedBaseURL.Scheme == "" || parsedBaseURL.Host == "" {
return nil, errors.New("new lobby service HTTP client: base URL must be absolute")
}
return &HTTPClient{
baseURL: parsedBaseURL.String(),
httpClient: httpClient,
}, nil
}
// Close releases idle HTTP connections owned by the client transport.
func (c *HTTPClient) Close() error {
if c == nil || c.httpClient == nil {
return nil
}
type idleCloser interface {
CloseIdleConnections()
}
if transport, ok := c.httpClient.Transport.(idleCloser); ok {
transport.CloseIdleConnections()
}
return nil
}
// ExecuteCommand routes one authenticated gateway command to the matching
// trusted Game Lobby public REST route.
func (c *HTTPClient) ExecuteCommand(ctx context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
if c == nil || c.httpClient == nil {
return downstream.UnaryResult{}, errors.New("execute lobby service command: nil client")
}
if ctx == nil {
return downstream.UnaryResult{}, errors.New("execute lobby service command: nil context")
}
if err := ctx.Err(); err != nil {
return downstream.UnaryResult{}, err
}
if strings.TrimSpace(command.UserID) == "" {
return downstream.UnaryResult{}, errors.New("execute lobby service command: user_id must not be empty")
}
switch command.MessageType {
case lobbymodel.MessageTypeMyGamesList:
if _, err := transcoder.PayloadToMyGamesListRequest(command.PayloadBytes); err != nil {
return downstream.UnaryResult{}, fmt.Errorf("execute lobby service command %q: %w", command.MessageType, err)
}
return c.executeMyGamesList(ctx, command.UserID)
case lobbymodel.MessageTypeOpenEnrollment:
request, err := transcoder.PayloadToOpenEnrollmentRequest(command.PayloadBytes)
if err != nil {
return downstream.UnaryResult{}, fmt.Errorf("execute lobby service command %q: %w", command.MessageType, err)
}
return c.executeOpenEnrollment(ctx, command.UserID, request)
default:
return downstream.UnaryResult{}, fmt.Errorf("execute lobby service command: unsupported message type %q", command.MessageType)
}
}
func (c *HTTPClient) executeMyGamesList(ctx context.Context, userID string) (downstream.UnaryResult, error) {
payload, statusCode, err := c.doRequest(ctx, http.MethodGet, c.baseURL+myGamesListPath, userID, nil)
if err != nil {
return downstream.UnaryResult{}, fmt.Errorf("execute my games list: %w", err)
}
if statusCode == http.StatusOK {
var response lobbymodel.MyGamesListResponse
if err := decodeStrictJSONPayload(payload, &response); err != nil {
return downstream.UnaryResult{}, fmt.Errorf("decode success response: %w", err)
}
payloadBytes, err := transcoder.MyGamesListResponseToPayload(&response)
if err != nil {
return downstream.UnaryResult{}, fmt.Errorf("encode success response payload: %w", err)
}
return downstream.UnaryResult{
ResultCode: resultCodeOK,
PayloadBytes: payloadBytes,
}, nil
}
return projectErrorResponse(statusCode, payload)
}
func (c *HTTPClient) executeOpenEnrollment(ctx context.Context, userID string, request *lobbymodel.OpenEnrollmentRequest) (downstream.UnaryResult, error) {
if request == nil || strings.TrimSpace(request.GameID) == "" {
return downstream.UnaryResult{}, errors.New("execute open enrollment: game_id must not be empty")
}
target := c.baseURL + fmt.Sprintf(openEnrollmentPathFormat, url.PathEscape(request.GameID))
payload, statusCode, err := c.doRequest(ctx, http.MethodPost, target, userID, struct{}{})
if err != nil {
return downstream.UnaryResult{}, fmt.Errorf("execute open enrollment: %w", err)
}
if statusCode == http.StatusOK {
// Lobby's open-enrollment endpoint returns the full game record;
// the gateway boundary projects the minimal status pair.
var fullRecord struct {
GameID string `json:"game_id"`
Status string `json:"status"`
}
if err := json.Unmarshal(payload, &fullRecord); err != nil {
return downstream.UnaryResult{}, fmt.Errorf("decode success response: %w", err)
}
payloadBytes, err := transcoder.OpenEnrollmentResponseToPayload(&lobbymodel.OpenEnrollmentResponse{
GameID: fullRecord.GameID,
Status: fullRecord.Status,
})
if err != nil {
return downstream.UnaryResult{}, fmt.Errorf("encode success response payload: %w", err)
}
return downstream.UnaryResult{
ResultCode: resultCodeOK,
PayloadBytes: payloadBytes,
}, nil
}
return projectErrorResponse(statusCode, payload)
}
func (c *HTTPClient) doRequest(ctx context.Context, method, targetURL, userID string, requestBody any) ([]byte, int, error) {
if c == nil || c.httpClient == nil {
return nil, 0, errors.New("nil client")
}
var bodyReader io.Reader
if requestBody != nil {
body, err := json.Marshal(requestBody)
if err != nil {
return nil, 0, fmt.Errorf("marshal request body: %w", err)
}
bodyReader = bytes.NewReader(body)
}
request, err := http.NewRequestWithContext(ctx, method, targetURL, bodyReader)
if err != nil {
return nil, 0, fmt.Errorf("build request: %w", err)
}
if requestBody != nil {
request.Header.Set("Content-Type", "application/json")
}
request.Header.Set(headerCallingUserID, userID)
response, err := c.httpClient.Do(request)
if err != nil {
return nil, 0, err
}
defer response.Body.Close()
payload, err := io.ReadAll(response.Body)
if err != nil {
return nil, 0, fmt.Errorf("read response body: %w", err)
}
return payload, response.StatusCode, nil
}
func projectErrorResponse(statusCode int, payload []byte) (downstream.UnaryResult, error) {
switch {
case statusCode == http.StatusServiceUnavailable:
return downstream.UnaryResult{}, downstream.ErrDownstreamUnavailable
case statusCode >= 400 && statusCode <= 599:
errorResponse, err := decodeLobbyError(statusCode, payload)
if err != nil {
return downstream.UnaryResult{}, fmt.Errorf("decode error response: %w", err)
}
payloadBytes, err := transcoder.LobbyErrorResponseToPayload(errorResponse)
if err != nil {
return downstream.UnaryResult{}, fmt.Errorf("encode error response payload: %w", err)
}
return downstream.UnaryResult{
ResultCode: errorResponse.Error.Code,
PayloadBytes: payloadBytes,
}, nil
default:
return downstream.UnaryResult{}, fmt.Errorf("unexpected HTTP status %d", statusCode)
}
}
func decodeLobbyError(statusCode int, payload []byte) (*lobbymodel.ErrorResponse, error) {
var response lobbymodel.ErrorResponse
if err := decodeStrictJSONPayload(payload, &response); err != nil {
return nil, err
}
response.Error.Code = normalizeErrorCode(statusCode, response.Error.Code)
response.Error.Message = normalizeErrorMessage(response.Error.Code, response.Error.Message)
if strings.TrimSpace(response.Error.Code) == "" {
return nil, errors.New("missing error code")
}
if strings.TrimSpace(response.Error.Message) == "" {
return nil, errors.New("missing error message")
}
return &response, nil
}
func normalizeErrorCode(statusCode int, code string) string {
trimmed := strings.TrimSpace(code)
if trimmed != "" {
return trimmed
}
switch statusCode {
case http.StatusBadRequest:
return defaultErrorCodeBadRequest
case http.StatusForbidden:
return defaultErrorCodeForbidden
case http.StatusNotFound:
return defaultErrorCodeNotFound
case http.StatusConflict:
return defaultErrorCodeConflict
default:
return defaultErrorCodeInternalError
}
}
func normalizeErrorMessage(code, message string) string {
trimmed := strings.TrimSpace(message)
if trimmed != "" {
return trimmed
}
if stable, ok := stableErrorMessages[code]; ok {
return stable
}
return stableErrorMessages[defaultErrorCodeInternalError]
}
func decodeStrictJSONPayload(payload []byte, target any) error {
decoder := json.NewDecoder(bytes.NewReader(payload))
decoder.DisallowUnknownFields()
if err := decoder.Decode(target); err != nil {
return err
}
if err := decoder.Decode(&struct{}{}); err != io.EOF {
if err == nil {
return errors.New("unexpected trailing JSON input")
}
return err
}
return nil
}
var _ downstream.Client = (*HTTPClient)(nil)
@@ -1,212 +0,0 @@
package lobbyservice_test
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"galaxy/gateway/internal/downstream"
"galaxy/gateway/internal/downstream/lobbyservice"
lobbymodel "galaxy/model/lobby"
"galaxy/transcoder"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestExecuteMyGamesListSuccess(t *testing.T) {
t.Parallel()
expectedResponse := lobbymodel.MyGamesListResponse{
Items: []lobbymodel.GameSummary{
{
GameID: "game-1",
GameName: "Nebula Clash",
GameType: "private",
Status: "draft",
OwnerUserID: "user-1",
MinPlayers: 2,
MaxPlayers: 8,
EnrollmentEndsAt: time.Date(2026, 5, 1, 12, 0, 0, 0, time.UTC),
CreatedAt: time.Date(2026, 4, 28, 9, 0, 0, 0, time.UTC),
UpdatedAt: time.Date(2026, 4, 28, 9, 5, 0, 0, time.UTC),
},
},
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodGet, r.Method)
assert.Equal(t, "/api/v1/lobby/my/games", r.URL.Path)
assert.Equal(t, "user-1", r.Header.Get("X-User-Id"))
w.Header().Set("Content-Type", "application/json")
require.NoError(t, json.NewEncoder(w).Encode(expectedResponse))
}))
t.Cleanup(server.Close)
client, err := lobbyservice.NewHTTPClient(server.URL)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, client.Close()) })
requestBytes, err := transcoder.MyGamesListRequestToPayload(&lobbymodel.MyGamesListRequest{})
require.NoError(t, err)
result, err := client.ExecuteCommand(context.Background(), downstream.AuthenticatedCommand{
MessageType: lobbymodel.MessageTypeMyGamesList,
UserID: "user-1",
PayloadBytes: requestBytes,
})
require.NoError(t, err)
assert.Equal(t, "ok", result.ResultCode)
decoded, err := transcoder.PayloadToMyGamesListResponse(result.PayloadBytes)
require.NoError(t, err)
require.Len(t, decoded.Items, 1)
assert.Equal(t, expectedResponse.Items[0].GameID, decoded.Items[0].GameID)
assert.Equal(t, expectedResponse.Items[0].OwnerUserID, decoded.Items[0].OwnerUserID)
assert.Equal(t, expectedResponse.Items[0].MinPlayers, decoded.Items[0].MinPlayers)
}
func TestExecuteOpenEnrollmentSuccess(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodPost, r.Method)
assert.Equal(t, "/api/v1/lobby/games/game-77/open-enrollment", r.URL.Path)
assert.Equal(t, "owner-1", r.Header.Get("X-User-Id"))
w.Header().Set("Content-Type", "application/json")
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"game_id": "game-77",
"status": "enrollment_open",
}))
}))
t.Cleanup(server.Close)
client, err := lobbyservice.NewHTTPClient(server.URL)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, client.Close()) })
requestBytes, err := transcoder.OpenEnrollmentRequestToPayload(&lobbymodel.OpenEnrollmentRequest{GameID: "game-77"})
require.NoError(t, err)
result, err := client.ExecuteCommand(context.Background(), downstream.AuthenticatedCommand{
MessageType: lobbymodel.MessageTypeOpenEnrollment,
UserID: "owner-1",
PayloadBytes: requestBytes,
})
require.NoError(t, err)
assert.Equal(t, "ok", result.ResultCode)
decoded, err := transcoder.PayloadToOpenEnrollmentResponse(result.PayloadBytes)
require.NoError(t, err)
assert.Equal(t, "game-77", decoded.GameID)
assert.Equal(t, "enrollment_open", decoded.Status)
}
func TestExecuteOpenEnrollmentForbiddenProjectsErrorEnvelope(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusForbidden)
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"error": map[string]string{
"code": "forbidden",
"message": "only the game owner may open enrollment",
},
}))
}))
t.Cleanup(server.Close)
client, err := lobbyservice.NewHTTPClient(server.URL)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, client.Close()) })
requestBytes, err := transcoder.OpenEnrollmentRequestToPayload(&lobbymodel.OpenEnrollmentRequest{GameID: "game-77"})
require.NoError(t, err)
result, err := client.ExecuteCommand(context.Background(), downstream.AuthenticatedCommand{
MessageType: lobbymodel.MessageTypeOpenEnrollment,
UserID: "non-owner",
PayloadBytes: requestBytes,
})
require.NoError(t, err)
assert.Equal(t, "forbidden", result.ResultCode)
decoded, err := transcoder.PayloadToLobbyErrorResponse(result.PayloadBytes)
require.NoError(t, err)
assert.Equal(t, "forbidden", decoded.Error.Code)
assert.NotEmpty(t, decoded.Error.Message)
}
func TestExecuteCommandUnavailableProjectsErrUnavailable(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusServiceUnavailable)
}))
t.Cleanup(server.Close)
client, err := lobbyservice.NewHTTPClient(server.URL)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, client.Close()) })
requestBytes, err := transcoder.MyGamesListRequestToPayload(&lobbymodel.MyGamesListRequest{})
require.NoError(t, err)
_, err = client.ExecuteCommand(context.Background(), downstream.AuthenticatedCommand{
MessageType: lobbymodel.MessageTypeMyGamesList,
UserID: "user-1",
PayloadBytes: requestBytes,
})
require.Error(t, err)
assert.True(t, errors.Is(err, downstream.ErrDownstreamUnavailable))
}
func TestExecuteCommandRejectsEmptyUserID(t *testing.T) {
t.Parallel()
client, err := lobbyservice.NewHTTPClient("http://127.0.0.1:1")
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, client.Close()) })
requestBytes, err := transcoder.MyGamesListRequestToPayload(&lobbymodel.MyGamesListRequest{})
require.NoError(t, err)
_, err = client.ExecuteCommand(context.Background(), downstream.AuthenticatedCommand{
MessageType: lobbymodel.MessageTypeMyGamesList,
UserID: "",
PayloadBytes: requestBytes,
})
require.Error(t, err)
assert.True(t, strings.Contains(err.Error(), "user_id"), "error must mention user_id; got %q", err.Error())
}
func TestNewRoutesReservesUnavailableClientWhenBaseURLEmpty(t *testing.T) {
t.Parallel()
routes, closeFn, err := lobbyservice.NewRoutes("")
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, closeFn()) })
require.Contains(t, routes, lobbymodel.MessageTypeMyGamesList)
require.Contains(t, routes, lobbymodel.MessageTypeOpenEnrollment)
requestBytes, err := transcoder.MyGamesListRequestToPayload(&lobbymodel.MyGamesListRequest{})
require.NoError(t, err)
_, err = routes[lobbymodel.MessageTypeMyGamesList].ExecuteCommand(
context.Background(),
downstream.AuthenticatedCommand{
MessageType: lobbymodel.MessageTypeMyGamesList,
UserID: "user-1",
PayloadBytes: requestBytes,
},
)
require.Error(t, err)
assert.True(t, errors.Is(err, downstream.ErrDownstreamUnavailable))
}
@@ -1,45 +0,0 @@
package lobbyservice
import (
"context"
"galaxy/gateway/internal/downstream"
lobbymodel "galaxy/model/lobby"
)
var noOpClose = func() error { return nil }
// NewRoutes returns the reserved authenticated gateway routes owned by
// the Gateway -> Game Lobby boundary.
//
// When baseURL is empty, the returned routes still reserve the stable
// `lobby.*` message types but resolve them to a dependency-unavailable
// client so callers receive the transport-level unavailable outcome
// instead of a route-miss error.
func NewRoutes(baseURL string) (map[string]downstream.Client, func() error, error) {
client := downstream.Client(unavailableClient{})
closeFn := noOpClose
if baseURL != "" {
httpClient, err := NewHTTPClient(baseURL)
if err != nil {
return nil, nil, err
}
client = httpClient
closeFn = httpClient.Close
}
return map[string]downstream.Client{
lobbymodel.MessageTypeMyGamesList: client,
lobbymodel.MessageTypeOpenEnrollment: client,
}, closeFn, nil
}
type unavailableClient struct{}
func (unavailableClient) ExecuteCommand(context.Context, downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
return downstream.UnaryResult{}, downstream.ErrDownstreamUnavailable
}
var _ downstream.Client = unavailableClient{}
@@ -1,311 +0,0 @@
// Package userservice implements the authenticated Gateway -> User Service
// self-service downstream adapter.
package userservice
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"galaxy/gateway/internal/downstream"
usermodel "galaxy/model/user"
"galaxy/transcoder"
)
const (
getMyAccountResultCodeOK = "ok"
userServiceAccountPathSuffix = "/account"
userServiceProfilePathSuffix = "/profile"
userServiceSettingsPathSuffix = "/settings"
)
var stableErrorMessages = map[string]string{
"invalid_request": "request is invalid",
"subject_not_found": "subject not found",
"conflict": "request conflicts with current state",
"internal_error": "internal server error",
}
// HTTPClient implements downstream.Client against the trusted internal User
// Service REST API while preserving FlatBuffers at the external authenticated
// gateway boundary.
type HTTPClient struct {
baseURL string
httpClient *http.Client
}
// NewHTTPClient constructs one User Service downstream client backed by the
// trusted internal REST API at baseURL.
func NewHTTPClient(baseURL string) (*HTTPClient, error) {
transport, ok := http.DefaultTransport.(*http.Transport)
if !ok {
return nil, errors.New("new user service HTTP client: default transport is not *http.Transport")
}
return newHTTPClient(baseURL, &http.Client{
Transport: transport.Clone(),
})
}
func newHTTPClient(baseURL string, httpClient *http.Client) (*HTTPClient, error) {
if httpClient == nil {
return nil, errors.New("new user service HTTP client: http client must not be nil")
}
trimmedBaseURL := strings.TrimSpace(baseURL)
if trimmedBaseURL == "" {
return nil, errors.New("new user service HTTP client: base URL must not be empty")
}
parsedBaseURL, err := url.Parse(strings.TrimRight(trimmedBaseURL, "/"))
if err != nil {
return nil, fmt.Errorf("new user service HTTP client: parse base URL: %w", err)
}
if parsedBaseURL.Scheme == "" || parsedBaseURL.Host == "" {
return nil, errors.New("new user service HTTP client: base URL must be absolute")
}
return &HTTPClient{
baseURL: parsedBaseURL.String(),
httpClient: httpClient,
}, nil
}
// Close releases idle HTTP connections owned by the client transport.
func (c *HTTPClient) Close() error {
if c == nil || c.httpClient == nil {
return nil
}
type idleCloser interface {
CloseIdleConnections()
}
if transport, ok := c.httpClient.Transport.(idleCloser); ok {
transport.CloseIdleConnections()
}
return nil
}
// ExecuteCommand routes one authenticated gateway command to the matching
// trusted internal User Service self-service route.
func (c *HTTPClient) ExecuteCommand(ctx context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
if c == nil || c.httpClient == nil {
return downstream.UnaryResult{}, errors.New("execute user service command: nil client")
}
if ctx == nil {
return downstream.UnaryResult{}, errors.New("execute user service command: nil context")
}
if err := ctx.Err(); err != nil {
return downstream.UnaryResult{}, err
}
if strings.TrimSpace(command.UserID) == "" {
return downstream.UnaryResult{}, errors.New("execute user service command: user_id must not be empty")
}
switch command.MessageType {
case usermodel.MessageTypeGetMyAccount:
if _, err := transcoder.PayloadToGetMyAccountRequest(command.PayloadBytes); err != nil {
return downstream.UnaryResult{}, fmt.Errorf("execute user service command %q: %w", command.MessageType, err)
}
return c.executeGetMyAccount(ctx, command.UserID)
case usermodel.MessageTypeUpdateMyProfile:
request, err := transcoder.PayloadToUpdateMyProfileRequest(command.PayloadBytes)
if err != nil {
return downstream.UnaryResult{}, fmt.Errorf("execute user service command %q: %w", command.MessageType, err)
}
return c.executeUpdateMyProfile(ctx, command.UserID, request)
case usermodel.MessageTypeUpdateMySettings:
request, err := transcoder.PayloadToUpdateMySettingsRequest(command.PayloadBytes)
if err != nil {
return downstream.UnaryResult{}, fmt.Errorf("execute user service command %q: %w", command.MessageType, err)
}
return c.executeUpdateMySettings(ctx, command.UserID, request)
default:
return downstream.UnaryResult{}, fmt.Errorf("execute user service command: unsupported message type %q", command.MessageType)
}
}
func (c *HTTPClient) executeGetMyAccount(ctx context.Context, userID string) (downstream.UnaryResult, error) {
payload, statusCode, err := c.doRequest(ctx, http.MethodGet, c.userPath(userID, userServiceAccountPathSuffix), nil)
if err != nil {
return downstream.UnaryResult{}, fmt.Errorf("execute get my account: %w", err)
}
return projectResponse(statusCode, payload)
}
func (c *HTTPClient) executeUpdateMyProfile(ctx context.Context, userID string, request *usermodel.UpdateMyProfileRequest) (downstream.UnaryResult, error) {
payload, statusCode, err := c.doRequest(ctx, http.MethodPost, c.userPath(userID, userServiceProfilePathSuffix), request)
if err != nil {
return downstream.UnaryResult{}, fmt.Errorf("execute update my profile: %w", err)
}
return projectResponse(statusCode, payload)
}
func (c *HTTPClient) executeUpdateMySettings(ctx context.Context, userID string, request *usermodel.UpdateMySettingsRequest) (downstream.UnaryResult, error) {
payload, statusCode, err := c.doRequest(ctx, http.MethodPost, c.userPath(userID, userServiceSettingsPathSuffix), request)
if err != nil {
return downstream.UnaryResult{}, fmt.Errorf("execute update my settings: %w", err)
}
return projectResponse(statusCode, payload)
}
func (c *HTTPClient) doRequest(ctx context.Context, method string, targetURL string, requestBody any) ([]byte, int, error) {
if c == nil || c.httpClient == nil {
return nil, 0, errors.New("nil client")
}
var bodyReader io.Reader
if requestBody != nil {
payload, err := json.Marshal(requestBody)
if err != nil {
return nil, 0, fmt.Errorf("marshal request body: %w", err)
}
bodyReader = bytes.NewReader(payload)
}
request, err := http.NewRequestWithContext(ctx, method, targetURL, bodyReader)
if err != nil {
return nil, 0, fmt.Errorf("build request: %w", err)
}
if requestBody != nil {
request.Header.Set("Content-Type", "application/json")
}
response, err := c.httpClient.Do(request)
if err != nil {
return nil, 0, err
}
defer response.Body.Close()
payload, err := io.ReadAll(response.Body)
if err != nil {
return nil, 0, fmt.Errorf("read response body: %w", err)
}
return payload, response.StatusCode, nil
}
func (c *HTTPClient) userPath(userID string, suffix string) string {
return c.baseURL + "/api/v1/internal/users/" + url.PathEscape(userID) + suffix
}
func projectResponse(statusCode int, payload []byte) (downstream.UnaryResult, error) {
switch {
case statusCode == http.StatusOK:
var response usermodel.AccountResponse
if err := decodeStrictJSONPayload(payload, &response); err != nil {
return downstream.UnaryResult{}, fmt.Errorf("decode success response: %w", err)
}
payloadBytes, err := transcoder.AccountResponseToPayload(&response)
if err != nil {
return downstream.UnaryResult{}, fmt.Errorf("encode success response payload: %w", err)
}
return downstream.UnaryResult{
ResultCode: getMyAccountResultCodeOK,
PayloadBytes: payloadBytes,
}, nil
case statusCode == http.StatusServiceUnavailable:
return downstream.UnaryResult{}, downstream.ErrDownstreamUnavailable
case statusCode >= 400 && statusCode <= 599:
errorResponse, err := decodeUserServiceError(statusCode, payload)
if err != nil {
return downstream.UnaryResult{}, fmt.Errorf("decode error response: %w", err)
}
payloadBytes, err := transcoder.ErrorResponseToPayload(errorResponse)
if err != nil {
return downstream.UnaryResult{}, fmt.Errorf("encode error response payload: %w", err)
}
return downstream.UnaryResult{
ResultCode: errorResponse.Error.Code,
PayloadBytes: payloadBytes,
}, nil
default:
return downstream.UnaryResult{}, fmt.Errorf("unexpected HTTP status %d", statusCode)
}
}
func decodeUserServiceError(statusCode int, payload []byte) (*usermodel.ErrorResponse, error) {
var response usermodel.ErrorResponse
if err := decodeStrictJSONPayload(payload, &response); err != nil {
return nil, err
}
response.Error.Code = normalizeErrorCode(statusCode, response.Error.Code)
response.Error.Message = normalizeErrorMessage(response.Error.Code, response.Error.Message)
if strings.TrimSpace(response.Error.Code) == "" {
return nil, errors.New("missing error code")
}
if strings.TrimSpace(response.Error.Message) == "" {
return nil, errors.New("missing error message")
}
return &response, nil
}
func normalizeErrorCode(statusCode int, code string) string {
trimmed := strings.TrimSpace(code)
if trimmed != "" {
return trimmed
}
switch statusCode {
case http.StatusBadRequest:
return "invalid_request"
case http.StatusNotFound:
return "subject_not_found"
case http.StatusConflict:
return "conflict"
default:
return "internal_error"
}
}
func normalizeErrorMessage(code string, message string) string {
trimmed := strings.TrimSpace(message)
if trimmed != "" {
return trimmed
}
if stable, ok := stableErrorMessages[code]; ok {
return stable
}
return stableErrorMessages["internal_error"]
}
func decodeStrictJSONPayload(payload []byte, target any) error {
decoder := json.NewDecoder(bytes.NewReader(payload))
decoder.DisallowUnknownFields()
if err := decoder.Decode(target); err != nil {
return err
}
if err := decoder.Decode(&struct{}{}); err != io.EOF {
if err == nil {
return errors.New("unexpected trailing JSON input")
}
return err
}
return nil
}
var _ downstream.Client = (*HTTPClient)(nil)
@@ -1,400 +0,0 @@
package userservice
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"galaxy/gateway/internal/downstream"
usermodel "galaxy/model/user"
"galaxy/transcoder"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewHTTPClient(t *testing.T) {
t.Parallel()
tests := []struct {
name string
baseURL string
wantURL string
wantErr string
}{
{
name: "absolute URL is normalized",
baseURL: " http://127.0.0.1:8081/ ",
wantURL: "http://127.0.0.1:8081",
},
{
name: "empty base URL is rejected",
baseURL: " ",
wantErr: "base URL must not be empty",
},
{
name: "relative base URL is rejected",
baseURL: "/relative",
wantErr: "base URL must be absolute",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
client, err := NewHTTPClient(tt.baseURL)
if tt.wantErr != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.wantErr)
return
}
require.NoError(t, err)
assert.Equal(t, tt.wantURL, client.baseURL)
})
}
}
func TestHTTPClientExecuteGetMyAccountSuccess(t *testing.T) {
t.Parallel()
wantResponse := sampleAccountResponse()
server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
require.Equal(t, http.MethodGet, request.Method)
require.Equal(t, "/api/v1/internal/users/user-123/account", request.URL.Path)
require.NoError(t, json.NewEncoder(writer).Encode(wantResponse))
}))
defer server.Close()
client := newTestHTTPClient(t, server)
payload, err := transcoder.GetMyAccountRequestToPayload(&usermodel.GetMyAccountRequest{})
require.NoError(t, err)
result, err := client.ExecuteCommand(context.Background(), downstream.AuthenticatedCommand{
UserID: "user-123",
MessageType: usermodel.MessageTypeGetMyAccount,
PayloadBytes: payload,
})
require.NoError(t, err)
assert.Equal(t, getMyAccountResultCodeOK, result.ResultCode)
decoded, err := transcoder.PayloadToAccountResponse(result.PayloadBytes)
require.NoError(t, err)
assert.Equal(t, wantResponse, decoded)
}
func TestHTTPClientExecuteUpdateMyProfileProjectsConflict(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
require.Equal(t, http.MethodPost, request.Method)
require.Equal(t, "/api/v1/internal/users/user-123/profile", request.URL.Path)
body, err := io.ReadAll(request.Body)
require.NoError(t, err)
require.JSONEq(t, `{"display_name":"NovaPrime"}`, string(body))
writer.WriteHeader(http.StatusConflict)
require.NoError(t, json.NewEncoder(writer).Encode(&usermodel.ErrorResponse{
Error: usermodel.ErrorBody{
Code: "conflict",
Message: "request conflicts with current state",
},
}))
}))
defer server.Close()
client := newTestHTTPClient(t, server)
payload, err := transcoder.UpdateMyProfileRequestToPayload(&usermodel.UpdateMyProfileRequest{DisplayName: "NovaPrime"})
require.NoError(t, err)
result, err := client.ExecuteCommand(context.Background(), downstream.AuthenticatedCommand{
UserID: "user-123",
MessageType: usermodel.MessageTypeUpdateMyProfile,
PayloadBytes: payload,
})
require.NoError(t, err)
assert.Equal(t, "conflict", result.ResultCode)
decoded, err := transcoder.PayloadToErrorResponse(result.PayloadBytes)
require.NoError(t, err)
assert.Equal(t, &usermodel.ErrorResponse{
Error: usermodel.ErrorBody{
Code: "conflict",
Message: "request conflicts with current state",
},
}, decoded)
}
func TestHTTPClientExecuteUpdateMySettingsProjectsInvalidRequest(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
require.Equal(t, http.MethodPost, request.Method)
require.Equal(t, "/api/v1/internal/users/user-123/settings", request.URL.Path)
body, err := io.ReadAll(request.Body)
require.NoError(t, err)
require.JSONEq(t, `{"preferred_language":"bad","time_zone":"Mars/Base"}`, string(body))
writer.WriteHeader(http.StatusBadRequest)
require.NoError(t, json.NewEncoder(writer).Encode(&usermodel.ErrorResponse{
Error: usermodel.ErrorBody{
Code: "invalid_request",
Message: "request is invalid",
},
}))
}))
defer server.Close()
client := newTestHTTPClient(t, server)
payload, err := transcoder.UpdateMySettingsRequestToPayload(&usermodel.UpdateMySettingsRequest{
PreferredLanguage: "bad",
TimeZone: "Mars/Base",
})
require.NoError(t, err)
result, err := client.ExecuteCommand(context.Background(), downstream.AuthenticatedCommand{
UserID: "user-123",
MessageType: usermodel.MessageTypeUpdateMySettings,
PayloadBytes: payload,
})
require.NoError(t, err)
assert.Equal(t, "invalid_request", result.ResultCode)
decoded, err := transcoder.PayloadToErrorResponse(result.PayloadBytes)
require.NoError(t, err)
assert.Equal(t, "invalid_request", decoded.Error.Code)
}
func TestHTTPClientExecuteCommandProjectsSubjectNotFound(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
writer.WriteHeader(http.StatusNotFound)
require.NoError(t, json.NewEncoder(writer).Encode(&usermodel.ErrorResponse{
Error: usermodel.ErrorBody{
Code: "subject_not_found",
Message: "subject not found",
},
}))
}))
defer server.Close()
client := newTestHTTPClient(t, server)
payload, err := transcoder.GetMyAccountRequestToPayload(&usermodel.GetMyAccountRequest{})
require.NoError(t, err)
result, err := client.ExecuteCommand(context.Background(), downstream.AuthenticatedCommand{
UserID: "user-missing",
MessageType: usermodel.MessageTypeGetMyAccount,
PayloadBytes: payload,
})
require.NoError(t, err)
assert.Equal(t, "subject_not_found", result.ResultCode)
}
func TestHTTPClientExecuteCommandMaps503ToUnavailable(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
writer.WriteHeader(http.StatusServiceUnavailable)
require.NoError(t, json.NewEncoder(writer).Encode(&usermodel.ErrorResponse{
Error: usermodel.ErrorBody{
Code: "service_unavailable",
Message: "service is unavailable",
},
}))
}))
defer server.Close()
client := newTestHTTPClient(t, server)
payload, err := transcoder.GetMyAccountRequestToPayload(&usermodel.GetMyAccountRequest{})
require.NoError(t, err)
_, err = client.ExecuteCommand(context.Background(), downstream.AuthenticatedCommand{
UserID: "user-123",
MessageType: usermodel.MessageTypeGetMyAccount,
PayloadBytes: payload,
})
require.Error(t, err)
assert.ErrorIs(t, err, downstream.ErrDownstreamUnavailable)
}
func TestHTTPClientExecuteCommandUsesCallerContext(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
<-request.Context().Done()
}))
defer server.Close()
client := newTestHTTPClient(t, server)
payload, err := transcoder.GetMyAccountRequestToPayload(&usermodel.GetMyAccountRequest{})
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 25*time.Millisecond)
defer cancel()
_, err = client.ExecuteCommand(ctx, downstream.AuthenticatedCommand{
UserID: "user-123",
MessageType: usermodel.MessageTypeGetMyAccount,
PayloadBytes: payload,
})
require.Error(t, err)
assert.ErrorIs(t, err, context.DeadlineExceeded)
}
func TestHTTPClientExecuteCommandRejectsMalformedSuccessPayload(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
_, _ = writer.Write([]byte(`{"account":{"user_id":"user-123","unexpected":true}}`))
}))
defer server.Close()
client := newTestHTTPClient(t, server)
payload, err := transcoder.GetMyAccountRequestToPayload(&usermodel.GetMyAccountRequest{})
require.NoError(t, err)
_, err = client.ExecuteCommand(context.Background(), downstream.AuthenticatedCommand{
UserID: "user-123",
MessageType: usermodel.MessageTypeGetMyAccount,
PayloadBytes: payload,
})
require.Error(t, err)
assert.Contains(t, err.Error(), "decode success response")
}
func TestHTTPClientExecuteCommandRejectsUnsupportedMessageType(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.NotFoundHandler())
defer server.Close()
client := newTestHTTPClient(t, server)
_, err := client.ExecuteCommand(context.Background(), downstream.AuthenticatedCommand{
UserID: "user-123",
MessageType: "user.unsupported",
PayloadBytes: []byte("payload"),
})
require.Error(t, err)
assert.Contains(t, err.Error(), "unsupported message type")
}
func TestNewRoutesReserveUserMessageTypesWhenUnconfigured(t *testing.T) {
t.Parallel()
routes, closeFn, err := NewRoutes("")
require.NoError(t, err)
require.NoError(t, closeFn())
router := downstream.NewStaticRouter(routes)
for _, messageType := range []string{
usermodel.MessageTypeGetMyAccount,
usermodel.MessageTypeUpdateMyProfile,
usermodel.MessageTypeUpdateMySettings,
} {
client, routeErr := router.Route(messageType)
require.NoError(t, routeErr)
_, execErr := client.ExecuteCommand(context.Background(), downstream.AuthenticatedCommand{
UserID: "user-123",
MessageType: messageType,
})
require.Error(t, execErr)
assert.ErrorIs(t, execErr, downstream.ErrDownstreamUnavailable)
}
}
func TestUnavailableClientReturnsDownstreamUnavailable(t *testing.T) {
t.Parallel()
_, err := unavailableClient{}.ExecuteCommand(context.Background(), downstream.AuthenticatedCommand{})
require.Error(t, err)
assert.ErrorIs(t, err, downstream.ErrDownstreamUnavailable)
}
func newTestHTTPClient(t *testing.T, server *httptest.Server) *HTTPClient {
t.Helper()
client, err := newHTTPClient(server.URL, server.Client())
require.NoError(t, err)
return client
}
func sampleAccountResponse() *usermodel.AccountResponse {
now := time.Date(2026, time.April, 9, 10, 0, 0, 0, time.UTC)
expiresAt := now.Add(30 * 24 * time.Hour)
return &usermodel.AccountResponse{
Account: usermodel.Account{
UserID: "user-123",
Email: "pilot@example.com",
UserName: "player-abcdefgh",
DisplayName: "PilotNova",
PreferredLanguage: "en",
TimeZone: "Europe/Kaliningrad",
DeclaredCountry: "DE",
Entitlement: usermodel.EntitlementSnapshot{
PlanCode: "free",
IsPaid: false,
Source: "auth_registration",
Actor: usermodel.ActorRef{Type: "service", ID: "user-service"},
ReasonCode: "initial_free_entitlement",
StartsAt: now,
UpdatedAt: now,
},
ActiveSanctions: []usermodel.ActiveSanction{
{
SanctionCode: "profile_update_block",
Scope: "lobby",
ReasonCode: "manual_block",
Actor: usermodel.ActorRef{Type: "admin", ID: "admin-1"},
AppliedAt: now,
ExpiresAt: &expiresAt,
},
},
ActiveLimits: []usermodel.ActiveLimit{
{
LimitCode: "max_owned_private_games",
Value: 3,
ReasonCode: "manual_override",
Actor: usermodel.ActorRef{Type: "admin", ID: "admin-1"},
AppliedAt: now,
},
},
CreatedAt: now,
UpdatedAt: now,
},
}
}
func TestDecodeUserServiceErrorNormalizesBlankFields(t *testing.T) {
t.Parallel()
response, err := decodeUserServiceError(http.StatusBadRequest, []byte(`{"error":{"code":" ","message":" "}}`))
require.NoError(t, err)
assert.Equal(t, "invalid_request", response.Error.Code)
assert.Equal(t, "request is invalid", response.Error.Message)
}
func TestHTTPClientExecuteCommandRejectsNilContext(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.NotFoundHandler())
defer server.Close()
client := newTestHTTPClient(t, server)
_, err := client.ExecuteCommand(nil, downstream.AuthenticatedCommand{})
require.Error(t, err)
assert.Contains(t, err.Error(), "nil context")
}
@@ -1,46 +0,0 @@
package userservice
import (
"context"
"galaxy/gateway/internal/downstream"
usermodel "galaxy/model/user"
)
var noOpClose = func() error { return nil }
// NewRoutes returns the reserved authenticated gateway routes owned by the
// Gateway -> User self-service boundary.
//
// When baseURL is empty, the returned routes still reserve the stable
// `user.*` message types but resolve them to a dependency-unavailable client
// so callers receive the transport-level unavailable outcome instead of a
// route-miss error.
func NewRoutes(baseURL string) (map[string]downstream.Client, func() error, error) {
client := downstream.Client(unavailableClient{})
closeFn := noOpClose
if baseURL != "" {
httpClient, err := NewHTTPClient(baseURL)
if err != nil {
return nil, nil, err
}
client = httpClient
closeFn = httpClient.Close
}
return map[string]downstream.Client{
usermodel.MessageTypeGetMyAccount: client,
usermodel.MessageTypeUpdateMyProfile: client,
usermodel.MessageTypeUpdateMySettings: client,
}, closeFn, nil
}
type unavailableClient struct{}
func (unavailableClient) ExecuteCommand(context.Context, downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
return downstream.UnaryResult{}, downstream.ErrDownstreamUnavailable
}
var _ downstream.Client = unavailableClient{}
@@ -1,299 +0,0 @@
package events
import (
"bytes"
"context"
"errors"
"fmt"
"strings"
"sync"
"time"
"galaxy/gateway/internal/config"
"galaxy/gateway/internal/push"
"galaxy/gateway/internal/telemetry"
"github.com/redis/go-redis/v9"
"go.opentelemetry.io/otel/attribute"
"go.uber.org/zap"
)
const clientEventReadCount int64 = 128
// ClientEventPublisher accepts decoded client-facing events from the internal
// event subscriber.
type ClientEventPublisher interface {
// Publish fans out event to the currently active push streams.
Publish(event push.Event)
}
// RedisClientEventSubscriber consumes client-facing events from one Redis
// Stream and forwards them to the configured publisher.
type RedisClientEventSubscriber struct {
client *redis.Client
stream string
pingTimeout time.Duration
readBlockTimeout time.Duration
publisher ClientEventPublisher
logger *zap.Logger
metrics *telemetry.Runtime
startedOnce sync.Once
started chan struct{}
}
// NewRedisClientEventSubscriber constructs a Redis Stream subscriber that uses
// client and forwards decoded client-facing events to publisher.
func NewRedisClientEventSubscriber(client *redis.Client, sessionCfg config.SessionCacheRedisConfig, eventsCfg config.ClientEventsRedisConfig, publisher ClientEventPublisher) (*RedisClientEventSubscriber, error) {
return NewRedisClientEventSubscriberWithObservability(client, sessionCfg, eventsCfg, publisher, nil, nil)
}
// NewRedisClientEventSubscriberWithObservability constructs a Redis Stream
// subscriber that also records malformed or dropped internal events. The
// subscriber does not own the client; the runtime supplies a shared
// *redis.Client.
func NewRedisClientEventSubscriberWithObservability(client *redis.Client, sessionCfg config.SessionCacheRedisConfig, eventsCfg config.ClientEventsRedisConfig, publisher ClientEventPublisher, logger *zap.Logger, metrics *telemetry.Runtime) (*RedisClientEventSubscriber, error) {
if client == nil {
return nil, errors.New("new redis client event subscriber: nil redis client")
}
if sessionCfg.LookupTimeout <= 0 {
return nil, errors.New("new redis client event subscriber: lookup timeout must be positive")
}
if strings.TrimSpace(eventsCfg.Stream) == "" {
return nil, errors.New("new redis client event subscriber: stream must not be empty")
}
if eventsCfg.ReadBlockTimeout <= 0 {
return nil, errors.New("new redis client event subscriber: read block timeout must be positive")
}
if publisher == nil {
return nil, errors.New("new redis client event subscriber: nil publisher")
}
if logger == nil {
logger = zap.NewNop()
}
return &RedisClientEventSubscriber{
client: client,
stream: eventsCfg.Stream,
pingTimeout: sessionCfg.LookupTimeout,
readBlockTimeout: eventsCfg.ReadBlockTimeout,
publisher: publisher,
logger: logger.Named("client_event_subscriber"),
metrics: metrics,
started: make(chan struct{}),
}, nil
}
// Run consumes client-facing events until ctx is canceled or Redis returns an
// unexpected error.
func (s *RedisClientEventSubscriber) Run(ctx context.Context) error {
if s == nil || s.client == nil {
return errors.New("run redis client event subscriber: nil subscriber")
}
if ctx == nil {
return errors.New("run redis client event subscriber: nil context")
}
if err := ctx.Err(); err != nil {
return err
}
lastID, err := s.resolveStartID(ctx)
if err != nil {
return err
}
s.signalStarted()
for {
streams, err := s.client.XRead(ctx, &redis.XReadArgs{
Streams: []string{s.stream, lastID},
Count: clientEventReadCount,
Block: s.readBlockTimeout,
}).Result()
switch {
case err == nil:
for _, stream := range streams {
for _, message := range stream.Messages {
s.publishMessage(message)
lastID = message.ID
}
}
continue
case errors.Is(err, redis.Nil):
continue
case ctx.Err() != nil && (errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) || errors.Is(err, redis.ErrClosed)):
return ctx.Err()
case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded), errors.Is(err, redis.ErrClosed):
return fmt.Errorf("run redis client event subscriber: %w", err)
default:
return fmt.Errorf("run redis client event subscriber: %w", err)
}
}
}
func (s *RedisClientEventSubscriber) resolveStartID(ctx context.Context) (string, error) {
messages, err := s.client.XRevRangeN(ctx, s.stream, "+", "-", 1).Result()
switch {
case err == nil:
case errors.Is(err, redis.Nil):
return "0-0", nil
default:
return "", fmt.Errorf("run redis client event subscriber: resolve stream tail: %w", err)
}
if len(messages) == 0 {
return "0-0", nil
}
return messages[0].ID, nil
}
// Shutdown is a no-op kept for App framework compatibility. The blocking
// XRead loop terminates when its context is cancelled by the parent runtime,
// which also owns and closes the shared Redis client.
func (s *RedisClientEventSubscriber) Shutdown(ctx context.Context) error {
if ctx == nil {
return errors.New("shutdown redis client event subscriber: nil context")
}
return nil
}
// Close is a no-op kept for backwards-compatible cleanup wiring; the
// subscriber does not own the shared Redis client.
func (s *RedisClientEventSubscriber) Close() error {
return nil
}
func (s *RedisClientEventSubscriber) signalStarted() {
s.startedOnce.Do(func() {
close(s.started)
})
}
func (s *RedisClientEventSubscriber) publishMessage(message redis.XMessage) {
event, err := decodeClientEvent(message.Values)
if err != nil {
s.logger.Warn("dropped malformed client event",
zap.String("stream", s.stream),
zap.String("message_id", message.ID),
zap.Error(err),
)
s.metrics.RecordInternalEventDrop(context.Background(),
attribute.String("component", "client_event_subscriber"),
attribute.String("reason", "malformed_event"),
)
return
}
s.publisher.Publish(event)
}
func decodeClientEvent(values map[string]any) (push.Event, error) {
requiredKeys := map[string]struct{}{
"user_id": {},
"event_type": {},
"event_id": {},
"payload_bytes": {},
}
optionalKeys := map[string]struct{}{
"device_session_id": {},
"request_id": {},
"trace_id": {},
}
for key := range values {
if _, ok := requiredKeys[key]; ok {
continue
}
if _, ok := optionalKeys[key]; ok {
continue
}
return push.Event{}, fmt.Errorf("decode client event: unsupported field %q", key)
}
userID, err := requiredStringField(values, "user_id")
if err != nil {
return push.Event{}, err
}
eventType, err := requiredStringField(values, "event_type")
if err != nil {
return push.Event{}, err
}
eventID, err := requiredStringField(values, "event_id")
if err != nil {
return push.Event{}, err
}
payloadBytes, err := requiredBytesField(values, "payload_bytes")
if err != nil {
return push.Event{}, err
}
event := push.Event{
UserID: userID,
EventType: eventType,
EventID: eventID,
PayloadBytes: payloadBytes,
}
if deviceSessionID, ok, err := optionalStringField(values, "device_session_id"); err != nil {
return push.Event{}, err
} else if ok {
event.DeviceSessionID = strings.TrimSpace(deviceSessionID)
}
if requestID, ok, err := optionalStringField(values, "request_id"); err != nil {
return push.Event{}, err
} else if ok {
event.RequestID = requestID
}
if traceID, ok, err := optionalStringField(values, "trace_id"); err != nil {
return push.Event{}, err
} else if ok {
event.TraceID = traceID
}
return event, nil
}
func requiredBytesField(values map[string]any, field string) ([]byte, error) {
value, ok := values[field]
if !ok {
return nil, fmt.Errorf("decode client event: missing %s", field)
}
byteValue, err := coerceBytes(value)
if err != nil {
return nil, fmt.Errorf("decode client event: %s: %w", field, err)
}
return byteValue, nil
}
func optionalStringField(values map[string]any, field string) (string, bool, error) {
value, ok := values[field]
if !ok {
return "", false, nil
}
stringValue, err := coerceString(value)
if err != nil {
return "", false, fmt.Errorf("decode client event: %s: %w", field, err)
}
return stringValue, true, nil
}
func coerceBytes(value any) ([]byte, error) {
switch typed := value.(type) {
case string:
return []byte(typed), nil
case []byte:
return bytes.Clone(typed), nil
default:
return nil, fmt.Errorf("unsupported type %T", value)
}
}
@@ -1,289 +0,0 @@
package events
import (
"context"
"strings"
"sync"
"testing"
"time"
"galaxy/gateway/internal/config"
"galaxy/gateway/internal/push"
"galaxy/gateway/internal/testutil"
"github.com/alicebob/miniredis/v2"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRedisClientEventSubscriberPublishesValidEvent(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := &recordingClientEventPublisher{}
subscriber := newTestRedisClientEventSubscriber(t, server, publisher)
running := runTestClientEventSubscriber(t, subscriber)
defer running.stop(t)
addClientEvent(t, server, "gateway:client_events", map[string]any{
"user_id": "user-123",
"device_session_id": "device-session-123",
"event_type": "fleet.updated",
"event_id": "event-123",
"payload_bytes": []byte("payload-123"),
"request_id": "request-123",
"trace_id": "trace-123",
})
require.Eventually(t, func() bool {
return len(publisher.events()) == 1
}, time.Second, 10*time.Millisecond)
assert.Equal(t, []push.Event{{
UserID: "user-123",
DeviceSessionID: "device-session-123",
EventType: "fleet.updated",
EventID: "event-123",
PayloadBytes: []byte("payload-123"),
RequestID: "request-123",
TraceID: "trace-123",
}}, publisher.events())
}
func TestRedisClientEventSubscriberSkipsMalformedEventAndContinues(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := &recordingClientEventPublisher{}
subscriber := newTestRedisClientEventSubscriber(t, server, publisher)
running := runTestClientEventSubscriber(t, subscriber)
defer running.stop(t)
addClientEvent(t, server, "gateway:client_events", map[string]any{
"user_id": "user-123",
"event_type": "fleet.updated",
"event_id": "event-bad",
"payload_bytes": []byte("payload-bad"),
"unexpected": "boom",
})
addClientEvent(t, server, "gateway:client_events", map[string]any{
"user_id": "user-123",
"event_type": "fleet.updated",
"event_id": "event-good",
"payload_bytes": []byte("payload-good"),
})
require.Eventually(t, func() bool {
events := publisher.events()
return len(events) == 1 && events[0].EventID == "event-good"
}, time.Second, 10*time.Millisecond)
}
func TestRedisClientEventSubscriberStartsFromCurrentTail(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := &recordingClientEventPublisher{}
addClientEvent(t, server, "gateway:client_events", map[string]any{
"user_id": "user-123",
"event_type": "fleet.updated",
"event_id": "event-old",
"payload_bytes": []byte("payload-old"),
})
subscriber := newTestRedisClientEventSubscriber(t, server, publisher)
running := runTestClientEventSubscriber(t, subscriber)
defer running.stop(t)
assert.Never(t, func() bool {
return len(publisher.events()) > 0
}, 100*time.Millisecond, 10*time.Millisecond)
addClientEvent(t, server, "gateway:client_events", map[string]any{
"user_id": "user-123",
"event_type": "fleet.updated",
"event_id": "event-new",
"payload_bytes": []byte("payload-new"),
})
require.Eventually(t, func() bool {
events := publisher.events()
return len(events) == 1 && events[0].EventID == "event-new"
}, time.Second, 10*time.Millisecond)
}
func TestRedisClientEventSubscriberShutdownInterruptsBlockingRead(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := &recordingClientEventPublisher{}
subscriber := newTestRedisClientEventSubscriber(t, server, publisher)
ctx, cancel := context.WithCancel(context.Background())
resultCh := make(chan error, 1)
go func() {
resultCh <- subscriber.Run(ctx)
}()
select {
case <-subscriber.started:
case <-time.After(time.Second):
require.FailNow(t, "subscriber did not start")
}
cancel()
require.NoError(t, subscriber.Shutdown(context.Background()))
select {
case err := <-resultCh:
require.ErrorIs(t, err, context.Canceled)
case <-time.After(time.Second):
require.FailNow(t, "subscriber did not stop after shutdown")
}
}
func TestRedisClientEventSubscriberLogsAndCountsMalformedEvents(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
publisher := &recordingClientEventPublisher{}
logger, logBuffer := testutil.NewObservedLogger(t)
telemetryRuntime := testutil.NewTelemetryRuntime(t, logger)
subscriber, err := NewRedisClientEventSubscriberWithObservability(
newTestRedisClient(t, server),
config.SessionCacheRedisConfig{
KeyPrefix: "gateway:session:",
LookupTimeout: 250 * time.Millisecond,
},
config.ClientEventsRedisConfig{
Stream: "gateway:client_events",
ReadBlockTimeout: 25 * time.Millisecond,
},
publisher,
logger,
telemetryRuntime,
)
require.NoError(t, err)
running := runTestClientEventSubscriber(t, subscriber)
defer running.stop(t)
addClientEvent(t, server, "gateway:client_events", map[string]any{
"user_id": "user-123",
"event_type": "fleet.updated",
"event_id": "event-bad",
"payload_bytes": []byte("payload-bad"),
"unexpected": "boom",
})
require.Eventually(t, func() bool {
return strings.Contains(logBuffer.String(), "dropped malformed client event")
}, time.Second, 10*time.Millisecond)
metricsText := testutil.ScrapeMetrics(t, telemetryRuntime.Handler())
assert.Contains(t, metricsText, `gateway_internal_event_drops_total`)
assert.Contains(t, metricsText, `component="client_event_subscriber"`)
assert.Contains(t, metricsText, `reason="malformed_event"`)
}
func newTestRedisClientEventSubscriber(t *testing.T, server *miniredis.Miniredis, publisher ClientEventPublisher) *RedisClientEventSubscriber {
t.Helper()
subscriber, err := NewRedisClientEventSubscriber(
newTestRedisClient(t, server),
config.SessionCacheRedisConfig{
KeyPrefix: "gateway:session:",
LookupTimeout: 250 * time.Millisecond,
},
config.ClientEventsRedisConfig{
Stream: "gateway:client_events",
ReadBlockTimeout: 25 * time.Millisecond,
},
publisher,
)
require.NoError(t, err)
return subscriber
}
func addClientEvent(t *testing.T, server *miniredis.Miniredis, stream string, values map[string]any) {
t.Helper()
client := redis.NewClient(&redis.Options{
Addr: server.Addr(),
Protocol: 2,
DisableIdentity: true,
})
defer func() {
assert.NoError(t, client.Close())
}()
err := client.XAdd(context.Background(), &redis.XAddArgs{
Stream: stream,
Values: values,
}).Err()
require.NoError(t, err)
}
type runningClientEventSubscriber struct {
cancel context.CancelFunc
resultCh chan error
}
func runTestClientEventSubscriber(t *testing.T, subscriber *RedisClientEventSubscriber) runningClientEventSubscriber {
t.Helper()
ctx, cancel := context.WithCancel(context.Background())
resultCh := make(chan error, 1)
go func() {
resultCh <- subscriber.Run(ctx)
}()
select {
case <-subscriber.started:
case <-time.After(time.Second):
require.FailNow(t, "subscriber did not start")
}
return runningClientEventSubscriber{
cancel: cancel,
resultCh: resultCh,
}
}
func (r runningClientEventSubscriber) stop(t *testing.T) {
t.Helper()
r.cancel()
select {
case err := <-r.resultCh:
require.ErrorIs(t, err, context.Canceled)
case <-time.After(time.Second):
require.FailNow(t, "subscriber did not stop")
}
}
type recordingClientEventPublisher struct {
mu sync.Mutex
records []push.Event
}
func (p *recordingClientEventPublisher) Publish(event push.Event) {
p.mu.Lock()
defer p.mu.Unlock()
p.records = append(p.records, event)
}
func (p *recordingClientEventPublisher) events() []push.Event {
p.mu.Lock()
defer p.mu.Unlock()
cloned := make([]push.Event, len(p.records))
copy(cloned, p.records)
return cloned
}
+145
View File
@@ -0,0 +1,145 @@
// Package events translates inbound `pushv1.PushEvent` frames received
// from backend into actions on the gateway-side push hub. It replaces
// the Stage <6.2 Redis Stream subscribers (`session_events`,
// `client_events`) with a single dispatcher driven by the gRPC
// SubscribePush stream.
package events
import (
"context"
"strings"
pushv1 "galaxy/backend/proto/push/v1"
"galaxy/gateway/internal/push"
"galaxy/gateway/internal/telemetry"
"go.opentelemetry.io/otel/attribute"
"go.uber.org/zap"
)
// SessionInvalidator closes every active push subscription bound to a
// (device_session_id) or every session of a user when the backend emits
// a SessionInvalidation frame. *push.Hub satisfies this contract.
type SessionInvalidator interface {
RevokeDeviceSession(deviceSessionID string)
RevokeAllForUser(userID string)
}
// EventPublisher fans out a translated client event to active push
// subscriptions. *push.Hub satisfies this contract.
type EventPublisher interface {
Publish(event push.Event)
}
// Dispatcher converts inbound `pushv1.PushEvent` frames into either a
// hub Publish or a hub revocation. Malformed frames are dropped and
// counted via telemetry; observability mirrors the previous
// RecordInternalEventDrop semantics.
type Dispatcher struct {
publisher EventPublisher
invalidator SessionInvalidator
logger *zap.Logger
metrics *telemetry.Runtime
}
// NewDispatcher constructs a Dispatcher. publisher and invalidator are
// required; logger and metrics may be nil.
func NewDispatcher(publisher EventPublisher, invalidator SessionInvalidator, logger *zap.Logger, metrics *telemetry.Runtime) *Dispatcher {
if logger == nil {
logger = zap.NewNop()
}
return &Dispatcher{
publisher: publisher,
invalidator: invalidator,
logger: logger.Named("push_dispatcher"),
metrics: metrics,
}
}
// Handle implements backendclient.EventHandler. It is safe for
// concurrent use; the caller serialises ev within its goroutine.
func (d *Dispatcher) Handle(ctx context.Context, ev *pushv1.PushEvent) {
if d == nil || ev == nil {
return
}
switch kind := ev.GetKind().(type) {
case *pushv1.PushEvent_ClientEvent:
d.handleClientEvent(ctx, kind.ClientEvent, ev.GetCursor())
case *pushv1.PushEvent_SessionInvalidation:
d.handleSessionInvalidation(kind.SessionInvalidation)
default:
d.logger.Warn("dropped malformed push event",
zap.String("cursor", ev.GetCursor()),
zap.String("reason", "unknown_kind"),
)
d.recordDrop(ctx, "unknown_kind")
}
}
func (d *Dispatcher) handleClientEvent(ctx context.Context, ce *pushv1.ClientEvent, cursor string) {
if ce == nil || d.publisher == nil {
return
}
userID := strings.TrimSpace(ce.GetUserId())
kind := strings.TrimSpace(ce.GetKind())
eventID := strings.TrimSpace(ce.GetEventId())
if userID == "" || kind == "" || eventID == "" {
d.logger.Warn("dropped malformed client event",
zap.String("cursor", cursor),
zap.String("user_id", userID),
zap.String("kind", kind),
zap.String("event_id", eventID),
)
d.recordDrop(ctx, "malformed_client_event")
return
}
d.publisher.Publish(push.Event{
UserID: userID,
DeviceSessionID: strings.TrimSpace(ce.GetDeviceSessionId()),
EventType: kind,
EventID: eventID,
PayloadBytes: cloneBytes(ce.GetPayload()),
RequestID: ce.GetRequestId(),
TraceID: ce.GetTraceId(),
})
}
func (d *Dispatcher) handleSessionInvalidation(si *pushv1.SessionInvalidation) {
if si == nil || d.invalidator == nil {
return
}
userID := strings.TrimSpace(si.GetUserId())
deviceSessionID := strings.TrimSpace(si.GetDeviceSessionId())
switch {
case deviceSessionID != "":
d.invalidator.RevokeDeviceSession(deviceSessionID)
case userID != "":
d.invalidator.RevokeAllForUser(userID)
default:
d.logger.Warn("dropped malformed session_invalidation: user_id and device_session_id both empty")
}
}
func (d *Dispatcher) recordDrop(ctx context.Context, reason string) {
if d.metrics == nil {
return
}
d.metrics.RecordInternalEventDrop(ctx,
attribute.String("component", "push_dispatcher"),
attribute.String("reason", reason),
)
}
func cloneBytes(in []byte) []byte {
if len(in) == 0 {
return nil
}
out := make([]byte, len(in))
copy(out, in)
return out
}
+157
View File
@@ -0,0 +1,157 @@
package events_test
import (
"context"
"sync"
"testing"
pushv1 "galaxy/backend/proto/push/v1"
"galaxy/gateway/internal/events"
"galaxy/gateway/internal/push"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type capturePublisher struct {
mu sync.Mutex
events []push.Event
}
func (c *capturePublisher) Publish(event push.Event) {
c.mu.Lock()
defer c.mu.Unlock()
c.events = append(c.events, event)
}
func (c *capturePublisher) snapshot() []push.Event {
c.mu.Lock()
defer c.mu.Unlock()
out := make([]push.Event, len(c.events))
copy(out, c.events)
return out
}
type captureInvalidator struct {
mu sync.Mutex
devices []string
users []string
}
func (c *captureInvalidator) RevokeDeviceSession(id string) {
c.mu.Lock()
defer c.mu.Unlock()
c.devices = append(c.devices, id)
}
func (c *captureInvalidator) RevokeAllForUser(id string) {
c.mu.Lock()
defer c.mu.Unlock()
c.users = append(c.users, id)
}
func (c *captureInvalidator) snapshot() ([]string, []string) {
c.mu.Lock()
defer c.mu.Unlock()
d := append([]string(nil), c.devices...)
u := append([]string(nil), c.users...)
return d, u
}
func TestDispatcherForwardsClientEventToPublisher(t *testing.T) {
t.Parallel()
pub := &capturePublisher{}
inv := &captureInvalidator{}
disp := events.NewDispatcher(pub, inv, nil, nil)
disp.Handle(context.Background(), &pushv1.PushEvent{
Cursor: "00000000000000000001",
Kind: &pushv1.PushEvent_ClientEvent{
ClientEvent: &pushv1.ClientEvent{
UserId: "user-1",
DeviceSessionId: "device-1",
Kind: "lobby.invite.received",
Payload: []byte(`{"x":1}`),
EventId: "route-1",
RequestId: "req-1",
TraceId: "trace-1",
},
},
})
got := pub.snapshot()
require.Len(t, got, 1)
assert.Equal(t, push.Event{
UserID: "user-1",
DeviceSessionID: "device-1",
EventType: "lobby.invite.received",
EventID: "route-1",
PayloadBytes: []byte(`{"x":1}`),
RequestID: "req-1",
TraceID: "trace-1",
}, got[0])
devices, users := inv.snapshot()
assert.Empty(t, devices)
assert.Empty(t, users)
}
func TestDispatcherDropsClientEventMissingEventID(t *testing.T) {
t.Parallel()
pub := &capturePublisher{}
disp := events.NewDispatcher(pub, &captureInvalidator{}, nil, nil)
disp.Handle(context.Background(), &pushv1.PushEvent{
Kind: &pushv1.PushEvent_ClientEvent{
ClientEvent: &pushv1.ClientEvent{
UserId: "user-1",
Kind: "lobby.invite.received",
},
},
})
assert.Empty(t, pub.snapshot())
}
func TestDispatcherSessionInvalidationByDeviceID(t *testing.T) {
t.Parallel()
inv := &captureInvalidator{}
disp := events.NewDispatcher(&capturePublisher{}, inv, nil, nil)
disp.Handle(context.Background(), &pushv1.PushEvent{
Kind: &pushv1.PushEvent_SessionInvalidation{
SessionInvalidation: &pushv1.SessionInvalidation{
UserId: "user-1",
DeviceSessionId: "device-1",
Reason: "auth.revoke_session",
},
},
})
devices, users := inv.snapshot()
assert.Equal(t, []string{"device-1"}, devices)
assert.Empty(t, users)
}
func TestDispatcherSessionInvalidationFanOutForUser(t *testing.T) {
t.Parallel()
inv := &captureInvalidator{}
disp := events.NewDispatcher(&capturePublisher{}, inv, nil, nil)
disp.Handle(context.Background(), &pushv1.PushEvent{
Kind: &pushv1.PushEvent_SessionInvalidation{
SessionInvalidation: &pushv1.SessionInvalidation{
UserId: "user-1",
Reason: "auth.revoke_all_for_user",
},
},
})
devices, users := inv.snapshot()
assert.Empty(t, devices)
assert.Equal(t, []string{"user-1"}, users)
}
@@ -1,396 +0,0 @@
package events
import (
"context"
"crypto/ed25519"
"crypto/sha256"
"encoding/base64"
"errors"
"net"
"sync"
"testing"
"time"
"galaxy/gateway/internal/app"
"galaxy/gateway/internal/authn"
"galaxy/gateway/internal/clock"
"galaxy/gateway/internal/config"
"galaxy/gateway/internal/downstream"
"galaxy/gateway/internal/grpcapi"
"galaxy/gateway/internal/replay"
"galaxy/gateway/internal/session"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/status"
)
var testNow = time.Date(2026, time.April, 1, 12, 0, 0, 0, time.UTC)
func TestAuthenticatedGatewayWarmsLocalSessionCache(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
local := session.NewMemoryCache()
fallback := &countingSessionCache{
records: map[string]session.Record{
"device-session-123": newActiveSessionRecord("user-123"),
},
}
readThrough, err := session.NewReadThroughCache(local, fallback)
require.NoError(t, err)
subscriber := newTestRedisSessionSubscriber(t, server, local)
downstreamClient := &recordingDownstreamClient{}
addr, running := runAuthenticatedGateway(t, readThrough, subscriber, downstreamClient)
defer running.stop(t)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-1"))
require.NoError(t, err)
assert.Equal(t, 1, fallback.lookupCalls())
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-2"))
require.NoError(t, err)
assert.Equal(t, 1, fallback.lookupCalls())
assert.Len(t, downstreamClient.commands(), 2)
}
func TestAuthenticatedGatewayUsesSessionUpdateEventWithoutFallbackLookup(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
local := session.NewMemoryCache()
fallback := &countingSessionCache{
records: map[string]session.Record{
"device-session-123": newActiveSessionRecord("user-123"),
},
}
readThrough, err := session.NewReadThroughCache(local, fallback)
require.NoError(t, err)
subscriber := newTestRedisSessionSubscriber(t, server, local)
downstreamClient := &recordingDownstreamClient{}
addr, running := runAuthenticatedGateway(t, readThrough, subscriber, downstreamClient)
defer running.stop(t)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-1"))
require.NoError(t, err)
assert.Equal(t, 1, fallback.lookupCalls())
addSessionEvent(t, server, "gateway:session_events", map[string]string{
"device_session_id": "device-session-123",
"user_id": "user-456",
"client_public_key": testClientPublicKeyBase64(),
"status": string(session.StatusActive),
})
require.Eventually(t, func() bool {
record, lookupErr := local.Lookup(context.Background(), "device-session-123")
return lookupErr == nil && record.UserID == "user-456"
}, time.Second, 10*time.Millisecond)
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-2"))
require.NoError(t, err)
assert.Equal(t, 1, fallback.lookupCalls())
commands := downstreamClient.commands()
require.Len(t, commands, 2)
assert.Equal(t, "user-456", commands[1].UserID)
}
func TestAuthenticatedGatewayRejectsRevokedSessionAfterEventWithoutFallbackLookup(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
local := session.NewMemoryCache()
fallback := &countingSessionCache{
records: map[string]session.Record{
"device-session-123": newActiveSessionRecord("user-123"),
},
}
readThrough, err := session.NewReadThroughCache(local, fallback)
require.NoError(t, err)
subscriber := newTestRedisSessionSubscriber(t, server, local)
downstreamClient := &recordingDownstreamClient{}
addr, running := runAuthenticatedGateway(t, readThrough, subscriber, downstreamClient)
defer running.stop(t)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-1"))
require.NoError(t, err)
assert.Equal(t, 1, fallback.lookupCalls())
addSessionEvent(t, server, "gateway:session_events", map[string]string{
"device_session_id": "device-session-123",
"user_id": "user-123",
"client_public_key": testClientPublicKeyBase64(),
"status": string(session.StatusRevoked),
"revoked_at_ms": "123456789",
})
require.Eventually(t, func() bool {
record, lookupErr := local.Lookup(context.Background(), "device-session-123")
return lookupErr == nil && record.Status == session.StatusRevoked
}, time.Second, 10*time.Millisecond)
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-2"))
require.Error(t, err)
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
assert.Equal(t, "device session is revoked", status.Convert(err).Message())
assert.Equal(t, 1, fallback.lookupCalls())
}
type runningAuthenticatedGateway struct {
cancel context.CancelFunc
resultCh chan error
}
func runAuthenticatedGateway(t *testing.T, sessionCache session.Cache, subscriber *RedisSessionSubscriber, downstreamClient downstream.Client) (string, runningAuthenticatedGateway) {
t.Helper()
addr := unusedTCPAddr(t)
grpcCfg := config.DefaultAuthenticatedGRPCConfig()
grpcCfg.Addr = addr
grpcCfg.FreshnessWindow = 5 * time.Minute
router := downstream.NewStaticRouter(map[string]downstream.Client{
"fleet.move": downstreamClient,
})
gateway := grpcapi.NewServer(grpcCfg, grpcapi.ServerDependencies{
Router: router,
ResponseSigner: newTestResponseSigner(t),
SessionCache: sessionCache,
ReplayStore: staticReplayStore{},
Clock: fixedClock{now: testNow},
})
application := app.New(
config.Config{
ShutdownTimeout: time.Second,
AuthenticatedGRPC: grpcCfg,
},
gateway,
subscriber,
)
ctx, cancel := context.WithCancel(context.Background())
resultCh := make(chan error, 1)
go func() {
resultCh <- application.Run(ctx)
}()
select {
case <-subscriber.started:
case <-time.After(time.Second):
require.FailNow(t, "session subscriber did not start")
}
return addr, runningAuthenticatedGateway{
cancel: cancel,
resultCh: resultCh,
}
}
func (g runningAuthenticatedGateway) stop(t *testing.T) {
t.Helper()
g.cancel()
select {
case err := <-g.resultCh:
require.NoError(t, err)
case <-time.After(2 * time.Second):
require.FailNow(t, "gateway did not stop after cancellation")
}
}
func dialGatewayClient(t *testing.T, addr string) *grpc.ClientConn {
t.Helper()
var conn *grpc.ClientConn
require.Eventually(t, func() bool {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
candidate, err := grpc.DialContext(
ctx,
addr,
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithBlock(),
)
if err != nil {
if candidate != nil {
_ = candidate.Close()
}
return false
}
conn = candidate
return true
}, 2*time.Second, 10*time.Millisecond, "gateway did not accept gRPC connections")
return conn
}
func unusedTCPAddr(t *testing.T) string {
t.Helper()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
addr := listener.Addr().String()
require.NoError(t, listener.Close())
return addr
}
func newExecuteCommandRequest(requestID string) *gatewayv1.ExecuteCommandRequest {
payloadBytes := []byte("payload")
payloadHash := sha256.Sum256(payloadBytes)
req := &gatewayv1.ExecuteCommandRequest{
ProtocolVersion: "v1",
DeviceSessionId: "device-session-123",
MessageType: "fleet.move",
TimestampMs: testNow.UnixMilli(),
RequestId: requestID,
PayloadBytes: payloadBytes,
PayloadHash: payloadHash[:],
TraceId: "trace-123",
}
req.Signature = ed25519.Sign(testClientPrivateKey(), authn.BuildRequestSigningInput(authn.RequestSigningFields{
ProtocolVersion: req.GetProtocolVersion(),
DeviceSessionID: req.GetDeviceSessionId(),
MessageType: req.GetMessageType(),
TimestampMS: req.GetTimestampMs(),
RequestID: req.GetRequestId(),
PayloadHash: req.GetPayloadHash(),
}))
return req
}
func newActiveSessionRecord(userID string) session.Record {
return session.Record{
DeviceSessionID: "device-session-123",
UserID: userID,
ClientPublicKey: testClientPublicKeyBase64(),
Status: session.StatusActive,
}
}
func testClientPrivateKey() ed25519.PrivateKey {
seed := sha256.Sum256([]byte("gateway-events-grpc-test-client"))
return ed25519.NewKeyFromSeed(seed[:])
}
func testClientPublicKeyBase64() string {
return base64.StdEncoding.EncodeToString(testClientPrivateKey().Public().(ed25519.PublicKey))
}
func newTestResponseSigner(t *testing.T) authn.ResponseSigner {
t.Helper()
seed := sha256.Sum256([]byte("gateway-events-grpc-test-response"))
signer, err := authn.NewEd25519ResponseSigner(ed25519.NewKeyFromSeed(seed[:]))
require.NoError(t, err)
return signer
}
type fixedClock struct {
now time.Time
}
func (c fixedClock) Now() time.Time {
return c.now
}
var _ clock.Clock = fixedClock{}
type staticReplayStore struct{}
func (staticReplayStore) Reserve(context.Context, string, string, time.Duration) error {
return nil
}
var _ replay.Store = staticReplayStore{}
type countingSessionCache struct {
mu sync.Mutex
records map[string]session.Record
lookupCount int
}
func (c *countingSessionCache) Lookup(context.Context, string) (session.Record, error) {
c.mu.Lock()
defer c.mu.Unlock()
c.lookupCount++
record, ok := c.records["device-session-123"]
if !ok {
return session.Record{}, errors.New("lookup session from counting cache: session cache record not found")
}
return record, nil
}
func (c *countingSessionCache) lookupCalls() int {
c.mu.Lock()
defer c.mu.Unlock()
return c.lookupCount
}
type recordingDownstreamClient struct {
mu sync.Mutex
captured []downstream.AuthenticatedCommand
}
func (c *recordingDownstreamClient) ExecuteCommand(_ context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
c.mu.Lock()
c.captured = append(c.captured, command)
c.mu.Unlock()
return downstream.UnaryResult{
ResultCode: "ok",
PayloadBytes: []byte("response"),
}, nil
}
func (c *recordingDownstreamClient) commands() []downstream.AuthenticatedCommand {
c.mu.Lock()
defer c.mu.Unlock()
cloned := make([]downstream.AuthenticatedCommand, len(c.captured))
copy(cloned, c.captured)
return cloned
}
@@ -1,447 +0,0 @@
package events
import (
"context"
"crypto/ed25519"
"crypto/sha256"
"encoding/base64"
"testing"
"time"
"galaxy/gateway/internal/app"
"galaxy/gateway/internal/authn"
"galaxy/gateway/internal/config"
"galaxy/gateway/internal/grpcapi"
"galaxy/gateway/internal/push"
"galaxy/gateway/internal/session"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
notificationfbs "galaxy/schema/fbs/notification"
"github.com/alicebob/miniredis/v2"
flatbuffers "github.com/google/flatbuffers/go"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
func TestSubscribeEventsFanOutsUserTargetedEventToAllUserSessions(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
sessionCache := session.NewMemoryCache()
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-1", "user-123")))
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-2", "user-123")))
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-3", "user-999")))
pushHub := push.NewHub(4)
clientSubscriber := newTestRedisClientEventSubscriber(t, server, pushHub)
addr, running := runPushGateway(t, sessionCache, pushHub, clientSubscriber)
defer running.stop(t)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
targetOneCtx, cancelTargetOne := context.WithCancel(context.Background())
defer cancelTargetOne()
targetOne, err := client.SubscribeEvents(targetOneCtx, newPushSubscribeEventsRequest("device-session-1", "request-1"))
require.NoError(t, err)
assertPushBootstrapEvent(t, recvPushEvent(t, targetOne), "request-1", "trace-device-session-1")
targetTwoCtx, cancelTargetTwo := context.WithCancel(context.Background())
defer cancelTargetTwo()
targetTwo, err := client.SubscribeEvents(targetTwoCtx, newPushSubscribeEventsRequest("device-session-2", "request-2"))
require.NoError(t, err)
assertPushBootstrapEvent(t, recvPushEvent(t, targetTwo), "request-2", "trace-device-session-2")
unrelatedCtx, cancelUnrelated := context.WithCancel(context.Background())
defer cancelUnrelated()
unrelated, err := client.SubscribeEvents(unrelatedCtx, newPushSubscribeEventsRequest("device-session-3", "request-3"))
require.NoError(t, err)
assertPushBootstrapEvent(t, recvPushEvent(t, unrelated), "request-3", "trace-device-session-3")
payloadBytes := buildGameTurnReadyPayload(t, "game-123", 54)
addClientEvent(t, server, "gateway:client_events", map[string]any{
"user_id": "user-123",
"event_type": "game.turn.ready",
"event_id": "event-123",
"payload_bytes": payloadBytes,
"request_id": "request-123",
"trace_id": "trace-123",
})
firstDelivered := recvPushEvent(t, targetOne)
assertSignedPushEvent(t, firstDelivered, push.Event{
UserID: "user-123",
EventType: "game.turn.ready",
EventID: "event-123",
PayloadBytes: payloadBytes,
RequestID: "request-123",
TraceID: "trace-123",
})
assertDecodedGameTurnReadyPayload(t, firstDelivered.GetPayloadBytes(), "game-123", 54)
secondDelivered := recvPushEvent(t, targetTwo)
assertSignedPushEvent(t, secondDelivered, push.Event{
UserID: "user-123",
EventType: "game.turn.ready",
EventID: "event-123",
PayloadBytes: payloadBytes,
RequestID: "request-123",
TraceID: "trace-123",
})
assertDecodedGameTurnReadyPayload(t, secondDelivered.GetPayloadBytes(), "game-123", 54)
assertNoPushEvent(t, unrelated, cancelUnrelated)
}
func TestSubscribeEventsFanOutsSessionTargetedEventOnlyToMatchingSession(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
sessionCache := session.NewMemoryCache()
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-1", "user-123")))
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-2", "user-123")))
pushHub := push.NewHub(4)
clientSubscriber := newTestRedisClientEventSubscriber(t, server, pushHub)
addr, running := runPushGateway(t, sessionCache, pushHub, clientSubscriber)
defer running.stop(t)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
otherCtx, cancelOther := context.WithCancel(context.Background())
defer cancelOther()
otherStream, err := client.SubscribeEvents(otherCtx, newPushSubscribeEventsRequest("device-session-1", "request-1"))
require.NoError(t, err)
assertPushBootstrapEvent(t, recvPushEvent(t, otherStream), "request-1", "trace-device-session-1")
targetCtx, cancelTarget := context.WithCancel(context.Background())
defer cancelTarget()
targetStream, err := client.SubscribeEvents(targetCtx, newPushSubscribeEventsRequest("device-session-2", "request-2"))
require.NoError(t, err)
assertPushBootstrapEvent(t, recvPushEvent(t, targetStream), "request-2", "trace-device-session-2")
addClientEvent(t, server, "gateway:client_events", map[string]any{
"user_id": "user-123",
"device_session_id": "device-session-2",
"event_type": "fleet.updated",
"event_id": "event-456",
"payload_bytes": []byte("payload-456"),
})
assertSignedPushEvent(t, recvPushEvent(t, targetStream), push.Event{
UserID: "user-123",
DeviceSessionID: "device-session-2",
EventType: "fleet.updated",
EventID: "event-456",
PayloadBytes: []byte("payload-456"),
})
assertNoPushEvent(t, otherStream, cancelOther)
}
func TestSubscribeEventsClosesRevokedSessionStreamAndRejectsReopen(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
sessionCache := session.NewMemoryCache()
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-1", "user-123")))
pushHub := push.NewHub(4)
clientSubscriber := newTestRedisClientEventSubscriber(t, server, pushHub)
sessionSubscriber := newTestRedisSessionSubscriberWithRevocationHandler(t, server, sessionCache, pushHub)
addr, running := runPushGateway(t, sessionCache, pushHub, clientSubscriber, sessionSubscriber)
defer running.stop(t)
select {
case <-sessionSubscriber.started:
case <-time.After(time.Second):
require.FailNow(t, "session subscriber did not start")
}
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
streamCtx, cancelStream := context.WithCancel(context.Background())
defer cancelStream()
stream, err := client.SubscribeEvents(streamCtx, newPushSubscribeEventsRequest("device-session-1", "request-1"))
require.NoError(t, err)
assertPushBootstrapEvent(t, recvPushEvent(t, stream), "request-1", "trace-device-session-1")
addSessionEvent(t, server, "gateway:session_events", map[string]string{
"device_session_id": "device-session-1",
"user_id": "user-123",
"client_public_key": pushClientPublicKeyBase64(),
"status": string(session.StatusRevoked),
"revoked_at_ms": "123456789",
})
require.Eventually(t, func() bool {
record, lookupErr := sessionCache.Lookup(context.Background(), "device-session-1")
return lookupErr == nil && record.Status == session.StatusRevoked
}, time.Second, 10*time.Millisecond)
recvErrCh := make(chan error, 1)
go func() {
_, recvErr := stream.Recv()
recvErrCh <- recvErr
}()
select {
case recvErr := <-recvErrCh:
require.Error(t, recvErr)
assert.Equal(t, codes.FailedPrecondition, status.Code(recvErr))
assert.Equal(t, "device session is revoked", status.Convert(recvErr).Message())
case <-time.After(time.Second):
require.FailNow(t, "stream did not close after revoke")
}
reopened, err := client.SubscribeEvents(context.Background(), newPushSubscribeEventsRequest("device-session-1", "request-2"))
if err == nil {
_, err = reopened.Recv()
}
require.Error(t, err)
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
assert.Equal(t, "device session is revoked", status.Convert(err).Message())
}
func TestSubscribeEventsClosesActiveStreamWhenGatewayShutsDown(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
sessionCache := session.NewMemoryCache()
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-1", "user-123")))
pushHub := push.NewHub(4)
clientSubscriber := newTestRedisClientEventSubscriber(t, server, pushHub)
addr, running := runPushGateway(t, sessionCache, pushHub, clientSubscriber)
defer running.stop(t)
conn := dialGatewayClient(t, addr)
defer func() {
require.NoError(t, conn.Close())
}()
client := gatewayv1.NewEdgeGatewayClient(conn)
stream, err := client.SubscribeEvents(context.Background(), newPushSubscribeEventsRequest("device-session-1", "request-1"))
require.NoError(t, err)
assertPushBootstrapEvent(t, recvPushEvent(t, stream), "request-1", "trace-device-session-1")
recvErrCh := make(chan error, 1)
go func() {
_, recvErr := stream.Recv()
recvErrCh <- recvErr
}()
running.cancel()
select {
case recvErr := <-recvErrCh:
require.Error(t, recvErr)
assert.Equal(t, codes.Unavailable, status.Code(recvErr))
assert.Equal(t, "gateway is shutting down", status.Convert(recvErr).Message())
case <-time.After(time.Second):
require.FailNow(t, "stream did not close after gateway shutdown")
}
}
func runPushGateway(t *testing.T, sessionCache session.Cache, pushHub *push.Hub, clientSubscriber *RedisClientEventSubscriber, extraComponents ...app.Component) (string, runningAuthenticatedGateway) {
t.Helper()
addr := unusedTCPAddr(t)
grpcCfg := config.DefaultAuthenticatedGRPCConfig()
grpcCfg.Addr = addr
grpcCfg.FreshnessWindow = 5 * time.Minute
responseSigner := newTestResponseSigner(t)
gateway := grpcapi.NewServer(grpcCfg, grpcapi.ServerDependencies{
Service: grpcapi.NewFanOutPushStreamService(pushHub, responseSigner, fixedClock{now: testNow}, zap.NewNop()),
ResponseSigner: responseSigner,
SessionCache: sessionCache,
ReplayStore: staticReplayStore{},
Clock: fixedClock{now: testNow},
PushHub: pushHub,
})
components := []app.Component{gateway, clientSubscriber}
components = append(components, extraComponents...)
application := app.New(
config.Config{
ShutdownTimeout: time.Second,
AuthenticatedGRPC: grpcCfg,
},
components...,
)
ctx, cancel := context.WithCancel(context.Background())
resultCh := make(chan error, 1)
go func() {
resultCh <- application.Run(ctx)
}()
select {
case <-clientSubscriber.started:
case <-time.After(time.Second):
require.FailNow(t, "client event subscriber did not start")
}
return addr, runningAuthenticatedGateway{
cancel: cancel,
resultCh: resultCh,
}
}
func newPushActiveSessionRecord(deviceSessionID string, userID string) session.Record {
return session.Record{
DeviceSessionID: deviceSessionID,
UserID: userID,
ClientPublicKey: pushClientPublicKeyBase64(),
Status: session.StatusActive,
}
}
func newPushSubscribeEventsRequest(deviceSessionID string, requestID string) *gatewayv1.SubscribeEventsRequest {
payloadHash := sha256.Sum256(nil)
traceID := "trace-" + deviceSessionID
req := &gatewayv1.SubscribeEventsRequest{
ProtocolVersion: "v1",
DeviceSessionId: deviceSessionID,
MessageType: "gateway.subscribe",
TimestampMs: testNow.UnixMilli(),
RequestId: requestID,
PayloadHash: payloadHash[:],
TraceId: traceID,
}
req.Signature = ed25519.Sign(pushClientPrivateKey(), authn.BuildRequestSigningInput(authn.RequestSigningFields{
ProtocolVersion: req.GetProtocolVersion(),
DeviceSessionID: req.GetDeviceSessionId(),
MessageType: req.GetMessageType(),
TimestampMS: req.GetTimestampMs(),
RequestID: req.GetRequestId(),
PayloadHash: req.GetPayloadHash(),
}))
return req
}
func recvPushEvent(t *testing.T, stream grpc.ServerStreamingClient[gatewayv1.GatewayEvent]) *gatewayv1.GatewayEvent {
t.Helper()
event, err := stream.Recv()
require.NoError(t, err)
return event
}
func assertPushBootstrapEvent(t *testing.T, event *gatewayv1.GatewayEvent, wantRequestID string, wantTraceID string) {
t.Helper()
require.NotNil(t, event)
assert.Equal(t, "gateway.server_time", event.GetEventType())
assert.Equal(t, wantRequestID, event.GetEventId())
assert.Equal(t, wantRequestID, event.GetRequestId())
assert.Equal(t, wantTraceID, event.GetTraceId())
require.NoError(t, authn.VerifyPayloadHash(event.GetPayloadBytes(), event.GetPayloadHash()))
require.NoError(t, authn.VerifyEventSignature(pushResponseSignerPublicKey(), event.GetSignature(), authn.EventSigningFields{
EventType: event.GetEventType(),
EventID: event.GetEventId(),
TimestampMS: event.GetTimestampMs(),
RequestID: event.GetRequestId(),
TraceID: event.GetTraceId(),
PayloadHash: event.GetPayloadHash(),
}))
}
func assertSignedPushEvent(t *testing.T, event *gatewayv1.GatewayEvent, want push.Event) {
t.Helper()
require.NotNil(t, event)
assert.Equal(t, want.EventType, event.GetEventType())
assert.Equal(t, want.EventID, event.GetEventId())
assert.Equal(t, want.RequestID, event.GetRequestId())
assert.Equal(t, want.TraceID, event.GetTraceId())
assert.Equal(t, want.PayloadBytes, event.GetPayloadBytes())
require.NoError(t, authn.VerifyPayloadHash(event.GetPayloadBytes(), event.GetPayloadHash()))
require.NoError(t, authn.VerifyEventSignature(pushResponseSignerPublicKey(), event.GetSignature(), authn.EventSigningFields{
EventType: event.GetEventType(),
EventID: event.GetEventId(),
TimestampMS: event.GetTimestampMs(),
RequestID: event.GetRequestId(),
TraceID: event.GetTraceId(),
PayloadHash: event.GetPayloadHash(),
}))
}
func assertNoPushEvent(t *testing.T, stream grpc.ServerStreamingClient[gatewayv1.GatewayEvent], cancel context.CancelFunc) {
t.Helper()
recvCh := make(chan *gatewayv1.GatewayEvent, 1)
errCh := make(chan error, 1)
go func() {
event, err := stream.Recv()
if err != nil {
errCh <- err
return
}
recvCh <- event
}()
select {
case event := <-recvCh:
require.FailNowf(t, "unexpected push event delivered", "%+v", event)
case <-time.After(100 * time.Millisecond):
cancel()
case err := <-errCh:
require.FailNowf(t, "stream closed unexpectedly", "%v", err)
}
}
func pushClientPrivateKey() ed25519.PrivateKey {
seed := sha256.Sum256([]byte("gateway-push-grpc-test-client"))
return ed25519.NewKeyFromSeed(seed[:])
}
func pushClientPublicKeyBase64() string {
return base64.StdEncoding.EncodeToString(pushClientPrivateKey().Public().(ed25519.PublicKey))
}
func pushResponseSignerPublicKey() ed25519.PublicKey {
seed := sha256.Sum256([]byte("gateway-events-grpc-test-response"))
return ed25519.NewKeyFromSeed(seed[:]).Public().(ed25519.PublicKey)
}
func buildGameTurnReadyPayload(t *testing.T, gameID string, turnNumber int64) []byte {
t.Helper()
builder := flatbuffers.NewBuilder(64)
gameIDOffset := builder.CreateString(gameID)
notificationfbs.GameTurnReadyEventStart(builder)
notificationfbs.GameTurnReadyEventAddGameId(builder, gameIDOffset)
notificationfbs.GameTurnReadyEventAddTurnNumber(builder, turnNumber)
offset := notificationfbs.GameTurnReadyEventEnd(builder)
notificationfbs.FinishGameTurnReadyEventBuffer(builder, offset)
return builder.FinishedBytes()
}
func assertDecodedGameTurnReadyPayload(t *testing.T, payload []byte, wantGameID string, wantTurnNumber int64) {
t.Helper()
event := notificationfbs.GetRootAsGameTurnReadyEvent(payload, 0)
require.Equal(t, wantGameID, string(event.GameId()))
require.Equal(t, wantTurnNumber, event.TurnNumber())
}
-347
View File
@@ -1,347 +0,0 @@
// Package events subscribes to internal session lifecycle streams used to keep
// the gateway hot-path session cache synchronized without per-request upstream
// lookups.
package events
import (
"context"
"errors"
"fmt"
"strconv"
"strings"
"sync"
"time"
"galaxy/gateway/internal/config"
"galaxy/gateway/internal/session"
"galaxy/gateway/internal/telemetry"
"github.com/redis/go-redis/v9"
"go.opentelemetry.io/otel/attribute"
"go.uber.org/zap"
)
const sessionEventReadCount int64 = 128
// SessionRevocationHandler reacts to a successfully applied revoked session
// snapshot and may tear down active resources bound to that session.
type SessionRevocationHandler interface {
// RevokeDeviceSession tears down active resources bound to deviceSessionID.
RevokeDeviceSession(deviceSessionID string)
}
// RedisSessionSubscriber consumes full session snapshots from one Redis Stream
// and applies them to a process-local session snapshot store.
type RedisSessionSubscriber struct {
client *redis.Client
stream string
pingTimeout time.Duration
readBlockTimeout time.Duration
store session.SnapshotStore
revocationHandler SessionRevocationHandler
logger *zap.Logger
metrics *telemetry.Runtime
startedOnce sync.Once
started chan struct{}
}
// NewRedisSessionSubscriber constructs a Redis Stream subscriber that uses
// client and applies updates to store.
func NewRedisSessionSubscriber(client *redis.Client, sessionCfg config.SessionCacheRedisConfig, eventsCfg config.SessionEventsRedisConfig, store session.SnapshotStore) (*RedisSessionSubscriber, error) {
return NewRedisSessionSubscriberWithObservability(client, sessionCfg, eventsCfg, store, nil, nil, nil)
}
// NewRedisSessionSubscriberWithRevocationHandler constructs a Redis Stream
// subscriber that uses client, applies updates to store, and optionally tears
// down active resources for revoked sessions.
func NewRedisSessionSubscriberWithRevocationHandler(client *redis.Client, sessionCfg config.SessionCacheRedisConfig, eventsCfg config.SessionEventsRedisConfig, store session.SnapshotStore, revocationHandler SessionRevocationHandler) (*RedisSessionSubscriber, error) {
return NewRedisSessionSubscriberWithObservability(client, sessionCfg, eventsCfg, store, revocationHandler, nil, nil)
}
// NewRedisSessionSubscriberWithObservability constructs a Redis Stream
// subscriber that also logs and counts malformed internal session events. The
// subscriber does not own the client; the runtime supplies a shared
// *redis.Client.
func NewRedisSessionSubscriberWithObservability(client *redis.Client, sessionCfg config.SessionCacheRedisConfig, eventsCfg config.SessionEventsRedisConfig, store session.SnapshotStore, revocationHandler SessionRevocationHandler, logger *zap.Logger, metrics *telemetry.Runtime) (*RedisSessionSubscriber, error) {
if client == nil {
return nil, errors.New("new redis session subscriber: nil redis client")
}
if sessionCfg.LookupTimeout <= 0 {
return nil, errors.New("new redis session subscriber: lookup timeout must be positive")
}
if strings.TrimSpace(eventsCfg.Stream) == "" {
return nil, errors.New("new redis session subscriber: stream must not be empty")
}
if eventsCfg.ReadBlockTimeout <= 0 {
return nil, errors.New("new redis session subscriber: read block timeout must be positive")
}
if store == nil {
return nil, errors.New("new redis session subscriber: nil session snapshot store")
}
if logger == nil {
logger = zap.NewNop()
}
return &RedisSessionSubscriber{
client: client,
stream: eventsCfg.Stream,
pingTimeout: sessionCfg.LookupTimeout,
readBlockTimeout: eventsCfg.ReadBlockTimeout,
store: store,
revocationHandler: revocationHandler,
logger: logger.Named("session_subscriber"),
metrics: metrics,
started: make(chan struct{}),
}, nil
}
// Run consumes session lifecycle events until ctx is canceled or Redis returns
// an unexpected error.
func (s *RedisSessionSubscriber) Run(ctx context.Context) error {
if s == nil || s.client == nil {
return errors.New("run redis session subscriber: nil subscriber")
}
if ctx == nil {
return errors.New("run redis session subscriber: nil context")
}
if err := ctx.Err(); err != nil {
return err
}
lastID, err := s.resolveStartID(ctx)
if err != nil {
return err
}
s.signalStarted()
for {
streams, err := s.client.XRead(ctx, &redis.XReadArgs{
Streams: []string{s.stream, lastID},
Count: sessionEventReadCount,
Block: s.readBlockTimeout,
}).Result()
switch {
case err == nil:
for _, stream := range streams {
for _, message := range stream.Messages {
s.applyMessage(message)
lastID = message.ID
}
}
continue
case errors.Is(err, redis.Nil):
continue
case ctx.Err() != nil && (errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) || errors.Is(err, redis.ErrClosed)):
return ctx.Err()
case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded), errors.Is(err, redis.ErrClosed):
return fmt.Errorf("run redis session subscriber: %w", err)
default:
return fmt.Errorf("run redis session subscriber: %w", err)
}
}
}
func (s *RedisSessionSubscriber) resolveStartID(ctx context.Context) (string, error) {
messages, err := s.client.XRevRangeN(ctx, s.stream, "+", "-", 1).Result()
switch {
case err == nil:
case errors.Is(err, redis.Nil):
return "0-0", nil
default:
return "", fmt.Errorf("run redis session subscriber: resolve stream tail: %w", err)
}
if len(messages) == 0 {
return "0-0", nil
}
return messages[0].ID, nil
}
// Shutdown is a no-op kept for App framework compatibility. The blocking
// XRead loop terminates when its context is cancelled by the parent runtime,
// which also owns and closes the shared Redis client.
func (s *RedisSessionSubscriber) Shutdown(ctx context.Context) error {
if ctx == nil {
return errors.New("shutdown redis session subscriber: nil context")
}
return nil
}
// Close is a no-op kept for backwards-compatible cleanup wiring; the
// subscriber does not own the shared Redis client.
func (s *RedisSessionSubscriber) Close() error {
return nil
}
func (s *RedisSessionSubscriber) signalStarted() {
s.startedOnce.Do(func() {
close(s.started)
})
}
func (s *RedisSessionSubscriber) applyMessage(message redis.XMessage) {
record, err := decodeSessionRecordSnapshot(message.Values)
if err != nil {
s.logger.Warn("dropped malformed session event",
zap.String("stream", s.stream),
zap.String("message_id", message.ID),
zap.Error(err),
)
s.metrics.RecordInternalEventDrop(context.Background(),
attribute.String("component", "session_subscriber"),
attribute.String("reason", "malformed_event"),
)
if deviceSessionID, ok := extractDeviceSessionID(message.Values); ok {
s.store.Delete(deviceSessionID)
}
return
}
if err := s.store.Upsert(record); err != nil {
s.logger.Warn("dropped session snapshot after store failure",
zap.String("stream", s.stream),
zap.String("message_id", message.ID),
zap.String("device_session_id", record.DeviceSessionID),
zap.Error(err),
)
s.metrics.RecordInternalEventDrop(context.Background(),
attribute.String("component", "session_subscriber"),
attribute.String("reason", "store_failure"),
)
s.store.Delete(record.DeviceSessionID)
return
}
if record.Status == session.StatusRevoked && s.revocationHandler != nil {
s.revocationHandler.RevokeDeviceSession(record.DeviceSessionID)
}
}
func decodeSessionRecordSnapshot(values map[string]any) (session.Record, error) {
requiredKeys := map[string]struct{}{
"device_session_id": {},
"user_id": {},
"client_public_key": {},
"status": {},
}
optionalKeys := map[string]struct{}{
"revoked_at_ms": {},
}
for key := range values {
if _, ok := requiredKeys[key]; ok {
continue
}
if _, ok := optionalKeys[key]; ok {
continue
}
return session.Record{}, fmt.Errorf("decode session event: unsupported field %q", key)
}
deviceSessionID, err := requiredStringField(values, "device_session_id")
if err != nil {
return session.Record{}, err
}
userID, err := requiredStringField(values, "user_id")
if err != nil {
return session.Record{}, err
}
clientPublicKey, err := requiredStringField(values, "client_public_key")
if err != nil {
return session.Record{}, err
}
statusValue, err := requiredStringField(values, "status")
if err != nil {
return session.Record{}, err
}
record := session.Record{
DeviceSessionID: deviceSessionID,
UserID: userID,
ClientPublicKey: clientPublicKey,
Status: session.Status(statusValue),
}
if rawRevokedAtMS, ok := values["revoked_at_ms"]; ok {
revokedAtMS, err := parseInt64Field(rawRevokedAtMS, "revoked_at_ms")
if err != nil {
return session.Record{}, err
}
record.RevokedAtMS = &revokedAtMS
}
return record, nil
}
func extractDeviceSessionID(values map[string]any) (string, bool) {
value, ok := values["device_session_id"]
if !ok {
return "", false
}
deviceSessionID, err := coerceString(value)
if err != nil {
return "", false
}
if strings.TrimSpace(deviceSessionID) == "" {
return "", false
}
return deviceSessionID, true
}
func requiredStringField(values map[string]any, field string) (string, error) {
value, ok := values[field]
if !ok {
return "", fmt.Errorf("decode session event: missing %s", field)
}
stringValue, err := coerceString(value)
if err != nil {
return "", fmt.Errorf("decode session event: %s: %w", field, err)
}
if strings.TrimSpace(stringValue) == "" {
return "", fmt.Errorf("decode session event: %s must not be empty", field)
}
return stringValue, nil
}
func parseInt64Field(value any, field string) (int64, error) {
stringValue, err := coerceString(value)
if err != nil {
return 0, fmt.Errorf("decode session event: %s: %w", field, err)
}
parsed, err := strconv.ParseInt(strings.TrimSpace(stringValue), 10, 64)
if err != nil {
return 0, fmt.Errorf("decode session event: %s: %w", field, err)
}
return parsed, nil
}
func coerceString(value any) (string, error) {
switch typed := value.(type) {
case string:
return typed, nil
case []byte:
return string(typed), nil
case fmt.Stringer:
return typed.String(), nil
case int:
return strconv.Itoa(typed), nil
case int64:
return strconv.FormatInt(typed, 10), nil
case uint64:
return strconv.FormatUint(typed, 10), nil
default:
return "", fmt.Errorf("unsupported value type %T", value)
}
}
-381
View File
@@ -1,381 +0,0 @@
package events
import (
"context"
"sync"
"testing"
"time"
"galaxy/gateway/internal/config"
"galaxy/gateway/internal/session"
"github.com/alicebob/miniredis/v2"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRedisSessionSubscriberAppliesActiveSnapshot(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := session.NewMemoryCache()
subscriber := newTestRedisSessionSubscriber(t, server, store)
running := runTestSubscriber(t, subscriber)
defer running.stop(t)
addSessionEvent(t, server, "gateway:session_events", map[string]string{
"device_session_id": "device-session-123",
"user_id": "user-123",
"client_public_key": "public-key-123",
"status": string(session.StatusActive),
})
require.Eventually(t, func() bool {
record, err := store.Lookup(context.Background(), "device-session-123")
if err != nil {
return false
}
return record.UserID == "user-123" && record.Status == session.StatusActive
}, time.Second, 10*time.Millisecond)
}
func TestRedisSessionSubscriberAppliesRevokedSnapshot(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := session.NewMemoryCache()
require.NoError(t, store.Upsert(session.Record{
DeviceSessionID: "device-session-123",
UserID: "user-123",
ClientPublicKey: "public-key-123",
Status: session.StatusActive,
}))
subscriber := newTestRedisSessionSubscriber(t, server, store)
running := runTestSubscriber(t, subscriber)
defer running.stop(t)
addSessionEvent(t, server, "gateway:session_events", map[string]string{
"device_session_id": "device-session-123",
"user_id": "user-123",
"client_public_key": "public-key-123",
"status": string(session.StatusRevoked),
"revoked_at_ms": "123456789",
})
require.Eventually(t, func() bool {
record, err := store.Lookup(context.Background(), "device-session-123")
if err != nil || record.RevokedAtMS == nil {
return false
}
return record.Status == session.StatusRevoked && *record.RevokedAtMS == 123456789
}, time.Second, 10*time.Millisecond)
}
func TestRedisSessionSubscriberRevokedSnapshotTriggersRevocationHandler(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := session.NewMemoryCache()
handler := &recordingSessionRevocationHandler{}
subscriber := newTestRedisSessionSubscriberWithRevocationHandler(t, server, store, handler)
running := runTestSubscriber(t, subscriber)
defer running.stop(t)
addSessionEvent(t, server, "gateway:session_events", map[string]string{
"device_session_id": "device-session-123",
"user_id": "user-123",
"client_public_key": "public-key-123",
"status": string(session.StatusRevoked),
"revoked_at_ms": "123456789",
})
require.Eventually(t, func() bool {
record, err := store.Lookup(context.Background(), "device-session-123")
if err != nil || record.Status != session.StatusRevoked {
return false
}
return assert.ObjectsAreEqual([]string{"device-session-123"}, handler.revocations())
}, time.Second, 10*time.Millisecond)
}
func TestRedisSessionSubscriberActiveSnapshotDoesNotTriggerRevocationHandler(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := session.NewMemoryCache()
handler := &recordingSessionRevocationHandler{}
subscriber := newTestRedisSessionSubscriberWithRevocationHandler(t, server, store, handler)
running := runTestSubscriber(t, subscriber)
defer running.stop(t)
addSessionEvent(t, server, "gateway:session_events", map[string]string{
"device_session_id": "device-session-123",
"user_id": "user-123",
"client_public_key": "public-key-123",
"status": string(session.StatusActive),
})
assert.Never(t, func() bool {
return len(handler.revocations()) != 0
}, 100*time.Millisecond, 10*time.Millisecond)
}
func TestRedisSessionSubscriberStoreFailureDoesNotTriggerRevocationHandler(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
handler := &recordingSessionRevocationHandler{}
subscriber := newTestRedisSessionSubscriberWithRevocationHandler(t, server, failingSnapshotStore{}, handler)
running := runTestSubscriber(t, subscriber)
defer running.stop(t)
addSessionEvent(t, server, "gateway:session_events", map[string]string{
"device_session_id": "device-session-123",
"user_id": "user-123",
"client_public_key": "public-key-123",
"status": string(session.StatusRevoked),
"revoked_at_ms": "123456789",
})
assert.Never(t, func() bool {
return len(handler.revocations()) != 0
}, 100*time.Millisecond, 10*time.Millisecond)
}
func TestRedisSessionSubscriberLaterEventWins(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := session.NewMemoryCache()
subscriber := newTestRedisSessionSubscriber(t, server, store)
running := runTestSubscriber(t, subscriber)
defer running.stop(t)
addSessionEvent(t, server, "gateway:session_events", map[string]string{
"device_session_id": "device-session-123",
"user_id": "user-123",
"client_public_key": "public-key-123",
"status": string(session.StatusActive),
})
addSessionEvent(t, server, "gateway:session_events", map[string]string{
"device_session_id": "device-session-123",
"user_id": "user-456",
"client_public_key": "public-key-456",
"status": string(session.StatusActive),
})
require.Eventually(t, func() bool {
record, err := store.Lookup(context.Background(), "device-session-123")
if err != nil {
return false
}
return record.UserID == "user-456" && record.ClientPublicKey == "public-key-456"
}, time.Second, 10*time.Millisecond)
}
func TestRedisSessionSubscriberMalformedEventEvictsAndContinues(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := session.NewMemoryCache()
require.NoError(t, store.Upsert(session.Record{
DeviceSessionID: "device-session-123",
UserID: "user-123",
ClientPublicKey: "public-key-123",
Status: session.StatusActive,
}))
subscriber := newTestRedisSessionSubscriber(t, server, store)
running := runTestSubscriber(t, subscriber)
defer running.stop(t)
addSessionEvent(t, server, "gateway:session_events", map[string]string{
"device_session_id": "device-session-123",
"user_id": "user-123",
"client_public_key": "public-key-123",
"status": "paused",
})
require.Eventually(t, func() bool {
_, err := store.Lookup(context.Background(), "device-session-123")
return err != nil
}, time.Second, 10*time.Millisecond)
addSessionEvent(t, server, "gateway:session_events", map[string]string{
"device_session_id": "device-session-123",
"user_id": "user-456",
"client_public_key": "public-key-456",
"status": string(session.StatusActive),
})
require.Eventually(t, func() bool {
record, err := store.Lookup(context.Background(), "device-session-123")
if err != nil {
return false
}
return record.UserID == "user-456" && record.Status == session.StatusActive
}, time.Second, 10*time.Millisecond)
}
func TestRedisSessionSubscriberShutdownInterruptsBlockingRead(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
store := session.NewMemoryCache()
subscriber := newTestRedisSessionSubscriber(t, server, store)
ctx, cancel := context.WithCancel(context.Background())
resultCh := make(chan error, 1)
go func() {
resultCh <- subscriber.Run(ctx)
}()
select {
case <-subscriber.started:
case <-time.After(time.Second):
require.FailNow(t, "subscriber did not start")
}
cancel()
require.NoError(t, subscriber.Shutdown(context.Background()))
select {
case err := <-resultCh:
require.ErrorIs(t, err, context.Canceled)
case <-time.After(time.Second):
require.FailNow(t, "subscriber did not stop after shutdown")
}
}
func newTestRedisSessionSubscriber(t *testing.T, server *miniredis.Miniredis, store session.SnapshotStore) *RedisSessionSubscriber {
t.Helper()
return newTestRedisSessionSubscriberWithRevocationHandler(t, server, store, nil)
}
func newTestRedisSessionSubscriberWithRevocationHandler(t *testing.T, server *miniredis.Miniredis, store session.SnapshotStore, revocationHandler SessionRevocationHandler) *RedisSessionSubscriber {
t.Helper()
client := newTestRedisClient(t, server)
subscriber, err := NewRedisSessionSubscriberWithRevocationHandler(
client,
config.SessionCacheRedisConfig{
KeyPrefix: "gateway:session:",
LookupTimeout: 250 * time.Millisecond,
},
config.SessionEventsRedisConfig{
Stream: "gateway:session_events",
ReadBlockTimeout: 25 * time.Millisecond,
},
store,
revocationHandler,
)
require.NoError(t, err)
return subscriber
}
func newTestRedisClient(t *testing.T, server *miniredis.Miniredis) *redis.Client {
t.Helper()
client := redis.NewClient(&redis.Options{
Addr: server.Addr(),
Protocol: 2,
DisableIdentity: true,
})
t.Cleanup(func() {
assert.NoError(t, client.Close())
})
return client
}
type recordingSessionRevocationHandler struct {
mu sync.Mutex
revokedIDs []string
}
func (h *recordingSessionRevocationHandler) RevokeDeviceSession(deviceSessionID string) {
h.mu.Lock()
h.revokedIDs = append(h.revokedIDs, deviceSessionID)
h.mu.Unlock()
}
func (h *recordingSessionRevocationHandler) revocations() []string {
h.mu.Lock()
defer h.mu.Unlock()
return append([]string(nil), h.revokedIDs...)
}
type failingSnapshotStore struct{}
func (failingSnapshotStore) Lookup(context.Context, string) (session.Record, error) {
return session.Record{}, session.ErrNotFound
}
func (failingSnapshotStore) Upsert(session.Record) error {
return context.DeadlineExceeded
}
func (failingSnapshotStore) Delete(string) {}
func addSessionEvent(t *testing.T, server *miniredis.Miniredis, stream string, fields map[string]string) {
t.Helper()
values := make([]string, 0, len(fields)*2)
for key, value := range fields {
values = append(values, key, value)
}
_, err := server.XAdd(stream, "*", values)
require.NoError(t, err)
}
type runningSubscriber struct {
cancel context.CancelFunc
resultCh chan error
stopOnce bool
}
func runTestSubscriber(t *testing.T, subscriber *RedisSessionSubscriber) runningSubscriber {
t.Helper()
ctx, cancel := context.WithCancel(context.Background())
resultCh := make(chan error, 1)
go func() {
resultCh <- subscriber.Run(ctx)
}()
select {
case <-subscriber.started:
case <-time.After(time.Second):
require.FailNow(t, "subscriber did not start")
}
return runningSubscriber{
cancel: cancel,
resultCh: resultCh,
}
}
func (r runningSubscriber) stop(t *testing.T) {
t.Helper()
r.cancel()
select {
case err := <-r.resultCh:
require.ErrorIs(t, err, context.Canceled)
case <-time.After(time.Second):
require.FailNow(t, "subscriber did not stop")
}
}
+1 -1
View File
@@ -8,7 +8,7 @@ import (
"strings"
"time"
"galaxy/gateway/internal/authn"
"galaxy/gateway/authn"
"galaxy/gateway/internal/clock"
"galaxy/gateway/internal/downstream"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
@@ -7,7 +7,7 @@ import (
"testing"
"time"
"galaxy/gateway/internal/authn"
"galaxy/gateway/authn"
"galaxy/gateway/internal/config"
"galaxy/gateway/internal/downstream"
"galaxy/gateway/internal/testutil"
+1 -1
View File
@@ -4,7 +4,7 @@ import (
"context"
"errors"
"galaxy/gateway/internal/authn"
"galaxy/gateway/authn"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"google.golang.org/grpc"
+1 -1
View File
@@ -6,7 +6,7 @@ import (
"crypto/sha256"
"errors"
"galaxy/gateway/internal/authn"
"galaxy/gateway/authn"
"galaxy/gateway/internal/clock"
"galaxy/gateway/internal/logging"
"galaxy/gateway/internal/push"
+1 -1
View File
@@ -5,7 +5,7 @@ import (
"context"
"crypto/sha256"
"galaxy/gateway/internal/authn"
"galaxy/gateway/authn"
"galaxy/gateway/internal/clock"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
gatewayfbs "galaxy/schema/fbs/gateway"
+1 -1
View File
@@ -8,7 +8,7 @@ import (
"net"
"sync"
"galaxy/gateway/internal/authn"
"galaxy/gateway/authn"
"galaxy/gateway/internal/clock"
"galaxy/gateway/internal/config"
"galaxy/gateway/internal/downstream"
+1 -1
View File
@@ -4,7 +4,7 @@ import (
"context"
"errors"
"galaxy/gateway/internal/authn"
"galaxy/gateway/authn"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
"google.golang.org/grpc"
@@ -9,7 +9,7 @@ import (
"encoding/pem"
"time"
"galaxy/gateway/internal/authn"
"galaxy/gateway/authn"
"galaxy/gateway/internal/downstream"
"galaxy/gateway/internal/session"
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
+22
View File
@@ -273,6 +273,28 @@ func (h *Hub) RevokeDeviceSession(deviceSessionID string) {
}
}
// RevokeAllForUser closes every active subscription bound to userID,
// regardless of device-session id. Used when backend emits a
// SessionInvalidation that targets every session of a user.
func (h *Hub) RevokeAllForUser(userID string) {
if h == nil {
return
}
userID = strings.TrimSpace(userID)
if userID == "" {
return
}
h.mu.RLock()
targets := cloneSubscriptions(h.byUser[userID])
h.mu.RUnlock()
for _, target := range targets {
h.unregister(target.id, ErrSubscriptionRevoked)
}
}
// Shutdown closes every active subscription because the gateway is shutting
// down.
func (h *Hub) Shutdown() {
@@ -1,232 +0,0 @@
package restapi
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
)
const (
authServiceSendEmailCodePath = "/api/v1/public/auth/send-email-code"
authServiceConfirmEmailCodePath = "/api/v1/public/auth/confirm-email-code"
)
// HTTPAuthServiceClient implements AuthServiceClient over the Auth / Session
// Service public HTTP API using strict JSON request and response decoding.
type HTTPAuthServiceClient struct {
baseURL string
httpClient *http.Client
}
type authServiceErrorEnvelope struct {
Error *authServiceErrorBody `json:"error"`
}
type authServiceErrorBody struct {
Code string `json:"code"`
Message string `json:"message"`
}
// NewHTTPAuthServiceClient constructs an AuthServiceClient that delegates the
// gateway public-auth routes to the Auth / Session Service public HTTP API at
// baseURL. The resulting client relies only on the caller-provided context for
// cancellation and timeout control.
func NewHTTPAuthServiceClient(baseURL string) (*HTTPAuthServiceClient, error) {
transport, ok := http.DefaultTransport.(*http.Transport)
if !ok {
return nil, errors.New("new auth service HTTP client: default transport is not *http.Transport")
}
return newHTTPAuthServiceClient(baseURL, &http.Client{
Transport: transport.Clone(),
})
}
func newHTTPAuthServiceClient(baseURL string, httpClient *http.Client) (*HTTPAuthServiceClient, error) {
if httpClient == nil {
return nil, errors.New("new auth service HTTP client: http client must not be nil")
}
trimmedBaseURL := strings.TrimSpace(baseURL)
if trimmedBaseURL == "" {
return nil, errors.New("new auth service HTTP client: base URL must not be empty")
}
parsedBaseURL, err := url.Parse(strings.TrimRight(trimmedBaseURL, "/"))
if err != nil {
return nil, fmt.Errorf("new auth service HTTP client: parse base URL: %w", err)
}
if parsedBaseURL.Scheme == "" || parsedBaseURL.Host == "" {
return nil, errors.New("new auth service HTTP client: base URL must be absolute")
}
return &HTTPAuthServiceClient{
baseURL: parsedBaseURL.String(),
httpClient: httpClient,
}, nil
}
// Close releases idle HTTP connections owned by the client transport.
func (c *HTTPAuthServiceClient) Close() error {
if c == nil || c.httpClient == nil {
return nil
}
type idleCloser interface {
CloseIdleConnections()
}
if transport, ok := c.httpClient.Transport.(idleCloser); ok {
transport.CloseIdleConnections()
}
return nil
}
// SendEmailCode delegates the public send-email-code route to the configured
// Auth / Session Service public HTTP API.
func (c *HTTPAuthServiceClient) SendEmailCode(ctx context.Context, input SendEmailCodeInput) (SendEmailCodeResult, error) {
payload, statusCode, err := c.doJSONRequest(ctx, authServiceSendEmailCodePath, input, map[string]string{
"Accept-Language": resolvePreferredLanguage(input.PreferredLanguage),
})
if err != nil {
return SendEmailCodeResult{}, fmt.Errorf("send email code via auth service: %w", err)
}
switch {
case statusCode == http.StatusOK:
var result SendEmailCodeResult
if err := decodeStrictJSONPayload(payload, &result); err != nil {
return SendEmailCodeResult{}, fmt.Errorf("send email code via auth service: decode success response: %w", err)
}
if err := validateSendEmailCodeResult(&result); err != nil {
return SendEmailCodeResult{}, fmt.Errorf("send email code via auth service: %w", err)
}
return result, nil
case statusCode >= 400 && statusCode <= 599:
authErr, err := decodeAuthServiceError(statusCode, payload)
if err != nil {
return SendEmailCodeResult{}, fmt.Errorf("send email code via auth service: %w", err)
}
return SendEmailCodeResult{}, authErr
default:
return SendEmailCodeResult{}, fmt.Errorf("send email code via auth service: unexpected HTTP status %d", statusCode)
}
}
// ConfirmEmailCode delegates the public confirm-email-code route to the
// configured Auth / Session Service public HTTP API.
func (c *HTTPAuthServiceClient) ConfirmEmailCode(ctx context.Context, input ConfirmEmailCodeInput) (ConfirmEmailCodeResult, error) {
payload, statusCode, err := c.doJSONRequest(ctx, authServiceConfirmEmailCodePath, input, nil)
if err != nil {
return ConfirmEmailCodeResult{}, fmt.Errorf("confirm email code via auth service: %w", err)
}
switch {
case statusCode == http.StatusOK:
var result ConfirmEmailCodeResult
if err := decodeStrictJSONPayload(payload, &result); err != nil {
return ConfirmEmailCodeResult{}, fmt.Errorf("confirm email code via auth service: decode success response: %w", err)
}
if err := validateConfirmEmailCodeResult(&result); err != nil {
return ConfirmEmailCodeResult{}, fmt.Errorf("confirm email code via auth service: %w", err)
}
return result, nil
case statusCode >= 400 && statusCode <= 599:
authErr, err := decodeAuthServiceError(statusCode, payload)
if err != nil {
return ConfirmEmailCodeResult{}, fmt.Errorf("confirm email code via auth service: %w", err)
}
return ConfirmEmailCodeResult{}, authErr
default:
return ConfirmEmailCodeResult{}, fmt.Errorf("confirm email code via auth service: unexpected HTTP status %d", statusCode)
}
}
func (c *HTTPAuthServiceClient) doJSONRequest(ctx context.Context, path string, requestBody any, headers map[string]string) ([]byte, int, error) {
if c == nil || c.httpClient == nil {
return nil, 0, errors.New("nil client")
}
if ctx == nil {
return nil, 0, errors.New("nil context")
}
if err := ctx.Err(); err != nil {
return nil, 0, err
}
payload, err := json.Marshal(requestBody)
if err != nil {
return nil, 0, fmt.Errorf("marshal request body: %w", err)
}
request, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+path, bytes.NewReader(payload))
if err != nil {
return nil, 0, fmt.Errorf("build request: %w", err)
}
request.Header.Set("Content-Type", "application/json")
for key, value := range headers {
if strings.TrimSpace(value) == "" {
continue
}
request.Header.Set(key, value)
}
response, err := c.httpClient.Do(request)
if err != nil {
return nil, 0, err
}
defer response.Body.Close()
responsePayload, err := io.ReadAll(response.Body)
if err != nil {
return nil, 0, fmt.Errorf("read response body: %w", err)
}
return responsePayload, response.StatusCode, nil
}
func decodeAuthServiceError(statusCode int, payload []byte) (*AuthServiceError, error) {
var envelope authServiceErrorEnvelope
if err := decodeStrictJSONPayload(payload, &envelope); err != nil {
return nil, fmt.Errorf("decode error response: %w", err)
}
if envelope.Error == nil {
return nil, errors.New("decode error response: missing error object")
}
return &AuthServiceError{
StatusCode: statusCode,
Code: envelope.Error.Code,
Message: envelope.Error.Message,
}, nil
}
func decodeStrictJSONPayload(payload []byte, target any) error {
decoder := json.NewDecoder(bytes.NewReader(payload))
decoder.DisallowUnknownFields()
if err := decoder.Decode(target); err != nil {
return err
}
if err := decoder.Decode(&struct{}{}); err != io.EOF {
if err == nil {
return errors.New("unexpected trailing JSON input")
}
return err
}
return nil
}
var _ AuthServiceClient = (*HTTPAuthServiceClient)(nil)
@@ -1,369 +0,0 @@
package restapi
import (
"context"
"errors"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewHTTPAuthServiceClient(t *testing.T) {
t.Parallel()
tests := []struct {
name string
baseURL string
wantErr string
}{
{
name: "success",
baseURL: " http://127.0.0.1:8080/ ",
},
{
name: "empty base url",
wantErr: "base URL must not be empty",
},
{
name: "relative base url",
baseURL: "/authsession",
wantErr: "base URL must be absolute",
},
{
name: "malformed base url",
baseURL: "://bad",
wantErr: "parse base URL",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
client, err := NewHTTPAuthServiceClient(tt.baseURL)
if tt.wantErr != "" {
require.Error(t, err)
assert.ErrorContains(t, err, tt.wantErr)
return
}
require.NoError(t, err)
assert.Equal(t, "http://127.0.0.1:8080", client.baseURL)
assert.NoError(t, client.Close())
})
}
}
func TestHTTPAuthServiceClientSendEmailCodeSuccess(t *testing.T) {
t.Parallel()
var requestContentType string
var requestAcceptLanguage string
var requestBody string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodPost, r.Method)
assert.Equal(t, authServiceSendEmailCodePath, r.URL.Path)
requestContentType = r.Header.Get("Content-Type")
requestAcceptLanguage = r.Header.Get("Accept-Language")
payload, err := io.ReadAll(r.Body)
require.NoError(t, err)
requestBody = string(payload)
w.Header().Set("Content-Type", "application/json")
_, err = io.WriteString(w, `{"challenge_id":"challenge-123"}`)
require.NoError(t, err)
}))
defer server.Close()
client := newTestHTTPAuthServiceClient(t, server)
result, err := client.SendEmailCode(context.Background(), SendEmailCodeInput{
Email: "pilot@example.com",
PreferredLanguage: "fr-FR",
})
require.NoError(t, err)
assert.Equal(t, SendEmailCodeResult{ChallengeID: "challenge-123"}, result)
assert.Equal(t, "application/json", requestContentType)
assert.Equal(t, "fr-FR", requestAcceptLanguage)
assert.JSONEq(t, `{"email":"pilot@example.com"}`, requestBody)
}
func TestHTTPAuthServiceClientSendEmailCodeDefaultsAcceptLanguageToEnglish(t *testing.T) {
t.Parallel()
var requestAcceptLanguage string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestAcceptLanguage = r.Header.Get("Accept-Language")
w.Header().Set("Content-Type", "application/json")
_, err := io.WriteString(w, `{"challenge_id":"challenge-123"}`)
require.NoError(t, err)
}))
defer server.Close()
client := newTestHTTPAuthServiceClient(t, server)
_, err := client.SendEmailCode(context.Background(), SendEmailCodeInput{Email: "pilot@example.com"})
require.NoError(t, err)
assert.Equal(t, "en", requestAcceptLanguage)
}
func TestHTTPAuthServiceClientConfirmEmailCodeSuccess(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodPost, r.Method)
assert.Equal(t, authServiceConfirmEmailCodePath, r.URL.Path)
payload, err := io.ReadAll(r.Body)
require.NoError(t, err)
assert.JSONEq(t, `{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key","time_zone":"Europe/Kaliningrad"}`, string(payload))
w.Header().Set("Content-Type", "application/json")
_, err = io.WriteString(w, `{"device_session_id":"device-session-123"}`)
require.NoError(t, err)
}))
defer server.Close()
client := newTestHTTPAuthServiceClient(t, server)
result, err := client.ConfirmEmailCode(context.Background(), ConfirmEmailCodeInput{
ChallengeID: "challenge-123",
Code: "123456",
ClientPublicKey: "public-key",
TimeZone: "Europe/Kaliningrad",
})
require.NoError(t, err)
assert.Equal(t, ConfirmEmailCodeResult{DeviceSessionID: "device-session-123"}, result)
}
func TestHTTPAuthServiceClientProjectsAuthServiceErrors(t *testing.T) {
t.Parallel()
tests := []struct {
name string
statusCode int
responseBody string
call func(*HTTPAuthServiceClient) error
wantStatusCode int
wantCode string
wantMessage string
}{
{
name: "send email code error",
statusCode: http.StatusServiceUnavailable,
responseBody: `{"error":{"code":"service_unavailable","message":"service is unavailable"}}`,
call: func(client *HTTPAuthServiceClient) error {
_, err := client.SendEmailCode(context.Background(), SendEmailCodeInput{Email: "pilot@example.com"})
return err
},
wantStatusCode: http.StatusServiceUnavailable,
wantCode: "service_unavailable",
wantMessage: "service is unavailable",
},
{
name: "confirm email code error",
statusCode: http.StatusConflict,
responseBody: `{"error":{"code":"session_limit_exceeded","message":"active session limit would be exceeded"}}`,
call: func(client *HTTPAuthServiceClient) error {
_, err := client.ConfirmEmailCode(context.Background(), ConfirmEmailCodeInput{
ChallengeID: "challenge-123",
Code: "123456",
ClientPublicKey: "public-key",
TimeZone: "Europe/Kaliningrad",
})
return err
},
wantStatusCode: http.StatusConflict,
wantCode: "session_limit_exceeded",
wantMessage: "active session limit would be exceeded",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(tt.statusCode)
_, err := io.WriteString(w, tt.responseBody)
require.NoError(t, err)
}))
defer server.Close()
client := newTestHTTPAuthServiceClient(t, server)
err := tt.call(client)
require.Error(t, err)
var authErr *AuthServiceError
require.ErrorAs(t, err, &authErr)
assert.Equal(t, tt.wantStatusCode, authErr.StatusCode)
assert.Equal(t, tt.wantCode, authErr.Code)
assert.Equal(t, tt.wantMessage, authErr.Message)
})
}
}
func TestHTTPAuthServiceClientRejectsMalformedPayloads(t *testing.T) {
t.Parallel()
tests := []struct {
name string
path string
statusCode int
responseBody string
wantErr string
}{
{
name: "send email code rejects unknown success field",
path: authServiceSendEmailCodePath,
statusCode: http.StatusOK,
responseBody: `{"challenge_id":"challenge-123","extra":true}`,
wantErr: "decode success response",
},
{
name: "confirm email code rejects empty success field",
path: authServiceConfirmEmailCodePath,
statusCode: http.StatusOK,
responseBody: `{"device_session_id":" "}`,
wantErr: "empty device_session_id",
},
{
name: "rejects missing error object",
path: authServiceSendEmailCodePath,
statusCode: http.StatusBadRequest,
responseBody: `{}`,
wantErr: "missing error object",
},
{
name: "rejects malformed error envelope",
path: authServiceConfirmEmailCodePath,
statusCode: http.StatusBadRequest,
responseBody: `{"error":{"code":"invalid_code","message":"confirmation code is invalid","extra":true}}`,
wantErr: "decode error response",
},
{
name: "rejects unexpected status",
path: authServiceSendEmailCodePath,
statusCode: http.StatusCreated,
responseBody: `{"challenge_id":"challenge-123"}`,
wantErr: "unexpected HTTP status 201",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, tt.path, r.URL.Path)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(tt.statusCode)
_, err := io.WriteString(w, tt.responseBody)
require.NoError(t, err)
}))
defer server.Close()
client := newTestHTTPAuthServiceClient(t, server)
var err error
switch tt.path {
case authServiceSendEmailCodePath:
_, err = client.SendEmailCode(context.Background(), SendEmailCodeInput{Email: "pilot@example.com"})
default:
_, err = client.ConfirmEmailCode(context.Background(), ConfirmEmailCodeInput{
ChallengeID: "challenge-123",
Code: "123456",
ClientPublicKey: "public-key",
TimeZone: "Europe/Kaliningrad",
})
}
require.Error(t, err)
assert.ErrorContains(t, err, tt.wantErr)
assert.NotErrorAs(t, err, new(*AuthServiceError))
})
}
}
func TestHTTPAuthServiceClientUsesCallerContext(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(100 * time.Millisecond)
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"challenge_id":"challenge-123"}`)
}))
defer server.Close()
client := newTestHTTPAuthServiceClient(t, server)
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer cancel()
_, err := client.SendEmailCode(ctx, SendEmailCodeInput{Email: "pilot@example.com"})
require.Error(t, err)
assert.ErrorContains(t, err, "send email code via auth service")
assert.True(t, errors.Is(err, context.DeadlineExceeded))
}
func TestHTTPAuthServiceClientRejectsNilContext(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.FailNow(t, "unexpected request", r.URL.Path)
}))
defer server.Close()
client := newTestHTTPAuthServiceClient(t, server)
_, err := client.SendEmailCode(nil, SendEmailCodeInput{Email: "pilot@example.com"})
require.Error(t, err)
assert.ErrorContains(t, err, "nil context")
}
func newTestHTTPAuthServiceClient(t *testing.T, server *httptest.Server) *HTTPAuthServiceClient {
t.Helper()
client, err := newHTTPAuthServiceClient(server.URL, server.Client())
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, client.Close())
})
return client
}
func TestDecodeStrictJSONPayloadRejectsTrailingJSON(t *testing.T) {
t.Parallel()
var target struct {
Value string `json:"value"`
}
err := decodeStrictJSONPayload([]byte(`{"value":"ok"}{}`), &target)
require.Error(t, err)
assert.Equal(t, "unexpected trailing JSON input", err.Error())
}
func TestDecodeAuthServiceErrorPreservesBlankFieldsForLaterNormalization(t *testing.T) {
t.Parallel()
authErr, err := decodeAuthServiceError(http.StatusBadGateway, []byte(`{"error":{"code":" ","message":" "}}`))
require.NoError(t, err)
assert.Equal(t, http.StatusBadGateway, authErr.StatusCode)
assert.True(t, strings.TrimSpace(authErr.Code) == "")
assert.True(t, strings.TrimSpace(authErr.Message) == "")
}
+50
View File
@@ -0,0 +1,50 @@
package session
import (
"context"
"errors"
"fmt"
)
// BackendLookup describes the slice of `backendclient.RESTClient`
// SessionCache depends on. The narrow interface keeps this package free
// of any backendclient import.
type BackendLookup interface {
LookupSession(ctx context.Context, deviceSessionID string) (Record, error)
}
// BackendCache resolves authenticated device sessions by issuing one
// synchronous REST call to backend per request. The canonical implementation replaces the
// previous Redis-backed projection with this thin wrapper; gateway no
// longer keeps a process-local snapshot. See ARCHITECTURE.md §11
// «backend (sync REST), no Redis projection».
type BackendCache struct {
backend BackendLookup
}
// NewBackendCache constructs a Cache that delegates every Lookup to
// backend over REST. backend must not be nil.
func NewBackendCache(backend BackendLookup) (*BackendCache, error) {
if backend == nil {
return nil, errors.New("session.NewBackendCache: backend lookup must not be nil")
}
return &BackendCache{backend: backend}, nil
}
// Lookup resolves deviceSessionID via backend. ErrNotFound is forwarded
// unchanged so callers can keep using the existing equality check.
func (c *BackendCache) Lookup(ctx context.Context, deviceSessionID string) (Record, error) {
if c == nil {
return Record{}, errors.New("session backend cache: nil cache")
}
if c.backend == nil {
return Record{}, errors.New("session backend cache: nil backend lookup")
}
rec, err := c.backend.LookupSession(ctx, deviceSessionID)
if err != nil {
return Record{}, fmt.Errorf("session backend cache: %w", err)
}
return rec, nil
}
var _ Cache = (*BackendCache)(nil)
-88
View File
@@ -1,88 +0,0 @@
package session
import (
"context"
"errors"
"fmt"
"strings"
"sync"
)
// MemoryCache stores session record snapshots in process-local memory. It is
// intended for the authenticated gateway hot path and deliberately keeps no
// TTL or size-based eviction policy.
type MemoryCache struct {
mu sync.RWMutex
records map[string]Record
}
// NewMemoryCache constructs an empty process-local session snapshot store.
func NewMemoryCache() *MemoryCache {
return &MemoryCache{
records: make(map[string]Record),
}
}
// Lookup resolves deviceSessionID from the process-local snapshot map.
func (c *MemoryCache) Lookup(ctx context.Context, deviceSessionID string) (Record, error) {
if c == nil {
return Record{}, errors.New("lookup session from in-memory cache: nil cache")
}
if ctx == nil || fmt.Sprint(ctx) == "context.TODO" {
return Record{}, errors.New("lookup session from in-memory cache: nil context")
}
if strings.TrimSpace(deviceSessionID) == "" {
return Record{}, errors.New("lookup session from in-memory cache: empty device session id")
}
c.mu.RLock()
record, ok := c.records[deviceSessionID]
c.mu.RUnlock()
if !ok {
return Record{}, fmt.Errorf("lookup session from in-memory cache: %w", ErrNotFound)
}
return cloneRecord(record), nil
}
// Upsert stores record in the process-local snapshot map after validating the
// same session invariants expected from the Redis-backed cache.
func (c *MemoryCache) Upsert(record Record) error {
if c == nil {
return errors.New("upsert session into in-memory cache: nil cache")
}
if err := validateRecord(record.DeviceSessionID, record); err != nil {
return fmt.Errorf("upsert session into in-memory cache: %w", err)
}
cloned := cloneRecord(record)
c.mu.Lock()
c.records[record.DeviceSessionID] = cloned
c.mu.Unlock()
return nil
}
// Delete removes the local snapshot for deviceSessionID when one exists.
func (c *MemoryCache) Delete(deviceSessionID string) {
if c == nil || strings.TrimSpace(deviceSessionID) == "" {
return
}
c.mu.Lock()
delete(c.records, deviceSessionID)
c.mu.Unlock()
}
func cloneRecord(record Record) Record {
cloned := record
if record.RevokedAtMS != nil {
value := *record.RevokedAtMS
cloned.RevokedAtMS = &value
}
return cloned
}
var _ SnapshotStore = (*MemoryCache)(nil)
-68
View File
@@ -1,68 +0,0 @@
package session
import (
"context"
"errors"
"fmt"
)
// ReadThroughCache resolves authenticated sessions from a process-local
// SnapshotStore first and falls back to another Cache only on a local miss.
type ReadThroughCache struct {
local SnapshotStore
fallback Cache
}
// NewReadThroughCache constructs a hot-path cache that seeds local snapshots
// from fallback on demand.
func NewReadThroughCache(local SnapshotStore, fallback Cache) (*ReadThroughCache, error) {
if local == nil {
return nil, errors.New("new read-through session cache: nil local cache")
}
if fallback == nil {
return nil, errors.New("new read-through session cache: nil fallback cache")
}
return &ReadThroughCache{
local: local,
fallback: fallback,
}, nil
}
// Lookup resolves deviceSessionID from local first, then performs one fallback
// lookup on a local miss and seeds the local cache with the returned snapshot.
func (c *ReadThroughCache) Lookup(ctx context.Context, deviceSessionID string) (Record, error) {
if c == nil {
return Record{}, errors.New("lookup session from read-through cache: nil cache")
}
record, err := c.local.Lookup(ctx, deviceSessionID)
switch {
case err == nil:
return record, nil
case !errors.Is(err, ErrNotFound):
return Record{}, fmt.Errorf("lookup session from read-through cache: %w", err)
}
record, err = c.fallback.Lookup(ctx, deviceSessionID)
if err != nil {
return Record{}, err
}
if err := c.local.Upsert(record); err != nil {
return Record{}, fmt.Errorf("lookup session from read-through cache: seed local cache: %w", err)
}
return cloneRecord(record), nil
}
// Local returns the mutable process-local snapshot store used by c.
func (c *ReadThroughCache) Local() SnapshotStore {
if c == nil {
return nil
}
return c.local
}
var _ Cache = (*ReadThroughCache)(nil)
@@ -1,176 +0,0 @@
package session
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMemoryCacheLookupReturnsClonedRecord(t *testing.T) {
t.Parallel()
cache := NewMemoryCache()
revokedAtMS := int64(123456789)
require.NoError(t, cache.Upsert(Record{
DeviceSessionID: "device-session-123",
UserID: "user-123",
ClientPublicKey: "public-key-123",
Status: StatusRevoked,
RevokedAtMS: &revokedAtMS,
}))
record, err := cache.Lookup(context.Background(), "device-session-123")
require.NoError(t, err)
require.NotNil(t, record.RevokedAtMS)
*record.RevokedAtMS = 1
stored, err := cache.Lookup(context.Background(), "device-session-123")
require.NoError(t, err)
require.NotNil(t, stored.RevokedAtMS)
assert.Equal(t, revokedAtMS, *stored.RevokedAtMS)
}
func TestReadThroughCacheLocalHitSkipsFallback(t *testing.T) {
t.Parallel()
local := NewMemoryCache()
require.NoError(t, local.Upsert(Record{
DeviceSessionID: "device-session-123",
UserID: "user-123",
ClientPublicKey: "public-key-123",
Status: StatusActive,
}))
fallback := &recordingCache{
lookupFunc: func(context.Context, string) (Record, error) {
return Record{}, errors.New("fallback should not be called")
},
}
cache, err := NewReadThroughCache(local, fallback)
require.NoError(t, err)
record, err := cache.Lookup(context.Background(), "device-session-123")
require.NoError(t, err)
assert.Equal(t, Record{
DeviceSessionID: "device-session-123",
UserID: "user-123",
ClientPublicKey: "public-key-123",
Status: StatusActive,
}, record)
assert.Equal(t, 0, fallback.lookupCalls)
}
func TestReadThroughCacheFallbackSeedsLocalCache(t *testing.T) {
t.Parallel()
local := NewMemoryCache()
fallback := &recordingCache{
lookupFunc: func(context.Context, string) (Record, error) {
return Record{
DeviceSessionID: "device-session-123",
UserID: "user-123",
ClientPublicKey: "public-key-123",
Status: StatusActive,
}, nil
},
}
cache, err := NewReadThroughCache(local, fallback)
require.NoError(t, err)
record, err := cache.Lookup(context.Background(), "device-session-123")
require.NoError(t, err)
assert.Equal(t, 1, fallback.lookupCalls)
assert.Equal(t, "user-123", record.UserID)
record, err = cache.Lookup(context.Background(), "device-session-123")
require.NoError(t, err)
assert.Equal(t, 1, fallback.lookupCalls)
assert.Equal(t, "user-123", record.UserID)
}
func TestReadThroughCacheKeepsRevokedSnapshotLocal(t *testing.T) {
t.Parallel()
revokedAtMS := int64(123456789)
local := NewMemoryCache()
fallback := &recordingCache{
lookupFunc: func(context.Context, string) (Record, error) {
return Record{
DeviceSessionID: "device-session-123",
UserID: "user-123",
ClientPublicKey: "public-key-123",
Status: StatusRevoked,
RevokedAtMS: &revokedAtMS,
}, nil
},
}
cache, err := NewReadThroughCache(local, fallback)
require.NoError(t, err)
record, err := cache.Lookup(context.Background(), "device-session-123")
require.NoError(t, err)
require.NotNil(t, record.RevokedAtMS)
assert.Equal(t, StatusRevoked, record.Status)
assert.Equal(t, 1, fallback.lookupCalls)
record, err = cache.Lookup(context.Background(), "device-session-123")
require.NoError(t, err)
require.NotNil(t, record.RevokedAtMS)
assert.Equal(t, StatusRevoked, record.Status)
assert.Equal(t, revokedAtMS, *record.RevokedAtMS)
assert.Equal(t, 1, fallback.lookupCalls)
}
func TestReadThroughCacheReturnsClonedFallbackRecord(t *testing.T) {
t.Parallel()
revokedAtMS := int64(123456789)
local := NewMemoryCache()
fallback := &recordingCache{
lookupFunc: func(context.Context, string) (Record, error) {
return Record{
DeviceSessionID: "device-session-123",
UserID: "user-123",
ClientPublicKey: "public-key-123",
Status: StatusRevoked,
RevokedAtMS: &revokedAtMS,
}, nil
},
}
cache, err := NewReadThroughCache(local, fallback)
require.NoError(t, err)
record, err := cache.Lookup(context.Background(), "device-session-123")
require.NoError(t, err)
require.NotNil(t, record.RevokedAtMS)
*record.RevokedAtMS = 1
stored, err := local.Lookup(context.Background(), "device-session-123")
require.NoError(t, err)
require.NotNil(t, stored.RevokedAtMS)
assert.Equal(t, revokedAtMS, *stored.RevokedAtMS)
}
type recordingCache struct {
lookupCalls int
lookupFunc func(context.Context, string) (Record, error)
}
func (c *recordingCache) Lookup(ctx context.Context, deviceSessionID string) (Record, error) {
c.lookupCalls++
if c.lookupFunc != nil {
return c.lookupFunc(ctx, deviceSessionID)
}
return Record{}, errors.New("lookup is not implemented")
}
-150
View File
@@ -1,150 +0,0 @@
package session
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"strings"
"time"
"galaxy/gateway/internal/config"
"github.com/redis/go-redis/v9"
)
// RedisCache implements Cache with Redis GET lookups over strict JSON session
// records.
type RedisCache struct {
client *redis.Client
keyPrefix string
lookupTimeout time.Duration
}
type redisRecord struct {
DeviceSessionID string `json:"device_session_id"`
UserID string `json:"user_id"`
ClientPublicKey string `json:"client_public_key"`
Status Status `json:"status"`
RevokedAtMS *int64 `json:"revoked_at_ms,omitempty"`
}
// NewRedisCache constructs a Redis-backed SessionCache that uses client and
// applies the namespace and timeout settings from cfg. The cache does not own
// the client; the runtime supplies a shared *redis.Client.
func NewRedisCache(client *redis.Client, cfg config.SessionCacheRedisConfig) (*RedisCache, error) {
if client == nil {
return nil, errors.New("new redis session cache: nil redis client")
}
if strings.TrimSpace(cfg.KeyPrefix) == "" {
return nil, errors.New("new redis session cache: redis key prefix must not be empty")
}
if cfg.LookupTimeout <= 0 {
return nil, errors.New("new redis session cache: lookup timeout must be positive")
}
return &RedisCache{
client: client,
keyPrefix: cfg.KeyPrefix,
lookupTimeout: cfg.LookupTimeout,
}, nil
}
// Lookup resolves deviceSessionID from Redis, validates the cached JSON
// payload strictly, and returns the decoded session record.
func (c *RedisCache) Lookup(ctx context.Context, deviceSessionID string) (Record, error) {
if c == nil || c.client == nil {
return Record{}, errors.New("lookup session from redis: nil cache")
}
if ctx == nil || fmt.Sprint(ctx) == "context.TODO" {
return Record{}, errors.New("lookup session from redis: nil context")
}
if strings.TrimSpace(deviceSessionID) == "" {
return Record{}, errors.New("lookup session from redis: empty device session id")
}
lookupCtx, cancel := context.WithTimeout(ctx, c.lookupTimeout)
defer cancel()
payload, err := c.client.Get(lookupCtx, c.lookupKey(deviceSessionID)).Bytes()
switch {
case errors.Is(err, redis.Nil):
return Record{}, fmt.Errorf("lookup session from redis: %w", ErrNotFound)
case err != nil:
return Record{}, fmt.Errorf("lookup session from redis: %w", err)
}
record, err := decodeRedisRecord(deviceSessionID, payload)
if err != nil {
return Record{}, fmt.Errorf("lookup session from redis: %w", err)
}
return record, nil
}
func (c *RedisCache) lookupKey(deviceSessionID string) string {
return c.keyPrefix + deviceSessionID
}
func decodeRedisRecord(expectedDeviceSessionID string, payload []byte) (Record, error) {
decoder := json.NewDecoder(bytes.NewReader(payload))
decoder.DisallowUnknownFields()
var stored redisRecord
if err := decoder.Decode(&stored); err != nil {
return Record{}, fmt.Errorf("decode redis session record: %w", err)
}
if err := decoder.Decode(&struct{}{}); err != io.EOF {
if err == nil {
return Record{}, errors.New("decode redis session record: unexpected trailing JSON input")
}
return Record{}, fmt.Errorf("decode redis session record: %w", err)
}
record := Record{
DeviceSessionID: stored.DeviceSessionID,
UserID: stored.UserID,
ClientPublicKey: stored.ClientPublicKey,
Status: stored.Status,
RevokedAtMS: cloneOptionalInt64(stored.RevokedAtMS),
}
if err := validateRecord(expectedDeviceSessionID, record); err != nil {
return Record{}, err
}
return record, nil
}
func validateRecord(expectedDeviceSessionID string, record Record) error {
if record.DeviceSessionID == "" {
return errors.New("session record device_session_id must not be empty")
}
if record.DeviceSessionID != expectedDeviceSessionID {
return fmt.Errorf("session record device_session_id %q does not match requested %q", record.DeviceSessionID, expectedDeviceSessionID)
}
if record.UserID == "" {
return errors.New("session record user_id must not be empty")
}
if record.ClientPublicKey == "" {
return errors.New("session record client_public_key must not be empty")
}
if !record.Status.IsKnown() {
return fmt.Errorf("session record status %q is unsupported", record.Status)
}
return nil
}
func cloneOptionalInt64(value *int64) *int64 {
if value == nil {
return nil
}
cloned := *value
return &cloned
}
var _ Cache = (*RedisCache)(nil)
-317
View File
@@ -1,317 +0,0 @@
package session
import (
"context"
"encoding/json"
"errors"
"testing"
"time"
"galaxy/gateway/internal/config"
"github.com/alicebob/miniredis/v2"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func newRedisClient(t *testing.T, server *miniredis.Miniredis) *redis.Client {
t.Helper()
client := redis.NewClient(&redis.Options{
Addr: server.Addr(),
Protocol: 2,
DisableIdentity: true,
})
t.Cleanup(func() {
assert.NoError(t, client.Close())
})
return client
}
func TestNewRedisCache(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
client := newRedisClient(t, server)
validCfg := config.SessionCacheRedisConfig{
KeyPrefix: "gateway:session:",
LookupTimeout: 250 * time.Millisecond,
}
tests := []struct {
name string
client *redis.Client
cfg config.SessionCacheRedisConfig
wantErr string
}{
{name: "valid config", client: client, cfg: validCfg},
{name: "nil client", client: nil, cfg: validCfg, wantErr: "nil redis client"},
{
name: "empty key prefix",
client: client,
cfg: config.SessionCacheRedisConfig{LookupTimeout: 250 * time.Millisecond},
wantErr: "redis key prefix must not be empty",
},
{
name: "non-positive lookup timeout",
client: client,
cfg: config.SessionCacheRedisConfig{KeyPrefix: "gateway:session:"},
wantErr: "lookup timeout must be positive",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
cache, err := NewRedisCache(tt.client, tt.cfg)
if tt.wantErr != "" {
require.Error(t, err)
require.ErrorContains(t, err, tt.wantErr)
return
}
require.NoError(t, err)
require.NotNil(t, cache)
})
}
}
func TestRedisCacheLookup(t *testing.T) {
t.Parallel()
revokedAtMS := int64(123456789)
tests := []struct {
name string
cfg config.SessionCacheRedisConfig
requestID string
seed func(*testing.T, *miniredis.Miniredis, config.SessionCacheRedisConfig)
want Record
wantErrIs error
wantErrText string
assertErrText string
}{
{
name: "active cache hit",
requestID: "device-session-123",
cfg: config.SessionCacheRedisConfig{
KeyPrefix: "gateway:session:",
},
seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) {
t.Helper()
setRedisSessionRecord(t, server, cfg.KeyPrefix+"device-session-123", redisRecord{
DeviceSessionID: "device-session-123",
UserID: "user-123",
ClientPublicKey: "public-key-123",
Status: StatusActive,
})
},
want: Record{
DeviceSessionID: "device-session-123",
UserID: "user-123",
ClientPublicKey: "public-key-123",
Status: StatusActive,
},
},
{
name: "missing session",
requestID: "device-session-404",
cfg: config.SessionCacheRedisConfig{
KeyPrefix: "gateway:session:",
},
wantErrIs: ErrNotFound,
assertErrText: "session cache record not found",
},
{
name: "revoked session",
requestID: "device-session-revoked",
cfg: config.SessionCacheRedisConfig{
KeyPrefix: "gateway:session:",
},
seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) {
t.Helper()
setRedisSessionRecord(t, server, cfg.KeyPrefix+"device-session-revoked", redisRecord{
DeviceSessionID: "device-session-revoked",
UserID: "user-777",
ClientPublicKey: "public-key-777",
Status: StatusRevoked,
RevokedAtMS: &revokedAtMS,
})
},
want: Record{
DeviceSessionID: "device-session-revoked",
UserID: "user-777",
ClientPublicKey: "public-key-777",
Status: StatusRevoked,
RevokedAtMS: &revokedAtMS,
},
},
{
name: "malformed json",
requestID: "device-session-bad-json",
cfg: config.SessionCacheRedisConfig{
KeyPrefix: "gateway:session:",
},
seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) {
t.Helper()
server.Set(cfg.KeyPrefix+"device-session-bad-json", "{")
},
wantErrText: "decode redis session record",
},
{
name: "unknown status",
requestID: "device-session-unknown-status",
cfg: config.SessionCacheRedisConfig{
KeyPrefix: "gateway:session:",
},
seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) {
t.Helper()
setRedisSessionRecord(t, server, cfg.KeyPrefix+"device-session-unknown-status", redisRecord{
DeviceSessionID: "device-session-unknown-status",
UserID: "user-1",
ClientPublicKey: "public-key-1",
Status: Status("paused"),
})
},
wantErrText: `status "paused" is unsupported`,
},
{
name: "missing required field",
requestID: "device-session-missing-user",
cfg: config.SessionCacheRedisConfig{
KeyPrefix: "gateway:session:",
},
seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) {
t.Helper()
setRedisSessionRecord(t, server, cfg.KeyPrefix+"device-session-missing-user", redisRecord{
DeviceSessionID: "device-session-missing-user",
ClientPublicKey: "public-key-1",
Status: StatusActive,
})
},
wantErrText: "user_id must not be empty",
},
{
name: "device session id mismatch",
requestID: "device-session-requested",
cfg: config.SessionCacheRedisConfig{
KeyPrefix: "gateway:session:",
},
seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) {
t.Helper()
setRedisSessionRecord(t, server, cfg.KeyPrefix+"device-session-requested", redisRecord{
DeviceSessionID: "device-session-other",
UserID: "user-1",
ClientPublicKey: "public-key-1",
Status: StatusActive,
})
},
wantErrText: `does not match requested "device-session-requested"`,
},
{
name: "key prefix is honored",
requestID: "device-session-prefixed",
cfg: config.SessionCacheRedisConfig{
KeyPrefix: "custom:session:",
},
seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) {
t.Helper()
setRedisSessionRecord(t, server, cfg.KeyPrefix+"device-session-prefixed", redisRecord{
DeviceSessionID: "device-session-prefixed",
UserID: "user-prefixed",
ClientPublicKey: "public-key-prefixed",
Status: StatusActive,
})
setRedisSessionRecord(t, server, "gateway:session:device-session-prefixed", redisRecord{
DeviceSessionID: "device-session-prefixed",
UserID: "wrong-user",
ClientPublicKey: "wrong-key",
Status: StatusRevoked,
})
},
want: Record{
DeviceSessionID: "device-session-prefixed",
UserID: "user-prefixed",
ClientPublicKey: "public-key-prefixed",
Status: StatusActive,
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
cfg := tt.cfg
cfg.LookupTimeout = 250 * time.Millisecond
if tt.seed != nil {
tt.seed(t, server, cfg)
}
cache := newTestRedisCache(t, server, cfg)
record, err := cache.Lookup(context.Background(), tt.requestID)
if tt.wantErrIs != nil || tt.wantErrText != "" {
require.Error(t, err)
if tt.wantErrIs != nil {
assert.ErrorIs(t, err, tt.wantErrIs)
}
if tt.wantErrText != "" {
assert.ErrorContains(t, err, tt.wantErrText)
}
if tt.assertErrText != "" {
assert.ErrorContains(t, err, tt.assertErrText)
}
return
}
require.NoError(t, err)
assert.Equal(t, tt.want, record)
})
}
}
func newTestRedisCache(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) *RedisCache {
t.Helper()
if cfg.KeyPrefix == "" {
cfg.KeyPrefix = "gateway:session:"
}
if cfg.LookupTimeout == 0 {
cfg.LookupTimeout = 250 * time.Millisecond
}
cache, err := NewRedisCache(newRedisClient(t, server), cfg)
require.NoError(t, err)
return cache
}
func setRedisSessionRecord(t *testing.T, server *miniredis.Miniredis, key string, record redisRecord) {
t.Helper()
payload, err := json.Marshal(record)
require.NoError(t, err)
server.Set(key, string(payload))
}
func TestRedisCacheLookupNilContext(t *testing.T) {
t.Parallel()
server := miniredis.RunT(t)
cache := newTestRedisCache(t, server, config.SessionCacheRedisConfig{})
_, err := cache.Lookup(context.TODO(), "device-session-123")
require.Error(t, err)
assert.False(t, errors.Is(err, ErrNotFound))
assert.ErrorContains(t, err, "nil context")
}
+4 -15
View File
@@ -13,27 +13,16 @@ var (
ErrNotFound = errors.New("session cache record not found")
)
// Cache resolves authenticated device-session state from the gateway hot-path
// cache.
// Cache resolves authenticated device-session state from the gateway
// hot path. The implementation dropped the previous Redis projection: the only
// implementation is *BackendCache, which calls backend's
// `/api/v1/internal/sessions/{id}` synchronously per request.
type Cache interface {
// Lookup returns the cached record for deviceSessionID. Implementations must
// wrap ErrNotFound when the cache does not contain the requested record.
Lookup(ctx context.Context, deviceSessionID string) (Record, error)
}
// SnapshotStore stores mutable session record snapshots inside one gateway
// process and exposes the same read contract as Cache for the hot path.
type SnapshotStore interface {
Cache
// Upsert stores record under record.DeviceSessionID, replacing any previous
// snapshot for that session.
Upsert(record Record) error
// Delete removes the local snapshot for deviceSessionID when it exists.
Delete(deviceSessionID string)
}
// Status identifies the cached lifecycle state of a device session.
type Status string