git » chasquid » commit 02322a7

courier: Add tests for STS policy checks

author Alberto Bertogli
2021-10-25 11:39:09 UTC
committer Alberto Bertogli
2021-11-26 13:25:31 UTC
parent 14e270b7f524480351a857a008bd392277be663d

courier: Add tests for STS policy checks

This patch adds tests for STS policy checks in combination with TLS
security levels.

This helps ensure we're detecting mis-matches of TLS status
(plain/insecure/secure) and STS policy enforcement.

internal/courier/smtp.go +5 -0
internal/courier/smtp_test.go +112 -25

diff --git a/internal/courier/smtp.go b/internal/courier/smtp.go
index 9d55cd2..7d70b73 100644
--- a/internal/courier/smtp.go
+++ b/internal/courier/smtp.go
@@ -152,6 +152,7 @@ func (a *attempt) deliver(mx string) (error, bool) {
 				return nil
 			},
 		}
+
 		err = c.StartTLS(config)
 		if err != nil {
 			tlsCount.Add("tls:failed", 1)
@@ -206,6 +207,9 @@ func (a *attempt) deliver(mx string) (error, bool) {
 	return nil, false
 }
 
+// CA roots to validate against, so we can override it for testing.
+var certRoots *x509.CertPool = nil
+
 func (a *attempt) verifyConnection(cs tls.ConnectionState) domaininfo.SecLevel {
 	// Validate certificates, using the same logic Go does, and following the
 	// official example at
@@ -213,6 +217,7 @@ func (a *attempt) verifyConnection(cs tls.ConnectionState) domaininfo.SecLevel {
 	opts := x509.VerifyOptions{
 		DNSName:       cs.ServerName,
 		Intermediates: x509.NewCertPool(),
+		Roots:         certRoots,
 	}
 	for _, cert := range cs.PeerCertificates[1:] {
 		opts.Intermediates.AddCert(cert)
diff --git a/internal/courier/smtp_test.go b/internal/courier/smtp_test.go
index 70bb6a7..5827e3d 100644
--- a/internal/courier/smtp_test.go
+++ b/internal/courier/smtp_test.go
@@ -3,7 +3,9 @@ package courier
 import (
 	"bufio"
 	"crypto/tls"
+	"crypto/x509"
 	"fmt"
+	"io/ioutil"
 	"net"
 	"net/textproto"
 	"os"
@@ -13,6 +15,7 @@ import (
 	"time"
 
 	"blitiri.com.ar/go/chasquid/internal/domaininfo"
+	"blitiri.com.ar/go/chasquid/internal/sts"
 	"blitiri.com.ar/go/chasquid/internal/testlib"
 	"blitiri.com.ar/go/chasquid/internal/trace"
 )
@@ -44,6 +47,7 @@ func newSMTP(t *testing.T) (*SMTP, string) {
 // Fake server, to test SMTP out.
 type FakeServer struct {
 	t         *testing.T
+	tmpDir    string
 	responses map[string]string
 	wg        *sync.WaitGroup
 	addr      string
@@ -53,6 +57,7 @@ type FakeServer struct {
 func newFakeServer(t *testing.T, responses map[string]string) *FakeServer {
 	s := &FakeServer{
 		t:         t,
+		tmpDir:    testlib.MustTempDir(t),
 		responses: responses,
 		wg:        &sync.WaitGroup{},
 	}
@@ -60,24 +65,27 @@ func newFakeServer(t *testing.T, responses map[string]string) *FakeServer {
 	return s
 }
 
-func (s *FakeServer) loadTLS() string {
-	tmpDir := testlib.MustTempDir(s.t)
+func (s *FakeServer) Cleanup() {
+	// Remove our temporary data. Be extra paranoid and make sure the
+	// directory isn't too shallow.
+	if len(s.tmpDir) > 8 {
+		os.RemoveAll(s.tmpDir)
+	}
+}
+
+func (s *FakeServer) initTLS() {
 	var err error
-	s.tlsConfig, err = testlib.GenerateCert(tmpDir)
+	s.tlsConfig, err = testlib.GenerateCert(s.tmpDir)
 	if err != nil {
-		os.RemoveAll(tmpDir)
 		s.t.Fatalf("error generating cert: %v", err)
 	}
 
-	cert, err := tls.LoadX509KeyPair(tmpDir+"/cert.pem", tmpDir+"/key.pem")
+	cert, err := tls.LoadX509KeyPair(s.tmpDir+"/cert.pem", s.tmpDir+"/key.pem")
 	if err != nil {
-		os.RemoveAll(tmpDir)
 		s.t.Fatalf("error loading temp cert: %v", err)
 	}
 
 	s.tlsConfig.Certificates = []tls.Certificate{cert}
-
-	return tmpDir
 }
 
 func (s *FakeServer) start() string {
@@ -88,15 +96,14 @@ func (s *FakeServer) start() string {
 	}
 	s.addr = l.Addr().String()
 
+	s.initTLS()
+
 	s.wg.Add(1)
 
 	go func() {
 		defer s.wg.Done()
 		defer l.Close()
 
-		tmpDir := s.loadTLS()
-		defer os.RemoveAll(tmpDir)
-
 		c, err := l.Accept()
 		if err != nil {
 			panic(err)
@@ -172,6 +179,7 @@ func TestSMTP(t *testing.T) {
 		"QUIT":              "250 quit ok\n",
 	}
 	srv := newFakeServer(t, responses)
+	defer srv.Cleanup()
 	host, port := srv.HostPort()
 
 	// Put a non-existing host first, so we check that if the first host
@@ -244,6 +252,7 @@ func TestSMTPErrors(t *testing.T) {
 
 	for _, rs := range responses {
 		srv := newFakeServer(t, rs)
+		defer srv.Cleanup()
 		host, port := srv.HostPort()
 
 		testMX["to"] = []*net.MX{{Host: host, Pref: 10}}
@@ -359,21 +368,24 @@ func TestLookupInvalidDomain(t *testing.T) {
 	}
 }
 
+// Server fake responses for a complete TLS delivery.
+// We use this in a few tests, so make it common.
+var tlsResponses = map[string]string{
+	"_welcome":          "220 welcome\n",
+	"EHLO hello":        "250-ehlo ok\n250 STARTTLS\n",
+	"STARTTLS":          "220 starttls go\n",
+	"_STARTTLS":         "ok",
+	"MAIL FROM:<me@me>": "250 mail ok\n",
+	"RCPT TO:<to@to>":   "250 rcpt ok\n",
+	"DATA":              "354 send data\n",
+	"_DATA":             "250 data ok\n",
+	"QUIT":              "250 quit ok\n",
+}
+
 func TestTLS(t *testing.T) {
 	smtpTotalTimeout = 5 * time.Second
-
-	responses := map[string]string{
-		"_welcome":          "220 welcome\n",
-		"EHLO hello":        "250-ehlo ok\n250 STARTTLS\n",
-		"STARTTLS":          "220 starttls go\n",
-		"_STARTTLS":         "ok",
-		"MAIL FROM:<me@me>": "250 mail ok\n",
-		"RCPT TO:<to@to>":   "250 rcpt ok\n",
-		"DATA":              "354 send data\n",
-		"_DATA":             "250 data ok\n",
-		"QUIT":              "250 quit ok\n",
-	}
-	srv := newFakeServer(t, responses)
+	srv := newFakeServer(t, tlsResponses)
+	defer srv.Cleanup()
 	_, *smtpPort = srv.HostPort()
 
 	testMX["to"] = []*net.MX{
@@ -391,7 +403,7 @@ func TestTLS(t *testing.T) {
 
 	// Now do another delivery, but without TLS, to check that the detection
 	// of connection downgrade is working.
-	responses = map[string]string{
+	responses := map[string]string{
 		"_welcome":          "220 welcome\n",
 		"EHLO hello":        "250 ehlo ok\n",
 		"MAIL FROM:<me@me>": "250 mail ok\n",
@@ -401,6 +413,7 @@ func TestTLS(t *testing.T) {
 		"QUIT":              "250 quit ok\n",
 	}
 	srv = newFakeServer(t, responses)
+	defer srv.Cleanup()
 	_, *smtpPort = srv.HostPort()
 
 	err, permanent := s.Deliver("me@me", "to@to", []byte("data"))
@@ -425,6 +438,7 @@ func TestTLSError(t *testing.T) {
 		"_STARTTLS":  "no",
 	}
 	srv := newFakeServer(t, responses)
+	defer srv.Cleanup()
 	_, *smtpPort = srv.HostPort()
 
 	testMX["to"] = []*net.MX{
@@ -443,3 +457,76 @@ func TestTLSError(t *testing.T) {
 
 	srv.Wait()
 }
+
+func TestSTSPolicyEnforcement(t *testing.T) {
+	smtpTotalTimeout = 5 * time.Second
+	srv := newFakeServer(t, tlsResponses)
+	defer srv.Cleanup()
+	_, *smtpPort = srv.HostPort()
+
+	s, tmpDir := newSMTP(t)
+	defer testlib.RemoveIfOk(t, tmpDir)
+
+	a := &attempt{
+		courier:  s,
+		from:     "me@me",
+		to:       "to@to",
+		toDomain: "to",
+		data:     []byte("data"),
+		tr:       trace.New("test", "test"),
+	}
+
+	a.stsPolicy = &sts.Policy{
+		Version: "STSv1",
+		Mode:    sts.Enforce,
+		MXs:     []string{"mx"},
+		MaxAge:  1 * time.Minute,
+	}
+
+	// At this point the cert is not valid, which is incompatible with STS
+	// policy, so we expect it to fail.
+	err, permanent := a.deliver("localhost")
+	if !strings.Contains(err.Error(),
+		"invalid security level (TLS_INSECURE) for STS policy") {
+		t.Errorf("expected invalid sec level error, got %v", err)
+	}
+	if permanent != false {
+		t.Errorf("expected transient error, got permanent")
+	}
+
+	srv.Wait()
+
+	// Do another delivery attempt, but this time we trust the server cert.
+	// This time it should be successful, because the connection level should
+	// be TLS_SECURE which is required by the STS policy.
+	srv = newFakeServer(t, tlsResponses)
+	_, *smtpPort = srv.HostPort()
+	defer srv.Cleanup()
+
+	certRoots = loadCert(t, srv.tmpDir+"/cert.pem")
+	defer func() {
+		certRoots = nil
+	}()
+
+	err, permanent = a.deliver("localhost")
+	if err != nil {
+		t.Errorf("expected success, got %v (permanent=%v)", err, permanent)
+	}
+
+	srv.Wait()
+}
+
+func loadCert(t *testing.T, path string) *x509.CertPool {
+	t.Helper()
+
+	pool := x509.NewCertPool()
+	data, err := ioutil.ReadFile(path)
+	if err != nil {
+		t.Fatalf("error reading cert %q: %v", path, err)
+	}
+	ok := pool.AppendCertsFromPEM(data)
+	if !ok {
+		t.Fatalf("failed to load cert %q", path)
+	}
+	return pool
+}