454 lines
11 KiB
Go
454 lines
11 KiB
Go
package smtp
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"crypto/x509/pkix"
|
|
"encoding/pem"
|
|
"io"
|
|
"math/big"
|
|
"net"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"galaxy/mail/internal/domain/common"
|
|
deliverydomain "galaxy/mail/internal/domain/delivery"
|
|
"galaxy/mail/internal/ports"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestProviderBuildMessageIncludesHeadersBodiesAndAttachments(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
provider := newTestProvider(t)
|
|
message := testMessage(t)
|
|
|
|
msg, err := provider.buildMessage(message)
|
|
require.NoError(t, err)
|
|
|
|
var buffer bytes.Buffer
|
|
_, err = msg.WriteTo(&buffer)
|
|
require.NoError(t, err)
|
|
|
|
payload := buffer.String()
|
|
require.Contains(t, payload, "From: \"Galaxy Mail\" <noreply@example.com>")
|
|
require.Contains(t, payload, "To: <pilot@example.com>")
|
|
require.Contains(t, payload, "Cc: <copilot@example.com>")
|
|
require.Contains(t, payload, "Reply-To: <reply@example.com>")
|
|
require.Contains(t, payload, "Subject: Turn update")
|
|
require.Contains(t, payload, "multipart/mixed")
|
|
require.Contains(t, payload, "multipart/alternative")
|
|
require.Contains(t, payload, "text/plain")
|
|
require.Contains(t, payload, "text/html")
|
|
require.Contains(t, payload, "guide.txt")
|
|
require.Contains(t, payload, "charset=utf-8")
|
|
require.NotContains(t, payload, "\nBcc:")
|
|
}
|
|
|
|
func TestProviderSendClassifiesAccepted(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := startSMTPTestServer(t, smtpTestServerConfig{
|
|
supportsSTARTTLS: true,
|
|
finalDataReply: "250 2.0.0 accepted",
|
|
})
|
|
|
|
provider := newLiveProvider(t, server.addr)
|
|
result, err := provider.Send(context.Background(), testMessage(t))
|
|
require.NoError(t, err)
|
|
require.Equal(t, ports.ClassificationAccepted, result.Classification)
|
|
require.Equal(t, "provider=smtp result=accepted", result.Summary)
|
|
require.Contains(t, server.data(), "Subject: Turn update")
|
|
require.NotContains(t, server.data(), "\nBcc:")
|
|
}
|
|
|
|
func TestProviderSendClassifiesTransientSMTPFailure(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := startSMTPTestServer(t, smtpTestServerConfig{
|
|
supportsSTARTTLS: true,
|
|
finalDataReply: "451 4.3.0 temporary_failure",
|
|
})
|
|
|
|
provider := newLiveProvider(t, server.addr)
|
|
result, err := provider.Send(context.Background(), testMessage(t))
|
|
require.NoError(t, err)
|
|
require.Equal(t, ports.ClassificationTransientFailure, result.Classification)
|
|
require.Contains(t, result.Summary, "provider=smtp")
|
|
require.Contains(t, result.Summary, "result=transient_failure")
|
|
require.Contains(t, result.Summary, "phase=data")
|
|
require.Contains(t, result.Summary, "smtp_code=451")
|
|
}
|
|
|
|
func TestProviderSendClassifiesPermanentSMTPFailure(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := startSMTPTestServer(t, smtpTestServerConfig{
|
|
supportsSTARTTLS: true,
|
|
finalDataReply: "550 5.7.1 permanent_failure",
|
|
})
|
|
|
|
provider := newLiveProvider(t, server.addr)
|
|
result, err := provider.Send(context.Background(), testMessage(t))
|
|
require.NoError(t, err)
|
|
require.Equal(t, ports.ClassificationPermanentFailure, result.Classification)
|
|
require.Contains(t, result.Summary, "provider=smtp")
|
|
require.Contains(t, result.Summary, "result=permanent_failure")
|
|
require.Contains(t, result.Summary, "phase=data")
|
|
require.Contains(t, result.Summary, "smtp_code=550")
|
|
}
|
|
|
|
func TestProviderSendClassifiesMissingSTARTTLSAsPermanentFailure(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
server := startSMTPTestServer(t, smtpTestServerConfig{
|
|
supportsSTARTTLS: false,
|
|
finalDataReply: "250 2.0.0 accepted",
|
|
})
|
|
|
|
provider := newLiveProvider(t, server.addr)
|
|
result, err := provider.Send(context.Background(), testMessage(t))
|
|
require.NoError(t, err)
|
|
require.Equal(t, ports.ClassificationPermanentFailure, result.Classification)
|
|
require.Contains(t, result.Summary, "provider=smtp")
|
|
require.Contains(t, result.Summary, "result=permanent_failure")
|
|
require.Contains(t, result.Summary, "phase=tls")
|
|
}
|
|
|
|
func TestProviderSendClassifiesExpiredDeadlineAsTransientFailure(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
provider := newTestProvider(t)
|
|
|
|
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-time.Second))
|
|
defer cancel()
|
|
|
|
result, err := provider.Send(ctx, testMessage(t))
|
|
require.NoError(t, err)
|
|
require.Equal(t, ports.ClassificationTransientFailure, result.Classification)
|
|
require.Contains(t, result.Summary, "result=transient_failure")
|
|
require.Contains(t, result.Summary, "phase=context")
|
|
}
|
|
|
|
func TestNewRejectsUnpairedAuthConfiguration(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
_, err := New(Config{
|
|
Addr: "127.0.0.1:2525",
|
|
Username: "mailer",
|
|
FromEmail: "noreply@example.com",
|
|
Timeout: time.Second,
|
|
})
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), "smtp username and password")
|
|
}
|
|
|
|
func newTestProvider(t *testing.T) *Provider {
|
|
t.Helper()
|
|
|
|
provider, err := New(Config{
|
|
Addr: "127.0.0.1:2525",
|
|
FromEmail: "noreply@example.com",
|
|
FromName: "Galaxy Mail",
|
|
Timeout: 15 * time.Second,
|
|
TLSConfig: &tls.Config{
|
|
ServerName: "localhost",
|
|
InsecureSkipVerify: true, //nolint:gosec // test-only self-signed SMTP server.
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() {
|
|
require.NoError(t, provider.Close())
|
|
})
|
|
|
|
return provider
|
|
}
|
|
|
|
func newLiveProvider(t *testing.T, addr string) *Provider {
|
|
t.Helper()
|
|
|
|
provider, err := New(Config{
|
|
Addr: addr,
|
|
FromEmail: "noreply@example.com",
|
|
FromName: "Galaxy Mail",
|
|
Timeout: 5 * time.Second,
|
|
TLSConfig: &tls.Config{
|
|
ServerName: "localhost",
|
|
InsecureSkipVerify: true, //nolint:gosec // test-only self-signed SMTP server.
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() {
|
|
require.NoError(t, provider.Close())
|
|
})
|
|
|
|
return provider
|
|
}
|
|
|
|
func testMessage(t *testing.T) ports.Message {
|
|
t.Helper()
|
|
|
|
message := ports.Message{
|
|
Envelope: deliverydomain.Envelope{
|
|
To: []common.Email{common.Email("pilot@example.com")},
|
|
Cc: []common.Email{common.Email("copilot@example.com")},
|
|
Bcc: []common.Email{common.Email("ops@example.com")},
|
|
ReplyTo: []common.Email{common.Email("reply@example.com")},
|
|
},
|
|
Content: deliverydomain.Content{
|
|
Subject: "Turn update",
|
|
TextBody: "Turn 54 is ready.",
|
|
HTMLBody: "<p>Turn <strong>54</strong> is ready.</p>",
|
|
},
|
|
Attachments: []ports.Attachment{
|
|
{
|
|
Metadata: common.AttachmentMetadata{
|
|
Filename: "guide.txt",
|
|
ContentType: "text/plain; charset=utf-8",
|
|
SizeBytes: int64(len([]byte("read me"))),
|
|
},
|
|
Content: []byte("read me"),
|
|
},
|
|
},
|
|
}
|
|
require.NoError(t, message.Validate())
|
|
|
|
return message
|
|
}
|
|
|
|
type smtpTestServerConfig struct {
|
|
supportsSTARTTLS bool
|
|
finalDataReply string
|
|
}
|
|
|
|
type smtpTestServer struct {
|
|
addr string
|
|
listener net.Listener
|
|
tlsConfig *tls.Config
|
|
|
|
mu sync.Mutex
|
|
conn net.Conn
|
|
payload strings.Builder
|
|
}
|
|
|
|
func startSMTPTestServer(t *testing.T, cfg smtpTestServerConfig) *smtpTestServer {
|
|
t.Helper()
|
|
|
|
certificate := newTestCertificate(t)
|
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
require.NoError(t, err)
|
|
|
|
server := &smtpTestServer{
|
|
addr: listener.Addr().String(),
|
|
listener: listener,
|
|
tlsConfig: &tls.Config{
|
|
Certificates: []tls.Certificate{certificate},
|
|
MinVersion: tls.VersionTLS12,
|
|
},
|
|
}
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer close(done)
|
|
|
|
conn, err := listener.Accept()
|
|
if err != nil {
|
|
return
|
|
}
|
|
server.mu.Lock()
|
|
server.conn = conn
|
|
server.mu.Unlock()
|
|
defer func() {
|
|
_ = conn.Close()
|
|
}()
|
|
|
|
server.serveConnection(conn, cfg)
|
|
}()
|
|
|
|
t.Cleanup(func() {
|
|
server.mu.Lock()
|
|
if server.conn != nil {
|
|
_ = server.conn.Close()
|
|
}
|
|
server.mu.Unlock()
|
|
_ = listener.Close()
|
|
<-done
|
|
})
|
|
|
|
return server
|
|
}
|
|
|
|
func (server *smtpTestServer) data() string {
|
|
server.mu.Lock()
|
|
defer server.mu.Unlock()
|
|
return server.payload.String()
|
|
}
|
|
|
|
func (server *smtpTestServer) serveConnection(conn net.Conn, cfg smtpTestServerConfig) {
|
|
reader := newSMTPLineReader(conn)
|
|
writer := newSMTPLineWriter(conn)
|
|
writer.writeLine("220 localhost ESMTP")
|
|
|
|
tlsActive := false
|
|
for {
|
|
line, err := reader.readLine()
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
command := strings.ToUpper(line)
|
|
switch {
|
|
case strings.HasPrefix(command, "EHLO "), strings.HasPrefix(command, "HELO "):
|
|
if cfg.supportsSTARTTLS && !tlsActive {
|
|
writer.writeLines(
|
|
"250-localhost",
|
|
"250-8BITMIME",
|
|
"250-STARTTLS",
|
|
"250 SMTPUTF8",
|
|
)
|
|
continue
|
|
}
|
|
writer.writeLines(
|
|
"250-localhost",
|
|
"250-8BITMIME",
|
|
"250 SMTPUTF8",
|
|
)
|
|
case command == "STARTTLS":
|
|
writer.writeLine("220 Ready to start TLS")
|
|
tlsConn := tls.Server(conn, server.tlsConfig)
|
|
if err := tlsConn.Handshake(); err != nil {
|
|
return
|
|
}
|
|
conn = tlsConn
|
|
server.mu.Lock()
|
|
server.conn = conn
|
|
server.mu.Unlock()
|
|
reader = newSMTPLineReader(conn)
|
|
writer = newSMTPLineWriter(conn)
|
|
tlsActive = true
|
|
case strings.HasPrefix(command, "MAIL FROM:"):
|
|
writer.writeLine("250 2.1.0 Ok")
|
|
case strings.HasPrefix(command, "RCPT TO:"):
|
|
writer.writeLine("250 2.1.5 Ok")
|
|
case command == "DATA":
|
|
writer.writeLine("354 End data with <CR><LF>.<CR><LF>")
|
|
|
|
var builder strings.Builder
|
|
for {
|
|
dataLine, err := reader.readRawLine()
|
|
if err != nil {
|
|
return
|
|
}
|
|
if dataLine == ".\r\n" {
|
|
break
|
|
}
|
|
builder.WriteString(dataLine)
|
|
}
|
|
|
|
server.mu.Lock()
|
|
server.payload.WriteString(builder.String())
|
|
server.mu.Unlock()
|
|
|
|
writer.writeLine(cfg.finalDataReply)
|
|
case command == "RSET":
|
|
writer.writeLine("250 2.0.0 Ok")
|
|
case command == "QUIT":
|
|
writer.writeLine("221 2.0.0 Bye")
|
|
return
|
|
default:
|
|
writer.writeLine("250 2.0.0 Ok")
|
|
}
|
|
}
|
|
}
|
|
|
|
type smtpLineReader struct {
|
|
reader *bytes.Buffer
|
|
conn net.Conn
|
|
}
|
|
|
|
func newSMTPLineReader(conn net.Conn) *smtpLineReader {
|
|
return &smtpLineReader{conn: conn}
|
|
}
|
|
|
|
func (reader *smtpLineReader) readLine() (string, error) {
|
|
line, err := reader.readRawLine()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return strings.TrimSuffix(strings.TrimSuffix(line, "\n"), "\r"), nil
|
|
}
|
|
|
|
func (reader *smtpLineReader) readRawLine() (string, error) {
|
|
var buffer bytes.Buffer
|
|
tmp := make([]byte, 1)
|
|
for {
|
|
_, err := reader.conn.Read(tmp)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
buffer.WriteByte(tmp[0])
|
|
if tmp[0] == '\n' {
|
|
return buffer.String(), nil
|
|
}
|
|
}
|
|
}
|
|
|
|
type smtpLineWriter struct {
|
|
conn net.Conn
|
|
}
|
|
|
|
func newSMTPLineWriter(conn net.Conn) *smtpLineWriter {
|
|
return &smtpLineWriter{conn: conn}
|
|
}
|
|
|
|
func (writer *smtpLineWriter) writeLine(line string) {
|
|
_, _ = io.WriteString(writer.conn, line+"\r\n")
|
|
}
|
|
|
|
func (writer *smtpLineWriter) writeLines(lines ...string) {
|
|
for _, line := range lines {
|
|
writer.writeLine(line)
|
|
}
|
|
}
|
|
|
|
func newTestCertificate(t *testing.T) tls.Certificate {
|
|
t.Helper()
|
|
|
|
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
require.NoError(t, err)
|
|
|
|
template := x509.Certificate{
|
|
SerialNumber: big.NewInt(1),
|
|
Subject: pkix.Name{
|
|
CommonName: "localhost",
|
|
},
|
|
NotBefore: time.Now().Add(-time.Hour),
|
|
NotAfter: time.Now().Add(time.Hour),
|
|
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
|
BasicConstraintsValid: true,
|
|
DNSNames: []string{"localhost"},
|
|
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
|
}
|
|
|
|
der, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
|
|
require.NoError(t, err)
|
|
|
|
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
|
|
keyPEM := pem.EncodeToMemory(&pem.Block{
|
|
Type: "RSA PRIVATE KEY",
|
|
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
|
|
})
|
|
|
|
certificate, err := tls.X509KeyPair(certPEM, keyPEM)
|
|
require.NoError(t, err)
|
|
return certificate
|
|
}
|