package internalhttp import ( "bytes" "context" "encoding/json" "io" "log/slog" "net/http" "net/http/httptest" "testing" "galaxy/mail/internal/service/acceptauthdelivery" mailtelemetry "galaxy/mail/internal/telemetry" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/attribute" sdkmetric "go.opentelemetry.io/otel/sdk/metric" "go.opentelemetry.io/otel/sdk/metric/metricdata" sdktrace "go.opentelemetry.io/otel/sdk/trace" "go.opentelemetry.io/otel/sdk/trace/tracetest" ) func TestLoginCodeDeliveryHandlerReturnsSuccessOutcomes(t *testing.T) { t.Parallel() tests := []struct { name string result acceptauthdelivery.Result wantOutcome string }{ {name: "sent", result: acceptauthdelivery.Result{Outcome: acceptauthdelivery.OutcomeSent}, wantOutcome: "sent"}, {name: "suppressed", result: acceptauthdelivery.Result{Outcome: acceptauthdelivery.OutcomeSuppressed}, wantOutcome: "suppressed"}, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() handler := newHandler(Dependencies{ Logger: slog.New(slog.NewJSONHandler(io.Discard, nil)), AcceptLoginCodeDelivery: acceptLoginCodeDeliveryFunc(func(context.Context, acceptauthdelivery.Input) (acceptauthdelivery.Result, error) { return tt.result, nil }), }) response := doLoginCodeDeliveryRequest(t, handler, `{"email":"pilot@example.com","code":"123456","locale":"en"}`, "challenge-1") defer response.Body.Close() require.Equal(t, http.StatusOK, response.StatusCode) require.Equal(t, "application/json", response.Header.Get("Content-Type")) var payload LoginCodeDeliveryResponse require.NoError(t, decodeJSONBody(response, &payload)) require.Equal(t, LoginCodeDeliveryOutcome(tt.wantOutcome), payload.Outcome) }) } } func TestLoginCodeDeliveryHandlerMapsErrors(t *testing.T) { t.Parallel() tests := []struct { name string body string header string useCaseErr error wantCode int wantErr string }{ { name: "invalid request", body: `{"email":"pilot@example.com","code":"123456","locale":"en"}`, wantCode: http.StatusBadRequest, wantErr: ErrorCodeInvalidRequest, }, { name: "conflict", body: `{"email":"pilot@example.com","code":"123456","locale":"en"}`, header: "challenge-1", useCaseErr: acceptauthdelivery.ErrConflict, wantCode: http.StatusConflict, wantErr: ErrorCodeConflict, }, { name: "service unavailable", body: `{"email":"pilot@example.com","code":"123456","locale":"en"}`, header: "challenge-1", useCaseErr: acceptauthdelivery.ErrServiceUnavailable, wantCode: http.StatusServiceUnavailable, wantErr: ErrorCodeServiceUnavailable, }, { name: "internal error", body: `{"email":"pilot@example.com","code":"123456","locale":"en"}`, header: "challenge-1", useCaseErr: context.DeadlineExceeded, wantCode: http.StatusInternalServerError, wantErr: ErrorCodeInternalError, }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() handler := newHandler(Dependencies{ Logger: slog.New(slog.NewJSONHandler(io.Discard, nil)), AcceptLoginCodeDelivery: acceptLoginCodeDeliveryFunc(func(context.Context, acceptauthdelivery.Input) (acceptauthdelivery.Result, error) { if tt.useCaseErr != nil { return acceptauthdelivery.Result{}, tt.useCaseErr } return acceptauthdelivery.Result{Outcome: acceptauthdelivery.OutcomeSent}, nil }), }) response := doLoginCodeDeliveryRequest(t, handler, tt.body, tt.header) defer response.Body.Close() require.Equal(t, tt.wantCode, response.StatusCode) var payload ErrorResponse require.NoError(t, decodeJSONBody(response, &payload)) require.Equal(t, tt.wantErr, payload.Error.Code) }) } } func TestLoginCodeDeliveryHandlerEmitsMetricsAndSpan(t *testing.T) { t.Parallel() reader := sdkmetric.NewManualReader() meterProvider := sdkmetric.NewMeterProvider(sdkmetric.WithReader(reader)) recorder := tracetest.NewSpanRecorder() tracerProvider := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(recorder)) telemetryRuntime, err := mailtelemetry.NewWithProviders(meterProvider, tracerProvider) require.NoError(t, err) loggerBuffer := &bytes.Buffer{} logger := slog.New(slog.NewJSONHandler(loggerBuffer, nil)) handler := newHandler(Dependencies{ Logger: logger, Telemetry: telemetryRuntime, AcceptLoginCodeDelivery: acceptLoginCodeDeliveryFunc(func(context.Context, acceptauthdelivery.Input) (acceptauthdelivery.Result, error) { return acceptauthdelivery.Result{Outcome: acceptauthdelivery.OutcomeSent}, nil }), }) response := doLoginCodeDeliveryRequest(t, handler, `{"email":"pilot@example.com","code":"123456","locale":"en"}`, "challenge-1") defer response.Body.Close() require.Equal(t, http.StatusOK, response.StatusCode) require.Len(t, recorder.Ended(), 1) assert.Equal(t, LoginCodeDeliveriesPath, recorder.Ended()[0].Name()) assert.Contains(t, loggerBuffer.String(), "otel_trace_id") assert.Contains(t, loggerBuffer.String(), "otel_span_id") assertMetricCount(t, reader, "mail.internal_http.requests", map[string]string{ "route": LoginCodeDeliveriesPath, "method": http.MethodPost, "edge_outcome": "success", }, 1) } type acceptLoginCodeDeliveryFunc func(context.Context, acceptauthdelivery.Input) (acceptauthdelivery.Result, error) func (fn acceptLoginCodeDeliveryFunc) Execute(ctx context.Context, input acceptauthdelivery.Input) (acceptauthdelivery.Result, error) { return fn(ctx, input) } func doLoginCodeDeliveryRequest(t *testing.T, handler http.Handler, body string, idempotencyKey string) *http.Response { t.Helper() request := httptest.NewRequest(http.MethodPost, LoginCodeDeliveriesPath, bytes.NewBufferString(body)) request.Header.Set("Content-Type", "application/json") if idempotencyKey != "" { request.Header.Set(IdempotencyKeyHeader, idempotencyKey) } recorder := httptest.NewRecorder() handler.ServeHTTP(recorder, request) return recorder.Result() } func decodeJSONBody(response *http.Response, target any) error { return json.NewDecoder(response.Body).Decode(target) } func assertMetricCount(t *testing.T, reader *sdkmetric.ManualReader, metricName string, wantAttrs map[string]string, wantValue int64) { t.Helper() var resourceMetrics metricdata.ResourceMetrics require.NoError(t, reader.Collect(context.Background(), &resourceMetrics)) for _, scopeMetrics := range resourceMetrics.ScopeMetrics { for _, metric := range scopeMetrics.Metrics { if metric.Name != metricName { continue } sum, ok := metric.Data.(metricdata.Sum[int64]) require.True(t, ok) for _, point := range sum.DataPoints { if hasMetricAttributes(point.Attributes.ToSlice(), wantAttrs) { assert.Equal(t, wantValue, point.Value) return } } } } require.Failf(t, "test failed", "metric %q with attrs %v not found", metricName, wantAttrs) } func hasMetricAttributes(values []attribute.KeyValue, want map[string]string) bool { if len(values) != len(want) { return false } for _, value := range values { if want[string(value.Key)] != value.Value.AsString() { return false } } return true }