429 lines
10 KiB
Go
429 lines
10 KiB
Go
package mail
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"galaxy/authsession/internal/domain/common"
|
|
"galaxy/authsession/internal/ports"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestNewRESTClient(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
cfg Config
|
|
wantErr string
|
|
}{
|
|
{
|
|
name: "valid config",
|
|
cfg: Config{
|
|
BaseURL: "http://127.0.0.1:8080",
|
|
RequestTimeout: time.Second,
|
|
},
|
|
},
|
|
{
|
|
name: "empty base url",
|
|
cfg: Config{
|
|
RequestTimeout: time.Second,
|
|
},
|
|
wantErr: "base URL must not be empty",
|
|
},
|
|
{
|
|
name: "relative base url",
|
|
cfg: Config{
|
|
BaseURL: "/relative",
|
|
RequestTimeout: time.Second,
|
|
},
|
|
wantErr: "base URL must be absolute",
|
|
},
|
|
{
|
|
name: "non positive timeout",
|
|
cfg: Config{
|
|
BaseURL: "http://127.0.0.1:8080",
|
|
},
|
|
wantErr: "request timeout must be positive",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
client, err := NewRESTClient(tt.cfg)
|
|
if tt.wantErr != "" {
|
|
require.Error(t, err)
|
|
assert.ErrorContains(t, err, tt.wantErr)
|
|
return
|
|
}
|
|
|
|
require.NoError(t, err)
|
|
assert.NoError(t, client.Close())
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRESTClientSendLoginCodeSuccessCases(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
response string
|
|
wantOutcome ports.SendLoginCodeOutcome
|
|
}{
|
|
{
|
|
name: "sent",
|
|
response: `{"outcome":"sent"}`,
|
|
wantOutcome: ports.SendLoginCodeOutcomeSent,
|
|
},
|
|
{
|
|
name: "suppressed",
|
|
response: `{"outcome":"suppressed"}`,
|
|
wantOutcome: ports.SendLoginCodeOutcomeSuppressed,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var requestsMu sync.Mutex
|
|
var requests []capturedRequest
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
requestsMu.Lock()
|
|
requests = append(requests, captureRequest(t, r))
|
|
requestsMu.Unlock()
|
|
|
|
writeJSON(t, w, http.StatusOK, json.RawMessage(tt.response))
|
|
}))
|
|
defer server.Close()
|
|
|
|
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
|
|
|
|
result, err := client.SendLoginCode(context.Background(), validInput())
|
|
require.NoError(t, err)
|
|
assert.Equal(t, tt.wantOutcome, result.Outcome)
|
|
|
|
requestsMu.Lock()
|
|
defer requestsMu.Unlock()
|
|
|
|
require.Len(t, requests, 1)
|
|
assert.Equal(t, http.MethodPost, requests[0].Method)
|
|
assert.Equal(t, sendLoginCodePath, requests[0].Path)
|
|
assert.Equal(t, "application/json", requests[0].ContentType)
|
|
assert.Equal(t, "challenge-1", requests[0].IdempotencyKey)
|
|
assert.JSONEq(t, `{"email":"pilot@example.com","code":"654321","locale":"en"}`, requests[0].Body)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRESTClientPreservesNormalizedEmailAndCodeExactly(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var captured capturedRequest
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
captured = captureRequest(t, r)
|
|
writeJSON(t, w, http.StatusOK, map[string]string{"outcome": "sent"})
|
|
}))
|
|
defer server.Close()
|
|
|
|
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
|
|
|
|
result, err := client.SendLoginCode(context.Background(), ports.SendLoginCodeInput{
|
|
Email: common.Email("Pilot+Alias@Example.com"),
|
|
IdempotencyKey: "challenge-1",
|
|
Code: "123456",
|
|
Locale: "fr-FR",
|
|
})
|
|
require.NoError(t, err)
|
|
assert.Equal(t, ports.SendLoginCodeOutcomeSent, result.Outcome)
|
|
assert.Equal(t, "challenge-1", captured.IdempotencyKey)
|
|
assert.JSONEq(t, `{"email":"Pilot+Alias@Example.com","code":"123456","locale":"fr-FR"}`, captured.Body)
|
|
}
|
|
|
|
func TestRESTClientSendLoginCodeDoesNotRetry(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("no retry on 503", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var calls atomic.Int64
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
calls.Add(1)
|
|
http.Error(w, "temporary", http.StatusServiceUnavailable)
|
|
}))
|
|
defer server.Close()
|
|
|
|
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
|
|
|
|
_, err := client.SendLoginCode(context.Background(), validInput())
|
|
require.Error(t, err)
|
|
assert.ErrorContains(t, err, "unexpected HTTP status 503")
|
|
assert.EqualValues(t, 1, calls.Load())
|
|
})
|
|
|
|
t.Run("no retry on transport failure", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var calls atomic.Int64
|
|
client, err := newRESTClient(Config{
|
|
BaseURL: "http://127.0.0.1:8080",
|
|
RequestTimeout: 250 * time.Millisecond,
|
|
}, &http.Client{
|
|
Transport: roundTripperFunc(func(request *http.Request) (*http.Response, error) {
|
|
calls.Add(1)
|
|
return nil, errors.New("temporary transport failure")
|
|
}),
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
_, err = client.SendLoginCode(context.Background(), validInput())
|
|
require.Error(t, err)
|
|
assert.ErrorContains(t, err, "temporary transport failure")
|
|
assert.EqualValues(t, 1, calls.Load())
|
|
})
|
|
}
|
|
|
|
func TestRESTClientStrictDecodingAndUnexpectedStatuses(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
statusCode int
|
|
body string
|
|
wantErrText string
|
|
}{
|
|
{
|
|
name: "rejects unknown field",
|
|
statusCode: http.StatusOK,
|
|
body: `{"outcome":"sent","extra":true}`,
|
|
wantErrText: "decode response body",
|
|
},
|
|
{
|
|
name: "rejects unsupported outcome",
|
|
statusCode: http.StatusOK,
|
|
body: `{"outcome":"queued"}`,
|
|
wantErrText: "unsupported",
|
|
},
|
|
{
|
|
name: "rejects missing outcome",
|
|
statusCode: http.StatusOK,
|
|
body: `{}`,
|
|
wantErrText: "unsupported",
|
|
},
|
|
{
|
|
name: "rejects trailing json",
|
|
statusCode: http.StatusOK,
|
|
body: `{"outcome":"sent"}{}`,
|
|
wantErrText: "unexpected trailing JSON input",
|
|
},
|
|
{
|
|
name: "rejects unexpected status",
|
|
statusCode: http.StatusBadGateway,
|
|
body: `{"error":"temporary"}`,
|
|
wantErrText: "unexpected HTTP status 502",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(tt.statusCode)
|
|
_, err := io.WriteString(w, tt.body)
|
|
require.NoError(t, err)
|
|
}))
|
|
defer server.Close()
|
|
|
|
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
|
|
|
|
_, err := client.SendLoginCode(context.Background(), validInput())
|
|
require.Error(t, err)
|
|
assert.ErrorContains(t, err, tt.wantErrText)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRESTClientRequestTimeout(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
time.Sleep(40 * time.Millisecond)
|
|
writeJSON(t, w, http.StatusOK, map[string]string{"outcome": "sent"})
|
|
}))
|
|
defer server.Close()
|
|
|
|
client := newTestRESTClient(t, server.URL, 10*time.Millisecond)
|
|
|
|
_, err := client.SendLoginCode(context.Background(), validInput())
|
|
require.Error(t, err)
|
|
assert.ErrorContains(t, err, "context deadline exceeded")
|
|
}
|
|
|
|
func TestRESTClientContextAndValidation(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
t.Fatalf("unexpected upstream call")
|
|
}))
|
|
defer server.Close()
|
|
|
|
client := newTestRESTClient(t, server.URL, 250*time.Millisecond)
|
|
cancelledCtx, cancel := context.WithCancel(context.Background())
|
|
cancel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
run func() error
|
|
}{
|
|
{
|
|
name: "nil context",
|
|
run: func() error {
|
|
_, err := client.SendLoginCode(nil, validInput())
|
|
return err
|
|
},
|
|
},
|
|
{
|
|
name: "cancelled context",
|
|
run: func() error {
|
|
_, err := client.SendLoginCode(cancelledCtx, validInput())
|
|
return err
|
|
},
|
|
},
|
|
{
|
|
name: "invalid email",
|
|
run: func() error {
|
|
_, err := client.SendLoginCode(context.Background(), ports.SendLoginCodeInput{
|
|
Email: common.Email(" bad@example.com "),
|
|
IdempotencyKey: "challenge-1",
|
|
Code: "123456",
|
|
Locale: "en",
|
|
})
|
|
return err
|
|
},
|
|
},
|
|
{
|
|
name: "invalid code",
|
|
run: func() error {
|
|
_, err := client.SendLoginCode(context.Background(), ports.SendLoginCodeInput{
|
|
Email: common.Email("pilot@example.com"),
|
|
IdempotencyKey: "challenge-1",
|
|
Code: " 123456 ",
|
|
Locale: "en",
|
|
})
|
|
return err
|
|
},
|
|
},
|
|
{
|
|
name: "invalid locale",
|
|
run: func() error {
|
|
_, err := client.SendLoginCode(context.Background(), ports.SendLoginCodeInput{
|
|
Email: common.Email("pilot@example.com"),
|
|
IdempotencyKey: "challenge-1",
|
|
Code: "123456",
|
|
Locale: " en ",
|
|
})
|
|
return err
|
|
},
|
|
},
|
|
{
|
|
name: "invalid idempotency key",
|
|
run: func() error {
|
|
_, err := client.SendLoginCode(context.Background(), ports.SendLoginCodeInput{
|
|
Email: common.Email("pilot@example.com"),
|
|
IdempotencyKey: " challenge-1 ",
|
|
Code: "123456",
|
|
Locale: "en",
|
|
})
|
|
return err
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
err := tt.run()
|
|
require.Error(t, err)
|
|
})
|
|
}
|
|
}
|
|
|
|
type capturedRequest struct {
|
|
Method string
|
|
Path string
|
|
ContentType string
|
|
IdempotencyKey string
|
|
Body string
|
|
}
|
|
|
|
func captureRequest(t *testing.T, request *http.Request) capturedRequest {
|
|
t.Helper()
|
|
|
|
body, err := io.ReadAll(request.Body)
|
|
require.NoError(t, err)
|
|
|
|
return capturedRequest{
|
|
Method: request.Method,
|
|
Path: request.URL.Path,
|
|
ContentType: request.Header.Get("Content-Type"),
|
|
IdempotencyKey: request.Header.Get("Idempotency-Key"),
|
|
Body: strings.TrimSpace(string(body)),
|
|
}
|
|
}
|
|
|
|
func writeJSON(t *testing.T, writer http.ResponseWriter, statusCode int, value any) {
|
|
t.Helper()
|
|
|
|
payload, err := json.Marshal(value)
|
|
require.NoError(t, err)
|
|
|
|
writer.Header().Set("Content-Type", "application/json")
|
|
writer.WriteHeader(statusCode)
|
|
_, err = writer.Write(payload)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func newTestRESTClient(t *testing.T, baseURL string, timeout time.Duration) *RESTClient {
|
|
t.Helper()
|
|
|
|
client, err := NewRESTClient(Config{
|
|
BaseURL: baseURL,
|
|
RequestTimeout: timeout,
|
|
})
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() {
|
|
assert.NoError(t, client.Close())
|
|
})
|
|
|
|
return client
|
|
}
|
|
|
|
type roundTripperFunc func(*http.Request) (*http.Response, error)
|
|
|
|
func (fn roundTripperFunc) RoundTrip(request *http.Request) (*http.Response, error) {
|
|
return fn(request)
|
|
}
|