author | Alberto Bertogli
<albertito@blitiri.com.ar> 2021-10-25 11:39:09 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2021-11-26 13:25:31 UTC |
parent | 14e270b7f524480351a857a008bd392277be663d |
internal/courier/smtp.go | +5 | -0 |
internal/courier/smtp_test.go | +112 | -25 |
diff --git a/internal/courier/smtp.go b/internal/courier/smtp.go index 9d55cd2..7d70b73 100644 --- a/internal/courier/smtp.go +++ b/internal/courier/smtp.go @@ -152,6 +152,7 @@ func (a *attempt) deliver(mx string) (error, bool) { return nil }, } + err = c.StartTLS(config) if err != nil { tlsCount.Add("tls:failed", 1) @@ -206,6 +207,9 @@ func (a *attempt) deliver(mx string) (error, bool) { return nil, false } +// CA roots to validate against, so we can override it for testing. +var certRoots *x509.CertPool = nil + func (a *attempt) verifyConnection(cs tls.ConnectionState) domaininfo.SecLevel { // Validate certificates, using the same logic Go does, and following the // official example at @@ -213,6 +217,7 @@ func (a *attempt) verifyConnection(cs tls.ConnectionState) domaininfo.SecLevel { opts := x509.VerifyOptions{ DNSName: cs.ServerName, Intermediates: x509.NewCertPool(), + Roots: certRoots, } for _, cert := range cs.PeerCertificates[1:] { opts.Intermediates.AddCert(cert) diff --git a/internal/courier/smtp_test.go b/internal/courier/smtp_test.go index 70bb6a7..5827e3d 100644 --- a/internal/courier/smtp_test.go +++ b/internal/courier/smtp_test.go @@ -3,7 +3,9 @@ package courier import ( "bufio" "crypto/tls" + "crypto/x509" "fmt" + "io/ioutil" "net" "net/textproto" "os" @@ -13,6 +15,7 @@ import ( "time" "blitiri.com.ar/go/chasquid/internal/domaininfo" + "blitiri.com.ar/go/chasquid/internal/sts" "blitiri.com.ar/go/chasquid/internal/testlib" "blitiri.com.ar/go/chasquid/internal/trace" ) @@ -44,6 +47,7 @@ func newSMTP(t *testing.T) (*SMTP, string) { // Fake server, to test SMTP out. type FakeServer struct { t *testing.T + tmpDir string responses map[string]string wg *sync.WaitGroup addr string @@ -53,6 +57,7 @@ type FakeServer struct { func newFakeServer(t *testing.T, responses map[string]string) *FakeServer { s := &FakeServer{ t: t, + tmpDir: testlib.MustTempDir(t), responses: responses, wg: &sync.WaitGroup{}, } @@ -60,24 +65,27 @@ func newFakeServer(t *testing.T, responses map[string]string) *FakeServer { return s } -func (s *FakeServer) loadTLS() string { - tmpDir := testlib.MustTempDir(s.t) +func (s *FakeServer) Cleanup() { + // Remove our temporary data. Be extra paranoid and make sure the + // directory isn't too shallow. + if len(s.tmpDir) > 8 { + os.RemoveAll(s.tmpDir) + } +} + +func (s *FakeServer) initTLS() { var err error - s.tlsConfig, err = testlib.GenerateCert(tmpDir) + s.tlsConfig, err = testlib.GenerateCert(s.tmpDir) if err != nil { - os.RemoveAll(tmpDir) s.t.Fatalf("error generating cert: %v", err) } - cert, err := tls.LoadX509KeyPair(tmpDir+"/cert.pem", tmpDir+"/key.pem") + cert, err := tls.LoadX509KeyPair(s.tmpDir+"/cert.pem", s.tmpDir+"/key.pem") if err != nil { - os.RemoveAll(tmpDir) s.t.Fatalf("error loading temp cert: %v", err) } s.tlsConfig.Certificates = []tls.Certificate{cert} - - return tmpDir } func (s *FakeServer) start() string { @@ -88,15 +96,14 @@ func (s *FakeServer) start() string { } s.addr = l.Addr().String() + s.initTLS() + s.wg.Add(1) go func() { defer s.wg.Done() defer l.Close() - tmpDir := s.loadTLS() - defer os.RemoveAll(tmpDir) - c, err := l.Accept() if err != nil { panic(err) @@ -172,6 +179,7 @@ func TestSMTP(t *testing.T) { "QUIT": "250 quit ok\n", } srv := newFakeServer(t, responses) + defer srv.Cleanup() host, port := srv.HostPort() // Put a non-existing host first, so we check that if the first host @@ -244,6 +252,7 @@ func TestSMTPErrors(t *testing.T) { for _, rs := range responses { srv := newFakeServer(t, rs) + defer srv.Cleanup() host, port := srv.HostPort() testMX["to"] = []*net.MX{{Host: host, Pref: 10}} @@ -359,21 +368,24 @@ func TestLookupInvalidDomain(t *testing.T) { } } +// Server fake responses for a complete TLS delivery. +// We use this in a few tests, so make it common. +var tlsResponses = map[string]string{ + "_welcome": "220 welcome\n", + "EHLO hello": "250-ehlo ok\n250 STARTTLS\n", + "STARTTLS": "220 starttls go\n", + "_STARTTLS": "ok", + "MAIL FROM:<me@me>": "250 mail ok\n", + "RCPT TO:<to@to>": "250 rcpt ok\n", + "DATA": "354 send data\n", + "_DATA": "250 data ok\n", + "QUIT": "250 quit ok\n", +} + func TestTLS(t *testing.T) { smtpTotalTimeout = 5 * time.Second - - responses := map[string]string{ - "_welcome": "220 welcome\n", - "EHLO hello": "250-ehlo ok\n250 STARTTLS\n", - "STARTTLS": "220 starttls go\n", - "_STARTTLS": "ok", - "MAIL FROM:<me@me>": "250 mail ok\n", - "RCPT TO:<to@to>": "250 rcpt ok\n", - "DATA": "354 send data\n", - "_DATA": "250 data ok\n", - "QUIT": "250 quit ok\n", - } - srv := newFakeServer(t, responses) + srv := newFakeServer(t, tlsResponses) + defer srv.Cleanup() _, *smtpPort = srv.HostPort() testMX["to"] = []*net.MX{ @@ -391,7 +403,7 @@ func TestTLS(t *testing.T) { // Now do another delivery, but without TLS, to check that the detection // of connection downgrade is working. - responses = map[string]string{ + responses := map[string]string{ "_welcome": "220 welcome\n", "EHLO hello": "250 ehlo ok\n", "MAIL FROM:<me@me>": "250 mail ok\n", @@ -401,6 +413,7 @@ func TestTLS(t *testing.T) { "QUIT": "250 quit ok\n", } srv = newFakeServer(t, responses) + defer srv.Cleanup() _, *smtpPort = srv.HostPort() err, permanent := s.Deliver("me@me", "to@to", []byte("data")) @@ -425,6 +438,7 @@ func TestTLSError(t *testing.T) { "_STARTTLS": "no", } srv := newFakeServer(t, responses) + defer srv.Cleanup() _, *smtpPort = srv.HostPort() testMX["to"] = []*net.MX{ @@ -443,3 +457,76 @@ func TestTLSError(t *testing.T) { srv.Wait() } + +func TestSTSPolicyEnforcement(t *testing.T) { + smtpTotalTimeout = 5 * time.Second + srv := newFakeServer(t, tlsResponses) + defer srv.Cleanup() + _, *smtpPort = srv.HostPort() + + s, tmpDir := newSMTP(t) + defer testlib.RemoveIfOk(t, tmpDir) + + a := &attempt{ + courier: s, + from: "me@me", + to: "to@to", + toDomain: "to", + data: []byte("data"), + tr: trace.New("test", "test"), + } + + a.stsPolicy = &sts.Policy{ + Version: "STSv1", + Mode: sts.Enforce, + MXs: []string{"mx"}, + MaxAge: 1 * time.Minute, + } + + // At this point the cert is not valid, which is incompatible with STS + // policy, so we expect it to fail. + err, permanent := a.deliver("localhost") + if !strings.Contains(err.Error(), + "invalid security level (TLS_INSECURE) for STS policy") { + t.Errorf("expected invalid sec level error, got %v", err) + } + if permanent != false { + t.Errorf("expected transient error, got permanent") + } + + srv.Wait() + + // Do another delivery attempt, but this time we trust the server cert. + // This time it should be successful, because the connection level should + // be TLS_SECURE which is required by the STS policy. + srv = newFakeServer(t, tlsResponses) + _, *smtpPort = srv.HostPort() + defer srv.Cleanup() + + certRoots = loadCert(t, srv.tmpDir+"/cert.pem") + defer func() { + certRoots = nil + }() + + err, permanent = a.deliver("localhost") + if err != nil { + t.Errorf("expected success, got %v (permanent=%v)", err, permanent) + } + + srv.Wait() +} + +func loadCert(t *testing.T, path string) *x509.CertPool { + t.Helper() + + pool := x509.NewCertPool() + data, err := ioutil.ReadFile(path) + if err != nil { + t.Fatalf("error reading cert %q: %v", path, err) + } + ok := pool.AppendCertsFromPEM(data) + if !ok { + t.Fatalf("failed to load cert %q", path) + } + return pool +}