378 lines
8.8 KiB
Go
378 lines
8.8 KiB
Go
package harness
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"crypto/x509/pkix"
|
|
"encoding/pem"
|
|
"io"
|
|
"math/big"
|
|
"net"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
// SMTPCaptureConfig configures one local SMTP capture server.
|
|
type SMTPCaptureConfig struct {
|
|
// SupportsSTARTTLS controls whether the server advertises and accepts the
|
|
// STARTTLS upgrade command.
|
|
SupportsSTARTTLS bool
|
|
|
|
// FinalDataReply stores the final SMTP status line returned after the
|
|
// message body has been received. Empty value keeps the default accepted
|
|
// reply.
|
|
FinalDataReply string
|
|
}
|
|
|
|
// SMTPCapture stores one running local SMTP capture server together with the
|
|
// generated trust anchor used by external processes.
|
|
type SMTPCapture struct {
|
|
addr string
|
|
rootCAPath string
|
|
listener net.Listener
|
|
tlsConfig *tls.Config
|
|
|
|
connsMu sync.Mutex
|
|
conns map[net.Conn]struct{}
|
|
|
|
payloadsMu sync.Mutex
|
|
payloads []string
|
|
|
|
acceptWG sync.WaitGroup
|
|
connWG sync.WaitGroup
|
|
}
|
|
|
|
// StartSMTPCapture starts one local SMTP server suitable for black-box tests
|
|
// that need to observe captured message payloads.
|
|
func StartSMTPCapture(t testing.TB, cfg SMTPCaptureConfig) *SMTPCapture {
|
|
t.Helper()
|
|
|
|
if cfg.FinalDataReply == "" {
|
|
cfg.FinalDataReply = "250 2.0.0 accepted"
|
|
}
|
|
|
|
serverCertificate, rootCAPEM := newSMTPCertificates(t)
|
|
rootCAPath := filepath.Join(t.TempDir(), "smtp-root-ca.pem")
|
|
if err := os.WriteFile(rootCAPath, rootCAPEM, 0o600); err != nil {
|
|
t.Fatalf("write SMTP root CA: %v", err)
|
|
}
|
|
|
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("start SMTP capture listener: %v", err)
|
|
}
|
|
|
|
capture := &SMTPCapture{
|
|
addr: listener.Addr().String(),
|
|
rootCAPath: rootCAPath,
|
|
listener: listener,
|
|
tlsConfig: &tls.Config{
|
|
Certificates: []tls.Certificate{serverCertificate},
|
|
MinVersion: tls.VersionTLS12,
|
|
},
|
|
conns: make(map[net.Conn]struct{}),
|
|
}
|
|
|
|
capture.acceptWG.Add(1)
|
|
go func() {
|
|
defer capture.acceptWG.Done()
|
|
for {
|
|
conn, err := listener.Accept()
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
capture.trackConn(conn)
|
|
capture.connWG.Add(1)
|
|
go func() {
|
|
defer capture.connWG.Done()
|
|
defer capture.untrackConn(conn)
|
|
defer func() {
|
|
_ = conn.Close()
|
|
}()
|
|
|
|
capture.serveConnection(conn, cfg)
|
|
}()
|
|
}
|
|
}()
|
|
|
|
t.Cleanup(func() {
|
|
_ = capture.listener.Close()
|
|
capture.closeConnections()
|
|
capture.acceptWG.Wait()
|
|
capture.connWG.Wait()
|
|
})
|
|
|
|
return capture
|
|
}
|
|
|
|
// Addr returns the externally reachable TCP address of the capture server.
|
|
func (capture *SMTPCapture) Addr() string {
|
|
if capture == nil {
|
|
return ""
|
|
}
|
|
|
|
return capture.addr
|
|
}
|
|
|
|
// RootCAPath returns the PEM path that should be trusted by clients talking to
|
|
// the capture server over STARTTLS.
|
|
func (capture *SMTPCapture) RootCAPath() string {
|
|
if capture == nil {
|
|
return ""
|
|
}
|
|
|
|
return capture.rootCAPath
|
|
}
|
|
|
|
// LatestPayload returns the most recently captured SMTP DATA payload.
|
|
func (capture *SMTPCapture) LatestPayload() string {
|
|
if capture == nil {
|
|
return ""
|
|
}
|
|
|
|
capture.payloadsMu.Lock()
|
|
defer capture.payloadsMu.Unlock()
|
|
|
|
if len(capture.payloads) == 0 {
|
|
return ""
|
|
}
|
|
|
|
return capture.payloads[len(capture.payloads)-1]
|
|
}
|
|
|
|
func (capture *SMTPCapture) trackConn(conn net.Conn) {
|
|
capture.connsMu.Lock()
|
|
defer capture.connsMu.Unlock()
|
|
capture.conns[conn] = struct{}{}
|
|
}
|
|
|
|
func (capture *SMTPCapture) untrackConn(conn net.Conn) {
|
|
capture.connsMu.Lock()
|
|
defer capture.connsMu.Unlock()
|
|
delete(capture.conns, conn)
|
|
}
|
|
|
|
func (capture *SMTPCapture) closeConnections() {
|
|
capture.connsMu.Lock()
|
|
defer capture.connsMu.Unlock()
|
|
|
|
for conn := range capture.conns {
|
|
_ = conn.Close()
|
|
}
|
|
}
|
|
|
|
func (capture *SMTPCapture) appendPayload(payload string) {
|
|
capture.payloadsMu.Lock()
|
|
defer capture.payloadsMu.Unlock()
|
|
capture.payloads = append(capture.payloads, payload)
|
|
}
|
|
|
|
func (capture *SMTPCapture) serveConnection(conn net.Conn, cfg SMTPCaptureConfig) {
|
|
reader := newSMTPLineReader(conn)
|
|
writer := newSMTPLineWriter(conn)
|
|
writer.writeLine("220 localhost ESMTP")
|
|
|
|
tlsActive := false
|
|
for {
|
|
line, err := reader.readLine()
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
command := strings.ToUpper(line)
|
|
switch {
|
|
case strings.HasPrefix(command, "EHLO "), strings.HasPrefix(command, "HELO "):
|
|
if cfg.SupportsSTARTTLS && !tlsActive {
|
|
writer.writeLines(
|
|
"250-localhost",
|
|
"250-8BITMIME",
|
|
"250-STARTTLS",
|
|
"250 SMTPUTF8",
|
|
)
|
|
continue
|
|
}
|
|
|
|
writer.writeLines(
|
|
"250-localhost",
|
|
"250-8BITMIME",
|
|
"250 SMTPUTF8",
|
|
)
|
|
case command == "STARTTLS":
|
|
if !cfg.SupportsSTARTTLS {
|
|
writer.writeLine("454 4.7.0 TLS not available")
|
|
continue
|
|
}
|
|
|
|
writer.writeLine("220 Ready to start TLS")
|
|
tlsConn := tls.Server(conn, capture.tlsConfig)
|
|
if err := tlsConn.Handshake(); err != nil {
|
|
return
|
|
}
|
|
|
|
capture.trackConn(tlsConn)
|
|
capture.untrackConn(conn)
|
|
conn = tlsConn
|
|
reader = newSMTPLineReader(conn)
|
|
writer = newSMTPLineWriter(conn)
|
|
tlsActive = true
|
|
case strings.HasPrefix(command, "MAIL FROM:"):
|
|
writer.writeLine("250 2.1.0 Ok")
|
|
case strings.HasPrefix(command, "RCPT TO:"):
|
|
writer.writeLine("250 2.1.5 Ok")
|
|
case command == "DATA":
|
|
writer.writeLine("354 End data with <CR><LF>.<CR><LF>")
|
|
|
|
var payload strings.Builder
|
|
for {
|
|
dataLine, err := reader.readRawLine()
|
|
if err != nil {
|
|
return
|
|
}
|
|
if dataLine == ".\r\n" {
|
|
break
|
|
}
|
|
payload.WriteString(dataLine)
|
|
}
|
|
|
|
capture.appendPayload(payload.String())
|
|
writer.writeLine(cfg.FinalDataReply)
|
|
case command == "RSET":
|
|
writer.writeLine("250 2.0.0 Ok")
|
|
case command == "QUIT":
|
|
writer.writeLine("221 2.0.0 Bye")
|
|
return
|
|
default:
|
|
writer.writeLine("250 2.0.0 Ok")
|
|
}
|
|
}
|
|
}
|
|
|
|
type smtpLineReader struct {
|
|
conn net.Conn
|
|
}
|
|
|
|
func newSMTPLineReader(conn net.Conn) *smtpLineReader {
|
|
return &smtpLineReader{conn: conn}
|
|
}
|
|
|
|
func (reader *smtpLineReader) readLine() (string, error) {
|
|
line, err := reader.readRawLine()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return strings.TrimSuffix(strings.TrimSuffix(line, "\n"), "\r"), nil
|
|
}
|
|
|
|
func (reader *smtpLineReader) readRawLine() (string, error) {
|
|
var buffer bytes.Buffer
|
|
tmp := make([]byte, 1)
|
|
for {
|
|
if _, err := reader.conn.Read(tmp); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
buffer.WriteByte(tmp[0])
|
|
if tmp[0] == '\n' {
|
|
return buffer.String(), nil
|
|
}
|
|
}
|
|
}
|
|
|
|
type smtpLineWriter struct {
|
|
conn net.Conn
|
|
}
|
|
|
|
func newSMTPLineWriter(conn net.Conn) *smtpLineWriter {
|
|
return &smtpLineWriter{conn: conn}
|
|
}
|
|
|
|
func (writer *smtpLineWriter) writeLine(line string) {
|
|
_, _ = io.WriteString(writer.conn, line+"\r\n")
|
|
}
|
|
|
|
func (writer *smtpLineWriter) writeLines(lines ...string) {
|
|
for _, line := range lines {
|
|
writer.writeLine(line)
|
|
}
|
|
}
|
|
|
|
func newSMTPCertificates(t testing.TB) (tls.Certificate, []byte) {
|
|
t.Helper()
|
|
|
|
rootKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
t.Fatalf("generate SMTP root key: %v", err)
|
|
}
|
|
|
|
now := time.Now()
|
|
rootTemplate := x509.Certificate{
|
|
SerialNumber: big.NewInt(1),
|
|
Subject: pkix.Name{
|
|
CommonName: "galaxy-integration-smtp-root",
|
|
},
|
|
NotBefore: now.Add(-time.Hour),
|
|
NotAfter: now.Add(24 * time.Hour),
|
|
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign | x509.KeyUsageDigitalSignature,
|
|
IsCA: true,
|
|
BasicConstraintsValid: true,
|
|
}
|
|
|
|
rootDER, err := x509.CreateCertificate(rand.Reader, &rootTemplate, &rootTemplate, &rootKey.PublicKey, rootKey)
|
|
if err != nil {
|
|
t.Fatalf("create SMTP root certificate: %v", err)
|
|
}
|
|
|
|
rootPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: rootDER})
|
|
|
|
serverKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
t.Fatalf("generate SMTP server key: %v", err)
|
|
}
|
|
|
|
serverTemplate := x509.Certificate{
|
|
SerialNumber: big.NewInt(2),
|
|
Subject: pkix.Name{
|
|
CommonName: "127.0.0.1",
|
|
},
|
|
NotBefore: now.Add(-time.Hour),
|
|
NotAfter: now.Add(24 * time.Hour),
|
|
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
|
BasicConstraintsValid: true,
|
|
DNSNames: []string{"localhost"},
|
|
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
|
}
|
|
|
|
rootCert, err := x509.ParseCertificate(rootDER)
|
|
if err != nil {
|
|
t.Fatalf("parse SMTP root certificate: %v", err)
|
|
}
|
|
|
|
serverDER, err := x509.CreateCertificate(rand.Reader, &serverTemplate, rootCert, &serverKey.PublicKey, rootKey)
|
|
if err != nil {
|
|
t.Fatalf("create SMTP server certificate: %v", err)
|
|
}
|
|
|
|
serverPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: serverDER})
|
|
serverKeyPEM := pem.EncodeToMemory(&pem.Block{
|
|
Type: "RSA PRIVATE KEY",
|
|
Bytes: x509.MarshalPKCS1PrivateKey(serverKey),
|
|
})
|
|
|
|
certificate, err := tls.X509KeyPair(append(serverPEM, rootPEM...), serverKeyPEM)
|
|
if err != nil {
|
|
t.Fatalf("load SMTP server key pair: %v", err)
|
|
}
|
|
|
|
return certificate, rootPEM
|
|
}
|