package harness import ( "bytes" "crypto/rand" "crypto/rsa" "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/pem" "io" "math/big" "net" "os" "path/filepath" "strings" "sync" "testing" "time" ) // SMTPCaptureConfig configures one local SMTP capture server. type SMTPCaptureConfig struct { // SupportsSTARTTLS controls whether the server advertises and accepts the // STARTTLS upgrade command. SupportsSTARTTLS bool // FinalDataReply stores the final SMTP status line returned after the // message body has been received. Empty value keeps the default accepted // reply. FinalDataReply string } // SMTPCapture stores one running local SMTP capture server together with the // generated trust anchor used by external processes. type SMTPCapture struct { addr string rootCAPath string listener net.Listener tlsConfig *tls.Config connsMu sync.Mutex conns map[net.Conn]struct{} payloadsMu sync.Mutex payloads []string acceptWG sync.WaitGroup connWG sync.WaitGroup } // StartSMTPCapture starts one local SMTP server suitable for black-box tests // that need to observe captured message payloads. func StartSMTPCapture(t testing.TB, cfg SMTPCaptureConfig) *SMTPCapture { t.Helper() if cfg.FinalDataReply == "" { cfg.FinalDataReply = "250 2.0.0 accepted" } serverCertificate, rootCAPEM := newSMTPCertificates(t) rootCAPath := filepath.Join(t.TempDir(), "smtp-root-ca.pem") if err := os.WriteFile(rootCAPath, rootCAPEM, 0o600); err != nil { t.Fatalf("write SMTP root CA: %v", err) } listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("start SMTP capture listener: %v", err) } capture := &SMTPCapture{ addr: listener.Addr().String(), rootCAPath: rootCAPath, listener: listener, tlsConfig: &tls.Config{ Certificates: []tls.Certificate{serverCertificate}, MinVersion: tls.VersionTLS12, }, conns: make(map[net.Conn]struct{}), } capture.acceptWG.Add(1) go func() { defer capture.acceptWG.Done() for { conn, err := listener.Accept() if err != nil { return } capture.trackConn(conn) capture.connWG.Add(1) go func() { defer capture.connWG.Done() defer capture.untrackConn(conn) defer func() { _ = conn.Close() }() capture.serveConnection(conn, cfg) }() } }() t.Cleanup(func() { _ = capture.listener.Close() capture.closeConnections() capture.acceptWG.Wait() capture.connWG.Wait() }) return capture } // Addr returns the externally reachable TCP address of the capture server. func (capture *SMTPCapture) Addr() string { if capture == nil { return "" } return capture.addr } // RootCAPath returns the PEM path that should be trusted by clients talking to // the capture server over STARTTLS. func (capture *SMTPCapture) RootCAPath() string { if capture == nil { return "" } return capture.rootCAPath } // LatestPayload returns the most recently captured SMTP DATA payload. func (capture *SMTPCapture) LatestPayload() string { if capture == nil { return "" } capture.payloadsMu.Lock() defer capture.payloadsMu.Unlock() if len(capture.payloads) == 0 { return "" } return capture.payloads[len(capture.payloads)-1] } func (capture *SMTPCapture) trackConn(conn net.Conn) { capture.connsMu.Lock() defer capture.connsMu.Unlock() capture.conns[conn] = struct{}{} } func (capture *SMTPCapture) untrackConn(conn net.Conn) { capture.connsMu.Lock() defer capture.connsMu.Unlock() delete(capture.conns, conn) } func (capture *SMTPCapture) closeConnections() { capture.connsMu.Lock() defer capture.connsMu.Unlock() for conn := range capture.conns { _ = conn.Close() } } func (capture *SMTPCapture) appendPayload(payload string) { capture.payloadsMu.Lock() defer capture.payloadsMu.Unlock() capture.payloads = append(capture.payloads, payload) } func (capture *SMTPCapture) serveConnection(conn net.Conn, cfg SMTPCaptureConfig) { reader := newSMTPLineReader(conn) writer := newSMTPLineWriter(conn) writer.writeLine("220 localhost ESMTP") tlsActive := false for { line, err := reader.readLine() if err != nil { return } command := strings.ToUpper(line) switch { case strings.HasPrefix(command, "EHLO "), strings.HasPrefix(command, "HELO "): if cfg.SupportsSTARTTLS && !tlsActive { writer.writeLines( "250-localhost", "250-8BITMIME", "250-STARTTLS", "250 SMTPUTF8", ) continue } writer.writeLines( "250-localhost", "250-8BITMIME", "250 SMTPUTF8", ) case command == "STARTTLS": if !cfg.SupportsSTARTTLS { writer.writeLine("454 4.7.0 TLS not available") continue } writer.writeLine("220 Ready to start TLS") tlsConn := tls.Server(conn, capture.tlsConfig) if err := tlsConn.Handshake(); err != nil { return } capture.trackConn(tlsConn) capture.untrackConn(conn) conn = tlsConn reader = newSMTPLineReader(conn) writer = newSMTPLineWriter(conn) tlsActive = true case strings.HasPrefix(command, "MAIL FROM:"): writer.writeLine("250 2.1.0 Ok") case strings.HasPrefix(command, "RCPT TO:"): writer.writeLine("250 2.1.5 Ok") case command == "DATA": writer.writeLine("354 End data with .") var payload strings.Builder for { dataLine, err := reader.readRawLine() if err != nil { return } if dataLine == ".\r\n" { break } payload.WriteString(dataLine) } capture.appendPayload(payload.String()) writer.writeLine(cfg.FinalDataReply) case command == "RSET": writer.writeLine("250 2.0.0 Ok") case command == "QUIT": writer.writeLine("221 2.0.0 Bye") return default: writer.writeLine("250 2.0.0 Ok") } } } type smtpLineReader struct { conn net.Conn } func newSMTPLineReader(conn net.Conn) *smtpLineReader { return &smtpLineReader{conn: conn} } func (reader *smtpLineReader) readLine() (string, error) { line, err := reader.readRawLine() if err != nil { return "", err } return strings.TrimSuffix(strings.TrimSuffix(line, "\n"), "\r"), nil } func (reader *smtpLineReader) readRawLine() (string, error) { var buffer bytes.Buffer tmp := make([]byte, 1) for { if _, err := reader.conn.Read(tmp); err != nil { return "", err } buffer.WriteByte(tmp[0]) if tmp[0] == '\n' { return buffer.String(), nil } } } type smtpLineWriter struct { conn net.Conn } func newSMTPLineWriter(conn net.Conn) *smtpLineWriter { return &smtpLineWriter{conn: conn} } func (writer *smtpLineWriter) writeLine(line string) { _, _ = io.WriteString(writer.conn, line+"\r\n") } func (writer *smtpLineWriter) writeLines(lines ...string) { for _, line := range lines { writer.writeLine(line) } } func newSMTPCertificates(t testing.TB) (tls.Certificate, []byte) { t.Helper() rootKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { t.Fatalf("generate SMTP root key: %v", err) } now := time.Now() rootTemplate := x509.Certificate{ SerialNumber: big.NewInt(1), Subject: pkix.Name{ CommonName: "galaxy-integration-smtp-root", }, NotBefore: now.Add(-time.Hour), NotAfter: now.Add(24 * time.Hour), KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign | x509.KeyUsageDigitalSignature, IsCA: true, BasicConstraintsValid: true, } rootDER, err := x509.CreateCertificate(rand.Reader, &rootTemplate, &rootTemplate, &rootKey.PublicKey, rootKey) if err != nil { t.Fatalf("create SMTP root certificate: %v", err) } rootPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: rootDER}) serverKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { t.Fatalf("generate SMTP server key: %v", err) } serverTemplate := x509.Certificate{ SerialNumber: big.NewInt(2), Subject: pkix.Name{ CommonName: "127.0.0.1", }, NotBefore: now.Add(-time.Hour), NotAfter: now.Add(24 * time.Hour), KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, BasicConstraintsValid: true, DNSNames: []string{"localhost"}, IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, } rootCert, err := x509.ParseCertificate(rootDER) if err != nil { t.Fatalf("parse SMTP root certificate: %v", err) } serverDER, err := x509.CreateCertificate(rand.Reader, &serverTemplate, rootCert, &serverKey.PublicKey, rootKey) if err != nil { t.Fatalf("create SMTP server certificate: %v", err) } serverPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: serverDER}) serverKeyPEM := pem.EncodeToMemory(&pem.Block{ Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(serverKey), }) certificate, err := tls.X509KeyPair(append(serverPEM, rootPEM...), serverKeyPEM) if err != nil { t.Fatalf("load SMTP server key pair: %v", err) } return certificate, rootPEM }