git » debian:golang-blitiri-go-spf » commit 8d4219a

spf: Support the "ptr" mechanism

author Alberto Bertogli
2016-10-21 23:07:07 UTC
committer Alberto Bertogli
2016-10-21 23:07:07 UTC
parent b45726f5cbfa404e5fd79b29eb20ba7a6f8532fc

spf: Support the "ptr" mechanism

The "ptr" mechanism should not be used, but unfortunately many domains
still do.

This patch extends our SPF implementation to support checking it.

https://tools.ietf.org/html/rfc7208#section-5.5

spf.go +42 -7
spf_test.go +52 -1

diff --git a/spf.go b/spf.go
index fbc28d6..adae64b 100644
--- a/spf.go
+++ b/spf.go
@@ -11,7 +11,6 @@
 //
 // Not supported (return Neutral if used):
 //  - "exists".
-//  - "ptr".
 //  - "exp".
 //  - Macros.
 //
@@ -30,9 +29,10 @@ import (
 
 // Functions that we can override for testing purposes.
 var (
-	lookupTXT = net.LookupTXT
-	lookupMX  = net.LookupMX
-	lookupIP  = net.LookupIP
+	lookupTXT  = net.LookupTXT
+	lookupMX   = net.LookupMX
+	lookupIP   = net.LookupIP
+	lookupAddr = net.LookupAddr
 )
 
 // Results and Errors. Note the values have meaning, we use them in headers.
@@ -81,13 +81,16 @@ var QualToResult = map[byte]Result{
 // with a given identity.
 // Reference: https://tools.ietf.org/html/rfc7208#section-4
 func CheckHost(ip net.IP, domain string) (Result, error) {
-	r := &resolution{ip, 0}
+	r := &resolution{ip, 0, nil}
 	return r.Check(domain)
 }
 
 type resolution struct {
 	ip    net.IP
 	count uint
+
+	// Result of doing a reverse lookup for ip (so we only do it once).
+	ipNames []string
 }
 
 func (r *resolution) Check(domain string) (Result, error) {
@@ -167,10 +170,12 @@ func (r *resolution) Check(domain string) (Result, error) {
 			if ok, res, err := r.ipField(result, field); ok {
 				return res, err
 			}
+		} else if strings.HasPrefix(field, "ptr") {
+			if ok, res, err := r.ptrField(result, field, domain); ok {
+				return res, err
+			}
 		} else if strings.HasPrefix(field, "exists") {
 			return Neutral, fmt.Errorf("'exists' not supported")
-		} else if strings.HasPrefix(field, "ptr") {
-			return Neutral, fmt.Errorf("'ptr' not supported")
 		} else if strings.HasPrefix(field, "exp=") {
 			return Neutral, fmt.Errorf("'exp' not supported")
 		} else if strings.HasPrefix(field, "redirect=") {
@@ -246,6 +251,36 @@ func (r *resolution) ipField(res Result, field string) (bool, Result, error) {
 	return false, "", nil
 }
 
+// ptrField processes a "ptr" field.
+func (r *resolution) ptrField(res Result, field, domain string) (bool, Result, error) {
+	// Extract the domain if the field is in the form "ptr:domain"
+	if len(field) >= 4 {
+		domain = field[4:]
+
+	}
+
+	if r.ipNames == nil {
+		r.count++
+		n, err := lookupAddr(r.ip.String())
+		if err != nil {
+			// https://tools.ietf.org/html/rfc7208#section-5
+			if isTemporary(err) {
+				return true, TempError, err
+			}
+			return false, "", err
+		}
+		r.ipNames = n
+	}
+
+	for _, n := range r.ipNames {
+		if strings.HasSuffix(n, domain+".") {
+			return true, res, fmt.Errorf("matched ptr:%s", domain)
+		}
+	}
+
+	return false, "", nil
+}
+
 // includeField processes an "include" field.
 func (r *resolution) includeField(res Result, field string) (bool, Result, error) {
 	// https://tools.ietf.org/html/rfc7208#section-5.2
diff --git a/spf_test.go b/spf_test.go
index a5007dc..8d8d813 100644
--- a/spf_test.go
+++ b/spf_test.go
@@ -29,10 +29,18 @@ func LookupIP(host string) (ips []net.IP, err error) {
 	return ipResults[host], ipErrors[host]
 }
 
+var addrResults = map[string][]string{}
+var addrErrors = map[string]error{}
+
+func LookupAddr(host string) (addrs []string, err error) {
+	return addrResults[host], addrErrors[host]
+}
+
 func TestMain(m *testing.M) {
 	lookupTXT = LookupTXT
 	lookupMX = LookupMX
 	lookupIP = LookupIP
+	lookupAddr = LookupAddr
 
 	flag.Parse()
 	os.Exit(m.Run())
@@ -41,6 +49,7 @@ func TestMain(m *testing.M) {
 var ip1110 = net.ParseIP("1.1.1.0")
 var ip1111 = net.ParseIP("1.1.1.1")
 var ip6666 = net.ParseIP("2001:db8::68")
+var ip6660 = net.ParseIP("2001:db8::0")
 
 func TestBasic(t *testing.T) {
 	cases := []struct {
@@ -70,12 +79,16 @@ func TestBasic(t *testing.T) {
 		{"v=spf1 ip4:1.2.3.4 ~all", SoftFail},
 		{"v=spf1 ip6:12 ~all", PermError},
 		{"v=spf1 ip4:1.1.1.1 -all", Pass},
+		{"v=spf1 ptr -all", Pass},
+		{"v=spf1 ptr:d1111 -all", Pass},
+		{"v=spf1 ptr:lalala -all", Pass},
 		{"v=spf1 blah", PermError},
 	}
 
 	ipResults["d1111"] = []net.IP{ip1111}
 	ipResults["d1110"] = []net.IP{ip1110}
 	mxResults["d1110"] = []*net.MX{{"d1110", 5}, {"nothing", 10}}
+	addrResults["1.1.1.1"] = []string{"lalala.", "domain.", "d1111."}
 
 	for _, c := range cases {
 		txtResults["domain"] = []string{c.txt}
@@ -90,10 +103,48 @@ func TestBasic(t *testing.T) {
 	}
 }
 
+func TestIPv6(t *testing.T) {
+	cases := []struct {
+		txt string
+		res Result
+	}{
+		{"v=spf1 all", Pass},
+		{"v=spf1 a ~all", SoftFail},
+		{"v=spf1 a/24", Neutral},
+		{"v=spf1 a:d6660/24", Pass},
+		{"v=spf1 a:d6660", Neutral},
+		{"v=spf1 a:d6666", Pass},
+		{"v=spf1 a:nothing/24", Neutral},
+		{"v=spf1 mx:d6660/24 ~all", Pass},
+		{"v=spf1 ip6:2001:db8::68 ~all", Pass},
+		{"v=spf1 ip6:2001:db8::1/24 ~all", Pass},
+		{"v=spf1 ip6:2001:db8::1/100 ~all", Pass},
+		{"v=spf1 ptr -all", Pass},
+		{"v=spf1 ptr:d6666 -all", Pass},
+		{"v=spf1 ptr:sonlas6 -all", Pass},
+	}
+
+	ipResults["d6666"] = []net.IP{ip6666}
+	ipResults["d6660"] = []net.IP{ip6660}
+	mxResults["d6660"] = []*net.MX{{"d6660", 5}, {"nothing", 10}}
+	addrResults["2001:db8::68"] = []string{"sonlas6.", "domain.", "d6666."}
+
+	for _, c := range cases {
+		txtResults["domain"] = []string{c.txt}
+		res, err := CheckHost(ip6666, "domain")
+		if (res == TempError || res == PermError) && (err == nil) {
+			t.Errorf("%q: expected error, got nil", c.txt)
+		}
+		if res != c.res {
+			t.Errorf("%q: expected %q, got %q", c.txt, c.res, res)
+			t.Logf("%q:   error: %v", c.txt, err)
+		}
+	}
+}
+
 func TestNotSupported(t *testing.T) {
 	cases := []string{
 		"v=spf1 exists:blah -all",
-		"v=spf1 ptr -all",
 		"v=spf1 exp=blah -all",
 		"v=spf1 a:%{o} -all",
 	}