git » spf » commit 54adca3

Use static errors in most cases

author Alberto Bertogli
2018-03-18 11:58:38 UTC
committer Alberto Bertogli
2018-03-18 13:16:26 UTC
parent a230f6e74cd0e6cde4d3fd530b63f164017bac9d

Use static errors in most cases

In order to reduce the potential attack vector, and to make tests more
stringent, this patch makes the errors generated by the library static.

The tests are extended to check for the particular errors as well.

Some errors from the standard library still may be propagated upwards;
those will be addressed later if needed.

Note the errors are not exported, and not a part of the API; this is
just for security and testability purposes.

spf.go +33 -17
spf_test.go +72 -57

diff --git a/spf.go b/spf.go
index 1f22c56..62f2eca 100644
--- a/spf.go
+++ b/spf.go
@@ -88,6 +88,22 @@ var qualToResult = map[byte]Result{
 	'?': Neutral,
 }
 
+var (
+	errLookupLimitReached = fmt.Errorf("lookup limit reached")
+	errMacrosNotSupported = fmt.Errorf("macros not supported")
+	errExistsNotSupported = fmt.Errorf("'exists' not supported")
+	errExpNotSupported    = fmt.Errorf("'exp' not supported")
+	errUnknownField       = fmt.Errorf("unknown field")
+	errInvalidIP          = fmt.Errorf("invalid ipX value")
+	errInvalidMask        = fmt.Errorf("invalid mask")
+
+	errMatchedAll = fmt.Errorf("matched 'all'")
+	errMatchedA   = fmt.Errorf("matched 'a'")
+	errMatchedIP  = fmt.Errorf("matched 'ip'")
+	errMatchedMX  = fmt.Errorf("matched 'mx'")
+	errMatchedPTR = fmt.Errorf("matched 'ptr'")
+)
+
 // CheckHost fetches SPF records for `domain`, parses them, and evaluates them
 // to determine if `ip` is permitted to send mail for it.
 // Reference: https://tools.ietf.org/html/rfc7208#section-4
