feat: mail service
This commit is contained in:
@@ -0,0 +1,440 @@
|
||||
// Package smtp provides the SMTP-backed provider adapter used by Mail
|
||||
// Service.
|
||||
package smtp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
stdmail "net/mail"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"galaxy/mail/internal/ports"
|
||||
|
||||
gomail "github.com/wneessen/go-mail"
|
||||
)
|
||||
|
||||
const providerName = "smtp"
|
||||
|
||||
// Config stores the SMTP provider connection settings.
|
||||
type Config struct {
|
||||
// Addr stores the SMTP server network address.
|
||||
Addr string
|
||||
|
||||
// Username stores the optional SMTP authentication username.
|
||||
Username string
|
||||
|
||||
// Password stores the optional SMTP authentication password.
|
||||
Password string
|
||||
|
||||
// FromEmail stores the envelope sender mailbox.
|
||||
FromEmail string
|
||||
|
||||
// FromName stores the optional display name of the sender.
|
||||
FromName string
|
||||
|
||||
// Timeout stores the maximum SMTP dial-and-send window enforced by the
|
||||
// adapter when the caller does not provide an earlier deadline.
|
||||
Timeout time.Duration
|
||||
|
||||
// InsecureSkipVerify disables SMTP certificate verification. This is meant
|
||||
// only for local development and black-box tests with self-signed capture
|
||||
// servers.
|
||||
InsecureSkipVerify bool
|
||||
|
||||
// TLSConfig stores the optional TLS client configuration override used by
|
||||
// tests. Production wiring leaves it nil and uses secure defaults.
|
||||
TLSConfig *tls.Config
|
||||
}
|
||||
|
||||
// Provider stores the SMTP-backed delivery adapter.
|
||||
type Provider struct {
|
||||
client *gomail.Client
|
||||
fromEmail string
|
||||
fromName string
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
// New constructs one SMTP-backed provider and validates cfg.
|
||||
func New(cfg Config) (*Provider, error) {
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("new smtp provider: %w", err)
|
||||
}
|
||||
|
||||
host, portText, err := net.SplitHostPort(strings.TrimSpace(cfg.Addr))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new smtp provider: split smtp addr: %w", err)
|
||||
}
|
||||
port, err := strconv.Atoi(portText)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new smtp provider: parse smtp port: %w", err)
|
||||
}
|
||||
|
||||
options := []gomail.Option{
|
||||
gomail.WithPort(port),
|
||||
gomail.WithTimeout(cfg.Timeout),
|
||||
gomail.WithTLSPolicy(gomail.TLSMandatory),
|
||||
}
|
||||
if cfg.TLSConfig != nil {
|
||||
options = append(options, gomail.WithTLSConfig(cfg.TLSConfig))
|
||||
} else if cfg.InsecureSkipVerify {
|
||||
options = append(options, gomail.WithTLSConfig(&tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
ServerName: host,
|
||||
InsecureSkipVerify: true, //nolint:gosec // Explicit opt-in for local integration scenarios only.
|
||||
}))
|
||||
} else {
|
||||
options = append(options, gomail.WithTLSConfig(&tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
ServerName: host,
|
||||
}))
|
||||
}
|
||||
if cfg.Username != "" {
|
||||
options = append(options,
|
||||
gomail.WithUsername(cfg.Username),
|
||||
gomail.WithPassword(cfg.Password),
|
||||
gomail.WithSMTPAuth(gomail.SMTPAuthAutoDiscover),
|
||||
)
|
||||
}
|
||||
|
||||
client, err := gomail.NewClient(host, options...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new smtp provider: %w", err)
|
||||
}
|
||||
|
||||
return &Provider{
|
||||
client: client,
|
||||
fromEmail: cfg.FromEmail,
|
||||
fromName: cfg.FromName,
|
||||
timeout: cfg.Timeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Send attempts one outbound SMTP delivery and returns a classified provider
|
||||
// outcome whenever the interaction reached a stable SMTP result.
|
||||
func (provider *Provider) Send(ctx context.Context, message ports.Message) (ports.Result, error) {
|
||||
switch {
|
||||
case ctx == nil:
|
||||
return ports.Result{}, errors.New("send with smtp provider: nil context")
|
||||
case provider == nil || provider.client == nil:
|
||||
return ports.Result{}, errors.New("send with smtp provider: nil provider")
|
||||
}
|
||||
if err := message.Validate(); err != nil {
|
||||
return ports.Result{}, fmt.Errorf("send with smtp provider: %w", err)
|
||||
}
|
||||
|
||||
if err := ctx.Err(); err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return newResult(ports.ClassificationTransientFailure, summaryFields{
|
||||
Phase: "context",
|
||||
}, map[string]string{
|
||||
"phase": "context",
|
||||
"error": "deadline_exceeded",
|
||||
})
|
||||
}
|
||||
|
||||
return ports.Result{}, fmt.Errorf("send with smtp provider: %w", err)
|
||||
}
|
||||
|
||||
msg, err := provider.buildMessage(message)
|
||||
if err != nil {
|
||||
return newResult(ports.ClassificationPermanentFailure, summaryFields{
|
||||
Phase: "build",
|
||||
}, map[string]string{
|
||||
"phase": "build",
|
||||
"error": classifyLocalBuildError(err),
|
||||
})
|
||||
}
|
||||
|
||||
sendCtx, cancel := provider.sendContext(ctx)
|
||||
defer cancel()
|
||||
|
||||
err = provider.client.DialAndSendWithContext(sendCtx, msg)
|
||||
if err == nil {
|
||||
return newResult(ports.ClassificationAccepted, summaryFields{}, nil)
|
||||
}
|
||||
|
||||
return provider.classifySendError(err)
|
||||
}
|
||||
|
||||
// Close releases SMTP client resources.
|
||||
func (provider *Provider) Close() error {
|
||||
if provider == nil || provider.client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
provider.client.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate reports whether cfg stores a complete SMTP provider configuration.
|
||||
func (cfg Config) Validate() error {
|
||||
host, port, err := net.SplitHostPort(strings.TrimSpace(cfg.Addr))
|
||||
switch {
|
||||
case err != nil || port == "":
|
||||
return fmt.Errorf("smtp addr %q must use host:port form", cfg.Addr)
|
||||
case host != "" && strings.Contains(host, " "):
|
||||
return fmt.Errorf("smtp addr %q must use host:port form", cfg.Addr)
|
||||
case cfg.Timeout <= 0:
|
||||
return fmt.Errorf("smtp timeout must be positive")
|
||||
case strings.TrimSpace(cfg.Username) == "" && strings.TrimSpace(cfg.Password) != "":
|
||||
return fmt.Errorf("smtp username and password must be configured together")
|
||||
case strings.TrimSpace(cfg.Username) != "" && strings.TrimSpace(cfg.Password) == "":
|
||||
return fmt.Errorf("smtp username and password must be configured together")
|
||||
}
|
||||
|
||||
parsed, err := stdmail.ParseAddress(strings.TrimSpace(cfg.FromEmail))
|
||||
if err != nil || parsed == nil || parsed.Name != "" || parsed.Address != strings.TrimSpace(cfg.FromEmail) {
|
||||
return fmt.Errorf("smtp from email %q must be a single valid email address", cfg.FromEmail)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (provider *Provider) buildMessage(message ports.Message) (*gomail.Msg, error) {
|
||||
msg := gomail.NewMsg()
|
||||
msg.EnvelopeFrom(provider.fromEmail)
|
||||
|
||||
switch strings.TrimSpace(provider.fromName) {
|
||||
case "":
|
||||
if err := msg.From(provider.fromEmail); err != nil {
|
||||
return nil, fmt.Errorf("set from header: %w", err)
|
||||
}
|
||||
default:
|
||||
if err := msg.FromFormat(provider.fromName, provider.fromEmail); err != nil {
|
||||
return nil, fmt.Errorf("set from header: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
msg.SetBodyString(gomail.TypeTextPlain, message.Content.TextBody)
|
||||
if message.Content.HTMLBody != "" {
|
||||
msg.AddAlternativeString(gomail.TypeTextHTML, message.Content.HTMLBody)
|
||||
}
|
||||
msg.Subject(message.Content.Subject)
|
||||
|
||||
for _, address := range message.Envelope.To {
|
||||
if err := msg.AddTo(address.String()); err != nil {
|
||||
return nil, fmt.Errorf("add to recipient: %w", err)
|
||||
}
|
||||
}
|
||||
for _, address := range message.Envelope.Cc {
|
||||
if err := msg.AddCc(address.String()); err != nil {
|
||||
return nil, fmt.Errorf("add cc recipient: %w", err)
|
||||
}
|
||||
}
|
||||
for _, address := range message.Envelope.Bcc {
|
||||
if err := msg.AddBcc(address.String()); err != nil {
|
||||
return nil, fmt.Errorf("add bcc recipient: %w", err)
|
||||
}
|
||||
}
|
||||
for _, address := range message.Envelope.ReplyTo {
|
||||
if err := msg.ReplyTo(address.String()); err != nil {
|
||||
return nil, fmt.Errorf("add reply-to recipient: %w", err)
|
||||
}
|
||||
}
|
||||
for _, attachment := range message.Attachments {
|
||||
if err := attachment.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("attach file %q: %w", attachment.Metadata.Filename, err)
|
||||
}
|
||||
if err := msg.AttachReader(
|
||||
attachment.Metadata.Filename,
|
||||
bytes.NewReader(attachment.Content),
|
||||
gomail.WithFileContentType(gomail.ContentType(attachment.Metadata.ContentType)),
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("attach file %q: %w", attachment.Metadata.Filename, err)
|
||||
}
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (provider *Provider) classifySendError(err error) (ports.Result, error) {
|
||||
switch {
|
||||
case errors.Is(err, context.DeadlineExceeded):
|
||||
return newResult(ports.ClassificationTransientFailure, summaryFields{
|
||||
Phase: "send",
|
||||
}, map[string]string{
|
||||
"phase": "send",
|
||||
"error": "deadline_exceeded",
|
||||
})
|
||||
case strings.Contains(strings.ToLower(err.Error()), "starttls"):
|
||||
return newResult(ports.ClassificationPermanentFailure, summaryFields{
|
||||
Phase: "tls",
|
||||
}, map[string]string{
|
||||
"phase": "tls",
|
||||
"error": "starttls_required",
|
||||
})
|
||||
}
|
||||
|
||||
var sendErr *gomail.SendError
|
||||
if errors.As(err, &sendErr) {
|
||||
codeText := ""
|
||||
if code := sendErr.ErrorCode(); code > 0 {
|
||||
codeText = strconv.Itoa(code)
|
||||
}
|
||||
phase := smtpReasonPhase(sendErr, err)
|
||||
|
||||
details := map[string]string{
|
||||
"phase": phase,
|
||||
"error": sanitizeDetailValue(strings.ToLower(sendErr.Reason.String())),
|
||||
}
|
||||
if codeText != "" {
|
||||
details["smtp_code"] = codeText
|
||||
}
|
||||
|
||||
switch {
|
||||
case sendErr.ErrorCode() >= 500:
|
||||
return newResult(ports.ClassificationPermanentFailure, summaryFields{
|
||||
Phase: phase,
|
||||
SMTPCode: codeText,
|
||||
}, details)
|
||||
case sendErr.ErrorCode() >= 400:
|
||||
return newResult(ports.ClassificationTransientFailure, summaryFields{
|
||||
Phase: phase,
|
||||
SMTPCode: codeText,
|
||||
}, details)
|
||||
case sendErr.IsTemp():
|
||||
return newResult(ports.ClassificationTransientFailure, summaryFields{
|
||||
Phase: phase,
|
||||
}, details)
|
||||
default:
|
||||
return newResult(ports.ClassificationPermanentFailure, summaryFields{
|
||||
Phase: phase,
|
||||
}, details)
|
||||
}
|
||||
}
|
||||
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) {
|
||||
return newResult(ports.ClassificationTransientFailure, summaryFields{
|
||||
Phase: "dial",
|
||||
}, map[string]string{
|
||||
"phase": "dial",
|
||||
"net_op": "smtp",
|
||||
"net_err": sanitizeDetailValue(strings.ToLower(netErr.Error())),
|
||||
})
|
||||
}
|
||||
|
||||
return newResult(ports.ClassificationPermanentFailure, summaryFields{
|
||||
Phase: "send",
|
||||
}, map[string]string{
|
||||
"phase": "send",
|
||||
"error": sanitizeDetailValue(strings.ToLower(err.Error())),
|
||||
})
|
||||
}
|
||||
|
||||
func (provider *Provider) sendContext(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
remaining := time.Until(deadline)
|
||||
if remaining <= provider.timeout {
|
||||
return ctx, func() {}
|
||||
}
|
||||
}
|
||||
|
||||
return context.WithTimeout(ctx, provider.timeout)
|
||||
}
|
||||
|
||||
type summaryFields struct {
|
||||
Phase string
|
||||
SMTPCode string
|
||||
}
|
||||
|
||||
func newResult(classification ports.Classification, fields summaryFields, details map[string]string) (ports.Result, error) {
|
||||
summary, err := ports.BuildSafeSummary(ports.SummaryFields{
|
||||
Provider: providerName,
|
||||
Result: string(classification),
|
||||
Phase: fields.Phase,
|
||||
SMTPCode: fields.SMTPCode,
|
||||
})
|
||||
if err != nil {
|
||||
return ports.Result{}, fmt.Errorf("build smtp provider summary: %w", err)
|
||||
}
|
||||
|
||||
result := ports.Result{
|
||||
Classification: classification,
|
||||
Summary: summary,
|
||||
Details: ports.CloneDetails(details),
|
||||
}
|
||||
if err := result.Validate(); err != nil {
|
||||
return ports.Result{}, fmt.Errorf("build smtp provider result: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func classifyLocalBuildError(err error) string {
|
||||
return sanitizeDetailValue(strings.ToLower(err.Error()))
|
||||
}
|
||||
|
||||
func smtpReasonPhase(sendErr *gomail.SendError, err error) string {
|
||||
if sendErr == nil {
|
||||
return "send"
|
||||
}
|
||||
|
||||
switch sendErr.Reason {
|
||||
case gomail.ErrConnCheck:
|
||||
return "dial"
|
||||
case gomail.ErrSMTPMailFrom:
|
||||
return "mail_from"
|
||||
case gomail.ErrSMTPRcptTo:
|
||||
return "rcpt_to"
|
||||
case gomail.ErrSMTPData:
|
||||
return "data"
|
||||
case gomail.ErrSMTPDataClose:
|
||||
return "data"
|
||||
case gomail.ErrSMTPReset:
|
||||
return "reset"
|
||||
case gomail.ErrWriteContent:
|
||||
return "build"
|
||||
case gomail.ErrGetSender, gomail.ErrGetRcpts:
|
||||
return "build"
|
||||
case gomail.ErrNoUnencoded:
|
||||
return "build"
|
||||
default:
|
||||
lower := strings.ToLower(err.Error())
|
||||
switch {
|
||||
case strings.Contains(lower, "starttls"):
|
||||
return "tls"
|
||||
case strings.Contains(lower, "auth"):
|
||||
return "auth"
|
||||
default:
|
||||
return "send"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func sanitizeDetailValue(value string) string {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
var builder strings.Builder
|
||||
for _, r := range value {
|
||||
if r > 0x7f {
|
||||
builder.WriteByte('_')
|
||||
continue
|
||||
}
|
||||
switch {
|
||||
case r >= 'a' && r <= 'z':
|
||||
builder.WriteRune(r)
|
||||
case r >= '0' && r <= '9':
|
||||
builder.WriteRune(r)
|
||||
case r == '.', r == '_', r == '-':
|
||||
builder.WriteRune(r)
|
||||
default:
|
||||
builder.WriteByte('_')
|
||||
}
|
||||
}
|
||||
|
||||
if builder.Len() == 0 {
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
return builder.String()
|
||||
}
|
||||
@@ -0,0 +1,453 @@
|
||||
package smtp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"io"
|
||||
"math/big"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"galaxy/mail/internal/domain/common"
|
||||
deliverydomain "galaxy/mail/internal/domain/delivery"
|
||||
"galaxy/mail/internal/ports"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestProviderBuildMessageIncludesHeadersBodiesAndAttachments(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
provider := newTestProvider(t)
|
||||
message := testMessage(t)
|
||||
|
||||
msg, err := provider.buildMessage(message)
|
||||
require.NoError(t, err)
|
||||
|
||||
var buffer bytes.Buffer
|
||||
_, err = msg.WriteTo(&buffer)
|
||||
require.NoError(t, err)
|
||||
|
||||
payload := buffer.String()
|
||||
require.Contains(t, payload, "From: \"Galaxy Mail\" <noreply@example.com>")
|
||||
require.Contains(t, payload, "To: <pilot@example.com>")
|
||||
require.Contains(t, payload, "Cc: <copilot@example.com>")
|
||||
require.Contains(t, payload, "Reply-To: <reply@example.com>")
|
||||
require.Contains(t, payload, "Subject: Turn update")
|
||||
require.Contains(t, payload, "multipart/mixed")
|
||||
require.Contains(t, payload, "multipart/alternative")
|
||||
require.Contains(t, payload, "text/plain")
|
||||
require.Contains(t, payload, "text/html")
|
||||
require.Contains(t, payload, "guide.txt")
|
||||
require.Contains(t, payload, "charset=utf-8")
|
||||
require.NotContains(t, payload, "\nBcc:")
|
||||
}
|
||||
|
||||
func TestProviderSendClassifiesAccepted(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := startSMTPTestServer(t, smtpTestServerConfig{
|
||||
supportsSTARTTLS: true,
|
||||
finalDataReply: "250 2.0.0 accepted",
|
||||
})
|
||||
|
||||
provider := newLiveProvider(t, server.addr)
|
||||
result, err := provider.Send(context.Background(), testMessage(t))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ports.ClassificationAccepted, result.Classification)
|
||||
require.Equal(t, "provider=smtp result=accepted", result.Summary)
|
||||
require.Contains(t, server.data(), "Subject: Turn update")
|
||||
require.NotContains(t, server.data(), "\nBcc:")
|
||||
}
|
||||
|
||||
func TestProviderSendClassifiesTransientSMTPFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := startSMTPTestServer(t, smtpTestServerConfig{
|
||||
supportsSTARTTLS: true,
|
||||
finalDataReply: "451 4.3.0 temporary_failure",
|
||||
})
|
||||
|
||||
provider := newLiveProvider(t, server.addr)
|
||||
result, err := provider.Send(context.Background(), testMessage(t))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ports.ClassificationTransientFailure, result.Classification)
|
||||
require.Contains(t, result.Summary, "provider=smtp")
|
||||
require.Contains(t, result.Summary, "result=transient_failure")
|
||||
require.Contains(t, result.Summary, "phase=data")
|
||||
require.Contains(t, result.Summary, "smtp_code=451")
|
||||
}
|
||||
|
||||
func TestProviderSendClassifiesPermanentSMTPFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := startSMTPTestServer(t, smtpTestServerConfig{
|
||||
supportsSTARTTLS: true,
|
||||
finalDataReply: "550 5.7.1 permanent_failure",
|
||||
})
|
||||
|
||||
provider := newLiveProvider(t, server.addr)
|
||||
result, err := provider.Send(context.Background(), testMessage(t))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ports.ClassificationPermanentFailure, result.Classification)
|
||||
require.Contains(t, result.Summary, "provider=smtp")
|
||||
require.Contains(t, result.Summary, "result=permanent_failure")
|
||||
require.Contains(t, result.Summary, "phase=data")
|
||||
require.Contains(t, result.Summary, "smtp_code=550")
|
||||
}
|
||||
|
||||
func TestProviderSendClassifiesMissingSTARTTLSAsPermanentFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := startSMTPTestServer(t, smtpTestServerConfig{
|
||||
supportsSTARTTLS: false,
|
||||
finalDataReply: "250 2.0.0 accepted",
|
||||
})
|
||||
|
||||
provider := newLiveProvider(t, server.addr)
|
||||
result, err := provider.Send(context.Background(), testMessage(t))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ports.ClassificationPermanentFailure, result.Classification)
|
||||
require.Contains(t, result.Summary, "provider=smtp")
|
||||
require.Contains(t, result.Summary, "result=permanent_failure")
|
||||
require.Contains(t, result.Summary, "phase=tls")
|
||||
}
|
||||
|
||||
func TestProviderSendClassifiesExpiredDeadlineAsTransientFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
provider := newTestProvider(t)
|
||||
|
||||
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-time.Second))
|
||||
defer cancel()
|
||||
|
||||
result, err := provider.Send(ctx, testMessage(t))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ports.ClassificationTransientFailure, result.Classification)
|
||||
require.Contains(t, result.Summary, "result=transient_failure")
|
||||
require.Contains(t, result.Summary, "phase=context")
|
||||
}
|
||||
|
||||
func TestNewRejectsUnpairedAuthConfiguration(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := New(Config{
|
||||
Addr: "127.0.0.1:2525",
|
||||
Username: "mailer",
|
||||
FromEmail: "noreply@example.com",
|
||||
Timeout: time.Second,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "smtp username and password")
|
||||
}
|
||||
|
||||
func newTestProvider(t *testing.T) *Provider {
|
||||
t.Helper()
|
||||
|
||||
provider, err := New(Config{
|
||||
Addr: "127.0.0.1:2525",
|
||||
FromEmail: "noreply@example.com",
|
||||
FromName: "Galaxy Mail",
|
||||
Timeout: 15 * time.Second,
|
||||
TLSConfig: &tls.Config{
|
||||
ServerName: "localhost",
|
||||
InsecureSkipVerify: true, //nolint:gosec // test-only self-signed SMTP server.
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, provider.Close())
|
||||
})
|
||||
|
||||
return provider
|
||||
}
|
||||
|
||||
func newLiveProvider(t *testing.T, addr string) *Provider {
|
||||
t.Helper()
|
||||
|
||||
provider, err := New(Config{
|
||||
Addr: addr,
|
||||
FromEmail: "noreply@example.com",
|
||||
FromName: "Galaxy Mail",
|
||||
Timeout: 5 * time.Second,
|
||||
TLSConfig: &tls.Config{
|
||||
ServerName: "localhost",
|
||||
InsecureSkipVerify: true, //nolint:gosec // test-only self-signed SMTP server.
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, provider.Close())
|
||||
})
|
||||
|
||||
return provider
|
||||
}
|
||||
|
||||
func testMessage(t *testing.T) ports.Message {
|
||||
t.Helper()
|
||||
|
||||
message := ports.Message{
|
||||
Envelope: deliverydomain.Envelope{
|
||||
To: []common.Email{common.Email("pilot@example.com")},
|
||||
Cc: []common.Email{common.Email("copilot@example.com")},
|
||||
Bcc: []common.Email{common.Email("ops@example.com")},
|
||||
ReplyTo: []common.Email{common.Email("reply@example.com")},
|
||||
},
|
||||
Content: deliverydomain.Content{
|
||||
Subject: "Turn update",
|
||||
TextBody: "Turn 54 is ready.",
|
||||
HTMLBody: "<p>Turn <strong>54</strong> is ready.</p>",
|
||||
},
|
||||
Attachments: []ports.Attachment{
|
||||
{
|
||||
Metadata: common.AttachmentMetadata{
|
||||
Filename: "guide.txt",
|
||||
ContentType: "text/plain; charset=utf-8",
|
||||
SizeBytes: int64(len([]byte("read me"))),
|
||||
},
|
||||
Content: []byte("read me"),
|
||||
},
|
||||
},
|
||||
}
|
||||
require.NoError(t, message.Validate())
|
||||
|
||||
return message
|
||||
}
|
||||
|
||||
type smtpTestServerConfig struct {
|
||||
supportsSTARTTLS bool
|
||||
finalDataReply string
|
||||
}
|
||||
|
||||
type smtpTestServer struct {
|
||||
addr string
|
||||
listener net.Listener
|
||||
tlsConfig *tls.Config
|
||||
|
||||
mu sync.Mutex
|
||||
conn net.Conn
|
||||
payload strings.Builder
|
||||
}
|
||||
|
||||
func startSMTPTestServer(t *testing.T, cfg smtpTestServerConfig) *smtpTestServer {
|
||||
t.Helper()
|
||||
|
||||
certificate := newTestCertificate(t)
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
server := &smtpTestServer{
|
||||
addr: listener.Addr().String(),
|
||||
listener: listener,
|
||||
tlsConfig: &tls.Config{
|
||||
Certificates: []tls.Certificate{certificate},
|
||||
MinVersion: tls.VersionTLS12,
|
||||
},
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
server.mu.Lock()
|
||||
server.conn = conn
|
||||
server.mu.Unlock()
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
}()
|
||||
|
||||
server.serveConnection(conn, cfg)
|
||||
}()
|
||||
|
||||
t.Cleanup(func() {
|
||||
server.mu.Lock()
|
||||
if server.conn != nil {
|
||||
_ = server.conn.Close()
|
||||
}
|
||||
server.mu.Unlock()
|
||||
_ = listener.Close()
|
||||
<-done
|
||||
})
|
||||
|
||||
return server
|
||||
}
|
||||
|
||||
func (server *smtpTestServer) data() string {
|
||||
server.mu.Lock()
|
||||
defer server.mu.Unlock()
|
||||
return server.payload.String()
|
||||
}
|
||||
|
||||
func (server *smtpTestServer) serveConnection(conn net.Conn, cfg smtpTestServerConfig) {
|
||||
reader := newSMTPLineReader(conn)
|
||||
writer := newSMTPLineWriter(conn)
|
||||
writer.writeLine("220 localhost ESMTP")
|
||||
|
||||
tlsActive := false
|
||||
for {
|
||||
line, err := reader.readLine()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
command := strings.ToUpper(line)
|
||||
switch {
|
||||
case strings.HasPrefix(command, "EHLO "), strings.HasPrefix(command, "HELO "):
|
||||
if cfg.supportsSTARTTLS && !tlsActive {
|
||||
writer.writeLines(
|
||||
"250-localhost",
|
||||
"250-8BITMIME",
|
||||
"250-STARTTLS",
|
||||
"250 SMTPUTF8",
|
||||
)
|
||||
continue
|
||||
}
|
||||
writer.writeLines(
|
||||
"250-localhost",
|
||||
"250-8BITMIME",
|
||||
"250 SMTPUTF8",
|
||||
)
|
||||
case command == "STARTTLS":
|
||||
writer.writeLine("220 Ready to start TLS")
|
||||
tlsConn := tls.Server(conn, server.tlsConfig)
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
return
|
||||
}
|
||||
conn = tlsConn
|
||||
server.mu.Lock()
|
||||
server.conn = conn
|
||||
server.mu.Unlock()
|
||||
reader = newSMTPLineReader(conn)
|
||||
writer = newSMTPLineWriter(conn)
|
||||
tlsActive = true
|
||||
case strings.HasPrefix(command, "MAIL FROM:"):
|
||||
writer.writeLine("250 2.1.0 Ok")
|
||||
case strings.HasPrefix(command, "RCPT TO:"):
|
||||
writer.writeLine("250 2.1.5 Ok")
|
||||
case command == "DATA":
|
||||
writer.writeLine("354 End data with <CR><LF>.<CR><LF>")
|
||||
|
||||
var builder strings.Builder
|
||||
for {
|
||||
dataLine, err := reader.readRawLine()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if dataLine == ".\r\n" {
|
||||
break
|
||||
}
|
||||
builder.WriteString(dataLine)
|
||||
}
|
||||
|
||||
server.mu.Lock()
|
||||
server.payload.WriteString(builder.String())
|
||||
server.mu.Unlock()
|
||||
|
||||
writer.writeLine(cfg.finalDataReply)
|
||||
case command == "RSET":
|
||||
writer.writeLine("250 2.0.0 Ok")
|
||||
case command == "QUIT":
|
||||
writer.writeLine("221 2.0.0 Bye")
|
||||
return
|
||||
default:
|
||||
writer.writeLine("250 2.0.0 Ok")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type smtpLineReader struct {
|
||||
reader *bytes.Buffer
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
func newSMTPLineReader(conn net.Conn) *smtpLineReader {
|
||||
return &smtpLineReader{conn: conn}
|
||||
}
|
||||
|
||||
func (reader *smtpLineReader) readLine() (string, error) {
|
||||
line, err := reader.readRawLine()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strings.TrimSuffix(strings.TrimSuffix(line, "\n"), "\r"), nil
|
||||
}
|
||||
|
||||
func (reader *smtpLineReader) readRawLine() (string, error) {
|
||||
var buffer bytes.Buffer
|
||||
tmp := make([]byte, 1)
|
||||
for {
|
||||
_, err := reader.conn.Read(tmp)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
buffer.WriteByte(tmp[0])
|
||||
if tmp[0] == '\n' {
|
||||
return buffer.String(), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type smtpLineWriter struct {
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
func newSMTPLineWriter(conn net.Conn) *smtpLineWriter {
|
||||
return &smtpLineWriter{conn: conn}
|
||||
}
|
||||
|
||||
func (writer *smtpLineWriter) writeLine(line string) {
|
||||
_, _ = io.WriteString(writer.conn, line+"\r\n")
|
||||
}
|
||||
|
||||
func (writer *smtpLineWriter) writeLines(lines ...string) {
|
||||
for _, line := range lines {
|
||||
writer.writeLine(line)
|
||||
}
|
||||
}
|
||||
|
||||
func newTestCertificate(t *testing.T) tls.Certificate {
|
||||
t.Helper()
|
||||
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
template := x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{
|
||||
CommonName: "localhost",
|
||||
},
|
||||
NotBefore: time.Now().Add(-time.Hour),
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
DNSNames: []string{"localhost"},
|
||||
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
||||
}
|
||||
|
||||
der, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
|
||||
keyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
|
||||
})
|
||||
|
||||
certificate, err := tls.X509KeyPair(certPEM, keyPEM)
|
||||
require.NoError(t, err)
|
||||
return certificate
|
||||
}
|
||||
Reference in New Issue
Block a user