git » spf » commit f8b7a67

a, mx: Support dual CIDR length masks

author Alberto Bertogli
2019-10-13 20:34:18 UTC
committer Alberto Bertogli
2019-10-14 12:31:32 UTC
parent fcd50be18d1a8aeeb5b31c7dfba8a1c4d4bb411f

a, mx: Support dual CIDR length masks

For A and MX mechanisms, the masks are specified with a special syntax,
allowing no/IPv4/IPv6/both masks.

The form is "[/<mask4>][//<mask6>]". Note that both are optional, and if
only mask6 is given, it still needs to be prefixed by "//".

The current implementation doesn't handle this well, so this patch
extends the mask management significantly to support the masks properly.

Found by the standard test suite.

spf.go +47 -20
spf_test.go +10 -3

diff --git a/spf.go b/spf.go
index c169110..67ab070 100644
--- a/spf.go
+++ b/spf.go
@@ -402,16 +402,26 @@ func (r *resolution) includeField(res Result, field string) (bool, Result, error
 	return false, "", fmt.Errorf("This should never be reached")
 }
 
-func ipMatch(ip, tomatch net.IP, mask int) (bool, error) {
+type DualMasks struct {
+	v4 int
+	v6 int
+}
+
+func ipMatch(ip, tomatch net.IP, masks DualMasks) (bool, error) {
+	mask := -1
+	if tomatch.To4() != nil && masks.v4 >= 0 {
+		mask = masks.v4
+	} else if tomatch.To4() == nil && masks.v6 >= 0 {
+		mask = masks.v6
+	}
+
 	if mask >= 0 {
-		_, ipnet, err := net.ParseCIDR(fmt.Sprintf("%s/%d", tomatch.String(), mask))
+		_, ipnet, err := net.ParseCIDR(
+			fmt.Sprintf("%s/%d", tomatch.String(), mask))
 		if err != nil {
 			return false, errInvalidMask
 		}
-		if ipnet.Contains(ip) {
-			return true, nil
-		}
-		return false, nil
+		return ipnet.Contains(ip), nil
 	} else {
 		if ip.Equal(tomatch) {
 			return true, nil
@@ -420,31 +430,46 @@ func ipMatch(ip, tomatch net.IP, mask int) (bool, error) {
 	}
 }
 
-var aRegexp = regexp.MustCompile("[aA](:([^/]+))?(/(.+))?")
-var mxRegexp = regexp.MustCompile("[mM][xX](:([^/]+))?(/(.+))?")
+var aRegexp = regexp.MustCompile(`^[aA](:([^/]+))?(/(\w+))?(//(\w+))?$`)
+var mxRegexp = regexp.MustCompile(`^[mM][xX](:([^/]+))?(/(\w+))?(//(\w+))?$`)
 
-func domainAndMask(re *regexp.Regexp, field, domain string) (string, int, error) {
-	var err error
-	mask := -1
-	if groups := re.FindStringSubmatch(field); groups != nil {
+func domainAndMask(re *regexp.Regexp, field, domain string) (string, DualMasks, error) {
+	masks := DualMasks{-1, -1}
+	groups := re.FindStringSubmatch(field)
+	if groups != nil {
 		if groups[2] != "" {
 			domain = groups[2]
 		}
 		if groups[4] != "" {
-			mask, err = strconv.Atoi(groups[4])
-			if err != nil {
-				return "", -1, errInvalidMask
+			mask4, err := strconv.Atoi(groups[4])
+			if err != nil || mask4 < 0 || mask4 > 32 {
+				return "", masks, errInvalidMask
+			}
+			masks.v4 = mask4
+		}
+		if groups[6] != "" {
+			mask6, err := strconv.Atoi(groups[6])
+			if err != nil || mask6 < 0 || mask6 > 128 {
+				return "", masks, errInvalidMask
 			}
+			masks.v6 = mask6
 		}
 	}
+	trace("masks on %q: %q %q %v", field, groups, domain, masks)
+
+	// Test to catch malformed entries: if there's a /, there must be at least
+	// one mask.
+	if strings.Contains(field, "/") && masks.v4 == -1 && masks.v6 == -1 {
+		return "", masks, errInvalidMask
+	}
 
-	return domain, mask, nil
+	return domain, masks, nil
 }
 
 // aField processes an "a" field.
 func (r *resolution) aField(res Result, field, domain string) (bool, Result, error) {
 	// https://tools.ietf.org/html/rfc7208#section-5.3
-	domain, mask, err := domainAndMask(aRegexp, field, domain)
+	domain, masks, err := domainAndMask(aRegexp, field, domain)
 	if err != nil {
 		return true, PermError, err
 	}
@@ -459,8 +484,9 @@ func (r *resolution) aField(res Result, field, domain string) (bool, Result, err
 		return false, "", err
 	}
 	for _, ip := range ips {
-		ok, err := ipMatch(r.ip, ip, mask)
+		ok, err := ipMatch(r.ip, ip, masks)
 		if ok {
+			trace("mx matched %v, %v, %v", r.ip, ip, masks)
 			return true, res, errMatchedA
 		} else if err != nil {
 			return true, PermError, err
@@ -473,7 +499,7 @@ func (r *resolution) aField(res Result, field, domain string) (bool, Result, err
 // mxField processes an "mx" field.
 func (r *resolution) mxField(res Result, field, domain string) (bool, Result, error) {
 	// https://tools.ietf.org/html/rfc7208#section-5.4
-	domain, mask, err := domainAndMask(mxRegexp, field, domain)
+	domain, masks, err := domainAndMask(mxRegexp, field, domain)
 	if err != nil {
 		return true, PermError, err
 	}
@@ -501,8 +527,9 @@ func (r *resolution) mxField(res Result, field, domain string) (bool, Result, er
 		mxips = append(mxips, ips...)
 	}
 	for _, ip := range mxips {
-		ok, err := ipMatch(r.ip, ip, mask)
+		ok, err := ipMatch(r.ip, ip, masks)
 		if ok {
+			trace("mx matched %v, %v, %v", r.ip, ip, masks)
 			return true, res, errMatchedMX
 		} else if err != nil {
 			return true, PermError, err
diff --git a/spf_test.go b/spf_test.go
index aa80a4b..acbff91 100644
--- a/spf_test.go
+++ b/spf_test.go
@@ -44,12 +44,16 @@ func TestBasic(t *testing.T) {
 		{"v=spf1 mx/24", Neutral, nil},
 		{"v=spf1 mx:a/montoto ~all", PermError, errInvalidMask},
 		{"v=spf1 mx:d1110/24 ~all", Pass, errMatchedMX},
+		{"v=spf1 mx:d1110/24//100 ~all", Pass, errMatchedMX},
+		{"v=spf1 mx:d1110/24//129 ~all", PermError, errInvalidMask},
+		{"v=spf1 mx:d1110/24/100 ~all", PermError, errInvalidMask},
 		{"v=spf1 mx:d1110/99 ~all", PermError, errInvalidMask},
 		{"v=spf1 ip4:1.2.3.4 ~all", SoftFail, errMatchedAll},
 		{"v=spf1 ip6:12 ~all", PermError, errInvalidIP},
 		{"v=spf1 ip4:1.1.1.1 -all", Pass, errMatchedIP},
 		{"v=spf1 ip4:1.1.1.1/24 -all", Pass, errMatchedIP},
 		{"v=spf1 ip4:1.1.1.1/lala -all", PermError, errInvalidMask},
+		{"v=spf1 ip4:1.1.1.1/33 -all", PermError, errInvalidMask},
 		{"v=spf1 include:doesnotexist", PermError, errNoResult},
 		{"v=spf1 ptr -all", Pass, errMatchedPTR},
 		{"v=spf1 ptr:d1111 -all", Pass, errMatchedPTR},
@@ -90,11 +94,14 @@ func TestIPv6(t *testing.T) {
 		{"v=spf1 all", Pass, errMatchedAll},
 		{"v=spf1 a ~all", SoftFail, errMatchedAll},
 		{"v=spf1 a/24", Neutral, nil},
-		{"v=spf1 a:d6660/24", Pass, errMatchedA},
+		{"v=spf1 a:d6660//24", Pass, errMatchedA},
+		{"v=spf1 a:d6660/24//100", Pass, errMatchedA},
 		{"v=spf1 a:d6660", Neutral, nil},
 		{"v=spf1 a:d6666", Pass, errMatchedA},
-		{"v=spf1 a:nothing/24", Neutral, nil},
-		{"v=spf1 mx:d6660/24 ~all", Pass, errMatchedMX},
+		{"v=spf1 a:nothing//24", Neutral, nil},
+		{"v=spf1 mx:d6660//24 ~all", Pass, errMatchedMX},
+		{"v=spf1 mx:d6660/24//100 ~all", Pass, errMatchedMX},
+		{"v=spf1 mx:d6660/24/100 ~all", PermError, errInvalidMask},
 		{"v=spf1 ip6:2001:db8::68 ~all", Pass, errMatchedIP},
 		{"v=spf1 ip6:2001:db8::1/24 ~all", Pass, errMatchedIP},
 		{"v=spf1 ip6:2001:db8::1/100 ~all", Pass, errMatchedIP},