author | Alberto Bertogli
<albertito@blitiri.com.ar> 2018-03-18 12:33:31 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2018-03-18 13:16:26 UTC |
parent | 54adca3936c096a3f67eea5551f3f1b91d1b887e |
spf.go | +4 | -2 |
spf_test.go | +30 | -1 |
diff --git a/spf.go b/spf.go index 62f2eca..2edcb6d 100644 --- a/spf.go +++ b/spf.go @@ -96,6 +96,7 @@ var ( errUnknownField = fmt.Errorf("unknown field") errInvalidIP = fmt.Errorf("invalid ipX value") errInvalidMask = fmt.Errorf("invalid mask") + errNoResult = fmt.Errorf("lookup yielded no result") errMatchedAll = fmt.Errorf("matched 'all'") errMatchedA = fmt.Errorf("matched 'a'") @@ -318,12 +319,13 @@ func (r *resolution) includeField(res Result, field string) (bool, Result, error return false, ir, err case TempError: return true, TempError, err - case PermError, None: + case PermError: return true, PermError, err + case None: + return true, PermError, errNoResult } return false, "", fmt.Errorf("This should never be reached") - } func ipMatch(ip, tomatch net.IP, mask int) (bool, error) { diff --git a/spf_test.go b/spf_test.go index 65ae45a..b385eec 100644 --- a/spf_test.go +++ b/spf_test.go @@ -86,6 +86,7 @@ func TestBasic(t *testing.T) { {"v=spf1 ip4:1.1.1.1 -all", Pass, errMatchedIP}, {"v=spf1 ip4:1.1.1.1/24 -all", Pass, errMatchedIP}, {"v=spf1 ip4:1.1.1.1/lala -all", PermError, errInvalidMask}, + {"v=spf1 include:doesnotexist", PermError, errNoResult}, {"v=spf1 ptr -all", Pass, errMatchedPTR}, {"v=spf1 ptr:d1111 -all", Pass, errMatchedPTR}, {"v=spf1 ptr:lalala -all", Pass, errMatchedPTR}, @@ -175,7 +176,35 @@ func TestNotSupported(t *testing.T) { } } -func TestRecursion(t *testing.T) { +func TestInclude(t *testing.T) { + // Test that the include is doing a recursive lookup. + // If we got a match on 1.1.1.1, is because include:domain2 did not match. + txtResults["domain"] = []string{"v=spf1 include:domain2 ip4:1.1.1.1"} + + cases := []struct { + txt string + res Result + err error + }{ + {"", PermError, errNoResult}, + {"v=spf1 all", Pass, errMatchedAll}, + + // domain2 did not pass, so continued and matched parent's ip4. + {"v=spf1", Pass, errMatchedIP}, + {"v=spf1 -all", Pass, errMatchedIP}, + } + + for _, c := range cases { + txtResults["domain2"] = []string{c.txt} + res, err := CheckHost(ip1111, "domain") + if res != c.res || err != c.err { + t.Errorf("%q: expected [%v/%v], got [%v/%v]", + c.txt, c.res, c.err, res, err) + } + } +} + +func TestRecursionLimit(t *testing.T) { txtResults["domain"] = []string{"v=spf1 include:domain ~all"} res, err := CheckHost(ip1111, "domain")