author | Alberto Bertogli
<albertito@blitiri.com.ar> 2016-10-21 23:07:07 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2016-10-21 23:07:07 UTC |
parent | b45726f5cbfa404e5fd79b29eb20ba7a6f8532fc |
spf.go | +42 | -7 |
spf_test.go | +52 | -1 |
diff --git a/spf.go b/spf.go index fbc28d6..adae64b 100644 --- a/spf.go +++ b/spf.go @@ -11,7 +11,6 @@ // // Not supported (return Neutral if used): // - "exists". -// - "ptr". // - "exp". // - Macros. // @@ -30,9 +29,10 @@ import ( // Functions that we can override for testing purposes. var ( - lookupTXT = net.LookupTXT - lookupMX = net.LookupMX - lookupIP = net.LookupIP + lookupTXT = net.LookupTXT + lookupMX = net.LookupMX + lookupIP = net.LookupIP + lookupAddr = net.LookupAddr ) // Results and Errors. Note the values have meaning, we use them in headers. @@ -81,13 +81,16 @@ var QualToResult = map[byte]Result{ // with a given identity. // Reference: https://tools.ietf.org/html/rfc7208#section-4 func CheckHost(ip net.IP, domain string) (Result, error) { - r := &resolution{ip, 0} + r := &resolution{ip, 0, nil} return r.Check(domain) } type resolution struct { ip net.IP count uint + + // Result of doing a reverse lookup for ip (so we only do it once). + ipNames []string } func (r *resolution) Check(domain string) (Result, error) { @@ -167,10 +170,12 @@ func (r *resolution) Check(domain string) (Result, error) { if ok, res, err := r.ipField(result, field); ok { return res, err } + } else if strings.HasPrefix(field, "ptr") { + if ok, res, err := r.ptrField(result, field, domain); ok { + return res, err + } } else if strings.HasPrefix(field, "exists") { return Neutral, fmt.Errorf("'exists' not supported") - } else if strings.HasPrefix(field, "ptr") { - return Neutral, fmt.Errorf("'ptr' not supported") } else if strings.HasPrefix(field, "exp=") { return Neutral, fmt.Errorf("'exp' not supported") } else if strings.HasPrefix(field, "redirect=") { @@ -246,6 +251,36 @@ func (r *resolution) ipField(res Result, field string) (bool, Result, error) { return false, "", nil } +// ptrField processes a "ptr" field. +func (r *resolution) ptrField(res Result, field, domain string) (bool, Result, error) { + // Extract the domain if the field is in the form "ptr:domain" + if len(field) >= 4 { + domain = field[4:] + + } + + if r.ipNames == nil { + r.count++ + n, err := lookupAddr(r.ip.String()) + if err != nil { + // https://tools.ietf.org/html/rfc7208#section-5 + if isTemporary(err) { + return true, TempError, err + } + return false, "", err + } + r.ipNames = n + } + + for _, n := range r.ipNames { + if strings.HasSuffix(n, domain+".") { + return true, res, fmt.Errorf("matched ptr:%s", domain) + } + } + + return false, "", nil +} + // includeField processes an "include" field. func (r *resolution) includeField(res Result, field string) (bool, Result, error) { // https://tools.ietf.org/html/rfc7208#section-5.2 diff --git a/spf_test.go b/spf_test.go index a5007dc..8d8d813 100644 --- a/spf_test.go +++ b/spf_test.go @@ -29,10 +29,18 @@ func LookupIP(host string) (ips []net.IP, err error) { return ipResults[host], ipErrors[host] } +var addrResults = map[string][]string{} +var addrErrors = map[string]error{} + +func LookupAddr(host string) (addrs []string, err error) { + return addrResults[host], addrErrors[host] +} + func TestMain(m *testing.M) { lookupTXT = LookupTXT lookupMX = LookupMX lookupIP = LookupIP + lookupAddr = LookupAddr flag.Parse() os.Exit(m.Run()) @@ -41,6 +49,7 @@ func TestMain(m *testing.M) { var ip1110 = net.ParseIP("1.1.1.0") var ip1111 = net.ParseIP("1.1.1.1") var ip6666 = net.ParseIP("2001:db8::68") +var ip6660 = net.ParseIP("2001:db8::0") func TestBasic(t *testing.T) { cases := []struct { @@ -70,12 +79,16 @@ func TestBasic(t *testing.T) { {"v=spf1 ip4:1.2.3.4 ~all", SoftFail}, {"v=spf1 ip6:12 ~all", PermError}, {"v=spf1 ip4:1.1.1.1 -all", Pass}, + {"v=spf1 ptr -all", Pass}, + {"v=spf1 ptr:d1111 -all", Pass}, + {"v=spf1 ptr:lalala -all", Pass}, {"v=spf1 blah", PermError}, } ipResults["d1111"] = []net.IP{ip1111} ipResults["d1110"] = []net.IP{ip1110} mxResults["d1110"] = []*net.MX{{"d1110", 5}, {"nothing", 10}} + addrResults["1.1.1.1"] = []string{"lalala.", "domain.", "d1111."} for _, c := range cases { txtResults["domain"] = []string{c.txt} @@ -90,10 +103,48 @@ func TestBasic(t *testing.T) { } } +func TestIPv6(t *testing.T) { + cases := []struct { + txt string + res Result + }{ + {"v=spf1 all", Pass}, + {"v=spf1 a ~all", SoftFail}, + {"v=spf1 a/24", Neutral}, + {"v=spf1 a:d6660/24", Pass}, + {"v=spf1 a:d6660", Neutral}, + {"v=spf1 a:d6666", Pass}, + {"v=spf1 a:nothing/24", Neutral}, + {"v=spf1 mx:d6660/24 ~all", Pass}, + {"v=spf1 ip6:2001:db8::68 ~all", Pass}, + {"v=spf1 ip6:2001:db8::1/24 ~all", Pass}, + {"v=spf1 ip6:2001:db8::1/100 ~all", Pass}, + {"v=spf1 ptr -all", Pass}, + {"v=spf1 ptr:d6666 -all", Pass}, + {"v=spf1 ptr:sonlas6 -all", Pass}, + } + + ipResults["d6666"] = []net.IP{ip6666} + ipResults["d6660"] = []net.IP{ip6660} + mxResults["d6660"] = []*net.MX{{"d6660", 5}, {"nothing", 10}} + addrResults["2001:db8::68"] = []string{"sonlas6.", "domain.", "d6666."} + + for _, c := range cases { + txtResults["domain"] = []string{c.txt} + res, err := CheckHost(ip6666, "domain") + if (res == TempError || res == PermError) && (err == nil) { + t.Errorf("%q: expected error, got nil", c.txt) + } + if res != c.res { + t.Errorf("%q: expected %q, got %q", c.txt, c.res, res) + t.Logf("%q: error: %v", c.txt, err) + } + } +} + func TestNotSupported(t *testing.T) { cases := []string{ "v=spf1 exists:blah -all", - "v=spf1 ptr -all", "v=spf1 exp=blah -all", "v=spf1 a:%{o} -all", }