git » spf » commit 3cf50bb

Use net.IPMask to represent masks internally

author Alberto Bertogli
2021-05-09 09:10:19 UTC
committer Alberto Bertogli
2021-05-09 09:38:54 UTC
parent 25345a2e2784a550b613409b0abd49f59b209ab2

Use net.IPMask to represent masks internally

Currently we we internally use ints to represent IP masks, and then
convert them at checking time.

This carries unnecessary complexity, as we need to check the validity of
the masks in a couple of different places.

This patch simplifies the handling by moving to representing masks with
net.IPMask, which in turn simplifies a couple of helper functions.

Some tests are added to validate corner cases continue to work as
expected.

spf.go +20 -28
spf_test.go +64 -16

diff --git a/spf.go b/spf.go
index 38710e2..0cf1319 100644
--- a/spf.go
+++ b/spf.go
@@ -568,50 +568,48 @@ func (r *resolution) includeField(res Result, field, domain string) (bool, Resul
 }
 
 type dualMasks struct {
-	v4 int
-	v6 int
+	v4 net.IPMask
+	v6 net.IPMask
 }
 
-func ipMatch(ip, tomatch net.IP, masks dualMasks) (bool, error) {
-	mask := -1
-	if tomatch.To4() != nil && masks.v4 >= 0 {
+func ipMatch(ip, tomatch net.IP, masks dualMasks) bool {
+	mask := net.IPMask(nil)
+	if tomatch.To4() != nil && masks.v4 != nil {
 		mask = masks.v4
-	} else if tomatch.To4() == nil && masks.v6 >= 0 {
+	} else if tomatch.To4() == nil && masks.v6 != nil {
 		mask = masks.v6
 	}
 
-	if mask >= 0 {
-		_, ipnet, err := net.ParseCIDR(
-			fmt.Sprintf("%s/%d", tomatch.String(), mask))
-		if err != nil {
-			return false, errInvalidMask
-		}
-		return ipnet.Contains(ip), nil
+	if mask != nil {
+		ipnet := net.IPNet{IP: tomatch, Mask: mask}
+		return ipnet.Contains(ip)
 	}
 
-	return ip.Equal(tomatch), nil
+	return ip.Equal(tomatch)
 }
 
 var aRegexp = regexp.MustCompile(`^[aA](:([^/]+))?(/(\w+))?(//(\w+))?$`)
 var mxRegexp = regexp.MustCompile(`^[mM][xX](:([^/]+))?(/(\w+))?(//(\w+))?$`)
 
 func domainAndMask(re *regexp.Regexp, field, domain string) (string, dualMasks, error) {
-	masks := dualMasks{-1, -1}
+	masks := dualMasks{}
 	groups := re.FindStringSubmatch(field)
 	if groups != nil {
 		if groups[2] != "" {
 			domain = groups[2]
 		}
 		if groups[4] != "" {
-			mask4, err := strconv.Atoi(groups[4])
-			if err != nil || mask4 < 0 || mask4 > 32 {
+			i, err := strconv.Atoi(groups[4])
+			mask4 := net.CIDRMask(i, 32)
+			if err != nil || mask4 == nil {
 				return "", masks, errInvalidMask
 			}
 			masks.v4 = mask4
 		}
 		if groups[6] != "" {
-			mask6, err := strconv.Atoi(groups[6])
-			if err != nil || mask6 < 0 || mask6 > 128 {
+			i, err := strconv.Atoi(groups[6])
+			mask6 := net.CIDRMask(i, 128)
+			if err != nil || mask6 == nil {
 				return "", masks, errInvalidMask
 			}
 			masks.v6 = mask6
@@ -621,7 +619,7 @@ func domainAndMask(re *regexp.Regexp, field, domain string) (string, dualMasks,
 
 	// Test to catch malformed entries: if there's a /, there must be at least
 	// one mask.
-	if strings.Contains(field, "/") && masks.v4 == -1 && masks.v6 == -1 {
+	if strings.Contains(field, "/") && masks.v4 == nil && masks.v6 == nil {
 		return "", masks, errInvalidMask
 	}
 
@@ -650,12 +648,9 @@ func (r *resolution) aField(res Result, field, domain string) (bool, Result, err
 		return false, "", err
 	}
 	for _, ip := range ips {
-		ok, err := ipMatch(r.ip, ip.IP, masks)
-		if ok {
+		if ipMatch(r.ip, ip.IP, masks) {
 			trace("a matched %v, %v, %v", r.ip, ip.IP, masks)
 			return true, res, errMatchedA
-		} else if err != nil {
-			return true, PermError, err
 		}
 	}
 
@@ -706,12 +701,9 @@ func (r *resolution) mxField(res Result, field, domain string) (bool, Result, er
 		}
 	}
 	for _, ip := range mxips {
-		ok, err := ipMatch(r.ip, ip, masks)
-		if ok {
+		if ipMatch(r.ip, ip, masks) {
 			trace("mx matched %v, %v, %v", r.ip, ip, masks)
 			return true, res, errMatchedMX
-		} else if err != nil {
-			return true, PermError, err
 		}
 	}
 
diff --git a/spf_test.go b/spf_test.go
index 6b7787f..6d76d04 100644
--- a/spf_test.go
+++ b/spf_test.go
@@ -403,31 +403,34 @@ func mx(host string, pref uint16) *net.MX {
 	return &net.MX{Host: host, Pref: pref}
 }
 
+func mkDM(v4, v6 int) dualMasks {
+	return dualMasks{net.CIDRMask(v4, 32), net.CIDRMask(v6, 128)}
+}
+
 func TestIPMatchHelper(t *testing.T) {
 	cases := []struct {
 		ip      net.IP
 		tomatch net.IP
 		masks   dualMasks
 		ok      bool
-		err     error
 	}{
-		{ip1111, ip1110, dualMasks{24, -1}, true, nil},
-		{ip1111, ip1111, dualMasks{-1, -1}, true, nil},
-		{ip1111, ip1110, dualMasks{-1, -1}, false, nil},
-		{ip1111, ip1110, dualMasks{32, -1}, false, nil},
-		{ip1111, ip1110, dualMasks{99, -1}, false, errInvalidMask},
-
-		{ip6666, ip6660, dualMasks{-1, 100}, true, nil},
-		{ip6666, ip6666, dualMasks{-1, -1}, true, nil},
-		{ip6666, ip6660, dualMasks{-1, -1}, false, nil},
-		{ip6666, ip6660, dualMasks{-1, 128}, false, nil},
-		{ip6666, ip6660, dualMasks{-1, 200}, false, errInvalidMask},
+		{ip1111, ip1110, mkDM(24, -1), true},
+		{ip1111, ip1111, mkDM(-1, -1), true},
+		{ip1111, ip1110, mkDM(-1, -1), false},
+		{ip1111, ip1110, mkDM(32, -1), false},
+		{ip1111, ip1110, mkDM(99, -1), false},
+
+		{ip6666, ip6660, mkDM(-1, 100), true},
+		{ip6666, ip6666, mkDM(-1, -1), true},
+		{ip6666, ip6660, mkDM(-1, -1), false},
+		{ip6666, ip6660, mkDM(-1, 128), false},
+		{ip6666, ip6660, mkDM(-1, 200), false},
 	}
 	for _, c := range cases {
-		ok, err := ipMatch(c.ip, c.tomatch, c.masks)
-		if ok != c.ok || err != c.err {
-			t.Errorf("[%s %s/%v]: expected %v/%v, got %v/%v",
-				c.ip, c.tomatch, c.masks, c.ok, c.err, ok, err)
+		ok := ipMatch(c.ip, c.tomatch, c.masks)
+		if ok != c.ok {
+			t.Errorf("[%s %s/%v]: expected %v, got %v",
+				c.ip, c.tomatch, c.masks, c.ok, ok)
 		}
 	}
 }
@@ -539,3 +542,48 @@ func TestWithResolver(t *testing.T) {
 		t.Errorf("expected pass, got %q / %q", res, err)
 	}
 }
+
+// Test some corner cases when resolver.LookupIPAddr returns an invalid
+// address. This can happen if using a buggy custom resolver.
+func TestBadResolverResponse(t *testing.T) {
+	dns := NewResolver()
+	trace = t.Logf
+
+	// When LookupIPAddr returns an invalid ip, for an "a" field.
+	dns.ip["domain1"] = []net.IP{nil}
+	dns.txt["domain1"] = []string{"v=spf1 a:domain1 -all"}
+	res, err := CheckHostWithSender(ip1111, "helo", "user@domain1",
+		WithResolver(dns))
+	if res != Fail {
+		t.Errorf("expected fail, got %q / %q", res, err)
+	}
+
+	// Same as above, except the field has a mask.
+	dns.ip["domain1"] = []net.IP{nil}
+	dns.txt["domain1"] = []string{"v=spf1 a:domain1//24 -all"}
+	res, err = CheckHostWithSender(ip1111, "helo", "user@domain1",
+		WithResolver(dns))
+	if res != Fail {
+		t.Errorf("expected fail, got %q / %q", res, err)
+	}
+
+	// When LookupIPAddr returns an invalid ip, for an "mx" field.
+	dns.ip["mx.domain1"] = []net.IP{nil}
+	dns.mx["domain1"] = []*net.MX{mx("mx.domain1", 5)}
+	dns.txt["domain1"] = []string{"v=spf1 mx:domain1 -all"}
+	res, err = CheckHostWithSender(ip1111, "helo", "user@domain1",
+		WithResolver(dns))
+	if res != Fail {
+		t.Errorf("expected fail, got %q / %q", res, err)
+	}
+
+	// Same as above, except the field has a mask.
+	dns.ip["mx.domain1"] = []net.IP{nil}
+	dns.mx["domain1"] = []*net.MX{mx("mx.domain1", 5)}
+	dns.txt["domain1"] = []string{"v=spf1 mx:domain1//24 -all"}
+	res, err = CheckHostWithSender(ip1111, "helo", "user@domain1",
+		WithResolver(dns))
+	if res != Fail {
+		t.Errorf("expected fail, got %q / %q", res, err)
+	}
+}