author | Alberto Bertogli
<albertito@blitiri.com.ar> 2021-10-15 10:00:08 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2021-10-25 11:41:24 UTC |
parent | 90d385556fc4aa1985f91195a3d1322ddd67c98f |
internal/courier/smtp.go | +41 | -28 |
internal/courier/smtp_test.go | +173 | -20 |
diff --git a/internal/courier/smtp.go b/internal/courier/smtp.go index 27b6c9b..9d55cd2 100644 --- a/internal/courier/smtp.go +++ b/internal/courier/smtp.go @@ -3,6 +3,7 @@ package courier import ( "context" "crypto/tls" + "crypto/x509" "flag" "net" "time" @@ -119,12 +120,6 @@ type attempt struct { } func (a *attempt) deliver(mx string) (error, bool) { - // Do we use insecure TLS? - // Set as fallback when retrying. - insecure := false - secLevel := domaininfo.SecLevel_PLAIN - -retry: conn, err := net.DialTimeout("tcp", mx+":"+*smtpPort, smtpDialTimeout) if err != nil { return a.tr.Errorf("Could not dial: %v", err), false @@ -141,33 +136,26 @@ retry: return a.tr.Errorf("Error saying hello: %v", err), false } + secLevel := domaininfo.SecLevel_PLAIN if ok, _ := c.Extension("STARTTLS"); ok { config := &tls.Config{ - ServerName: mx, - InsecureSkipVerify: insecure, + ServerName: mx, + + // Unfortunately, many servers use self-signed and invalid + // certificates. So we use a custom verification (identical to + // Go's) to distinguish between invalid and valid certificates. + // That information is used to track the security level, to + // prevent downgrade attacks. + InsecureSkipVerify: true, + VerifyConnection: func(cs tls.ConnectionState) error { + secLevel = a.verifyConnection(cs) + return nil + }, } err = c.StartTLS(config) if err != nil { - // Unfortunately, many servers use self-signed certs, so if we - // fail verification we just try again without validating. - if insecure { - tlsCount.Add("tls:failed", 1) - return a.tr.Errorf("TLS error: %v", err), false - } - - insecure = true - a.tr.Debugf("TLS error, retrying insecurely") - goto retry - } - - if config.InsecureSkipVerify { - a.tr.Debugf("Insecure - using TLS, but cert does not match %s", mx) - tlsCount.Add("tls:insecure", 1) - secLevel = domaininfo.SecLevel_TLS_INSECURE - } else { - tlsCount.Add("tls:secure", 1) - a.tr.Debugf("Secure - using TLS") - secLevel = domaininfo.SecLevel_TLS_SECURE + tlsCount.Add("tls:failed", 1) + return a.tr.Errorf("TLS error: %v", err), false } } else { tlsCount.Add("plain", 1) @@ -218,6 +206,31 @@ retry: return nil, false } +func (a *attempt) verifyConnection(cs tls.ConnectionState) domaininfo.SecLevel { + // Validate certificates, using the same logic Go does, and following the + // official example at + // https://pkg.go.dev/crypto/tls#example-Config-VerifyConnection. + opts := x509.VerifyOptions{ + DNSName: cs.ServerName, + Intermediates: x509.NewCertPool(), + } + for _, cert := range cs.PeerCertificates[1:] { + opts.Intermediates.AddCert(cert) + } + _, err := cs.PeerCertificates[0].Verify(opts) + + if err != nil { + // Invalid TLS cert, since it could not be verified. + a.tr.Debugf("Insecure - using TLS, but with an invalid cert") + tlsCount.Add("tls:insecure", 1) + return domaininfo.SecLevel_TLS_INSECURE + } else { + tlsCount.Add("tls:secure", 1) + a.tr.Debugf("Secure - using TLS") + return domaininfo.SecLevel_TLS_SECURE + } +} + func (s *SMTP) fetchSTSPolicy(tr *trace.Trace, domain string) *sts.Policy { if s.STSCache == nil { return nil diff --git a/internal/courier/smtp_test.go b/internal/courier/smtp_test.go index 93cedbf..70bb6a7 100644 --- a/internal/courier/smtp_test.go +++ b/internal/courier/smtp_test.go @@ -2,9 +2,11 @@ package courier import ( "bufio" + "crypto/tls" "fmt" "net" "net/textproto" + "os" "strings" "sync" "testing" @@ -40,51 +42,119 @@ func newSMTP(t *testing.T) (*SMTP, string) { } // Fake server, to test SMTP out. -func fakeServer(t *testing.T, responses map[string]string) (string, *sync.WaitGroup) { +type FakeServer struct { + t *testing.T + responses map[string]string + wg *sync.WaitGroup + addr string + tlsConfig *tls.Config +} + +func newFakeServer(t *testing.T, responses map[string]string) *FakeServer { + s := &FakeServer{ + t: t, + responses: responses, + wg: &sync.WaitGroup{}, + } + s.start() + return s +} + +func (s *FakeServer) loadTLS() string { + tmpDir := testlib.MustTempDir(s.t) + var err error + s.tlsConfig, err = testlib.GenerateCert(tmpDir) + if err != nil { + os.RemoveAll(tmpDir) + s.t.Fatalf("error generating cert: %v", err) + } + + cert, err := tls.LoadX509KeyPair(tmpDir+"/cert.pem", tmpDir+"/key.pem") + if err != nil { + os.RemoveAll(tmpDir) + s.t.Fatalf("error loading temp cert: %v", err) + } + + s.tlsConfig.Certificates = []tls.Certificate{cert} + + return tmpDir +} + +func (s *FakeServer) start() string { + s.t.Helper() l, err := net.Listen("tcp", "localhost:0") if err != nil { - t.Fatalf("fake server listen: %v", err) + s.t.Fatalf("fake server listen: %v", err) } + s.addr = l.Addr().String() - wg := &sync.WaitGroup{} - wg.Add(1) + s.wg.Add(1) go func() { - defer wg.Done() + defer s.wg.Done() defer l.Close() + tmpDir := s.loadTLS() + defer os.RemoveAll(tmpDir) + c, err := l.Accept() if err != nil { panic(err) } defer c.Close() - t.Logf("fakeServer got connection") + s.t.Logf("fakeServer got connection") r := textproto.NewReader(bufio.NewReader(c)) - c.Write([]byte(responses["_welcome"])) + c.Write([]byte(s.responses["_welcome"])) for { line, err := r.ReadLine() if err != nil { - t.Logf("fakeServer exiting: %v\n", err) + s.t.Logf("fakeServer exiting: %v\n", err) return } - t.Logf("fakeServer read: %q\n", line) - c.Write([]byte(responses[line])) + s.t.Logf("fakeServer read: %q\n", line) + if line == "STARTTLS" && s.responses["_STARTTLS"] == "ok" { + c.Write([]byte(s.responses["STARTTLS"])) + tlssrv := tls.Server(c, s.tlsConfig) + err = tlssrv.Handshake() + if err != nil { + s.t.Logf("starttls handshake error: %v", err) + return + } + + // Replace the connection with the wrapped one. + // Don't send a reply, as per the protocol. + c = tlssrv + defer c.Close() + r = textproto.NewReader(bufio.NewReader(c)) + continue + } + + c.Write([]byte(s.responses[line])) if line == "DATA" { _, err = r.ReadDotBytes() if err != nil { - t.Logf("fakeServer exiting: %v\n", err) + s.t.Logf("fakeServer exiting: %v\n", err) return } - c.Write([]byte(responses["_DATA"])) + c.Write([]byte(s.responses["_DATA"])) } } }() - return l.Addr().String(), wg + return s.addr +} + +func (s *FakeServer) HostPort() (string, string) { + host, port, _ := net.SplitHostPort(s.addr) + return host, port +} + +func (s *FakeServer) Wait() { + s.wg.Wait() } func TestSMTP(t *testing.T) { @@ -101,8 +171,8 @@ func TestSMTP(t *testing.T) { "_DATA": "250 data ok\n", "QUIT": "250 quit ok\n", } - addr, wg := fakeServer(t, responses) - host, port, _ := net.SplitHostPort(addr) + srv := newFakeServer(t, responses) + host, port := srv.HostPort() // Put a non-existing host first, so we check that if the first host // doesn't work, we try with the rest. @@ -123,7 +193,7 @@ func TestSMTP(t *testing.T) { t.Errorf("deliver failed: %v", err) } - wg.Wait() + srv.Wait() } func TestSMTPErrors(t *testing.T) { @@ -173,8 +243,8 @@ func TestSMTPErrors(t *testing.T) { } for _, rs := range responses { - addr, wg := fakeServer(t, rs) - host, port, _ := net.SplitHostPort(addr) + srv := newFakeServer(t, rs) + host, port := srv.HostPort() testMX["to"] = []*net.MX{{Host: host, Pref: 10}} *smtpPort = port @@ -187,7 +257,7 @@ func TestSMTPErrors(t *testing.T) { } t.Logf("failed as expected: %v", err) - wg.Wait() + srv.Wait() } } @@ -289,4 +359,87 @@ func TestLookupInvalidDomain(t *testing.T) { } } -// TODO: Test STARTTLS negotiation. +func TestTLS(t *testing.T) { + smtpTotalTimeout = 5 * time.Second + + responses := map[string]string{ + "_welcome": "220 welcome\n", + "EHLO hello": "250-ehlo ok\n250 STARTTLS\n", + "STARTTLS": "220 starttls go\n", + "_STARTTLS": "ok", + "MAIL FROM:<me@me>": "250 mail ok\n", + "RCPT TO:<to@to>": "250 rcpt ok\n", + "DATA": "354 send data\n", + "_DATA": "250 data ok\n", + "QUIT": "250 quit ok\n", + } + srv := newFakeServer(t, responses) + _, *smtpPort = srv.HostPort() + + testMX["to"] = []*net.MX{ + {Host: "localhost", Pref: 20}, + } + + s, tmpDir := newSMTP(t) + defer testlib.RemoveIfOk(t, tmpDir) + err, _ := s.Deliver("me@me", "to@to", []byte("data")) + if err != nil { + t.Errorf("deliver failed: %v", err) + } + + srv.Wait() + + // Now do another delivery, but without TLS, to check that the detection + // of connection downgrade is working. + responses = map[string]string{ + "_welcome": "220 welcome\n", + "EHLO hello": "250 ehlo ok\n", + "MAIL FROM:<me@me>": "250 mail ok\n", + "RCPT TO:<to@to>": "250 rcpt ok\n", + "DATA": "354 send data\n", + "_DATA": "250 data ok\n", + "QUIT": "250 quit ok\n", + } + srv = newFakeServer(t, responses) + _, *smtpPort = srv.HostPort() + + err, permanent := s.Deliver("me@me", "to@to", []byte("data")) + if !strings.Contains(err.Error(), + "Security level check failed (level:PLAIN)") { + t.Errorf("expected sec level check failed, got: %v", err) + } + if permanent != false { + t.Errorf("expected transient failure, got permanent") + } + + srv.Wait() +} + +func TestTLSError(t *testing.T) { + smtpTotalTimeout = 5 * time.Second + + responses := map[string]string{ + "_welcome": "220 welcome\n", + "EHLO hello": "250-ehlo ok\n250 STARTTLS\n", + "STARTTLS": "500 starttls err\n", + "_STARTTLS": "no", + } + srv := newFakeServer(t, responses) + _, *smtpPort = srv.HostPort() + + testMX["to"] = []*net.MX{ + {Host: "localhost", Pref: 20}, + } + + s, tmpDir := newSMTP(t) + defer testlib.RemoveIfOk(t, tmpDir) + err, permanent := s.Deliver("me@me", "to@to", []byte("data")) + if !strings.Contains(err.Error(), "TLS error:") { + t.Errorf("expected TLS error, got: %v", err) + } + if permanent != false { + t.Errorf("expected transient failure, got permanent") + } + + srv.Wait() +}