Files
galaxy-game/authsession/internal/adapters/mail/rest_client_test.go
T
2026-04-17 18:39:16 +02:00

429 lines
10 KiB
Go

package mail
import (
"context"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"galaxy/authsession/internal/domain/common"
"galaxy/authsession/internal/ports"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewRESTClient(t *testing.T) {
t.Parallel()
tests := []struct {
name string
cfg Config
wantErr string
}{
{
name: "valid config",
cfg: Config{
BaseURL: "http://127.0.0.1:8080",
RequestTimeout: time.Second,
},
},
{
name: "empty base url",
cfg: Config{
RequestTimeout: time.Second,
},
wantErr: "base URL must not be empty",
},
{
name: "relative base url",
cfg: Config{
BaseURL: "/relative",
RequestTimeout: time.Second,
},
wantErr: "base URL must be absolute",
},
{
name: "non positive timeout",
cfg: Config{
BaseURL: "http://127.0.0.1:8080",
},
wantErr: "request timeout must be positive",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
client, err := NewRESTClient(tt.cfg)
if tt.wantErr != "" {
require.Error(t, err)
assert.ErrorContains(t, err, tt.wantErr)
return
}
require.NoError(t, err)
assert.NoError(t, client.Close())
})
}
}
func TestRESTClientSendLoginCodeSuccessCases(t *testing.T) {
t.Parallel()
tests := []struct {
name string
response string
wantOutcome ports.SendLoginCodeOutcome
}{
{
name: "sent",
response: `{"outcome":"sent"}`,
wantOutcome: ports.SendLoginCodeOutcomeSent,
},
{
name: "suppressed",
response: `{"outcome":"suppressed"}`,
wantOutcome: ports.SendLoginCodeOutcomeSuppressed,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var requestsMu sync.Mutex
var requests []capturedRequest
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestsMu.Lock()
requests = append(requests, captureRequest(t, r))
requestsMu.Unlock()
writeJSON(t, w, http.StatusOK, json.RawMessage(tt.response))
}))
defer server.Close()
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
result, err := client.SendLoginCode(context.Background(), validInput())
require.NoError(t, err)
assert.Equal(t, tt.wantOutcome, result.Outcome)
requestsMu.Lock()
defer requestsMu.Unlock()
require.Len(t, requests, 1)
assert.Equal(t, http.MethodPost, requests[0].Method)
assert.Equal(t, sendLoginCodePath, requests[0].Path)
assert.Equal(t, "application/json", requests[0].ContentType)
assert.Equal(t, "challenge-1", requests[0].IdempotencyKey)
assert.JSONEq(t, `{"email":"pilot@example.com","code":"654321","locale":"en"}`, requests[0].Body)
})
}
}
func TestRESTClientPreservesNormalizedEmailAndCodeExactly(t *testing.T) {
t.Parallel()
var captured capturedRequest
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured = captureRequest(t, r)
writeJSON(t, w, http.StatusOK, map[string]string{"outcome": "sent"})
}))
defer server.Close()
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
result, err := client.SendLoginCode(context.Background(), ports.SendLoginCodeInput{
Email: common.Email("Pilot+Alias@Example.com"),
IdempotencyKey: "challenge-1",
Code: "123456",
Locale: "fr-FR",
})
require.NoError(t, err)
assert.Equal(t, ports.SendLoginCodeOutcomeSent, result.Outcome)
assert.Equal(t, "challenge-1", captured.IdempotencyKey)
assert.JSONEq(t, `{"email":"Pilot+Alias@Example.com","code":"123456","locale":"fr-FR"}`, captured.Body)
}
func TestRESTClientSendLoginCodeDoesNotRetry(t *testing.T) {
t.Parallel()
t.Run("no retry on 503", func(t *testing.T) {
t.Parallel()
var calls atomic.Int64
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
calls.Add(1)
http.Error(w, "temporary", http.StatusServiceUnavailable)
}))
defer server.Close()
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
_, err := client.SendLoginCode(context.Background(), validInput())
require.Error(t, err)
assert.ErrorContains(t, err, "unexpected HTTP status 503")
assert.EqualValues(t, 1, calls.Load())
})
t.Run("no retry on transport failure", func(t *testing.T) {
t.Parallel()
var calls atomic.Int64
client, err := newRESTClient(Config{
BaseURL: "http://127.0.0.1:8080",
RequestTimeout: 250 * time.Millisecond,
}, &http.Client{
Transport: roundTripperFunc(func(request *http.Request) (*http.Response, error) {
calls.Add(1)
return nil, errors.New("temporary transport failure")
}),
})
require.NoError(t, err)
_, err = client.SendLoginCode(context.Background(), validInput())
require.Error(t, err)
assert.ErrorContains(t, err, "temporary transport failure")
assert.EqualValues(t, 1, calls.Load())
})
}
func TestRESTClientStrictDecodingAndUnexpectedStatuses(t *testing.T) {
t.Parallel()
tests := []struct {
name string
statusCode int
body string
wantErrText string
}{
{
name: "rejects unknown field",
statusCode: http.StatusOK,
body: `{"outcome":"sent","extra":true}`,
wantErrText: "decode response body",
},
{
name: "rejects unsupported outcome",
statusCode: http.StatusOK,
body: `{"outcome":"queued"}`,
wantErrText: "unsupported",
},
{
name: "rejects missing outcome",
statusCode: http.StatusOK,
body: `{}`,
wantErrText: "unsupported",
},
{
name: "rejects trailing json",
statusCode: http.StatusOK,
body: `{"outcome":"sent"}{}`,
wantErrText: "unexpected trailing JSON input",
},
{
name: "rejects unexpected status",
statusCode: http.StatusBadGateway,
body: `{"error":"temporary"}`,
wantErrText: "unexpected HTTP status 502",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(tt.statusCode)
_, err := io.WriteString(w, tt.body)
require.NoError(t, err)
}))
defer server.Close()
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
_, err := client.SendLoginCode(context.Background(), validInput())
require.Error(t, err)
assert.ErrorContains(t, err, tt.wantErrText)
})
}
}
func TestRESTClientRequestTimeout(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(40 * time.Millisecond)
writeJSON(t, w, http.StatusOK, map[string]string{"outcome": "sent"})
}))
defer server.Close()
client := newTestRESTClient(t, server.URL, 10*time.Millisecond)
_, err := client.SendLoginCode(context.Background(), validInput())
require.Error(t, err)
assert.ErrorContains(t, err, "context deadline exceeded")
}
func TestRESTClientContextAndValidation(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatalf("unexpected upstream call")
}))
defer server.Close()
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
cancelledCtx, cancel := context.WithCancel(context.Background())
cancel()
tests := []struct {
name string
run func() error
}{
{
name: "nil context",
run: func() error {
_, err := client.SendLoginCode(nil, validInput())
return err
},
},
{
name: "cancelled context",
run: func() error {
_, err := client.SendLoginCode(cancelledCtx, validInput())
return err
},
},
{
name: "invalid email",
run: func() error {
_, err := client.SendLoginCode(context.Background(), ports.SendLoginCodeInput{
Email: common.Email(" bad@example.com "),
IdempotencyKey: "challenge-1",
Code: "123456",
Locale: "en",
})
return err
},
},
{
name: "invalid code",
run: func() error {
_, err := client.SendLoginCode(context.Background(), ports.SendLoginCodeInput{
Email: common.Email("pilot@example.com"),
IdempotencyKey: "challenge-1",
Code: " 123456 ",
Locale: "en",
})
return err
},
},
{
name: "invalid locale",
run: func() error {
_, err := client.SendLoginCode(context.Background(), ports.SendLoginCodeInput{
Email: common.Email("pilot@example.com"),
IdempotencyKey: "challenge-1",
Code: "123456",
Locale: " en ",
})
return err
},
},
{
name: "invalid idempotency key",
run: func() error {
_, err := client.SendLoginCode(context.Background(), ports.SendLoginCodeInput{
Email: common.Email("pilot@example.com"),
IdempotencyKey: " challenge-1 ",
Code: "123456",
Locale: "en",
})
return err
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
err := tt.run()
require.Error(t, err)
})
}
}
type capturedRequest struct {
Method string
Path string
ContentType string
IdempotencyKey string
Body string
}
func captureRequest(t *testing.T, request *http.Request) capturedRequest {
t.Helper()
body, err := io.ReadAll(request.Body)
require.NoError(t, err)
return capturedRequest{
Method: request.Method,
Path: request.URL.Path,
ContentType: request.Header.Get("Content-Type"),
IdempotencyKey: request.Header.Get("Idempotency-Key"),
Body: strings.TrimSpace(string(body)),
}
}
func writeJSON(t *testing.T, writer http.ResponseWriter, statusCode int, value any) {
t.Helper()
payload, err := json.Marshal(value)
require.NoError(t, err)
writer.Header().Set("Content-Type", "application/json")
writer.WriteHeader(statusCode)
_, err = writer.Write(payload)
require.NoError(t, err)
}
func newTestRESTClient(t *testing.T, baseURL string, timeout time.Duration) *RESTClient {
t.Helper()
client, err := NewRESTClient(Config{
BaseURL: baseURL,
RequestTimeout: timeout,
})
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, client.Close())
})
return client
}
type roundTripperFunc func(*http.Request) (*http.Response, error)
func (fn roundTripperFunc) RoundTrip(request *http.Request) (*http.Response, error) {
return fn(request)
}