git » chasquid » commit 5451e3c

courier: Support testing STARTTLS

author Alberto Bertogli
2020-09-26 09:36:37 UTC
committer Alberto Bertogli
2021-09-03 11:58:10 UTC
parent a5bd8cbc0dc0a48768629adb13aacfcacb0afba8

courier: Support testing STARTTLS

This patch adds support for testing STARTTLS commands on the SMTP courier.

It works by using an embedded certificate for server-side, and then
making clients' root CA pool contain it.

To keep the tests more organized, and to prepare for upcoming changes,
the fake server is split out into its own file.

internal/courier/fakeserver_test.go +95 -0
internal/courier/smtp_test.go +52 -79
internal/testlib/testlib.go +62 -0

diff --git a/internal/courier/fakeserver_test.go b/internal/courier/fakeserver_test.go
new file mode 100644
index 0000000..ce43892
--- /dev/null
+++ b/internal/courier/fakeserver_test.go
@@ -0,0 +1,95 @@
+package courier
+
+import (
+	"bufio"
+	"crypto/tls"
+	"crypto/x509"
+	"net"
+	"net/textproto"
+	"strings"
+	"sync"
+	"testing"
+
+	"blitiri.com.ar/go/chasquid/internal/testlib"
+)
+
+type fakeServer struct {
+	addr    string
+	wg      *sync.WaitGroup
+	rootCAs *x509.CertPool
+}
+
+// Fake server, to test SMTP out.
+func newFakeServer(t *testing.T, responses map[string]string) *fakeServer {
+	l, err := net.Listen("tcp", "localhost:0")
+	if err != nil {
+		t.Fatalf("fake server listen: %v", err)
+	}
+
+	clientTLSConfig, serverTLSConfig := testlib.TLSConfig()
+
+	srv := &fakeServer{
+		addr:    l.Addr().String(),
+		wg:      &sync.WaitGroup{},
+		rootCAs: clientTLSConfig.RootCAs,
+	}
+
+	srv.wg.Add(1)
+
+	go func() {
+		defer srv.wg.Done()
+		defer l.Close()
+
+		var c net.Conn
+		var err error
+		c, err = l.Accept()
+		if err != nil {
+			panic(err)
+		}
+		defer c.Close()
+
+		t.Logf("fakeServer got connection")
+
+		r := textproto.NewReader(bufio.NewReader(c))
+		c.Write([]byte(responses["_welcome"]))
+		for {
+			line, err := r.ReadLine()
+			if err != nil {
+				t.Logf("fakeServer exiting: %v\n", err)
+				return
+			}
+
+			t.Logf("fakeServer read: %q\n", line)
+			c.Write([]byte(responses[line]))
+
+			if line == "DATA" {
+				_, err = r.ReadDotBytes()
+				if err != nil {
+					t.Logf("fakeServer exiting: %v\n", err)
+					return
+				}
+				c.Write([]byte(responses["_DATA"]))
+			} else if line == "STARTTLS" && strings.HasPrefix(responses[line], "220 ") {
+				tlsconn := tls.Server(c, serverTLSConfig)
+				defer tlsconn.Close()
+
+				if err = tlsconn.Handshake(); err != nil {
+					t.Logf("fakeServer error in STARTTLS: %v", err)
+					return
+				}
+				c = tlsconn
+				r = textproto.NewReader(bufio.NewReader(c))
+			}
+		}
+	}()
+
+	return srv
+}
+
+func makeResp(as ...string) map[string]string {
+	m := map[string]string{}
+	for i := 0; i < len(as); i += 2 {
+		m[as[i]] = as[i+1]
+	}
+	return m
+}
diff --git a/internal/courier/smtp_test.go b/internal/courier/smtp_test.go
index 0455023..38b11b5 100644
--- a/internal/courier/smtp_test.go
+++ b/internal/courier/smtp_test.go
@@ -1,12 +1,9 @@
 package courier
 
 import (
-	"bufio"
 	"fmt"
 	"net"
-	"net/textproto"
 	"strings"
-	"sync"
 	"testing"
 	"time"
 
@@ -39,54 +36,6 @@ func newSMTP(t *testing.T) (*SMTP, string) {
 	return &SMTP{"hello", dinfo, nil}, dir
 }
 
-// Fake server, to test SMTP out.
-func fakeServer(t *testing.T, responses map[string]string) (string, *sync.WaitGroup) {
-	l, err := net.Listen("tcp", "localhost:0")
-	if err != nil {
-		t.Fatalf("fake server listen: %v", err)
-	}
-
-	wg := &sync.WaitGroup{}
-	wg.Add(1)
-
-	go func() {
-		defer wg.Done()
-		defer l.Close()
-
-		c, err := l.Accept()
-		if err != nil {
-			panic(err)
-		}
-		defer c.Close()
-
-		t.Logf("fakeServer got connection")
-
-		r := textproto.NewReader(bufio.NewReader(c))
-		c.Write([]byte(responses["_welcome"]))
-		for {
-			line, err := r.ReadLine()
-			if err != nil {
-				t.Logf("fakeServer exiting: %v\n", err)
-				return
-			}
-
-			t.Logf("fakeServer read: %q\n", line)
-			c.Write([]byte(responses[line]))
-
-			if line == "DATA" {
-				_, err = r.ReadDotBytes()
-				if err != nil {
-					t.Logf("fakeServer exiting: %v\n", err)
-					return
-				}
-				c.Write([]byte(responses["_DATA"]))
-			}
-		}
-	}()
-
-	return l.Addr().String(), wg
-}
-
 func TestSMTP(t *testing.T) {
 	// Shorten the total timeout, so the test fails quickly if the protocol
 	// gets stuck.
@@ -101,8 +50,8 @@ func TestSMTP(t *testing.T) {
 		"_DATA":             "250 data ok\n",
 		"QUIT":              "250 quit ok\n",
 	}
-	addr, wg := fakeServer(t, responses)
-	host, port, _ := net.SplitHostPort(addr)
+	srv := newFakeServer(t, responses)
+	host, port, _ := net.SplitHostPort(srv.addr)
 
 	// Put a non-existing host first, so we check that if the first host
 	// doesn't work, we try with the rest.
@@ -123,7 +72,7 @@ func TestSMTP(t *testing.T) {
 		t.Errorf("deliver failed: %v", err)
 	}
 
-	wg.Wait()
+	srv.wg.Wait()
 }
 
 func TestSMTPErrors(t *testing.T) {
@@ -131,50 +80,66 @@ func TestSMTPErrors(t *testing.T) {
 	// gets stuck.
 	smtpTotalTimeout = 1 * time.Second
 
-	responses := []map[string]string{
+	cases := []struct {
+		responses map[string]string
+		errPrefix string
+	}{
 		// First test: hang response, should fail due to timeout.
 		{
-			"_welcome": "220 no newline",
+			makeResp("_welcome", "220 no newline"),
+			"",
 		},
 
 		// MAIL FROM not allowed.
 		{
-			"_welcome":          "220 mail from not allowed\n",
-			"EHLO hello":        "250 ehlo ok\n",
-			"MAIL FROM:<me@me>": "501 mail error\n",
+			makeResp(
+				"_welcome", "220 mail from not allowed\n",
+				"EHLO hello", "250 ehlo ok\n",
+				"MAIL FROM:<me@me>", "501 mail error\n",
+			),
+			"MAIL+RCPT 501 mail error",
 		},
 
 		// RCPT TO not allowed.
 		{
-			"_welcome":          "220 rcpt to not allowed\n",
-			"EHLO hello":        "250 ehlo ok\n",
-			"MAIL FROM:<me@me>": "250 mail ok\n",
-			"RCPT TO:<to@to>":   "501 rcpt error\n",
+			makeResp(
+				"_welcome", "220 rcpt to not allowed\n",
+				"EHLO hello", "250 ehlo ok\n",
+				"MAIL FROM:<me@me>", "250 mail ok\n",
+				"RCPT TO:<to@to>", "501 rcpt error\n",
+			),
+			"MAIL+RCPT 501 rcpt error",
 		},
 
 		// DATA error.
 		{
-			"_welcome":          "220 data error\n",
-			"EHLO hello":        "250 ehlo ok\n",
-			"MAIL FROM:<me@me>": "250 mail ok\n",
-			"RCPT TO:<to@to>":   "250 rcpt ok\n",
-			"DATA":              "554 data error\n",
+			makeResp(
+				"_welcome", "220 data error\n",
+				"EHLO hello", "250 ehlo ok\n",
+				"MAIL FROM:<me@me>", "250 mail ok\n",
+				"RCPT TO:<to@to>", "250 rcpt ok\n",
+				"DATA", "554 data error\n",
+			),
+			"DATA 554 data error",
 		},
 
 		// DATA response error.
 		{
-			"_welcome":          "220 data response error\n",
-			"EHLO hello":        "250 ehlo ok\n",
-			"MAIL FROM:<me@me>": "250 mail ok\n",
-			"RCPT TO:<to@to>":   "250 rcpt ok\n",
-			"DATA":              "354 send data\n",
-			"_DATA":             "551 data response error\n",
+			makeResp(
+				"_welcome", "220 data error\n",
+				"EHLO hello", "250 ehlo ok\n",
+				"MAIL FROM:<me@me>", "250 mail ok\n",
+				"RCPT TO:<to@to>", "250 rcpt ok\n",
+				"DATA", "354 send data\n",
+				"_DATA", "551 data response error\n",
+			),
+			"DATA closing 551 data response error",
 		},
 	}
 
-	for _, rs := range responses {
-		addr, wg := fakeServer(t, rs)
-		host, port, _ := net.SplitHostPort(addr)
+	for _, c := range cases {
+		srv := newFakeServer(t, c.responses)
+		host, port, _ := net.SplitHostPort(srv.addr)
 
 		testMX["to"] = []*net.MX{{Host: host, Pref: 10}}
 		*smtpPort = port
@@ -182,12 +147,20 @@ func TestSMTPErrors(t *testing.T) {
 		s, tmpDir := newSMTP(t)
 		defer testlib.RemoveIfOk(t, tmpDir)
 		err, _ := s.Deliver("me@me", "to@to", []byte("data"))
+
 		if err == nil {
-			t.Errorf("deliver not failed in case %q: %v", rs["_welcome"], err)
+			t.Errorf("deliver not failed in case %q: %v",
+				c.responses["_welcome"], err)
+			continue
 		}
 		t.Logf("failed as expected: %v", err)
 
-		wg.Wait()
+		if !strings.HasPrefix(err.Error(), c.errPrefix) {
+			t.Errorf("expected error prefix %q, got %q",
+				c.errPrefix, err)
+		}
+
+		srv.wg.Wait()
 	}
 }
 
diff --git a/internal/testlib/testlib.go b/internal/testlib/testlib.go
index 162f4f8..2eaf5a1 100644
--- a/internal/testlib/testlib.go
+++ b/internal/testlib/testlib.go
@@ -2,6 +2,9 @@
 package testlib
 
 import (
+	"crypto/tls"
+	"crypto/x509"
+	"fmt"
 	"io/ioutil"
 	"net"
 	"os"
@@ -129,3 +132,62 @@ func (c dumbCourier) Deliver(from string, to string, data []byte) (error, bool)
 
 // DumbCourier always succeeds delivery, and ignores everything.
 var DumbCourier = dumbCourier{}
+
+func TLSConfig() (client, server *tls.Config) {
+	cert, err := tls.X509KeyPair(testCert, testKey)
+	if err != nil {
+		panic(fmt.Sprintf("error creating key pair: %v", err))
+	}
+
+	server = &tls.Config{
+		Certificates: []tls.Certificate{cert},
+	}
+
+	srvCert, err := x509.ParseCertificate(server.Certificates[0].Certificate[0])
+	if err != nil {
+		panic(fmt.Sprintf("error extracting server cert: %v", err))
+	}
+
+	pool := x509.NewCertPool()
+	pool.AddCert(srvCert)
+	client = &tls.Config{
+		RootCAs: pool,
+	}
+
+	return
+}
+
+// PEM-encoded TLS certs, for "localhost", "127.0.0.1" and "[::1]".
+// Generated with:
+// go run /usr/share/go-1.14/src/crypto/tls/generate_cert.go  --rsa-bits 1024 --host 127.0.0.1,::1,localhost --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h
+var testCert = []byte(`-----BEGIN CERTIFICATE-----
+MIICETCCAXqgAwIBAgIQZRK1BVeALoVrF03BQur3kzANBgkqhkiG9w0BAQsFADAS
+MRAwDgYDVQQKEwdBY21lIENvMCAXDTcwMDEwMTAwMDAwMFoYDzIwODQwMTI5MTYw
+MDAwWjASMRAwDgYDVQQKEwdBY21lIENvMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCB
+iQKBgQCseGPB9aV9c/c/MbRVxwjsG5fqsb+/HnTSh1QDbGnVvCDoJZg4wJv3AEh1
+s9+A12+/ImIpB8I+8sr0ErPfGQ3fAJFx+TgQ6xmyDtNWjkTPHt6AF3cb2jv0rmze
+Dpa1vFXe0FwiZ2d9d1ZGvw1sPIugVAyGjW98bCxw42PXGd4s0wIDAQABo2YwZDAO
+BgNVHQ8BAf8EBAMCAqQwEwYDVR0lBAwwCgYIKwYBBQUHAwEwDwYDVR0TAQH/BAUw
+AwEB/zAsBgNVHREEJTAjgglsb2NhbGhvc3SHBH8AAAGHEAAAAAAAAAAAAAAAAAAA
+AAEwDQYJKoZIhvcNAQELBQADgYEASUE1j+hQT7LgKYaP0w1itfcSZGhR1ZGnMThZ
+iiPnHt6ZLINZ39x2P/71KJYZklpJgewGBVRqMNTIW6hAa3UU7giQHQDDwSCtH4Zf
+C4WOwq3LouX8rLqZwq8W6ETnpbSUDEslhVR2IdufcLH947yoWbLuUc30SJET6dZq
+Fd+Ux3o=
+-----END CERTIFICATE-----`)
+
+var testKey = []byte(`-----BEGIN PRIVATE KEY-----
+MIICdwIBADANBgkqhkiG9w0BAQEFAASCAmEwggJdAgEAAoGBAKx4Y8H1pX1z9z8x
+tFXHCOwbl+qxv78edNKHVANsadW8IOglmDjAm/cASHWz34DXb78iYikHwj7yyvQS
+s98ZDd8AkXH5OBDrGbIO01aORM8e3oAXdxvaO/SubN4OlrW8Vd7QXCJnZ313Vka/
+DWw8i6BUDIaNb3xsLHDjY9cZ3izTAgMBAAECgYAgOia52YLg3Eh5AHqoBJcAN2+9
+pRUlSzWdGThzo1BrZcnoVw4InMUH9H+VrtS2qIry9iPNcuuzA380+EGwEGhs0o/q
+jHd4HtAgPK0zI/DalbVRzkiU9Qjqq2CHMpuJYIh+S2TlGHDUkShdnKi3RCgV0FC6
+B49JDnMrcOIyGLKFMQJBANSh86c1ZesTMhkw0tJh4DX0pSqfPI2aSXaRBAvUBUcO
+A/6EYlML5r2Fp6TyqxXk4+Dou9oG3yaUUoEA02Osyv0CQQDPpXefcP28PzamivrI
+o3SQMhIJ2dpqZw7YIbQQ1cE46Q4fVTXpq/eut0TQTcOEEUR9XzIsa3x/qbkJbHFm
+vegPAkEAs69gTY7sX6jLD0qY/bxEUpQ49zm1XBxjtFR7zNsQ0qjfazfIN1G5XbMS
+pmuDdG8Gu0sxY9+mt91jkyx1dqfQqQJARnyuAdLSX1+6BojxHsDV5ckJdIyeZzY6
+xMWUIY7eO5ppb9t2JK96sbWGx4tOTnuqG0EAgDGwnomXxYopaK4YowJBAIjMMW31
+h33KQZJ7ON4pX0+AP2yeHZsbdTXpRuVXeRgoZUK7mo0nWuHFy5etz10GVw4AZtze
+cfH9gZHF5jzmpyo=
+-----END PRIVATE KEY-----`)