package mail import ( "context" "encoding/json" "errors" "io" "net/http" "net/http/httptest" "strings" "sync" "sync/atomic" "testing" "time" "galaxy/authsession/internal/domain/common" "galaxy/authsession/internal/ports" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestNewRESTClient(t *testing.T) { t.Parallel() tests := []struct { name string cfg Config wantErr string }{ { name: "valid config", cfg: Config{ BaseURL: "http://127.0.0.1:8080", RequestTimeout: time.Second, }, }, { name: "empty base url", cfg: Config{ RequestTimeout: time.Second, }, wantErr: "base URL must not be empty", }, { name: "relative base url", cfg: Config{ BaseURL: "/relative", RequestTimeout: time.Second, }, wantErr: "base URL must be absolute", }, { name: "non positive timeout", cfg: Config{ BaseURL: "http://127.0.0.1:8080", }, wantErr: "request timeout must be positive", }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() client, err := NewRESTClient(tt.cfg) if tt.wantErr != "" { require.Error(t, err) assert.ErrorContains(t, err, tt.wantErr) return } require.NoError(t, err) assert.NoError(t, client.Close()) }) } } func TestRESTClientSendLoginCodeSuccessCases(t *testing.T) { t.Parallel() tests := []struct { name string response string wantOutcome ports.SendLoginCodeOutcome }{ { name: "sent", response: `{"outcome":"sent"}`, wantOutcome: ports.SendLoginCodeOutcomeSent, }, { name: "suppressed", response: `{"outcome":"suppressed"}`, wantOutcome: ports.SendLoginCodeOutcomeSuppressed, }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() var requestsMu sync.Mutex var requests []capturedRequest server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { requestsMu.Lock() requests = append(requests, captureRequest(t, r)) requestsMu.Unlock() writeJSON(t, w, http.StatusOK, json.RawMessage(tt.response)) })) defer server.Close() client := newTestRESTClient(t, server.URL, 250*time.Millisecond) result, err := client.SendLoginCode(context.Background(), validInput()) require.NoError(t, err) assert.Equal(t, tt.wantOutcome, result.Outcome) requestsMu.Lock() defer requestsMu.Unlock() require.Len(t, requests, 1) assert.Equal(t, http.MethodPost, requests[0].Method) assert.Equal(t, sendLoginCodePath, requests[0].Path) assert.Equal(t, "application/json", requests[0].ContentType) assert.Equal(t, "challenge-1", requests[0].IdempotencyKey) assert.JSONEq(t, `{"email":"pilot@example.com","code":"654321","locale":"en"}`, requests[0].Body) }) } } func TestRESTClientPreservesNormalizedEmailAndCodeExactly(t *testing.T) { t.Parallel() var captured capturedRequest server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { captured = captureRequest(t, r) writeJSON(t, w, http.StatusOK, map[string]string{"outcome": "sent"}) })) defer server.Close() client := newTestRESTClient(t, server.URL, 250*time.Millisecond) result, err := client.SendLoginCode(context.Background(), ports.SendLoginCodeInput{ Email: common.Email("Pilot+Alias@Example.com"), IdempotencyKey: "challenge-1", Code: "123456", Locale: "fr-FR", }) require.NoError(t, err) assert.Equal(t, ports.SendLoginCodeOutcomeSent, result.Outcome) assert.Equal(t, "challenge-1", captured.IdempotencyKey) assert.JSONEq(t, `{"email":"Pilot+Alias@Example.com","code":"123456","locale":"fr-FR"}`, captured.Body) } func TestRESTClientSendLoginCodeDoesNotRetry(t *testing.T) { t.Parallel() t.Run("no retry on 503", func(t *testing.T) { t.Parallel() var calls atomic.Int64 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { calls.Add(1) http.Error(w, "temporary", http.StatusServiceUnavailable) })) defer server.Close() client := newTestRESTClient(t, server.URL, 250*time.Millisecond) _, err := client.SendLoginCode(context.Background(), validInput()) require.Error(t, err) assert.ErrorContains(t, err, "unexpected HTTP status 503") assert.EqualValues(t, 1, calls.Load()) }) t.Run("no retry on transport failure", func(t *testing.T) { t.Parallel() var calls atomic.Int64 client, err := newRESTClient(Config{ BaseURL: "http://127.0.0.1:8080", RequestTimeout: 250 * time.Millisecond, }, &http.Client{ Transport: roundTripperFunc(func(request *http.Request) (*http.Response, error) { calls.Add(1) return nil, errors.New("temporary transport failure") }), }) require.NoError(t, err) _, err = client.SendLoginCode(context.Background(), validInput()) require.Error(t, err) assert.ErrorContains(t, err, "temporary transport failure") assert.EqualValues(t, 1, calls.Load()) }) } func TestRESTClientStrictDecodingAndUnexpectedStatuses(t *testing.T) { t.Parallel() tests := []struct { name string statusCode int body string wantErrText string }{ { name: "rejects unknown field", statusCode: http.StatusOK, body: `{"outcome":"sent","extra":true}`, wantErrText: "decode response body", }, { name: "rejects unsupported outcome", statusCode: http.StatusOK, body: `{"outcome":"queued"}`, wantErrText: "unsupported", }, { name: "rejects missing outcome", statusCode: http.StatusOK, body: `{}`, wantErrText: "unsupported", }, { name: "rejects trailing json", statusCode: http.StatusOK, body: `{"outcome":"sent"}{}`, wantErrText: "unexpected trailing JSON input", }, { name: "rejects unexpected status", statusCode: http.StatusBadGateway, body: `{"error":"temporary"}`, wantErrText: "unexpected HTTP status 502", }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(tt.statusCode) _, err := io.WriteString(w, tt.body) require.NoError(t, err) })) defer server.Close() client := newTestRESTClient(t, server.URL, 250*time.Millisecond) _, err := client.SendLoginCode(context.Background(), validInput()) require.Error(t, err) assert.ErrorContains(t, err, tt.wantErrText) }) } } func TestRESTClientRequestTimeout(t *testing.T) { t.Parallel() server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(40 * time.Millisecond) writeJSON(t, w, http.StatusOK, map[string]string{"outcome": "sent"}) })) defer server.Close() client := newTestRESTClient(t, server.URL, 10*time.Millisecond) _, err := client.SendLoginCode(context.Background(), validInput()) require.Error(t, err) assert.ErrorContains(t, err, "context deadline exceeded") } func TestRESTClientContextAndValidation(t *testing.T) { t.Parallel() server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Fatalf("unexpected upstream call") })) defer server.Close() client := newTestRESTClient(t, server.URL, 250*time.Millisecond) cancelledCtx, cancel := context.WithCancel(context.Background()) cancel() tests := []struct { name string run func() error }{ { name: "nil context", run: func() error { _, err := client.SendLoginCode(nil, validInput()) return err }, }, { name: "cancelled context", run: func() error { _, err := client.SendLoginCode(cancelledCtx, validInput()) return err }, }, { name: "invalid email", run: func() error { _, err := client.SendLoginCode(context.Background(), ports.SendLoginCodeInput{ Email: common.Email(" bad@example.com "), IdempotencyKey: "challenge-1", Code: "123456", Locale: "en", }) return err }, }, { name: "invalid code", run: func() error { _, err := client.SendLoginCode(context.Background(), ports.SendLoginCodeInput{ Email: common.Email("pilot@example.com"), IdempotencyKey: "challenge-1", Code: " 123456 ", Locale: "en", }) return err }, }, { name: "invalid locale", run: func() error { _, err := client.SendLoginCode(context.Background(), ports.SendLoginCodeInput{ Email: common.Email("pilot@example.com"), IdempotencyKey: "challenge-1", Code: "123456", Locale: " en ", }) return err }, }, { name: "invalid idempotency key", run: func() error { _, err := client.SendLoginCode(context.Background(), ports.SendLoginCodeInput{ Email: common.Email("pilot@example.com"), IdempotencyKey: " challenge-1 ", Code: "123456", Locale: "en", }) return err }, }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { err := tt.run() require.Error(t, err) }) } } type capturedRequest struct { Method string Path string ContentType string IdempotencyKey string Body string } func captureRequest(t *testing.T, request *http.Request) capturedRequest { t.Helper() body, err := io.ReadAll(request.Body) require.NoError(t, err) return capturedRequest{ Method: request.Method, Path: request.URL.Path, ContentType: request.Header.Get("Content-Type"), IdempotencyKey: request.Header.Get("Idempotency-Key"), Body: strings.TrimSpace(string(body)), } } func writeJSON(t *testing.T, writer http.ResponseWriter, statusCode int, value any) { t.Helper() payload, err := json.Marshal(value) require.NoError(t, err) writer.Header().Set("Content-Type", "application/json") writer.WriteHeader(statusCode) _, err = writer.Write(payload) require.NoError(t, err) } func newTestRESTClient(t *testing.T, baseURL string, timeout time.Duration) *RESTClient { t.Helper() client, err := NewRESTClient(Config{ BaseURL: baseURL, RequestTimeout: timeout, }) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, client.Close()) }) return client } type roundTripperFunc func(*http.Request) (*http.Response, error) func (fn roundTripperFunc) RoundTrip(request *http.Request) (*http.Response, error) { return fn(request) }