author | Alberto Bertogli
<albertito@blitiri.com.ar> 2019-10-13 20:34:18 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2019-10-14 12:31:32 UTC |
parent | fcd50be18d1a8aeeb5b31c7dfba8a1c4d4bb411f |
spf.go | +47 | -20 |
spf_test.go | +10 | -3 |
diff --git a/spf.go b/spf.go index c169110..67ab070 100644 --- a/spf.go +++ b/spf.go @@ -402,16 +402,26 @@ func (r *resolution) includeField(res Result, field string) (bool, Result, error return false, "", fmt.Errorf("This should never be reached") } -func ipMatch(ip, tomatch net.IP, mask int) (bool, error) { +type DualMasks struct { + v4 int + v6 int +} + +func ipMatch(ip, tomatch net.IP, masks DualMasks) (bool, error) { + mask := -1 + if tomatch.To4() != nil && masks.v4 >= 0 { + mask = masks.v4 + } else if tomatch.To4() == nil && masks.v6 >= 0 { + mask = masks.v6 + } + if mask >= 0 { - _, ipnet, err := net.ParseCIDR(fmt.Sprintf("%s/%d", tomatch.String(), mask)) + _, ipnet, err := net.ParseCIDR( + fmt.Sprintf("%s/%d", tomatch.String(), mask)) if err != nil { return false, errInvalidMask } - if ipnet.Contains(ip) { - return true, nil - } - return false, nil + return ipnet.Contains(ip), nil } else { if ip.Equal(tomatch) { return true, nil @@ -420,31 +430,46 @@ func ipMatch(ip, tomatch net.IP, mask int) (bool, error) { } } -var aRegexp = regexp.MustCompile("[aA](:([^/]+))?(/(.+))?") -var mxRegexp = regexp.MustCompile("[mM][xX](:([^/]+))?(/(.+))?") +var aRegexp = regexp.MustCompile(`^[aA](:([^/]+))?(/(\w+))?(//(\w+))?$`) +var mxRegexp = regexp.MustCompile(`^[mM][xX](:([^/]+))?(/(\w+))?(//(\w+))?$`) -func domainAndMask(re *regexp.Regexp, field, domain string) (string, int, error) { - var err error - mask := -1 - if groups := re.FindStringSubmatch(field); groups != nil { +func domainAndMask(re *regexp.Regexp, field, domain string) (string, DualMasks, error) { + masks := DualMasks{-1, -1} + groups := re.FindStringSubmatch(field) + if groups != nil { if groups[2] != "" { domain = groups[2] } if groups[4] != "" { - mask, err = strconv.Atoi(groups[4]) - if err != nil { - return "", -1, errInvalidMask + mask4, err := strconv.Atoi(groups[4]) + if err != nil || mask4 < 0 || mask4 > 32 { + return "", masks, errInvalidMask + } + masks.v4 = mask4 + } + if groups[6] != "" { + mask6, err := strconv.Atoi(groups[6]) + if err != nil || mask6 < 0 || mask6 > 128 { + return "", masks, errInvalidMask } + masks.v6 = mask6 } } + trace("masks on %q: %q %q %v", field, groups, domain, masks) + + // Test to catch malformed entries: if there's a /, there must be at least + // one mask. + if strings.Contains(field, "/") && masks.v4 == -1 && masks.v6 == -1 { + return "", masks, errInvalidMask + } - return domain, mask, nil + return domain, masks, nil } // aField processes an "a" field. func (r *resolution) aField(res Result, field, domain string) (bool, Result, error) { // https://tools.ietf.org/html/rfc7208#section-5.3 - domain, mask, err := domainAndMask(aRegexp, field, domain) + domain, masks, err := domainAndMask(aRegexp, field, domain) if err != nil { return true, PermError, err } @@ -459,8 +484,9 @@ func (r *resolution) aField(res Result, field, domain string) (bool, Result, err return false, "", err } for _, ip := range ips { - ok, err := ipMatch(r.ip, ip, mask) + ok, err := ipMatch(r.ip, ip, masks) if ok { + trace("mx matched %v, %v, %v", r.ip, ip, masks) return true, res, errMatchedA } else if err != nil { return true, PermError, err @@ -473,7 +499,7 @@ func (r *resolution) aField(res Result, field, domain string) (bool, Result, err // mxField processes an "mx" field. func (r *resolution) mxField(res Result, field, domain string) (bool, Result, error) { // https://tools.ietf.org/html/rfc7208#section-5.4 - domain, mask, err := domainAndMask(mxRegexp, field, domain) + domain, masks, err := domainAndMask(mxRegexp, field, domain) if err != nil { return true, PermError, err } @@ -501,8 +527,9 @@ func (r *resolution) mxField(res Result, field, domain string) (bool, Result, er mxips = append(mxips, ips...) } for _, ip := range mxips { - ok, err := ipMatch(r.ip, ip, mask) + ok, err := ipMatch(r.ip, ip, masks) if ok { + trace("mx matched %v, %v, %v", r.ip, ip, masks) return true, res, errMatchedMX } else if err != nil { return true, PermError, err diff --git a/spf_test.go b/spf_test.go index aa80a4b..acbff91 100644 --- a/spf_test.go +++ b/spf_test.go @@ -44,12 +44,16 @@ func TestBasic(t *testing.T) { {"v=spf1 mx/24", Neutral, nil}, {"v=spf1 mx:a/montoto ~all", PermError, errInvalidMask}, {"v=spf1 mx:d1110/24 ~all", Pass, errMatchedMX}, + {"v=spf1 mx:d1110/24//100 ~all", Pass, errMatchedMX}, + {"v=spf1 mx:d1110/24//129 ~all", PermError, errInvalidMask}, + {"v=spf1 mx:d1110/24/100 ~all", PermError, errInvalidMask}, {"v=spf1 mx:d1110/99 ~all", PermError, errInvalidMask}, {"v=spf1 ip4:1.2.3.4 ~all", SoftFail, errMatchedAll}, {"v=spf1 ip6:12 ~all", PermError, errInvalidIP}, {"v=spf1 ip4:1.1.1.1 -all", Pass, errMatchedIP}, {"v=spf1 ip4:1.1.1.1/24 -all", Pass, errMatchedIP}, {"v=spf1 ip4:1.1.1.1/lala -all", PermError, errInvalidMask}, + {"v=spf1 ip4:1.1.1.1/33 -all", PermError, errInvalidMask}, {"v=spf1 include:doesnotexist", PermError, errNoResult}, {"v=spf1 ptr -all", Pass, errMatchedPTR}, {"v=spf1 ptr:d1111 -all", Pass, errMatchedPTR}, @@ -90,11 +94,14 @@ func TestIPv6(t *testing.T) { {"v=spf1 all", Pass, errMatchedAll}, {"v=spf1 a ~all", SoftFail, errMatchedAll}, {"v=spf1 a/24", Neutral, nil}, - {"v=spf1 a:d6660/24", Pass, errMatchedA}, + {"v=spf1 a:d6660//24", Pass, errMatchedA}, + {"v=spf1 a:d6660/24//100", Pass, errMatchedA}, {"v=spf1 a:d6660", Neutral, nil}, {"v=spf1 a:d6666", Pass, errMatchedA}, - {"v=spf1 a:nothing/24", Neutral, nil}, - {"v=spf1 mx:d6660/24 ~all", Pass, errMatchedMX}, + {"v=spf1 a:nothing//24", Neutral, nil}, + {"v=spf1 mx:d6660//24 ~all", Pass, errMatchedMX}, + {"v=spf1 mx:d6660/24//100 ~all", Pass, errMatchedMX}, + {"v=spf1 mx:d6660/24/100 ~all", PermError, errInvalidMask}, {"v=spf1 ip6:2001:db8::68 ~all", Pass, errMatchedIP}, {"v=spf1 ip6:2001:db8::1/24 ~all", Pass, errMatchedIP}, {"v=spf1 ip6:2001:db8::1/100 ~all", Pass, errMatchedIP},