tests: integration suite
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# Required startup settings.
|
||||
GATEWAY_SESSION_CACHE_REDIS_ADDR=127.0.0.1:6379
|
||||
GATEWAY_SESSION_EVENTS_REDIS_STREAM=gateway:session-events
|
||||
GATEWAY_CLIENT_EVENTS_REDIS_STREAM=gateway:client-events
|
||||
GATEWAY_SESSION_EVENTS_REDIS_STREAM=gateway:session_events
|
||||
GATEWAY_CLIENT_EVENTS_REDIS_STREAM=gateway:client_events
|
||||
GATEWAY_RESPONSE_SIGNER_PRIVATE_KEY_PEM_PATH=./secrets/response-signer.pem
|
||||
|
||||
# Main listeners.
|
||||
@@ -17,8 +17,9 @@ GATEWAY_AUTHENTICATED_GRPC_ADDR=127.0.0.1:9090
|
||||
# GATEWAY_REPLAY_REDIS_KEY_PREFIX=gateway:replay:
|
||||
# GATEWAY_SESSION_CACHE_REDIS_TLS_ENABLED=false
|
||||
|
||||
# Optional public-auth integration. Without an injected adapter the routes stay
|
||||
# mounted and return 503 service_unavailable.
|
||||
# Optional public-auth integration. Without a configured Auth / Session Service
|
||||
# base URL the routes stay mounted and return 503 service_unavailable.
|
||||
# GATEWAY_AUTH_SERVICE_BASE_URL=http://127.0.0.1:8081
|
||||
# GATEWAY_PUBLIC_AUTH_UPSTREAM_TIMEOUT=3s
|
||||
|
||||
# Optional shutdown and telemetry tuning.
|
||||
|
||||
+3
-2
@@ -21,13 +21,14 @@ Required startup environment variables:
|
||||
Optional integrations:
|
||||
|
||||
- `GATEWAY_ADMIN_HTTP_ADDR` enables the private `/metrics` listener;
|
||||
- an injected `AuthServiceClient` enables real public auth handling;
|
||||
- `GATEWAY_AUTH_SERVICE_BASE_URL` enables real public auth handling through
|
||||
Auth / Session Service public HTTP;
|
||||
- injected downstream routes are required for successful `ExecuteCommand`.
|
||||
|
||||
Operational caveats:
|
||||
|
||||
- public auth routes stay mounted and return `503 service_unavailable` until an
|
||||
auth adapter is wired;
|
||||
auth service base URL is configured;
|
||||
- authenticated gRPC starts without downstream routes, but `ExecuteCommand`
|
||||
returns gRPC `UNIMPLEMENTED` until routing is configured.
|
||||
|
||||
|
||||
@@ -25,6 +25,8 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var errNoopClose = func() error { return nil }
|
||||
|
||||
// main loads the gateway configuration, runs the process lifecycle, and exits
|
||||
// with a non-zero status when startup or runtime fails.
|
||||
func main() {
|
||||
@@ -53,8 +55,18 @@ func run(ctx context.Context) (err error) {
|
||||
return fmt.Errorf("build gateway telemetry: %w", err)
|
||||
}
|
||||
|
||||
publicRESTDeps, closePublicRESTDeps, err := newPublicRESTDependencies(cfg, logger, telemetryRuntime)
|
||||
if err != nil {
|
||||
_ = telemetryRuntime.Shutdown(context.Background())
|
||||
_ = logging.Sync(logger)
|
||||
return err
|
||||
}
|
||||
|
||||
grpcDeps, components, cleanup, err := newAuthenticatedGRPCDependencies(ctx, cfg, logger, telemetryRuntime)
|
||||
if err != nil {
|
||||
_ = closePublicRESTDeps()
|
||||
_ = telemetryRuntime.Shutdown(context.Background())
|
||||
_ = logging.Sync(logger)
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
@@ -63,16 +75,14 @@ func run(ctx context.Context) (err error) {
|
||||
|
||||
err = errors.Join(
|
||||
err,
|
||||
closePublicRESTDeps(),
|
||||
cleanup(),
|
||||
telemetryRuntime.Shutdown(shutdownCtx),
|
||||
logging.Sync(logger),
|
||||
)
|
||||
}()
|
||||
|
||||
restServer := restapi.NewServer(cfg.PublicHTTP, restapi.ServerDependencies{
|
||||
Logger: logger,
|
||||
Telemetry: telemetryRuntime,
|
||||
})
|
||||
restServer := restapi.NewServer(cfg.PublicHTTP, publicRESTDeps)
|
||||
grpcServer := grpcapi.NewServer(cfg.AuthenticatedGRPC, grpcDeps)
|
||||
|
||||
applicationComponents := []app.Component{
|
||||
@@ -96,6 +106,25 @@ func run(ctx context.Context) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
func newPublicRESTDependencies(cfg config.Config, logger *zap.Logger, telemetryRuntime *telemetry.Runtime) (restapi.ServerDependencies, func() error, error) {
|
||||
deps := restapi.ServerDependencies{
|
||||
Logger: logger,
|
||||
Telemetry: telemetryRuntime,
|
||||
}
|
||||
|
||||
if cfg.AuthService.BaseURL == "" {
|
||||
return deps, errNoopClose, nil
|
||||
}
|
||||
|
||||
authService, err := restapi.NewHTTPAuthServiceClient(cfg.AuthService.BaseURL)
|
||||
if err != nil {
|
||||
return restapi.ServerDependencies{}, nil, fmt.Errorf("build public REST dependencies: auth service client: %w", err)
|
||||
}
|
||||
|
||||
deps.AuthService = authService
|
||||
return deps, authService.Close, nil
|
||||
}
|
||||
|
||||
func newAuthenticatedGRPCDependencies(ctx context.Context, cfg config.Config, logger *zap.Logger, telemetryRuntime *telemetry.Runtime) (grpcapi.ServerDependencies, []app.Component, func() error, error) {
|
||||
responseSigner, err := authn.LoadEd25519ResponseSignerFromPEMFile(cfg.ResponseSigner.PrivateKeyPEMPath)
|
||||
if err != nil {
|
||||
|
||||
@@ -7,12 +7,14 @@ import (
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"net"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/gateway/internal/config"
|
||||
"galaxy/gateway/internal/restapi"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -20,6 +22,72 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestNewPublicRESTDependencies(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
authServer := httptest.NewServer(nil)
|
||||
defer authServer.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg config.Config
|
||||
assert func(*testing.T, restapi.ServerDependencies)
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "default unavailable auth service when base url is empty",
|
||||
cfg: config.Config{},
|
||||
assert: func(t *testing.T, deps restapi.ServerDependencies) {
|
||||
t.Helper()
|
||||
assert.Nil(t, deps.AuthService)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "real auth service client when base url is configured",
|
||||
cfg: config.Config{
|
||||
AuthService: config.AuthServiceConfig{
|
||||
BaseURL: authServer.URL,
|
||||
},
|
||||
},
|
||||
assert: func(t *testing.T, deps restapi.ServerDependencies) {
|
||||
t.Helper()
|
||||
require.NotNil(t, deps.AuthService)
|
||||
_, ok := deps.AuthService.(*restapi.HTTPAuthServiceClient)
|
||||
assert.True(t, ok)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid auth service base url fails fast",
|
||||
cfg: config.Config{
|
||||
AuthService: config.AuthServiceConfig{
|
||||
BaseURL: "/relative",
|
||||
},
|
||||
},
|
||||
wantErr: "auth service client",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
deps, cleanup, err := newPublicRESTDependencies(tt.cfg, zap.NewNop(), nil)
|
||||
if tt.wantErr != "" {
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cleanup)
|
||||
tt.assert(t, deps)
|
||||
assert.NoError(t, cleanup())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAuthenticatedGRPCDependencies(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -39,6 +40,11 @@ const (
|
||||
// configures the timeout budget used for public auth upstream calls.
|
||||
publicAuthUpstreamTimeoutEnvVar = "GATEWAY_PUBLIC_AUTH_UPSTREAM_TIMEOUT"
|
||||
|
||||
// authServiceBaseURLEnvVar names the environment variable that configures
|
||||
// the optional Auth / Session Service public HTTP base URL used by gateway
|
||||
// public-auth delegation.
|
||||
authServiceBaseURLEnvVar = "GATEWAY_AUTH_SERVICE_BASE_URL"
|
||||
|
||||
// adminHTTPAddrEnvVar names the environment variable that configures the
|
||||
// private admin HTTP listener address. When it is empty, the admin listener
|
||||
// remains disabled.
|
||||
@@ -464,6 +470,15 @@ type PublicHTTPConfig struct {
|
||||
AntiAbuse PublicHTTPAntiAbuseConfig
|
||||
}
|
||||
|
||||
// AuthServiceConfig describes the optional public-auth upstream used by the
|
||||
// gateway runtime.
|
||||
type AuthServiceConfig struct {
|
||||
// BaseURL is the absolute base URL of the Auth / Session Service public
|
||||
// HTTP API. When BaseURL is empty, the gateway keeps using its built-in
|
||||
// unavailable public-auth adapter.
|
||||
BaseURL string
|
||||
}
|
||||
|
||||
// AdminHTTPConfig describes the private operational HTTP listener used for
|
||||
// metrics exposure. The listener remains disabled when Addr is empty.
|
||||
type AdminHTTPConfig struct {
|
||||
@@ -591,6 +606,10 @@ type Config struct {
|
||||
// PublicHTTP configures the public unauthenticated REST listener.
|
||||
PublicHTTP PublicHTTPConfig
|
||||
|
||||
// AuthService configures the optional public-auth delegation to the Auth /
|
||||
// Session Service.
|
||||
AuthService AuthServiceConfig
|
||||
|
||||
// AdminHTTP configures the optional private admin listener used for metrics
|
||||
// exposure.
|
||||
AdminHTTP AdminHTTPConfig
|
||||
@@ -766,6 +785,12 @@ func DefaultResponseSignerConfig() ResponseSignerConfig {
|
||||
return ResponseSignerConfig{}
|
||||
}
|
||||
|
||||
// DefaultAuthServiceConfig returns the default public-auth upstream settings.
|
||||
// The zero value keeps the built-in unavailable adapter active.
|
||||
func DefaultAuthServiceConfig() AuthServiceConfig {
|
||||
return AuthServiceConfig{}
|
||||
}
|
||||
|
||||
// LoadFromEnv loads Config from the process environment, applies defaults for
|
||||
// omitted settings, and validates the resulting values.
|
||||
func LoadFromEnv() (Config, error) {
|
||||
@@ -773,6 +798,7 @@ func LoadFromEnv() (Config, error) {
|
||||
ShutdownTimeout: defaultShutdownTimeout,
|
||||
Logging: DefaultLoggingConfig(),
|
||||
PublicHTTP: DefaultPublicHTTPConfig(),
|
||||
AuthService: DefaultAuthServiceConfig(),
|
||||
AdminHTTP: DefaultAdminHTTPConfig(),
|
||||
AuthenticatedGRPC: DefaultAuthenticatedGRPCConfig(),
|
||||
SessionCacheRedis: DefaultSessionCacheRedisConfig(),
|
||||
@@ -825,6 +851,11 @@ func LoadFromEnv() (Config, error) {
|
||||
}
|
||||
cfg.PublicHTTP.AuthUpstreamTimeout = publicAuthUpstreamTimeout
|
||||
|
||||
rawAuthServiceBaseURL, ok := os.LookupEnv(authServiceBaseURLEnvVar)
|
||||
if ok {
|
||||
cfg.AuthService.BaseURL = rawAuthServiceBaseURL
|
||||
}
|
||||
|
||||
rawAdminHTTPAddr, ok := os.LookupEnv(adminHTTPAddrEnvVar)
|
||||
if ok {
|
||||
cfg.AdminHTTP.Addr = rawAdminHTTPAddr
|
||||
@@ -1082,6 +1113,17 @@ func LoadFromEnv() (Config, error) {
|
||||
if cfg.PublicHTTP.AuthUpstreamTimeout <= 0 {
|
||||
return Config{}, fmt.Errorf("load gateway config: %s must be positive", publicAuthUpstreamTimeoutEnvVar)
|
||||
}
|
||||
cfg.AuthService.BaseURL = strings.TrimSpace(cfg.AuthService.BaseURL)
|
||||
if cfg.AuthService.BaseURL != "" {
|
||||
parsedAuthServiceBaseURL, err := url.Parse(cfg.AuthService.BaseURL)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("load gateway config: parse %s: %w", authServiceBaseURLEnvVar, err)
|
||||
}
|
||||
if parsedAuthServiceBaseURL.Scheme == "" || parsedAuthServiceBaseURL.Host == "" {
|
||||
return Config{}, fmt.Errorf("load gateway config: %s must be an absolute URL", authServiceBaseURLEnvVar)
|
||||
}
|
||||
cfg.AuthService.BaseURL = strings.TrimRight(parsedAuthServiceBaseURL.String(), "/")
|
||||
}
|
||||
if addr := strings.TrimSpace(cfg.AdminHTTP.Addr); addr != "" {
|
||||
cfg.AdminHTTP.Addr = addr
|
||||
}
|
||||
|
||||
@@ -24,6 +24,9 @@ func TestLoadFromEnv(t *testing.T) {
|
||||
customPublicHTTPAddr := new(string)
|
||||
*customPublicHTTPAddr = "127.0.0.1:9090"
|
||||
|
||||
customAuthServiceBaseURL := new(string)
|
||||
*customAuthServiceBaseURL = " http://127.0.0.1:8082/ "
|
||||
|
||||
customAuthenticatedGRPCAddr := new(string)
|
||||
*customAuthenticatedGRPCAddr = "127.0.0.1:9191"
|
||||
|
||||
@@ -76,6 +79,7 @@ func TestLoadFromEnv(t *testing.T) {
|
||||
name string
|
||||
shutdownTimeout *string
|
||||
publicHTTPAddr *string
|
||||
authServiceBaseURL *string
|
||||
authenticatedGRPCAddr *string
|
||||
authenticatedGRPCFreshnessWindow *string
|
||||
sessionCacheRedisAddr *string
|
||||
@@ -179,6 +183,40 @@ func TestLoadFromEnv(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "custom auth service base url",
|
||||
authServiceBaseURL: customAuthServiceBaseURL,
|
||||
sessionCacheRedisAddr: customSessionCacheRedisAddr,
|
||||
responseSignerPrivateKeyPEMPath: customResponseSignerPrivateKeyPEMPath,
|
||||
want: Config{
|
||||
ShutdownTimeout: 5 * time.Second,
|
||||
Logging: DefaultLoggingConfig(),
|
||||
PublicHTTP: DefaultPublicHTTPConfig(),
|
||||
AuthService: AuthServiceConfig{
|
||||
BaseURL: "http://127.0.0.1:8082",
|
||||
},
|
||||
AdminHTTP: DefaultAdminHTTPConfig(),
|
||||
AuthenticatedGRPC: DefaultAuthenticatedGRPCConfig(),
|
||||
SessionCacheRedis: SessionCacheRedisConfig{
|
||||
Addr: "127.0.0.1:6379",
|
||||
DB: defaultSessionCacheRedisDB,
|
||||
KeyPrefix: defaultSessionCacheRedisKeyPrefix,
|
||||
LookupTimeout: defaultSessionCacheRedisLookupTimeout,
|
||||
},
|
||||
ReplayRedis: DefaultReplayRedisConfig(),
|
||||
SessionEventsRedis: SessionEventsRedisConfig{
|
||||
Stream: "gateway:session_events",
|
||||
ReadBlockTimeout: defaultSessionEventsRedisReadBlockTimeout,
|
||||
},
|
||||
ClientEventsRedis: ClientEventsRedisConfig{
|
||||
Stream: "gateway:client_events",
|
||||
ReadBlockTimeout: defaultClientEventsRedisReadBlockTimeout,
|
||||
},
|
||||
ResponseSigner: ResponseSignerConfig{
|
||||
PrivateKeyPEMPath: *customResponseSignerPrivateKeyPEMPath,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "custom authenticated grpc address",
|
||||
authenticatedGRPCAddr: customAuthenticatedGRPCAddr,
|
||||
@@ -329,6 +367,7 @@ func TestLoadFromEnv(t *testing.T) {
|
||||
restoreEnvs(t,
|
||||
shutdownTimeoutEnvVar,
|
||||
publicHTTPAddrEnvVar,
|
||||
authServiceBaseURLEnvVar,
|
||||
authenticatedGRPCAddrEnvVar,
|
||||
authenticatedGRPCFreshnessWindowEnvVar,
|
||||
sessionCacheRedisAddrEnvVar,
|
||||
@@ -339,6 +378,7 @@ func TestLoadFromEnv(t *testing.T) {
|
||||
|
||||
setEnvValue(t, shutdownTimeoutEnvVar, tt.shutdownTimeout)
|
||||
setEnvValue(t, publicHTTPAddrEnvVar, tt.publicHTTPAddr)
|
||||
setEnvValue(t, authServiceBaseURLEnvVar, tt.authServiceBaseURL)
|
||||
setEnvValue(t, authenticatedGRPCAddrEnvVar, tt.authenticatedGRPCAddr)
|
||||
setEnvValue(t, authenticatedGRPCFreshnessWindowEnvVar, tt.authenticatedGRPCFreshnessWindow)
|
||||
setEnvValue(t, sessionCacheRedisAddrEnvVar, tt.sessionCacheRedisAddr)
|
||||
@@ -477,6 +517,70 @@ func TestLoadFromEnvOperationalSettings(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadFromEnvAuthService(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
customSessionCacheRedisAddr := new(string)
|
||||
*customSessionCacheRedisAddr = "127.0.0.1:6379"
|
||||
|
||||
customSessionEventsRedisStream := new(string)
|
||||
*customSessionEventsRedisStream = "gateway:session_events"
|
||||
|
||||
customClientEventsRedisStream := new(string)
|
||||
*customClientEventsRedisStream = "gateway:client_events"
|
||||
|
||||
customResponseSignerPrivateKeyPEMPath := new(string)
|
||||
*customResponseSignerPrivateKeyPEMPath = writeTestResponseSignerPEMFile(t)
|
||||
|
||||
invalidRelativeURL := new(string)
|
||||
*invalidRelativeURL = "/authsession"
|
||||
|
||||
invalidURL := new(string)
|
||||
*invalidURL = "://bad"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
value *string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "relative url rejected",
|
||||
value: invalidRelativeURL,
|
||||
wantErr: authServiceBaseURLEnvVar + " must be an absolute URL",
|
||||
},
|
||||
{
|
||||
name: "malformed url rejected",
|
||||
value: invalidURL,
|
||||
wantErr: "parse " + authServiceBaseURLEnvVar,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
restoreEnvs(t,
|
||||
authServiceBaseURLEnvVar,
|
||||
sessionCacheRedisAddrEnvVar,
|
||||
sessionEventsRedisStreamEnvVar,
|
||||
clientEventsRedisStreamEnvVar,
|
||||
responseSignerPrivateKeyPEMPathEnvVar,
|
||||
)
|
||||
setEnvValue(t, authServiceBaseURLEnvVar, tt.value)
|
||||
setEnvValue(t, sessionCacheRedisAddrEnvVar, customSessionCacheRedisAddr)
|
||||
setEnvValue(t, sessionEventsRedisStreamEnvVar, customSessionEventsRedisStream)
|
||||
setEnvValue(t, clientEventsRedisStreamEnvVar, customClientEventsRedisStream)
|
||||
setEnvValue(t, responseSignerPrivateKeyPEMPathEnvVar, customResponseSignerPrivateKeyPEMPath)
|
||||
|
||||
_, err := LoadFromEnv()
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, tt.wantErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadFromEnvAuthenticatedGRPCAntiAbuse(t *testing.T) {
|
||||
customSessionCacheRedisAddr := new(string)
|
||||
*customSessionCacheRedisAddr = "127.0.0.1:6379"
|
||||
@@ -1212,6 +1316,7 @@ func operationalEnvVars() []string {
|
||||
publicHTTPReadTimeoutEnvVar,
|
||||
publicHTTPIdleTimeoutEnvVar,
|
||||
publicAuthUpstreamTimeoutEnvVar,
|
||||
authServiceBaseURLEnvVar,
|
||||
adminHTTPAddrEnvVar,
|
||||
adminHTTPReadHeaderTimeoutEnvVar,
|
||||
adminHTTPReadTimeoutEnvVar,
|
||||
|
||||
@@ -0,0 +1,224 @@
|
||||
package restapi
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
authServiceSendEmailCodePath = "/api/v1/public/auth/send-email-code"
|
||||
authServiceConfirmEmailCodePath = "/api/v1/public/auth/confirm-email-code"
|
||||
)
|
||||
|
||||
// HTTPAuthServiceClient implements AuthServiceClient over the Auth / Session
|
||||
// Service public HTTP API using strict JSON request and response decoding.
|
||||
type HTTPAuthServiceClient struct {
|
||||
baseURL string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
type authServiceErrorEnvelope struct {
|
||||
Error *authServiceErrorBody `json:"error"`
|
||||
}
|
||||
|
||||
type authServiceErrorBody struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// NewHTTPAuthServiceClient constructs an AuthServiceClient that delegates the
|
||||
// gateway public-auth routes to the Auth / Session Service public HTTP API at
|
||||
// baseURL. The resulting client relies only on the caller-provided context for
|
||||
// cancellation and timeout control.
|
||||
func NewHTTPAuthServiceClient(baseURL string) (*HTTPAuthServiceClient, error) {
|
||||
transport, ok := http.DefaultTransport.(*http.Transport)
|
||||
if !ok {
|
||||
return nil, errors.New("new auth service HTTP client: default transport is not *http.Transport")
|
||||
}
|
||||
|
||||
return newHTTPAuthServiceClient(baseURL, &http.Client{
|
||||
Transport: transport.Clone(),
|
||||
})
|
||||
}
|
||||
|
||||
func newHTTPAuthServiceClient(baseURL string, httpClient *http.Client) (*HTTPAuthServiceClient, error) {
|
||||
if httpClient == nil {
|
||||
return nil, errors.New("new auth service HTTP client: http client must not be nil")
|
||||
}
|
||||
|
||||
trimmedBaseURL := strings.TrimSpace(baseURL)
|
||||
if trimmedBaseURL == "" {
|
||||
return nil, errors.New("new auth service HTTP client: base URL must not be empty")
|
||||
}
|
||||
|
||||
parsedBaseURL, err := url.Parse(strings.TrimRight(trimmedBaseURL, "/"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new auth service HTTP client: parse base URL: %w", err)
|
||||
}
|
||||
if parsedBaseURL.Scheme == "" || parsedBaseURL.Host == "" {
|
||||
return nil, errors.New("new auth service HTTP client: base URL must be absolute")
|
||||
}
|
||||
|
||||
return &HTTPAuthServiceClient{
|
||||
baseURL: parsedBaseURL.String(),
|
||||
httpClient: httpClient,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close releases idle HTTP connections owned by the client transport.
|
||||
func (c *HTTPAuthServiceClient) Close() error {
|
||||
if c == nil || c.httpClient == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
type idleCloser interface {
|
||||
CloseIdleConnections()
|
||||
}
|
||||
|
||||
if transport, ok := c.httpClient.Transport.(idleCloser); ok {
|
||||
transport.CloseIdleConnections()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendEmailCode delegates the public send-email-code route to the configured
|
||||
// Auth / Session Service public HTTP API.
|
||||
func (c *HTTPAuthServiceClient) SendEmailCode(ctx context.Context, input SendEmailCodeInput) (SendEmailCodeResult, error) {
|
||||
payload, statusCode, err := c.doJSONRequest(ctx, authServiceSendEmailCodePath, input)
|
||||
if err != nil {
|
||||
return SendEmailCodeResult{}, fmt.Errorf("send email code via auth service: %w", err)
|
||||
}
|
||||
|
||||
switch {
|
||||
case statusCode == http.StatusOK:
|
||||
var result SendEmailCodeResult
|
||||
if err := decodeStrictJSONPayload(payload, &result); err != nil {
|
||||
return SendEmailCodeResult{}, fmt.Errorf("send email code via auth service: decode success response: %w", err)
|
||||
}
|
||||
if err := validateSendEmailCodeResult(&result); err != nil {
|
||||
return SendEmailCodeResult{}, fmt.Errorf("send email code via auth service: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
case statusCode >= 400 && statusCode <= 599:
|
||||
authErr, err := decodeAuthServiceError(statusCode, payload)
|
||||
if err != nil {
|
||||
return SendEmailCodeResult{}, fmt.Errorf("send email code via auth service: %w", err)
|
||||
}
|
||||
|
||||
return SendEmailCodeResult{}, authErr
|
||||
default:
|
||||
return SendEmailCodeResult{}, fmt.Errorf("send email code via auth service: unexpected HTTP status %d", statusCode)
|
||||
}
|
||||
}
|
||||
|
||||
// ConfirmEmailCode delegates the public confirm-email-code route to the
|
||||
// configured Auth / Session Service public HTTP API.
|
||||
func (c *HTTPAuthServiceClient) ConfirmEmailCode(ctx context.Context, input ConfirmEmailCodeInput) (ConfirmEmailCodeResult, error) {
|
||||
payload, statusCode, err := c.doJSONRequest(ctx, authServiceConfirmEmailCodePath, input)
|
||||
if err != nil {
|
||||
return ConfirmEmailCodeResult{}, fmt.Errorf("confirm email code via auth service: %w", err)
|
||||
}
|
||||
|
||||
switch {
|
||||
case statusCode == http.StatusOK:
|
||||
var result ConfirmEmailCodeResult
|
||||
if err := decodeStrictJSONPayload(payload, &result); err != nil {
|
||||
return ConfirmEmailCodeResult{}, fmt.Errorf("confirm email code via auth service: decode success response: %w", err)
|
||||
}
|
||||
if err := validateConfirmEmailCodeResult(&result); err != nil {
|
||||
return ConfirmEmailCodeResult{}, fmt.Errorf("confirm email code via auth service: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
case statusCode >= 400 && statusCode <= 599:
|
||||
authErr, err := decodeAuthServiceError(statusCode, payload)
|
||||
if err != nil {
|
||||
return ConfirmEmailCodeResult{}, fmt.Errorf("confirm email code via auth service: %w", err)
|
||||
}
|
||||
|
||||
return ConfirmEmailCodeResult{}, authErr
|
||||
default:
|
||||
return ConfirmEmailCodeResult{}, fmt.Errorf("confirm email code via auth service: unexpected HTTP status %d", statusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HTTPAuthServiceClient) doJSONRequest(ctx context.Context, path string, requestBody any) ([]byte, int, error) {
|
||||
if c == nil || c.httpClient == nil {
|
||||
return nil, 0, errors.New("nil client")
|
||||
}
|
||||
if ctx == nil {
|
||||
return nil, 0, errors.New("nil context")
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(requestBody)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("marshal request body: %w", err)
|
||||
}
|
||||
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+path, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("build request: %w", err)
|
||||
}
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
response, err := c.httpClient.Do(request)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
responsePayload, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("read response body: %w", err)
|
||||
}
|
||||
|
||||
return responsePayload, response.StatusCode, nil
|
||||
}
|
||||
|
||||
func decodeAuthServiceError(statusCode int, payload []byte) (*AuthServiceError, error) {
|
||||
var envelope authServiceErrorEnvelope
|
||||
if err := decodeStrictJSONPayload(payload, &envelope); err != nil {
|
||||
return nil, fmt.Errorf("decode error response: %w", err)
|
||||
}
|
||||
if envelope.Error == nil {
|
||||
return nil, errors.New("decode error response: missing error object")
|
||||
}
|
||||
|
||||
return &AuthServiceError{
|
||||
StatusCode: statusCode,
|
||||
Code: envelope.Error.Code,
|
||||
Message: envelope.Error.Message,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func decodeStrictJSONPayload(payload []byte, target any) error {
|
||||
decoder := json.NewDecoder(bytes.NewReader(payload))
|
||||
decoder.DisallowUnknownFields()
|
||||
|
||||
if err := decoder.Decode(target); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := decoder.Decode(&struct{}{}); err != io.EOF {
|
||||
if err == nil {
|
||||
return errors.New("unexpected trailing JSON input")
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ AuthServiceClient = (*HTTPAuthServiceClient)(nil)
|
||||
@@ -0,0 +1,346 @@
|
||||
package restapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewHTTPAuthServiceClient(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
baseURL string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
baseURL: " http://127.0.0.1:8080/ ",
|
||||
},
|
||||
{
|
||||
name: "empty base url",
|
||||
wantErr: "base URL must not be empty",
|
||||
},
|
||||
{
|
||||
name: "relative base url",
|
||||
baseURL: "/authsession",
|
||||
wantErr: "base URL must be absolute",
|
||||
},
|
||||
{
|
||||
name: "malformed base url",
|
||||
baseURL: "://bad",
|
||||
wantErr: "parse base URL",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, err := NewHTTPAuthServiceClient(tt.baseURL)
|
||||
if tt.wantErr != "" {
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "http://127.0.0.1:8080", client.baseURL)
|
||||
assert.NoError(t, client.Close())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPAuthServiceClientSendEmailCodeSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var requestContentType string
|
||||
var requestBody string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, http.MethodPost, r.Method)
|
||||
assert.Equal(t, authServiceSendEmailCodePath, r.URL.Path)
|
||||
|
||||
requestContentType = r.Header.Get("Content-Type")
|
||||
payload, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
requestBody = string(payload)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, err = io.WriteString(w, `{"challenge_id":"challenge-123"}`)
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestHTTPAuthServiceClient(t, server)
|
||||
|
||||
result, err := client.SendEmailCode(context.Background(), SendEmailCodeInput{
|
||||
Email: "pilot@example.com",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, SendEmailCodeResult{ChallengeID: "challenge-123"}, result)
|
||||
assert.Equal(t, "application/json", requestContentType)
|
||||
assert.JSONEq(t, `{"email":"pilot@example.com"}`, requestBody)
|
||||
}
|
||||
|
||||
func TestHTTPAuthServiceClientConfirmEmailCodeSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, http.MethodPost, r.Method)
|
||||
assert.Equal(t, authServiceConfirmEmailCodePath, r.URL.Path)
|
||||
|
||||
payload, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
assert.JSONEq(t, `{"challenge_id":"challenge-123","code":"123456","client_public_key":"public-key","time_zone":"Europe/Kaliningrad"}`, string(payload))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, err = io.WriteString(w, `{"device_session_id":"device-session-123"}`)
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestHTTPAuthServiceClient(t, server)
|
||||
|
||||
result, err := client.ConfirmEmailCode(context.Background(), ConfirmEmailCodeInput{
|
||||
ChallengeID: "challenge-123",
|
||||
Code: "123456",
|
||||
ClientPublicKey: "public-key",
|
||||
TimeZone: "Europe/Kaliningrad",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ConfirmEmailCodeResult{DeviceSessionID: "device-session-123"}, result)
|
||||
}
|
||||
|
||||
func TestHTTPAuthServiceClientProjectsAuthServiceErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
responseBody string
|
||||
call func(*HTTPAuthServiceClient) error
|
||||
wantStatusCode int
|
||||
wantCode string
|
||||
wantMessage string
|
||||
}{
|
||||
{
|
||||
name: "send email code error",
|
||||
statusCode: http.StatusServiceUnavailable,
|
||||
responseBody: `{"error":{"code":"service_unavailable","message":"service is unavailable"}}`,
|
||||
call: func(client *HTTPAuthServiceClient) error {
|
||||
_, err := client.SendEmailCode(context.Background(), SendEmailCodeInput{Email: "pilot@example.com"})
|
||||
return err
|
||||
},
|
||||
wantStatusCode: http.StatusServiceUnavailable,
|
||||
wantCode: "service_unavailable",
|
||||
wantMessage: "service is unavailable",
|
||||
},
|
||||
{
|
||||
name: "confirm email code error",
|
||||
statusCode: http.StatusConflict,
|
||||
responseBody: `{"error":{"code":"session_limit_exceeded","message":"active session limit would be exceeded"}}`,
|
||||
call: func(client *HTTPAuthServiceClient) error {
|
||||
_, err := client.ConfirmEmailCode(context.Background(), ConfirmEmailCodeInput{
|
||||
ChallengeID: "challenge-123",
|
||||
Code: "123456",
|
||||
ClientPublicKey: "public-key",
|
||||
TimeZone: "Europe/Kaliningrad",
|
||||
})
|
||||
return err
|
||||
},
|
||||
wantStatusCode: http.StatusConflict,
|
||||
wantCode: "session_limit_exceeded",
|
||||
wantMessage: "active session limit would be exceeded",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(tt.statusCode)
|
||||
_, err := io.WriteString(w, tt.responseBody)
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestHTTPAuthServiceClient(t, server)
|
||||
err := tt.call(client)
|
||||
require.Error(t, err)
|
||||
|
||||
var authErr *AuthServiceError
|
||||
require.ErrorAs(t, err, &authErr)
|
||||
assert.Equal(t, tt.wantStatusCode, authErr.StatusCode)
|
||||
assert.Equal(t, tt.wantCode, authErr.Code)
|
||||
assert.Equal(t, tt.wantMessage, authErr.Message)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPAuthServiceClientRejectsMalformedPayloads(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
statusCode int
|
||||
responseBody string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "send email code rejects unknown success field",
|
||||
path: authServiceSendEmailCodePath,
|
||||
statusCode: http.StatusOK,
|
||||
responseBody: `{"challenge_id":"challenge-123","extra":true}`,
|
||||
wantErr: "decode success response",
|
||||
},
|
||||
{
|
||||
name: "confirm email code rejects empty success field",
|
||||
path: authServiceConfirmEmailCodePath,
|
||||
statusCode: http.StatusOK,
|
||||
responseBody: `{"device_session_id":" "}`,
|
||||
wantErr: "empty device_session_id",
|
||||
},
|
||||
{
|
||||
name: "rejects missing error object",
|
||||
path: authServiceSendEmailCodePath,
|
||||
statusCode: http.StatusBadRequest,
|
||||
responseBody: `{}`,
|
||||
wantErr: "missing error object",
|
||||
},
|
||||
{
|
||||
name: "rejects malformed error envelope",
|
||||
path: authServiceConfirmEmailCodePath,
|
||||
statusCode: http.StatusBadRequest,
|
||||
responseBody: `{"error":{"code":"invalid_code","message":"confirmation code is invalid","extra":true}}`,
|
||||
wantErr: "decode error response",
|
||||
},
|
||||
{
|
||||
name: "rejects unexpected status",
|
||||
path: authServiceSendEmailCodePath,
|
||||
statusCode: http.StatusCreated,
|
||||
responseBody: `{"challenge_id":"challenge-123"}`,
|
||||
wantErr: "unexpected HTTP status 201",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, tt.path, r.URL.Path)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(tt.statusCode)
|
||||
_, err := io.WriteString(w, tt.responseBody)
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestHTTPAuthServiceClient(t, server)
|
||||
|
||||
var err error
|
||||
switch tt.path {
|
||||
case authServiceSendEmailCodePath:
|
||||
_, err = client.SendEmailCode(context.Background(), SendEmailCodeInput{Email: "pilot@example.com"})
|
||||
default:
|
||||
_, err = client.ConfirmEmailCode(context.Background(), ConfirmEmailCodeInput{
|
||||
ChallengeID: "challenge-123",
|
||||
Code: "123456",
|
||||
ClientPublicKey: "public-key",
|
||||
TimeZone: "Europe/Kaliningrad",
|
||||
})
|
||||
}
|
||||
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, tt.wantErr)
|
||||
assert.NotErrorAs(t, err, new(*AuthServiceError))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPAuthServiceClientUsesCallerContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"challenge_id":"challenge-123"}`)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestHTTPAuthServiceClient(t, server)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, err := client.SendEmailCode(ctx, SendEmailCodeInput{Email: "pilot@example.com"})
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "send email code via auth service")
|
||||
assert.True(t, errors.Is(err, context.DeadlineExceeded))
|
||||
}
|
||||
|
||||
func TestHTTPAuthServiceClientRejectsNilContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.FailNow(t, "unexpected request", r.URL.Path)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestHTTPAuthServiceClient(t, server)
|
||||
|
||||
_, err := client.SendEmailCode(nil, SendEmailCodeInput{Email: "pilot@example.com"})
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "nil context")
|
||||
}
|
||||
|
||||
func newTestHTTPAuthServiceClient(t *testing.T, server *httptest.Server) *HTTPAuthServiceClient {
|
||||
t.Helper()
|
||||
|
||||
client, err := newHTTPAuthServiceClient(server.URL, server.Client())
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, client.Close())
|
||||
})
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
func TestDecodeStrictJSONPayloadRejectsTrailingJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var target struct {
|
||||
Value string `json:"value"`
|
||||
}
|
||||
err := decodeStrictJSONPayload([]byte(`{"value":"ok"}{}`), &target)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, "unexpected trailing JSON input", err.Error())
|
||||
}
|
||||
|
||||
func TestDecodeAuthServiceErrorPreservesBlankFieldsForLaterNormalization(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
authErr, err := decodeAuthServiceError(http.StatusBadGateway, []byte(`{"error":{"code":" ","message":" "}}`))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusBadGateway, authErr.StatusCode)
|
||||
assert.True(t, strings.TrimSpace(authErr.Code) == "")
|
||||
assert.True(t, strings.TrimSpace(authErr.Message) == "")
|
||||
}
|
||||
Reference in New Issue
Block a user