git » chasquid » commit 8bf584b

sts: Don't pre-filter MX list, but skip them if needed

author Alberto Bertogli
2018-05-27 09:43:46 UTC
committer Alberto Bertogli
2018-07-01 11:19:02 UTC
parent 252ab5d3e3eb16437a93c1d8464b2ccbdda06d82

sts: Don't pre-filter MX list, but skip them if needed

Instead of pre-filtering the MX list based on STS policy, just check
if it's allowed before each attempt, and skip it if not.

This simplifies the code.

internal/courier/smtp.go +12 -34
internal/sts/sts.go +5 -1

diff --git a/internal/courier/smtp.go b/internal/courier/smtp.go
index a46aa29..8bb9526 100644
--- a/internal/courier/smtp.go
+++ b/internal/courier/smtp.go
@@ -73,9 +73,7 @@ func (s *SMTP) Deliver(from string, to string, data []byte) (error, bool) {
 		a.from = ""
 	}
 
-	a.stsPolicy = s.fetchSTSPolicy(a.tr, a.toDomain)
-
-	mxs, err := lookupMXs(a.tr, a.toDomain, a.stsPolicy)
+	mxs, err := lookupMXs(a.tr, a.toDomain)
 	if err != nil || len(mxs) == 0 {
 		// Note this is considered a permanent error.
 		// This is in line with what other servers (Exim) do. However, the
@@ -95,7 +93,14 @@ func (s *SMTP) Deliver(from string, to string, data []byte) (error, bool) {
 		a.helloDomain, _ = os.Hostname()
 	}
 
+	a.stsPolicy = s.fetchSTSPolicy(a.tr, a.toDomain)
+
 	for _, mx := range mxs {
+		if a.stsPolicy != nil && !a.stsPolicy.MXIsAllowed(mx) {
+			a.tr.Printf("%q skipped as per MTA-STA policy", mx)
+			continue
+		}
+
 		var permanent bool
 		err, permanent = a.deliver(mx)
 		if err == nil {
@@ -248,7 +253,7 @@ func (s *SMTP) fetchSTSPolicy(tr *trace.Trace, domain string) *sts.Policy {
 	return policy
 }
 
-func lookupMXs(tr *trace.Trace, domain string, policy *sts.Policy) ([]string, error) {
+func lookupMXs(tr *trace.Trace, domain string) ([]string, error) {
 	domain, err := idna.ToASCII(domain)
 	if err != nil {
 		return nil, err
@@ -288,39 +293,12 @@ func lookupMXs(tr *trace.Trace, domain string, policy *sts.Policy) ([]string, er
 	// This case is explicitly covered by the SMTP RFC.
 	// https://tools.ietf.org/html/rfc5321#section-5.1
 
-	mxs = filterMXs(tr, policy, mxs)
-	if len(mxs) == 0 {
-		tr.Errorf("domain %q has no valid MX/A record", domain)
-	} else if len(mxs) > 5 {
-		// Cap the list of MXs to 5 hosts, to keep delivery attempt times
-		// sane and prevent abuse.
+	// Cap the list of MXs to 5 hosts, to keep delivery attempt times
+	// sane and prevent abuse.
+	if len(mxs) > 5 {
 		mxs = mxs[:5]
 	}
 
 	tr.Debugf("MXs: %v", mxs)
 	return mxs, nil
 }
-
-func filterMXs(tr *trace.Trace, p *sts.Policy, mxs []string) []string {
-	if p == nil {
-		return mxs
-	}
-
-	filtered := []string{}
-	for _, mx := range mxs {
-		if p.MXIsAllowed(mx) {
-			filtered = append(filtered, mx)
-		} else {
-			tr.Printf("MX %q not allowed by policy, skipping", mx)
-		}
-	}
-
-	// We don't want to return an empty set if the mode is not enforce.
-	// This prevents failures for policies in reporting mode.
-	// https://tools.ietf.org/html/draft-ietf-uta-mta-sts-18#section-5.1
-	if len(filtered) == 0 && p.Mode != sts.Enforce {
-		filtered = mxs
-	}
-
-	return filtered
-}
diff --git a/internal/sts/sts.go b/internal/sts/sts.go
index 0eda169..d4f7298 100644
--- a/internal/sts/sts.go
+++ b/internal/sts/sts.go
@@ -140,9 +140,13 @@ func (p *Policy) Check() error {
 	return nil
 }
 
-// MXMatches checks if the given MX is allowed, according to the policy.
+// MXIsAllowed checks if the given MX is allowed, according to the policy.
 // https://tools.ietf.org/html/draft-ietf-uta-mta-sts-18#section-4.1
 func (p *Policy) MXIsAllowed(mx string) bool {
+	if p.Mode != Enforce {
+		return true
+	}
+
 	for _, pattern := range p.MXs {
 		if matchDomain(mx, pattern) {
 			return true