author | Alberto Bertogli
<albertito@blitiri.com.ar> 2021-05-09 09:10:19 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2021-05-09 09:38:54 UTC |
parent | 25345a2e2784a550b613409b0abd49f59b209ab2 |
spf.go | +20 | -28 |
spf_test.go | +64 | -16 |
diff --git a/spf.go b/spf.go index 38710e2..0cf1319 100644 --- a/spf.go +++ b/spf.go @@ -568,50 +568,48 @@ func (r *resolution) includeField(res Result, field, domain string) (bool, Resul } type dualMasks struct { - v4 int - v6 int + v4 net.IPMask + v6 net.IPMask } -func ipMatch(ip, tomatch net.IP, masks dualMasks) (bool, error) { - mask := -1 - if tomatch.To4() != nil && masks.v4 >= 0 { +func ipMatch(ip, tomatch net.IP, masks dualMasks) bool { + mask := net.IPMask(nil) + if tomatch.To4() != nil && masks.v4 != nil { mask = masks.v4 - } else if tomatch.To4() == nil && masks.v6 >= 0 { + } else if tomatch.To4() == nil && masks.v6 != nil { mask = masks.v6 } - if mask >= 0 { - _, ipnet, err := net.ParseCIDR( - fmt.Sprintf("%s/%d", tomatch.String(), mask)) - if err != nil { - return false, errInvalidMask - } - return ipnet.Contains(ip), nil + if mask != nil { + ipnet := net.IPNet{IP: tomatch, Mask: mask} + return ipnet.Contains(ip) } - return ip.Equal(tomatch), nil + return ip.Equal(tomatch) } var aRegexp = regexp.MustCompile(`^[aA](:([^/]+))?(/(\w+))?(//(\w+))?$`) var mxRegexp = regexp.MustCompile(`^[mM][xX](:([^/]+))?(/(\w+))?(//(\w+))?$`) func domainAndMask(re *regexp.Regexp, field, domain string) (string, dualMasks, error) { - masks := dualMasks{-1, -1} + masks := dualMasks{} groups := re.FindStringSubmatch(field) if groups != nil { if groups[2] != "" { domain = groups[2] } if groups[4] != "" { - mask4, err := strconv.Atoi(groups[4]) - if err != nil || mask4 < 0 || mask4 > 32 { + i, err := strconv.Atoi(groups[4]) + mask4 := net.CIDRMask(i, 32) + if err != nil || mask4 == nil { return "", masks, errInvalidMask } masks.v4 = mask4 } if groups[6] != "" { - mask6, err := strconv.Atoi(groups[6]) - if err != nil || mask6 < 0 || mask6 > 128 { + i, err := strconv.Atoi(groups[6]) + mask6 := net.CIDRMask(i, 128) + if err != nil || mask6 == nil { return "", masks, errInvalidMask } masks.v6 = mask6 @@ -621,7 +619,7 @@ func domainAndMask(re *regexp.Regexp, field, domain string) (string, dualMasks, // Test to catch malformed entries: if there's a /, there must be at least // one mask. - if strings.Contains(field, "/") && masks.v4 == -1 && masks.v6 == -1 { + if strings.Contains(field, "/") && masks.v4 == nil && masks.v6 == nil { return "", masks, errInvalidMask } @@ -650,12 +648,9 @@ func (r *resolution) aField(res Result, field, domain string) (bool, Result, err return false, "", err } for _, ip := range ips { - ok, err := ipMatch(r.ip, ip.IP, masks) - if ok { + if ipMatch(r.ip, ip.IP, masks) { trace("a matched %v, %v, %v", r.ip, ip.IP, masks) return true, res, errMatchedA - } else if err != nil { - return true, PermError, err } } @@ -706,12 +701,9 @@ func (r *resolution) mxField(res Result, field, domain string) (bool, Result, er } } for _, ip := range mxips { - ok, err := ipMatch(r.ip, ip, masks) - if ok { + if ipMatch(r.ip, ip, masks) { trace("mx matched %v, %v, %v", r.ip, ip, masks) return true, res, errMatchedMX - } else if err != nil { - return true, PermError, err } } diff --git a/spf_test.go b/spf_test.go index 6b7787f..6d76d04 100644 --- a/spf_test.go +++ b/spf_test.go @@ -403,31 +403,34 @@ func mx(host string, pref uint16) *net.MX { return &net.MX{Host: host, Pref: pref} } +func mkDM(v4, v6 int) dualMasks { + return dualMasks{net.CIDRMask(v4, 32), net.CIDRMask(v6, 128)} +} + func TestIPMatchHelper(t *testing.T) { cases := []struct { ip net.IP tomatch net.IP masks dualMasks ok bool - err error }{ - {ip1111, ip1110, dualMasks{24, -1}, true, nil}, - {ip1111, ip1111, dualMasks{-1, -1}, true, nil}, - {ip1111, ip1110, dualMasks{-1, -1}, false, nil}, - {ip1111, ip1110, dualMasks{32, -1}, false, nil}, - {ip1111, ip1110, dualMasks{99, -1}, false, errInvalidMask}, - - {ip6666, ip6660, dualMasks{-1, 100}, true, nil}, - {ip6666, ip6666, dualMasks{-1, -1}, true, nil}, - {ip6666, ip6660, dualMasks{-1, -1}, false, nil}, - {ip6666, ip6660, dualMasks{-1, 128}, false, nil}, - {ip6666, ip6660, dualMasks{-1, 200}, false, errInvalidMask}, + {ip1111, ip1110, mkDM(24, -1), true}, + {ip1111, ip1111, mkDM(-1, -1), true}, + {ip1111, ip1110, mkDM(-1, -1), false}, + {ip1111, ip1110, mkDM(32, -1), false}, + {ip1111, ip1110, mkDM(99, -1), false}, + + {ip6666, ip6660, mkDM(-1, 100), true}, + {ip6666, ip6666, mkDM(-1, -1), true}, + {ip6666, ip6660, mkDM(-1, -1), false}, + {ip6666, ip6660, mkDM(-1, 128), false}, + {ip6666, ip6660, mkDM(-1, 200), false}, } for _, c := range cases { - ok, err := ipMatch(c.ip, c.tomatch, c.masks) - if ok != c.ok || err != c.err { - t.Errorf("[%s %s/%v]: expected %v/%v, got %v/%v", - c.ip, c.tomatch, c.masks, c.ok, c.err, ok, err) + ok := ipMatch(c.ip, c.tomatch, c.masks) + if ok != c.ok { + t.Errorf("[%s %s/%v]: expected %v, got %v", + c.ip, c.tomatch, c.masks, c.ok, ok) } } } @@ -539,3 +542,48 @@ func TestWithResolver(t *testing.T) { t.Errorf("expected pass, got %q / %q", res, err) } } + +// Test some corner cases when resolver.LookupIPAddr returns an invalid +// address. This can happen if using a buggy custom resolver. +func TestBadResolverResponse(t *testing.T) { + dns := NewResolver() + trace = t.Logf + + // When LookupIPAddr returns an invalid ip, for an "a" field. + dns.ip["domain1"] = []net.IP{nil} + dns.txt["domain1"] = []string{"v=spf1 a:domain1 -all"} + res, err := CheckHostWithSender(ip1111, "helo", "user@domain1", + WithResolver(dns)) + if res != Fail { + t.Errorf("expected fail, got %q / %q", res, err) + } + + // Same as above, except the field has a mask. + dns.ip["domain1"] = []net.IP{nil} + dns.txt["domain1"] = []string{"v=spf1 a:domain1//24 -all"} + res, err = CheckHostWithSender(ip1111, "helo", "user@domain1", + WithResolver(dns)) + if res != Fail { + t.Errorf("expected fail, got %q / %q", res, err) + } + + // When LookupIPAddr returns an invalid ip, for an "mx" field. + dns.ip["mx.domain1"] = []net.IP{nil} + dns.mx["domain1"] = []*net.MX{mx("mx.domain1", 5)} + dns.txt["domain1"] = []string{"v=spf1 mx:domain1 -all"} + res, err = CheckHostWithSender(ip1111, "helo", "user@domain1", + WithResolver(dns)) + if res != Fail { + t.Errorf("expected fail, got %q / %q", res, err) + } + + // Same as above, except the field has a mask. + dns.ip["mx.domain1"] = []net.IP{nil} + dns.mx["domain1"] = []*net.MX{mx("mx.domain1", 5)} + dns.txt["domain1"] = []string{"v=spf1 mx:domain1//24 -all"} + res, err = CheckHostWithSender(ip1111, "helo", "user@domain1", + WithResolver(dns)) + if res != Fail { + t.Errorf("expected fail, got %q / %q", res, err) + } +}