@@ -144,11 +160,11 @@ func (r *resolution) Check(domain string) (Result, error) {
 		// Limit the number of resolutions to 10
 		// https://tools.ietf.org/html/rfc7208#section-4.6.4
 		if r.count > 10 {
-			return PermError, fmt.Errorf("lookup limit reached")
+			return PermError, errLookupLimitReached
 		}
 
 		if strings.Contains(field, "%") {
-			return Neutral, fmt.Errorf("macros not supported")
+			return Neutral, errMacrosNotSupported
 		}
 
 		// See if we have a qualifier, defaulting to + (pass).
@@ -162,7 +178,7 @@ func (r *resolution) Check(domain string) (Result, error) {
 
 		if field == "all" {
 			// https://tools.ietf.org/html/rfc7208#section-5.1
-			return result, fmt.Errorf("matched 'all'")
+			return result, errMatchedAll
 		} else if strings.HasPrefix(field, "include:") {
 			if ok, res, err := r.includeField(result, field); ok {
 				return res, err
@@ -184,9 +200,9 @@ func (r *resolution) Check(domain string) (Result, error) {
 				return res, err
 			}
 		} else if strings.HasPrefix(field, "exists") {
-			return Neutral, fmt.Errorf("'exists' not supported")
+			return Neutral, errExistsNotSupported
 		} else if strings.HasPrefix(field, "exp=") {
-			return Neutral, fmt.Errorf("'exp' not supported")
+			return Neutral, errExpNotSupported
 		} else if strings.HasPrefix(field, "redirect=") {
 			// https://tools.ietf.org/html/rfc7208#section-6.1
 			result, err := r.Check(field[len("redirect="):])
@@ -196,7 +212,7 @@ func (r *resolution) Check(domain string) (Result, error) {
 			return result, err
 		} else {
 			// http://www.openspf.org/SPF_Record_Syntax
-			return PermError, fmt.Errorf("unknown field %q", field)
+			return PermError, errUnknownField
 		}
 	}
 
@@ -242,18 +258,18 @@ func (r *resolution) ipField(res Result, field string) (bool, Result, error) {
 	if strings.Contains(fip, "/") {
 		_, ipnet, err := net.ParseCIDR(fip)
 		if err != nil {
-			return true, PermError, err
+			return true, PermError, errInvalidMask
 		}
 		if ipnet.Contains(r.ip) {
-			return true, res, fmt.Errorf("matched %v", ipnet)
+			return true, res, errMatchedIP
 		}
 	} else {
 		ip := net.ParseIP(fip)
 		if ip == nil {
-			return true, PermError, fmt.Errorf("invalid ipX value")
+			return true, PermError, errInvalidIP
 		}
 		if ip.Equal(r.ip) {
-			return true, res, fmt.Errorf("matched %v", ip)
+			return true, res, errMatchedIP
 		}
 	}
 
@@ -283,7 +299,7 @@ func (r *resolution) ptrField(res Result, field, domain string) (bool, Result, e
 
 	for _, n := range r.ipNames {
 		if strings.HasSuffix(n, domain+".") {
-			return true, res, fmt.Errorf("matched ptr:%s", domain)
+			return true, res, errMatchedPTR
 		}
 	}
 
@@ -314,15 +330,15 @@ func ipMatch(ip, tomatch net.IP, mask int) (bool, error) {
 	if mask >= 0 {
 		_, ipnet, err := net.ParseCIDR(fmt.Sprintf("%s/%d", tomatch.String(), mask))
 		if err != nil {
-			return false, err
+			return false, errInvalidMask
 		}
 		if ipnet.Contains(ip) {
-			return true, fmt.Errorf("%v", ipnet)
+			return true, nil
 		}
 		return false, nil
 	} else {
 		if ip.Equal(tomatch) {
-			return true, fmt.Errorf("%v", tomatch)
+			return true, nil
 		}
 		return false, nil
 	}
@@ -341,7 +357,7 @@ func domainAndMask(re *regexp.Regexp, field, domain string) (string, int, error)
 		if groups[4] != "" {
 			mask, err = strconv.Atoi(groups[4])
 			if err != nil {
-				return "", -1, fmt.Errorf("error parsing mask")
+				return "", -1, errInvalidMask
 			}
 		}
 	}
@@ -369,7 +385,7 @@ func (r *resolution) aField(res Result, field, domain string) (bool, Result, err
 	for _, ip := range ips {
 		ok, err := ipMatch(r.ip, ip, mask)
 		if ok {
-			return true, res, fmt.Errorf("matched 'a' (%v)", err)
+			return true, res, errMatchedA
 		} else if err != nil {
 			return true, PermError, err
 		}
@@ -411,7 +427,7 @@ func (r *resolution) mxField(res Result, field, domain string) (bool, Result, er
 	for _, ip := range mxips {
 		ok, err := ipMatch(r.ip, ip, mask)
 		if ok {
-			return true, res, fmt.Errorf("matched 'mx' (%v)", err)
+			return true, res, errMatchedMX
 		} else if err != nil {
 			return true, PermError, err
 		}
diff --git a/spf_test.go b/spf_test.go
index 8337e75..65ae45a 100644
--- a/spf_test.go
+++ b/spf_test.go
@@ -55,34 +55,42 @@ func TestBasic(t *testing.T) {
 	cases := []struct {
 		txt string
 		res Result
+		err error
 	}{
-		{"", None},
-		{"blah", None},
-		{"v=spf1", Neutral},
-		{"v=spf1 ", Neutral},
-		{"v=spf1 -", PermError},
-		{"v=spf1 all", Pass},
-		{"v=spf1  +all", Pass},
-		{"v=spf1 -all ", Fail},
-		{"v=spf1 ~all", SoftFail},
-		{"v=spf1 ?all", Neutral},
-		{"v=spf1 a ~all", SoftFail},
-		{"v=spf1 a/24", Neutral},
-		{"v=spf1 a:d1110/24", Pass},
-		{"v=spf1 a:d1110", Neutral},
-		{"v=spf1 a:d1111", Pass},
-		{"v=spf1 a:nothing/24", Neutral},
-		{"v=spf1 mx", Neutral},
-		{"v=spf1 mx/24", Neutral},
-		{"v=spf1 mx:a/montoto ~all", PermError},
-		{"v=spf1 mx:d1110/24 ~all", Pass},
-		{"v=spf1 ip4:1.2.3.4 ~all", SoftFail},
-		{"v=spf1 ip6:12 ~all", PermError},
-		{"v=spf1 ip4:1.1.1.1 -all", Pass},
-		{"v=spf1 ptr -all", Pass},
-		{"v=spf1 ptr:d1111 -all", Pass},
-		{"v=spf1 ptr:lalala -all", Pass},
-		{"v=spf1 blah", PermError},
+		{"", None, nil},
+		{"blah", None, nil},
+		{"v=spf1", Neutral, nil},
+		{"v=spf1 ", Neutral, nil},
+		{"v=spf1 -", PermError, errUnknownField},
+		{"v=spf1 all", Pass, errMatchedAll},
+		{"v=spf1  +all", Pass, errMatchedAll},
+		{"v=spf1 -all ", Fail, errMatchedAll},
+		{"v=spf1 ~all", SoftFail, errMatchedAll},
+		{"v=spf1 ?all", Neutral, errMatchedAll},
+		{"v=spf1 a ~all", SoftFail, errMatchedAll},
+		{"v=spf1 a/24", Neutral, nil},
+		{"v=spf1 a:d1110/24", Pass, errMatchedA},
+		{"v=spf1 a:d1110/montoto", PermError, errInvalidMask},
+		{"v=spf1 a:d1110/99", PermError, errInvalidMask},
+		{"v=spf1 a:d1110/32", Neutral, nil},
+		{"v=spf1 a:d1110", Neutral, nil},
+		{"v=spf1 a:d1111", Pass, errMatchedA},
+		{"v=spf1 a:nothing/24", Neutral, nil},
+		{"v=spf1 mx", Neutral, nil},
+		{"v=spf1 mx/24", Neutral, nil},
+		{"v=spf1 mx:a/montoto ~all", PermError, errInvalidMask},
+		{"v=spf1 mx:d1110/24 ~all", Pass, errMatchedMX},
+		{"v=spf1 mx:d1110/99 ~all", PermError, errInvalidMask},
+		{"v=spf1 ip4:1.2.3.4 ~all", SoftFail, errMatchedAll},
+		{"v=spf1 ip6:12 ~all", PermError, errInvalidIP},
+		{"v=spf1 ip4:1.1.1.1 -all", Pass, errMatchedIP},
+		{"v=spf1 ip4:1.1.1.1/24 -all", Pass, errMatchedIP},
+		{"v=spf1 ip4:1.1.1.1/lala -all", PermError, errInvalidMask},
+		{"v=spf1 ptr -all", Pass, errMatchedPTR},
+		{"v=spf1 ptr:d1111 -all", Pass, errMatchedPTR},
+		{"v=spf1 ptr:lalala -all", Pass, errMatchedPTR},
+		{"v=spf1 ptr:doesnotexist -all", Fail, errMatchedAll},
+		{"v=spf1 blah", PermError, errUnknownField},
 	}
 
 	ipResults["d1111"] = []net.IP{ip1111}
@@ -98,7 +106,9 @@ func TestBasic(t *testing.T) {
 		}
 		if res != c.res {
 			t.Errorf("%q: expected %q, got %q", c.txt, c.res, res)
-			t.Logf("%q:   error: %v", c.txt, err)
+		}
+		if err != c.err {
+			t.Errorf("%q: expected error [%v], got [%v]", c.txt, c.err, err)
 		}
 	}
 }
@@ -107,21 +117,22 @@ func TestIPv6(t *testing.T) {
 	cases := []struct {
 		txt string
 		res Result
+		err error
 	}{
-		{"v=spf1 all", Pass},
-		{"v=spf1 a ~all", SoftFail},
-		{"v=spf1 a/24", Neutral},
-		{"v=spf1 a:d6660/24", Pass},
-		{"v=spf1 a:d6660", Neutral},
-		{"v=spf1 a:d6666", Pass},
-		{"v=spf1 a:nothing/24", Neutral},
-		{"v=spf1 mx:d6660/24 ~all", Pass},
-		{"v=spf1 ip6:2001:db8::68 ~all", Pass},
-		{"v=spf1 ip6:2001:db8::1/24 ~all", Pass},
-		{"v=spf1 ip6:2001:db8::1/100 ~all", Pass},
-		{"v=spf1 ptr -all", Pass},
-		{"v=spf1 ptr:d6666 -all", Pass},
-		{"v=spf1 ptr:sonlas6 -all", Pass},
+		{"v=spf1 all", Pass, errMatchedAll},
+		{"v=spf1 a ~all", SoftFail, errMatchedAll},
+		{"v=spf1 a/24", Neutral, nil},
+		{"v=spf1 a:d6660/24", Pass, errMatchedA},
+		{"v=spf1 a:d6660", Neutral, nil},
+		{"v=spf1 a:d6666", Pass, errMatchedA},
+		{"v=spf1 a:nothing/24", Neutral, nil},
+		{"v=spf1 mx:d6660/24 ~all", Pass, errMatchedMX},
+		{"v=spf1 ip6:2001:db8::68 ~all", Pass, errMatchedIP},
+		{"v=spf1 ip6:2001:db8::1/24 ~all", Pass, errMatchedIP},
+		{"v=spf1 ip6:2001:db8::1/100 ~all", Pass, errMatchedIP},
+		{"v=spf1 ptr -all", Pass, errMatchedPTR},
+		{"v=spf1 ptr:d6666 -all", Pass, errMatchedPTR},
+		{"v=spf1 ptr:sonlas6 -all", Pass, errMatchedPTR},
 	}
 
 	ipResults["d6666"] = []net.IP{ip6666}
@@ -137,25 +148,29 @@ func TestIPv6(t *testing.T) {
 		}
 		if res != c.res {
 			t.Errorf("%q: expected %q, got %q", c.txt, c.res, res)
-			t.Logf("%q:   error: %v", c.txt, err)
+		}
+		if err != c.err {
+			t.Errorf("%q: expected error [%v], got [%v]", c.txt, c.err, err)
 		}
 	}
 }
 
 func TestNotSupported(t *testing.T) {
-	cases := []string{
-		"v=spf1 exists:blah -all",
-		"v=spf1 exp=blah -all",
-		"v=spf1 a:%{o} -all",
-		"v=spf1 redirect=_spf.%{d}",
+	cases := []struct {
+		txt string
+		err error
+	}{
+		{"v=spf1 exists:blah -all", errExistsNotSupported},
+		{"v=spf1 exp=blah -all", errExpNotSupported},
+		{"v=spf1 a:%{o} -all", errMacrosNotSupported},
+		{"v=spf1 redirect=_spf.%{d}", errMacrosNotSupported},
 	}
 
-	for _, txt := range cases {
-		txtResults["domain"] = []string{txt}
+	for _, c := range cases {
+		txtResults["domain"] = []string{c.txt}
 		res, err := CheckHost(ip1111, "domain")
-		if res != Neutral {
-			t.Errorf("%q: expected neutral, got %v", txt, res)
-			t.Logf("%q:   error: %v", txt, err)
+		if res != Neutral || err != c.err {
+			t.Errorf("%q: expected neutral/%q, got %v/%q", c.txt, c.err, res, err)
 		}
 	}
 }
@@ -164,7 +179,7 @@ func TestRecursion(t *testing.T) {
 	txtResults["domain"] = []string{"v=spf1 include:domain ~all"}
 
 	res, err := CheckHost(ip1111, "domain")
-	if res != PermError {
+	if res != PermError || err != errLookupLimitReached {
 		t.Errorf("expected permerror, got %v (%v)", res, err)
 	}
 }
@@ -191,7 +206,7 @@ func TestInvalidRedirect(t *testing.T) {
 	}
 
 	res, err = CheckHost(ip1111, "domain")
-	if res != PermError {
+	if res != PermError || err != nil {
 		t.Errorf("expected permerror, got %v (%v)", res, err)
 	}
 }
@@ -203,13 +218,13 @@ func TestRedirectOrder(t *testing.T) {
 
 	txtResults["domain"] = []string{"v=spf1 redirect=faildom"}
 	res, err := CheckHost(ip1111, "domain")
-	if res != Fail {
+	if res != Fail || err != errMatchedAll {
 		t.Errorf("expected fail, got %v (%v)", res, err)
 	}
 
 	txtResults["domain"] = []string{"v=spf1 redirect=faildom all"}
 	res, err = CheckHost(ip1111, "domain")
-	if res != Pass {
+	if res != Pass || err != errMatchedAll {
 		t.Errorf("expected pass, got %v (%v)", res, err)
 	}
 }