git » spf » commit 0af9dac

Implement "exists"

author Alberto Bertogli
2019-10-14 01:08:50 UTC
committer Alberto Bertogli
2019-10-14 12:35:13 UTC
parent aa9c76649b5d4abdff6742664904f6f93d237aef

Implement "exists"

This patch implements the "exists" mechanism.

spf.go +47 -16
spf_test.go +1 -18

diff --git a/spf.go b/spf.go
index 1d22bff..c362b9c 100644
--- a/spf.go
+++ b/spf.go
@@ -16,15 +16,10 @@
 //   ip4
 //   ip6
 //   redirect
+//   exists
 //   exp (ignored)
 //   Macros
 //
-// Not supported (return Neutral if used):
-//   exists
-//
-// This is intentional and there are no plans to add them for now, as they are
-// very rare, convoluted and not worth the additional complexity.
-//
 // References:
 //   https://tools.ietf.org/html/rfc7208
 //   https://en.wikipedia.org/wiki/Sender_Policy_Framework
@@ -92,19 +87,20 @@ var qualToResult = map[byte]Result{
 
 var (
 	errLookupLimitReached = fmt.Errorf("lookup limit reached")
-	errExistsNotSupported = fmt.Errorf("'exists' not supported")
 	errUnknownField       = fmt.Errorf("unknown field")
 	errInvalidIP          = fmt.Errorf("invalid ipX value")
 	errInvalidMask        = fmt.Errorf("invalid mask")
 	errInvalidMacro       = fmt.Errorf("invalid macro")
+	errInvalidDomain      = fmt.Errorf("invalid domain")
 	errNoResult           = fmt.Errorf("lookup yielded no result")
 	errMultipleRecords    = fmt.Errorf("multiple matching DNS records")
 
-	errMatchedAll = fmt.Errorf("matched 'all'")
-	errMatchedA   = fmt.Errorf("matched 'a'")
-	errMatchedIP  = fmt.Errorf("matched 'ip'")
-	errMatchedMX  = fmt.Errorf("matched 'mx'")
-	errMatchedPTR = fmt.Errorf("matched 'ptr'")
+	errMatchedAll    = fmt.Errorf("matched 'all'")
+	errMatchedA      = fmt.Errorf("matched 'a'")
+	errMatchedIP     = fmt.Errorf("matched 'ip'")
+	errMatchedMX     = fmt.Errorf("matched 'mx'")
+	errMatchedPTR    = fmt.Errorf("matched 'ptr'")
+	errMatchedExists = fmt.Errorf("matched 'exists'")
 )
 
 // CheckHost fetches SPF records for `domain`, parses them, and evaluates them
@@ -252,9 +248,11 @@ func (r *resolution) Check(domain string) (Result, error) {
 				trace("ptr ok, %v %v", res, err)
 				return res, err
 			}
-		} else if strings.HasPrefix(lfield, "exists") {
-			trace("exists, neutral / not supported")
-			return Neutral, errExistsNotSupported
+		} else if strings.HasPrefix(lfield, "exists:") {
+			if ok, res, err := r.existsField(result, field, domain); ok {
+				trace("exists ok, %v %v", res, err)
+				return res, err
+			}
 		} else if strings.HasPrefix(lfield, "exp=") {
 			trace("exp= not used, skipping")
 			continue
@@ -378,6 +376,39 @@ func (r *resolution) ptrField(res Result, field, domain string) (bool, Result, e
 	return false, "", nil
 }
 
+// existsField processes a "exists" field.
+// https://tools.ietf.org/html/rfc7208#section-5.7
+func (r *resolution) existsField(res Result, field, domain string) (bool, Result, error) {
+	// The field is in the form "exists:<domain>".
+	eDomain := field[7:]
+	eDomain, err := r.expandMacros(eDomain, domain)
+	if err != nil {
+		return true, PermError, errInvalidMacro
+	}
+
+	if eDomain == "" {
+		return true, PermError, errInvalidDomain
+	}
+
+	r.count++
+	ips, err := lookupIP(eDomain)
+	if err != nil {
+		// https://tools.ietf.org/html/rfc7208#section-5
+		if isTemporary(err) {
+			return true, TempError, err
+		}
+		return false, "", err
+	}
+
+	// Exists only counts if there are IPv4 matches.
+	for _, ip := range ips {
+		if ip.To4() != nil {
+			return true, res, errMatchedExists
+		}
+	}
+	return false, "", nil
+}
+
 // includeField processes an "include" field.
 func (r *resolution) includeField(res Result, field, domain string) (bool, Result, error) {
 	// https://tools.ietf.org/html/rfc7208#section-5.2
@@ -579,7 +610,7 @@ func (r *resolution) expandMacros(s, domain string) (string, error) {
 	// doesn't, prevent them from sneaking through.
 	if strings.Contains(s, "/") {
 		trace("macro contains /")
-		return "", errInvalidMacro
+		return "", errInvalidDomain
 	}
 
 	// Bypass the complex logic if there are no macros present.
diff --git a/spf_test.go b/spf_test.go
index c2393e1..bbe3f42 100644
--- a/spf_test.go
+++ b/spf_test.go
@@ -60,6 +60,7 @@ func TestBasic(t *testing.T) {
 		{"v=spf1 ptr:lalala -all", Pass, errMatchedPTR},
 		{"v=spf1 ptr:doesnotexist -all", Fail, errMatchedAll},
 		{"v=spf1 blah", PermError, errUnknownField},
+		{"v=spf1 exists:d1111 -all", Pass, errMatchedExists},
 	}
 
 	dns.ip["d1111"] = []net.IP{ip1111}
@@ -130,24 +131,6 @@ func TestIPv6(t *testing.T) {
 	}
 }
 
-func TestNotSupported(t *testing.T) {
-	trace = t.Logf
-	cases := []struct {
-		txt string
-		err error
-	}{
-		{"v=spf1 exists:blah -all", errExistsNotSupported},
-	}
-
-	for _, c := range cases {
-		dns.txt["domain"] = []string{c.txt}
-		res, err := CheckHost(ip1111, "domain")
-		if res != Neutral || err != c.err {
-			t.Errorf("%q: expected neutral/%q, got %v/%q", c.txt, c.err, res, err)
-		}
-	}
-}
-
 func TestInclude(t *testing.T) {
 	// Test that the include is doing a recursive lookup.
 	// If we got a match on 1.1.1.1, is because include:domain2 did not match.