386 lines
10 KiB
Go
386 lines
10 KiB
Go
package events
|
|
|
|
import (
|
|
"context"
|
|
"crypto/ed25519"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"errors"
|
|
"net"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"galaxy/gateway/internal/app"
|
|
"galaxy/gateway/internal/authn"
|
|
"galaxy/gateway/internal/clock"
|
|
"galaxy/gateway/internal/config"
|
|
"galaxy/gateway/internal/downstream"
|
|
"galaxy/gateway/internal/grpcapi"
|
|
"galaxy/gateway/internal/replay"
|
|
"galaxy/gateway/internal/session"
|
|
gatewayv1 "galaxy/gateway/proto/galaxy/gateway/v1"
|
|
|
|
"github.com/alicebob/miniredis/v2"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/credentials/insecure"
|
|
"google.golang.org/grpc/status"
|
|
)
|
|
|
|
var testNow = time.Date(2026, time.April, 1, 12, 0, 0, 0, time.UTC)
|
|
|
|
func TestAuthenticatedGatewayWarmsLocalSessionCache(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := miniredis.RunT(t)
|
|
local := session.NewMemoryCache()
|
|
fallback := &countingSessionCache{
|
|
records: map[string]session.Record{
|
|
"device-session-123": newActiveSessionRecord("user-123"),
|
|
},
|
|
}
|
|
readThrough, err := session.NewReadThroughCache(local, fallback)
|
|
require.NoError(t, err)
|
|
|
|
subscriber := newTestRedisSessionSubscriber(t, server, local)
|
|
downstreamClient := &recordingDownstreamClient{}
|
|
addr, running := runAuthenticatedGateway(t, readThrough, subscriber, downstreamClient)
|
|
defer running.stop(t)
|
|
|
|
conn := dialGatewayClient(t, addr)
|
|
defer func() {
|
|
require.NoError(t, conn.Close())
|
|
}()
|
|
|
|
client := gatewayv1.NewEdgeGatewayClient(conn)
|
|
|
|
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-1"))
|
|
require.NoError(t, err)
|
|
assert.Equal(t, 1, fallback.lookupCalls())
|
|
|
|
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-2"))
|
|
require.NoError(t, err)
|
|
assert.Equal(t, 1, fallback.lookupCalls())
|
|
assert.Len(t, downstreamClient.commands(), 2)
|
|
}
|
|
|
|
func TestAuthenticatedGatewayUsesSessionUpdateEventWithoutFallbackLookup(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := miniredis.RunT(t)
|
|
local := session.NewMemoryCache()
|
|
fallback := &countingSessionCache{
|
|
records: map[string]session.Record{
|
|
"device-session-123": newActiveSessionRecord("user-123"),
|
|
},
|
|
}
|
|
readThrough, err := session.NewReadThroughCache(local, fallback)
|
|
require.NoError(t, err)
|
|
|
|
subscriber := newTestRedisSessionSubscriber(t, server, local)
|
|
downstreamClient := &recordingDownstreamClient{}
|
|
addr, running := runAuthenticatedGateway(t, readThrough, subscriber, downstreamClient)
|
|
defer running.stop(t)
|
|
|
|
conn := dialGatewayClient(t, addr)
|
|
defer func() {
|
|
require.NoError(t, conn.Close())
|
|
}()
|
|
|
|
client := gatewayv1.NewEdgeGatewayClient(conn)
|
|
|
|
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-1"))
|
|
require.NoError(t, err)
|
|
assert.Equal(t, 1, fallback.lookupCalls())
|
|
|
|
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
|
"device_session_id": "device-session-123",
|
|
"user_id": "user-456",
|
|
"client_public_key": testClientPublicKeyBase64(),
|
|
"status": string(session.StatusActive),
|
|
})
|
|
|
|
require.Eventually(t, func() bool {
|
|
record, lookupErr := local.Lookup(context.Background(), "device-session-123")
|
|
return lookupErr == nil && record.UserID == "user-456"
|
|
}, time.Second, 10*time.Millisecond)
|
|
|
|
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-2"))
|
|
require.NoError(t, err)
|
|
assert.Equal(t, 1, fallback.lookupCalls())
|
|
|
|
commands := downstreamClient.commands()
|
|
require.Len(t, commands, 2)
|
|
assert.Equal(t, "user-456", commands[1].UserID)
|
|
}
|
|
|
|
func TestAuthenticatedGatewayRejectsRevokedSessionAfterEventWithoutFallbackLookup(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := miniredis.RunT(t)
|
|
local := session.NewMemoryCache()
|
|
fallback := &countingSessionCache{
|
|
records: map[string]session.Record{
|
|
"device-session-123": newActiveSessionRecord("user-123"),
|
|
},
|
|
}
|
|
readThrough, err := session.NewReadThroughCache(local, fallback)
|
|
require.NoError(t, err)
|
|
|
|
subscriber := newTestRedisSessionSubscriber(t, server, local)
|
|
downstreamClient := &recordingDownstreamClient{}
|
|
addr, running := runAuthenticatedGateway(t, readThrough, subscriber, downstreamClient)
|
|
defer running.stop(t)
|
|
|
|
conn := dialGatewayClient(t, addr)
|
|
defer func() {
|
|
require.NoError(t, conn.Close())
|
|
}()
|
|
|
|
client := gatewayv1.NewEdgeGatewayClient(conn)
|
|
|
|
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-1"))
|
|
require.NoError(t, err)
|
|
assert.Equal(t, 1, fallback.lookupCalls())
|
|
|
|
addSessionEvent(t, server, "gateway:session_events", map[string]string{
|
|
"device_session_id": "device-session-123",
|
|
"user_id": "user-123",
|
|
"client_public_key": testClientPublicKeyBase64(),
|
|
"status": string(session.StatusRevoked),
|
|
"revoked_at_ms": "123456789",
|
|
})
|
|
|
|
require.Eventually(t, func() bool {
|
|
record, lookupErr := local.Lookup(context.Background(), "device-session-123")
|
|
return lookupErr == nil && record.Status == session.StatusRevoked
|
|
}, time.Second, 10*time.Millisecond)
|
|
|
|
_, err = client.ExecuteCommand(context.Background(), newExecuteCommandRequest("request-2"))
|
|
require.Error(t, err)
|
|
assert.Equal(t, codes.FailedPrecondition, status.Code(err))
|
|
assert.Equal(t, "device session is revoked", status.Convert(err).Message())
|
|
assert.Equal(t, 1, fallback.lookupCalls())
|
|
}
|
|
|
|
type runningAuthenticatedGateway struct {
|
|
cancel context.CancelFunc
|
|
resultCh chan error
|
|
}
|
|
|
|
func runAuthenticatedGateway(t *testing.T, sessionCache session.Cache, subscriber *RedisSessionSubscriber, downstreamClient downstream.Client) (string, runningAuthenticatedGateway) {
|
|
t.Helper()
|
|
|
|
addr := unusedTCPAddr(t)
|
|
grpcCfg := config.DefaultAuthenticatedGRPCConfig()
|
|
grpcCfg.Addr = addr
|
|
grpcCfg.FreshnessWindow = 5 * time.Minute
|
|
|
|
router := downstream.NewStaticRouter(map[string]downstream.Client{
|
|
"fleet.move": downstreamClient,
|
|
})
|
|
|
|
gateway := grpcapi.NewServer(grpcCfg, grpcapi.ServerDependencies{
|
|
Router: router,
|
|
ResponseSigner: newTestResponseSigner(t),
|
|
SessionCache: sessionCache,
|
|
ReplayStore: staticReplayStore{},
|
|
Clock: fixedClock{now: testNow},
|
|
})
|
|
|
|
application := app.New(
|
|
config.Config{
|
|
ShutdownTimeout: time.Second,
|
|
AuthenticatedGRPC: grpcCfg,
|
|
},
|
|
gateway,
|
|
subscriber,
|
|
)
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
resultCh := make(chan error, 1)
|
|
go func() {
|
|
resultCh <- application.Run(ctx)
|
|
}()
|
|
|
|
select {
|
|
case <-subscriber.started:
|
|
case <-time.After(time.Second):
|
|
require.FailNow(t, "session subscriber did not start")
|
|
}
|
|
|
|
return addr, runningAuthenticatedGateway{
|
|
cancel: cancel,
|
|
resultCh: resultCh,
|
|
}
|
|
}
|
|
|
|
func (g runningAuthenticatedGateway) stop(t *testing.T) {
|
|
t.Helper()
|
|
|
|
g.cancel()
|
|
|
|
select {
|
|
case err := <-g.resultCh:
|
|
require.NoError(t, err)
|
|
case <-time.After(2 * time.Second):
|
|
require.FailNow(t, "gateway did not stop after cancellation")
|
|
}
|
|
}
|
|
|
|
func dialGatewayClient(t *testing.T, addr string) *grpc.ClientConn {
|
|
t.Helper()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
defer cancel()
|
|
|
|
conn, err := grpc.DialContext(
|
|
ctx,
|
|
addr,
|
|
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
|
grpc.WithBlock(),
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
return conn
|
|
}
|
|
|
|
func unusedTCPAddr(t *testing.T) string {
|
|
t.Helper()
|
|
|
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
require.NoError(t, err)
|
|
|
|
addr := listener.Addr().String()
|
|
require.NoError(t, listener.Close())
|
|
|
|
return addr
|
|
}
|
|
|
|
func newExecuteCommandRequest(requestID string) *gatewayv1.ExecuteCommandRequest {
|
|
payloadBytes := []byte("payload")
|
|
payloadHash := sha256.Sum256(payloadBytes)
|
|
|
|
req := &gatewayv1.ExecuteCommandRequest{
|
|
ProtocolVersion: "v1",
|
|
DeviceSessionId: "device-session-123",
|
|
MessageType: "fleet.move",
|
|
TimestampMs: testNow.UnixMilli(),
|
|
RequestId: requestID,
|
|
PayloadBytes: payloadBytes,
|
|
PayloadHash: payloadHash[:],
|
|
TraceId: "trace-123",
|
|
}
|
|
req.Signature = ed25519.Sign(testClientPrivateKey(), authn.BuildRequestSigningInput(authn.RequestSigningFields{
|
|
ProtocolVersion: req.GetProtocolVersion(),
|
|
DeviceSessionID: req.GetDeviceSessionId(),
|
|
MessageType: req.GetMessageType(),
|
|
TimestampMS: req.GetTimestampMs(),
|
|
RequestID: req.GetRequestId(),
|
|
PayloadHash: req.GetPayloadHash(),
|
|
}))
|
|
|
|
return req
|
|
}
|
|
|
|
func newActiveSessionRecord(userID string) session.Record {
|
|
return session.Record{
|
|
DeviceSessionID: "device-session-123",
|
|
UserID: userID,
|
|
ClientPublicKey: testClientPublicKeyBase64(),
|
|
Status: session.StatusActive,
|
|
}
|
|
}
|
|
|
|
func testClientPrivateKey() ed25519.PrivateKey {
|
|
seed := sha256.Sum256([]byte("gateway-events-grpc-test-client"))
|
|
return ed25519.NewKeyFromSeed(seed[:])
|
|
}
|
|
|
|
func testClientPublicKeyBase64() string {
|
|
return base64.StdEncoding.EncodeToString(testClientPrivateKey().Public().(ed25519.PublicKey))
|
|
}
|
|
|
|
func newTestResponseSigner(t *testing.T) authn.ResponseSigner {
|
|
t.Helper()
|
|
|
|
seed := sha256.Sum256([]byte("gateway-events-grpc-test-response"))
|
|
signer, err := authn.NewEd25519ResponseSigner(ed25519.NewKeyFromSeed(seed[:]))
|
|
require.NoError(t, err)
|
|
|
|
return signer
|
|
}
|
|
|
|
type fixedClock struct {
|
|
now time.Time
|
|
}
|
|
|
|
func (c fixedClock) Now() time.Time {
|
|
return c.now
|
|
}
|
|
|
|
var _ clock.Clock = fixedClock{}
|
|
|
|
type staticReplayStore struct{}
|
|
|
|
func (staticReplayStore) Reserve(context.Context, string, string, time.Duration) error {
|
|
return nil
|
|
}
|
|
|
|
var _ replay.Store = staticReplayStore{}
|
|
|
|
type countingSessionCache struct {
|
|
mu sync.Mutex
|
|
records map[string]session.Record
|
|
lookupCount int
|
|
}
|
|
|
|
func (c *countingSessionCache) Lookup(context.Context, string) (session.Record, error) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
c.lookupCount++
|
|
|
|
record, ok := c.records["device-session-123"]
|
|
if !ok {
|
|
return session.Record{}, errors.New("lookup session from counting cache: session cache record not found")
|
|
}
|
|
|
|
return record, nil
|
|
}
|
|
|
|
func (c *countingSessionCache) lookupCalls() int {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
return c.lookupCount
|
|
}
|
|
|
|
type recordingDownstreamClient struct {
|
|
mu sync.Mutex
|
|
captured []downstream.AuthenticatedCommand
|
|
}
|
|
|
|
func (c *recordingDownstreamClient) ExecuteCommand(_ context.Context, command downstream.AuthenticatedCommand) (downstream.UnaryResult, error) {
|
|
c.mu.Lock()
|
|
c.captured = append(c.captured, command)
|
|
c.mu.Unlock()
|
|
|
|
return downstream.UnaryResult{
|
|
ResultCode: "ok",
|
|
PayloadBytes: []byte("response"),
|
|
}, nil
|
|
}
|
|
|
|
func (c *recordingDownstreamClient) commands() []downstream.AuthenticatedCommand {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
cloned := make([]downstream.AuthenticatedCommand, len(c.captured))
|
|
copy(cloned, c.captured)
|
|
return cloned
|
|
}
|