author | Alberto Bertogli
<albertito@blitiri.com.ar> 2018-03-18 11:58:38 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2018-03-18 13:16:26 UTC |
parent | a230f6e74cd0e6cde4d3fd530b63f164017bac9d |
spf.go | +33 | -17 |
spf_test.go | +72 | -57 |
diff --git a/spf.go b/spf.go index 1f22c56..62f2eca 100644 --- a/spf.go +++ b/spf.go @@ -88,6 +88,22 @@ var qualToResult = map[byte]Result{ '?': Neutral, } +var ( + errLookupLimitReached = fmt.Errorf("lookup limit reached") + errMacrosNotSupported = fmt.Errorf("macros not supported") + errExistsNotSupported = fmt.Errorf("'exists' not supported") + errExpNotSupported = fmt.Errorf("'exp' not supported") + errUnknownField = fmt.Errorf("unknown field") + errInvalidIP = fmt.Errorf("invalid ipX value") + errInvalidMask = fmt.Errorf("invalid mask") + + errMatchedAll = fmt.Errorf("matched 'all'") + errMatchedA = fmt.Errorf("matched 'a'") + errMatchedIP = fmt.Errorf("matched 'ip'") + errMatchedMX = fmt.Errorf("matched 'mx'") + errMatchedPTR = fmt.Errorf("matched 'ptr'") +) + // CheckHost fetches SPF records for `domain`, parses them, and evaluates them // to determine if `ip` is permitted to send mail for it. // Reference: https://tools.ietf.org/html/rfc7208#section-4 @@ -144,11 +160,11 @@ func (r *resolution) Check(domain string) (Result, error) { // Limit the number of resolutions to 10 // https://tools.ietf.org/html/rfc7208#section-4.6.4 if r.count > 10 { - return PermError, fmt.Errorf("lookup limit reached") + return PermError, errLookupLimitReached } if strings.Contains(field, "%") { - return Neutral, fmt.Errorf("macros not supported") + return Neutral, errMacrosNotSupported } // See if we have a qualifier, defaulting to + (pass). @@ -162,7 +178,7 @@ func (r *resolution) Check(domain string) (Result, error) { if field == "all" { // https://tools.ietf.org/html/rfc7208#section-5.1 - return result, fmt.Errorf("matched 'all'") + return result, errMatchedAll } else if strings.HasPrefix(field, "include:") { if ok, res, err := r.includeField(result, field); ok { return res, err @@ -184,9 +200,9 @@ func (r *resolution) Check(domain string) (Result, error) { return res, err } } else if strings.HasPrefix(field, "exists") { - return Neutral, fmt.Errorf("'exists' not supported") + return Neutral, errExistsNotSupported } else if strings.HasPrefix(field, "exp=") { - return Neutral, fmt.Errorf("'exp' not supported") + return Neutral, errExpNotSupported } else if strings.HasPrefix(field, "redirect=") { // https://tools.ietf.org/html/rfc7208#section-6.1 result, err := r.Check(field[len("redirect="):]) @@ -196,7 +212,7 @@ func (r *resolution) Check(domain string) (Result, error) { return result, err } else { // http://www.openspf.org/SPF_Record_Syntax - return PermError, fmt.Errorf("unknown field %q", field) + return PermError, errUnknownField } } @@ -242,18 +258,18 @@ func (r *resolution) ipField(res Result, field string) (bool, Result, error) { if strings.Contains(fip, "/") { _, ipnet, err := net.ParseCIDR(fip) if err != nil { - return true, PermError, err + return true, PermError, errInvalidMask } if ipnet.Contains(r.ip) { - return true, res, fmt.Errorf("matched %v", ipnet) + return true, res, errMatchedIP } } else { ip := net.ParseIP(fip) if ip == nil { - return true, PermError, fmt.Errorf("invalid ipX value") + return true, PermError, errInvalidIP } if ip.Equal(r.ip) { - return true, res, fmt.Errorf("matched %v", ip) + return true, res, errMatchedIP } } @@ -283,7 +299,7 @@ func (r *resolution) ptrField(res Result, field, domain string) (bool, Result, e for _, n := range r.ipNames { if strings.HasSuffix(n, domain+".") { - return true, res, fmt.Errorf("matched ptr:%s", domain) + return true, res, errMatchedPTR } } @@ -314,15 +330,15 @@ func ipMatch(ip, tomatch net.IP, mask int) (bool, error) { if mask >= 0 { _, ipnet, err := net.ParseCIDR(fmt.Sprintf("%s/%d", tomatch.String(), mask)) if err != nil { - return false, err + return false, errInvalidMask } if ipnet.Contains(ip) { - return true, fmt.Errorf("%v", ipnet) + return true, nil } return false, nil } else { if ip.Equal(tomatch) { - return true, fmt.Errorf("%v", tomatch) + return true, nil } return false, nil } @@ -341,7 +357,7 @@ func domainAndMask(re *regexp.Regexp, field, domain string) (string, int, error) if groups[4] != "" { mask, err = strconv.Atoi(groups[4]) if err != nil { - return "", -1, fmt.Errorf("error parsing mask") + return "", -1, errInvalidMask } } } @@ -369,7 +385,7 @@ func (r *resolution) aField(res Result, field, domain string) (bool, Result, err for _, ip := range ips { ok, err := ipMatch(r.ip, ip, mask) if ok { - return true, res, fmt.Errorf("matched 'a' (%v)", err) + return true, res, errMatchedA } else if err != nil { return true, PermError, err } @@ -411,7 +427,7 @@ func (r *resolution) mxField(res Result, field, domain string) (bool, Result, er for _, ip := range mxips { ok, err := ipMatch(r.ip, ip, mask) if ok { - return true, res, fmt.Errorf("matched 'mx' (%v)", err) + return true, res, errMatchedMX } else if err != nil { return true, PermError, err } diff --git a/spf_test.go b/spf_test.go index 8337e75..65ae45a 100644 --- a/spf_test.go +++ b/spf_test.go @@ -55,34 +55,42 @@ func TestBasic(t *testing.T) { cases := []struct { txt string res Result + err error }{ - {"", None}, - {"blah", None}, - {"v=spf1", Neutral}, - {"v=spf1 ", Neutral}, - {"v=spf1 -", PermError}, - {"v=spf1 all", Pass}, - {"v=spf1 +all", Pass}, - {"v=spf1 -all ", Fail}, - {"v=spf1 ~all", SoftFail}, - {"v=spf1 ?all", Neutral}, - {"v=spf1 a ~all", SoftFail}, - {"v=spf1 a/24", Neutral}, - {"v=spf1 a:d1110/24", Pass}, - {"v=spf1 a:d1110", Neutral}, - {"v=spf1 a:d1111", Pass}, - {"v=spf1 a:nothing/24", Neutral}, - {"v=spf1 mx", Neutral}, - {"v=spf1 mx/24", Neutral}, - {"v=spf1 mx:a/montoto ~all", PermError}, - {"v=spf1 mx:d1110/24 ~all", Pass}, - {"v=spf1 ip4:1.2.3.4 ~all", SoftFail}, - {"v=spf1 ip6:12 ~all", PermError}, - {"v=spf1 ip4:1.1.1.1 -all", Pass}, - {"v=spf1 ptr -all", Pass}, - {"v=spf1 ptr:d1111 -all", Pass}, - {"v=spf1 ptr:lalala -all", Pass}, - {"v=spf1 blah", PermError}, + {"", None, nil}, + {"blah", None, nil}, + {"v=spf1", Neutral, nil}, + {"v=spf1 ", Neutral, nil}, + {"v=spf1 -", PermError, errUnknownField}, + {"v=spf1 all", Pass, errMatchedAll}, + {"v=spf1 +all", Pass, errMatchedAll}, + {"v=spf1 -all ", Fail, errMatchedAll}, + {"v=spf1 ~all", SoftFail, errMatchedAll}, + {"v=spf1 ?all", Neutral, errMatchedAll}, + {"v=spf1 a ~all", SoftFail, errMatchedAll}, + {"v=spf1 a/24", Neutral, nil}, + {"v=spf1 a:d1110/24", Pass, errMatchedA}, + {"v=spf1 a:d1110/montoto", PermError, errInvalidMask}, + {"v=spf1 a:d1110/99", PermError, errInvalidMask}, + {"v=spf1 a:d1110/32", Neutral, nil}, + {"v=spf1 a:d1110", Neutral, nil}, + {"v=spf1 a:d1111", Pass, errMatchedA}, + {"v=spf1 a:nothing/24", Neutral, nil}, + {"v=spf1 mx", Neutral, nil}, + {"v=spf1 mx/24", Neutral, nil}, + {"v=spf1 mx:a/montoto ~all", PermError, errInvalidMask}, + {"v=spf1 mx:d1110/24 ~all", Pass, errMatchedMX}, + {"v=spf1 mx:d1110/99 ~all", PermError, errInvalidMask}, + {"v=spf1 ip4:1.2.3.4 ~all", SoftFail, errMatchedAll}, + {"v=spf1 ip6:12 ~all", PermError, errInvalidIP}, + {"v=spf1 ip4:1.1.1.1 -all", Pass, errMatchedIP}, + {"v=spf1 ip4:1.1.1.1/24 -all", Pass, errMatchedIP}, + {"v=spf1 ip4:1.1.1.1/lala -all", PermError, errInvalidMask}, + {"v=spf1 ptr -all", Pass, errMatchedPTR}, + {"v=spf1 ptr:d1111 -all", Pass, errMatchedPTR}, + {"v=spf1 ptr:lalala -all", Pass, errMatchedPTR}, + {"v=spf1 ptr:doesnotexist -all", Fail, errMatchedAll}, + {"v=spf1 blah", PermError, errUnknownField}, } ipResults["d1111"] = []net.IP{ip1111} @@ -98,7 +106,9 @@ func TestBasic(t *testing.T) { } if res != c.res { t.Errorf("%q: expected %q, got %q", c.txt, c.res, res) - t.Logf("%q: error: %v", c.txt, err) + } + if err != c.err { + t.Errorf("%q: expected error [%v], got [%v]", c.txt, c.err, err) } } } @@ -107,21 +117,22 @@ func TestIPv6(t *testing.T) { cases := []struct { txt string res Result + err error }{ - {"v=spf1 all", Pass}, - {"v=spf1 a ~all", SoftFail}, - {"v=spf1 a/24", Neutral}, - {"v=spf1 a:d6660/24", Pass}, - {"v=spf1 a:d6660", Neutral}, - {"v=spf1 a:d6666", Pass}, - {"v=spf1 a:nothing/24", Neutral}, - {"v=spf1 mx:d6660/24 ~all", Pass}, - {"v=spf1 ip6:2001:db8::68 ~all", Pass}, - {"v=spf1 ip6:2001:db8::1/24 ~all", Pass}, - {"v=spf1 ip6:2001:db8::1/100 ~all", Pass}, - {"v=spf1 ptr -all", Pass}, - {"v=spf1 ptr:d6666 -all", Pass}, - {"v=spf1 ptr:sonlas6 -all", Pass}, + {"v=spf1 all", Pass, errMatchedAll}, + {"v=spf1 a ~all", SoftFail, errMatchedAll}, + {"v=spf1 a/24", Neutral, nil}, + {"v=spf1 a:d6660/24", Pass, errMatchedA}, + {"v=spf1 a:d6660", Neutral, nil}, + {"v=spf1 a:d6666", Pass, errMatchedA}, + {"v=spf1 a:nothing/24", Neutral, nil}, + {"v=spf1 mx:d6660/24 ~all", Pass, errMatchedMX}, + {"v=spf1 ip6:2001:db8::68 ~all", Pass, errMatchedIP}, + {"v=spf1 ip6:2001:db8::1/24 ~all", Pass, errMatchedIP}, + {"v=spf1 ip6:2001:db8::1/100 ~all", Pass, errMatchedIP}, + {"v=spf1 ptr -all", Pass, errMatchedPTR}, + {"v=spf1 ptr:d6666 -all", Pass, errMatchedPTR}, + {"v=spf1 ptr:sonlas6 -all", Pass, errMatchedPTR}, } ipResults["d6666"] = []net.IP{ip6666} @@ -137,25 +148,29 @@ func TestIPv6(t *testing.T) { } if res != c.res { t.Errorf("%q: expected %q, got %q", c.txt, c.res, res) - t.Logf("%q: error: %v", c.txt, err) + } + if err != c.err { + t.Errorf("%q: expected error [%v], got [%v]", c.txt, c.err, err) } } } func TestNotSupported(t *testing.T) { - cases := []string{ - "v=spf1 exists:blah -all", - "v=spf1 exp=blah -all", - "v=spf1 a:%{o} -all", - "v=spf1 redirect=_spf.%{d}", + cases := []struct { + txt string + err error + }{ + {"v=spf1 exists:blah -all", errExistsNotSupported}, + {"v=spf1 exp=blah -all", errExpNotSupported}, + {"v=spf1 a:%{o} -all", errMacrosNotSupported}, + {"v=spf1 redirect=_spf.%{d}", errMacrosNotSupported}, } - for _, txt := range cases { - txtResults["domain"] = []string{txt} + for _, c := range cases { + txtResults["domain"] = []string{c.txt} res, err := CheckHost(ip1111, "domain") - if res != Neutral { - t.Errorf("%q: expected neutral, got %v", txt, res) - t.Logf("%q: error: %v", txt, err) + if res != Neutral || err != c.err { + t.Errorf("%q: expected neutral/%q, got %v/%q", c.txt, c.err, res, err) } } } @@ -164,7 +179,7 @@ func TestRecursion(t *testing.T) { txtResults["domain"] = []string{"v=spf1 include:domain ~all"} res, err := CheckHost(ip1111, "domain") - if res != PermError { + if res != PermError || err != errLookupLimitReached { t.Errorf("expected permerror, got %v (%v)", res, err) } } @@ -191,7 +206,7 @@ func TestInvalidRedirect(t *testing.T) { } res, err = CheckHost(ip1111, "domain") - if res != PermError { + if res != PermError || err != nil { t.Errorf("expected permerror, got %v (%v)", res, err) } } @@ -203,13 +218,13 @@ func TestRedirectOrder(t *testing.T) { txtResults["domain"] = []string{"v=spf1 redirect=faildom"} res, err := CheckHost(ip1111, "domain") - if res != Fail { + if res != Fail || err != errMatchedAll { t.Errorf("expected fail, got %v (%v)", res, err) } txtResults["domain"] = []string{"v=spf1 redirect=faildom all"} res, err = CheckHost(ip1111, "domain") - if res != Pass { + if res != Pass || err != errMatchedAll { t.Errorf("expected pass, got %v (%v)", res, err) } }