author | Alberto Bertogli
<albertito@blitiri.com.ar> 2019-10-13 13:49:47 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2019-10-14 12:13:57 UTC |
parent | 8289dc24c98df04a67a7d1d1ad6f27fcaace4b92 |
dns_test.go | +69 | -0 |
spf_test.go | +43 | -75 |
diff --git a/dns_test.go b/dns_test.go new file mode 100644 index 0000000..a28fc7e --- /dev/null +++ b/dns_test.go @@ -0,0 +1,69 @@ +package spf + +import ( + "flag" + "net" + "os" + "strings" + "testing" +) + +// DNS overrides for testing. + +type DNS struct { + txt map[string][]string + mx map[string][]*net.MX + ip map[string][]net.IP + addr map[string][]string + errors map[string]error +} + +func NewDNS() DNS { + return DNS{ + txt: map[string][]string{}, + mx: map[string][]*net.MX{}, + ip: map[string][]net.IP{}, + addr: map[string][]string{}, + errors: map[string]error{}, + } +} + +// Single global variable that the overridden resolvers use. +// This way it's easier to get a clean slate between tests. +var dns DNS + +func LookupTXT(domain string) (txts []string, err error) { + domain = strings.ToLower(domain) + domain = strings.TrimRight(domain, ".") + return dns.txt[domain], dns.errors[domain] +} + +func LookupMX(domain string) (mxs []*net.MX, err error) { + domain = strings.ToLower(domain) + domain = strings.TrimRight(domain, ".") + return dns.mx[domain], dns.errors[domain] +} + +func LookupIP(host string) (ips []net.IP, err error) { + host = strings.ToLower(host) + host = strings.TrimRight(host, ".") + return dns.ip[host], dns.errors[host] +} + +func LookupAddr(host string) (addrs []string, err error) { + host = strings.ToLower(host) + host = strings.TrimRight(host, ".") + return dns.addr[host], dns.errors[host] +} + +func TestMain(m *testing.M) { + dns = NewDNS() + + lookupTXT = LookupTXT + lookupMX = LookupMX + lookupIP = LookupIP + lookupAddr = LookupAddr + + flag.Parse() + os.Exit(m.Run()) +} diff --git a/spf_test.go b/spf_test.go index 76ca65c..54383ec 100644 --- a/spf_test.go +++ b/spf_test.go @@ -1,57 +1,19 @@ package spf import ( - "flag" "fmt" "net" - "os" "testing" ) -var txtResults = map[string][]string{} -var txtErrors = map[string]error{} - -func LookupTXT(domain string) (txts []string, err error) { - return txtResults[domain], txtErrors[domain] -} - -var mxResults = map[string][]*net.MX{} -var mxErrors = map[string]error{} - -func LookupMX(domain string) (mxs []*net.MX, err error) { - return mxResults[domain], mxErrors[domain] -} - -var ipResults = map[string][]net.IP{} -var ipErrors = map[string]error{} - -func LookupIP(host string) (ips []net.IP, err error) { - return ipResults[host], ipErrors[host] -} - -var addrResults = map[string][]string{} -var addrErrors = map[string]error{} - -func LookupAddr(host string) (addrs []string, err error) { - return addrResults[host], addrErrors[host] -} - -func TestMain(m *testing.M) { - lookupTXT = LookupTXT - lookupMX = LookupMX - lookupIP = LookupIP - lookupAddr = LookupAddr - - flag.Parse() - os.Exit(m.Run()) -} - var ip1110 = net.ParseIP("1.1.1.0") var ip1111 = net.ParseIP("1.1.1.1") var ip6666 = net.ParseIP("2001:db8::68") var ip6660 = net.ParseIP("2001:db8::0") func TestBasic(t *testing.T) { + dns = NewDNS() + cases := []struct { txt string res Result @@ -94,13 +56,13 @@ func TestBasic(t *testing.T) { {"v=spf1 blah", PermError, errUnknownField}, } - ipResults["d1111"] = []net.IP{ip1111} - ipResults["d1110"] = []net.IP{ip1110} - mxResults["d1110"] = []*net.MX{{"d1110", 5}, {"nothing", 10}} - addrResults["1.1.1.1"] = []string{"lalala.", "domain.", "d1111."} + dns.ip["d1111"] = []net.IP{ip1111} + dns.ip["d1110"] = []net.IP{ip1110} + dns.mx["d1110"] = []*net.MX{{"d1110", 5}, {"nothing", 10}} + dns.addr["1.1.1.1"] = []string{"lalala.", "domain.", "d1111."} for _, c := range cases { - txtResults["domain"] = []string{c.txt} + dns.txt["domain"] = []string{c.txt} res, err := CheckHost(ip1111, "domain") if (res == TempError || res == PermError) && (err == nil) { t.Errorf("%q: expected error, got nil", c.txt) @@ -115,6 +77,8 @@ func TestBasic(t *testing.T) { } func TestIPv6(t *testing.T) { + dns = NewDNS() + cases := []struct { txt string res Result @@ -136,13 +100,13 @@ func TestIPv6(t *testing.T) { {"v=spf1 ptr:sonlas6 -all", Pass, errMatchedPTR}, } - ipResults["d6666"] = []net.IP{ip6666} - ipResults["d6660"] = []net.IP{ip6660} - mxResults["d6660"] = []*net.MX{{"d6660", 5}, {"nothing", 10}} - addrResults["2001:db8::68"] = []string{"sonlas6.", "domain.", "d6666."} + dns.ip["d6666"] = []net.IP{ip6666} + dns.ip["d6660"] = []net.IP{ip6660} + dns.mx["d6660"] = []*net.MX{{"d6660", 5}, {"nothing", 10}} + dns.addr["2001:db8::68"] = []string{"sonlas6.", "domain.", "d6666."} for _, c := range cases { - txtResults["domain"] = []string{c.txt} + dns.txt["domain"] = []string{c.txt} res, err := CheckHost(ip6666, "domain") if (res == TempError || res == PermError) && (err == nil) { t.Errorf("%q: expected error, got nil", c.txt) @@ -168,7 +132,7 @@ func TestNotSupported(t *testing.T) { } for _, c := range cases { - txtResults["domain"] = []string{c.txt} + dns.txt["domain"] = []string{c.txt} res, err := CheckHost(ip1111, "domain") if res != Neutral || err != c.err { t.Errorf("%q: expected neutral/%q, got %v/%q", c.txt, c.err, res, err) @@ -179,7 +143,8 @@ func TestNotSupported(t *testing.T) { func TestInclude(t *testing.T) { // Test that the include is doing a recursive lookup. // If we got a match on 1.1.1.1, is because include:domain2 did not match. - txtResults["domain"] = []string{"v=spf1 include:domain2 ip4:1.1.1.1"} + dns = NewDNS() + dns.txt["domain"] = []string{"v=spf1 include:domain2 ip4:1.1.1.1"} cases := []struct { txt string @@ -195,7 +160,7 @@ func TestInclude(t *testing.T) { } for _, c := range cases { - txtResults["domain2"] = []string{c.txt} + dns.txt["domain2"] = []string{c.txt} res, err := CheckHost(ip1111, "domain") if res != c.res || err != c.err { t.Errorf("%q: expected [%v/%v], got [%v/%v]", @@ -205,7 +170,8 @@ func TestInclude(t *testing.T) { } func TestRecursionLimit(t *testing.T) { - txtResults["domain"] = []string{"v=spf1 include:domain ~all"} + dns = NewDNS() + dns.txt["domain"] = []string{"v=spf1 include:domain ~all"} res, err := CheckHost(ip1111, "domain") if res != PermError || err != errLookupLimitReached { @@ -214,8 +180,9 @@ func TestRecursionLimit(t *testing.T) { } func TestRedirect(t *testing.T) { - txtResults["domain"] = []string{"v=spf1 redirect=domain2"} - txtResults["domain2"] = []string{"v=spf1 ip4:1.1.1.1 -all"} + dns = NewDNS() + dns.txt["domain"] = []string{"v=spf1 redirect=domain2"} + dns.txt["domain2"] = []string{"v=spf1 ip4:1.1.1.1 -all"} res, err := CheckHost(ip1111, "domain") if res != Pass { @@ -227,7 +194,8 @@ func TestInvalidRedirect(t *testing.T) { // Redirect to a non-existing host; the inner check returns None, but due // to the redirection, this lookup should return PermError. // https://tools.ietf.org/html/rfc7208#section-6.1 - txtResults["domain"] = []string{"v=spf1 redirect=doesnotexist"} + dns = NewDNS() + dns.txt["domain"] = []string{"v=spf1 redirect=doesnotexist"} res, err := CheckHost(ip1111, "doesnotexist") if res != None { @@ -243,15 +211,16 @@ func TestInvalidRedirect(t *testing.T) { func TestRedirectOrder(t *testing.T) { // We should only check redirects after all mechanisms, even if the // redirect modifier appears before them. - txtResults["faildom"] = []string{"v=spf1 -all"} + dns = NewDNS() + dns.txt["faildom"] = []string{"v=spf1 -all"} - txtResults["domain"] = []string{"v=spf1 redirect=faildom"} + dns.txt["domain"] = []string{"v=spf1 redirect=faildom"} res, err := CheckHost(ip1111, "domain") if res != Fail || err != errMatchedAll { t.Errorf("expected fail, got %v (%v)", res, err) } - txtResults["domain"] = []string{"v=spf1 redirect=faildom all"} + dns.txt["domain"] = []string{"v=spf1 redirect=faildom all"} res, err = CheckHost(ip1111, "domain") if res != Pass || err != errMatchedAll { t.Errorf("expected pass, got %v (%v)", res, err) @@ -259,9 +228,10 @@ func TestRedirectOrder(t *testing.T) { } func TestNoRecord(t *testing.T) { - txtResults["d1"] = []string{""} - txtResults["d2"] = []string{"loco", "v=spf2"} - txtErrors["nospf"] = fmt.Errorf("no such domain") + dns = NewDNS() + dns.txt["d1"] = []string{""} + dns.txt["d2"] = []string{"loco", "v=spf2"} + dns.errors["nospf"] = fmt.Errorf("no such domain") for _, domain := range []string{"d1", "d2", "d3", "nospf"} { res, err := CheckHost(ip1111, domain) @@ -272,17 +242,16 @@ func TestNoRecord(t *testing.T) { } func TestDNSTemporaryErrors(t *testing.T) { + dns = NewDNS() dnsError := &net.DNSError{ Err: "temporary error for testing", IsTemporary: true, } // Domain "tmperr" will fail resolution with a temporary error. - txtErrors["tmperr"] = dnsError - ipErrors["tmperr"] = dnsError - mxErrors["tmperr"] = dnsError - mxResults["tmpmx"] = []*net.MX{{"tmperr", 10}} - addrErrors["1.1.1.1"] = dnsError + dns.errors["tmperr"] = dnsError + dns.errors["1.1.1.1"] = dnsError + dns.mx["tmpmx"] = []*net.MX{{"tmperr", 10}} cases := []struct { txt string @@ -296,7 +265,7 @@ func TestDNSTemporaryErrors(t *testing.T) { } for _, c := range cases { - txtResults["domain"] = []string{c.txt} + dns.txt["domain"] = []string{c.txt} res, err := CheckHost(ip1111, "domain") if res != c.res { t.Errorf("%q: expected %v, got %v (%v)", @@ -306,17 +275,16 @@ func TestDNSTemporaryErrors(t *testing.T) { } func TestDNSPermanentErrors(t *testing.T) { + dns = NewDNS() dnsError := &net.DNSError{ Err: "permanent error for testing", IsTemporary: false, } // Domain "tmperr" will fail resolution with a temporary error. - txtErrors["tmperr"] = dnsError - ipErrors["tmperr"] = dnsError - mxErrors["tmperr"] = dnsError - mxResults["tmpmx"] = []*net.MX{{"tmperr", 10}} - addrErrors["1.1.1.1"] = dnsError + dns.errors["tmperr"] = dnsError + dns.errors["1.1.1.1"] = dnsError + dns.mx["tmpmx"] = []*net.MX{{"tmperr", 10}} cases := []struct { txt string @@ -330,7 +298,7 @@ func TestDNSPermanentErrors(t *testing.T) { } for _, c := range cases { - txtResults["domain"] = []string{c.txt} + dns.txt["domain"] = []string{c.txt} res, err := CheckHost(ip1111, "domain") if res != c.res { t.Errorf("%q: expected %v, got %v (%v)",