author | Alberto Bertogli
<albertito@blitiri.com.ar> 2023-03-03 09:51:48 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2023-03-03 09:51:48 UTC |
parent | 1927e15ea2d8f8917642c14819a80da640806f01 |
internal/courier/fakeserver_test.go | +10 | -5 |
internal/courier/smtp.go | +17 | -2 |
internal/courier/smtp_test.go | +34 | -15 |
diff --git a/internal/courier/fakeserver_test.go b/internal/courier/fakeserver_test.go index 677350c..a183cd3 100644 --- a/internal/courier/fakeserver_test.go +++ b/internal/courier/fakeserver_test.go @@ -20,14 +20,16 @@ type FakeServer struct { responses map[string]string wg *sync.WaitGroup addr string + conns int tlsConfig *tls.Config } -func newFakeServer(t *testing.T, responses map[string]string) *FakeServer { +func newFakeServer(t *testing.T, responses map[string]string, conns int) *FakeServer { s := &FakeServer{ t: t, tmpDir: testlib.MustTempDir(t), responses: responses, + conns: conns, wg: &sync.WaitGroup{}, } s.start() @@ -82,11 +84,10 @@ func (s *FakeServer) start() string { s.initTLS() - s.wg.Add(1) + s.wg.Add(s.conns) - go func() { + accept := func() { defer s.wg.Done() - defer l.Close() c, err := l.Accept() if err != nil { @@ -134,7 +135,11 @@ func (s *FakeServer) start() string { c.Write([]byte(s.responses["_DATA"])) } } - }() + } + + for i := 0; i < s.conns; i++ { + go accept() + } return s.addr } diff --git a/internal/courier/smtp.go b/internal/courier/smtp.go index 17a2cb4..0b7cc1b 100644 --- a/internal/courier/smtp.go +++ b/internal/courier/smtp.go @@ -120,6 +120,8 @@ type attempt struct { } func (a *attempt) deliver(mx string) (error, bool) { + skipTLS := false +retry: conn, err := net.DialTimeout("tcp", mx+":"+*smtpPort, smtpDialTimeout) if err != nil { return a.tr.Errorf("Could not dial: %v", err), false @@ -137,7 +139,7 @@ func (a *attempt) deliver(mx string) (error, bool) { } secLevel := domaininfo.SecLevel_PLAIN - if ok, _ := c.Extension("STARTTLS"); ok { + if ok, _ := c.Extension("STARTTLS"); ok && !skipTLS { config := &tls.Config{ ServerName: mx, @@ -155,8 +157,21 @@ func (a *attempt) deliver(mx string) (error, bool) { err = c.StartTLS(config) if err != nil { + // If we could not complete a jump to TLS (either because the + // STARTTLS command itself failed server-side, or because we got a + // TLS negotiation error), retry but without trying to use TLS. + // This should be quite rare, but it can happen if the server + // certificate is not parseable by the Go library, or if it has a + // broken TLS stack. + // Note that invalid and self-signed certs do NOT fall in this + // category, those are handled by the VerifyConnection function + // above, and don't need a retry. This is only needed for lower + // level errors. tlsCount.Add("tls:failed", 1) - return a.tr.Errorf("TLS error: %v", err), false + a.tr.Errorf("TLS error, retrying without TLS: %v", err) + skipTLS = true + conn.Close() + goto retry } } else { tlsCount.Add("plain", 1) diff --git a/internal/courier/smtp_test.go b/internal/courier/smtp_test.go index 2f47627..76584c5 100644 --- a/internal/courier/smtp_test.go +++ b/internal/courier/smtp_test.go @@ -51,7 +51,7 @@ func TestSMTP(t *testing.T) { "_DATA": "250 data ok\n", "QUIT": "250 quit ok\n", } - srv := newFakeServer(t, responses) + srv := newFakeServer(t, responses, 1) defer srv.Cleanup() host, port := srv.HostPort() @@ -124,7 +124,7 @@ func TestSMTPErrors(t *testing.T) { } for _, rs := range responses { - srv := newFakeServer(t, rs) + srv := newFakeServer(t, rs, 1) defer srv.Cleanup() host, port := srv.HostPort() @@ -257,7 +257,7 @@ var tlsResponses = map[string]string{ func TestTLS(t *testing.T) { smtpTotalTimeout = 5 * time.Second - srv := newFakeServer(t, tlsResponses) + srv := newFakeServer(t, tlsResponses, 1) defer srv.Cleanup() _, *smtpPort = srv.HostPort() @@ -285,7 +285,7 @@ func TestTLS(t *testing.T) { "_DATA": "250 data ok\n", "QUIT": "250 quit ok\n", } - srv = newFakeServer(t, responses) + srv = newFakeServer(t, responses, 1) defer srv.Cleanup() _, *smtpPort = srv.HostPort() @@ -305,12 +305,27 @@ func TestTLSError(t *testing.T) { smtpTotalTimeout = 5 * time.Second responses := map[string]string{ - "_welcome": "220 welcome\n", + "_welcome": "220 welcome\n", + + // STARTTLS should be advertised so we try to initiate it. "EHLO hello": "250-ehlo ok\n250 STARTTLS\n", - "STARTTLS": "500 starttls err\n", - "_STARTTLS": "no", + + // Error in STARTTLS request. Note that a TLS-layer error also falls + // under this code path, so both situations are covered by this test. + "STARTTLS": "500 starttls err\n", + "_STARTTLS": "no", + + // Rest of the transaction is normal and straightforward. + "MAIL FROM:<me@me>": "250 mail ok\n", + "RCPT TO:<to@to>": "250 rcpt ok\n", + "DATA": "354 send data\n", + "_DATA": "250 data ok\n", + "QUIT": "250 quit ok\n", } - srv := newFakeServer(t, responses) + // Note we expect 2 connections to the fake server (because of the retry + // after the failed STARTTLS). Note this also checks that we correctly + // close the errored connection, instead of leaving it lingering. + srv := newFakeServer(t, responses, 2) defer srv.Cleanup() _, *smtpPort = srv.HostPort() @@ -320,12 +335,16 @@ func TestTLSError(t *testing.T) { s, tmpDir := newSMTP(t) defer testlib.RemoveIfOk(t, tmpDir) - err, permanent := s.Deliver("me@me", "to@to", []byte("data")) - if !strings.Contains(err.Error(), "TLS error:") { - t.Errorf("expected TLS error, got: %v", err) + err, _ := s.Deliver("me@me", "to@to", []byte("data")) + if err != nil { + t.Errorf("deliver failed: %v", err) } - if permanent != false { - t.Errorf("expected transient failure, got permanent") + + // Double check that we delivered over a plaintext connection. + tr := trace.New("test", "test") + defer tr.Finish() + if !s.Dinfo.OutgoingSecLevel(tr, "to", domaininfo.SecLevel_PLAIN) { + t.Errorf("delivery did not took place over plaintext as expected") } srv.Wait() @@ -333,7 +352,7 @@ func TestTLSError(t *testing.T) { func TestSTSPolicyEnforcement(t *testing.T) { smtpTotalTimeout = 5 * time.Second - srv := newFakeServer(t, tlsResponses) + srv := newFakeServer(t, tlsResponses, 1) defer srv.Cleanup() _, *smtpPort = srv.HostPort() @@ -372,7 +391,7 @@ func TestSTSPolicyEnforcement(t *testing.T) { // Do another delivery attempt, but this time we trust the server cert. // This time it should be successful, because the connection level should // be TLS_SECURE which is required by the STS policy. - srv = newFakeServer(t, tlsResponses) + srv = newFakeServer(t, tlsResponses, 1) _, *smtpPort = srv.HostPort() defer srv.Cleanup()