feat: backend service
This commit is contained in:
@@ -1,80 +0,0 @@
|
||||
package authn
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
)
|
||||
|
||||
const (
|
||||
// EventDomainMarkerV1 binds the v1 server event signature to the Galaxy
|
||||
// gateway transport contract.
|
||||
EventDomainMarkerV1 = "galaxy-event-v1"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrInvalidEventSignature reports that a gateway stream event signature is
|
||||
// not a raw Ed25519 signature for the canonical event signing input.
|
||||
ErrInvalidEventSignature = errors.New("invalid event signature")
|
||||
)
|
||||
|
||||
// EventSigningFields contains the canonical v1 stream-event fields that are
|
||||
// bound into the server signing input.
|
||||
type EventSigningFields struct {
|
||||
// EventType identifies the stable client-facing event category.
|
||||
EventType string
|
||||
|
||||
// EventID is the stable event correlation identifier.
|
||||
EventID string
|
||||
|
||||
// TimestampMS carries the server event timestamp in milliseconds.
|
||||
TimestampMS int64
|
||||
|
||||
// RequestID optionally correlates the event to the opening client request.
|
||||
RequestID string
|
||||
|
||||
// TraceID optionally carries the client-supplied tracing correlation value.
|
||||
TraceID string
|
||||
|
||||
// PayloadHash is the raw SHA-256 digest of event payload bytes.
|
||||
PayloadHash []byte
|
||||
}
|
||||
|
||||
// BuildEventSigningInput returns the canonical byte sequence the v1 gateway
|
||||
// stream-event signature covers. String and byte fields are length-prefixed
|
||||
// with uvarint(len(field)) followed by raw bytes, while TimestampMS is
|
||||
// appended as an 8-byte big-endian uint64.
|
||||
func BuildEventSigningInput(fields EventSigningFields) []byte {
|
||||
size := len(EventDomainMarkerV1) +
|
||||
len(fields.EventType) +
|
||||
len(fields.EventID) +
|
||||
len(fields.RequestID) +
|
||||
len(fields.TraceID) +
|
||||
len(fields.PayloadHash) +
|
||||
(6 * binary.MaxVarintLen64) +
|
||||
8
|
||||
|
||||
buf := make([]byte, 0, size)
|
||||
buf = appendLengthPrefixedString(buf, EventDomainMarkerV1)
|
||||
buf = appendLengthPrefixedString(buf, fields.EventType)
|
||||
buf = appendLengthPrefixedString(buf, fields.EventID)
|
||||
buf = binary.BigEndian.AppendUint64(buf, uint64(fields.TimestampMS))
|
||||
buf = appendLengthPrefixedString(buf, fields.RequestID)
|
||||
buf = appendLengthPrefixedString(buf, fields.TraceID)
|
||||
buf = appendLengthPrefixedBytes(buf, fields.PayloadHash)
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
// VerifyEventSignature verifies that signature authenticates fields under
|
||||
// publicKey using the canonical v1 event signing input.
|
||||
func VerifyEventSignature(publicKey ed25519.PublicKey, signature []byte, fields EventSigningFields) error {
|
||||
if len(publicKey) != ed25519.PublicKeySize || len(signature) != ed25519.SignatureSize {
|
||||
return ErrInvalidEventSignature
|
||||
}
|
||||
if !ed25519.Verify(publicKey, BuildEventSigningInput(fields), signature) {
|
||||
return ErrInvalidEventSignature
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,111 +0,0 @@
|
||||
package authn
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBuildEventSigningInputChangesWhenSignedFieldChanges(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
base := EventSigningFields{
|
||||
EventType: "gateway.server_time",
|
||||
EventID: "request-123",
|
||||
TimestampMS: 123456789,
|
||||
RequestID: "request-123",
|
||||
TraceID: "trace-123",
|
||||
PayloadHash: mustSHA256([]byte("payload")),
|
||||
}
|
||||
|
||||
baseInput := BuildEventSigningInput(base)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mutate func(EventSigningFields) EventSigningFields
|
||||
}{
|
||||
{
|
||||
name: "event type",
|
||||
mutate: func(fields EventSigningFields) EventSigningFields {
|
||||
fields.EventType = "gateway.other"
|
||||
return fields
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "event id",
|
||||
mutate: func(fields EventSigningFields) EventSigningFields {
|
||||
fields.EventID = "request-456"
|
||||
return fields
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "timestamp",
|
||||
mutate: func(fields EventSigningFields) EventSigningFields {
|
||||
fields.TimestampMS++
|
||||
return fields
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "request id",
|
||||
mutate: func(fields EventSigningFields) EventSigningFields {
|
||||
fields.RequestID = "request-456"
|
||||
return fields
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "trace id",
|
||||
mutate: func(fields EventSigningFields) EventSigningFields {
|
||||
fields.TraceID = "trace-456"
|
||||
return fields
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "payload hash",
|
||||
mutate: func(fields EventSigningFields) EventSigningFields {
|
||||
fields.PayloadHash = mustSHA256([]byte("other"))
|
||||
return fields
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mutated := BuildEventSigningInput(tt.mutate(base))
|
||||
assert.False(t, bytes.Equal(baseInput, mutated))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignAndVerifyEventSignature(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, privateKey, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
signer, err := NewEd25519ResponseSigner(privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
fields := EventSigningFields{
|
||||
EventType: "gateway.server_time",
|
||||
EventID: "request-123",
|
||||
TimestampMS: 123456789,
|
||||
RequestID: "request-123",
|
||||
TraceID: "trace-123",
|
||||
PayloadHash: mustSHA256([]byte("payload")),
|
||||
}
|
||||
|
||||
signature, err := signer.SignEvent(fields)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, VerifyEventSignature(signer.PublicKey(), signature, fields))
|
||||
|
||||
fields.TraceID = "changed"
|
||||
require.ErrorIs(t, VerifyEventSignature(signer.PublicKey(), signature, fields), ErrInvalidEventSignature)
|
||||
}
|
||||
@@ -1,101 +0,0 @@
|
||||
// Package authn defines authenticated transport helpers shared by the gateway
|
||||
// edge verification pipeline.
|
||||
package authn
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
)
|
||||
|
||||
const (
|
||||
// RequestDomainMarkerV1 binds the v1 client request signature to the Galaxy
|
||||
// gateway transport contract.
|
||||
RequestDomainMarkerV1 = "galaxy-request-v1"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrInvalidPayloadHash reports that payloadHash is not a raw SHA-256 digest.
|
||||
ErrInvalidPayloadHash = errors.New("payload_hash must be a 32-byte SHA-256 digest")
|
||||
|
||||
// ErrPayloadHashMismatch reports that payloadHash does not match payloadBytes.
|
||||
ErrPayloadHashMismatch = errors.New("payload_hash does not match payload_bytes")
|
||||
)
|
||||
|
||||
// RequestSigningFields contains the canonical v1 request fields that are bound
|
||||
// into the client signing input after the gateway validates and normalizes the
|
||||
// request envelope.
|
||||
type RequestSigningFields struct {
|
||||
// ProtocolVersion identifies the transport envelope version.
|
||||
ProtocolVersion string
|
||||
|
||||
// DeviceSessionID identifies the authenticated device session bound to the
|
||||
// request.
|
||||
DeviceSessionID string
|
||||
|
||||
// MessageType is the stable downstream routing key.
|
||||
MessageType string
|
||||
|
||||
// TimestampMS carries the client request timestamp in milliseconds.
|
||||
TimestampMS int64
|
||||
|
||||
// RequestID is the transport correlation and anti-replay identifier.
|
||||
RequestID string
|
||||
|
||||
// PayloadHash is the raw SHA-256 digest of payload bytes.
|
||||
PayloadHash []byte
|
||||
}
|
||||
|
||||
// BuildRequestSigningInput returns the canonical byte sequence the v1 client
|
||||
// request signature covers. String and byte fields are length-prefixed with
|
||||
// uvarint(len(field)) followed by raw bytes, while TimestampMS is appended as
|
||||
// an 8-byte big-endian uint64. The caller is expected to pass fields that have
|
||||
// already passed earlier envelope validation.
|
||||
func BuildRequestSigningInput(fields RequestSigningFields) []byte {
|
||||
size := len(RequestDomainMarkerV1) +
|
||||
len(fields.ProtocolVersion) +
|
||||
len(fields.DeviceSessionID) +
|
||||
len(fields.MessageType) +
|
||||
len(fields.RequestID) +
|
||||
len(fields.PayloadHash) +
|
||||
(6 * binary.MaxVarintLen64) +
|
||||
8
|
||||
|
||||
buf := make([]byte, 0, size)
|
||||
buf = appendLengthPrefixedString(buf, RequestDomainMarkerV1)
|
||||
buf = appendLengthPrefixedString(buf, fields.ProtocolVersion)
|
||||
buf = appendLengthPrefixedString(buf, fields.DeviceSessionID)
|
||||
buf = appendLengthPrefixedString(buf, fields.MessageType)
|
||||
buf = binary.BigEndian.AppendUint64(buf, uint64(fields.TimestampMS))
|
||||
buf = appendLengthPrefixedString(buf, fields.RequestID)
|
||||
buf = appendLengthPrefixedBytes(buf, fields.PayloadHash)
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
// VerifyPayloadHash checks that payloadHash is the raw SHA-256 digest of
|
||||
// payloadBytes. Empty payloadBytes are valid and must use sha256.Sum256(nil).
|
||||
func VerifyPayloadHash(payloadBytes, payloadHash []byte) error {
|
||||
if len(payloadHash) != sha256.Size {
|
||||
return ErrInvalidPayloadHash
|
||||
}
|
||||
|
||||
sum := sha256.Sum256(payloadBytes)
|
||||
if !bytes.Equal(sum[:], payloadHash) {
|
||||
return ErrPayloadHashMismatch
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func appendLengthPrefixedString(dst []byte, value string) []byte {
|
||||
return appendLengthPrefixedBytes(dst, []byte(value))
|
||||
}
|
||||
|
||||
func appendLengthPrefixedBytes(dst []byte, value []byte) []byte {
|
||||
dst = binary.AppendUvarint(dst, uint64(len(value)))
|
||||
dst = append(dst, value...)
|
||||
|
||||
return dst
|
||||
}
|
||||
@@ -1,163 +0,0 @@
|
||||
package authn
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestVerifyPayloadHash(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
payloadSum := sha256.Sum256([]byte("payload"))
|
||||
emptySum := sha256.Sum256(nil)
|
||||
otherSum := sha256.Sum256([]byte("other"))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
payload []byte
|
||||
payloadHash []byte
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "matches non-empty payload",
|
||||
payload: []byte("payload"),
|
||||
payloadHash: payloadSum[:],
|
||||
},
|
||||
{
|
||||
name: "matches empty payload",
|
||||
payload: nil,
|
||||
payloadHash: emptySum[:],
|
||||
},
|
||||
{
|
||||
name: "rejects digest with invalid length",
|
||||
payload: []byte("payload"),
|
||||
payloadHash: []byte("short"),
|
||||
wantErr: ErrInvalidPayloadHash,
|
||||
},
|
||||
{
|
||||
name: "rejects digest mismatch",
|
||||
payload: []byte("payload"),
|
||||
payloadHash: otherSum[:],
|
||||
wantErr: ErrPayloadHashMismatch,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := VerifyPayloadHash(tt.payload, tt.payloadHash)
|
||||
if tt.wantErr == nil {
|
||||
require.NoError(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRequestSigningInput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fields := RequestSigningFields{
|
||||
ProtocolVersion: "v1",
|
||||
DeviceSessionID: "device-session-123",
|
||||
MessageType: "fleet.move",
|
||||
TimestampMS: 123456789,
|
||||
RequestID: "request-123",
|
||||
PayloadHash: mustSHA256([]byte("payload")),
|
||||
}
|
||||
|
||||
got := BuildRequestSigningInput(fields)
|
||||
|
||||
want, err := hex.DecodeString("1167616c6178792d726571756573742d7631027631126465766963652d73657373696f6e2d3132330a666c6565742e6d6f766500000000075bcd150b726571756573742d31323320239f59ed55e737c77147cf55ad0c1b030b6d7ee748a7426952f9b852d5a935e5")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestBuildRequestSigningInputChangesWhenSignedFieldChanges(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
base := RequestSigningFields{
|
||||
ProtocolVersion: "v1",
|
||||
DeviceSessionID: "device-session-123",
|
||||
MessageType: "fleet.move",
|
||||
TimestampMS: 123456789,
|
||||
RequestID: "request-123",
|
||||
PayloadHash: mustSHA256([]byte("payload")),
|
||||
}
|
||||
|
||||
baseInput := BuildRequestSigningInput(base)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mutate func(RequestSigningFields) RequestSigningFields
|
||||
}{
|
||||
{
|
||||
name: "protocol version",
|
||||
mutate: func(fields RequestSigningFields) RequestSigningFields {
|
||||
fields.ProtocolVersion = "v2"
|
||||
return fields
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "device session id",
|
||||
mutate: func(fields RequestSigningFields) RequestSigningFields {
|
||||
fields.DeviceSessionID = "device-session-456"
|
||||
return fields
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "message type",
|
||||
mutate: func(fields RequestSigningFields) RequestSigningFields {
|
||||
fields.MessageType = "fleet.attack"
|
||||
return fields
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "timestamp",
|
||||
mutate: func(fields RequestSigningFields) RequestSigningFields {
|
||||
fields.TimestampMS++
|
||||
return fields
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "request id",
|
||||
mutate: func(fields RequestSigningFields) RequestSigningFields {
|
||||
fields.RequestID = "request-456"
|
||||
return fields
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "payload hash",
|
||||
mutate: func(fields RequestSigningFields) RequestSigningFields {
|
||||
fields.PayloadHash = mustSHA256([]byte("other"))
|
||||
return fields
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mutated := BuildRequestSigningInput(tt.mutate(base))
|
||||
assert.False(t, bytes.Equal(baseInput, mutated))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func mustSHA256(payload []byte) []byte {
|
||||
sum := sha256.Sum256(payload)
|
||||
return sum[:]
|
||||
}
|
||||
@@ -1,189 +0,0 @@
|
||||
package authn
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ed25519"
|
||||
"crypto/x509"
|
||||
"encoding/binary"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
const (
|
||||
// ResponseDomainMarkerV1 binds the v1 server response signature to the
|
||||
// Galaxy gateway transport contract.
|
||||
ResponseDomainMarkerV1 = "galaxy-response-v1"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrInvalidResponsePrivateKeyPEM reports that the configured response
|
||||
// signer private key is not a strict PKCS#8 PEM-encoded private key.
|
||||
ErrInvalidResponsePrivateKeyPEM = errors.New("response signer private key is not a valid PKCS#8 PEM block")
|
||||
|
||||
// ErrInvalidResponsePrivateKey reports that the configured response signer
|
||||
// private key is not an Ed25519 private key.
|
||||
ErrInvalidResponsePrivateKey = errors.New("response signer private key must be an Ed25519 PKCS#8 private key")
|
||||
|
||||
// ErrInvalidResponseSignature reports that a server response signature is
|
||||
// not a raw Ed25519 signature for the canonical response signing input.
|
||||
ErrInvalidResponseSignature = errors.New("invalid response signature")
|
||||
)
|
||||
|
||||
// ResponseSigningFields contains the canonical v1 response fields that are
|
||||
// bound into the server signing input.
|
||||
type ResponseSigningFields struct {
|
||||
// ProtocolVersion identifies the transport envelope version.
|
||||
ProtocolVersion string
|
||||
|
||||
// RequestID is the transport correlation identifier copied from the
|
||||
// authenticated request.
|
||||
RequestID string
|
||||
|
||||
// TimestampMS carries the server response timestamp in milliseconds.
|
||||
TimestampMS int64
|
||||
|
||||
// ResultCode is the opaque downstream result code returned to the client.
|
||||
ResultCode string
|
||||
|
||||
// PayloadHash is the raw SHA-256 digest of response payload bytes.
|
||||
PayloadHash []byte
|
||||
}
|
||||
|
||||
// ResponseSigner signs authenticated unary responses and client-facing stream
|
||||
// events with one server-side key.
|
||||
type ResponseSigner interface {
|
||||
// SignResponse returns the raw Ed25519 signature for the canonical response
|
||||
// signing input built from fields.
|
||||
SignResponse(fields ResponseSigningFields) ([]byte, error)
|
||||
|
||||
// SignEvent returns the raw Ed25519 signature for the canonical event
|
||||
// signing input built from fields.
|
||||
SignEvent(fields EventSigningFields) ([]byte, error)
|
||||
}
|
||||
|
||||
// Ed25519ResponseSigner signs authenticated responses with one Ed25519 private
|
||||
// key loaded during process startup.
|
||||
type Ed25519ResponseSigner struct {
|
||||
privateKey ed25519.PrivateKey
|
||||
}
|
||||
|
||||
// NewEd25519ResponseSigner validates privateKey and constructs a signer using
|
||||
// a defensive key copy.
|
||||
func NewEd25519ResponseSigner(privateKey ed25519.PrivateKey) (*Ed25519ResponseSigner, error) {
|
||||
if len(privateKey) != ed25519.PrivateKeySize {
|
||||
return nil, ErrInvalidResponsePrivateKey
|
||||
}
|
||||
|
||||
return &Ed25519ResponseSigner{
|
||||
privateKey: bytes.Clone(privateKey),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// LoadEd25519ResponseSignerFromPEMFile loads a strict PKCS#8 PEM-encoded
|
||||
// Ed25519 private key from path and constructs a signer.
|
||||
func LoadEd25519ResponseSignerFromPEMFile(path string) (*Ed25519ResponseSigner, error) {
|
||||
pemBytes, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response signer private key PEM: %w", err)
|
||||
}
|
||||
|
||||
signer, err := ParseEd25519ResponseSignerPEM(pemBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return signer, nil
|
||||
}
|
||||
|
||||
// ParseEd25519ResponseSignerPEM parses one strict PKCS#8 PEM-encoded Ed25519
|
||||
// private key and constructs a signer from it.
|
||||
func ParseEd25519ResponseSignerPEM(pemBytes []byte) (*Ed25519ResponseSigner, error) {
|
||||
block, rest := pem.Decode(pemBytes)
|
||||
if block == nil || block.Type != "PRIVATE KEY" || len(bytes.TrimSpace(rest)) > 0 {
|
||||
return nil, ErrInvalidResponsePrivateKeyPEM
|
||||
}
|
||||
|
||||
parsedKey, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidResponsePrivateKeyPEM
|
||||
}
|
||||
|
||||
privateKey, ok := parsedKey.(ed25519.PrivateKey)
|
||||
if !ok {
|
||||
return nil, ErrInvalidResponsePrivateKey
|
||||
}
|
||||
|
||||
return NewEd25519ResponseSigner(privateKey)
|
||||
}
|
||||
|
||||
// PublicKey returns the Ed25519 public key that corresponds to the configured
|
||||
// response signer private key.
|
||||
func (s *Ed25519ResponseSigner) PublicKey() ed25519.PublicKey {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
publicKey, _ := s.privateKey.Public().(ed25519.PublicKey)
|
||||
return bytes.Clone(publicKey)
|
||||
}
|
||||
|
||||
// SignResponse signs the canonical v1 response signing input built from
|
||||
// fields.
|
||||
func (s *Ed25519ResponseSigner) SignResponse(fields ResponseSigningFields) ([]byte, error) {
|
||||
if s == nil || len(s.privateKey) != ed25519.PrivateKeySize {
|
||||
return nil, ErrInvalidResponsePrivateKey
|
||||
}
|
||||
|
||||
signature := ed25519.Sign(s.privateKey, BuildResponseSigningInput(fields))
|
||||
return bytes.Clone(signature), nil
|
||||
}
|
||||
|
||||
// SignEvent signs the canonical v1 stream-event signing input built from
|
||||
// fields.
|
||||
func (s *Ed25519ResponseSigner) SignEvent(fields EventSigningFields) ([]byte, error) {
|
||||
if s == nil || len(s.privateKey) != ed25519.PrivateKeySize {
|
||||
return nil, ErrInvalidResponsePrivateKey
|
||||
}
|
||||
|
||||
signature := ed25519.Sign(s.privateKey, BuildEventSigningInput(fields))
|
||||
return bytes.Clone(signature), nil
|
||||
}
|
||||
|
||||
// BuildResponseSigningInput returns the canonical byte sequence the v1 server
|
||||
// response signature covers. String and byte fields are length-prefixed with
|
||||
// uvarint(len(field)) followed by raw bytes, while TimestampMS is appended as
|
||||
// an 8-byte big-endian uint64.
|
||||
func BuildResponseSigningInput(fields ResponseSigningFields) []byte {
|
||||
size := len(ResponseDomainMarkerV1) +
|
||||
len(fields.ProtocolVersion) +
|
||||
len(fields.RequestID) +
|
||||
len(fields.ResultCode) +
|
||||
len(fields.PayloadHash) +
|
||||
(5 * binary.MaxVarintLen64) +
|
||||
8
|
||||
|
||||
buf := make([]byte, 0, size)
|
||||
buf = appendLengthPrefixedString(buf, ResponseDomainMarkerV1)
|
||||
buf = appendLengthPrefixedString(buf, fields.ProtocolVersion)
|
||||
buf = appendLengthPrefixedString(buf, fields.RequestID)
|
||||
buf = binary.BigEndian.AppendUint64(buf, uint64(fields.TimestampMS))
|
||||
buf = appendLengthPrefixedString(buf, fields.ResultCode)
|
||||
buf = appendLengthPrefixedBytes(buf, fields.PayloadHash)
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
// VerifyResponseSignature verifies that signature authenticates fields under
|
||||
// publicKey using the canonical v1 response signing input.
|
||||
func VerifyResponseSignature(publicKey ed25519.PublicKey, signature []byte, fields ResponseSigningFields) error {
|
||||
if len(publicKey) != ed25519.PublicKeySize || len(signature) != ed25519.SignatureSize {
|
||||
return ErrInvalidResponseSignature
|
||||
}
|
||||
if !ed25519.Verify(publicKey, BuildResponseSigningInput(fields), signature) {
|
||||
return ErrInvalidResponseSignature
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,146 +0,0 @@
|
||||
package authn
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBuildResponseSigningInputChangesWhenSignedFieldChanges(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
base := ResponseSigningFields{
|
||||
ProtocolVersion: "v1",
|
||||
RequestID: "request-123",
|
||||
TimestampMS: 123456789,
|
||||
ResultCode: "ok",
|
||||
PayloadHash: mustSHA256([]byte("payload")),
|
||||
}
|
||||
|
||||
baseInput := BuildResponseSigningInput(base)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mutate func(ResponseSigningFields) ResponseSigningFields
|
||||
}{
|
||||
{
|
||||
name: "protocol version",
|
||||
mutate: func(fields ResponseSigningFields) ResponseSigningFields {
|
||||
fields.ProtocolVersion = "v2"
|
||||
return fields
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "request id",
|
||||
mutate: func(fields ResponseSigningFields) ResponseSigningFields {
|
||||
fields.RequestID = "request-456"
|
||||
return fields
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "timestamp",
|
||||
mutate: func(fields ResponseSigningFields) ResponseSigningFields {
|
||||
fields.TimestampMS++
|
||||
return fields
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "result code",
|
||||
mutate: func(fields ResponseSigningFields) ResponseSigningFields {
|
||||
fields.ResultCode = "denied"
|
||||
return fields
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "payload hash",
|
||||
mutate: func(fields ResponseSigningFields) ResponseSigningFields {
|
||||
fields.PayloadHash = mustSHA256([]byte("other"))
|
||||
return fields
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mutated := BuildResponseSigningInput(tt.mutate(base))
|
||||
assert.False(t, bytes.Equal(baseInput, mutated))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseEd25519ResponseSignerPEMRejectsMalformedPEM(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := ParseEd25519ResponseSignerPEM([]byte("not-pem"))
|
||||
require.ErrorIs(t, err, ErrInvalidResponsePrivateKeyPEM)
|
||||
}
|
||||
|
||||
func TestParseEd25519ResponseSignerPEMRejectsNonPKCS8PEM(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, privateKey, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
pemBytes, err := x509.MarshalPKCS8PrivateKey(privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
block := pem.Block{
|
||||
Type: "ED25519 PRIVATE KEY",
|
||||
Bytes: pemBytes,
|
||||
}
|
||||
|
||||
_, err = ParseEd25519ResponseSignerPEM(pem.EncodeToMemory(&block))
|
||||
require.ErrorIs(t, err, ErrInvalidResponsePrivateKeyPEM)
|
||||
}
|
||||
|
||||
func TestParseEd25519ResponseSignerPEMRejectsNonEd25519Key(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
require.NoError(t, err)
|
||||
|
||||
pemBytes, err := x509.MarshalPKCS8PrivateKey(privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ParseEd25519ResponseSignerPEM(pem.EncodeToMemory(&pem.Block{
|
||||
Type: "PRIVATE KEY",
|
||||
Bytes: pemBytes,
|
||||
}))
|
||||
require.ErrorIs(t, err, ErrInvalidResponsePrivateKey)
|
||||
}
|
||||
|
||||
func TestSignAndVerifyResponseSignature(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, privateKey, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
signer, err := NewEd25519ResponseSigner(privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
fields := ResponseSigningFields{
|
||||
ProtocolVersion: "v1",
|
||||
RequestID: "request-123",
|
||||
TimestampMS: 123456789,
|
||||
ResultCode: "ok",
|
||||
PayloadHash: mustSHA256([]byte("payload")),
|
||||
}
|
||||
|
||||
signature, err := signer.SignResponse(fields)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, VerifyResponseSignature(signer.PublicKey(), signature, fields))
|
||||
|
||||
fields.ResultCode = "changed"
|
||||
require.ErrorIs(t, VerifyResponseSignature(signer.PublicKey(), signature, fields), ErrInvalidResponseSignature)
|
||||
}
|
||||
@@ -1,47 +0,0 @@
|
||||
package authn
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrInvalidClientPublicKey reports that cached client public key material
|
||||
// is not a base64-encoded raw Ed25519 public key.
|
||||
ErrInvalidClientPublicKey = errors.New("client_public_key is not a valid base64-encoded Ed25519 public key")
|
||||
|
||||
// ErrInvalidRequestSignature reports that a request signature is not a raw
|
||||
// Ed25519 signature for the canonical request signing input.
|
||||
ErrInvalidRequestSignature = errors.New("invalid request signature")
|
||||
)
|
||||
|
||||
// VerifyRequestSignature validates the base64-encoded raw Ed25519 public key
|
||||
// from session cache, builds the canonical v1 signing input from fields, and
|
||||
// verifies that signature authenticates the request.
|
||||
func VerifyRequestSignature(clientPublicKey string, signature []byte, fields RequestSigningFields) error {
|
||||
publicKey, err := decodeClientPublicKey(clientPublicKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(signature) != ed25519.SignatureSize {
|
||||
return ErrInvalidRequestSignature
|
||||
}
|
||||
if !ed25519.Verify(publicKey, BuildRequestSigningInput(fields), signature) {
|
||||
return ErrInvalidRequestSignature
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeClientPublicKey(value string) (ed25519.PublicKey, error) {
|
||||
decoded, err := base64.StdEncoding.Strict().DecodeString(value)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidClientPublicKey
|
||||
}
|
||||
if len(decoded) != ed25519.PublicKeySize {
|
||||
return nil, ErrInvalidClientPublicKey
|
||||
}
|
||||
|
||||
return ed25519.PublicKey(decoded), nil
|
||||
}
|
||||
@@ -1,137 +0,0 @@
|
||||
package authn
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestVerifyRequestSignature(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
clientPrivateKey := newTestPrivateKey("primary")
|
||||
clientPublicKey := clientPrivateKey.Public().(ed25519.PublicKey)
|
||||
otherPrivateKey := newTestPrivateKey("other")
|
||||
|
||||
fields := RequestSigningFields{
|
||||
ProtocolVersion: "v1",
|
||||
DeviceSessionID: "device-session-123",
|
||||
MessageType: "fleet.move",
|
||||
TimestampMS: 123456789,
|
||||
RequestID: "request-123",
|
||||
PayloadHash: mustSHA256([]byte("payload")),
|
||||
}
|
||||
|
||||
signature := ed25519.Sign(clientPrivateKey, BuildRequestSigningInput(fields))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
clientPublicKey string
|
||||
signature []byte
|
||||
fields RequestSigningFields
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "valid signature",
|
||||
clientPublicKey: base64.StdEncoding.EncodeToString(clientPublicKey),
|
||||
signature: signature,
|
||||
fields: fields,
|
||||
},
|
||||
{
|
||||
name: "message type change rejects signature",
|
||||
clientPublicKey: base64.StdEncoding.EncodeToString(clientPublicKey),
|
||||
signature: signature,
|
||||
fields: func() RequestSigningFields {
|
||||
mutated := fields
|
||||
mutated.MessageType = "fleet.attack"
|
||||
return mutated
|
||||
}(),
|
||||
wantErr: ErrInvalidRequestSignature,
|
||||
},
|
||||
{
|
||||
name: "request id change rejects signature",
|
||||
clientPublicKey: base64.StdEncoding.EncodeToString(clientPublicKey),
|
||||
signature: signature,
|
||||
fields: func() RequestSigningFields {
|
||||
mutated := fields
|
||||
mutated.RequestID = "request-456"
|
||||
return mutated
|
||||
}(),
|
||||
wantErr: ErrInvalidRequestSignature,
|
||||
},
|
||||
{
|
||||
name: "payload hash change rejects signature",
|
||||
clientPublicKey: base64.StdEncoding.EncodeToString(clientPublicKey),
|
||||
signature: signature,
|
||||
fields: func() RequestSigningFields {
|
||||
mutated := fields
|
||||
mutated.PayloadHash = mustSHA256([]byte("other"))
|
||||
return mutated
|
||||
}(),
|
||||
wantErr: ErrInvalidRequestSignature,
|
||||
},
|
||||
{
|
||||
name: "wrong key rejects signature",
|
||||
clientPublicKey: base64.StdEncoding.EncodeToString(otherPrivateKey.Public().(ed25519.PublicKey)),
|
||||
signature: signature,
|
||||
fields: fields,
|
||||
wantErr: ErrInvalidRequestSignature,
|
||||
},
|
||||
{
|
||||
name: "bit flipped signature rejects",
|
||||
clientPublicKey: base64.StdEncoding.EncodeToString(clientPublicKey),
|
||||
signature: func() []byte {
|
||||
corrupted := append([]byte(nil), signature...)
|
||||
corrupted[0] ^= 0xff
|
||||
return corrupted
|
||||
}(),
|
||||
fields: fields,
|
||||
wantErr: ErrInvalidRequestSignature,
|
||||
},
|
||||
{
|
||||
name: "invalid signature length rejects",
|
||||
clientPublicKey: base64.StdEncoding.EncodeToString(clientPublicKey),
|
||||
signature: signature[:len(signature)-1],
|
||||
fields: fields,
|
||||
wantErr: ErrInvalidRequestSignature,
|
||||
},
|
||||
{
|
||||
name: "invalid base64 public key rejects",
|
||||
clientPublicKey: "%%%not-base64%%%",
|
||||
signature: signature,
|
||||
fields: fields,
|
||||
wantErr: ErrInvalidClientPublicKey,
|
||||
},
|
||||
{
|
||||
name: "invalid public key length rejects",
|
||||
clientPublicKey: base64.StdEncoding.EncodeToString([]byte("short")),
|
||||
signature: signature,
|
||||
fields: fields,
|
||||
wantErr: ErrInvalidClientPublicKey,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := VerifyRequestSignature(tt.clientPublicKey, tt.signature, tt.fields)
|
||||
if tt.wantErr == nil {
|
||||
require.NoError(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newTestPrivateKey(label string) ed25519.PrivateKey {
|
||||
seed := sha256.Sum256([]byte("gateway-authn-signature-test-" + label))
|
||||
return ed25519.NewKeyFromSeed(seed[:])
|
||||
}
|
||||
@@ -0,0 +1,138 @@
|
||||
package backendclient
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Config describes the backend endpoint and gateway client identity used
|
||||
// to construct a Client. All fields are required when the gateway is
|
||||
// expected to talk to a real backend; the empty value yields an
|
||||
// always-unavailable client.
|
||||
type Config struct {
|
||||
// HTTPBaseURL is the absolute base URL of the backend HTTP listener
|
||||
// (`/api/v1/{public,user,internal}/*`). Required.
|
||||
HTTPBaseURL string
|
||||
|
||||
// GRPCPushURL is the dial target of the backend `Push.SubscribePush`
|
||||
// listener (`host:port`). Required.
|
||||
GRPCPushURL string
|
||||
|
||||
// GatewayClientID is the durable identifier this gateway instance
|
||||
// presents to backend in `GatewaySubscribeRequest.gateway_client_id`.
|
||||
// Required.
|
||||
GatewayClientID string
|
||||
|
||||
// HTTPTimeout bounds individual REST calls. Must be positive.
|
||||
HTTPTimeout time.Duration
|
||||
|
||||
// PushReconnectBaseBackoff is the starting delay between reconnect
|
||||
// attempts of `Push.SubscribePush`. Must be positive.
|
||||
PushReconnectBaseBackoff time.Duration
|
||||
|
||||
// PushReconnectMaxBackoff is the upper bound for exponential
|
||||
// reconnect delays. Must be greater than or equal to
|
||||
// PushReconnectBaseBackoff.
|
||||
PushReconnectMaxBackoff time.Duration
|
||||
}
|
||||
|
||||
// Validate reports a formatted error when cfg is missing required
|
||||
// values. The empty value is invalid; callers that intentionally omit
|
||||
// the backend may bypass this check by skipping NewClient entirely.
|
||||
func (cfg Config) Validate() error {
|
||||
trimmed := strings.TrimSpace(cfg.HTTPBaseURL)
|
||||
if trimmed == "" {
|
||||
return errors.New("backendclient: HTTPBaseURL must not be empty")
|
||||
}
|
||||
parsed, err := url.Parse(strings.TrimRight(trimmed, "/"))
|
||||
if err != nil {
|
||||
return fmt.Errorf("backendclient: parse HTTPBaseURL: %w", err)
|
||||
}
|
||||
if parsed.Scheme == "" || parsed.Host == "" {
|
||||
return errors.New("backendclient: HTTPBaseURL must be absolute")
|
||||
}
|
||||
if strings.TrimSpace(cfg.GRPCPushURL) == "" {
|
||||
return errors.New("backendclient: GRPCPushURL must not be empty")
|
||||
}
|
||||
if strings.TrimSpace(cfg.GatewayClientID) == "" {
|
||||
return errors.New("backendclient: GatewayClientID must not be empty")
|
||||
}
|
||||
if cfg.HTTPTimeout <= 0 {
|
||||
return errors.New("backendclient: HTTPTimeout must be positive")
|
||||
}
|
||||
if cfg.PushReconnectBaseBackoff <= 0 {
|
||||
return errors.New("backendclient: PushReconnectBaseBackoff must be positive")
|
||||
}
|
||||
if cfg.PushReconnectMaxBackoff < cfg.PushReconnectBaseBackoff {
|
||||
return errors.New("backendclient: PushReconnectMaxBackoff must be >= PushReconnectBaseBackoff")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Client aggregates the REST and gRPC adapters that talk to backend.
|
||||
// One value is shared across the gateway process; all methods are safe
|
||||
// for concurrent use.
|
||||
type Client struct {
|
||||
rest *RESTClient
|
||||
push *PushClient
|
||||
}
|
||||
|
||||
// NewClient constructs a Client that targets the configured backend.
|
||||
// REST adapter is always built. The gRPC push adapter is built lazily
|
||||
// when StartPush is called so unit tests can construct a Client with a
|
||||
// stubbed push transport.
|
||||
func NewClient(cfg Config) (*Client, error) {
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rest, err := NewRESTClient(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
push, err := NewPushClient(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Client{rest: rest, push: push}, nil
|
||||
}
|
||||
|
||||
// REST returns the REST adapter. The returned value is nil when the
|
||||
// Client was constructed without a backend; callers must guard.
|
||||
func (c *Client) REST() *RESTClient {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
return c.rest
|
||||
}
|
||||
|
||||
// Push returns the gRPC push adapter. The returned value is nil when
|
||||
// the Client was constructed without a backend.
|
||||
func (c *Client) Push() *PushClient {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
return c.push
|
||||
}
|
||||
|
||||
// Close releases idle HTTP connections and closes the gRPC push
|
||||
// connection. Safe to call multiple times.
|
||||
func (c *Client) Close() error {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
var firstErr error
|
||||
if c.rest != nil {
|
||||
if err := c.rest.Close(); err != nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
if c.push != nil {
|
||||
if err := c.push.Close(); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
return firstErr
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
// Package backendclient is the gateway-side adapter to the consolidated
|
||||
// `backend` service. It bundles every gateway → backend conversation:
|
||||
//
|
||||
// - public REST (`/api/v1/public/auth/*`) used by the public auth
|
||||
// surface,
|
||||
// - internal REST (`/api/v1/internal/sessions/*`,
|
||||
// `/api/v1/internal/users/*/account-internal`) used by the
|
||||
// authenticated request pipeline,
|
||||
// - authenticated user REST (`/api/v1/user/*`) used by the gRPC
|
||||
// downstream router after envelope verification,
|
||||
// - gRPC `Push.SubscribePush` used to receive `client_event` and
|
||||
// `session_invalidation` frames from backend.
|
||||
//
|
||||
// One env-driven Config describes the backend endpoint and the gateway
|
||||
// client identity. A single Client value is wired by `cmd/gateway` and
|
||||
// shared by all consumers (rest API public auth handler, gRPC session
|
||||
// cache, downstream user/lobby routes, and the push subscriber).
|
||||
package backendclient
|
||||
@@ -0,0 +1,197 @@
|
||||
package backendclient
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"galaxy/gateway/internal/downstream"
|
||||
lobbymodel "galaxy/model/lobby"
|
||||
"galaxy/transcoder"
|
||||
)
|
||||
|
||||
const (
|
||||
lobbyResultCodeOK = "ok"
|
||||
defaultLobbyErrorCodeInvalid = "invalid_request"
|
||||
defaultLobbyErrorCodeNoSubj = "subject_not_found"
|
||||
defaultLobbyErrorCodeForbid = "forbidden"
|
||||
defaultLobbyErrorCodeConfl = "conflict"
|
||||
defaultLobbyErrorCodeIntErr = "internal_error"
|
||||
)
|
||||
|
||||
var stableLobbyErrorMessages = map[string]string{
|
||||
defaultLobbyErrorCodeInvalid: "request is invalid",
|
||||
defaultLobbyErrorCodeNoSubj: "subject not found",
|
||||
defaultLobbyErrorCodeForbid: "operation is forbidden for the calling user",
|
||||
defaultLobbyErrorCodeConfl: "request conflicts with current state",
|
||||
defaultLobbyErrorCodeIntErr: "internal server error",
|
||||
}
|
||||
|
||||
// ExecuteLobbyCommand routes one authenticated lobby command into
|
||||
// backend's `/api/v1/user/lobby/*` endpoints.
|
||||
func (c *RESTClient) ExecuteLobbyCommand(ctx context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
|
||||
if c == nil || c.httpClient == nil {
|
||||
return downstream.UnaryResult{}, errors.New("backendclient: execute lobby command: nil client")
|
||||
}
|
||||
if ctx == nil {
|
||||
return downstream.UnaryResult{}, errors.New("backendclient: execute lobby command: nil context")
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return downstream.UnaryResult{}, err
|
||||
}
|
||||
if strings.TrimSpace(command.UserID) == "" {
|
||||
return downstream.UnaryResult{}, errors.New("backendclient: execute lobby command: user_id must not be empty")
|
||||
}
|
||||
|
||||
switch command.MessageType {
|
||||
case lobbymodel.MessageTypeMyGamesList:
|
||||
if _, err := transcoder.PayloadToMyGamesListRequest(command.PayloadBytes); err != nil {
|
||||
return downstream.UnaryResult{}, fmt.Errorf("backendclient: execute lobby command %q: %w", command.MessageType, err)
|
||||
}
|
||||
return c.executeLobbyMyGames(ctx, command.UserID)
|
||||
case lobbymodel.MessageTypeOpenEnrollment:
|
||||
req, err := transcoder.PayloadToOpenEnrollmentRequest(command.PayloadBytes)
|
||||
if err != nil {
|
||||
return downstream.UnaryResult{}, fmt.Errorf("backendclient: execute lobby command %q: %w", command.MessageType, err)
|
||||
}
|
||||
return c.executeLobbyOpenEnrollment(ctx, command.UserID, req)
|
||||
default:
|
||||
return downstream.UnaryResult{}, fmt.Errorf("backendclient: execute lobby command: unsupported message type %q", command.MessageType)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RESTClient) executeLobbyMyGames(ctx context.Context, userID string) (downstream.UnaryResult, error) {
|
||||
body, status, err := c.do(ctx, http.MethodGet, c.baseURL+"/api/v1/user/lobby/my/games", userID, nil)
|
||||
if err != nil {
|
||||
return downstream.UnaryResult{}, fmt.Errorf("execute lobby.my.games.list: %w", err)
|
||||
}
|
||||
if status == http.StatusOK {
|
||||
var response lobbymodel.MyGamesListResponse
|
||||
if err := decodeStrictJSON(body, &response); err != nil {
|
||||
return downstream.UnaryResult{}, fmt.Errorf("decode success response: %w", err)
|
||||
}
|
||||
payloadBytes, err := transcoder.MyGamesListResponseToPayload(&response)
|
||||
if err != nil {
|
||||
return downstream.UnaryResult{}, fmt.Errorf("encode success response payload: %w", err)
|
||||
}
|
||||
return downstream.UnaryResult{
|
||||
ResultCode: lobbyResultCodeOK,
|
||||
PayloadBytes: payloadBytes,
|
||||
}, nil
|
||||
}
|
||||
return projectLobbyErrorResponse(status, body)
|
||||
}
|
||||
|
||||
func (c *RESTClient) executeLobbyOpenEnrollment(ctx context.Context, userID string, req *lobbymodel.OpenEnrollmentRequest) (downstream.UnaryResult, error) {
|
||||
if req == nil || strings.TrimSpace(req.GameID) == "" {
|
||||
return downstream.UnaryResult{}, errors.New("execute lobby.game.open-enrollment: game_id must not be empty")
|
||||
}
|
||||
target := c.baseURL + "/api/v1/user/lobby/games/" + url.PathEscape(req.GameID) + "/open-enrollment"
|
||||
body, status, err := c.do(ctx, http.MethodPost, target, userID, struct{}{})
|
||||
if err != nil {
|
||||
return downstream.UnaryResult{}, fmt.Errorf("execute lobby.game.open-enrollment: %w", err)
|
||||
}
|
||||
if status == http.StatusOK {
|
||||
// Backend returns the full LobbyGameDetail; gateway projects the
|
||||
// minimal {game_id, status} pair onto the existing wire shape.
|
||||
var detail struct {
|
||||
GameID string `json:"game_id"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
if err := json.NewDecoder(bytes.NewReader(body)).Decode(&detail); err != nil {
|
||||
return downstream.UnaryResult{}, fmt.Errorf("decode success response: %w", err)
|
||||
}
|
||||
payloadBytes, err := transcoder.OpenEnrollmentResponseToPayload(&lobbymodel.OpenEnrollmentResponse{
|
||||
GameID: detail.GameID,
|
||||
Status: detail.Status,
|
||||
})
|
||||
if err != nil {
|
||||
return downstream.UnaryResult{}, fmt.Errorf("encode success response payload: %w", err)
|
||||
}
|
||||
return downstream.UnaryResult{
|
||||
ResultCode: lobbyResultCodeOK,
|
||||
PayloadBytes: payloadBytes,
|
||||
}, nil
|
||||
}
|
||||
return projectLobbyErrorResponse(status, body)
|
||||
}
|
||||
|
||||
func projectLobbyErrorResponse(statusCode int, payload []byte) (downstream.UnaryResult, error) {
|
||||
switch {
|
||||
case statusCode == http.StatusServiceUnavailable:
|
||||
return downstream.UnaryResult{}, downstream.ErrDownstreamUnavailable
|
||||
case statusCode >= 400 && statusCode <= 599:
|
||||
errResp, err := decodeLobbyError(statusCode, payload)
|
||||
if err != nil {
|
||||
return downstream.UnaryResult{}, fmt.Errorf("decode error response: %w", err)
|
||||
}
|
||||
payloadBytes, err := transcoder.LobbyErrorResponseToPayload(errResp)
|
||||
if err != nil {
|
||||
return downstream.UnaryResult{}, fmt.Errorf("encode error response payload: %w", err)
|
||||
}
|
||||
return downstream.UnaryResult{
|
||||
ResultCode: errResp.Error.Code,
|
||||
PayloadBytes: payloadBytes,
|
||||
}, nil
|
||||
default:
|
||||
return downstream.UnaryResult{}, fmt.Errorf("unexpected HTTP status %d", statusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func decodeLobbyError(statusCode int, payload []byte) (*lobbymodel.ErrorResponse, error) {
|
||||
var response lobbymodel.ErrorResponse
|
||||
decoder := json.NewDecoder(bytes.NewReader(payload))
|
||||
decoder.DisallowUnknownFields()
|
||||
if err := decoder.Decode(&response); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := decoder.Decode(&struct{}{}); err != io.EOF {
|
||||
if err == nil {
|
||||
return nil, errors.New("unexpected trailing JSON input")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
response.Error.Code = normalizeLobbyErrorCode(statusCode, response.Error.Code)
|
||||
response.Error.Message = normalizeLobbyErrorMessage(response.Error.Code, response.Error.Message)
|
||||
if strings.TrimSpace(response.Error.Code) == "" {
|
||||
return nil, errors.New("missing error code")
|
||||
}
|
||||
if strings.TrimSpace(response.Error.Message) == "" {
|
||||
return nil, errors.New("missing error message")
|
||||
}
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
func normalizeLobbyErrorCode(statusCode int, code string) string {
|
||||
if trimmed := strings.TrimSpace(code); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
switch statusCode {
|
||||
case http.StatusBadRequest:
|
||||
return defaultLobbyErrorCodeInvalid
|
||||
case http.StatusForbidden:
|
||||
return defaultLobbyErrorCodeForbid
|
||||
case http.StatusNotFound:
|
||||
return defaultLobbyErrorCodeNoSubj
|
||||
case http.StatusConflict:
|
||||
return defaultLobbyErrorCodeConfl
|
||||
default:
|
||||
return defaultLobbyErrorCodeIntErr
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeLobbyErrorMessage(code, message string) string {
|
||||
if trimmed := strings.TrimSpace(message); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
if stable, ok := stableLobbyErrorMessages[code]; ok {
|
||||
return stable
|
||||
}
|
||||
return stableLobbyErrorMessages[defaultLobbyErrorCodeIntErr]
|
||||
}
|
||||
@@ -0,0 +1,148 @@
|
||||
package backendclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// SendEmailCodeInput is the public REST and adapter payload used to
|
||||
// request a login code for a single e-mail address.
|
||||
type SendEmailCodeInput struct {
|
||||
Email string `json:"email"`
|
||||
PreferredLanguage string `json:"-"`
|
||||
}
|
||||
|
||||
// SendEmailCodeResult is the public REST and adapter payload returned
|
||||
// after backend creates a login challenge.
|
||||
type SendEmailCodeResult struct {
|
||||
ChallengeID string `json:"challenge_id"`
|
||||
}
|
||||
|
||||
// ConfirmEmailCodeInput is the public REST and adapter payload used to
|
||||
// complete a previously issued login challenge.
|
||||
type ConfirmEmailCodeInput struct {
|
||||
ChallengeID string `json:"challenge_id"`
|
||||
Code string `json:"code"`
|
||||
ClientPublicKey string `json:"client_public_key"`
|
||||
TimeZone string `json:"time_zone"`
|
||||
}
|
||||
|
||||
// ConfirmEmailCodeResult is the public REST and adapter payload
|
||||
// returned after backend creates a device session.
|
||||
type ConfirmEmailCodeResult struct {
|
||||
DeviceSessionID string `json:"device_session_id"`
|
||||
}
|
||||
|
||||
// AuthError lets a public REST handler project a stable error envelope
|
||||
// without re-deriving backend semantics. StatusCode is the HTTP status
|
||||
// the gateway should return; Code and Message form the JSON envelope.
|
||||
type AuthError struct {
|
||||
StatusCode int
|
||||
Code string
|
||||
Message string
|
||||
}
|
||||
|
||||
// Error returns a readable representation of the projected auth error.
|
||||
func (e *AuthError) Error() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("backendclient auth error: status=%d code=%s message=%s", e.StatusCode, e.Code, e.Message)
|
||||
}
|
||||
|
||||
// SendEmailCode delegates the public send-email-code route to backend.
|
||||
func (c *RESTClient) SendEmailCode(ctx context.Context, input SendEmailCodeInput) (SendEmailCodeResult, error) {
|
||||
if strings.TrimSpace(input.Email) == "" {
|
||||
return SendEmailCodeResult{}, errors.New("backendclient: send email code: email must not be empty")
|
||||
}
|
||||
body, status, err := c.doWithHeaders(ctx, http.MethodPost, c.baseURL+"/api/v1/public/auth/send-email-code", "", input, map[string]string{
|
||||
"Accept-Language": resolvePreferredLanguage(input.PreferredLanguage),
|
||||
})
|
||||
if err != nil {
|
||||
return SendEmailCodeResult{}, fmt.Errorf("backendclient: send email code: %w", err)
|
||||
}
|
||||
switch {
|
||||
case status == http.StatusOK:
|
||||
var result SendEmailCodeResult
|
||||
if err := decodeStrictJSON(body, &result); err != nil {
|
||||
return SendEmailCodeResult{}, fmt.Errorf("backendclient: send email code: decode success response: %w", err)
|
||||
}
|
||||
if strings.TrimSpace(result.ChallengeID) == "" {
|
||||
return SendEmailCodeResult{}, errors.New("backendclient: send email code: challenge_id must not be empty")
|
||||
}
|
||||
return result, nil
|
||||
case status >= 400 && status <= 599:
|
||||
authErr, decodeErr := decodeAuthError(status, body)
|
||||
if decodeErr != nil {
|
||||
return SendEmailCodeResult{}, fmt.Errorf("backendclient: send email code: %w", decodeErr)
|
||||
}
|
||||
return SendEmailCodeResult{}, authErr
|
||||
default:
|
||||
return SendEmailCodeResult{}, fmt.Errorf("backendclient: send email code: unexpected HTTP status %d", status)
|
||||
}
|
||||
}
|
||||
|
||||
// ConfirmEmailCode delegates the public confirm-email-code route to
|
||||
// backend.
|
||||
func (c *RESTClient) ConfirmEmailCode(ctx context.Context, input ConfirmEmailCodeInput) (ConfirmEmailCodeResult, error) {
|
||||
if strings.TrimSpace(input.ChallengeID) == "" {
|
||||
return ConfirmEmailCodeResult{}, errors.New("backendclient: confirm email code: challenge_id must not be empty")
|
||||
}
|
||||
body, status, err := c.doWithHeaders(ctx, http.MethodPost, c.baseURL+"/api/v1/public/auth/confirm-email-code", "", input, nil)
|
||||
if err != nil {
|
||||
return ConfirmEmailCodeResult{}, fmt.Errorf("backendclient: confirm email code: %w", err)
|
||||
}
|
||||
switch {
|
||||
case status == http.StatusOK:
|
||||
var result ConfirmEmailCodeResult
|
||||
if err := decodeStrictJSON(body, &result); err != nil {
|
||||
return ConfirmEmailCodeResult{}, fmt.Errorf("backendclient: confirm email code: decode success response: %w", err)
|
||||
}
|
||||
if strings.TrimSpace(result.DeviceSessionID) == "" {
|
||||
return ConfirmEmailCodeResult{}, errors.New("backendclient: confirm email code: device_session_id must not be empty")
|
||||
}
|
||||
return result, nil
|
||||
case status >= 400 && status <= 599:
|
||||
authErr, decodeErr := decodeAuthError(status, body)
|
||||
if decodeErr != nil {
|
||||
return ConfirmEmailCodeResult{}, fmt.Errorf("backendclient: confirm email code: %w", decodeErr)
|
||||
}
|
||||
return ConfirmEmailCodeResult{}, authErr
|
||||
default:
|
||||
return ConfirmEmailCodeResult{}, fmt.Errorf("backendclient: confirm email code: unexpected HTTP status %d", status)
|
||||
}
|
||||
}
|
||||
|
||||
// resolvePreferredLanguage returns a non-empty Accept-Language value or
|
||||
// the empty string when input is unset; downstream HTTP request helpers
|
||||
// drop the header on empty values.
|
||||
func resolvePreferredLanguage(preferred string) string {
|
||||
return strings.TrimSpace(preferred)
|
||||
}
|
||||
|
||||
type authErrorEnvelope struct {
|
||||
Error *authErrorBody `json:"error"`
|
||||
}
|
||||
|
||||
type authErrorBody struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func decodeAuthError(statusCode int, payload []byte) (*AuthError, error) {
|
||||
var envelope authErrorEnvelope
|
||||
if err := decodeStrictJSON(payload, &envelope); err != nil {
|
||||
return nil, fmt.Errorf("decode error response: %w", err)
|
||||
}
|
||||
if envelope.Error == nil {
|
||||
return nil, errors.New("decode error response: missing error object")
|
||||
}
|
||||
return &AuthError{
|
||||
StatusCode: statusCode,
|
||||
Code: envelope.Error.Code,
|
||||
Message: envelope.Error.Message,
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,266 @@
|
||||
// PushClient — gateway-side gRPC consumer of `Push.SubscribePush`.
|
||||
//
|
||||
// One PushClient is wired for the gateway lifecycle. Run keeps the
|
||||
// subscription open, reconnects on every transport error with
|
||||
// exponential backoff (capped at PushReconnectMaxBackoff), and forwards
|
||||
// every received PushEvent to the configured EventHandler. The cursor
|
||||
// of the last successfully handled event is remembered in process
|
||||
// memory only (see `backend/README.md` and `backend/docs/` D2). On reconnect
|
||||
// it is replayed back to backend so any events still in the freshness-
|
||||
// window ring are received exactly once.
|
||||
package backendclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand/v2"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
pushv1 "galaxy/backend/proto/push/v1"
|
||||
|
||||
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// EventHandler receives every PushEvent successfully drained from the
|
||||
// backend stream. Implementations must be concurrency-safe and must not
|
||||
// block; PushClient owns the calling goroutine and waits for Handle to
|
||||
// return before reading the next event.
|
||||
type EventHandler interface {
|
||||
Handle(context.Context, *pushv1.PushEvent)
|
||||
}
|
||||
|
||||
// EventHandlerFunc adapts a plain function to the EventHandler
|
||||
// contract.
|
||||
type EventHandlerFunc func(context.Context, *pushv1.PushEvent)
|
||||
|
||||
// Handle implements EventHandler.
|
||||
func (f EventHandlerFunc) Handle(ctx context.Context, ev *pushv1.PushEvent) { f(ctx, ev) }
|
||||
|
||||
// PushClient is the gRPC adapter that owns the long-lived
|
||||
// SubscribePush stream.
|
||||
type PushClient struct {
|
||||
cfg Config
|
||||
dialOpts []grpc.DialOption
|
||||
clock func() time.Time
|
||||
sleep func(context.Context, time.Duration) error
|
||||
logger *zap.Logger
|
||||
handler EventHandler
|
||||
|
||||
mu sync.Mutex
|
||||
cursor string
|
||||
|
||||
connMu sync.Mutex
|
||||
conn *grpc.ClientConn
|
||||
}
|
||||
|
||||
// NewPushClient constructs a PushClient. The default dial uses
|
||||
// transport credentials INSECURE; deployments behind TLS must wrap the
|
||||
// returned client with an alternative DialOption set via
|
||||
// WithDialOptions before calling Run.
|
||||
func NewPushClient(cfg Config) (*PushClient, error) {
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &PushClient{
|
||||
cfg: cfg,
|
||||
dialOpts: []grpc.DialOption{
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithStatsHandler(otelgrpc.NewClientHandler()),
|
||||
},
|
||||
clock: time.Now,
|
||||
sleep: defaultSleep,
|
||||
logger: zap.NewNop(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// WithDialOptions overrides the default dial options used when opening
|
||||
// the gRPC connection. Tests typically pass `grpc.WithContextDialer` so
|
||||
// `grpc.NewClient` connects to a `bufconn` listener.
|
||||
func (c *PushClient) WithDialOptions(opts ...grpc.DialOption) *PushClient {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
c.dialOpts = append([]grpc.DialOption(nil), opts...)
|
||||
return c
|
||||
}
|
||||
|
||||
// WithLogger replaces the structured logger.
|
||||
func (c *PushClient) WithLogger(logger *zap.Logger) *PushClient {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
c.logger = logger.Named("push_client")
|
||||
return c
|
||||
}
|
||||
|
||||
// WithHandler installs the EventHandler. Run returns an error if no
|
||||
// handler has been installed.
|
||||
func (c *PushClient) WithHandler(handler EventHandler) *PushClient {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
c.handler = handler
|
||||
return c
|
||||
}
|
||||
|
||||
// Cursor returns the cursor of the last event delivered to the handler.
|
||||
// Useful for tests and operator inspection. Returns the empty string
|
||||
// before any event has been processed.
|
||||
func (c *PushClient) Cursor() string {
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.cursor
|
||||
}
|
||||
|
||||
// Run opens the SubscribePush stream and forwards events until ctx is
|
||||
// cancelled. Network errors are retried with exponential backoff up to
|
||||
// PushReconnectMaxBackoff; ctx cancellation is the only terminal exit.
|
||||
func (c *PushClient) Run(ctx context.Context) error {
|
||||
if c == nil {
|
||||
return errors.New("backendclient.PushClient.Run: nil client")
|
||||
}
|
||||
if ctx == nil {
|
||||
return errors.New("backendclient.PushClient.Run: nil context")
|
||||
}
|
||||
if c.handler == nil {
|
||||
return errors.New("backendclient.PushClient.Run: handler is required")
|
||||
}
|
||||
|
||||
conn, err := grpc.NewClient(c.cfg.GRPCPushURL, c.dialOpts...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("backendclient.PushClient.Run: dial backend push: %w", err)
|
||||
}
|
||||
c.connMu.Lock()
|
||||
c.conn = conn
|
||||
c.connMu.Unlock()
|
||||
defer func() {
|
||||
c.connMu.Lock()
|
||||
_ = c.conn.Close()
|
||||
c.conn = nil
|
||||
c.connMu.Unlock()
|
||||
}()
|
||||
|
||||
pushAPI := pushv1.NewPushClient(conn)
|
||||
backoff := c.cfg.PushReconnectBaseBackoff
|
||||
|
||||
for {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err := c.runOnce(ctx, pushAPI)
|
||||
switch {
|
||||
case err == nil, errors.Is(err, context.Canceled):
|
||||
return ctx.Err()
|
||||
case status.Code(err) == codes.Aborted:
|
||||
c.logger.Info("backend replaced push subscription; reconnecting")
|
||||
case errors.Is(err, io.EOF):
|
||||
c.logger.Info("backend push stream closed; reconnecting")
|
||||
default:
|
||||
c.logger.Warn("backend push stream error; reconnecting",
|
||||
zap.Error(err),
|
||||
zap.Duration("backoff", backoff),
|
||||
)
|
||||
}
|
||||
|
||||
if err := c.sleep(ctx, jitter(backoff)); err != nil {
|
||||
return err
|
||||
}
|
||||
backoff = nextBackoff(backoff, c.cfg.PushReconnectMaxBackoff)
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown is a no-op kept for `app.Component` compatibility. The
|
||||
// SubscribePush call exits when its parent context is cancelled.
|
||||
func (c *PushClient) Shutdown(_ context.Context) error { return nil }
|
||||
|
||||
// Close closes the underlying gRPC connection if it is open. Idempotent.
|
||||
func (c *PushClient) Close() error {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
c.connMu.Lock()
|
||||
defer c.connMu.Unlock()
|
||||
if c.conn == nil {
|
||||
return nil
|
||||
}
|
||||
err := c.conn.Close()
|
||||
c.conn = nil
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *PushClient) runOnce(ctx context.Context, pushAPI pushv1.PushClient) error {
|
||||
stream, err := pushAPI.SubscribePush(ctx, &pushv1.GatewaySubscribeRequest{
|
||||
GatewayClientId: c.cfg.GatewayClientID,
|
||||
Cursor: c.Cursor(),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("subscribe push: %w", err)
|
||||
}
|
||||
|
||||
for {
|
||||
ev, err := stream.Recv()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.handler.Handle(ctx, ev)
|
||||
if cursor := ev.GetCursor(); cursor != "" {
|
||||
c.setCursor(cursor)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *PushClient) setCursor(cursor string) {
|
||||
c.mu.Lock()
|
||||
c.cursor = cursor
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func nextBackoff(current, max time.Duration) time.Duration {
|
||||
doubled := current * 2
|
||||
if doubled > max {
|
||||
return max
|
||||
}
|
||||
if doubled <= 0 {
|
||||
return max
|
||||
}
|
||||
return doubled
|
||||
}
|
||||
|
||||
// jitter returns d with ±20% multiplicative noise so multiple gateway
|
||||
// instances do not retry in lockstep after a backend restart.
|
||||
func jitter(d time.Duration) time.Duration {
|
||||
if d <= 0 {
|
||||
return d
|
||||
}
|
||||
noise := 1 + (rand.Float64()-0.5)*0.4
|
||||
return time.Duration(float64(d) * noise)
|
||||
}
|
||||
|
||||
func defaultSleep(ctx context.Context, d time.Duration) error {
|
||||
if d <= 0 {
|
||||
return nil
|
||||
}
|
||||
timer := time.NewTimer(d)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,132 @@
|
||||
package backendclient_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
backendpush "galaxy/backend/push"
|
||||
pushv1 "galaxy/backend/proto/push/v1"
|
||||
"galaxy/gateway/internal/backendclient"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/test/bufconn"
|
||||
)
|
||||
|
||||
// bufconnPushService starts an in-process backend push.Service backed by
|
||||
// a *grpc.Server on a bufconn listener and returns the dial option that
|
||||
// gateway PushClient should use to connect to it.
|
||||
type bufconnPushService struct {
|
||||
Service *backendpush.Service
|
||||
dial func(context.Context, string) (net.Conn, error)
|
||||
stop func()
|
||||
}
|
||||
|
||||
func newBufconnPushService(t *testing.T) *bufconnPushService {
|
||||
t.Helper()
|
||||
|
||||
service, err := backendpush.NewService(backendpush.ServiceConfig{
|
||||
FreshnessWindow: time.Minute,
|
||||
RingCapacity: 16,
|
||||
PerConnBuffer: 8,
|
||||
}, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
listener := bufconn.Listen(1 << 16)
|
||||
server := grpc.NewServer()
|
||||
pushv1.RegisterPushServer(server, service)
|
||||
|
||||
go func() {
|
||||
_ = server.Serve(listener)
|
||||
}()
|
||||
|
||||
stop := func() {
|
||||
service.Close()
|
||||
server.Stop()
|
||||
_ = listener.Close()
|
||||
}
|
||||
t.Cleanup(stop)
|
||||
|
||||
return &bufconnPushService{
|
||||
Service: service,
|
||||
dial: func(_ context.Context, _ string) (net.Conn, error) { return listener.Dial() },
|
||||
stop: stop,
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushClientDeliversClientEventsAndAdvancesCursor(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := newBufconnPushService(t)
|
||||
|
||||
type received struct {
|
||||
event *pushv1.PushEvent
|
||||
cursor string
|
||||
}
|
||||
out := make(chan received, 4)
|
||||
|
||||
cfg := backendclient.Config{
|
||||
HTTPBaseURL: "http://example.invalid",
|
||||
GRPCPushURL: "passthrough://bufconn",
|
||||
GatewayClientID: "gw-1",
|
||||
HTTPTimeout: time.Second,
|
||||
PushReconnectBaseBackoff: 10 * time.Millisecond,
|
||||
PushReconnectMaxBackoff: 100 * time.Millisecond,
|
||||
}
|
||||
client, err := backendclient.NewPushClient(cfg)
|
||||
require.NoError(t, err)
|
||||
client.WithDialOptions(
|
||||
grpc.WithContextDialer(svc.dial),
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
)
|
||||
client.WithHandler(backendclient.EventHandlerFunc(func(_ context.Context, ev *pushv1.PushEvent) {
|
||||
out <- received{event: ev, cursor: ev.GetCursor()}
|
||||
}))
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
var (
|
||||
runErr error
|
||||
wg sync.WaitGroup
|
||||
)
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
runErr = client.Run(ctx)
|
||||
}()
|
||||
|
||||
// Wait for backend service to register the subscription.
|
||||
require.Eventually(t, func() bool { return svc.Service.SubscriberCount() == 1 }, time.Second, 10*time.Millisecond)
|
||||
|
||||
userID := uuid.New()
|
||||
require.NoError(t, svc.Service.PublishClientEvent(context.Background(), userID, nil, "lobby.invite.received", map[string]any{"x": 1.0}, "evt-1", "req-1", "trace-1"))
|
||||
|
||||
select {
|
||||
case got := <-out:
|
||||
ce := got.event.GetClientEvent()
|
||||
require.NotNil(t, ce)
|
||||
assert.Equal(t, userID.String(), ce.GetUserId())
|
||||
assert.Equal(t, "lobby.invite.received", ce.GetKind())
|
||||
assert.Equal(t, "evt-1", ce.GetEventId())
|
||||
assert.Equal(t, "req-1", ce.GetRequestId())
|
||||
assert.Equal(t, "trace-1", ce.GetTraceId())
|
||||
assert.NotEmpty(t, got.cursor)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for client event")
|
||||
}
|
||||
|
||||
require.Eventually(t, func() bool { return client.Cursor() != "" }, time.Second, 10*time.Millisecond)
|
||||
|
||||
cancel()
|
||||
wg.Wait()
|
||||
if runErr != nil && runErr != context.Canceled {
|
||||
t.Fatalf("unexpected run error: %v", runErr)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,256 @@
|
||||
package backendclient
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/session"
|
||||
)
|
||||
|
||||
// HeaderUserID is the trusted gateway → backend identity header.
|
||||
const HeaderUserID = "X-User-Id"
|
||||
|
||||
// errSessionNotFound is the public error returned by LookupSession when
|
||||
// backend reports HTTP 404 for a device session id. It wraps
|
||||
// session.ErrNotFound so callers can keep using the existing typed
|
||||
// equality check at the gateway hot path.
|
||||
func errSessionNotFound() error {
|
||||
return fmt.Errorf("backendclient: lookup session: %w", session.ErrNotFound)
|
||||
}
|
||||
|
||||
// RESTClient owns the gateway's HTTP conversation with backend.
|
||||
//
|
||||
// All methods are safe for concurrent use.
|
||||
type RESTClient struct {
|
||||
baseURL string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewRESTClient constructs a RESTClient targeting the backend HTTP
|
||||
// listener configured in cfg.
|
||||
func NewRESTClient(cfg Config) (*RESTClient, error) {
|
||||
transport, ok := http.DefaultTransport.(*http.Transport)
|
||||
if !ok {
|
||||
return nil, errors.New("backendclient: default HTTP transport is not *http.Transport")
|
||||
}
|
||||
parsed, err := url.Parse(strings.TrimRight(strings.TrimSpace(cfg.HTTPBaseURL), "/"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("backendclient: parse HTTPBaseURL: %w", err)
|
||||
}
|
||||
if parsed.Scheme == "" || parsed.Host == "" {
|
||||
return nil, errors.New("backendclient: HTTPBaseURL must be absolute")
|
||||
}
|
||||
return &RESTClient{
|
||||
baseURL: parsed.String(),
|
||||
httpClient: &http.Client{
|
||||
Transport: transport.Clone(),
|
||||
Timeout: cfg.HTTPTimeout,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close releases idle HTTP connections owned by the client transport.
|
||||
func (c *RESTClient) Close() error {
|
||||
if c == nil || c.httpClient == nil {
|
||||
return nil
|
||||
}
|
||||
type idleCloser interface {
|
||||
CloseIdleConnections()
|
||||
}
|
||||
if transport, ok := c.httpClient.Transport.(idleCloser); ok {
|
||||
transport.CloseIdleConnections()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LookupSession resolves deviceSessionID against
|
||||
// `GET /api/v1/internal/sessions/{device_session_id}`.
|
||||
// Returns session.ErrNotFound (wrapped) when backend reports 404.
|
||||
func (c *RESTClient) LookupSession(ctx context.Context, deviceSessionID string) (session.Record, error) {
|
||||
if c == nil || c.httpClient == nil {
|
||||
return session.Record{}, errors.New("backendclient: nil REST client")
|
||||
}
|
||||
if strings.TrimSpace(deviceSessionID) == "" {
|
||||
return session.Record{}, errors.New("backendclient: lookup session: device_session_id must not be empty")
|
||||
}
|
||||
|
||||
target := c.baseURL + "/api/v1/internal/sessions/" + url.PathEscape(deviceSessionID)
|
||||
body, status, err := c.do(ctx, http.MethodGet, target, "", nil)
|
||||
if err != nil {
|
||||
return session.Record{}, fmt.Errorf("backendclient: lookup session: %w", err)
|
||||
}
|
||||
|
||||
switch {
|
||||
case status == http.StatusOK:
|
||||
return decodeDeviceSession(deviceSessionID, body)
|
||||
case status == http.StatusNotFound:
|
||||
return session.Record{}, errSessionNotFound()
|
||||
default:
|
||||
return session.Record{}, fmt.Errorf("backendclient: lookup session: unexpected HTTP status %d", status)
|
||||
}
|
||||
}
|
||||
|
||||
// RevokeSession asks backend to revoke a single device session by id.
|
||||
func (c *RESTClient) RevokeSession(ctx context.Context, deviceSessionID string) error {
|
||||
if strings.TrimSpace(deviceSessionID) == "" {
|
||||
return errors.New("backendclient: revoke session: device_session_id must not be empty")
|
||||
}
|
||||
target := c.baseURL + "/api/v1/internal/sessions/" + url.PathEscape(deviceSessionID) + "/revoke"
|
||||
_, status, err := c.do(ctx, http.MethodPost, target, "", nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("backendclient: revoke session: %w", err)
|
||||
}
|
||||
if status == http.StatusOK || status == http.StatusNoContent {
|
||||
return nil
|
||||
}
|
||||
if status == http.StatusNotFound {
|
||||
return errSessionNotFound()
|
||||
}
|
||||
return fmt.Errorf("backendclient: revoke session: unexpected HTTP status %d", status)
|
||||
}
|
||||
|
||||
// RevokeAllSessionsForUser asks backend to revoke every active device
|
||||
// session belonging to userID.
|
||||
func (c *RESTClient) RevokeAllSessionsForUser(ctx context.Context, userID string) error {
|
||||
if strings.TrimSpace(userID) == "" {
|
||||
return errors.New("backendclient: revoke-all sessions: user_id must not be empty")
|
||||
}
|
||||
target := c.baseURL + "/api/v1/internal/sessions/users/" + url.PathEscape(userID) + "/revoke-all"
|
||||
_, status, err := c.do(ctx, http.MethodPost, target, "", nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("backendclient: revoke-all sessions: %w", err)
|
||||
}
|
||||
if status == http.StatusOK || status == http.StatusNoContent {
|
||||
return nil
|
||||
}
|
||||
if status == http.StatusNotFound {
|
||||
return errSessionNotFound()
|
||||
}
|
||||
return fmt.Errorf("backendclient: revoke-all sessions: unexpected HTTP status %d", status)
|
||||
}
|
||||
|
||||
// do executes a JSON request and reads the response body. userID, when
|
||||
// non-empty, is sent as the X-User-Id header (required for `/api/v1/user/*`).
|
||||
func (c *RESTClient) do(ctx context.Context, method, target, userID string, body any) ([]byte, int, error) {
|
||||
return c.doWithHeaders(ctx, method, target, userID, body, nil)
|
||||
}
|
||||
|
||||
// doWithHeaders is the shared transport entry point. extraHeaders are
|
||||
// applied verbatim after Content-Type/X-User-Id; an empty value drops
|
||||
// the header so callers can pass optional language tags etc.
|
||||
func (c *RESTClient) doWithHeaders(ctx context.Context, method, target, userID string, body any, extraHeaders map[string]string) ([]byte, int, error) {
|
||||
if c == nil || c.httpClient == nil {
|
||||
return nil, 0, errors.New("nil REST client")
|
||||
}
|
||||
if ctx == nil {
|
||||
return nil, 0, errors.New("nil context")
|
||||
}
|
||||
|
||||
var reader io.Reader
|
||||
if body != nil {
|
||||
buf, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("marshal request body: %w", err)
|
||||
}
|
||||
reader = bytes.NewReader(buf)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, target, reader)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("build request: %w", err)
|
||||
}
|
||||
if body != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
if userID != "" {
|
||||
req.Header.Set(HeaderUserID, userID)
|
||||
}
|
||||
for key, value := range extraHeaders {
|
||||
if strings.TrimSpace(value) == "" {
|
||||
continue
|
||||
}
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
payload, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, resp.StatusCode, fmt.Errorf("read response body: %w", err)
|
||||
}
|
||||
return payload, resp.StatusCode, nil
|
||||
}
|
||||
|
||||
// deviceSessionWire mirrors backend openapi `DeviceSession`.
|
||||
type deviceSessionWire struct {
|
||||
DeviceSessionID string `json:"device_session_id"`
|
||||
UserID string `json:"user_id"`
|
||||
Status string `json:"status"`
|
||||
ClientPublicKey string `json:"client_public_key,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
RevokedAt *time.Time `json:"revoked_at,omitempty"`
|
||||
LastSeenAt *time.Time `json:"last_seen_at,omitempty"`
|
||||
}
|
||||
|
||||
func decodeDeviceSession(expectedDeviceSessionID string, payload []byte) (session.Record, error) {
|
||||
var wire deviceSessionWire
|
||||
if err := decodeStrictJSON(payload, &wire); err != nil {
|
||||
return session.Record{}, fmt.Errorf("decode device session: %w", err)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(wire.DeviceSessionID) == "" {
|
||||
return session.Record{}, errors.New("decode device session: device_session_id must not be empty")
|
||||
}
|
||||
if wire.DeviceSessionID != expectedDeviceSessionID {
|
||||
return session.Record{}, fmt.Errorf("decode device session: device_session_id %q does not match requested %q", wire.DeviceSessionID, expectedDeviceSessionID)
|
||||
}
|
||||
if strings.TrimSpace(wire.UserID) == "" {
|
||||
return session.Record{}, errors.New("decode device session: user_id must not be empty")
|
||||
}
|
||||
|
||||
status := session.Status(strings.TrimSpace(wire.Status))
|
||||
if !status.IsKnown() {
|
||||
return session.Record{}, fmt.Errorf("decode device session: status %q is unsupported", wire.Status)
|
||||
}
|
||||
if status == session.StatusActive && strings.TrimSpace(wire.ClientPublicKey) == "" {
|
||||
return session.Record{}, errors.New("decode device session: active record missing client_public_key")
|
||||
}
|
||||
|
||||
record := session.Record{
|
||||
DeviceSessionID: wire.DeviceSessionID,
|
||||
UserID: wire.UserID,
|
||||
ClientPublicKey: wire.ClientPublicKey,
|
||||
Status: status,
|
||||
}
|
||||
if wire.RevokedAt != nil {
|
||||
ms := wire.RevokedAt.UnixMilli()
|
||||
record.RevokedAtMS = &ms
|
||||
}
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func decodeStrictJSON(payload []byte, target any) error {
|
||||
decoder := json.NewDecoder(bytes.NewReader(payload))
|
||||
decoder.DisallowUnknownFields()
|
||||
if err := decoder.Decode(target); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := decoder.Decode(&struct{}{}); err != io.EOF {
|
||||
if err == nil {
|
||||
return errors.New("unexpected trailing JSON input")
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,190 @@
|
||||
package backendclient_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/backendclient"
|
||||
"galaxy/gateway/internal/session"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newRESTClient(t *testing.T, server *httptest.Server) *backendclient.RESTClient {
|
||||
t.Helper()
|
||||
cfg := backendclient.Config{
|
||||
HTTPBaseURL: server.URL,
|
||||
GRPCPushURL: "passthrough://test",
|
||||
GatewayClientID: "test-gateway",
|
||||
HTTPTimeout: time.Second,
|
||||
PushReconnectBaseBackoff: 10 * time.Millisecond,
|
||||
PushReconnectMaxBackoff: 100 * time.Millisecond,
|
||||
}
|
||||
client, err := backendclient.NewRESTClient(cfg)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { _ = client.Close() })
|
||||
return client
|
||||
}
|
||||
|
||||
func TestRESTClientLookupSessionReturnsActiveRecord(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, http.MethodGet, r.Method)
|
||||
require.Equal(t, "/api/v1/internal/sessions/device-1", r.URL.Path)
|
||||
writeJSON(t, w, http.StatusOK, map[string]any{
|
||||
"device_session_id": "device-1",
|
||||
"user_id": "user-1",
|
||||
"status": "active",
|
||||
"client_public_key": "pk-1",
|
||||
"created_at": "2026-04-01T00:00:00Z",
|
||||
})
|
||||
}))
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
client := newRESTClient(t, server)
|
||||
rec, err := client.LookupSession(context.Background(), "device-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, session.Record{
|
||||
DeviceSessionID: "device-1",
|
||||
UserID: "user-1",
|
||||
ClientPublicKey: "pk-1",
|
||||
Status: session.StatusActive,
|
||||
}, rec)
|
||||
}
|
||||
|
||||
func TestRESTClientLookupSessionReturnsRevokedRecord(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(t, w, http.StatusOK, map[string]any{
|
||||
"device_session_id": "device-2",
|
||||
"user_id": "user-2",
|
||||
"status": "revoked",
|
||||
"client_public_key": "pk-2",
|
||||
"created_at": "2026-04-01T00:00:00Z",
|
||||
"revoked_at": "2026-04-01T00:01:00Z",
|
||||
})
|
||||
}))
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
client := newRESTClient(t, server)
|
||||
rec, err := client.LookupSession(context.Background(), "device-2")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, session.StatusRevoked, rec.Status)
|
||||
require.NotNil(t, rec.RevokedAtMS)
|
||||
assert.Equal(t, time.Date(2026, 4, 1, 0, 1, 0, 0, time.UTC).UnixMilli(), *rec.RevokedAtMS)
|
||||
}
|
||||
|
||||
func TestRESTClientLookupSessionMapsNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
writeJSON(t, w, http.StatusNotFound, map[string]any{"error": map[string]any{"code": "subject_not_found", "message": "missing"}})
|
||||
}))
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
client := newRESTClient(t, server)
|
||||
_, err := client.LookupSession(context.Background(), "missing")
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, session.ErrNotFound))
|
||||
}
|
||||
|
||||
func TestRESTClientLookupSessionRejectsMismatchedID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
writeJSON(t, w, http.StatusOK, map[string]any{
|
||||
"device_session_id": "other",
|
||||
"user_id": "user-1",
|
||||
"status": "active",
|
||||
"client_public_key": "pk-1",
|
||||
"created_at": "2026-04-01T00:00:00Z",
|
||||
})
|
||||
}))
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
client := newRESTClient(t, server)
|
||||
_, err := client.LookupSession(context.Background(), "device-1")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "does not match requested")
|
||||
}
|
||||
|
||||
func TestRESTClientSendEmailCodeForwardsAcceptLanguage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, http.MethodPost, r.Method)
|
||||
require.Equal(t, "/api/v1/public/auth/send-email-code", r.URL.Path)
|
||||
require.Equal(t, "ru-RU", r.Header.Get("Accept-Language"))
|
||||
writeJSON(t, w, http.StatusOK, map[string]any{"challenge_id": "challenge-1"})
|
||||
}))
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
client := newRESTClient(t, server)
|
||||
out, err := client.SendEmailCode(context.Background(), backendclient.SendEmailCodeInput{
|
||||
Email: "user@example.com",
|
||||
PreferredLanguage: "ru-RU",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "challenge-1", out.ChallengeID)
|
||||
}
|
||||
|
||||
func TestRESTClientSendEmailCodeProjectsAuthError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
writeJSON(t, w, http.StatusBadRequest, map[string]any{
|
||||
"error": map[string]any{"code": "invalid_request", "message": "bad email"},
|
||||
})
|
||||
}))
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
client := newRESTClient(t, server)
|
||||
_, err := client.SendEmailCode(context.Background(), backendclient.SendEmailCodeInput{Email: "user@example.com"})
|
||||
require.Error(t, err)
|
||||
var authErr *backendclient.AuthError
|
||||
require.ErrorAs(t, err, &authErr)
|
||||
assert.Equal(t, http.StatusBadRequest, authErr.StatusCode)
|
||||
assert.Equal(t, "invalid_request", authErr.Code)
|
||||
assert.Equal(t, "bad email", authErr.Message)
|
||||
}
|
||||
|
||||
func TestRESTClientConfirmEmailCodeReturnsDeviceSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/api/v1/public/auth/confirm-email-code", r.URL.Path)
|
||||
|
||||
var body backendclient.ConfirmEmailCodeInput
|
||||
require.NoError(t, json.NewDecoder(r.Body).Decode(&body))
|
||||
assert.Equal(t, "challenge-1", body.ChallengeID)
|
||||
writeJSON(t, w, http.StatusOK, map[string]any{"device_session_id": "device-1"})
|
||||
}))
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
client := newRESTClient(t, server)
|
||||
out, err := client.ConfirmEmailCode(context.Background(), backendclient.ConfirmEmailCodeInput{
|
||||
ChallengeID: "challenge-1",
|
||||
Code: "12345",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "device-1", out.DeviceSessionID)
|
||||
}
|
||||
|
||||
func writeJSON(t *testing.T, w http.ResponseWriter, status int, body any) {
|
||||
t.Helper()
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
require.NoError(t, json.NewEncoder(w).Encode(body))
|
||||
}
|
||||
|
||||
// guard ensures package keeps testify dependency.
|
||||
var _ = strings.TrimSpace
|
||||
@@ -0,0 +1,67 @@
|
||||
package backendclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"galaxy/gateway/internal/downstream"
|
||||
lobbymodel "galaxy/model/lobby"
|
||||
usermodel "galaxy/model/user"
|
||||
)
|
||||
|
||||
// UserRoutes returns the authenticated `user.*` downstream routes
|
||||
// served by backend. When client is nil every route resolves to a
|
||||
// dependency-unavailable client so the static router still recognises
|
||||
// the message types.
|
||||
func UserRoutes(client *RESTClient) map[string]downstream.Client {
|
||||
target := downstream.Client(unavailableClient{})
|
||||
if client != nil {
|
||||
target = userCommandClient{rest: client}
|
||||
}
|
||||
return map[string]downstream.Client{
|
||||
usermodel.MessageTypeGetMyAccount: target,
|
||||
usermodel.MessageTypeUpdateMyProfile: target,
|
||||
usermodel.MessageTypeUpdateMySettings: target,
|
||||
}
|
||||
}
|
||||
|
||||
// LobbyRoutes returns the authenticated `lobby.*` downstream routes
|
||||
// served by backend. When client is nil every route resolves to a
|
||||
// dependency-unavailable client.
|
||||
func LobbyRoutes(client *RESTClient) map[string]downstream.Client {
|
||||
target := downstream.Client(unavailableClient{})
|
||||
if client != nil {
|
||||
target = lobbyCommandClient{rest: client}
|
||||
}
|
||||
return map[string]downstream.Client{
|
||||
lobbymodel.MessageTypeMyGamesList: target,
|
||||
lobbymodel.MessageTypeOpenEnrollment: target,
|
||||
}
|
||||
}
|
||||
|
||||
type unavailableClient struct{}
|
||||
|
||||
func (unavailableClient) ExecuteCommand(context.Context, downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
|
||||
return downstream.UnaryResult{}, downstream.ErrDownstreamUnavailable
|
||||
}
|
||||
|
||||
type userCommandClient struct {
|
||||
rest *RESTClient
|
||||
}
|
||||
|
||||
func (c userCommandClient) ExecuteCommand(ctx context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
|
||||
return c.rest.ExecuteUserCommand(ctx, command)
|
||||
}
|
||||
|
||||
type lobbyCommandClient struct {
|
||||
rest *RESTClient
|
||||
}
|
||||
|
||||
func (c lobbyCommandClient) ExecuteCommand(ctx context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
|
||||
return c.rest.ExecuteLobbyCommand(ctx, command)
|
||||
}
|
||||
|
||||
var (
|
||||
_ downstream.Client = unavailableClient{}
|
||||
_ downstream.Client = userCommandClient{}
|
||||
_ downstream.Client = lobbyCommandClient{}
|
||||
)
|
||||
@@ -0,0 +1,166 @@
|
||||
package backendclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"galaxy/gateway/internal/downstream"
|
||||
usermodel "galaxy/model/user"
|
||||
"galaxy/transcoder"
|
||||
)
|
||||
|
||||
const (
|
||||
userCommandResultCodeOK = "ok"
|
||||
defaultUserErrorCode = "internal_error"
|
||||
)
|
||||
|
||||
var stableUserErrorMessages = map[string]string{
|
||||
"invalid_request": "request is invalid",
|
||||
"subject_not_found": "subject not found",
|
||||
"conflict": "request conflicts with current state",
|
||||
defaultUserErrorCode: "internal server error",
|
||||
}
|
||||
|
||||
// ExecuteUserCommand routes one authenticated user-surface command into
|
||||
// backend's `/api/v1/user/*` endpoints. The function is registered for
|
||||
// the message types listed in `galaxy/model/user`.
|
||||
func (c *RESTClient) ExecuteUserCommand(ctx context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
|
||||
if c == nil || c.httpClient == nil {
|
||||
return downstream.UnaryResult{}, errors.New("backendclient: execute user command: nil client")
|
||||
}
|
||||
if ctx == nil {
|
||||
return downstream.UnaryResult{}, errors.New("backendclient: execute user command: nil context")
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return downstream.UnaryResult{}, err
|
||||
}
|
||||
if strings.TrimSpace(command.UserID) == "" {
|
||||
return downstream.UnaryResult{}, errors.New("backendclient: execute user command: user_id must not be empty")
|
||||
}
|
||||
|
||||
switch command.MessageType {
|
||||
case usermodel.MessageTypeGetMyAccount:
|
||||
if _, err := transcoder.PayloadToGetMyAccountRequest(command.PayloadBytes); err != nil {
|
||||
return downstream.UnaryResult{}, fmt.Errorf("backendclient: execute user command %q: %w", command.MessageType, err)
|
||||
}
|
||||
return c.executeUserAccountGet(ctx, command.UserID)
|
||||
case usermodel.MessageTypeUpdateMyProfile:
|
||||
req, err := transcoder.PayloadToUpdateMyProfileRequest(command.PayloadBytes)
|
||||
if err != nil {
|
||||
return downstream.UnaryResult{}, fmt.Errorf("backendclient: execute user command %q: %w", command.MessageType, err)
|
||||
}
|
||||
return c.executeUserAccountUpdateProfile(ctx, command.UserID, req)
|
||||
case usermodel.MessageTypeUpdateMySettings:
|
||||
req, err := transcoder.PayloadToUpdateMySettingsRequest(command.PayloadBytes)
|
||||
if err != nil {
|
||||
return downstream.UnaryResult{}, fmt.Errorf("backendclient: execute user command %q: %w", command.MessageType, err)
|
||||
}
|
||||
return c.executeUserAccountUpdateSettings(ctx, command.UserID, req)
|
||||
default:
|
||||
return downstream.UnaryResult{}, fmt.Errorf("backendclient: execute user command: unsupported message type %q", command.MessageType)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RESTClient) executeUserAccountGet(ctx context.Context, userID string) (downstream.UnaryResult, error) {
|
||||
body, status, err := c.do(ctx, http.MethodGet, c.baseURL+"/api/v1/user/account", userID, nil)
|
||||
if err != nil {
|
||||
return downstream.UnaryResult{}, fmt.Errorf("execute user.account.get: %w", err)
|
||||
}
|
||||
return projectUserResponse(status, body)
|
||||
}
|
||||
|
||||
func (c *RESTClient) executeUserAccountUpdateProfile(ctx context.Context, userID string, req *usermodel.UpdateMyProfileRequest) (downstream.UnaryResult, error) {
|
||||
body, status, err := c.do(ctx, http.MethodPatch, c.baseURL+"/api/v1/user/account/profile", userID, req)
|
||||
if err != nil {
|
||||
return downstream.UnaryResult{}, fmt.Errorf("execute user.profile.update: %w", err)
|
||||
}
|
||||
return projectUserResponse(status, body)
|
||||
}
|
||||
|
||||
func (c *RESTClient) executeUserAccountUpdateSettings(ctx context.Context, userID string, req *usermodel.UpdateMySettingsRequest) (downstream.UnaryResult, error) {
|
||||
body, status, err := c.do(ctx, http.MethodPatch, c.baseURL+"/api/v1/user/account/settings", userID, req)
|
||||
if err != nil {
|
||||
return downstream.UnaryResult{}, fmt.Errorf("execute user.settings.update: %w", err)
|
||||
}
|
||||
return projectUserResponse(status, body)
|
||||
}
|
||||
|
||||
func projectUserResponse(statusCode int, payload []byte) (downstream.UnaryResult, error) {
|
||||
switch {
|
||||
case statusCode == http.StatusOK:
|
||||
var response usermodel.AccountResponse
|
||||
if err := decodeStrictJSON(payload, &response); err != nil {
|
||||
return downstream.UnaryResult{}, fmt.Errorf("decode success response: %w", err)
|
||||
}
|
||||
payloadBytes, err := transcoder.AccountResponseToPayload(&response)
|
||||
if err != nil {
|
||||
return downstream.UnaryResult{}, fmt.Errorf("encode success response payload: %w", err)
|
||||
}
|
||||
return downstream.UnaryResult{
|
||||
ResultCode: userCommandResultCodeOK,
|
||||
PayloadBytes: payloadBytes,
|
||||
}, nil
|
||||
case statusCode == http.StatusServiceUnavailable:
|
||||
return downstream.UnaryResult{}, downstream.ErrDownstreamUnavailable
|
||||
case statusCode >= 400 && statusCode <= 599:
|
||||
errResp, err := decodeUserError(statusCode, payload)
|
||||
if err != nil {
|
||||
return downstream.UnaryResult{}, fmt.Errorf("decode error response: %w", err)
|
||||
}
|
||||
payloadBytes, err := transcoder.ErrorResponseToPayload(errResp)
|
||||
if err != nil {
|
||||
return downstream.UnaryResult{}, fmt.Errorf("encode error response payload: %w", err)
|
||||
}
|
||||
return downstream.UnaryResult{
|
||||
ResultCode: errResp.Error.Code,
|
||||
PayloadBytes: payloadBytes,
|
||||
}, nil
|
||||
default:
|
||||
return downstream.UnaryResult{}, fmt.Errorf("unexpected HTTP status %d", statusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func decodeUserError(statusCode int, payload []byte) (*usermodel.ErrorResponse, error) {
|
||||
var response usermodel.ErrorResponse
|
||||
if err := decodeStrictJSON(payload, &response); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
response.Error.Code = normalizeUserErrorCode(statusCode, response.Error.Code)
|
||||
response.Error.Message = normalizeUserErrorMessage(response.Error.Code, response.Error.Message)
|
||||
if strings.TrimSpace(response.Error.Code) == "" {
|
||||
return nil, errors.New("missing error code")
|
||||
}
|
||||
if strings.TrimSpace(response.Error.Message) == "" {
|
||||
return nil, errors.New("missing error message")
|
||||
}
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
func normalizeUserErrorCode(statusCode int, code string) string {
|
||||
if trimmed := strings.TrimSpace(code); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
switch statusCode {
|
||||
case http.StatusBadRequest:
|
||||
return "invalid_request"
|
||||
case http.StatusNotFound:
|
||||
return "subject_not_found"
|
||||
case http.StatusConflict:
|
||||
return "conflict"
|
||||
default:
|
||||
return defaultUserErrorCode
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeUserErrorMessage(code, message string) string {
|
||||
if trimmed := strings.TrimSpace(message); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
if stable, ok := stableUserErrorMessages[code]; ok {
|
||||
return stable
|
||||
}
|
||||
return stableUserErrorMessages[defaultUserErrorCode]
|
||||
}
|
||||
+176
-312
@@ -44,20 +44,34 @@ const (
|
||||
// configures the timeout budget used for public auth upstream calls.
|
||||
publicAuthUpstreamTimeoutEnvVar = "GATEWAY_PUBLIC_AUTH_UPSTREAM_TIMEOUT"
|
||||
|
||||
// authServiceBaseURLEnvVar names the environment variable that configures
|
||||
// the optional Auth / Session Service public HTTP base URL used by gateway
|
||||
// public-auth delegation.
|
||||
authServiceBaseURLEnvVar = "GATEWAY_AUTH_SERVICE_BASE_URL"
|
||||
// backendHTTPURLEnvVar names the environment variable that configures
|
||||
// the absolute base URL of the consolidated backend HTTP listener used
|
||||
// for public auth, internal session lookup, and authenticated user /
|
||||
// lobby commands.
|
||||
backendHTTPURLEnvVar = "GATEWAY_BACKEND_HTTP_URL"
|
||||
|
||||
// userServiceBaseURLEnvVar names the environment variable that configures
|
||||
// the optional User Service internal HTTP base URL used by authenticated
|
||||
// gateway self-service delegation.
|
||||
userServiceBaseURLEnvVar = "GATEWAY_USER_SERVICE_BASE_URL"
|
||||
// backendGRPCPushURLEnvVar names the environment variable that
|
||||
// configures the dial target of backend's gRPC `Push.SubscribePush`
|
||||
// listener.
|
||||
backendGRPCPushURLEnvVar = "GATEWAY_BACKEND_GRPC_PUSH_URL"
|
||||
|
||||
// lobbyServiceBaseURLEnvVar names the environment variable that configures
|
||||
// the optional Game Lobby public HTTP base URL used by authenticated
|
||||
// gateway platform-command delegation.
|
||||
lobbyServiceBaseURLEnvVar = "GATEWAY_LOBBY_SERVICE_BASE_URL"
|
||||
// backendGatewayClientIDEnvVar names the environment variable that
|
||||
// configures the durable identifier this gateway instance presents to
|
||||
// backend in `GatewaySubscribeRequest.gateway_client_id`.
|
||||
backendGatewayClientIDEnvVar = "GATEWAY_BACKEND_GATEWAY_CLIENT_ID"
|
||||
|
||||
// backendHTTPTimeoutEnvVar names the environment variable that
|
||||
// configures the per-call timeout applied to backend HTTP requests.
|
||||
backendHTTPTimeoutEnvVar = "GATEWAY_BACKEND_HTTP_TIMEOUT"
|
||||
|
||||
// backendPushReconnectBaseBackoffEnvVar names the environment variable
|
||||
// that configures the starting delay between reconnect attempts of the
|
||||
// gRPC SubscribePush stream.
|
||||
backendPushReconnectBaseBackoffEnvVar = "GATEWAY_BACKEND_PUSH_RECONNECT_BASE_BACKOFF"
|
||||
|
||||
// backendPushReconnectMaxBackoffEnvVar names the environment variable
|
||||
// that configures the upper bound for exponential reconnect delays.
|
||||
backendPushReconnectMaxBackoffEnvVar = "GATEWAY_BACKEND_PUSH_RECONNECT_MAX_BACKOFF"
|
||||
|
||||
// adminHTTPAddrEnvVar names the environment variable that configures the
|
||||
// private admin HTTP listener address. When it is empty, the admin listener
|
||||
@@ -152,14 +166,6 @@ const (
|
||||
// rate-limit burst.
|
||||
authenticatedGRPCMessageClassRateLimitBurstEnvVar = "GATEWAY_AUTHENTICATED_GRPC_ANTI_ABUSE_MESSAGE_CLASS_RATE_LIMIT_BURST"
|
||||
|
||||
// sessionCacheRedisKeyPrefixEnvVar names the environment variable that
|
||||
// configures the Redis key prefix used for SessionCache records.
|
||||
sessionCacheRedisKeyPrefixEnvVar = "GATEWAY_SESSION_CACHE_REDIS_KEY_PREFIX"
|
||||
|
||||
// sessionCacheRedisLookupTimeoutEnvVar names the environment variable that
|
||||
// configures the timeout used for SessionCache Redis lookups.
|
||||
sessionCacheRedisLookupTimeoutEnvVar = "GATEWAY_SESSION_CACHE_REDIS_LOOKUP_TIMEOUT"
|
||||
|
||||
// replayRedisKeyPrefixEnvVar names the environment variable that configures
|
||||
// the Redis key prefix used for authenticated replay reservations.
|
||||
replayRedisKeyPrefixEnvVar = "GATEWAY_REPLAY_REDIS_KEY_PREFIX"
|
||||
@@ -169,24 +175,6 @@ const (
|
||||
// startup connectivity checks.
|
||||
replayRedisReserveTimeoutEnvVar = "GATEWAY_REPLAY_REDIS_RESERVE_TIMEOUT"
|
||||
|
||||
// sessionEventsRedisStreamEnvVar names the environment variable that
|
||||
// configures the Redis Stream key consumed for session lifecycle updates.
|
||||
sessionEventsRedisStreamEnvVar = "GATEWAY_SESSION_EVENTS_REDIS_STREAM"
|
||||
|
||||
// sessionEventsRedisReadBlockTimeoutEnvVar names the environment variable
|
||||
// that configures the blocking read timeout used by the session event
|
||||
// subscriber.
|
||||
sessionEventsRedisReadBlockTimeoutEnvVar = "GATEWAY_SESSION_EVENTS_REDIS_READ_BLOCK_TIMEOUT"
|
||||
|
||||
// clientEventsRedisStreamEnvVar names the environment variable that
|
||||
// configures the Redis Stream key consumed for client-facing push events.
|
||||
clientEventsRedisStreamEnvVar = "GATEWAY_CLIENT_EVENTS_REDIS_STREAM"
|
||||
|
||||
// clientEventsRedisReadBlockTimeoutEnvVar names the environment variable
|
||||
// that configures the blocking read timeout used by the client-event
|
||||
// subscriber.
|
||||
clientEventsRedisReadBlockTimeoutEnvVar = "GATEWAY_CLIENT_EVENTS_REDIS_READ_BLOCK_TIMEOUT"
|
||||
|
||||
// responseSignerPrivateKeyPEMPathEnvVar names the environment variable that
|
||||
// configures the path to the PKCS#8 PEM-encoded Ed25519 private key used to
|
||||
// sign authenticated unary responses and stream events.
|
||||
@@ -293,13 +281,13 @@ const (
|
||||
defaultPublicHTTPAddr = ":8080"
|
||||
|
||||
defaultPublicHTTPReadHeaderTimeout = 2 * time.Second
|
||||
defaultPublicHTTPReadTimeout = 10 * time.Second
|
||||
defaultPublicHTTPIdleTimeout = time.Minute
|
||||
defaultPublicAuthUpstreamTimeout = 3 * time.Second
|
||||
defaultPublicHTTPReadTimeout = 10 * time.Second
|
||||
defaultPublicHTTPIdleTimeout = time.Minute
|
||||
defaultPublicAuthUpstreamTimeout = 3 * time.Second
|
||||
|
||||
defaultAdminHTTPReadHeaderTimeout = 2 * time.Second
|
||||
defaultAdminHTTPReadTimeout = 10 * time.Second
|
||||
defaultAdminHTTPIdleTimeout = time.Minute
|
||||
defaultAdminHTTPReadTimeout = 10 * time.Second
|
||||
defaultAdminHTTPIdleTimeout = time.Minute
|
||||
|
||||
// defaultAuthenticatedGRPCAddr is applied when
|
||||
// authenticatedGRPCAddrEnvVar is absent.
|
||||
@@ -307,48 +295,46 @@ const (
|
||||
|
||||
defaultAuthenticatedGRPCConnectionTimeout = 5 * time.Second
|
||||
defaultAuthenticatedGRPCDownstreamTimeout = 5 * time.Second
|
||||
defaultAuthenticatedGRPCFreshnessWindow = 5 * time.Minute
|
||||
defaultAuthenticatedGRPCFreshnessWindow = 5 * time.Minute
|
||||
|
||||
defaultAuthenticatedGRPCIPRateLimitRequests = 120
|
||||
defaultAuthenticatedGRPCIPRateLimitBurst = 40
|
||||
defaultAuthenticatedGRPCIPRateLimitBurst = 40
|
||||
|
||||
defaultAuthenticatedGRPCSessionRateLimitRequests = 60
|
||||
defaultAuthenticatedGRPCSessionRateLimitBurst = 20
|
||||
defaultAuthenticatedGRPCSessionRateLimitBurst = 20
|
||||
|
||||
defaultAuthenticatedGRPCUserRateLimitRequests = 120
|
||||
defaultAuthenticatedGRPCUserRateLimitBurst = 40
|
||||
defaultAuthenticatedGRPCUserRateLimitBurst = 40
|
||||
|
||||
defaultAuthenticatedGRPCMessageClassRateLimitRequests = 60
|
||||
defaultAuthenticatedGRPCMessageClassRateLimitBurst = 20
|
||||
defaultAuthenticatedGRPCMessageClassRateLimitBurst = 20
|
||||
|
||||
defaultSessionCacheRedisKeyPrefix = "gateway:session:"
|
||||
defaultSessionCacheRedisLookupTimeout = 250 * time.Millisecond
|
||||
|
||||
defaultReplayRedisKeyPrefix = "gateway:replay:"
|
||||
defaultReplayRedisKeyPrefix = "gateway:replay:"
|
||||
defaultReplayRedisReserveTimeout = 250 * time.Millisecond
|
||||
|
||||
defaultSessionEventsRedisReadBlockTimeout = time.Second
|
||||
defaultClientEventsRedisReadBlockTimeout = time.Second
|
||||
defaultBackendHTTPTimeout = 5 * time.Second
|
||||
defaultBackendPushReconnectBaseBackoff = 250 * time.Millisecond
|
||||
defaultBackendPushReconnectMaxBackoff = 30 * time.Second
|
||||
|
||||
defaultPublicAuthMaxBodyBytes = int64(8192)
|
||||
|
||||
defaultPublicAuthRateLimitRequests = 30
|
||||
defaultPublicAuthRateLimitBurst = 10
|
||||
defaultPublicAuthRateLimitBurst = 10
|
||||
|
||||
defaultBrowserBootstrapRateLimitRequests = 60
|
||||
defaultBrowserBootstrapRateLimitBurst = 20
|
||||
defaultBrowserBootstrapRateLimitBurst = 20
|
||||
|
||||
defaultBrowserAssetRateLimitRequests = 300
|
||||
defaultBrowserAssetRateLimitBurst = 80
|
||||
defaultBrowserAssetRateLimitBurst = 80
|
||||
|
||||
defaultPublicMiscRateLimitRequests = 30
|
||||
defaultPublicMiscRateLimitBurst = 10
|
||||
defaultPublicMiscRateLimitBurst = 10
|
||||
|
||||
defaultSendEmailCodeIdentityRateLimitRequests = 3
|
||||
defaultSendEmailCodeIdentityRateLimitBurst = 1
|
||||
defaultSendEmailCodeIdentityRateLimitBurst = 1
|
||||
|
||||
defaultConfirmEmailCodeIdentityRateLimitRequests = 6
|
||||
defaultConfirmEmailCodeIdentityRateLimitBurst = 2
|
||||
defaultConfirmEmailCodeIdentityRateLimitBurst = 2
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -462,31 +448,35 @@ type PublicHTTPConfig struct {
|
||||
AntiAbuse PublicHTTPAntiAbuseConfig
|
||||
}
|
||||
|
||||
// AuthServiceConfig describes the optional public-auth upstream used by the
|
||||
// gateway runtime.
|
||||
type AuthServiceConfig struct {
|
||||
// BaseURL is the absolute base URL of the Auth / Session Service public
|
||||
// HTTP API. When BaseURL is empty, the gateway keeps using its built-in
|
||||
// unavailable public-auth adapter.
|
||||
BaseURL string
|
||||
}
|
||||
// BackendConfig describes the consolidated backend service the gateway
|
||||
// talks to. Every authenticated and public HTTP request is forwarded to
|
||||
// `HTTPBaseURL`; the gRPC `Push.SubscribePush` stream is opened against
|
||||
// `GRPCPushURL`.
|
||||
type BackendConfig struct {
|
||||
// HTTPBaseURL is the absolute base URL of the backend HTTP listener
|
||||
// (`/api/v1/{public,user,internal}/*`). Required.
|
||||
HTTPBaseURL string
|
||||
|
||||
// UserServiceConfig describes the optional authenticated self-service upstream
|
||||
// used by the gateway runtime.
|
||||
type UserServiceConfig struct {
|
||||
// BaseURL is the absolute base URL of the User Service internal HTTP API.
|
||||
// When BaseURL is empty, the gateway keeps using its built-in unavailable
|
||||
// downstream adapter for the reserved `user.*` routes.
|
||||
BaseURL string
|
||||
}
|
||||
// GRPCPushURL is the dial target of the backend `Push.SubscribePush`
|
||||
// listener (`host:port`). Required.
|
||||
GRPCPushURL string
|
||||
|
||||
// LobbyServiceConfig describes the optional authenticated platform-command
|
||||
// upstream used by the gateway runtime.
|
||||
type LobbyServiceConfig struct {
|
||||
// BaseURL is the absolute base URL of the Game Lobby public HTTP API.
|
||||
// When BaseURL is empty, the gateway keeps using its built-in unavailable
|
||||
// downstream adapter for the reserved `lobby.*` routes.
|
||||
BaseURL string
|
||||
// GatewayClientID is the durable identifier this gateway instance
|
||||
// presents to backend in `GatewaySubscribeRequest.gateway_client_id`.
|
||||
// Required.
|
||||
GatewayClientID string
|
||||
|
||||
// HTTPTimeout bounds individual REST calls. Must be positive.
|
||||
HTTPTimeout time.Duration
|
||||
|
||||
// PushReconnectBaseBackoff is the starting delay between reconnect
|
||||
// attempts of `Push.SubscribePush`. Must be positive.
|
||||
PushReconnectBaseBackoff time.Duration
|
||||
|
||||
// PushReconnectMaxBackoff is the upper bound for exponential
|
||||
// reconnect delays. Must be greater than or equal to
|
||||
// PushReconnectBaseBackoff.
|
||||
PushReconnectMaxBackoff time.Duration
|
||||
}
|
||||
|
||||
// AdminHTTPConfig describes the private operational HTTP listener used for
|
||||
@@ -531,18 +521,6 @@ type AuthenticatedGRPCConfig struct {
|
||||
AntiAbuse AuthenticatedGRPCAntiAbuseConfig
|
||||
}
|
||||
|
||||
// SessionCacheRedisConfig describes the namespace and timeout used for
|
||||
// authenticated SessionCache lookups. Connection topology is shared with the
|
||||
// other Redis-backed gateway components and lives on Config.Redis (see
|
||||
// `pkg/redisconn`).
|
||||
type SessionCacheRedisConfig struct {
|
||||
// KeyPrefix is prepended to every SessionCache Redis key.
|
||||
KeyPrefix string
|
||||
|
||||
// LookupTimeout bounds individual SessionCache Redis operations.
|
||||
LookupTimeout time.Duration
|
||||
}
|
||||
|
||||
// ReplayRedisConfig describes the Redis namespace and timeout used for
|
||||
// authenticated replay reservations.
|
||||
type ReplayRedisConfig struct {
|
||||
@@ -553,29 +531,6 @@ type ReplayRedisConfig struct {
|
||||
ReserveTimeout time.Duration
|
||||
}
|
||||
|
||||
// SessionEventsRedisConfig describes the Redis Stream consumed by the gateway
|
||||
// to keep the process-local session cache synchronized with session lifecycle
|
||||
// updates.
|
||||
type SessionEventsRedisConfig struct {
|
||||
// Stream is the Redis Stream key carrying full session snapshot events.
|
||||
Stream string
|
||||
|
||||
// ReadBlockTimeout bounds one blocking XREAD call so shutdown remains
|
||||
// responsive even when the stream is idle.
|
||||
ReadBlockTimeout time.Duration
|
||||
}
|
||||
|
||||
// ClientEventsRedisConfig describes the Redis Stream consumed by the gateway
|
||||
// to deliver client-facing events to active push streams.
|
||||
type ClientEventsRedisConfig struct {
|
||||
// Stream is the Redis Stream key carrying client-facing event entries.
|
||||
Stream string
|
||||
|
||||
// ReadBlockTimeout bounds one blocking XREAD call so shutdown remains
|
||||
// responsive even when the stream is idle.
|
||||
ReadBlockTimeout time.Duration
|
||||
}
|
||||
|
||||
// ResponseSignerConfig describes the private-key material used to sign
|
||||
// authenticated unary responses and stream events.
|
||||
type ResponseSignerConfig struct {
|
||||
@@ -603,17 +558,10 @@ type Config struct {
|
||||
// PublicHTTP configures the public unauthenticated REST listener.
|
||||
PublicHTTP PublicHTTPConfig
|
||||
|
||||
// AuthService configures the optional public-auth delegation to the Auth /
|
||||
// Session Service.
|
||||
AuthService AuthServiceConfig
|
||||
|
||||
// UserService configures the optional authenticated self-service
|
||||
// delegation to User Service.
|
||||
UserService UserServiceConfig
|
||||
|
||||
// LobbyService configures the optional authenticated platform-command
|
||||
// delegation to Game Lobby.
|
||||
LobbyService LobbyServiceConfig
|
||||
// Backend configures the consolidated backend the gateway forwards
|
||||
// every public auth and authenticated user/lobby request to and the
|
||||
// gRPC `Push.SubscribePush` stream consumed for inbound events.
|
||||
Backend BackendConfig
|
||||
|
||||
// AdminHTTP configures the optional private admin listener used for metrics
|
||||
// exposure.
|
||||
@@ -622,25 +570,16 @@ type Config struct {
|
||||
// AuthenticatedGRPC configures the authenticated gRPC listener.
|
||||
AuthenticatedGRPC AuthenticatedGRPCConfig
|
||||
|
||||
// Redis carries the master/replica/password connection topology shared by
|
||||
// every gateway Redis component, sourced from the GATEWAY_REDIS_*
|
||||
// environment variables managed by `pkg/redisconn`.
|
||||
// Redis carries the master/replica/password connection topology used
|
||||
// by the anti-replay reservation store, sourced from the
|
||||
// GATEWAY_REDIS_* environment variables managed by `pkg/redisconn`.
|
||||
// The implementation dropped session cache projection and the two Redis
|
||||
// Streams; Redis is now used only for replay reservations.
|
||||
Redis redisconn.Config
|
||||
|
||||
// SessionCacheRedis configures the Redis-backed authenticated SessionCache.
|
||||
SessionCacheRedis SessionCacheRedisConfig
|
||||
|
||||
// ReplayRedis configures the Redis-backed authenticated ReplayStore.
|
||||
ReplayRedis ReplayRedisConfig
|
||||
|
||||
// SessionEventsRedis configures the Redis Stream consumed for session cache
|
||||
// updates and revocations.
|
||||
SessionEventsRedis SessionEventsRedisConfig
|
||||
|
||||
// ClientEventsRedis configures the Redis Stream consumed for client-facing
|
||||
// push delivery.
|
||||
ClientEventsRedis ClientEventsRedisConfig
|
||||
|
||||
// ResponseSigner configures the authenticated response and event signer
|
||||
// loaded during startup.
|
||||
ResponseSigner ResponseSignerConfig
|
||||
@@ -650,53 +589,53 @@ type Config struct {
|
||||
// for the public REST surface.
|
||||
func DefaultPublicHTTPConfig() PublicHTTPConfig {
|
||||
return PublicHTTPConfig{
|
||||
Addr: defaultPublicHTTPAddr,
|
||||
ReadHeaderTimeout: defaultPublicHTTPReadHeaderTimeout,
|
||||
ReadTimeout: defaultPublicHTTPReadTimeout,
|
||||
IdleTimeout: defaultPublicHTTPIdleTimeout,
|
||||
Addr: defaultPublicHTTPAddr,
|
||||
ReadHeaderTimeout: defaultPublicHTTPReadHeaderTimeout,
|
||||
ReadTimeout: defaultPublicHTTPReadTimeout,
|
||||
IdleTimeout: defaultPublicHTTPIdleTimeout,
|
||||
AuthUpstreamTimeout: defaultPublicAuthUpstreamTimeout,
|
||||
AntiAbuse: PublicHTTPAntiAbuseConfig{
|
||||
PublicAuth: PublicRoutePolicyConfig{
|
||||
MaxBodyBytes: defaultPublicAuthMaxBodyBytes,
|
||||
RateLimit: PublicRateLimitConfig{
|
||||
Requests: defaultPublicAuthRateLimitRequests,
|
||||
Window: defaultClassRateLimitWindow,
|
||||
Burst: defaultPublicAuthRateLimitBurst,
|
||||
Window: defaultClassRateLimitWindow,
|
||||
Burst: defaultPublicAuthRateLimitBurst,
|
||||
},
|
||||
},
|
||||
BrowserBootstrap: PublicRoutePolicyConfig{
|
||||
RateLimit: PublicRateLimitConfig{
|
||||
Requests: defaultBrowserBootstrapRateLimitRequests,
|
||||
Window: defaultClassRateLimitWindow,
|
||||
Burst: defaultBrowserBootstrapRateLimitBurst,
|
||||
Window: defaultClassRateLimitWindow,
|
||||
Burst: defaultBrowserBootstrapRateLimitBurst,
|
||||
},
|
||||
},
|
||||
BrowserAsset: PublicRoutePolicyConfig{
|
||||
RateLimit: PublicRateLimitConfig{
|
||||
Requests: defaultBrowserAssetRateLimitRequests,
|
||||
Window: defaultClassRateLimitWindow,
|
||||
Burst: defaultBrowserAssetRateLimitBurst,
|
||||
Window: defaultClassRateLimitWindow,
|
||||
Burst: defaultBrowserAssetRateLimitBurst,
|
||||
},
|
||||
},
|
||||
PublicMisc: PublicRoutePolicyConfig{
|
||||
RateLimit: PublicRateLimitConfig{
|
||||
Requests: defaultPublicMiscRateLimitRequests,
|
||||
Window: defaultClassRateLimitWindow,
|
||||
Burst: defaultPublicMiscRateLimitBurst,
|
||||
Window: defaultClassRateLimitWindow,
|
||||
Burst: defaultPublicMiscRateLimitBurst,
|
||||
},
|
||||
},
|
||||
SendEmailCodeIdentity: PublicAuthIdentityPolicyConfig{
|
||||
RateLimit: PublicRateLimitConfig{
|
||||
Requests: defaultSendEmailCodeIdentityRateLimitRequests,
|
||||
Window: defaultIdentityRateLimitWindow,
|
||||
Burst: defaultSendEmailCodeIdentityRateLimitBurst,
|
||||
Window: defaultIdentityRateLimitWindow,
|
||||
Burst: defaultSendEmailCodeIdentityRateLimitBurst,
|
||||
},
|
||||
},
|
||||
ConfirmEmailCodeIdentity: PublicAuthIdentityPolicyConfig{
|
||||
RateLimit: PublicRateLimitConfig{
|
||||
Requests: defaultConfirmEmailCodeIdentityRateLimitRequests,
|
||||
Window: defaultIdentityRateLimitWindow,
|
||||
Burst: defaultConfirmEmailCodeIdentityRateLimitBurst,
|
||||
Window: defaultIdentityRateLimitWindow,
|
||||
Burst: defaultConfirmEmailCodeIdentityRateLimitBurst,
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -708,8 +647,8 @@ func DefaultPublicHTTPConfig() PublicHTTPConfig {
|
||||
func DefaultAdminHTTPConfig() AdminHTTPConfig {
|
||||
return AdminHTTPConfig{
|
||||
ReadHeaderTimeout: defaultAdminHTTPReadHeaderTimeout,
|
||||
ReadTimeout: defaultAdminHTTPReadTimeout,
|
||||
IdleTimeout: defaultAdminHTTPIdleTimeout,
|
||||
ReadTimeout: defaultAdminHTTPReadTimeout,
|
||||
IdleTimeout: defaultAdminHTTPIdleTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -717,30 +656,30 @@ func DefaultAdminHTTPConfig() AdminHTTPConfig {
|
||||
// anti-abuse settings for the authenticated gRPC surface.
|
||||
func DefaultAuthenticatedGRPCConfig() AuthenticatedGRPCConfig {
|
||||
return AuthenticatedGRPCConfig{
|
||||
Addr: defaultAuthenticatedGRPCAddr,
|
||||
Addr: defaultAuthenticatedGRPCAddr,
|
||||
ConnectionTimeout: defaultAuthenticatedGRPCConnectionTimeout,
|
||||
DownstreamTimeout: defaultAuthenticatedGRPCDownstreamTimeout,
|
||||
FreshnessWindow: defaultAuthenticatedGRPCFreshnessWindow,
|
||||
FreshnessWindow: defaultAuthenticatedGRPCFreshnessWindow,
|
||||
AntiAbuse: AuthenticatedGRPCAntiAbuseConfig{
|
||||
IP: AuthenticatedRateLimitConfig{
|
||||
Requests: defaultAuthenticatedGRPCIPRateLimitRequests,
|
||||
Window: defaultClassRateLimitWindow,
|
||||
Burst: defaultAuthenticatedGRPCIPRateLimitBurst,
|
||||
Window: defaultClassRateLimitWindow,
|
||||
Burst: defaultAuthenticatedGRPCIPRateLimitBurst,
|
||||
},
|
||||
Session: AuthenticatedRateLimitConfig{
|
||||
Requests: defaultAuthenticatedGRPCSessionRateLimitRequests,
|
||||
Window: defaultClassRateLimitWindow,
|
||||
Burst: defaultAuthenticatedGRPCSessionRateLimitBurst,
|
||||
Window: defaultClassRateLimitWindow,
|
||||
Burst: defaultAuthenticatedGRPCSessionRateLimitBurst,
|
||||
},
|
||||
User: AuthenticatedRateLimitConfig{
|
||||
Requests: defaultAuthenticatedGRPCUserRateLimitRequests,
|
||||
Window: defaultClassRateLimitWindow,
|
||||
Burst: defaultAuthenticatedGRPCUserRateLimitBurst,
|
||||
Window: defaultClassRateLimitWindow,
|
||||
Burst: defaultAuthenticatedGRPCUserRateLimitBurst,
|
||||
},
|
||||
MessageClass: AuthenticatedRateLimitConfig{
|
||||
Requests: defaultAuthenticatedGRPCMessageClassRateLimitRequests,
|
||||
Window: defaultClassRateLimitWindow,
|
||||
Burst: defaultAuthenticatedGRPCMessageClassRateLimitBurst,
|
||||
Window: defaultClassRateLimitWindow,
|
||||
Burst: defaultAuthenticatedGRPCMessageClassRateLimitBurst,
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -751,39 +690,23 @@ func DefaultLoggingConfig() LoggingConfig {
|
||||
return LoggingConfig{Level: defaultLogLevel}
|
||||
}
|
||||
|
||||
// DefaultSessionCacheRedisConfig returns the default optional namespace and
|
||||
// timeout settings for the Redis-backed authenticated SessionCache.
|
||||
func DefaultSessionCacheRedisConfig() SessionCacheRedisConfig {
|
||||
return SessionCacheRedisConfig{
|
||||
KeyPrefix: defaultSessionCacheRedisKeyPrefix,
|
||||
LookupTimeout: defaultSessionCacheRedisLookupTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultReplayRedisConfig returns the default Redis key namespace and timeout
|
||||
// used for authenticated replay reservations.
|
||||
func DefaultReplayRedisConfig() ReplayRedisConfig {
|
||||
return ReplayRedisConfig{
|
||||
KeyPrefix: defaultReplayRedisKeyPrefix,
|
||||
KeyPrefix: defaultReplayRedisKeyPrefix,
|
||||
ReserveTimeout: defaultReplayRedisReserveTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultSessionEventsRedisConfig returns the default optional settings for the
|
||||
// session lifecycle event subscriber. Stream remains empty and must be
|
||||
// supplied explicitly.
|
||||
func DefaultSessionEventsRedisConfig() SessionEventsRedisConfig {
|
||||
return SessionEventsRedisConfig{
|
||||
ReadBlockTimeout: defaultSessionEventsRedisReadBlockTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultClientEventsRedisConfig returns the default optional settings for the
|
||||
// client-facing event subscriber. Stream remains empty and must be supplied
|
||||
// explicitly.
|
||||
func DefaultClientEventsRedisConfig() ClientEventsRedisConfig {
|
||||
return ClientEventsRedisConfig{
|
||||
ReadBlockTimeout: defaultClientEventsRedisReadBlockTimeout,
|
||||
// DefaultBackendConfig returns the default backend settings used for the
|
||||
// gateway → backend HTTP and gRPC conversation. URL fields stay empty and
|
||||
// must be supplied explicitly via env vars.
|
||||
func DefaultBackendConfig() BackendConfig {
|
||||
return BackendConfig{
|
||||
HTTPTimeout: defaultBackendHTTPTimeout,
|
||||
PushReconnectBaseBackoff: defaultBackendPushReconnectBaseBackoff,
|
||||
PushReconnectMaxBackoff: defaultBackendPushReconnectMaxBackoff,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -793,44 +716,19 @@ func DefaultResponseSignerConfig() ResponseSignerConfig {
|
||||
return ResponseSignerConfig{}
|
||||
}
|
||||
|
||||
// DefaultAuthServiceConfig returns the default public-auth upstream settings.
|
||||
// The zero value keeps the built-in unavailable adapter active.
|
||||
func DefaultAuthServiceConfig() AuthServiceConfig {
|
||||
return AuthServiceConfig{}
|
||||
}
|
||||
|
||||
// DefaultUserServiceConfig returns the default authenticated self-service
|
||||
// upstream settings. The zero value keeps the built-in unavailable adapter
|
||||
// active for reserved `user.*` routes.
|
||||
func DefaultUserServiceConfig() UserServiceConfig {
|
||||
return UserServiceConfig{}
|
||||
}
|
||||
|
||||
// DefaultLobbyServiceConfig returns the default authenticated platform-command
|
||||
// upstream settings. The zero value keeps the built-in unavailable adapter
|
||||
// active for reserved `lobby.*` routes.
|
||||
func DefaultLobbyServiceConfig() LobbyServiceConfig {
|
||||
return LobbyServiceConfig{}
|
||||
}
|
||||
|
||||
// LoadFromEnv loads Config from the process environment, applies defaults for
|
||||
// omitted settings, and validates the resulting values.
|
||||
func LoadFromEnv() (Config, error) {
|
||||
cfg := Config{
|
||||
ShutdownTimeout: defaultShutdownTimeout,
|
||||
Logging: DefaultLoggingConfig(),
|
||||
PublicHTTP: DefaultPublicHTTPConfig(),
|
||||
AuthService: DefaultAuthServiceConfig(),
|
||||
UserService: DefaultUserServiceConfig(),
|
||||
LobbyService: DefaultLobbyServiceConfig(),
|
||||
AdminHTTP: DefaultAdminHTTPConfig(),
|
||||
AuthenticatedGRPC: DefaultAuthenticatedGRPCConfig(),
|
||||
Redis: redisconn.DefaultConfig(),
|
||||
SessionCacheRedis: DefaultSessionCacheRedisConfig(),
|
||||
ReplayRedis: DefaultReplayRedisConfig(),
|
||||
SessionEventsRedis: DefaultSessionEventsRedisConfig(),
|
||||
ClientEventsRedis: DefaultClientEventsRedisConfig(),
|
||||
ResponseSigner: DefaultResponseSignerConfig(),
|
||||
ShutdownTimeout: defaultShutdownTimeout,
|
||||
Logging: DefaultLoggingConfig(),
|
||||
PublicHTTP: DefaultPublicHTTPConfig(),
|
||||
Backend: DefaultBackendConfig(),
|
||||
AdminHTTP: DefaultAdminHTTPConfig(),
|
||||
AuthenticatedGRPC: DefaultAuthenticatedGRPCConfig(),
|
||||
Redis: redisconn.DefaultConfig(),
|
||||
ReplayRedis: DefaultReplayRedisConfig(),
|
||||
ResponseSigner: DefaultResponseSignerConfig(),
|
||||
}
|
||||
|
||||
rawShutdownTimeout, ok := os.LookupEnv(shutdownTimeoutEnvVar)
|
||||
@@ -876,20 +774,30 @@ func LoadFromEnv() (Config, error) {
|
||||
}
|
||||
cfg.PublicHTTP.AuthUpstreamTimeout = publicAuthUpstreamTimeout
|
||||
|
||||
rawAuthServiceBaseURL, ok := os.LookupEnv(authServiceBaseURLEnvVar)
|
||||
if ok {
|
||||
cfg.AuthService.BaseURL = rawAuthServiceBaseURL
|
||||
if v, ok := os.LookupEnv(backendHTTPURLEnvVar); ok {
|
||||
cfg.Backend.HTTPBaseURL = v
|
||||
}
|
||||
|
||||
rawUserServiceBaseURL, ok := os.LookupEnv(userServiceBaseURLEnvVar)
|
||||
if ok {
|
||||
cfg.UserService.BaseURL = rawUserServiceBaseURL
|
||||
if v, ok := os.LookupEnv(backendGRPCPushURLEnvVar); ok {
|
||||
cfg.Backend.GRPCPushURL = v
|
||||
}
|
||||
|
||||
rawLobbyServiceBaseURL, ok := os.LookupEnv(lobbyServiceBaseURLEnvVar)
|
||||
if ok {
|
||||
cfg.LobbyService.BaseURL = rawLobbyServiceBaseURL
|
||||
if v, ok := os.LookupEnv(backendGatewayClientIDEnvVar); ok {
|
||||
cfg.Backend.GatewayClientID = v
|
||||
}
|
||||
backendHTTPTimeout, err := loadDurationEnvWithDefault(backendHTTPTimeoutEnvVar, cfg.Backend.HTTPTimeout)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
cfg.Backend.HTTPTimeout = backendHTTPTimeout
|
||||
backendPushReconnectBaseBackoff, err := loadDurationEnvWithDefault(backendPushReconnectBaseBackoffEnvVar, cfg.Backend.PushReconnectBaseBackoff)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
cfg.Backend.PushReconnectBaseBackoff = backendPushReconnectBaseBackoff
|
||||
backendPushReconnectMaxBackoff, err := loadDurationEnvWithDefault(backendPushReconnectMaxBackoffEnvVar, cfg.Backend.PushReconnectMaxBackoff)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
cfg.Backend.PushReconnectMaxBackoff = backendPushReconnectMaxBackoff
|
||||
|
||||
rawAdminHTTPAddr, ok := os.LookupEnv(adminHTTPAddrEnvVar)
|
||||
if ok {
|
||||
@@ -987,17 +895,6 @@ func LoadFromEnv() (Config, error) {
|
||||
}
|
||||
cfg.Redis = redisConn
|
||||
|
||||
rawSessionCacheRedisKeyPrefix, ok := os.LookupEnv(sessionCacheRedisKeyPrefixEnvVar)
|
||||
if ok {
|
||||
cfg.SessionCacheRedis.KeyPrefix = rawSessionCacheRedisKeyPrefix
|
||||
}
|
||||
|
||||
sessionCacheRedisLookupTimeout, err := loadDurationEnvWithDefault(sessionCacheRedisLookupTimeoutEnvVar, cfg.SessionCacheRedis.LookupTimeout)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
cfg.SessionCacheRedis.LookupTimeout = sessionCacheRedisLookupTimeout
|
||||
|
||||
rawReplayRedisKeyPrefix, ok := os.LookupEnv(replayRedisKeyPrefixEnvVar)
|
||||
if ok {
|
||||
cfg.ReplayRedis.KeyPrefix = rawReplayRedisKeyPrefix
|
||||
@@ -1009,28 +906,6 @@ func LoadFromEnv() (Config, error) {
|
||||
}
|
||||
cfg.ReplayRedis.ReserveTimeout = replayRedisReserveTimeout
|
||||
|
||||
rawSessionEventsRedisStream, ok := os.LookupEnv(sessionEventsRedisStreamEnvVar)
|
||||
if ok {
|
||||
cfg.SessionEventsRedis.Stream = rawSessionEventsRedisStream
|
||||
}
|
||||
|
||||
sessionEventsRedisReadBlockTimeout, err := loadDurationEnvWithDefault(sessionEventsRedisReadBlockTimeoutEnvVar, cfg.SessionEventsRedis.ReadBlockTimeout)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
cfg.SessionEventsRedis.ReadBlockTimeout = sessionEventsRedisReadBlockTimeout
|
||||
|
||||
rawClientEventsRedisStream, ok := os.LookupEnv(clientEventsRedisStreamEnvVar)
|
||||
if ok {
|
||||
cfg.ClientEventsRedis.Stream = rawClientEventsRedisStream
|
||||
}
|
||||
|
||||
clientEventsRedisReadBlockTimeout, err := loadDurationEnvWithDefault(clientEventsRedisReadBlockTimeoutEnvVar, cfg.ClientEventsRedis.ReadBlockTimeout)
|
||||
if err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
cfg.ClientEventsRedis.ReadBlockTimeout = clientEventsRedisReadBlockTimeout
|
||||
|
||||
rawSignerKeyPath, ok := os.LookupEnv(responseSignerPrivateKeyPEMPathEnvVar)
|
||||
if ok {
|
||||
cfg.ResponseSigner.PrivateKeyPEMPath = rawSignerKeyPath
|
||||
@@ -1127,27 +1002,34 @@ func LoadFromEnv() (Config, error) {
|
||||
if cfg.PublicHTTP.AuthUpstreamTimeout <= 0 {
|
||||
return Config{}, fmt.Errorf("load gateway config: %s must be positive", publicAuthUpstreamTimeoutEnvVar)
|
||||
}
|
||||
cfg.AuthService.BaseURL = strings.TrimSpace(cfg.AuthService.BaseURL)
|
||||
if cfg.AuthService.BaseURL != "" {
|
||||
parsedAuthServiceBaseURL, err := url.Parse(cfg.AuthService.BaseURL)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("load gateway config: parse %s: %w", authServiceBaseURLEnvVar, err)
|
||||
}
|
||||
if parsedAuthServiceBaseURL.Scheme == "" || parsedAuthServiceBaseURL.Host == "" {
|
||||
return Config{}, fmt.Errorf("load gateway config: %s must be an absolute URL", authServiceBaseURLEnvVar)
|
||||
}
|
||||
cfg.AuthService.BaseURL = strings.TrimRight(parsedAuthServiceBaseURL.String(), "/")
|
||||
cfg.Backend.HTTPBaseURL = strings.TrimSpace(cfg.Backend.HTTPBaseURL)
|
||||
if cfg.Backend.HTTPBaseURL == "" {
|
||||
return Config{}, fmt.Errorf("load gateway config: %s must not be empty", backendHTTPURLEnvVar)
|
||||
}
|
||||
cfg.UserService.BaseURL = strings.TrimSpace(cfg.UserService.BaseURL)
|
||||
if cfg.UserService.BaseURL != "" {
|
||||
parsedUserServiceBaseURL, err := url.Parse(cfg.UserService.BaseURL)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("load gateway config: parse %s: %w", userServiceBaseURLEnvVar, err)
|
||||
}
|
||||
if parsedUserServiceBaseURL.Scheme == "" || parsedUserServiceBaseURL.Host == "" {
|
||||
return Config{}, fmt.Errorf("load gateway config: %s must be an absolute URL", userServiceBaseURLEnvVar)
|
||||
}
|
||||
cfg.UserService.BaseURL = strings.TrimRight(parsedUserServiceBaseURL.String(), "/")
|
||||
parsedBackendHTTP, err := url.Parse(strings.TrimRight(cfg.Backend.HTTPBaseURL, "/"))
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("load gateway config: parse %s: %w", backendHTTPURLEnvVar, err)
|
||||
}
|
||||
if parsedBackendHTTP.Scheme == "" || parsedBackendHTTP.Host == "" {
|
||||
return Config{}, fmt.Errorf("load gateway config: %s must be an absolute URL", backendHTTPURLEnvVar)
|
||||
}
|
||||
cfg.Backend.HTTPBaseURL = parsedBackendHTTP.String()
|
||||
cfg.Backend.GRPCPushURL = strings.TrimSpace(cfg.Backend.GRPCPushURL)
|
||||
if cfg.Backend.GRPCPushURL == "" {
|
||||
return Config{}, fmt.Errorf("load gateway config: %s must not be empty", backendGRPCPushURLEnvVar)
|
||||
}
|
||||
cfg.Backend.GatewayClientID = strings.TrimSpace(cfg.Backend.GatewayClientID)
|
||||
if cfg.Backend.GatewayClientID == "" {
|
||||
return Config{}, fmt.Errorf("load gateway config: %s must not be empty", backendGatewayClientIDEnvVar)
|
||||
}
|
||||
if cfg.Backend.HTTPTimeout <= 0 {
|
||||
return Config{}, fmt.Errorf("load gateway config: %s must be positive", backendHTTPTimeoutEnvVar)
|
||||
}
|
||||
if cfg.Backend.PushReconnectBaseBackoff <= 0 {
|
||||
return Config{}, fmt.Errorf("load gateway config: %s must be positive", backendPushReconnectBaseBackoffEnvVar)
|
||||
}
|
||||
if cfg.Backend.PushReconnectMaxBackoff < cfg.Backend.PushReconnectBaseBackoff {
|
||||
return Config{}, fmt.Errorf("load gateway config: %s must be >= %s", backendPushReconnectMaxBackoffEnvVar, backendPushReconnectBaseBackoffEnvVar)
|
||||
}
|
||||
if addr := strings.TrimSpace(cfg.AdminHTTP.Addr); addr != "" {
|
||||
cfg.AdminHTTP.Addr = addr
|
||||
@@ -1208,30 +1090,12 @@ func LoadFromEnv() (Config, error) {
|
||||
if err := cfg.Redis.Validate(); err != nil {
|
||||
return Config{}, fmt.Errorf("load gateway config: redis: %w", err)
|
||||
}
|
||||
if strings.TrimSpace(cfg.SessionCacheRedis.KeyPrefix) == "" {
|
||||
return Config{}, fmt.Errorf("load gateway config: %s must not be empty", sessionCacheRedisKeyPrefixEnvVar)
|
||||
}
|
||||
if cfg.SessionCacheRedis.LookupTimeout <= 0 {
|
||||
return Config{}, fmt.Errorf("load gateway config: %s must be positive", sessionCacheRedisLookupTimeoutEnvVar)
|
||||
}
|
||||
if strings.TrimSpace(cfg.ReplayRedis.KeyPrefix) == "" {
|
||||
return Config{}, fmt.Errorf("load gateway config: %s must not be empty", replayRedisKeyPrefixEnvVar)
|
||||
}
|
||||
if cfg.ReplayRedis.ReserveTimeout <= 0 {
|
||||
return Config{}, fmt.Errorf("load gateway config: %s must be positive", replayRedisReserveTimeoutEnvVar)
|
||||
}
|
||||
if strings.TrimSpace(cfg.SessionEventsRedis.Stream) == "" {
|
||||
return Config{}, fmt.Errorf("load gateway config: %s must not be empty", sessionEventsRedisStreamEnvVar)
|
||||
}
|
||||
if cfg.SessionEventsRedis.ReadBlockTimeout <= 0 {
|
||||
return Config{}, fmt.Errorf("load gateway config: %s must be positive", sessionEventsRedisReadBlockTimeoutEnvVar)
|
||||
}
|
||||
if strings.TrimSpace(cfg.ClientEventsRedis.Stream) == "" {
|
||||
return Config{}, fmt.Errorf("load gateway config: %s must not be empty", clientEventsRedisStreamEnvVar)
|
||||
}
|
||||
if cfg.ClientEventsRedis.ReadBlockTimeout <= 0 {
|
||||
return Config{}, fmt.Errorf("load gateway config: %s must be positive", clientEventsRedisReadBlockTimeoutEnvVar)
|
||||
}
|
||||
if strings.TrimSpace(cfg.ResponseSigner.PrivateKeyPEMPath) == "" {
|
||||
return Config{}, fmt.Errorf("load gateway config: %s must not be empty", responseSignerPrivateKeyPEMPathEnvVar)
|
||||
}
|
||||
|
||||
+159
-1477
File diff suppressed because it is too large
Load Diff
@@ -1,329 +0,0 @@
|
||||
// Package lobbyservice implements the authenticated Gateway -> Game Lobby
|
||||
// downstream adapter. It forwards verified authenticated commands as
|
||||
// trusted-internal HTTP requests against Game Lobby's public REST surface,
|
||||
// transporting the calling user identity through the `X-User-Id` header.
|
||||
package lobbyservice
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"galaxy/gateway/internal/downstream"
|
||||
lobbymodel "galaxy/model/lobby"
|
||||
"galaxy/transcoder"
|
||||
)
|
||||
|
||||
const (
|
||||
myGamesListPath = "/api/v1/lobby/my/games"
|
||||
openEnrollmentPathFormat = "/api/v1/lobby/games/%s/open-enrollment"
|
||||
|
||||
resultCodeOK = "ok"
|
||||
defaultErrorCodeBadRequest = "invalid_request"
|
||||
defaultErrorCodeNotFound = "subject_not_found"
|
||||
defaultErrorCodeForbidden = "forbidden"
|
||||
defaultErrorCodeConflict = "conflict"
|
||||
defaultErrorCodeInternalError = "internal_error"
|
||||
|
||||
headerCallingUserID = "X-User-Id"
|
||||
)
|
||||
|
||||
var stableErrorMessages = map[string]string{
|
||||
defaultErrorCodeBadRequest: "request is invalid",
|
||||
defaultErrorCodeNotFound: "subject not found",
|
||||
defaultErrorCodeForbidden: "operation is forbidden for the calling user",
|
||||
defaultErrorCodeConflict: "request conflicts with current state",
|
||||
defaultErrorCodeInternalError: "internal server error",
|
||||
}
|
||||
|
||||
// HTTPClient implements downstream.Client against the trusted Game Lobby
|
||||
// public REST API while preserving FlatBuffers at the external authenticated
|
||||
// gateway boundary.
|
||||
type HTTPClient struct {
|
||||
baseURL string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewHTTPClient constructs one Game Lobby downstream client backed by the
|
||||
// public REST API at baseURL.
|
||||
func NewHTTPClient(baseURL string) (*HTTPClient, error) {
|
||||
transport, ok := http.DefaultTransport.(*http.Transport)
|
||||
if !ok {
|
||||
return nil, errors.New("new lobby service HTTP client: default transport is not *http.Transport")
|
||||
}
|
||||
|
||||
return newHTTPClient(baseURL, &http.Client{
|
||||
Transport: transport.Clone(),
|
||||
})
|
||||
}
|
||||
|
||||
func newHTTPClient(baseURL string, httpClient *http.Client) (*HTTPClient, error) {
|
||||
if httpClient == nil {
|
||||
return nil, errors.New("new lobby service HTTP client: http client must not be nil")
|
||||
}
|
||||
|
||||
trimmedBaseURL := strings.TrimSpace(baseURL)
|
||||
if trimmedBaseURL == "" {
|
||||
return nil, errors.New("new lobby service HTTP client: base URL must not be empty")
|
||||
}
|
||||
|
||||
parsedBaseURL, err := url.Parse(strings.TrimRight(trimmedBaseURL, "/"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new lobby service HTTP client: parse base URL: %w", err)
|
||||
}
|
||||
if parsedBaseURL.Scheme == "" || parsedBaseURL.Host == "" {
|
||||
return nil, errors.New("new lobby service HTTP client: base URL must be absolute")
|
||||
}
|
||||
|
||||
return &HTTPClient{
|
||||
baseURL: parsedBaseURL.String(),
|
||||
httpClient: httpClient,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close releases idle HTTP connections owned by the client transport.
|
||||
func (c *HTTPClient) Close() error {
|
||||
if c == nil || c.httpClient == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
type idleCloser interface {
|
||||
CloseIdleConnections()
|
||||
}
|
||||
|
||||
if transport, ok := c.httpClient.Transport.(idleCloser); ok {
|
||||
transport.CloseIdleConnections()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExecuteCommand routes one authenticated gateway command to the matching
|
||||
// trusted Game Lobby public REST route.
|
||||
func (c *HTTPClient) ExecuteCommand(ctx context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
|
||||
if c == nil || c.httpClient == nil {
|
||||
return downstream.UnaryResult{}, errors.New("execute lobby service command: nil client")
|
||||
}
|
||||
if ctx == nil {
|
||||
return downstream.UnaryResult{}, errors.New("execute lobby service command: nil context")
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return downstream.UnaryResult{}, err
|
||||
}
|
||||
if strings.TrimSpace(command.UserID) == "" {
|
||||
return downstream.UnaryResult{}, errors.New("execute lobby service command: user_id must not be empty")
|
||||
}
|
||||
|
||||
switch command.MessageType {
|
||||
case lobbymodel.MessageTypeMyGamesList:
|
||||
if _, err := transcoder.PayloadToMyGamesListRequest(command.PayloadBytes); err != nil {
|
||||
return downstream.UnaryResult{}, fmt.Errorf("execute lobby service command %q: %w", command.MessageType, err)
|
||||
}
|
||||
return c.executeMyGamesList(ctx, command.UserID)
|
||||
case lobbymodel.MessageTypeOpenEnrollment:
|
||||
request, err := transcoder.PayloadToOpenEnrollmentRequest(command.PayloadBytes)
|
||||
if err != nil {
|
||||
return downstream.UnaryResult{}, fmt.Errorf("execute lobby service command %q: %w", command.MessageType, err)
|
||||
}
|
||||
return c.executeOpenEnrollment(ctx, command.UserID, request)
|
||||
default:
|
||||
return downstream.UnaryResult{}, fmt.Errorf("execute lobby service command: unsupported message type %q", command.MessageType)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HTTPClient) executeMyGamesList(ctx context.Context, userID string) (downstream.UnaryResult, error) {
|
||||
payload, statusCode, err := c.doRequest(ctx, http.MethodGet, c.baseURL+myGamesListPath, userID, nil)
|
||||
if err != nil {
|
||||
return downstream.UnaryResult{}, fmt.Errorf("execute my games list: %w", err)
|
||||
}
|
||||
|
||||
if statusCode == http.StatusOK {
|
||||
var response lobbymodel.MyGamesListResponse
|
||||
if err := decodeStrictJSONPayload(payload, &response); err != nil {
|
||||
return downstream.UnaryResult{}, fmt.Errorf("decode success response: %w", err)
|
||||
}
|
||||
payloadBytes, err := transcoder.MyGamesListResponseToPayload(&response)
|
||||
if err != nil {
|
||||
return downstream.UnaryResult{}, fmt.Errorf("encode success response payload: %w", err)
|
||||
}
|
||||
return downstream.UnaryResult{
|
||||
ResultCode: resultCodeOK,
|
||||
PayloadBytes: payloadBytes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return projectErrorResponse(statusCode, payload)
|
||||
}
|
||||
|
||||
func (c *HTTPClient) executeOpenEnrollment(ctx context.Context, userID string, request *lobbymodel.OpenEnrollmentRequest) (downstream.UnaryResult, error) {
|
||||
if request == nil || strings.TrimSpace(request.GameID) == "" {
|
||||
return downstream.UnaryResult{}, errors.New("execute open enrollment: game_id must not be empty")
|
||||
}
|
||||
|
||||
target := c.baseURL + fmt.Sprintf(openEnrollmentPathFormat, url.PathEscape(request.GameID))
|
||||
payload, statusCode, err := c.doRequest(ctx, http.MethodPost, target, userID, struct{}{})
|
||||
if err != nil {
|
||||
return downstream.UnaryResult{}, fmt.Errorf("execute open enrollment: %w", err)
|
||||
}
|
||||
|
||||
if statusCode == http.StatusOK {
|
||||
// Lobby's open-enrollment endpoint returns the full game record;
|
||||
// the gateway boundary projects the minimal status pair.
|
||||
var fullRecord struct {
|
||||
GameID string `json:"game_id"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
if err := json.Unmarshal(payload, &fullRecord); err != nil {
|
||||
return downstream.UnaryResult{}, fmt.Errorf("decode success response: %w", err)
|
||||
}
|
||||
payloadBytes, err := transcoder.OpenEnrollmentResponseToPayload(&lobbymodel.OpenEnrollmentResponse{
|
||||
GameID: fullRecord.GameID,
|
||||
Status: fullRecord.Status,
|
||||
})
|
||||
if err != nil {
|
||||
return downstream.UnaryResult{}, fmt.Errorf("encode success response payload: %w", err)
|
||||
}
|
||||
return downstream.UnaryResult{
|
||||
ResultCode: resultCodeOK,
|
||||
PayloadBytes: payloadBytes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return projectErrorResponse(statusCode, payload)
|
||||
}
|
||||
|
||||
func (c *HTTPClient) doRequest(ctx context.Context, method, targetURL, userID string, requestBody any) ([]byte, int, error) {
|
||||
if c == nil || c.httpClient == nil {
|
||||
return nil, 0, errors.New("nil client")
|
||||
}
|
||||
|
||||
var bodyReader io.Reader
|
||||
if requestBody != nil {
|
||||
body, err := json.Marshal(requestBody)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("marshal request body: %w", err)
|
||||
}
|
||||
bodyReader = bytes.NewReader(body)
|
||||
}
|
||||
|
||||
request, err := http.NewRequestWithContext(ctx, method, targetURL, bodyReader)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("build request: %w", err)
|
||||
}
|
||||
if requestBody != nil {
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
request.Header.Set(headerCallingUserID, userID)
|
||||
|
||||
response, err := c.httpClient.Do(request)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
payload, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("read response body: %w", err)
|
||||
}
|
||||
|
||||
return payload, response.StatusCode, nil
|
||||
}
|
||||
|
||||
func projectErrorResponse(statusCode int, payload []byte) (downstream.UnaryResult, error) {
|
||||
switch {
|
||||
case statusCode == http.StatusServiceUnavailable:
|
||||
return downstream.UnaryResult{}, downstream.ErrDownstreamUnavailable
|
||||
case statusCode >= 400 && statusCode <= 599:
|
||||
errorResponse, err := decodeLobbyError(statusCode, payload)
|
||||
if err != nil {
|
||||
return downstream.UnaryResult{}, fmt.Errorf("decode error response: %w", err)
|
||||
}
|
||||
payloadBytes, err := transcoder.LobbyErrorResponseToPayload(errorResponse)
|
||||
if err != nil {
|
||||
return downstream.UnaryResult{}, fmt.Errorf("encode error response payload: %w", err)
|
||||
}
|
||||
return downstream.UnaryResult{
|
||||
ResultCode: errorResponse.Error.Code,
|
||||
PayloadBytes: payloadBytes,
|
||||
}, nil
|
||||
default:
|
||||
return downstream.UnaryResult{}, fmt.Errorf("unexpected HTTP status %d", statusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func decodeLobbyError(statusCode int, payload []byte) (*lobbymodel.ErrorResponse, error) {
|
||||
var response lobbymodel.ErrorResponse
|
||||
if err := decodeStrictJSONPayload(payload, &response); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
response.Error.Code = normalizeErrorCode(statusCode, response.Error.Code)
|
||||
response.Error.Message = normalizeErrorMessage(response.Error.Code, response.Error.Message)
|
||||
|
||||
if strings.TrimSpace(response.Error.Code) == "" {
|
||||
return nil, errors.New("missing error code")
|
||||
}
|
||||
if strings.TrimSpace(response.Error.Message) == "" {
|
||||
return nil, errors.New("missing error message")
|
||||
}
|
||||
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
func normalizeErrorCode(statusCode int, code string) string {
|
||||
trimmed := strings.TrimSpace(code)
|
||||
if trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
|
||||
switch statusCode {
|
||||
case http.StatusBadRequest:
|
||||
return defaultErrorCodeBadRequest
|
||||
case http.StatusForbidden:
|
||||
return defaultErrorCodeForbidden
|
||||
case http.StatusNotFound:
|
||||
return defaultErrorCodeNotFound
|
||||
case http.StatusConflict:
|
||||
return defaultErrorCodeConflict
|
||||
default:
|
||||
return defaultErrorCodeInternalError
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeErrorMessage(code, message string) string {
|
||||
trimmed := strings.TrimSpace(message)
|
||||
if trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
|
||||
if stable, ok := stableErrorMessages[code]; ok {
|
||||
return stable
|
||||
}
|
||||
|
||||
return stableErrorMessages[defaultErrorCodeInternalError]
|
||||
}
|
||||
|
||||
func decodeStrictJSONPayload(payload []byte, target any) error {
|
||||
decoder := json.NewDecoder(bytes.NewReader(payload))
|
||||
decoder.DisallowUnknownFields()
|
||||
|
||||
if err := decoder.Decode(target); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := decoder.Decode(&struct{}{}); err != io.EOF {
|
||||
if err == nil {
|
||||
return errors.New("unexpected trailing JSON input")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ downstream.Client = (*HTTPClient)(nil)
|
||||
@@ -1,212 +0,0 @@
|
||||
package lobbyservice_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/downstream"
|
||||
"galaxy/gateway/internal/downstream/lobbyservice"
|
||||
lobbymodel "galaxy/model/lobby"
|
||||
"galaxy/transcoder"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExecuteMyGamesListSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
expectedResponse := lobbymodel.MyGamesListResponse{
|
||||
Items: []lobbymodel.GameSummary{
|
||||
{
|
||||
GameID: "game-1",
|
||||
GameName: "Nebula Clash",
|
||||
GameType: "private",
|
||||
Status: "draft",
|
||||
OwnerUserID: "user-1",
|
||||
MinPlayers: 2,
|
||||
MaxPlayers: 8,
|
||||
EnrollmentEndsAt: time.Date(2026, 5, 1, 12, 0, 0, 0, time.UTC),
|
||||
CreatedAt: time.Date(2026, 4, 28, 9, 0, 0, 0, time.UTC),
|
||||
UpdatedAt: time.Date(2026, 4, 28, 9, 5, 0, 0, time.UTC),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, http.MethodGet, r.Method)
|
||||
assert.Equal(t, "/api/v1/lobby/my/games", r.URL.Path)
|
||||
assert.Equal(t, "user-1", r.Header.Get("X-User-Id"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
require.NoError(t, json.NewEncoder(w).Encode(expectedResponse))
|
||||
}))
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
client, err := lobbyservice.NewHTTPClient(server.URL)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, client.Close()) })
|
||||
|
||||
requestBytes, err := transcoder.MyGamesListRequestToPayload(&lobbymodel.MyGamesListRequest{})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := client.ExecuteCommand(context.Background(), downstream.AuthenticatedCommand{
|
||||
MessageType: lobbymodel.MessageTypeMyGamesList,
|
||||
UserID: "user-1",
|
||||
PayloadBytes: requestBytes,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", result.ResultCode)
|
||||
|
||||
decoded, err := transcoder.PayloadToMyGamesListResponse(result.PayloadBytes)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, decoded.Items, 1)
|
||||
assert.Equal(t, expectedResponse.Items[0].GameID, decoded.Items[0].GameID)
|
||||
assert.Equal(t, expectedResponse.Items[0].OwnerUserID, decoded.Items[0].OwnerUserID)
|
||||
assert.Equal(t, expectedResponse.Items[0].MinPlayers, decoded.Items[0].MinPlayers)
|
||||
}
|
||||
|
||||
func TestExecuteOpenEnrollmentSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, http.MethodPost, r.Method)
|
||||
assert.Equal(t, "/api/v1/lobby/games/game-77/open-enrollment", r.URL.Path)
|
||||
assert.Equal(t, "owner-1", r.Header.Get("X-User-Id"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
|
||||
"game_id": "game-77",
|
||||
"status": "enrollment_open",
|
||||
}))
|
||||
}))
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
client, err := lobbyservice.NewHTTPClient(server.URL)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, client.Close()) })
|
||||
|
||||
requestBytes, err := transcoder.OpenEnrollmentRequestToPayload(&lobbymodel.OpenEnrollmentRequest{GameID: "game-77"})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := client.ExecuteCommand(context.Background(), downstream.AuthenticatedCommand{
|
||||
MessageType: lobbymodel.MessageTypeOpenEnrollment,
|
||||
UserID: "owner-1",
|
||||
PayloadBytes: requestBytes,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", result.ResultCode)
|
||||
|
||||
decoded, err := transcoder.PayloadToOpenEnrollmentResponse(result.PayloadBytes)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "game-77", decoded.GameID)
|
||||
assert.Equal(t, "enrollment_open", decoded.Status)
|
||||
}
|
||||
|
||||
func TestExecuteOpenEnrollmentForbiddenProjectsErrorEnvelope(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
|
||||
"error": map[string]string{
|
||||
"code": "forbidden",
|
||||
"message": "only the game owner may open enrollment",
|
||||
},
|
||||
}))
|
||||
}))
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
client, err := lobbyservice.NewHTTPClient(server.URL)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, client.Close()) })
|
||||
|
||||
requestBytes, err := transcoder.OpenEnrollmentRequestToPayload(&lobbymodel.OpenEnrollmentRequest{GameID: "game-77"})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := client.ExecuteCommand(context.Background(), downstream.AuthenticatedCommand{
|
||||
MessageType: lobbymodel.MessageTypeOpenEnrollment,
|
||||
UserID: "non-owner",
|
||||
PayloadBytes: requestBytes,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "forbidden", result.ResultCode)
|
||||
|
||||
decoded, err := transcoder.PayloadToLobbyErrorResponse(result.PayloadBytes)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "forbidden", decoded.Error.Code)
|
||||
assert.NotEmpty(t, decoded.Error.Message)
|
||||
}
|
||||
|
||||
func TestExecuteCommandUnavailableProjectsErrUnavailable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
}))
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
client, err := lobbyservice.NewHTTPClient(server.URL)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, client.Close()) })
|
||||
|
||||
requestBytes, err := transcoder.MyGamesListRequestToPayload(&lobbymodel.MyGamesListRequest{})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), downstream.AuthenticatedCommand{
|
||||
MessageType: lobbymodel.MessageTypeMyGamesList,
|
||||
UserID: "user-1",
|
||||
PayloadBytes: requestBytes,
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, downstream.ErrDownstreamUnavailable))
|
||||
}
|
||||
|
||||
func TestExecuteCommandRejectsEmptyUserID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, err := lobbyservice.NewHTTPClient("http://127.0.0.1:1")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, client.Close()) })
|
||||
|
||||
requestBytes, err := transcoder.MyGamesListRequestToPayload(&lobbymodel.MyGamesListRequest{})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), downstream.AuthenticatedCommand{
|
||||
MessageType: lobbymodel.MessageTypeMyGamesList,
|
||||
UserID: "",
|
||||
PayloadBytes: requestBytes,
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.True(t, strings.Contains(err.Error(), "user_id"), "error must mention user_id; got %q", err.Error())
|
||||
}
|
||||
|
||||
func TestNewRoutesReservesUnavailableClientWhenBaseURLEmpty(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
routes, closeFn, err := lobbyservice.NewRoutes("")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, closeFn()) })
|
||||
|
||||
require.Contains(t, routes, lobbymodel.MessageTypeMyGamesList)
|
||||
require.Contains(t, routes, lobbymodel.MessageTypeOpenEnrollment)
|
||||
|
||||
requestBytes, err := transcoder.MyGamesListRequestToPayload(&lobbymodel.MyGamesListRequest{})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = routes[lobbymodel.MessageTypeMyGamesList].ExecuteCommand(
|
||||
context.Background(),
|
||||
downstream.AuthenticatedCommand{
|
||||
MessageType: lobbymodel.MessageTypeMyGamesList,
|
||||
UserID: "user-1",
|
||||
PayloadBytes: requestBytes,
|
||||
},
|
||||
)
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, downstream.ErrDownstreamUnavailable))
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
package lobbyservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"galaxy/gateway/internal/downstream"
|
||||
lobbymodel "galaxy/model/lobby"
|
||||
)
|
||||
|
||||
var noOpClose = func() error { return nil }
|
||||
|
||||
// NewRoutes returns the reserved authenticated gateway routes owned by
|
||||
// the Gateway -> Game Lobby boundary.
|
||||
//
|
||||
// When baseURL is empty, the returned routes still reserve the stable
|
||||
// `lobby.*` message types but resolve them to a dependency-unavailable
|
||||
// client so callers receive the transport-level unavailable outcome
|
||||
// instead of a route-miss error.
|
||||
func NewRoutes(baseURL string) (map[string]downstream.Client, func() error, error) {
|
||||
client := downstream.Client(unavailableClient{})
|
||||
closeFn := noOpClose
|
||||
|
||||
if baseURL != "" {
|
||||
httpClient, err := NewHTTPClient(baseURL)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
client = httpClient
|
||||
closeFn = httpClient.Close
|
||||
}
|
||||
|
||||
return map[string]downstream.Client{
|
||||
lobbymodel.MessageTypeMyGamesList: client,
|
||||
lobbymodel.MessageTypeOpenEnrollment: client,
|
||||
}, closeFn, nil
|
||||
}
|
||||
|
||||
type unavailableClient struct{}
|
||||
|
||||
func (unavailableClient) ExecuteCommand(context.Context, downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
|
||||
return downstream.UnaryResult{}, downstream.ErrDownstreamUnavailable
|
||||
}
|
||||
|
||||
var _ downstream.Client = unavailableClient{}
|
||||
@@ -1,311 +0,0 @@
|
||||
// Package userservice implements the authenticated Gateway -> User Service
|
||||
// self-service downstream adapter.
|
||||
package userservice
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"galaxy/gateway/internal/downstream"
|
||||
usermodel "galaxy/model/user"
|
||||
"galaxy/transcoder"
|
||||
)
|
||||
|
||||
const (
|
||||
getMyAccountResultCodeOK = "ok"
|
||||
|
||||
userServiceAccountPathSuffix = "/account"
|
||||
userServiceProfilePathSuffix = "/profile"
|
||||
userServiceSettingsPathSuffix = "/settings"
|
||||
)
|
||||
|
||||
var stableErrorMessages = map[string]string{
|
||||
"invalid_request": "request is invalid",
|
||||
"subject_not_found": "subject not found",
|
||||
"conflict": "request conflicts with current state",
|
||||
"internal_error": "internal server error",
|
||||
}
|
||||
|
||||
// HTTPClient implements downstream.Client against the trusted internal User
|
||||
// Service REST API while preserving FlatBuffers at the external authenticated
|
||||
// gateway boundary.
|
||||
type HTTPClient struct {
|
||||
baseURL string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewHTTPClient constructs one User Service downstream client backed by the
|
||||
// trusted internal REST API at baseURL.
|
||||
func NewHTTPClient(baseURL string) (*HTTPClient, error) {
|
||||
transport, ok := http.DefaultTransport.(*http.Transport)
|
||||
if !ok {
|
||||
return nil, errors.New("new user service HTTP client: default transport is not *http.Transport")
|
||||
}
|
||||
|
||||
return newHTTPClient(baseURL, &http.Client{
|
||||
Transport: transport.Clone(),
|
||||
})
|
||||
}
|
||||
|
||||
func newHTTPClient(baseURL string, httpClient *http.Client) (*HTTPClient, error) {
|
||||
if httpClient == nil {
|
||||
return nil, errors.New("new user service HTTP client: http client must not be nil")
|
||||
}
|
||||
|
||||
trimmedBaseURL := strings.TrimSpace(baseURL)
|
||||
if trimmedBaseURL == "" {
|
||||
return nil, errors.New("new user service HTTP client: base URL must not be empty")
|
||||
}
|
||||
|
||||
parsedBaseURL, err := url.Parse(strings.TrimRight(trimmedBaseURL, "/"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new user service HTTP client: parse base URL: %w", err)
|
||||
}
|
||||
if parsedBaseURL.Scheme == "" || parsedBaseURL.Host == "" {
|
||||
return nil, errors.New("new user service HTTP client: base URL must be absolute")
|
||||
}
|
||||
|
||||
return &HTTPClient{
|
||||
baseURL: parsedBaseURL.String(),
|
||||
httpClient: httpClient,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close releases idle HTTP connections owned by the client transport.
|
||||
func (c *HTTPClient) Close() error {
|
||||
if c == nil || c.httpClient == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
type idleCloser interface {
|
||||
CloseIdleConnections()
|
||||
}
|
||||
|
||||
if transport, ok := c.httpClient.Transport.(idleCloser); ok {
|
||||
transport.CloseIdleConnections()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExecuteCommand routes one authenticated gateway command to the matching
|
||||
// trusted internal User Service self-service route.
|
||||
func (c *HTTPClient) ExecuteCommand(ctx context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
|
||||
if c == nil || c.httpClient == nil {
|
||||
return downstream.UnaryResult{}, errors.New("execute user service command: nil client")
|
||||
}
|
||||
if ctx == nil {
|
||||
return downstream.UnaryResult{}, errors.New("execute user service command: nil context")
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return downstream.UnaryResult{}, err
|
||||
}
|
||||
if strings.TrimSpace(command.UserID) == "" {
|
||||
return downstream.UnaryResult{}, errors.New("execute user service command: user_id must not be empty")
|
||||
}
|
||||
|
||||
switch command.MessageType {
|
||||
case usermodel.MessageTypeGetMyAccount:
|
||||
if _, err := transcoder.PayloadToGetMyAccountRequest(command.PayloadBytes); err != nil {
|
||||
return downstream.UnaryResult{}, fmt.Errorf("execute user service command %q: %w", command.MessageType, err)
|
||||
}
|
||||
return c.executeGetMyAccount(ctx, command.UserID)
|
||||
case usermodel.MessageTypeUpdateMyProfile:
|
||||
request, err := transcoder.PayloadToUpdateMyProfileRequest(command.PayloadBytes)
|
||||
if err != nil {
|
||||
return downstream.UnaryResult{}, fmt.Errorf("execute user service command %q: %w", command.MessageType, err)
|
||||
}
|
||||
return c.executeUpdateMyProfile(ctx, command.UserID, request)
|
||||
case usermodel.MessageTypeUpdateMySettings:
|
||||
request, err := transcoder.PayloadToUpdateMySettingsRequest(command.PayloadBytes)
|
||||
if err != nil {
|
||||
return downstream.UnaryResult{}, fmt.Errorf("execute user service command %q: %w", command.MessageType, err)
|
||||
}
|
||||
return c.executeUpdateMySettings(ctx, command.UserID, request)
|
||||
default:
|
||||
return downstream.UnaryResult{}, fmt.Errorf("execute user service command: unsupported message type %q", command.MessageType)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HTTPClient) executeGetMyAccount(ctx context.Context, userID string) (downstream.UnaryResult, error) {
|
||||
payload, statusCode, err := c.doRequest(ctx, http.MethodGet, c.userPath(userID, userServiceAccountPathSuffix), nil)
|
||||
if err != nil {
|
||||
return downstream.UnaryResult{}, fmt.Errorf("execute get my account: %w", err)
|
||||
}
|
||||
|
||||
return projectResponse(statusCode, payload)
|
||||
}
|
||||
|
||||
func (c *HTTPClient) executeUpdateMyProfile(ctx context.Context, userID string, request *usermodel.UpdateMyProfileRequest) (downstream.UnaryResult, error) {
|
||||
payload, statusCode, err := c.doRequest(ctx, http.MethodPost, c.userPath(userID, userServiceProfilePathSuffix), request)
|
||||
if err != nil {
|
||||
return downstream.UnaryResult{}, fmt.Errorf("execute update my profile: %w", err)
|
||||
}
|
||||
|
||||
return projectResponse(statusCode, payload)
|
||||
}
|
||||
|
||||
func (c *HTTPClient) executeUpdateMySettings(ctx context.Context, userID string, request *usermodel.UpdateMySettingsRequest) (downstream.UnaryResult, error) {
|
||||
payload, statusCode, err := c.doRequest(ctx, http.MethodPost, c.userPath(userID, userServiceSettingsPathSuffix), request)
|
||||
if err != nil {
|
||||
return downstream.UnaryResult{}, fmt.Errorf("execute update my settings: %w", err)
|
||||
}
|
||||
|
||||
return projectResponse(statusCode, payload)
|
||||
}
|
||||
|
||||
func (c *HTTPClient) doRequest(ctx context.Context, method string, targetURL string, requestBody any) ([]byte, int, error) {
|
||||
if c == nil || c.httpClient == nil {
|
||||
return nil, 0, errors.New("nil client")
|
||||
}
|
||||
|
||||
var bodyReader io.Reader
|
||||
if requestBody != nil {
|
||||
payload, err := json.Marshal(requestBody)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("marshal request body: %w", err)
|
||||
}
|
||||
bodyReader = bytes.NewReader(payload)
|
||||
}
|
||||
|
||||
request, err := http.NewRequestWithContext(ctx, method, targetURL, bodyReader)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("build request: %w", err)
|
||||
}
|
||||
if requestBody != nil {
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
response, err := c.httpClient.Do(request)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
payload, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("read response body: %w", err)
|
||||
}
|
||||
|
||||
return payload, response.StatusCode, nil
|
||||
}
|
||||
|
||||
func (c *HTTPClient) userPath(userID string, suffix string) string {
|
||||
return c.baseURL + "/api/v1/internal/users/" + url.PathEscape(userID) + suffix
|
||||
}
|
||||
|
||||
func projectResponse(statusCode int, payload []byte) (downstream.UnaryResult, error) {
|
||||
switch {
|
||||
case statusCode == http.StatusOK:
|
||||
var response usermodel.AccountResponse
|
||||
if err := decodeStrictJSONPayload(payload, &response); err != nil {
|
||||
return downstream.UnaryResult{}, fmt.Errorf("decode success response: %w", err)
|
||||
}
|
||||
|
||||
payloadBytes, err := transcoder.AccountResponseToPayload(&response)
|
||||
if err != nil {
|
||||
return downstream.UnaryResult{}, fmt.Errorf("encode success response payload: %w", err)
|
||||
}
|
||||
|
||||
return downstream.UnaryResult{
|
||||
ResultCode: getMyAccountResultCodeOK,
|
||||
PayloadBytes: payloadBytes,
|
||||
}, nil
|
||||
case statusCode == http.StatusServiceUnavailable:
|
||||
return downstream.UnaryResult{}, downstream.ErrDownstreamUnavailable
|
||||
case statusCode >= 400 && statusCode <= 599:
|
||||
errorResponse, err := decodeUserServiceError(statusCode, payload)
|
||||
if err != nil {
|
||||
return downstream.UnaryResult{}, fmt.Errorf("decode error response: %w", err)
|
||||
}
|
||||
|
||||
payloadBytes, err := transcoder.ErrorResponseToPayload(errorResponse)
|
||||
if err != nil {
|
||||
return downstream.UnaryResult{}, fmt.Errorf("encode error response payload: %w", err)
|
||||
}
|
||||
|
||||
return downstream.UnaryResult{
|
||||
ResultCode: errorResponse.Error.Code,
|
||||
PayloadBytes: payloadBytes,
|
||||
}, nil
|
||||
default:
|
||||
return downstream.UnaryResult{}, fmt.Errorf("unexpected HTTP status %d", statusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func decodeUserServiceError(statusCode int, payload []byte) (*usermodel.ErrorResponse, error) {
|
||||
var response usermodel.ErrorResponse
|
||||
if err := decodeStrictJSONPayload(payload, &response); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
response.Error.Code = normalizeErrorCode(statusCode, response.Error.Code)
|
||||
response.Error.Message = normalizeErrorMessage(response.Error.Code, response.Error.Message)
|
||||
|
||||
if strings.TrimSpace(response.Error.Code) == "" {
|
||||
return nil, errors.New("missing error code")
|
||||
}
|
||||
if strings.TrimSpace(response.Error.Message) == "" {
|
||||
return nil, errors.New("missing error message")
|
||||
}
|
||||
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
func normalizeErrorCode(statusCode int, code string) string {
|
||||
trimmed := strings.TrimSpace(code)
|
||||
if trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
|
||||
switch statusCode {
|
||||
case http.StatusBadRequest:
|
||||
return "invalid_request"
|
||||
case http.StatusNotFound:
|
||||
return "subject_not_found"
|
||||
case http.StatusConflict:
|
||||
return "conflict"
|
||||
default:
|
||||
return "internal_error"
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeErrorMessage(code string, message string) string {
|
||||
trimmed := strings.TrimSpace(message)
|
||||
if trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
|
||||
if stable, ok := stableErrorMessages[code]; ok {
|
||||
return stable
|
||||
}
|
||||
|
||||
return stableErrorMessages["internal_error"]
|
||||
}
|
||||
|
||||
func decodeStrictJSONPayload(payload []byte, target any) error {
|
||||
decoder := json.NewDecoder(bytes.NewReader(payload))
|
||||
decoder.DisallowUnknownFields()
|
||||
|
||||
if err := decoder.Decode(target); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := decoder.Decode(&struct{}{}); err != io.EOF {
|
||||
if err == nil {
|
||||
return errors.New("unexpected trailing JSON input")
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ downstream.Client = (*HTTPClient)(nil)
|
||||
@@ -1,400 +0,0 @@
|
||||
package userservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/downstream"
|
||||
usermodel "galaxy/model/user"
|
||||
"galaxy/transcoder"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewHTTPClient(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
baseURL string
|
||||
wantURL string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "absolute URL is normalized",
|
||||
baseURL: " http://127.0.0.1:8081/ ",
|
||||
wantURL: "http://127.0.0.1:8081",
|
||||
},
|
||||
{
|
||||
name: "empty base URL is rejected",
|
||||
baseURL: " ",
|
||||
wantErr: "base URL must not be empty",
|
||||
},
|
||||
{
|
||||
name: "relative base URL is rejected",
|
||||
baseURL: "/relative",
|
||||
wantErr: "base URL must be absolute",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, err := NewHTTPClient(tt.baseURL)
|
||||
if tt.wantErr != "" {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantURL, client.baseURL)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPClientExecuteGetMyAccountSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
wantResponse := sampleAccountResponse()
|
||||
server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
||||
require.Equal(t, http.MethodGet, request.Method)
|
||||
require.Equal(t, "/api/v1/internal/users/user-123/account", request.URL.Path)
|
||||
require.NoError(t, json.NewEncoder(writer).Encode(wantResponse))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestHTTPClient(t, server)
|
||||
payload, err := transcoder.GetMyAccountRequestToPayload(&usermodel.GetMyAccountRequest{})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := client.ExecuteCommand(context.Background(), downstream.AuthenticatedCommand{
|
||||
UserID: "user-123",
|
||||
MessageType: usermodel.MessageTypeGetMyAccount,
|
||||
PayloadBytes: payload,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, getMyAccountResultCodeOK, result.ResultCode)
|
||||
|
||||
decoded, err := transcoder.PayloadToAccountResponse(result.PayloadBytes)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, wantResponse, decoded)
|
||||
}
|
||||
|
||||
func TestHTTPClientExecuteUpdateMyProfileProjectsConflict(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
||||
require.Equal(t, http.MethodPost, request.Method)
|
||||
require.Equal(t, "/api/v1/internal/users/user-123/profile", request.URL.Path)
|
||||
|
||||
body, err := io.ReadAll(request.Body)
|
||||
require.NoError(t, err)
|
||||
require.JSONEq(t, `{"display_name":"NovaPrime"}`, string(body))
|
||||
|
||||
writer.WriteHeader(http.StatusConflict)
|
||||
require.NoError(t, json.NewEncoder(writer).Encode(&usermodel.ErrorResponse{
|
||||
Error: usermodel.ErrorBody{
|
||||
Code: "conflict",
|
||||
Message: "request conflicts with current state",
|
||||
},
|
||||
}))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestHTTPClient(t, server)
|
||||
payload, err := transcoder.UpdateMyProfileRequestToPayload(&usermodel.UpdateMyProfileRequest{DisplayName: "NovaPrime"})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := client.ExecuteCommand(context.Background(), downstream.AuthenticatedCommand{
|
||||
UserID: "user-123",
|
||||
MessageType: usermodel.MessageTypeUpdateMyProfile,
|
||||
PayloadBytes: payload,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "conflict", result.ResultCode)
|
||||
|
||||
decoded, err := transcoder.PayloadToErrorResponse(result.PayloadBytes)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, &usermodel.ErrorResponse{
|
||||
Error: usermodel.ErrorBody{
|
||||
Code: "conflict",
|
||||
Message: "request conflicts with current state",
|
||||
},
|
||||
}, decoded)
|
||||
}
|
||||
|
||||
func TestHTTPClientExecuteUpdateMySettingsProjectsInvalidRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
||||
require.Equal(t, http.MethodPost, request.Method)
|
||||
require.Equal(t, "/api/v1/internal/users/user-123/settings", request.URL.Path)
|
||||
|
||||
body, err := io.ReadAll(request.Body)
|
||||
require.NoError(t, err)
|
||||
require.JSONEq(t, `{"preferred_language":"bad","time_zone":"Mars/Base"}`, string(body))
|
||||
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
require.NoError(t, json.NewEncoder(writer).Encode(&usermodel.ErrorResponse{
|
||||
Error: usermodel.ErrorBody{
|
||||
Code: "invalid_request",
|
||||
Message: "request is invalid",
|
||||
},
|
||||
}))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestHTTPClient(t, server)
|
||||
payload, err := transcoder.UpdateMySettingsRequestToPayload(&usermodel.UpdateMySettingsRequest{
|
||||
PreferredLanguage: "bad",
|
||||
TimeZone: "Mars/Base",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := client.ExecuteCommand(context.Background(), downstream.AuthenticatedCommand{
|
||||
UserID: "user-123",
|
||||
MessageType: usermodel.MessageTypeUpdateMySettings,
|
||||
PayloadBytes: payload,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "invalid_request", result.ResultCode)
|
||||
|
||||
decoded, err := transcoder.PayloadToErrorResponse(result.PayloadBytes)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "invalid_request", decoded.Error.Code)
|
||||
}
|
||||
|
||||
func TestHTTPClientExecuteCommandProjectsSubjectNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
||||
writer.WriteHeader(http.StatusNotFound)
|
||||
require.NoError(t, json.NewEncoder(writer).Encode(&usermodel.ErrorResponse{
|
||||
Error: usermodel.ErrorBody{
|
||||
Code: "subject_not_found",
|
||||
Message: "subject not found",
|
||||
},
|
||||
}))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestHTTPClient(t, server)
|
||||
payload, err := transcoder.GetMyAccountRequestToPayload(&usermodel.GetMyAccountRequest{})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := client.ExecuteCommand(context.Background(), downstream.AuthenticatedCommand{
|
||||
UserID: "user-missing",
|
||||
MessageType: usermodel.MessageTypeGetMyAccount,
|
||||
PayloadBytes: payload,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "subject_not_found", result.ResultCode)
|
||||
}
|
||||
|
||||
func TestHTTPClientExecuteCommandMaps503ToUnavailable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
||||
writer.WriteHeader(http.StatusServiceUnavailable)
|
||||
require.NoError(t, json.NewEncoder(writer).Encode(&usermodel.ErrorResponse{
|
||||
Error: usermodel.ErrorBody{
|
||||
Code: "service_unavailable",
|
||||
Message: "service is unavailable",
|
||||
},
|
||||
}))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestHTTPClient(t, server)
|
||||
payload, err := transcoder.GetMyAccountRequestToPayload(&usermodel.GetMyAccountRequest{})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), downstream.AuthenticatedCommand{
|
||||
UserID: "user-123",
|
||||
MessageType: usermodel.MessageTypeGetMyAccount,
|
||||
PayloadBytes: payload,
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, downstream.ErrDownstreamUnavailable)
|
||||
}
|
||||
|
||||
func TestHTTPClientExecuteCommandUsesCallerContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
||||
<-request.Context().Done()
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestHTTPClient(t, server)
|
||||
payload, err := transcoder.GetMyAccountRequestToPayload(&usermodel.GetMyAccountRequest{})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 25*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, err = client.ExecuteCommand(ctx, downstream.AuthenticatedCommand{
|
||||
UserID: "user-123",
|
||||
MessageType: usermodel.MessageTypeGetMyAccount,
|
||||
PayloadBytes: payload,
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
}
|
||||
|
||||
func TestHTTPClientExecuteCommandRejectsMalformedSuccessPayload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
||||
_, _ = writer.Write([]byte(`{"account":{"user_id":"user-123","unexpected":true}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestHTTPClient(t, server)
|
||||
payload, err := transcoder.GetMyAccountRequestToPayload(&usermodel.GetMyAccountRequest{})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), downstream.AuthenticatedCommand{
|
||||
UserID: "user-123",
|
||||
MessageType: usermodel.MessageTypeGetMyAccount,
|
||||
PayloadBytes: payload,
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "decode success response")
|
||||
}
|
||||
|
||||
func TestHTTPClientExecuteCommandRejectsUnsupportedMessageType(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.NotFoundHandler())
|
||||
defer server.Close()
|
||||
|
||||
client := newTestHTTPClient(t, server)
|
||||
|
||||
_, err := client.ExecuteCommand(context.Background(), downstream.AuthenticatedCommand{
|
||||
UserID: "user-123",
|
||||
MessageType: "user.unsupported",
|
||||
PayloadBytes: []byte("payload"),
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unsupported message type")
|
||||
}
|
||||
|
||||
func TestNewRoutesReserveUserMessageTypesWhenUnconfigured(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
routes, closeFn, err := NewRoutes("")
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, closeFn())
|
||||
|
||||
router := downstream.NewStaticRouter(routes)
|
||||
for _, messageType := range []string{
|
||||
usermodel.MessageTypeGetMyAccount,
|
||||
usermodel.MessageTypeUpdateMyProfile,
|
||||
usermodel.MessageTypeUpdateMySettings,
|
||||
} {
|
||||
client, routeErr := router.Route(messageType)
|
||||
require.NoError(t, routeErr)
|
||||
|
||||
_, execErr := client.ExecuteCommand(context.Background(), downstream.AuthenticatedCommand{
|
||||
UserID: "user-123",
|
||||
MessageType: messageType,
|
||||
})
|
||||
require.Error(t, execErr)
|
||||
assert.ErrorIs(t, execErr, downstream.ErrDownstreamUnavailable)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnavailableClientReturnsDownstreamUnavailable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := unavailableClient{}.ExecuteCommand(context.Background(), downstream.AuthenticatedCommand{})
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, downstream.ErrDownstreamUnavailable)
|
||||
}
|
||||
|
||||
func newTestHTTPClient(t *testing.T, server *httptest.Server) *HTTPClient {
|
||||
t.Helper()
|
||||
|
||||
client, err := newHTTPClient(server.URL, server.Client())
|
||||
require.NoError(t, err)
|
||||
return client
|
||||
}
|
||||
|
||||
func sampleAccountResponse() *usermodel.AccountResponse {
|
||||
now := time.Date(2026, time.April, 9, 10, 0, 0, 0, time.UTC)
|
||||
expiresAt := now.Add(30 * 24 * time.Hour)
|
||||
|
||||
return &usermodel.AccountResponse{
|
||||
Account: usermodel.Account{
|
||||
UserID: "user-123",
|
||||
Email: "pilot@example.com",
|
||||
UserName: "player-abcdefgh",
|
||||
DisplayName: "PilotNova",
|
||||
PreferredLanguage: "en",
|
||||
TimeZone: "Europe/Kaliningrad",
|
||||
DeclaredCountry: "DE",
|
||||
Entitlement: usermodel.EntitlementSnapshot{
|
||||
PlanCode: "free",
|
||||
IsPaid: false,
|
||||
Source: "auth_registration",
|
||||
Actor: usermodel.ActorRef{Type: "service", ID: "user-service"},
|
||||
ReasonCode: "initial_free_entitlement",
|
||||
StartsAt: now,
|
||||
UpdatedAt: now,
|
||||
},
|
||||
ActiveSanctions: []usermodel.ActiveSanction{
|
||||
{
|
||||
SanctionCode: "profile_update_block",
|
||||
Scope: "lobby",
|
||||
ReasonCode: "manual_block",
|
||||
Actor: usermodel.ActorRef{Type: "admin", ID: "admin-1"},
|
||||
AppliedAt: now,
|
||||
ExpiresAt: &expiresAt,
|
||||
},
|
||||
},
|
||||
ActiveLimits: []usermodel.ActiveLimit{
|
||||
{
|
||||
LimitCode: "max_owned_private_games",
|
||||
Value: 3,
|
||||
ReasonCode: "manual_override",
|
||||
Actor: usermodel.ActorRef{Type: "admin", ID: "admin-1"},
|
||||
AppliedAt: now,
|
||||
},
|
||||
},
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeUserServiceErrorNormalizesBlankFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
response, err := decodeUserServiceError(http.StatusBadRequest, []byte(`{"error":{"code":" ","message":" "}}`))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "invalid_request", response.Error.Code)
|
||||
assert.Equal(t, "request is invalid", response.Error.Message)
|
||||
}
|
||||
|
||||
func TestHTTPClientExecuteCommandRejectsNilContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.NotFoundHandler())
|
||||
defer server.Close()
|
||||
|
||||
client := newTestHTTPClient(t, server)
|
||||
|
||||
_, err := client.ExecuteCommand(nil, downstream.AuthenticatedCommand{})
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "nil context")
|
||||
}
|
||||
@@ -1,46 +0,0 @@
|
||||
package userservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"galaxy/gateway/internal/downstream"
|
||||
usermodel "galaxy/model/user"
|
||||
)
|
||||
|
||||
var noOpClose = func() error { return nil }
|
||||
|
||||
// NewRoutes returns the reserved authenticated gateway routes owned by the
|
||||
// Gateway -> User self-service boundary.
|
||||
//
|
||||
// When baseURL is empty, the returned routes still reserve the stable
|
||||
// `user.*` message types but resolve them to a dependency-unavailable client
|
||||
// so callers receive the transport-level unavailable outcome instead of a
|
||||
// route-miss error.
|
||||
func NewRoutes(baseURL string) (map[string]downstream.Client, func() error, error) {
|
||||
client := downstream.Client(unavailableClient{})
|
||||
closeFn := noOpClose
|
||||
|
||||
if baseURL != "" {
|
||||
httpClient, err := NewHTTPClient(baseURL)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
client = httpClient
|
||||
closeFn = httpClient.Close
|
||||
}
|
||||
|
||||
return map[string]downstream.Client{
|
||||
usermodel.MessageTypeGetMyAccount: client,
|
||||
usermodel.MessageTypeUpdateMyProfile: client,
|
||||
usermodel.MessageTypeUpdateMySettings: client,
|
||||
}, closeFn, nil
|
||||
}
|
||||
|
||||
type unavailableClient struct{}
|
||||
|
||||
func (unavailableClient) ExecuteCommand(context.Context, downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
|
||||
return downstream.UnaryResult{}, downstream.ErrDownstreamUnavailable
|
||||
}
|
||||
|
||||
var _ downstream.Client = unavailableClient{}
|
||||
@@ -1,299 +0,0 @@
|
||||
package events
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/config"
|
||||
"galaxy/gateway/internal/push"
|
||||
"galaxy/gateway/internal/telemetry"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const clientEventReadCount int64 = 128
|
||||
|
||||
// ClientEventPublisher accepts decoded client-facing events from the internal
|
||||
// event subscriber.
|
||||
type ClientEventPublisher interface {
|
||||
// Publish fans out event to the currently active push streams.
|
||||
Publish(event push.Event)
|
||||
}
|
||||
|
||||
// RedisClientEventSubscriber consumes client-facing events from one Redis
|
||||
// Stream and forwards them to the configured publisher.
|
||||
type RedisClientEventSubscriber struct {
|
||||
client *redis.Client
|
||||
stream string
|
||||
pingTimeout time.Duration
|
||||
readBlockTimeout time.Duration
|
||||
publisher ClientEventPublisher
|
||||
logger *zap.Logger
|
||||
metrics *telemetry.Runtime
|
||||
|
||||
startedOnce sync.Once
|
||||
started chan struct{}
|
||||
}
|
||||
|
||||
// NewRedisClientEventSubscriber constructs a Redis Stream subscriber that uses
|
||||
// client and forwards decoded client-facing events to publisher.
|
||||
func NewRedisClientEventSubscriber(client *redis.Client, sessionCfg config.SessionCacheRedisConfig, eventsCfg config.ClientEventsRedisConfig, publisher ClientEventPublisher) (*RedisClientEventSubscriber, error) {
|
||||
return NewRedisClientEventSubscriberWithObservability(client, sessionCfg, eventsCfg, publisher, nil, nil)
|
||||
}
|
||||
|
||||
// NewRedisClientEventSubscriberWithObservability constructs a Redis Stream
|
||||
// subscriber that also records malformed or dropped internal events. The
|
||||
// subscriber does not own the client; the runtime supplies a shared
|
||||
// *redis.Client.
|
||||
func NewRedisClientEventSubscriberWithObservability(client *redis.Client, sessionCfg config.SessionCacheRedisConfig, eventsCfg config.ClientEventsRedisConfig, publisher ClientEventPublisher, logger *zap.Logger, metrics *telemetry.Runtime) (*RedisClientEventSubscriber, error) {
|
||||
if client == nil {
|
||||
return nil, errors.New("new redis client event subscriber: nil redis client")
|
||||
}
|
||||
if sessionCfg.LookupTimeout <= 0 {
|
||||
return nil, errors.New("new redis client event subscriber: lookup timeout must be positive")
|
||||
}
|
||||
if strings.TrimSpace(eventsCfg.Stream) == "" {
|
||||
return nil, errors.New("new redis client event subscriber: stream must not be empty")
|
||||
}
|
||||
if eventsCfg.ReadBlockTimeout <= 0 {
|
||||
return nil, errors.New("new redis client event subscriber: read block timeout must be positive")
|
||||
}
|
||||
if publisher == nil {
|
||||
return nil, errors.New("new redis client event subscriber: nil publisher")
|
||||
}
|
||||
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
|
||||
return &RedisClientEventSubscriber{
|
||||
client: client,
|
||||
stream: eventsCfg.Stream,
|
||||
pingTimeout: sessionCfg.LookupTimeout,
|
||||
readBlockTimeout: eventsCfg.ReadBlockTimeout,
|
||||
publisher: publisher,
|
||||
logger: logger.Named("client_event_subscriber"),
|
||||
metrics: metrics,
|
||||
started: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Run consumes client-facing events until ctx is canceled or Redis returns an
|
||||
// unexpected error.
|
||||
func (s *RedisClientEventSubscriber) Run(ctx context.Context) error {
|
||||
if s == nil || s.client == nil {
|
||||
return errors.New("run redis client event subscriber: nil subscriber")
|
||||
}
|
||||
if ctx == nil {
|
||||
return errors.New("run redis client event subscriber: nil context")
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
lastID, err := s.resolveStartID(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.signalStarted()
|
||||
|
||||
for {
|
||||
streams, err := s.client.XRead(ctx, &redis.XReadArgs{
|
||||
Streams: []string{s.stream, lastID},
|
||||
Count: clientEventReadCount,
|
||||
Block: s.readBlockTimeout,
|
||||
}).Result()
|
||||
switch {
|
||||
case err == nil:
|
||||
for _, stream := range streams {
|
||||
for _, message := range stream.Messages {
|
||||
s.publishMessage(message)
|
||||
lastID = message.ID
|
||||
}
|
||||
}
|
||||
continue
|
||||
case errors.Is(err, redis.Nil):
|
||||
continue
|
||||
case ctx.Err() != nil && (errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) || errors.Is(err, redis.ErrClosed)):
|
||||
return ctx.Err()
|
||||
case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded), errors.Is(err, redis.ErrClosed):
|
||||
return fmt.Errorf("run redis client event subscriber: %w", err)
|
||||
default:
|
||||
return fmt.Errorf("run redis client event subscriber: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RedisClientEventSubscriber) resolveStartID(ctx context.Context) (string, error) {
|
||||
messages, err := s.client.XRevRangeN(ctx, s.stream, "+", "-", 1).Result()
|
||||
switch {
|
||||
case err == nil:
|
||||
case errors.Is(err, redis.Nil):
|
||||
return "0-0", nil
|
||||
default:
|
||||
return "", fmt.Errorf("run redis client event subscriber: resolve stream tail: %w", err)
|
||||
}
|
||||
|
||||
if len(messages) == 0 {
|
||||
return "0-0", nil
|
||||
}
|
||||
|
||||
return messages[0].ID, nil
|
||||
}
|
||||
|
||||
// Shutdown is a no-op kept for App framework compatibility. The blocking
|
||||
// XRead loop terminates when its context is cancelled by the parent runtime,
|
||||
// which also owns and closes the shared Redis client.
|
||||
func (s *RedisClientEventSubscriber) Shutdown(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
return errors.New("shutdown redis client event subscriber: nil context")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close is a no-op kept for backwards-compatible cleanup wiring; the
|
||||
// subscriber does not own the shared Redis client.
|
||||
func (s *RedisClientEventSubscriber) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *RedisClientEventSubscriber) signalStarted() {
|
||||
s.startedOnce.Do(func() {
|
||||
close(s.started)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *RedisClientEventSubscriber) publishMessage(message redis.XMessage) {
|
||||
event, err := decodeClientEvent(message.Values)
|
||||
if err != nil {
|
||||
s.logger.Warn("dropped malformed client event",
|
||||
zap.String("stream", s.stream),
|
||||
zap.String("message_id", message.ID),
|
||||
zap.Error(err),
|
||||
)
|
||||
s.metrics.RecordInternalEventDrop(context.Background(),
|
||||
attribute.String("component", "client_event_subscriber"),
|
||||
attribute.String("reason", "malformed_event"),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
s.publisher.Publish(event)
|
||||
}
|
||||
|
||||
func decodeClientEvent(values map[string]any) (push.Event, error) {
|
||||
requiredKeys := map[string]struct{}{
|
||||
"user_id": {},
|
||||
"event_type": {},
|
||||
"event_id": {},
|
||||
"payload_bytes": {},
|
||||
}
|
||||
optionalKeys := map[string]struct{}{
|
||||
"device_session_id": {},
|
||||
"request_id": {},
|
||||
"trace_id": {},
|
||||
}
|
||||
|
||||
for key := range values {
|
||||
if _, ok := requiredKeys[key]; ok {
|
||||
continue
|
||||
}
|
||||
if _, ok := optionalKeys[key]; ok {
|
||||
continue
|
||||
}
|
||||
|
||||
return push.Event{}, fmt.Errorf("decode client event: unsupported field %q", key)
|
||||
}
|
||||
|
||||
userID, err := requiredStringField(values, "user_id")
|
||||
if err != nil {
|
||||
return push.Event{}, err
|
||||
}
|
||||
eventType, err := requiredStringField(values, "event_type")
|
||||
if err != nil {
|
||||
return push.Event{}, err
|
||||
}
|
||||
eventID, err := requiredStringField(values, "event_id")
|
||||
if err != nil {
|
||||
return push.Event{}, err
|
||||
}
|
||||
payloadBytes, err := requiredBytesField(values, "payload_bytes")
|
||||
if err != nil {
|
||||
return push.Event{}, err
|
||||
}
|
||||
|
||||
event := push.Event{
|
||||
UserID: userID,
|
||||
EventType: eventType,
|
||||
EventID: eventID,
|
||||
PayloadBytes: payloadBytes,
|
||||
}
|
||||
|
||||
if deviceSessionID, ok, err := optionalStringField(values, "device_session_id"); err != nil {
|
||||
return push.Event{}, err
|
||||
} else if ok {
|
||||
event.DeviceSessionID = strings.TrimSpace(deviceSessionID)
|
||||
}
|
||||
|
||||
if requestID, ok, err := optionalStringField(values, "request_id"); err != nil {
|
||||
return push.Event{}, err
|
||||
} else if ok {
|
||||
event.RequestID = requestID
|
||||
}
|
||||
|
||||
if traceID, ok, err := optionalStringField(values, "trace_id"); err != nil {
|
||||
return push.Event{}, err
|
||||
} else if ok {
|
||||
event.TraceID = traceID
|
||||
}
|
||||
|
||||
return event, nil
|
||||
}
|
||||
|
||||
func requiredBytesField(values map[string]any, field string) ([]byte, error) {
|
||||
value, ok := values[field]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("decode client event: missing %s", field)
|
||||
}
|
||||
|
||||
byteValue, err := coerceBytes(value)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode client event: %s: %w", field, err)
|
||||
}
|
||||
|
||||
return byteValue, nil
|
||||
}
|
||||
|
||||
func optionalStringField(values map[string]any, field string) (string, bool, error) {
|
||||
value, ok := values[field]
|
||||
if !ok {
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
stringValue, err := coerceString(value)
|
||||
if err != nil {
|
||||
return "", false, fmt.Errorf("decode client event: %s: %w", field, err)
|
||||
}
|
||||
|
||||
return stringValue, true, nil
|
||||
}
|
||||
|
||||
func coerceBytes(value any) ([]byte, error) {
|
||||
switch typed := value.(type) {
|
||||
case string:
|
||||
return []byte(typed), nil
|
||||
case []byte:
|
||||
return bytes.Clone(typed), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported type %T", value)
|
||||
}
|
||||
}
|
||||
@@ -1,289 +0,0 @@
|
||||
package events
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/config"
|
||||
"galaxy/gateway/internal/push"
|
||||
"galaxy/gateway/internal/testutil"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRedisClientEventSubscriberPublishesValidEvent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := &recordingClientEventPublisher{}
|
||||
subscriber := newTestRedisClientEventSubscriber(t, server, publisher)
|
||||
running := runTestClientEventSubscriber(t, subscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
addClientEvent(t, server, "gateway:client_events", map[string]any{
|
||||
"user_id": "user-123",
|
||||
"device_session_id": "device-session-123",
|
||||
"event_type": "fleet.updated",
|
||||
"event_id": "event-123",
|
||||
"payload_bytes": []byte("payload-123"),
|
||||
"request_id": "request-123",
|
||||
"trace_id": "trace-123",
|
||||
})
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return len(publisher.events()) == 1
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
|
||||
assert.Equal(t, []push.Event{{
|
||||
UserID: "user-123",
|
||||
DeviceSessionID: "device-session-123",
|
||||
EventType: "fleet.updated",
|
||||
EventID: "event-123",
|
||||
PayloadBytes: []byte("payload-123"),
|
||||
RequestID: "request-123",
|
||||
TraceID: "trace-123",
|
||||
}}, publisher.events())
|
||||
}
|
||||
|
||||
func TestRedisClientEventSubscriberSkipsMalformedEventAndContinues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := &recordingClientEventPublisher{}
|
||||
subscriber := newTestRedisClientEventSubscriber(t, server, publisher)
|
||||
running := runTestClientEventSubscriber(t, subscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
addClientEvent(t, server, "gateway:client_events", map[string]any{
|
||||
"user_id": "user-123",
|
||||
"event_type": "fleet.updated",
|
||||
"event_id": "event-bad",
|
||||
"payload_bytes": []byte("payload-bad"),
|
||||
"unexpected": "boom",
|
||||
})
|
||||
addClientEvent(t, server, "gateway:client_events", map[string]any{
|
||||
"user_id": "user-123",
|
||||
"event_type": "fleet.updated",
|
||||
"event_id": "event-good",
|
||||
"payload_bytes": []byte("payload-good"),
|
||||
})
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
events := publisher.events()
|
||||
return len(events) == 1 && events[0].EventID == "event-good"
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestRedisClientEventSubscriberStartsFromCurrentTail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := &recordingClientEventPublisher{}
|
||||
|
||||
addClientEvent(t, server, "gateway:client_events", map[string]any{
|
||||
"user_id": "user-123",
|
||||
"event_type": "fleet.updated",
|
||||
"event_id": "event-old",
|
||||
"payload_bytes": []byte("payload-old"),
|
||||
})
|
||||
|
||||
subscriber := newTestRedisClientEventSubscriber(t, server, publisher)
|
||||
running := runTestClientEventSubscriber(t, subscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
assert.Never(t, func() bool {
|
||||
return len(publisher.events()) > 0
|
||||
}, 100*time.Millisecond, 10*time.Millisecond)
|
||||
|
||||
addClientEvent(t, server, "gateway:client_events", map[string]any{
|
||||
"user_id": "user-123",
|
||||
"event_type": "fleet.updated",
|
||||
"event_id": "event-new",
|
||||
"payload_bytes": []byte("payload-new"),
|
||||
})
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
events := publisher.events()
|
||||
return len(events) == 1 && events[0].EventID == "event-new"
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestRedisClientEventSubscriberShutdownInterruptsBlockingRead(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := &recordingClientEventPublisher{}
|
||||
subscriber := newTestRedisClientEventSubscriber(t, server, publisher)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
resultCh := make(chan error, 1)
|
||||
go func() {
|
||||
resultCh <- subscriber.Run(ctx)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-subscriber.started:
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "subscriber did not start")
|
||||
}
|
||||
|
||||
cancel()
|
||||
require.NoError(t, subscriber.Shutdown(context.Background()))
|
||||
|
||||
select {
|
||||
case err := <-resultCh:
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "subscriber did not stop after shutdown")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisClientEventSubscriberLogsAndCountsMalformedEvents(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
publisher := &recordingClientEventPublisher{}
|
||||
logger, logBuffer := testutil.NewObservedLogger(t)
|
||||
telemetryRuntime := testutil.NewTelemetryRuntime(t, logger)
|
||||
|
||||
subscriber, err := NewRedisClientEventSubscriberWithObservability(
|
||||
newTestRedisClient(t, server),
|
||||
config.SessionCacheRedisConfig{
|
||||
KeyPrefix: "gateway:session:",
|
||||
LookupTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
config.ClientEventsRedisConfig{
|
||||
Stream: "gateway:client_events",
|
||||
ReadBlockTimeout: 25 * time.Millisecond,
|
||||
},
|
||||
publisher,
|
||||
logger,
|
||||
telemetryRuntime,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
running := runTestClientEventSubscriber(t, subscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
addClientEvent(t, server, "gateway:client_events", map[string]any{
|
||||
"user_id": "user-123",
|
||||
"event_type": "fleet.updated",
|
||||
"event_id": "event-bad",
|
||||
"payload_bytes": []byte("payload-bad"),
|
||||
"unexpected": "boom",
|
||||
})
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return strings.Contains(logBuffer.String(), "dropped malformed client event")
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
|
||||
metricsText := testutil.ScrapeMetrics(t, telemetryRuntime.Handler())
|
||||
assert.Contains(t, metricsText, `gateway_internal_event_drops_total`)
|
||||
assert.Contains(t, metricsText, `component="client_event_subscriber"`)
|
||||
assert.Contains(t, metricsText, `reason="malformed_event"`)
|
||||
}
|
||||
|
||||
func newTestRedisClientEventSubscriber(t *testing.T, server *miniredis.Miniredis, publisher ClientEventPublisher) *RedisClientEventSubscriber {
|
||||
t.Helper()
|
||||
|
||||
subscriber, err := NewRedisClientEventSubscriber(
|
||||
newTestRedisClient(t, server),
|
||||
config.SessionCacheRedisConfig{
|
||||
KeyPrefix: "gateway:session:",
|
||||
LookupTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
config.ClientEventsRedisConfig{
|
||||
Stream: "gateway:client_events",
|
||||
ReadBlockTimeout: 25 * time.Millisecond,
|
||||
},
|
||||
publisher,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
return subscriber
|
||||
}
|
||||
|
||||
func addClientEvent(t *testing.T, server *miniredis.Miniredis, stream string, values map[string]any) {
|
||||
t.Helper()
|
||||
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: server.Addr(),
|
||||
Protocol: 2,
|
||||
DisableIdentity: true,
|
||||
})
|
||||
defer func() {
|
||||
assert.NoError(t, client.Close())
|
||||
}()
|
||||
|
||||
err := client.XAdd(context.Background(), &redis.XAddArgs{
|
||||
Stream: stream,
|
||||
Values: values,
|
||||
}).Err()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
type runningClientEventSubscriber struct {
|
||||
cancel context.CancelFunc
|
||||
resultCh chan error
|
||||
}
|
||||
|
||||
func runTestClientEventSubscriber(t *testing.T, subscriber *RedisClientEventSubscriber) runningClientEventSubscriber {
|
||||
t.Helper()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
resultCh := make(chan error, 1)
|
||||
go func() {
|
||||
resultCh <- subscriber.Run(ctx)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-subscriber.started:
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "subscriber did not start")
|
||||
}
|
||||
|
||||
return runningClientEventSubscriber{
|
||||
cancel: cancel,
|
||||
resultCh: resultCh,
|
||||
}
|
||||
}
|
||||
|
||||
func (r runningClientEventSubscriber) stop(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
r.cancel()
|
||||
|
||||
select {
|
||||
case err := <-r.resultCh:
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "subscriber did not stop")
|
||||
}
|
||||
}
|
||||
|
||||
type recordingClientEventPublisher struct {
|
||||
mu sync.Mutex
|
||||
records []push.Event
|
||||
}
|
||||
|
||||
func (p *recordingClientEventPublisher) Publish(event push.Event) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
p.records = append(p.records, event)
|
||||
}
|
||||
|
||||
func (p *recordingClientEventPublisher) events() []push.Event {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
cloned := make([]push.Event, len(p.records))
|
||||
copy(cloned, p.records)
|
||||
return cloned
|
||||
}
|
||||
@@ -0,0 +1,145 @@
|
||||
// Package events translates inbound `pushv1.PushEvent` frames received
|
||||
// from backend into actions on the gateway-side push hub. It replaces
|
||||
// the Stage <6.2 Redis Stream subscribers (`session_events`,
|
||||
// `client_events`) with a single dispatcher driven by the gRPC
|
||||
// SubscribePush stream.
|
||||
package events
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
pushv1 "galaxy/backend/proto/push/v1"
|
||||
"galaxy/gateway/internal/push"
|
||||
"galaxy/gateway/internal/telemetry"
|
||||
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// SessionInvalidator closes every active push subscription bound to a
|
||||
// (device_session_id) or every session of a user when the backend emits
|
||||
// a SessionInvalidation frame. *push.Hub satisfies this contract.
|
||||
type SessionInvalidator interface {
|
||||
RevokeDeviceSession(deviceSessionID string)
|
||||
RevokeAllForUser(userID string)
|
||||
}
|
||||
|
||||
// EventPublisher fans out a translated client event to active push
|
||||
// subscriptions. *push.Hub satisfies this contract.
|
||||
type EventPublisher interface {
|
||||
Publish(event push.Event)
|
||||
}
|
||||
|
||||
// Dispatcher converts inbound `pushv1.PushEvent` frames into either a
|
||||
// hub Publish or a hub revocation. Malformed frames are dropped and
|
||||
// counted via telemetry; observability mirrors the previous
|
||||
// RecordInternalEventDrop semantics.
|
||||
type Dispatcher struct {
|
||||
publisher EventPublisher
|
||||
invalidator SessionInvalidator
|
||||
logger *zap.Logger
|
||||
metrics *telemetry.Runtime
|
||||
}
|
||||
|
||||
// NewDispatcher constructs a Dispatcher. publisher and invalidator are
|
||||
// required; logger and metrics may be nil.
|
||||
func NewDispatcher(publisher EventPublisher, invalidator SessionInvalidator, logger *zap.Logger, metrics *telemetry.Runtime) *Dispatcher {
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
return &Dispatcher{
|
||||
publisher: publisher,
|
||||
invalidator: invalidator,
|
||||
logger: logger.Named("push_dispatcher"),
|
||||
metrics: metrics,
|
||||
}
|
||||
}
|
||||
|
||||
// Handle implements backendclient.EventHandler. It is safe for
|
||||
// concurrent use; the caller serialises ev within its goroutine.
|
||||
func (d *Dispatcher) Handle(ctx context.Context, ev *pushv1.PushEvent) {
|
||||
if d == nil || ev == nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch kind := ev.GetKind().(type) {
|
||||
case *pushv1.PushEvent_ClientEvent:
|
||||
d.handleClientEvent(ctx, kind.ClientEvent, ev.GetCursor())
|
||||
case *pushv1.PushEvent_SessionInvalidation:
|
||||
d.handleSessionInvalidation(kind.SessionInvalidation)
|
||||
default:
|
||||
d.logger.Warn("dropped malformed push event",
|
||||
zap.String("cursor", ev.GetCursor()),
|
||||
zap.String("reason", "unknown_kind"),
|
||||
)
|
||||
d.recordDrop(ctx, "unknown_kind")
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Dispatcher) handleClientEvent(ctx context.Context, ce *pushv1.ClientEvent, cursor string) {
|
||||
if ce == nil || d.publisher == nil {
|
||||
return
|
||||
}
|
||||
|
||||
userID := strings.TrimSpace(ce.GetUserId())
|
||||
kind := strings.TrimSpace(ce.GetKind())
|
||||
eventID := strings.TrimSpace(ce.GetEventId())
|
||||
if userID == "" || kind == "" || eventID == "" {
|
||||
d.logger.Warn("dropped malformed client event",
|
||||
zap.String("cursor", cursor),
|
||||
zap.String("user_id", userID),
|
||||
zap.String("kind", kind),
|
||||
zap.String("event_id", eventID),
|
||||
)
|
||||
d.recordDrop(ctx, "malformed_client_event")
|
||||
return
|
||||
}
|
||||
|
||||
d.publisher.Publish(push.Event{
|
||||
UserID: userID,
|
||||
DeviceSessionID: strings.TrimSpace(ce.GetDeviceSessionId()),
|
||||
EventType: kind,
|
||||
EventID: eventID,
|
||||
PayloadBytes: cloneBytes(ce.GetPayload()),
|
||||
RequestID: ce.GetRequestId(),
|
||||
TraceID: ce.GetTraceId(),
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Dispatcher) handleSessionInvalidation(si *pushv1.SessionInvalidation) {
|
||||
if si == nil || d.invalidator == nil {
|
||||
return
|
||||
}
|
||||
|
||||
userID := strings.TrimSpace(si.GetUserId())
|
||||
deviceSessionID := strings.TrimSpace(si.GetDeviceSessionId())
|
||||
|
||||
switch {
|
||||
case deviceSessionID != "":
|
||||
d.invalidator.RevokeDeviceSession(deviceSessionID)
|
||||
case userID != "":
|
||||
d.invalidator.RevokeAllForUser(userID)
|
||||
default:
|
||||
d.logger.Warn("dropped malformed session_invalidation: user_id and device_session_id both empty")
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Dispatcher) recordDrop(ctx context.Context, reason string) {
|
||||
if d.metrics == nil {
|
||||
return
|
||||
}
|
||||
d.metrics.RecordInternalEventDrop(ctx,
|
||||
attribute.String("component", "push_dispatcher"),
|
||||
attribute.String("reason", reason),
|
||||
)
|
||||
}
|
||||
|
||||
func cloneBytes(in []byte) []byte {
|
||||
if len(in) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]byte, len(in))
|
||||
copy(out, in)
|
||||
return out
|
||||
}
|
||||
@@ -0,0 +1,157 @@
|
||||
package events_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
pushv1 "galaxy/backend/proto/push/v1"
|
||||
"galaxy/gateway/internal/events"
|
||||
"galaxy/gateway/internal/push"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type capturePublisher struct {
|
||||
mu sync.Mutex
|
||||
events []push.Event
|
||||
}
|
||||
|
||||
func (c *capturePublisher) Publish(event push.Event) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.events = append(c.events, event)
|
||||
}
|
||||
|
||||
func (c *capturePublisher) snapshot() []push.Event {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
out := make([]push.Event, len(c.events))
|
||||
copy(out, c.events)
|
||||
return out
|
||||
}
|
||||
|
||||
type captureInvalidator struct {
|
||||
mu sync.Mutex
|
||||
devices []string
|
||||
users []string
|
||||
}
|
||||
|
||||
func (c *captureInvalidator) RevokeDeviceSession(id string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.devices = append(c.devices, id)
|
||||
}
|
||||
|
||||
func (c *captureInvalidator) RevokeAllForUser(id string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.users = append(c.users, id)
|
||||
}
|
||||
|
||||
func (c *captureInvalidator) snapshot() ([]string, []string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
d := append([]string(nil), c.devices...)
|
||||
u := append([]string(nil), c.users...)
|
||||
return d, u
|
||||
}
|
||||
|
||||
func TestDispatcherForwardsClientEventToPublisher(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pub := &capturePublisher{}
|
||||
inv := &captureInvalidator{}
|
||||
disp := events.NewDispatcher(pub, inv, nil, nil)
|
||||
|
||||
disp.Handle(context.Background(), &pushv1.PushEvent{
|
||||
Cursor: "00000000000000000001",
|
||||
Kind: &pushv1.PushEvent_ClientEvent{
|
||||
ClientEvent: &pushv1.ClientEvent{
|
||||
UserId: "user-1",
|
||||
DeviceSessionId: "device-1",
|
||||
Kind: "lobby.invite.received",
|
||||
Payload: []byte(`{"x":1}`),
|
||||
EventId: "route-1",
|
||||
RequestId: "req-1",
|
||||
TraceId: "trace-1",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
got := pub.snapshot()
|
||||
require.Len(t, got, 1)
|
||||
assert.Equal(t, push.Event{
|
||||
UserID: "user-1",
|
||||
DeviceSessionID: "device-1",
|
||||
EventType: "lobby.invite.received",
|
||||
EventID: "route-1",
|
||||
PayloadBytes: []byte(`{"x":1}`),
|
||||
RequestID: "req-1",
|
||||
TraceID: "trace-1",
|
||||
}, got[0])
|
||||
|
||||
devices, users := inv.snapshot()
|
||||
assert.Empty(t, devices)
|
||||
assert.Empty(t, users)
|
||||
}
|
||||
|
||||
func TestDispatcherDropsClientEventMissingEventID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pub := &capturePublisher{}
|
||||
disp := events.NewDispatcher(pub, &captureInvalidator{}, nil, nil)
|
||||
|
||||
disp.Handle(context.Background(), &pushv1.PushEvent{
|
||||
Kind: &pushv1.PushEvent_ClientEvent{
|
||||
ClientEvent: &pushv1.ClientEvent{
|
||||
UserId: "user-1",
|
||||
Kind: "lobby.invite.received",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
assert.Empty(t, pub.snapshot())
|
||||
}
|
||||
|
||||
func TestDispatcherSessionInvalidationByDeviceID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
inv := &captureInvalidator{}
|
||||
disp := events.NewDispatcher(&capturePublisher{}, inv, nil, nil)
|
||||
|
||||
disp.Handle(context.Background(), &pushv1.PushEvent{
|
||||
Kind: &pushv1.PushEvent_SessionInvalidation{
|
||||
SessionInvalidation: &pushv1.SessionInvalidation{
|
||||
UserId: "user-1",
|
||||
DeviceSessionId: "device-1",
|
||||
Reason: "auth.revoke_session",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
devices, users := inv.snapshot()
|
||||
assert.Equal(t, []string{"device-1"}, devices)
|
||||
assert.Empty(t, users)
|
||||
}
|
||||
|
||||
func TestDispatcherSessionInvalidationFanOutForUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
inv := &captureInvalidator{}
|
||||
disp := events.NewDispatcher(&capturePublisher{}, inv, nil, nil)
|
||||
|
||||
disp.Handle(context.Background(), &pushv1.PushEvent{
|
||||
Kind: &pushv1.PushEvent_SessionInvalidation{
|
||||
SessionInvalidation: &pushv1.SessionInvalidation{
|
||||
UserId: "user-1",
|
||||
Reason: "auth.revoke_all_for_user",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
devices, users := inv.snapshot()
|
||||
assert.Empty(t, devices)
|
||||
assert.Equal(t, []string{"user-1"}, users)
|
||||
}
|
||||
@@ -1,396 +0,0 @@
|
||||
package events
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/app"
|
||||
"galaxy/gateway/internal/authn"
|
||||
"galaxy/gateway/internal/clock"
|
||||
"galaxy/gateway/internal/config"
|
||||
"galaxy/gateway/internal/downstream"
|
||||
"galaxy/gateway/internal/grpcapi"
|
||||
"galaxy/gateway/internal/replay"
|
||||
"galaxy/gateway/internal/session"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
var testNow = time.Date(2026, time.April, 1, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
func TestAuthenticatedGatewayWarmsLocalSessionCache(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
local := session.NewMemoryCache()
|
||||
fallback := &countingSessionCache{
|
||||
records: map[string]session.Record{
|
||||
"device-session-123": newActiveSessionRecord("user-123"),
|
||||
},
|
||||
}
|
||||
readThrough, err := session.NewReadThroughCache(local, fallback)
|
||||
require.NoError(t, err)
|
||||
|
||||
subscriber := newTestRedisSessionSubscriber(t, server, local)
|
||||
downstreamClient := &recordingDownstreamClient{}
|
||||
addr, running := runAuthenticatedGateway(t, readThrough, subscriber, downstreamClient)
|
||||
defer running.stop(t)
|
||||
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-1"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, fallback.lookupCalls())
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-2"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, fallback.lookupCalls())
|
||||
assert.Len(t, downstreamClient.commands(), 2)
|
||||
}
|
||||
|
||||
func TestAuthenticatedGatewayUsesSessionUpdateEventWithoutFallbackLookup(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
local := session.NewMemoryCache()
|
||||
fallback := &countingSessionCache{
|
||||
records: map[string]session.Record{
|
||||
"device-session-123": newActiveSessionRecord("user-123"),
|
||||
},
|
||||
}
|
||||
readThrough, err := session.NewReadThroughCache(local, fallback)
|
||||
require.NoError(t, err)
|
||||
|
||||
subscriber := newTestRedisSessionSubscriber(t, server, local)
|
||||
downstreamClient := &recordingDownstreamClient{}
|
||||
addr, running := runAuthenticatedGateway(t, readThrough, subscriber, downstreamClient)
|
||||
defer running.stop(t)
|
||||
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-1"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, fallback.lookupCalls())
|
||||
|
||||
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
||||
"device_session_id": "device-session-123",
|
||||
"user_id": "user-456",
|
||||
"client_public_key": testClientPublicKeyBase64(),
|
||||
"status": string(session.StatusActive),
|
||||
})
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
record, lookupErr := local.Lookup(context.Background(), "device-session-123")
|
||||
return lookupErr == nil && record.UserID == "user-456"
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-2"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, fallback.lookupCalls())
|
||||
|
||||
commands := downstreamClient.commands()
|
||||
require.Len(t, commands, 2)
|
||||
assert.Equal(t, "user-456", commands[1].UserID)
|
||||
}
|
||||
|
||||
func TestAuthenticatedGatewayRejectsRevokedSessionAfterEventWithoutFallbackLookup(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
local := session.NewMemoryCache()
|
||||
fallback := &countingSessionCache{
|
||||
records: map[string]session.Record{
|
||||
"device-session-123": newActiveSessionRecord("user-123"),
|
||||
},
|
||||
}
|
||||
readThrough, err := session.NewReadThroughCache(local, fallback)
|
||||
require.NoError(t, err)
|
||||
|
||||
subscriber := newTestRedisSessionSubscriber(t, server, local)
|
||||
downstreamClient := &recordingDownstreamClient{}
|
||||
addr, running := runAuthenticatedGateway(t, readThrough, subscriber, downstreamClient)
|
||||
defer running.stop(t)
|
||||
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-1"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, fallback.lookupCalls())
|
||||
|
||||
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
||||
"device_session_id": "device-session-123",
|
||||
"user_id": "user-123",
|
||||
"client_public_key": testClientPublicKeyBase64(),
|
||||
"status": string(session.StatusRevoked),
|
||||
"revoked_at_ms": "123456789",
|
||||
})
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
record, lookupErr := local.Lookup(context.Background(), "device-session-123")
|
||||
return lookupErr == nil && record.Status == session.StatusRevoked
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
|
||||
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-2"))
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
|
||||
assert.Equal(t, "device session is revoked", status.Convert(err).Message())
|
||||
assert.Equal(t, 1, fallback.lookupCalls())
|
||||
}
|
||||
|
||||
type runningAuthenticatedGateway struct {
|
||||
cancel context.CancelFunc
|
||||
resultCh chan error
|
||||
}
|
||||
|
||||
func runAuthenticatedGateway(t *testing.T, sessionCache session.Cache, subscriber *RedisSessionSubscriber, downstreamClient downstream.Client) (string, runningAuthenticatedGateway) {
|
||||
t.Helper()
|
||||
|
||||
addr := unusedTCPAddr(t)
|
||||
grpcCfg := config.DefaultAuthenticatedGRPCConfig()
|
||||
grpcCfg.Addr = addr
|
||||
grpcCfg.FreshnessWindow = 5 * time.Minute
|
||||
|
||||
router := downstream.NewStaticRouter(map[string]downstream.Client{
|
||||
"fleet.move": downstreamClient,
|
||||
})
|
||||
|
||||
gateway := grpcapi.NewServer(grpcCfg, grpcapi.ServerDependencies{
|
||||
Router: router,
|
||||
ResponseSigner: newTestResponseSigner(t),
|
||||
SessionCache: sessionCache,
|
||||
ReplayStore: staticReplayStore{},
|
||||
Clock: fixedClock{now: testNow},
|
||||
})
|
||||
|
||||
application := app.New(
|
||||
config.Config{
|
||||
ShutdownTimeout: time.Second,
|
||||
AuthenticatedGRPC: grpcCfg,
|
||||
},
|
||||
gateway,
|
||||
subscriber,
|
||||
)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
resultCh := make(chan error, 1)
|
||||
go func() {
|
||||
resultCh <- application.Run(ctx)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-subscriber.started:
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "session subscriber did not start")
|
||||
}
|
||||
|
||||
return addr, runningAuthenticatedGateway{
|
||||
cancel: cancel,
|
||||
resultCh: resultCh,
|
||||
}
|
||||
}
|
||||
|
||||
func (g runningAuthenticatedGateway) stop(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
g.cancel()
|
||||
|
||||
select {
|
||||
case err := <-g.resultCh:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(2 * time.Second):
|
||||
require.FailNow(t, "gateway did not stop after cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
func dialGatewayClient(t *testing.T, addr string) *grpc.ClientConn {
|
||||
t.Helper()
|
||||
|
||||
var conn *grpc.ClientConn
|
||||
require.Eventually(t, func() bool {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
candidate, err := grpc.DialContext(
|
||||
ctx,
|
||||
addr,
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithBlock(),
|
||||
)
|
||||
if err != nil {
|
||||
if candidate != nil {
|
||||
_ = candidate.Close()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
conn = candidate
|
||||
return true
|
||||
}, 2*time.Second, 10*time.Millisecond, "gateway did not accept gRPC connections")
|
||||
|
||||
return conn
|
||||
}
|
||||
|
||||
func unusedTCPAddr(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
addr := listener.Addr().String()
|
||||
require.NoError(t, listener.Close())
|
||||
|
||||
return addr
|
||||
}
|
||||
|
||||
func newExecuteCommandRequest(requestID string) *gatewayv1.ExecuteCommandRequest {
|
||||
payloadBytes := []byte("payload")
|
||||
payloadHash := sha256.Sum256(payloadBytes)
|
||||
|
||||
req := &gatewayv1.ExecuteCommandRequest{
|
||||
ProtocolVersion: "v1",
|
||||
DeviceSessionId: "device-session-123",
|
||||
MessageType: "fleet.move",
|
||||
TimestampMs: testNow.UnixMilli(),
|
||||
RequestId: requestID,
|
||||
PayloadBytes: payloadBytes,
|
||||
PayloadHash: payloadHash[:],
|
||||
TraceId: "trace-123",
|
||||
}
|
||||
req.Signature = ed25519.Sign(testClientPrivateKey(), authn.BuildRequestSigningInput(authn.RequestSigningFields{
|
||||
ProtocolVersion: req.GetProtocolVersion(),
|
||||
DeviceSessionID: req.GetDeviceSessionId(),
|
||||
MessageType: req.GetMessageType(),
|
||||
TimestampMS: req.GetTimestampMs(),
|
||||
RequestID: req.GetRequestId(),
|
||||
PayloadHash: req.GetPayloadHash(),
|
||||
}))
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
func newActiveSessionRecord(userID string) session.Record {
|
||||
return session.Record{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: userID,
|
||||
ClientPublicKey: testClientPublicKeyBase64(),
|
||||
Status: session.StatusActive,
|
||||
}
|
||||
}
|
||||
|
||||
func testClientPrivateKey() ed25519.PrivateKey {
|
||||
seed := sha256.Sum256([]byte("gateway-events-grpc-test-client"))
|
||||
return ed25519.NewKeyFromSeed(seed[:])
|
||||
}
|
||||
|
||||
func testClientPublicKeyBase64() string {
|
||||
return base64.StdEncoding.EncodeToString(testClientPrivateKey().Public().(ed25519.PublicKey))
|
||||
}
|
||||
|
||||
func newTestResponseSigner(t *testing.T) authn.ResponseSigner {
|
||||
t.Helper()
|
||||
|
||||
seed := sha256.Sum256([]byte("gateway-events-grpc-test-response"))
|
||||
signer, err := authn.NewEd25519ResponseSigner(ed25519.NewKeyFromSeed(seed[:]))
|
||||
require.NoError(t, err)
|
||||
|
||||
return signer
|
||||
}
|
||||
|
||||
type fixedClock struct {
|
||||
now time.Time
|
||||
}
|
||||
|
||||
func (c fixedClock) Now() time.Time {
|
||||
return c.now
|
||||
}
|
||||
|
||||
var _ clock.Clock = fixedClock{}
|
||||
|
||||
type staticReplayStore struct{}
|
||||
|
||||
func (staticReplayStore) Reserve(context.Context, string, string, time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ replay.Store = staticReplayStore{}
|
||||
|
||||
type countingSessionCache struct {
|
||||
mu sync.Mutex
|
||||
records map[string]session.Record
|
||||
lookupCount int
|
||||
}
|
||||
|
||||
func (c *countingSessionCache) Lookup(context.Context, string) (session.Record, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.lookupCount++
|
||||
|
||||
record, ok := c.records["device-session-123"]
|
||||
if !ok {
|
||||
return session.Record{}, errors.New("lookup session from counting cache: session cache record not found")
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func (c *countingSessionCache) lookupCalls() int {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
return c.lookupCount
|
||||
}
|
||||
|
||||
type recordingDownstreamClient struct {
|
||||
mu sync.Mutex
|
||||
captured []downstream.AuthenticatedCommand
|
||||
}
|
||||
|
||||
func (c *recordingDownstreamClient) ExecuteCommand(_ context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
|
||||
c.mu.Lock()
|
||||
c.captured = append(c.captured, command)
|
||||
c.mu.Unlock()
|
||||
|
||||
return downstream.UnaryResult{
|
||||
ResultCode: "ok",
|
||||
PayloadBytes: []byte("response"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *recordingDownstreamClient) commands() []downstream.AuthenticatedCommand {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
cloned := make([]downstream.AuthenticatedCommand, len(c.captured))
|
||||
copy(cloned, c.captured)
|
||||
return cloned
|
||||
}
|
||||
@@ -1,447 +0,0 @@
|
||||
package events
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/app"
|
||||
"galaxy/gateway/internal/authn"
|
||||
"galaxy/gateway/internal/config"
|
||||
"galaxy/gateway/internal/grpcapi"
|
||||
"galaxy/gateway/internal/push"
|
||||
"galaxy/gateway/internal/session"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
notificationfbs "galaxy/schema/fbs/notification"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
flatbuffers "github.com/google/flatbuffers/go"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestSubscribeEventsFanOutsUserTargetedEventToAllUserSessions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
sessionCache := session.NewMemoryCache()
|
||||
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-1", "user-123")))
|
||||
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-2", "user-123")))
|
||||
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-3", "user-999")))
|
||||
|
||||
pushHub := push.NewHub(4)
|
||||
clientSubscriber := newTestRedisClientEventSubscriber(t, server, pushHub)
|
||||
addr, running := runPushGateway(t, sessionCache, pushHub, clientSubscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
targetOneCtx, cancelTargetOne := context.WithCancel(context.Background())
|
||||
defer cancelTargetOne()
|
||||
targetOne, err := client.SubscribeEvents(targetOneCtx, newPushSubscribeEventsRequest("device-session-1", "request-1"))
|
||||
require.NoError(t, err)
|
||||
assertPushBootstrapEvent(t, recvPushEvent(t, targetOne), "request-1", "trace-device-session-1")
|
||||
|
||||
targetTwoCtx, cancelTargetTwo := context.WithCancel(context.Background())
|
||||
defer cancelTargetTwo()
|
||||
targetTwo, err := client.SubscribeEvents(targetTwoCtx, newPushSubscribeEventsRequest("device-session-2", "request-2"))
|
||||
require.NoError(t, err)
|
||||
assertPushBootstrapEvent(t, recvPushEvent(t, targetTwo), "request-2", "trace-device-session-2")
|
||||
|
||||
unrelatedCtx, cancelUnrelated := context.WithCancel(context.Background())
|
||||
defer cancelUnrelated()
|
||||
unrelated, err := client.SubscribeEvents(unrelatedCtx, newPushSubscribeEventsRequest("device-session-3", "request-3"))
|
||||
require.NoError(t, err)
|
||||
assertPushBootstrapEvent(t, recvPushEvent(t, unrelated), "request-3", "trace-device-session-3")
|
||||
|
||||
payloadBytes := buildGameTurnReadyPayload(t, "game-123", 54)
|
||||
addClientEvent(t, server, "gateway:client_events", map[string]any{
|
||||
"user_id": "user-123",
|
||||
"event_type": "game.turn.ready",
|
||||
"event_id": "event-123",
|
||||
"payload_bytes": payloadBytes,
|
||||
"request_id": "request-123",
|
||||
"trace_id": "trace-123",
|
||||
})
|
||||
|
||||
firstDelivered := recvPushEvent(t, targetOne)
|
||||
assertSignedPushEvent(t, firstDelivered, push.Event{
|
||||
UserID: "user-123",
|
||||
EventType: "game.turn.ready",
|
||||
EventID: "event-123",
|
||||
PayloadBytes: payloadBytes,
|
||||
RequestID: "request-123",
|
||||
TraceID: "trace-123",
|
||||
})
|
||||
assertDecodedGameTurnReadyPayload(t, firstDelivered.GetPayloadBytes(), "game-123", 54)
|
||||
|
||||
secondDelivered := recvPushEvent(t, targetTwo)
|
||||
assertSignedPushEvent(t, secondDelivered, push.Event{
|
||||
UserID: "user-123",
|
||||
EventType: "game.turn.ready",
|
||||
EventID: "event-123",
|
||||
PayloadBytes: payloadBytes,
|
||||
RequestID: "request-123",
|
||||
TraceID: "trace-123",
|
||||
})
|
||||
assertDecodedGameTurnReadyPayload(t, secondDelivered.GetPayloadBytes(), "game-123", 54)
|
||||
assertNoPushEvent(t, unrelated, cancelUnrelated)
|
||||
}
|
||||
|
||||
func TestSubscribeEventsFanOutsSessionTargetedEventOnlyToMatchingSession(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
sessionCache := session.NewMemoryCache()
|
||||
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-1", "user-123")))
|
||||
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-2", "user-123")))
|
||||
|
||||
pushHub := push.NewHub(4)
|
||||
clientSubscriber := newTestRedisClientEventSubscriber(t, server, pushHub)
|
||||
addr, running := runPushGateway(t, sessionCache, pushHub, clientSubscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
otherCtx, cancelOther := context.WithCancel(context.Background())
|
||||
defer cancelOther()
|
||||
otherStream, err := client.SubscribeEvents(otherCtx, newPushSubscribeEventsRequest("device-session-1", "request-1"))
|
||||
require.NoError(t, err)
|
||||
assertPushBootstrapEvent(t, recvPushEvent(t, otherStream), "request-1", "trace-device-session-1")
|
||||
|
||||
targetCtx, cancelTarget := context.WithCancel(context.Background())
|
||||
defer cancelTarget()
|
||||
targetStream, err := client.SubscribeEvents(targetCtx, newPushSubscribeEventsRequest("device-session-2", "request-2"))
|
||||
require.NoError(t, err)
|
||||
assertPushBootstrapEvent(t, recvPushEvent(t, targetStream), "request-2", "trace-device-session-2")
|
||||
|
||||
addClientEvent(t, server, "gateway:client_events", map[string]any{
|
||||
"user_id": "user-123",
|
||||
"device_session_id": "device-session-2",
|
||||
"event_type": "fleet.updated",
|
||||
"event_id": "event-456",
|
||||
"payload_bytes": []byte("payload-456"),
|
||||
})
|
||||
|
||||
assertSignedPushEvent(t, recvPushEvent(t, targetStream), push.Event{
|
||||
UserID: "user-123",
|
||||
DeviceSessionID: "device-session-2",
|
||||
EventType: "fleet.updated",
|
||||
EventID: "event-456",
|
||||
PayloadBytes: []byte("payload-456"),
|
||||
})
|
||||
assertNoPushEvent(t, otherStream, cancelOther)
|
||||
}
|
||||
|
||||
func TestSubscribeEventsClosesRevokedSessionStreamAndRejectsReopen(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
sessionCache := session.NewMemoryCache()
|
||||
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-1", "user-123")))
|
||||
|
||||
pushHub := push.NewHub(4)
|
||||
clientSubscriber := newTestRedisClientEventSubscriber(t, server, pushHub)
|
||||
sessionSubscriber := newTestRedisSessionSubscriberWithRevocationHandler(t, server, sessionCache, pushHub)
|
||||
addr, running := runPushGateway(t, sessionCache, pushHub, clientSubscriber, sessionSubscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
select {
|
||||
case <-sessionSubscriber.started:
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "session subscriber did not start")
|
||||
}
|
||||
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
streamCtx, cancelStream := context.WithCancel(context.Background())
|
||||
defer cancelStream()
|
||||
|
||||
stream, err := client.SubscribeEvents(streamCtx, newPushSubscribeEventsRequest("device-session-1", "request-1"))
|
||||
require.NoError(t, err)
|
||||
assertPushBootstrapEvent(t, recvPushEvent(t, stream), "request-1", "trace-device-session-1")
|
||||
|
||||
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
||||
"device_session_id": "device-session-1",
|
||||
"user_id": "user-123",
|
||||
"client_public_key": pushClientPublicKeyBase64(),
|
||||
"status": string(session.StatusRevoked),
|
||||
"revoked_at_ms": "123456789",
|
||||
})
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
record, lookupErr := sessionCache.Lookup(context.Background(), "device-session-1")
|
||||
return lookupErr == nil && record.Status == session.StatusRevoked
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
|
||||
recvErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
_, recvErr := stream.Recv()
|
||||
recvErrCh <- recvErr
|
||||
}()
|
||||
|
||||
select {
|
||||
case recvErr := <-recvErrCh:
|
||||
require.Error(t, recvErr)
|
||||
assert.Equal(t, codes.FailedPrecondition, status.Code(recvErr))
|
||||
assert.Equal(t, "device session is revoked", status.Convert(recvErr).Message())
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "stream did not close after revoke")
|
||||
}
|
||||
|
||||
reopened, err := client.SubscribeEvents(context.Background(), newPushSubscribeEventsRequest("device-session-1", "request-2"))
|
||||
if err == nil {
|
||||
_, err = reopened.Recv()
|
||||
}
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
|
||||
assert.Equal(t, "device session is revoked", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
func TestSubscribeEventsClosesActiveStreamWhenGatewayShutsDown(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
sessionCache := session.NewMemoryCache()
|
||||
require.NoError(t, sessionCache.Upsert(newPushActiveSessionRecord("device-session-1", "user-123")))
|
||||
|
||||
pushHub := push.NewHub(4)
|
||||
clientSubscriber := newTestRedisClientEventSubscriber(t, server, pushHub)
|
||||
addr, running := runPushGateway(t, sessionCache, pushHub, clientSubscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
conn := dialGatewayClient(t, addr)
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
client := gatewayv1.NewEdgeGatewayClient(conn)
|
||||
|
||||
stream, err := client.SubscribeEvents(context.Background(), newPushSubscribeEventsRequest("device-session-1", "request-1"))
|
||||
require.NoError(t, err)
|
||||
assertPushBootstrapEvent(t, recvPushEvent(t, stream), "request-1", "trace-device-session-1")
|
||||
|
||||
recvErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
_, recvErr := stream.Recv()
|
||||
recvErrCh <- recvErr
|
||||
}()
|
||||
|
||||
running.cancel()
|
||||
|
||||
select {
|
||||
case recvErr := <-recvErrCh:
|
||||
require.Error(t, recvErr)
|
||||
assert.Equal(t, codes.Unavailable, status.Code(recvErr))
|
||||
assert.Equal(t, "gateway is shutting down", status.Convert(recvErr).Message())
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "stream did not close after gateway shutdown")
|
||||
}
|
||||
}
|
||||
|
||||
func runPushGateway(t *testing.T, sessionCache session.Cache, pushHub *push.Hub, clientSubscriber *RedisClientEventSubscriber, extraComponents ...app.Component) (string, runningAuthenticatedGateway) {
|
||||
t.Helper()
|
||||
|
||||
addr := unusedTCPAddr(t)
|
||||
grpcCfg := config.DefaultAuthenticatedGRPCConfig()
|
||||
grpcCfg.Addr = addr
|
||||
grpcCfg.FreshnessWindow = 5 * time.Minute
|
||||
|
||||
responseSigner := newTestResponseSigner(t)
|
||||
gateway := grpcapi.NewServer(grpcCfg, grpcapi.ServerDependencies{
|
||||
Service: grpcapi.NewFanOutPushStreamService(pushHub, responseSigner, fixedClock{now: testNow}, zap.NewNop()),
|
||||
ResponseSigner: responseSigner,
|
||||
SessionCache: sessionCache,
|
||||
ReplayStore: staticReplayStore{},
|
||||
Clock: fixedClock{now: testNow},
|
||||
PushHub: pushHub,
|
||||
})
|
||||
|
||||
components := []app.Component{gateway, clientSubscriber}
|
||||
components = append(components, extraComponents...)
|
||||
application := app.New(
|
||||
config.Config{
|
||||
ShutdownTimeout: time.Second,
|
||||
AuthenticatedGRPC: grpcCfg,
|
||||
},
|
||||
components...,
|
||||
)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
resultCh := make(chan error, 1)
|
||||
go func() {
|
||||
resultCh <- application.Run(ctx)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-clientSubscriber.started:
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "client event subscriber did not start")
|
||||
}
|
||||
|
||||
return addr, runningAuthenticatedGateway{
|
||||
cancel: cancel,
|
||||
resultCh: resultCh,
|
||||
}
|
||||
}
|
||||
|
||||
func newPushActiveSessionRecord(deviceSessionID string, userID string) session.Record {
|
||||
return session.Record{
|
||||
DeviceSessionID: deviceSessionID,
|
||||
UserID: userID,
|
||||
ClientPublicKey: pushClientPublicKeyBase64(),
|
||||
Status: session.StatusActive,
|
||||
}
|
||||
}
|
||||
|
||||
func newPushSubscribeEventsRequest(deviceSessionID string, requestID string) *gatewayv1.SubscribeEventsRequest {
|
||||
payloadHash := sha256.Sum256(nil)
|
||||
traceID := "trace-" + deviceSessionID
|
||||
|
||||
req := &gatewayv1.SubscribeEventsRequest{
|
||||
ProtocolVersion: "v1",
|
||||
DeviceSessionId: deviceSessionID,
|
||||
MessageType: "gateway.subscribe",
|
||||
TimestampMs: testNow.UnixMilli(),
|
||||
RequestId: requestID,
|
||||
PayloadHash: payloadHash[:],
|
||||
TraceId: traceID,
|
||||
}
|
||||
req.Signature = ed25519.Sign(pushClientPrivateKey(), authn.BuildRequestSigningInput(authn.RequestSigningFields{
|
||||
ProtocolVersion: req.GetProtocolVersion(),
|
||||
DeviceSessionID: req.GetDeviceSessionId(),
|
||||
MessageType: req.GetMessageType(),
|
||||
TimestampMS: req.GetTimestampMs(),
|
||||
RequestID: req.GetRequestId(),
|
||||
PayloadHash: req.GetPayloadHash(),
|
||||
}))
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
func recvPushEvent(t *testing.T, stream grpc.ServerStreamingClient[gatewayv1.GatewayEvent]) *gatewayv1.GatewayEvent {
|
||||
t.Helper()
|
||||
|
||||
event, err := stream.Recv()
|
||||
require.NoError(t, err)
|
||||
return event
|
||||
}
|
||||
|
||||
func assertPushBootstrapEvent(t *testing.T, event *gatewayv1.GatewayEvent, wantRequestID string, wantTraceID string) {
|
||||
t.Helper()
|
||||
|
||||
require.NotNil(t, event)
|
||||
assert.Equal(t, "gateway.server_time", event.GetEventType())
|
||||
assert.Equal(t, wantRequestID, event.GetEventId())
|
||||
assert.Equal(t, wantRequestID, event.GetRequestId())
|
||||
assert.Equal(t, wantTraceID, event.GetTraceId())
|
||||
require.NoError(t, authn.VerifyPayloadHash(event.GetPayloadBytes(), event.GetPayloadHash()))
|
||||
require.NoError(t, authn.VerifyEventSignature(pushResponseSignerPublicKey(), event.GetSignature(), authn.EventSigningFields{
|
||||
EventType: event.GetEventType(),
|
||||
EventID: event.GetEventId(),
|
||||
TimestampMS: event.GetTimestampMs(),
|
||||
RequestID: event.GetRequestId(),
|
||||
TraceID: event.GetTraceId(),
|
||||
PayloadHash: event.GetPayloadHash(),
|
||||
}))
|
||||
}
|
||||
|
||||
func assertSignedPushEvent(t *testing.T, event *gatewayv1.GatewayEvent, want push.Event) {
|
||||
t.Helper()
|
||||
|
||||
require.NotNil(t, event)
|
||||
assert.Equal(t, want.EventType, event.GetEventType())
|
||||
assert.Equal(t, want.EventID, event.GetEventId())
|
||||
assert.Equal(t, want.RequestID, event.GetRequestId())
|
||||
assert.Equal(t, want.TraceID, event.GetTraceId())
|
||||
assert.Equal(t, want.PayloadBytes, event.GetPayloadBytes())
|
||||
require.NoError(t, authn.VerifyPayloadHash(event.GetPayloadBytes(), event.GetPayloadHash()))
|
||||
require.NoError(t, authn.VerifyEventSignature(pushResponseSignerPublicKey(), event.GetSignature(), authn.EventSigningFields{
|
||||
EventType: event.GetEventType(),
|
||||
EventID: event.GetEventId(),
|
||||
TimestampMS: event.GetTimestampMs(),
|
||||
RequestID: event.GetRequestId(),
|
||||
TraceID: event.GetTraceId(),
|
||||
PayloadHash: event.GetPayloadHash(),
|
||||
}))
|
||||
}
|
||||
|
||||
func assertNoPushEvent(t *testing.T, stream grpc.ServerStreamingClient[gatewayv1.GatewayEvent], cancel context.CancelFunc) {
|
||||
t.Helper()
|
||||
|
||||
recvCh := make(chan *gatewayv1.GatewayEvent, 1)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
event, err := stream.Recv()
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
recvCh <- event
|
||||
}()
|
||||
|
||||
select {
|
||||
case event := <-recvCh:
|
||||
require.FailNowf(t, "unexpected push event delivered", "%+v", event)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
cancel()
|
||||
case err := <-errCh:
|
||||
require.FailNowf(t, "stream closed unexpectedly", "%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func pushClientPrivateKey() ed25519.PrivateKey {
|
||||
seed := sha256.Sum256([]byte("gateway-push-grpc-test-client"))
|
||||
return ed25519.NewKeyFromSeed(seed[:])
|
||||
}
|
||||
|
||||
func pushClientPublicKeyBase64() string {
|
||||
return base64.StdEncoding.EncodeToString(pushClientPrivateKey().Public().(ed25519.PublicKey))
|
||||
}
|
||||
|
||||
func pushResponseSignerPublicKey() ed25519.PublicKey {
|
||||
seed := sha256.Sum256([]byte("gateway-events-grpc-test-response"))
|
||||
return ed25519.NewKeyFromSeed(seed[:]).Public().(ed25519.PublicKey)
|
||||
}
|
||||
|
||||
func buildGameTurnReadyPayload(t *testing.T, gameID string, turnNumber int64) []byte {
|
||||
t.Helper()
|
||||
|
||||
builder := flatbuffers.NewBuilder(64)
|
||||
gameIDOffset := builder.CreateString(gameID)
|
||||
|
||||
notificationfbs.GameTurnReadyEventStart(builder)
|
||||
notificationfbs.GameTurnReadyEventAddGameId(builder, gameIDOffset)
|
||||
notificationfbs.GameTurnReadyEventAddTurnNumber(builder, turnNumber)
|
||||
offset := notificationfbs.GameTurnReadyEventEnd(builder)
|
||||
notificationfbs.FinishGameTurnReadyEventBuffer(builder, offset)
|
||||
|
||||
return builder.FinishedBytes()
|
||||
}
|
||||
|
||||
func assertDecodedGameTurnReadyPayload(t *testing.T, payload []byte, wantGameID string, wantTurnNumber int64) {
|
||||
t.Helper()
|
||||
|
||||
event := notificationfbs.GetRootAsGameTurnReadyEvent(payload, 0)
|
||||
require.Equal(t, wantGameID, string(event.GameId()))
|
||||
require.Equal(t, wantTurnNumber, event.TurnNumber())
|
||||
}
|
||||
@@ -1,347 +0,0 @@
|
||||
// Package events subscribes to internal session lifecycle streams used to keep
|
||||
// the gateway hot-path session cache synchronized without per-request upstream
|
||||
// lookups.
|
||||
package events
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/config"
|
||||
"galaxy/gateway/internal/session"
|
||||
"galaxy/gateway/internal/telemetry"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const sessionEventReadCount int64 = 128
|
||||
|
||||
// SessionRevocationHandler reacts to a successfully applied revoked session
|
||||
// snapshot and may tear down active resources bound to that session.
|
||||
type SessionRevocationHandler interface {
|
||||
// RevokeDeviceSession tears down active resources bound to deviceSessionID.
|
||||
RevokeDeviceSession(deviceSessionID string)
|
||||
}
|
||||
|
||||
// RedisSessionSubscriber consumes full session snapshots from one Redis Stream
|
||||
// and applies them to a process-local session snapshot store.
|
||||
type RedisSessionSubscriber struct {
|
||||
client *redis.Client
|
||||
stream string
|
||||
pingTimeout time.Duration
|
||||
readBlockTimeout time.Duration
|
||||
store session.SnapshotStore
|
||||
revocationHandler SessionRevocationHandler
|
||||
logger *zap.Logger
|
||||
metrics *telemetry.Runtime
|
||||
|
||||
startedOnce sync.Once
|
||||
started chan struct{}
|
||||
}
|
||||
|
||||
// NewRedisSessionSubscriber constructs a Redis Stream subscriber that uses
|
||||
// client and applies updates to store.
|
||||
func NewRedisSessionSubscriber(client *redis.Client, sessionCfg config.SessionCacheRedisConfig, eventsCfg config.SessionEventsRedisConfig, store session.SnapshotStore) (*RedisSessionSubscriber, error) {
|
||||
return NewRedisSessionSubscriberWithObservability(client, sessionCfg, eventsCfg, store, nil, nil, nil)
|
||||
}
|
||||
|
||||
// NewRedisSessionSubscriberWithRevocationHandler constructs a Redis Stream
|
||||
// subscriber that uses client, applies updates to store, and optionally tears
|
||||
// down active resources for revoked sessions.
|
||||
func NewRedisSessionSubscriberWithRevocationHandler(client *redis.Client, sessionCfg config.SessionCacheRedisConfig, eventsCfg config.SessionEventsRedisConfig, store session.SnapshotStore, revocationHandler SessionRevocationHandler) (*RedisSessionSubscriber, error) {
|
||||
return NewRedisSessionSubscriberWithObservability(client, sessionCfg, eventsCfg, store, revocationHandler, nil, nil)
|
||||
}
|
||||
|
||||
// NewRedisSessionSubscriberWithObservability constructs a Redis Stream
|
||||
// subscriber that also logs and counts malformed internal session events. The
|
||||
// subscriber does not own the client; the runtime supplies a shared
|
||||
// *redis.Client.
|
||||
func NewRedisSessionSubscriberWithObservability(client *redis.Client, sessionCfg config.SessionCacheRedisConfig, eventsCfg config.SessionEventsRedisConfig, store session.SnapshotStore, revocationHandler SessionRevocationHandler, logger *zap.Logger, metrics *telemetry.Runtime) (*RedisSessionSubscriber, error) {
|
||||
if client == nil {
|
||||
return nil, errors.New("new redis session subscriber: nil redis client")
|
||||
}
|
||||
if sessionCfg.LookupTimeout <= 0 {
|
||||
return nil, errors.New("new redis session subscriber: lookup timeout must be positive")
|
||||
}
|
||||
if strings.TrimSpace(eventsCfg.Stream) == "" {
|
||||
return nil, errors.New("new redis session subscriber: stream must not be empty")
|
||||
}
|
||||
if eventsCfg.ReadBlockTimeout <= 0 {
|
||||
return nil, errors.New("new redis session subscriber: read block timeout must be positive")
|
||||
}
|
||||
if store == nil {
|
||||
return nil, errors.New("new redis session subscriber: nil session snapshot store")
|
||||
}
|
||||
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
|
||||
return &RedisSessionSubscriber{
|
||||
client: client,
|
||||
stream: eventsCfg.Stream,
|
||||
pingTimeout: sessionCfg.LookupTimeout,
|
||||
readBlockTimeout: eventsCfg.ReadBlockTimeout,
|
||||
store: store,
|
||||
revocationHandler: revocationHandler,
|
||||
logger: logger.Named("session_subscriber"),
|
||||
metrics: metrics,
|
||||
started: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Run consumes session lifecycle events until ctx is canceled or Redis returns
|
||||
// an unexpected error.
|
||||
func (s *RedisSessionSubscriber) Run(ctx context.Context) error {
|
||||
if s == nil || s.client == nil {
|
||||
return errors.New("run redis session subscriber: nil subscriber")
|
||||
}
|
||||
if ctx == nil {
|
||||
return errors.New("run redis session subscriber: nil context")
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
lastID, err := s.resolveStartID(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.signalStarted()
|
||||
|
||||
for {
|
||||
streams, err := s.client.XRead(ctx, &redis.XReadArgs{
|
||||
Streams: []string{s.stream, lastID},
|
||||
Count: sessionEventReadCount,
|
||||
Block: s.readBlockTimeout,
|
||||
}).Result()
|
||||
switch {
|
||||
case err == nil:
|
||||
for _, stream := range streams {
|
||||
for _, message := range stream.Messages {
|
||||
s.applyMessage(message)
|
||||
lastID = message.ID
|
||||
}
|
||||
}
|
||||
continue
|
||||
case errors.Is(err, redis.Nil):
|
||||
continue
|
||||
case ctx.Err() != nil && (errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) || errors.Is(err, redis.ErrClosed)):
|
||||
return ctx.Err()
|
||||
case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded), errors.Is(err, redis.ErrClosed):
|
||||
return fmt.Errorf("run redis session subscriber: %w", err)
|
||||
default:
|
||||
return fmt.Errorf("run redis session subscriber: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RedisSessionSubscriber) resolveStartID(ctx context.Context) (string, error) {
|
||||
messages, err := s.client.XRevRangeN(ctx, s.stream, "+", "-", 1).Result()
|
||||
switch {
|
||||
case err == nil:
|
||||
case errors.Is(err, redis.Nil):
|
||||
return "0-0", nil
|
||||
default:
|
||||
return "", fmt.Errorf("run redis session subscriber: resolve stream tail: %w", err)
|
||||
}
|
||||
|
||||
if len(messages) == 0 {
|
||||
return "0-0", nil
|
||||
}
|
||||
|
||||
return messages[0].ID, nil
|
||||
}
|
||||
|
||||
// Shutdown is a no-op kept for App framework compatibility. The blocking
|
||||
// XRead loop terminates when its context is cancelled by the parent runtime,
|
||||
// which also owns and closes the shared Redis client.
|
||||
func (s *RedisSessionSubscriber) Shutdown(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
return errors.New("shutdown redis session subscriber: nil context")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close is a no-op kept for backwards-compatible cleanup wiring; the
|
||||
// subscriber does not own the shared Redis client.
|
||||
func (s *RedisSessionSubscriber) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *RedisSessionSubscriber) signalStarted() {
|
||||
s.startedOnce.Do(func() {
|
||||
close(s.started)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *RedisSessionSubscriber) applyMessage(message redis.XMessage) {
|
||||
record, err := decodeSessionRecordSnapshot(message.Values)
|
||||
if err != nil {
|
||||
s.logger.Warn("dropped malformed session event",
|
||||
zap.String("stream", s.stream),
|
||||
zap.String("message_id", message.ID),
|
||||
zap.Error(err),
|
||||
)
|
||||
s.metrics.RecordInternalEventDrop(context.Background(),
|
||||
attribute.String("component", "session_subscriber"),
|
||||
attribute.String("reason", "malformed_event"),
|
||||
)
|
||||
if deviceSessionID, ok := extractDeviceSessionID(message.Values); ok {
|
||||
s.store.Delete(deviceSessionID)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.store.Upsert(record); err != nil {
|
||||
s.logger.Warn("dropped session snapshot after store failure",
|
||||
zap.String("stream", s.stream),
|
||||
zap.String("message_id", message.ID),
|
||||
zap.String("device_session_id", record.DeviceSessionID),
|
||||
zap.Error(err),
|
||||
)
|
||||
s.metrics.RecordInternalEventDrop(context.Background(),
|
||||
attribute.String("component", "session_subscriber"),
|
||||
attribute.String("reason", "store_failure"),
|
||||
)
|
||||
s.store.Delete(record.DeviceSessionID)
|
||||
return
|
||||
}
|
||||
|
||||
if record.Status == session.StatusRevoked && s.revocationHandler != nil {
|
||||
s.revocationHandler.RevokeDeviceSession(record.DeviceSessionID)
|
||||
}
|
||||
}
|
||||
|
||||
func decodeSessionRecordSnapshot(values map[string]any) (session.Record, error) {
|
||||
requiredKeys := map[string]struct{}{
|
||||
"device_session_id": {},
|
||||
"user_id": {},
|
||||
"client_public_key": {},
|
||||
"status": {},
|
||||
}
|
||||
optionalKeys := map[string]struct{}{
|
||||
"revoked_at_ms": {},
|
||||
}
|
||||
|
||||
for key := range values {
|
||||
if _, ok := requiredKeys[key]; ok {
|
||||
continue
|
||||
}
|
||||
if _, ok := optionalKeys[key]; ok {
|
||||
continue
|
||||
}
|
||||
|
||||
return session.Record{}, fmt.Errorf("decode session event: unsupported field %q", key)
|
||||
}
|
||||
|
||||
deviceSessionID, err := requiredStringField(values, "device_session_id")
|
||||
if err != nil {
|
||||
return session.Record{}, err
|
||||
}
|
||||
userID, err := requiredStringField(values, "user_id")
|
||||
if err != nil {
|
||||
return session.Record{}, err
|
||||
}
|
||||
clientPublicKey, err := requiredStringField(values, "client_public_key")
|
||||
if err != nil {
|
||||
return session.Record{}, err
|
||||
}
|
||||
statusValue, err := requiredStringField(values, "status")
|
||||
if err != nil {
|
||||
return session.Record{}, err
|
||||
}
|
||||
|
||||
record := session.Record{
|
||||
DeviceSessionID: deviceSessionID,
|
||||
UserID: userID,
|
||||
ClientPublicKey: clientPublicKey,
|
||||
Status: session.Status(statusValue),
|
||||
}
|
||||
|
||||
if rawRevokedAtMS, ok := values["revoked_at_ms"]; ok {
|
||||
revokedAtMS, err := parseInt64Field(rawRevokedAtMS, "revoked_at_ms")
|
||||
if err != nil {
|
||||
return session.Record{}, err
|
||||
}
|
||||
record.RevokedAtMS = &revokedAtMS
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func extractDeviceSessionID(values map[string]any) (string, bool) {
|
||||
value, ok := values["device_session_id"]
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
|
||||
deviceSessionID, err := coerceString(value)
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
if strings.TrimSpace(deviceSessionID) == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return deviceSessionID, true
|
||||
}
|
||||
|
||||
func requiredStringField(values map[string]any, field string) (string, error) {
|
||||
value, ok := values[field]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("decode session event: missing %s", field)
|
||||
}
|
||||
|
||||
stringValue, err := coerceString(value)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decode session event: %s: %w", field, err)
|
||||
}
|
||||
if strings.TrimSpace(stringValue) == "" {
|
||||
return "", fmt.Errorf("decode session event: %s must not be empty", field)
|
||||
}
|
||||
|
||||
return stringValue, nil
|
||||
}
|
||||
|
||||
func parseInt64Field(value any, field string) (int64, error) {
|
||||
stringValue, err := coerceString(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("decode session event: %s: %w", field, err)
|
||||
}
|
||||
|
||||
parsed, err := strconv.ParseInt(strings.TrimSpace(stringValue), 10, 64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("decode session event: %s: %w", field, err)
|
||||
}
|
||||
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
func coerceString(value any) (string, error) {
|
||||
switch typed := value.(type) {
|
||||
case string:
|
||||
return typed, nil
|
||||
case []byte:
|
||||
return string(typed), nil
|
||||
case fmt.Stringer:
|
||||
return typed.String(), nil
|
||||
case int:
|
||||
return strconv.Itoa(typed), nil
|
||||
case int64:
|
||||
return strconv.FormatInt(typed, 10), nil
|
||||
case uint64:
|
||||
return strconv.FormatUint(typed, 10), nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported value type %T", value)
|
||||
}
|
||||
}
|
||||
@@ -1,381 +0,0 @@
|
||||
package events
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/config"
|
||||
"galaxy/gateway/internal/session"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRedisSessionSubscriberAppliesActiveSnapshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := session.NewMemoryCache()
|
||||
subscriber := newTestRedisSessionSubscriber(t, server, store)
|
||||
running := runTestSubscriber(t, subscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
||||
"device_session_id": "device-session-123",
|
||||
"user_id": "user-123",
|
||||
"client_public_key": "public-key-123",
|
||||
"status": string(session.StatusActive),
|
||||
})
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
record, err := store.Lookup(context.Background(), "device-session-123")
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return record.UserID == "user-123" && record.Status == session.StatusActive
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestRedisSessionSubscriberAppliesRevokedSnapshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := session.NewMemoryCache()
|
||||
require.NoError(t, store.Upsert(session.Record{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: "public-key-123",
|
||||
Status: session.StatusActive,
|
||||
}))
|
||||
|
||||
subscriber := newTestRedisSessionSubscriber(t, server, store)
|
||||
running := runTestSubscriber(t, subscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
||||
"device_session_id": "device-session-123",
|
||||
"user_id": "user-123",
|
||||
"client_public_key": "public-key-123",
|
||||
"status": string(session.StatusRevoked),
|
||||
"revoked_at_ms": "123456789",
|
||||
})
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
record, err := store.Lookup(context.Background(), "device-session-123")
|
||||
if err != nil || record.RevokedAtMS == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return record.Status == session.StatusRevoked && *record.RevokedAtMS == 123456789
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestRedisSessionSubscriberRevokedSnapshotTriggersRevocationHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := session.NewMemoryCache()
|
||||
handler := &recordingSessionRevocationHandler{}
|
||||
subscriber := newTestRedisSessionSubscriberWithRevocationHandler(t, server, store, handler)
|
||||
running := runTestSubscriber(t, subscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
||||
"device_session_id": "device-session-123",
|
||||
"user_id": "user-123",
|
||||
"client_public_key": "public-key-123",
|
||||
"status": string(session.StatusRevoked),
|
||||
"revoked_at_ms": "123456789",
|
||||
})
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
record, err := store.Lookup(context.Background(), "device-session-123")
|
||||
if err != nil || record.Status != session.StatusRevoked {
|
||||
return false
|
||||
}
|
||||
|
||||
return assert.ObjectsAreEqual([]string{"device-session-123"}, handler.revocations())
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestRedisSessionSubscriberActiveSnapshotDoesNotTriggerRevocationHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := session.NewMemoryCache()
|
||||
handler := &recordingSessionRevocationHandler{}
|
||||
subscriber := newTestRedisSessionSubscriberWithRevocationHandler(t, server, store, handler)
|
||||
running := runTestSubscriber(t, subscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
||||
"device_session_id": "device-session-123",
|
||||
"user_id": "user-123",
|
||||
"client_public_key": "public-key-123",
|
||||
"status": string(session.StatusActive),
|
||||
})
|
||||
|
||||
assert.Never(t, func() bool {
|
||||
return len(handler.revocations()) != 0
|
||||
}, 100*time.Millisecond, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestRedisSessionSubscriberStoreFailureDoesNotTriggerRevocationHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
handler := &recordingSessionRevocationHandler{}
|
||||
subscriber := newTestRedisSessionSubscriberWithRevocationHandler(t, server, failingSnapshotStore{}, handler)
|
||||
running := runTestSubscriber(t, subscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
||||
"device_session_id": "device-session-123",
|
||||
"user_id": "user-123",
|
||||
"client_public_key": "public-key-123",
|
||||
"status": string(session.StatusRevoked),
|
||||
"revoked_at_ms": "123456789",
|
||||
})
|
||||
|
||||
assert.Never(t, func() bool {
|
||||
return len(handler.revocations()) != 0
|
||||
}, 100*time.Millisecond, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestRedisSessionSubscriberLaterEventWins(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := session.NewMemoryCache()
|
||||
subscriber := newTestRedisSessionSubscriber(t, server, store)
|
||||
running := runTestSubscriber(t, subscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
||||
"device_session_id": "device-session-123",
|
||||
"user_id": "user-123",
|
||||
"client_public_key": "public-key-123",
|
||||
"status": string(session.StatusActive),
|
||||
})
|
||||
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
||||
"device_session_id": "device-session-123",
|
||||
"user_id": "user-456",
|
||||
"client_public_key": "public-key-456",
|
||||
"status": string(session.StatusActive),
|
||||
})
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
record, err := store.Lookup(context.Background(), "device-session-123")
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return record.UserID == "user-456" && record.ClientPublicKey == "public-key-456"
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestRedisSessionSubscriberMalformedEventEvictsAndContinues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := session.NewMemoryCache()
|
||||
require.NoError(t, store.Upsert(session.Record{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: "public-key-123",
|
||||
Status: session.StatusActive,
|
||||
}))
|
||||
|
||||
subscriber := newTestRedisSessionSubscriber(t, server, store)
|
||||
running := runTestSubscriber(t, subscriber)
|
||||
defer running.stop(t)
|
||||
|
||||
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
||||
"device_session_id": "device-session-123",
|
||||
"user_id": "user-123",
|
||||
"client_public_key": "public-key-123",
|
||||
"status": "paused",
|
||||
})
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
_, err := store.Lookup(context.Background(), "device-session-123")
|
||||
return err != nil
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
|
||||
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
||||
"device_session_id": "device-session-123",
|
||||
"user_id": "user-456",
|
||||
"client_public_key": "public-key-456",
|
||||
"status": string(session.StatusActive),
|
||||
})
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
record, err := store.Lookup(context.Background(), "device-session-123")
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return record.UserID == "user-456" && record.Status == session.StatusActive
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestRedisSessionSubscriberShutdownInterruptsBlockingRead(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
store := session.NewMemoryCache()
|
||||
subscriber := newTestRedisSessionSubscriber(t, server, store)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
resultCh := make(chan error, 1)
|
||||
go func() {
|
||||
resultCh <- subscriber.Run(ctx)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-subscriber.started:
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "subscriber did not start")
|
||||
}
|
||||
|
||||
cancel()
|
||||
require.NoError(t, subscriber.Shutdown(context.Background()))
|
||||
|
||||
select {
|
||||
case err := <-resultCh:
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "subscriber did not stop after shutdown")
|
||||
}
|
||||
}
|
||||
|
||||
func newTestRedisSessionSubscriber(t *testing.T, server *miniredis.Miniredis, store session.SnapshotStore) *RedisSessionSubscriber {
|
||||
t.Helper()
|
||||
|
||||
return newTestRedisSessionSubscriberWithRevocationHandler(t, server, store, nil)
|
||||
}
|
||||
|
||||
func newTestRedisSessionSubscriberWithRevocationHandler(t *testing.T, server *miniredis.Miniredis, store session.SnapshotStore, revocationHandler SessionRevocationHandler) *RedisSessionSubscriber {
|
||||
t.Helper()
|
||||
|
||||
client := newTestRedisClient(t, server)
|
||||
|
||||
subscriber, err := NewRedisSessionSubscriberWithRevocationHandler(
|
||||
client,
|
||||
config.SessionCacheRedisConfig{
|
||||
KeyPrefix: "gateway:session:",
|
||||
LookupTimeout: 250 * time.Millisecond,
|
||||
},
|
||||
config.SessionEventsRedisConfig{
|
||||
Stream: "gateway:session_events",
|
||||
ReadBlockTimeout: 25 * time.Millisecond,
|
||||
},
|
||||
store,
|
||||
revocationHandler,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
return subscriber
|
||||
}
|
||||
|
||||
func newTestRedisClient(t *testing.T, server *miniredis.Miniredis) *redis.Client {
|
||||
t.Helper()
|
||||
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: server.Addr(),
|
||||
Protocol: 2,
|
||||
DisableIdentity: true,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, client.Close())
|
||||
})
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
type recordingSessionRevocationHandler struct {
|
||||
mu sync.Mutex
|
||||
revokedIDs []string
|
||||
}
|
||||
|
||||
func (h *recordingSessionRevocationHandler) RevokeDeviceSession(deviceSessionID string) {
|
||||
h.mu.Lock()
|
||||
h.revokedIDs = append(h.revokedIDs, deviceSessionID)
|
||||
h.mu.Unlock()
|
||||
}
|
||||
|
||||
func (h *recordingSessionRevocationHandler) revocations() []string {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
return append([]string(nil), h.revokedIDs...)
|
||||
}
|
||||
|
||||
type failingSnapshotStore struct{}
|
||||
|
||||
func (failingSnapshotStore) Lookup(context.Context, string) (session.Record, error) {
|
||||
return session.Record{}, session.ErrNotFound
|
||||
}
|
||||
|
||||
func (failingSnapshotStore) Upsert(session.Record) error {
|
||||
return context.DeadlineExceeded
|
||||
}
|
||||
|
||||
func (failingSnapshotStore) Delete(string) {}
|
||||
|
||||
func addSessionEvent(t *testing.T, server *miniredis.Miniredis, stream string, fields map[string]string) {
|
||||
t.Helper()
|
||||
|
||||
values := make([]string, 0, len(fields)*2)
|
||||
for key, value := range fields {
|
||||
values = append(values, key, value)
|
||||
}
|
||||
|
||||
_, err := server.XAdd(stream, "*", values)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
type runningSubscriber struct {
|
||||
cancel context.CancelFunc
|
||||
resultCh chan error
|
||||
stopOnce bool
|
||||
}
|
||||
|
||||
func runTestSubscriber(t *testing.T, subscriber *RedisSessionSubscriber) runningSubscriber {
|
||||
t.Helper()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
resultCh := make(chan error, 1)
|
||||
go func() {
|
||||
resultCh <- subscriber.Run(ctx)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-subscriber.started:
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "subscriber did not start")
|
||||
}
|
||||
|
||||
return runningSubscriber{
|
||||
cancel: cancel,
|
||||
resultCh: resultCh,
|
||||
}
|
||||
}
|
||||
|
||||
func (r runningSubscriber) stop(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
r.cancel()
|
||||
|
||||
select {
|
||||
case err := <-r.resultCh:
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
case <-time.After(time.Second):
|
||||
require.FailNow(t, "subscriber did not stop")
|
||||
}
|
||||
}
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/authn"
|
||||
"galaxy/gateway/authn"
|
||||
"galaxy/gateway/internal/clock"
|
||||
"galaxy/gateway/internal/downstream"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/authn"
|
||||
"galaxy/gateway/authn"
|
||||
"galaxy/gateway/internal/config"
|
||||
"galaxy/gateway/internal/downstream"
|
||||
"galaxy/gateway/internal/testutil"
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"galaxy/gateway/internal/authn"
|
||||
"galaxy/gateway/authn"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
|
||||
"galaxy/gateway/internal/authn"
|
||||
"galaxy/gateway/authn"
|
||||
"galaxy/gateway/internal/clock"
|
||||
"galaxy/gateway/internal/logging"
|
||||
"galaxy/gateway/internal/push"
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
|
||||
"galaxy/gateway/internal/authn"
|
||||
"galaxy/gateway/authn"
|
||||
"galaxy/gateway/internal/clock"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
gatewayfbs "galaxy/schema/fbs/gateway"
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"galaxy/gateway/internal/authn"
|
||||
"galaxy/gateway/authn"
|
||||
"galaxy/gateway/internal/clock"
|
||||
"galaxy/gateway/internal/config"
|
||||
"galaxy/gateway/internal/downstream"
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"galaxy/gateway/internal/authn"
|
||||
"galaxy/gateway/authn"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"encoding/pem"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/authn"
|
||||
"galaxy/gateway/authn"
|
||||
"galaxy/gateway/internal/downstream"
|
||||
"galaxy/gateway/internal/session"
|
||||
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
||||
|
||||
@@ -273,6 +273,28 @@ func (h *Hub) RevokeDeviceSession(deviceSessionID string) {
|
||||
}
|
||||
}
|
||||
|
||||
// RevokeAllForUser closes every active subscription bound to userID,
|
||||
// regardless of device-session id. Used when backend emits a
|
||||
// SessionInvalidation that targets every session of a user.
|
||||
func (h *Hub) RevokeAllForUser(userID string) {
|
||||
if h == nil {
|
||||
return
|
||||
}
|
||||
|
||||
userID = strings.TrimSpace(userID)
|
||||
if userID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
h.mu.RLock()
|
||||
targets := cloneSubscriptions(h.byUser[userID])
|
||||
h.mu.RUnlock()
|
||||
|
||||
for _, target := range targets {
|
||||
h.unregister(target.id, ErrSubscriptionRevoked)
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown closes every active subscription because the gateway is shutting
|
||||
// down.
|
||||
func (h *Hub) Shutdown() {
|
||||
|
||||
@@ -1,232 +0,0 @@
|
||||
package restapi
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
authServiceSendEmailCodePath = "/api/v1/public/auth/send-email-code"
|
||||
authServiceConfirmEmailCodePath = "/api/v1/public/auth/confirm-email-code"
|
||||
)
|
||||
|
||||
// HTTPAuthServiceClient implements AuthServiceClient over the Auth / Session
|
||||
// Service public HTTP API using strict JSON request and response decoding.
|
||||
type HTTPAuthServiceClient struct {
|
||||
baseURL string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
type authServiceErrorEnvelope struct {
|
||||
Error *authServiceErrorBody `json:"error"`
|
||||
}
|
||||
|
||||
type authServiceErrorBody struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// NewHTTPAuthServiceClient constructs an AuthServiceClient that delegates the
|
||||
// gateway public-auth routes to the Auth / Session Service public HTTP API at
|
||||
// baseURL. The resulting client relies only on the caller-provided context for
|
||||
// cancellation and timeout control.
|
||||
func NewHTTPAuthServiceClient(baseURL string) (*HTTPAuthServiceClient, error) {
|
||||
transport, ok := http.DefaultTransport.(*http.Transport)
|
||||
if !ok {
|
||||
return nil, errors.New("new auth service HTTP client: default transport is not *http.Transport")
|
||||
}
|
||||
|
||||
return newHTTPAuthServiceClient(baseURL, &http.Client{
|
||||
Transport: transport.Clone(),
|
||||
})
|
||||
}
|
||||
|
||||
func newHTTPAuthServiceClient(baseURL string, httpClient *http.Client) (*HTTPAuthServiceClient, error) {
|
||||
if httpClient == nil {
|
||||
return nil, errors.New("new auth service HTTP client: http client must not be nil")
|
||||
}
|
||||
|
||||
trimmedBaseURL := strings.TrimSpace(baseURL)
|
||||
if trimmedBaseURL == "" {
|
||||
return nil, errors.New("new auth service HTTP client: base URL must not be empty")
|
||||
}
|
||||
|
||||
parsedBaseURL, err := url.Parse(strings.TrimRight(trimmedBaseURL, "/"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new auth service HTTP client: parse base URL: %w", err)
|
||||
}
|
||||
if parsedBaseURL.Scheme == "" || parsedBaseURL.Host == "" {
|
||||
return nil, errors.New("new auth service HTTP client: base URL must be absolute")
|
||||
}
|
||||
|
||||
return &HTTPAuthServiceClient{
|
||||
baseURL: parsedBaseURL.String(),
|
||||
httpClient: httpClient,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close releases idle HTTP connections owned by the client transport.
|
||||
func (c *HTTPAuthServiceClient) Close() error {
|
||||
if c == nil || c.httpClient == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
type idleCloser interface {
|
||||
CloseIdleConnections()
|
||||
}
|
||||
|
||||
if transport, ok := c.httpClient.Transport.(idleCloser); ok {
|
||||
transport.CloseIdleConnections()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendEmailCode delegates the public send-email-code route to the configured
|
||||
// Auth / Session Service public HTTP API.
|
||||
func (c *HTTPAuthServiceClient) SendEmailCode(ctx context.Context, input SendEmailCodeInput) (SendEmailCodeResult, error) {
|
||||
payload, statusCode, err := c.doJSONRequest(ctx, authServiceSendEmailCodePath, input, map[string]string{
|
||||
"Accept-Language": resolvePreferredLanguage(input.PreferredLanguage),
|
||||
})
|
||||
if err != nil {
|
||||
return SendEmailCodeResult{}, fmt.Errorf("send email code via auth service: %w", err)
|
||||
}
|
||||
|
||||
switch {
|
||||
case statusCode == http.StatusOK:
|
||||
var result SendEmailCodeResult
|
||||
if err := decodeStrictJSONPayload(payload, &result); err != nil {
|
||||
return SendEmailCodeResult{}, fmt.Errorf("send email code via auth service: decode success response: %w", err)
|
||||
}
|
||||
if err := validateSendEmailCodeResult(&result); err != nil {
|
||||
return SendEmailCodeResult{}, fmt.Errorf("send email code via auth service: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
case statusCode >= 400 && statusCode <= 599:
|
||||
authErr, err := decodeAuthServiceError(statusCode, payload)
|
||||
if err != nil {
|
||||
return SendEmailCodeResult{}, fmt.Errorf("send email code via auth service: %w", err)
|
||||
}
|
||||
|
||||
return SendEmailCodeResult{}, authErr
|
||||
default:
|
||||
return SendEmailCodeResult{}, fmt.Errorf("send email code via auth service: unexpected HTTP status %d", statusCode)
|
||||
}
|
||||
}
|
||||
|
||||
// ConfirmEmailCode delegates the public confirm-email-code route to the
|
||||
// configured Auth / Session Service public HTTP API.
|
||||
func (c *HTTPAuthServiceClient) ConfirmEmailCode(ctx context.Context, input ConfirmEmailCodeInput) (ConfirmEmailCodeResult, error) {
|
||||
payload, statusCode, err := c.doJSONRequest(ctx, authServiceConfirmEmailCodePath, input, nil)
|
||||
if err != nil {
|
||||
return ConfirmEmailCodeResult{}, fmt.Errorf("confirm email code via auth service: %w", err)
|
||||
}
|
||||
|
||||
switch {
|
||||
case statusCode == http.StatusOK:
|
||||
var result ConfirmEmailCodeResult
|
||||
if err := decodeStrictJSONPayload(payload, &result); err != nil {
|
||||
return ConfirmEmailCodeResult{}, fmt.Errorf("confirm email code via auth service: decode success response: %w", err)
|
||||
}
|
||||
if err := validateConfirmEmailCodeResult(&result); err != nil {
|
||||
return ConfirmEmailCodeResult{}, fmt.Errorf("confirm email code via auth service: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
case statusCode >= 400 && statusCode <= 599:
|
||||
authErr, err := decodeAuthServiceError(statusCode, payload)
|
||||
if err != nil {
|
||||
return ConfirmEmailCodeResult{}, fmt.Errorf("confirm email code via auth service: %w", err)
|
||||
}
|
||||
|
||||
return ConfirmEmailCodeResult{}, authErr
|
||||
default:
|
||||
return ConfirmEmailCodeResult{}, fmt.Errorf("confirm email code via auth service: unexpected HTTP status %d", statusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HTTPAuthServiceClient) doJSONRequest(ctx context.Context, path string, requestBody any, headers map[string]string) ([]byte, int, error) {
|
||||
if c == nil || c.httpClient == nil {
|
||||
return nil, 0, errors.New("nil client")
|
||||
}
|
||||
if ctx == nil {
|
||||
return nil, 0, errors.New("nil context")
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(requestBody)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("marshal request body: %w", err)
|
||||
}
|
||||
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+path, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("build request: %w", err)
|
||||
}
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
for key, value := range headers {
|
||||
if strings.TrimSpace(value) == "" {
|
||||
continue
|
||||
}
|
||||
request.Header.Set(key, value)
|
||||
}
|
||||
|
||||
response, err := c.httpClient.Do(request)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
responsePayload, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("read response body: %w", err)
|
||||
}
|
||||
|
||||
return responsePayload, response.StatusCode, nil
|
||||
}
|
||||
|
||||
func decodeAuthServiceError(statusCode int, payload []byte) (*AuthServiceError, error) {
|
||||
var envelope authServiceErrorEnvelope
|
||||
if err := decodeStrictJSONPayload(payload, &envelope); err != nil {
|
||||
return nil, fmt.Errorf("decode error response: %w", err)
|
||||
}
|
||||
if envelope.Error == nil {
|
||||
return nil, errors.New("decode error response: missing error object")
|
||||
}
|
||||
|
||||
return &AuthServiceError{
|
||||
StatusCode: statusCode,
|
||||
Code: envelope.Error.Code,
|
||||
Message: envelope.Error.Message,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func decodeStrictJSONPayload(payload []byte, target any) error {
|
||||
decoder := json.NewDecoder(bytes.NewReader(payload))
|
||||
decoder.DisallowUnknownFields()
|
||||
|
||||
if err := decoder.Decode(target); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := decoder.Decode(&struct{}{}); err != io.EOF {
|
||||
if err == nil {
|
||||
return errors.New("unexpected trailing JSON input")
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ AuthServiceClient = (*HTTPAuthServiceClient)(nil)
|
||||
@@ -1,369 +0,0 @@
|
||||
package restapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewHTTPAuthServiceClient(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
baseURL string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
baseURL: " http://127.0.0.1:8080/ ",
|
||||
},
|
||||
{
|
||||
name: "empty base url",
|
||||
wantErr: "base URL must not be empty",
|
||||
},
|
||||
{
|
||||
name: "relative base url",
|
||||
baseURL: "/authsession",
|
||||
wantErr: "base URL must be absolute",
|
||||
},
|
||||
{
|
||||
name: "malformed base url",
|
||||
baseURL: "://bad",
|
||||
wantErr: "parse base URL",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, err := NewHTTPAuthServiceClient(tt.baseURL)
|
||||
if tt.wantErr != "" {
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "http://127.0.0.1:8080", client.baseURL)
|
||||
assert.NoError(t, client.Close())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPAuthServiceClientSendEmailCodeSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var requestContentType string
|
||||
var requestAcceptLanguage string
|
||||
var requestBody string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, http.MethodPost, r.Method)
|
||||
assert.Equal(t, authServiceSendEmailCodePath, r.URL.Path)
|
||||
|
||||
requestContentType = r.Header.Get("Content-Type")
|
||||
requestAcceptLanguage = r.Header.Get("Accept-Language")
|
||||
payload, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
requestBody = string(payload)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, err = io.WriteString(w, `{"challenge_id":"challenge-123"}`)
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestHTTPAuthServiceClient(t, server)
|
||||
|
||||
result, err := client.SendEmailCode(context.Background(), SendEmailCodeInput{
|
||||
Email: "pilot@example.com",
|
||||
PreferredLanguage: "fr-FR",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, SendEmailCodeResult{ChallengeID: "challenge-123"}, result)
|
||||
assert.Equal(t, "application/json", requestContentType)
|
||||
assert.Equal(t, "fr-FR", requestAcceptLanguage)
|
||||
assert.JSONEq(t, `{"email":"pilot@example.com"}`, requestBody)
|
||||
}
|
||||
|
||||
func TestHTTPAuthServiceClientSendEmailCodeDefaultsAcceptLanguageToEnglish(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var requestAcceptLanguage string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestAcceptLanguage = r.Header.Get("Accept-Language")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, err := io.WriteString(w, `{"challenge_id":"challenge-123"}`)
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestHTTPAuthServiceClient(t, server)
|
||||
|
||||
_, err := client.SendEmailCode(context.Background(), SendEmailCodeInput{Email: "pilot@example.com"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "en", requestAcceptLanguage)
|
||||
}
|
||||
|
||||
func TestHTTPAuthServiceClientConfirmEmailCodeSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, http.MethodPost, r.Method)
|
||||
assert.Equal(t, authServiceConfirmEmailCodePath, r.URL.Path)
|
||||
|
||||
payload, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
assert.JSONEq(t, `{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key","time_zone":"Europe/Kaliningrad"}`, string(payload))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, err = io.WriteString(w, `{"device_session_id":"device-session-123"}`)
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestHTTPAuthServiceClient(t, server)
|
||||
|
||||
result, err := client.ConfirmEmailCode(context.Background(), ConfirmEmailCodeInput{
|
||||
ChallengeID: "challenge-123",
|
||||
Code: "123456",
|
||||
ClientPublicKey: "public-key",
|
||||
TimeZone: "Europe/Kaliningrad",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ConfirmEmailCodeResult{DeviceSessionID: "device-session-123"}, result)
|
||||
}
|
||||
|
||||
func TestHTTPAuthServiceClientProjectsAuthServiceErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
responseBody string
|
||||
call func(*HTTPAuthServiceClient) error
|
||||
wantStatusCode int
|
||||
wantCode string
|
||||
wantMessage string
|
||||
}{
|
||||
{
|
||||
name: "send email code error",
|
||||
statusCode: http.StatusServiceUnavailable,
|
||||
responseBody: `{"error":{"code":"service_unavailable","message":"service is unavailable"}}`,
|
||||
call: func(client *HTTPAuthServiceClient) error {
|
||||
_, err := client.SendEmailCode(context.Background(), SendEmailCodeInput{Email: "pilot@example.com"})
|
||||
return err
|
||||
},
|
||||
wantStatusCode: http.StatusServiceUnavailable,
|
||||
wantCode: "service_unavailable",
|
||||
wantMessage: "service is unavailable",
|
||||
},
|
||||
{
|
||||
name: "confirm email code error",
|
||||
statusCode: http.StatusConflict,
|
||||
responseBody: `{"error":{"code":"session_limit_exceeded","message":"active session limit would be exceeded"}}`,
|
||||
call: func(client *HTTPAuthServiceClient) error {
|
||||
_, err := client.ConfirmEmailCode(context.Background(), ConfirmEmailCodeInput{
|
||||
ChallengeID: "challenge-123",
|
||||
Code: "123456",
|
||||
ClientPublicKey: "public-key",
|
||||
TimeZone: "Europe/Kaliningrad",
|
||||
})
|
||||
return err
|
||||
},
|
||||
wantStatusCode: http.StatusConflict,
|
||||
wantCode: "session_limit_exceeded",
|
||||
wantMessage: "active session limit would be exceeded",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(tt.statusCode)
|
||||
_, err := io.WriteString(w, tt.responseBody)
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestHTTPAuthServiceClient(t, server)
|
||||
err := tt.call(client)
|
||||
require.Error(t, err)
|
||||
|
||||
var authErr *AuthServiceError
|
||||
require.ErrorAs(t, err, &authErr)
|
||||
assert.Equal(t, tt.wantStatusCode, authErr.StatusCode)
|
||||
assert.Equal(t, tt.wantCode, authErr.Code)
|
||||
assert.Equal(t, tt.wantMessage, authErr.Message)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPAuthServiceClientRejectsMalformedPayloads(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
statusCode int
|
||||
responseBody string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "send email code rejects unknown success field",
|
||||
path: authServiceSendEmailCodePath,
|
||||
statusCode: http.StatusOK,
|
||||
responseBody: `{"challenge_id":"challenge-123","extra":true}`,
|
||||
wantErr: "decode success response",
|
||||
},
|
||||
{
|
||||
name: "confirm email code rejects empty success field",
|
||||
path: authServiceConfirmEmailCodePath,
|
||||
statusCode: http.StatusOK,
|
||||
responseBody: `{"device_session_id":" "}`,
|
||||
wantErr: "empty device_session_id",
|
||||
},
|
||||
{
|
||||
name: "rejects missing error object",
|
||||
path: authServiceSendEmailCodePath,
|
||||
statusCode: http.StatusBadRequest,
|
||||
responseBody: `{}`,
|
||||
wantErr: "missing error object",
|
||||
},
|
||||
{
|
||||
name: "rejects malformed error envelope",
|
||||
path: authServiceConfirmEmailCodePath,
|
||||
statusCode: http.StatusBadRequest,
|
||||
responseBody: `{"error":{"code":"invalid_code","message":"confirmation code is invalid","extra":true}}`,
|
||||
wantErr: "decode error response",
|
||||
},
|
||||
{
|
||||
name: "rejects unexpected status",
|
||||
path: authServiceSendEmailCodePath,
|
||||
statusCode: http.StatusCreated,
|
||||
responseBody: `{"challenge_id":"challenge-123"}`,
|
||||
wantErr: "unexpected HTTP status 201",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, tt.path, r.URL.Path)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(tt.statusCode)
|
||||
_, err := io.WriteString(w, tt.responseBody)
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestHTTPAuthServiceClient(t, server)
|
||||
|
||||
var err error
|
||||
switch tt.path {
|
||||
case authServiceSendEmailCodePath:
|
||||
_, err = client.SendEmailCode(context.Background(), SendEmailCodeInput{Email: "pilot@example.com"})
|
||||
default:
|
||||
_, err = client.ConfirmEmailCode(context.Background(), ConfirmEmailCodeInput{
|
||||
ChallengeID: "challenge-123",
|
||||
Code: "123456",
|
||||
ClientPublicKey: "public-key",
|
||||
TimeZone: "Europe/Kaliningrad",
|
||||
})
|
||||
}
|
||||
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, tt.wantErr)
|
||||
assert.NotErrorAs(t, err, new(*AuthServiceError))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPAuthServiceClientUsesCallerContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"challenge_id":"challenge-123"}`)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestHTTPAuthServiceClient(t, server)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, err := client.SendEmailCode(ctx, SendEmailCodeInput{Email: "pilot@example.com"})
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "send email code via auth service")
|
||||
assert.True(t, errors.Is(err, context.DeadlineExceeded))
|
||||
}
|
||||
|
||||
func TestHTTPAuthServiceClientRejectsNilContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.FailNow(t, "unexpected request", r.URL.Path)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestHTTPAuthServiceClient(t, server)
|
||||
|
||||
_, err := client.SendEmailCode(nil, SendEmailCodeInput{Email: "pilot@example.com"})
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "nil context")
|
||||
}
|
||||
|
||||
func newTestHTTPAuthServiceClient(t *testing.T, server *httptest.Server) *HTTPAuthServiceClient {
|
||||
t.Helper()
|
||||
|
||||
client, err := newHTTPAuthServiceClient(server.URL, server.Client())
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, client.Close())
|
||||
})
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
func TestDecodeStrictJSONPayloadRejectsTrailingJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var target struct {
|
||||
Value string `json:"value"`
|
||||
}
|
||||
err := decodeStrictJSONPayload([]byte(`{"value":"ok"}{}`), &target)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, "unexpected trailing JSON input", err.Error())
|
||||
}
|
||||
|
||||
func TestDecodeAuthServiceErrorPreservesBlankFieldsForLaterNormalization(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
authErr, err := decodeAuthServiceError(http.StatusBadGateway, []byte(`{"error":{"code":" ","message":" "}}`))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusBadGateway, authErr.StatusCode)
|
||||
assert.True(t, strings.TrimSpace(authErr.Code) == "")
|
||||
assert.True(t, strings.TrimSpace(authErr.Message) == "")
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// BackendLookup describes the slice of `backendclient.RESTClient`
|
||||
// SessionCache depends on. The narrow interface keeps this package free
|
||||
// of any backendclient import.
|
||||
type BackendLookup interface {
|
||||
LookupSession(ctx context.Context, deviceSessionID string) (Record, error)
|
||||
}
|
||||
|
||||
// BackendCache resolves authenticated device sessions by issuing one
|
||||
// synchronous REST call to backend per request. The canonical implementation replaces the
|
||||
// previous Redis-backed projection with this thin wrapper; gateway no
|
||||
// longer keeps a process-local snapshot. See ARCHITECTURE.md §11
|
||||
// «backend (sync REST), no Redis projection».
|
||||
type BackendCache struct {
|
||||
backend BackendLookup
|
||||
}
|
||||
|
||||
// NewBackendCache constructs a Cache that delegates every Lookup to
|
||||
// backend over REST. backend must not be nil.
|
||||
func NewBackendCache(backend BackendLookup) (*BackendCache, error) {
|
||||
if backend == nil {
|
||||
return nil, errors.New("session.NewBackendCache: backend lookup must not be nil")
|
||||
}
|
||||
return &BackendCache{backend: backend}, nil
|
||||
}
|
||||
|
||||
// Lookup resolves deviceSessionID via backend. ErrNotFound is forwarded
|
||||
// unchanged so callers can keep using the existing equality check.
|
||||
func (c *BackendCache) Lookup(ctx context.Context, deviceSessionID string) (Record, error) {
|
||||
if c == nil {
|
||||
return Record{}, errors.New("session backend cache: nil cache")
|
||||
}
|
||||
if c.backend == nil {
|
||||
return Record{}, errors.New("session backend cache: nil backend lookup")
|
||||
}
|
||||
rec, err := c.backend.LookupSession(ctx, deviceSessionID)
|
||||
if err != nil {
|
||||
return Record{}, fmt.Errorf("session backend cache: %w", err)
|
||||
}
|
||||
return rec, nil
|
||||
}
|
||||
|
||||
var _ Cache = (*BackendCache)(nil)
|
||||
@@ -1,88 +0,0 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// MemoryCache stores session record snapshots in process-local memory. It is
|
||||
// intended for the authenticated gateway hot path and deliberately keeps no
|
||||
// TTL or size-based eviction policy.
|
||||
type MemoryCache struct {
|
||||
mu sync.RWMutex
|
||||
records map[string]Record
|
||||
}
|
||||
|
||||
// NewMemoryCache constructs an empty process-local session snapshot store.
|
||||
func NewMemoryCache() *MemoryCache {
|
||||
return &MemoryCache{
|
||||
records: make(map[string]Record),
|
||||
}
|
||||
}
|
||||
|
||||
// Lookup resolves deviceSessionID from the process-local snapshot map.
|
||||
func (c *MemoryCache) Lookup(ctx context.Context, deviceSessionID string) (Record, error) {
|
||||
if c == nil {
|
||||
return Record{}, errors.New("lookup session from in-memory cache: nil cache")
|
||||
}
|
||||
if ctx == nil || fmt.Sprint(ctx) == "context.TODO" {
|
||||
return Record{}, errors.New("lookup session from in-memory cache: nil context")
|
||||
}
|
||||
if strings.TrimSpace(deviceSessionID) == "" {
|
||||
return Record{}, errors.New("lookup session from in-memory cache: empty device session id")
|
||||
}
|
||||
|
||||
c.mu.RLock()
|
||||
record, ok := c.records[deviceSessionID]
|
||||
c.mu.RUnlock()
|
||||
if !ok {
|
||||
return Record{}, fmt.Errorf("lookup session from in-memory cache: %w", ErrNotFound)
|
||||
}
|
||||
|
||||
return cloneRecord(record), nil
|
||||
}
|
||||
|
||||
// Upsert stores record in the process-local snapshot map after validating the
|
||||
// same session invariants expected from the Redis-backed cache.
|
||||
func (c *MemoryCache) Upsert(record Record) error {
|
||||
if c == nil {
|
||||
return errors.New("upsert session into in-memory cache: nil cache")
|
||||
}
|
||||
if err := validateRecord(record.DeviceSessionID, record); err != nil {
|
||||
return fmt.Errorf("upsert session into in-memory cache: %w", err)
|
||||
}
|
||||
|
||||
cloned := cloneRecord(record)
|
||||
|
||||
c.mu.Lock()
|
||||
c.records[record.DeviceSessionID] = cloned
|
||||
c.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes the local snapshot for deviceSessionID when one exists.
|
||||
func (c *MemoryCache) Delete(deviceSessionID string) {
|
||||
if c == nil || strings.TrimSpace(deviceSessionID) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
delete(c.records, deviceSessionID)
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func cloneRecord(record Record) Record {
|
||||
cloned := record
|
||||
if record.RevokedAtMS != nil {
|
||||
value := *record.RevokedAtMS
|
||||
cloned.RevokedAtMS = &value
|
||||
}
|
||||
|
||||
return cloned
|
||||
}
|
||||
|
||||
var _ SnapshotStore = (*MemoryCache)(nil)
|
||||
@@ -1,68 +0,0 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ReadThroughCache resolves authenticated sessions from a process-local
|
||||
// SnapshotStore first and falls back to another Cache only on a local miss.
|
||||
type ReadThroughCache struct {
|
||||
local SnapshotStore
|
||||
fallback Cache
|
||||
}
|
||||
|
||||
// NewReadThroughCache constructs a hot-path cache that seeds local snapshots
|
||||
// from fallback on demand.
|
||||
func NewReadThroughCache(local SnapshotStore, fallback Cache) (*ReadThroughCache, error) {
|
||||
if local == nil {
|
||||
return nil, errors.New("new read-through session cache: nil local cache")
|
||||
}
|
||||
if fallback == nil {
|
||||
return nil, errors.New("new read-through session cache: nil fallback cache")
|
||||
}
|
||||
|
||||
return &ReadThroughCache{
|
||||
local: local,
|
||||
fallback: fallback,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Lookup resolves deviceSessionID from local first, then performs one fallback
|
||||
// lookup on a local miss and seeds the local cache with the returned snapshot.
|
||||
func (c *ReadThroughCache) Lookup(ctx context.Context, deviceSessionID string) (Record, error) {
|
||||
if c == nil {
|
||||
return Record{}, errors.New("lookup session from read-through cache: nil cache")
|
||||
}
|
||||
|
||||
record, err := c.local.Lookup(ctx, deviceSessionID)
|
||||
switch {
|
||||
case err == nil:
|
||||
return record, nil
|
||||
case !errors.Is(err, ErrNotFound):
|
||||
return Record{}, fmt.Errorf("lookup session from read-through cache: %w", err)
|
||||
}
|
||||
|
||||
record, err = c.fallback.Lookup(ctx, deviceSessionID)
|
||||
if err != nil {
|
||||
return Record{}, err
|
||||
}
|
||||
|
||||
if err := c.local.Upsert(record); err != nil {
|
||||
return Record{}, fmt.Errorf("lookup session from read-through cache: seed local cache: %w", err)
|
||||
}
|
||||
|
||||
return cloneRecord(record), nil
|
||||
}
|
||||
|
||||
// Local returns the mutable process-local snapshot store used by c.
|
||||
func (c *ReadThroughCache) Local() SnapshotStore {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return c.local
|
||||
}
|
||||
|
||||
var _ Cache = (*ReadThroughCache)(nil)
|
||||
@@ -1,176 +0,0 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMemoryCacheLookupReturnsClonedRecord(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cache := NewMemoryCache()
|
||||
revokedAtMS := int64(123456789)
|
||||
|
||||
require.NoError(t, cache.Upsert(Record{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: "public-key-123",
|
||||
Status: StatusRevoked,
|
||||
RevokedAtMS: &revokedAtMS,
|
||||
}))
|
||||
|
||||
record, err := cache.Lookup(context.Background(), "device-session-123")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, record.RevokedAtMS)
|
||||
|
||||
*record.RevokedAtMS = 1
|
||||
|
||||
stored, err := cache.Lookup(context.Background(), "device-session-123")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, stored.RevokedAtMS)
|
||||
assert.Equal(t, revokedAtMS, *stored.RevokedAtMS)
|
||||
}
|
||||
|
||||
func TestReadThroughCacheLocalHitSkipsFallback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
local := NewMemoryCache()
|
||||
require.NoError(t, local.Upsert(Record{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: "public-key-123",
|
||||
Status: StatusActive,
|
||||
}))
|
||||
|
||||
fallback := &recordingCache{
|
||||
lookupFunc: func(context.Context, string) (Record, error) {
|
||||
return Record{}, errors.New("fallback should not be called")
|
||||
},
|
||||
}
|
||||
|
||||
cache, err := NewReadThroughCache(local, fallback)
|
||||
require.NoError(t, err)
|
||||
|
||||
record, err := cache.Lookup(context.Background(), "device-session-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, Record{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: "public-key-123",
|
||||
Status: StatusActive,
|
||||
}, record)
|
||||
assert.Equal(t, 0, fallback.lookupCalls)
|
||||
}
|
||||
|
||||
func TestReadThroughCacheFallbackSeedsLocalCache(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
local := NewMemoryCache()
|
||||
fallback := &recordingCache{
|
||||
lookupFunc: func(context.Context, string) (Record, error) {
|
||||
return Record{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: "public-key-123",
|
||||
Status: StatusActive,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
cache, err := NewReadThroughCache(local, fallback)
|
||||
require.NoError(t, err)
|
||||
|
||||
record, err := cache.Lookup(context.Background(), "device-session-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, fallback.lookupCalls)
|
||||
assert.Equal(t, "user-123", record.UserID)
|
||||
|
||||
record, err = cache.Lookup(context.Background(), "device-session-123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, fallback.lookupCalls)
|
||||
assert.Equal(t, "user-123", record.UserID)
|
||||
}
|
||||
|
||||
func TestReadThroughCacheKeepsRevokedSnapshotLocal(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
revokedAtMS := int64(123456789)
|
||||
local := NewMemoryCache()
|
||||
fallback := &recordingCache{
|
||||
lookupFunc: func(context.Context, string) (Record, error) {
|
||||
return Record{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: "public-key-123",
|
||||
Status: StatusRevoked,
|
||||
RevokedAtMS: &revokedAtMS,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
cache, err := NewReadThroughCache(local, fallback)
|
||||
require.NoError(t, err)
|
||||
|
||||
record, err := cache.Lookup(context.Background(), "device-session-123")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, record.RevokedAtMS)
|
||||
assert.Equal(t, StatusRevoked, record.Status)
|
||||
assert.Equal(t, 1, fallback.lookupCalls)
|
||||
|
||||
record, err = cache.Lookup(context.Background(), "device-session-123")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, record.RevokedAtMS)
|
||||
assert.Equal(t, StatusRevoked, record.Status)
|
||||
assert.Equal(t, revokedAtMS, *record.RevokedAtMS)
|
||||
assert.Equal(t, 1, fallback.lookupCalls)
|
||||
}
|
||||
|
||||
func TestReadThroughCacheReturnsClonedFallbackRecord(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
revokedAtMS := int64(123456789)
|
||||
local := NewMemoryCache()
|
||||
fallback := &recordingCache{
|
||||
lookupFunc: func(context.Context, string) (Record, error) {
|
||||
return Record{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: "public-key-123",
|
||||
Status: StatusRevoked,
|
||||
RevokedAtMS: &revokedAtMS,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
cache, err := NewReadThroughCache(local, fallback)
|
||||
require.NoError(t, err)
|
||||
|
||||
record, err := cache.Lookup(context.Background(), "device-session-123")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, record.RevokedAtMS)
|
||||
|
||||
*record.RevokedAtMS = 1
|
||||
|
||||
stored, err := local.Lookup(context.Background(), "device-session-123")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, stored.RevokedAtMS)
|
||||
assert.Equal(t, revokedAtMS, *stored.RevokedAtMS)
|
||||
}
|
||||
|
||||
type recordingCache struct {
|
||||
lookupCalls int
|
||||
lookupFunc func(context.Context, string) (Record, error)
|
||||
}
|
||||
|
||||
func (c *recordingCache) Lookup(ctx context.Context, deviceSessionID string) (Record, error) {
|
||||
c.lookupCalls++
|
||||
if c.lookupFunc != nil {
|
||||
return c.lookupFunc(ctx, deviceSessionID)
|
||||
}
|
||||
|
||||
return Record{}, errors.New("lookup is not implemented")
|
||||
}
|
||||
@@ -1,150 +0,0 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/config"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// RedisCache implements Cache with Redis GET lookups over strict JSON session
|
||||
// records.
|
||||
type RedisCache struct {
|
||||
client *redis.Client
|
||||
keyPrefix string
|
||||
lookupTimeout time.Duration
|
||||
}
|
||||
|
||||
type redisRecord struct {
|
||||
DeviceSessionID string `json:"device_session_id"`
|
||||
UserID string `json:"user_id"`
|
||||
ClientPublicKey string `json:"client_public_key"`
|
||||
Status Status `json:"status"`
|
||||
RevokedAtMS *int64 `json:"revoked_at_ms,omitempty"`
|
||||
}
|
||||
|
||||
// NewRedisCache constructs a Redis-backed SessionCache that uses client and
|
||||
// applies the namespace and timeout settings from cfg. The cache does not own
|
||||
// the client; the runtime supplies a shared *redis.Client.
|
||||
func NewRedisCache(client *redis.Client, cfg config.SessionCacheRedisConfig) (*RedisCache, error) {
|
||||
if client == nil {
|
||||
return nil, errors.New("new redis session cache: nil redis client")
|
||||
}
|
||||
if strings.TrimSpace(cfg.KeyPrefix) == "" {
|
||||
return nil, errors.New("new redis session cache: redis key prefix must not be empty")
|
||||
}
|
||||
if cfg.LookupTimeout <= 0 {
|
||||
return nil, errors.New("new redis session cache: lookup timeout must be positive")
|
||||
}
|
||||
|
||||
return &RedisCache{
|
||||
client: client,
|
||||
keyPrefix: cfg.KeyPrefix,
|
||||
lookupTimeout: cfg.LookupTimeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Lookup resolves deviceSessionID from Redis, validates the cached JSON
|
||||
// payload strictly, and returns the decoded session record.
|
||||
func (c *RedisCache) Lookup(ctx context.Context, deviceSessionID string) (Record, error) {
|
||||
if c == nil || c.client == nil {
|
||||
return Record{}, errors.New("lookup session from redis: nil cache")
|
||||
}
|
||||
if ctx == nil || fmt.Sprint(ctx) == "context.TODO" {
|
||||
return Record{}, errors.New("lookup session from redis: nil context")
|
||||
}
|
||||
if strings.TrimSpace(deviceSessionID) == "" {
|
||||
return Record{}, errors.New("lookup session from redis: empty device session id")
|
||||
}
|
||||
|
||||
lookupCtx, cancel := context.WithTimeout(ctx, c.lookupTimeout)
|
||||
defer cancel()
|
||||
|
||||
payload, err := c.client.Get(lookupCtx, c.lookupKey(deviceSessionID)).Bytes()
|
||||
switch {
|
||||
case errors.Is(err, redis.Nil):
|
||||
return Record{}, fmt.Errorf("lookup session from redis: %w", ErrNotFound)
|
||||
case err != nil:
|
||||
return Record{}, fmt.Errorf("lookup session from redis: %w", err)
|
||||
}
|
||||
|
||||
record, err := decodeRedisRecord(deviceSessionID, payload)
|
||||
if err != nil {
|
||||
return Record{}, fmt.Errorf("lookup session from redis: %w", err)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func (c *RedisCache) lookupKey(deviceSessionID string) string {
|
||||
return c.keyPrefix + deviceSessionID
|
||||
}
|
||||
|
||||
func decodeRedisRecord(expectedDeviceSessionID string, payload []byte) (Record, error) {
|
||||
decoder := json.NewDecoder(bytes.NewReader(payload))
|
||||
decoder.DisallowUnknownFields()
|
||||
|
||||
var stored redisRecord
|
||||
if err := decoder.Decode(&stored); err != nil {
|
||||
return Record{}, fmt.Errorf("decode redis session record: %w", err)
|
||||
}
|
||||
if err := decoder.Decode(&struct{}{}); err != io.EOF {
|
||||
if err == nil {
|
||||
return Record{}, errors.New("decode redis session record: unexpected trailing JSON input")
|
||||
}
|
||||
return Record{}, fmt.Errorf("decode redis session record: %w", err)
|
||||
}
|
||||
|
||||
record := Record{
|
||||
DeviceSessionID: stored.DeviceSessionID,
|
||||
UserID: stored.UserID,
|
||||
ClientPublicKey: stored.ClientPublicKey,
|
||||
Status: stored.Status,
|
||||
RevokedAtMS: cloneOptionalInt64(stored.RevokedAtMS),
|
||||
}
|
||||
|
||||
if err := validateRecord(expectedDeviceSessionID, record); err != nil {
|
||||
return Record{}, err
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func validateRecord(expectedDeviceSessionID string, record Record) error {
|
||||
if record.DeviceSessionID == "" {
|
||||
return errors.New("session record device_session_id must not be empty")
|
||||
}
|
||||
if record.DeviceSessionID != expectedDeviceSessionID {
|
||||
return fmt.Errorf("session record device_session_id %q does not match requested %q", record.DeviceSessionID, expectedDeviceSessionID)
|
||||
}
|
||||
if record.UserID == "" {
|
||||
return errors.New("session record user_id must not be empty")
|
||||
}
|
||||
if record.ClientPublicKey == "" {
|
||||
return errors.New("session record client_public_key must not be empty")
|
||||
}
|
||||
if !record.Status.IsKnown() {
|
||||
return fmt.Errorf("session record status %q is unsupported", record.Status)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func cloneOptionalInt64(value *int64) *int64 {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cloned := *value
|
||||
return &cloned
|
||||
}
|
||||
|
||||
var _ Cache = (*RedisCache)(nil)
|
||||
@@ -1,317 +0,0 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/config"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newRedisClient(t *testing.T, server *miniredis.Miniredis) *redis.Client {
|
||||
t.Helper()
|
||||
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: server.Addr(),
|
||||
Protocol: 2,
|
||||
DisableIdentity: true,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, client.Close())
|
||||
})
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
func TestNewRedisCache(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
client := newRedisClient(t, server)
|
||||
|
||||
validCfg := config.SessionCacheRedisConfig{
|
||||
KeyPrefix: "gateway:session:",
|
||||
LookupTimeout: 250 * time.Millisecond,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
client *redis.Client
|
||||
cfg config.SessionCacheRedisConfig
|
||||
wantErr string
|
||||
}{
|
||||
{name: "valid config", client: client, cfg: validCfg},
|
||||
{name: "nil client", client: nil, cfg: validCfg, wantErr: "nil redis client"},
|
||||
{
|
||||
name: "empty key prefix",
|
||||
client: client,
|
||||
cfg: config.SessionCacheRedisConfig{LookupTimeout: 250 * time.Millisecond},
|
||||
wantErr: "redis key prefix must not be empty",
|
||||
},
|
||||
{
|
||||
name: "non-positive lookup timeout",
|
||||
client: client,
|
||||
cfg: config.SessionCacheRedisConfig{KeyPrefix: "gateway:session:"},
|
||||
wantErr: "lookup timeout must be positive",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cache, err := NewRedisCache(tt.client, tt.cfg)
|
||||
if tt.wantErr != "" {
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cache)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisCacheLookup(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
revokedAtMS := int64(123456789)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg config.SessionCacheRedisConfig
|
||||
requestID string
|
||||
seed func(*testing.T, *miniredis.Miniredis, config.SessionCacheRedisConfig)
|
||||
want Record
|
||||
wantErrIs error
|
||||
wantErrText string
|
||||
assertErrText string
|
||||
}{
|
||||
{
|
||||
name: "active cache hit",
|
||||
requestID: "device-session-123",
|
||||
cfg: config.SessionCacheRedisConfig{
|
||||
KeyPrefix: "gateway:session:",
|
||||
},
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) {
|
||||
t.Helper()
|
||||
setRedisSessionRecord(t, server, cfg.KeyPrefix+"device-session-123", redisRecord{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: "public-key-123",
|
||||
Status: StatusActive,
|
||||
})
|
||||
},
|
||||
want: Record{
|
||||
DeviceSessionID: "device-session-123",
|
||||
UserID: "user-123",
|
||||
ClientPublicKey: "public-key-123",
|
||||
Status: StatusActive,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing session",
|
||||
requestID: "device-session-404",
|
||||
cfg: config.SessionCacheRedisConfig{
|
||||
KeyPrefix: "gateway:session:",
|
||||
},
|
||||
wantErrIs: ErrNotFound,
|
||||
assertErrText: "session cache record not found",
|
||||
},
|
||||
{
|
||||
name: "revoked session",
|
||||
requestID: "device-session-revoked",
|
||||
cfg: config.SessionCacheRedisConfig{
|
||||
KeyPrefix: "gateway:session:",
|
||||
},
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) {
|
||||
t.Helper()
|
||||
setRedisSessionRecord(t, server, cfg.KeyPrefix+"device-session-revoked", redisRecord{
|
||||
DeviceSessionID: "device-session-revoked",
|
||||
UserID: "user-777",
|
||||
ClientPublicKey: "public-key-777",
|
||||
Status: StatusRevoked,
|
||||
RevokedAtMS: &revokedAtMS,
|
||||
})
|
||||
},
|
||||
want: Record{
|
||||
DeviceSessionID: "device-session-revoked",
|
||||
UserID: "user-777",
|
||||
ClientPublicKey: "public-key-777",
|
||||
Status: StatusRevoked,
|
||||
RevokedAtMS: &revokedAtMS,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "malformed json",
|
||||
requestID: "device-session-bad-json",
|
||||
cfg: config.SessionCacheRedisConfig{
|
||||
KeyPrefix: "gateway:session:",
|
||||
},
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) {
|
||||
t.Helper()
|
||||
server.Set(cfg.KeyPrefix+"device-session-bad-json", "{")
|
||||
},
|
||||
wantErrText: "decode redis session record",
|
||||
},
|
||||
{
|
||||
name: "unknown status",
|
||||
requestID: "device-session-unknown-status",
|
||||
cfg: config.SessionCacheRedisConfig{
|
||||
KeyPrefix: "gateway:session:",
|
||||
},
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) {
|
||||
t.Helper()
|
||||
setRedisSessionRecord(t, server, cfg.KeyPrefix+"device-session-unknown-status", redisRecord{
|
||||
DeviceSessionID: "device-session-unknown-status",
|
||||
UserID: "user-1",
|
||||
ClientPublicKey: "public-key-1",
|
||||
Status: Status("paused"),
|
||||
})
|
||||
},
|
||||
wantErrText: `status "paused" is unsupported`,
|
||||
},
|
||||
{
|
||||
name: "missing required field",
|
||||
requestID: "device-session-missing-user",
|
||||
cfg: config.SessionCacheRedisConfig{
|
||||
KeyPrefix: "gateway:session:",
|
||||
},
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) {
|
||||
t.Helper()
|
||||
setRedisSessionRecord(t, server, cfg.KeyPrefix+"device-session-missing-user", redisRecord{
|
||||
DeviceSessionID: "device-session-missing-user",
|
||||
ClientPublicKey: "public-key-1",
|
||||
Status: StatusActive,
|
||||
})
|
||||
},
|
||||
wantErrText: "user_id must not be empty",
|
||||
},
|
||||
{
|
||||
name: "device session id mismatch",
|
||||
requestID: "device-session-requested",
|
||||
cfg: config.SessionCacheRedisConfig{
|
||||
KeyPrefix: "gateway:session:",
|
||||
},
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) {
|
||||
t.Helper()
|
||||
setRedisSessionRecord(t, server, cfg.KeyPrefix+"device-session-requested", redisRecord{
|
||||
DeviceSessionID: "device-session-other",
|
||||
UserID: "user-1",
|
||||
ClientPublicKey: "public-key-1",
|
||||
Status: StatusActive,
|
||||
})
|
||||
},
|
||||
wantErrText: `does not match requested "device-session-requested"`,
|
||||
},
|
||||
{
|
||||
name: "key prefix is honored",
|
||||
requestID: "device-session-prefixed",
|
||||
cfg: config.SessionCacheRedisConfig{
|
||||
KeyPrefix: "custom:session:",
|
||||
},
|
||||
seed: func(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) {
|
||||
t.Helper()
|
||||
setRedisSessionRecord(t, server, cfg.KeyPrefix+"device-session-prefixed", redisRecord{
|
||||
DeviceSessionID: "device-session-prefixed",
|
||||
UserID: "user-prefixed",
|
||||
ClientPublicKey: "public-key-prefixed",
|
||||
Status: StatusActive,
|
||||
})
|
||||
setRedisSessionRecord(t, server, "gateway:session:device-session-prefixed", redisRecord{
|
||||
DeviceSessionID: "device-session-prefixed",
|
||||
UserID: "wrong-user",
|
||||
ClientPublicKey: "wrong-key",
|
||||
Status: StatusRevoked,
|
||||
})
|
||||
},
|
||||
want: Record{
|
||||
DeviceSessionID: "device-session-prefixed",
|
||||
UserID: "user-prefixed",
|
||||
ClientPublicKey: "public-key-prefixed",
|
||||
Status: StatusActive,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
|
||||
cfg := tt.cfg
|
||||
cfg.LookupTimeout = 250 * time.Millisecond
|
||||
|
||||
if tt.seed != nil {
|
||||
tt.seed(t, server, cfg)
|
||||
}
|
||||
|
||||
cache := newTestRedisCache(t, server, cfg)
|
||||
record, err := cache.Lookup(context.Background(), tt.requestID)
|
||||
if tt.wantErrIs != nil || tt.wantErrText != "" {
|
||||
require.Error(t, err)
|
||||
if tt.wantErrIs != nil {
|
||||
assert.ErrorIs(t, err, tt.wantErrIs)
|
||||
}
|
||||
if tt.wantErrText != "" {
|
||||
assert.ErrorContains(t, err, tt.wantErrText)
|
||||
}
|
||||
if tt.assertErrText != "" {
|
||||
assert.ErrorContains(t, err, tt.assertErrText)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.want, record)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newTestRedisCache(t *testing.T, server *miniredis.Miniredis, cfg config.SessionCacheRedisConfig) *RedisCache {
|
||||
t.Helper()
|
||||
|
||||
if cfg.KeyPrefix == "" {
|
||||
cfg.KeyPrefix = "gateway:session:"
|
||||
}
|
||||
if cfg.LookupTimeout == 0 {
|
||||
cfg.LookupTimeout = 250 * time.Millisecond
|
||||
}
|
||||
|
||||
cache, err := NewRedisCache(newRedisClient(t, server), cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
return cache
|
||||
}
|
||||
|
||||
func setRedisSessionRecord(t *testing.T, server *miniredis.Miniredis, key string, record redisRecord) {
|
||||
t.Helper()
|
||||
|
||||
payload, err := json.Marshal(record)
|
||||
require.NoError(t, err)
|
||||
|
||||
server.Set(key, string(payload))
|
||||
}
|
||||
|
||||
func TestRedisCacheLookupNilContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := miniredis.RunT(t)
|
||||
cache := newTestRedisCache(t, server, config.SessionCacheRedisConfig{})
|
||||
|
||||
_, err := cache.Lookup(context.TODO(), "device-session-123")
|
||||
require.Error(t, err)
|
||||
assert.False(t, errors.Is(err, ErrNotFound))
|
||||
assert.ErrorContains(t, err, "nil context")
|
||||
}
|
||||
@@ -13,27 +13,16 @@ var (
|
||||
ErrNotFound = errors.New("session cache record not found")
|
||||
)
|
||||
|
||||
// Cache resolves authenticated device-session state from the gateway hot-path
|
||||
// cache.
|
||||
// Cache resolves authenticated device-session state from the gateway
|
||||
// hot path. The implementation dropped the previous Redis projection: the only
|
||||
// implementation is *BackendCache, which calls backend's
|
||||
// `/api/v1/internal/sessions/{id}` synchronously per request.
|
||||
type Cache interface {
|
||||
// Lookup returns the cached record for deviceSessionID. Implementations must
|
||||
// wrap ErrNotFound when the cache does not contain the requested record.
|
||||
Lookup(ctx context.Context, deviceSessionID string) (Record, error)
|
||||
}
|
||||
|
||||
// SnapshotStore stores mutable session record snapshots inside one gateway
|
||||
// process and exposes the same read contract as Cache for the hot path.
|
||||
type SnapshotStore interface {
|
||||
Cache
|
||||
|
||||
// Upsert stores record under record.DeviceSessionID, replacing any previous
|
||||
// snapshot for that session.
|
||||
Upsert(record Record) error
|
||||
|
||||
// Delete removes the local snapshot for deviceSessionID when it exists.
|
||||
Delete(deviceSessionID string)
|
||||
}
|
||||
|
||||
// Status identifies the cached lifecycle state of a device session.
|
||||
type Status string
|
||||
|
||||
|
||||
Reference in New Issue
Block a user