git » spf » commit 1f35621

tests: Isolate DNS overrides

author Alberto Bertogli
2019-10-13 13:49:47 UTC
committer Alberto Bertogli
2019-10-14 12:13:57 UTC
parent 8289dc24c98df04a67a7d1d1ad6f27fcaace4b92

tests: Isolate DNS overrides

The DNS lookup is overridden for testing purposes. In the tests today
we have a global set of arrays to do this, which is difficult to
clear in between tests, and awkward to manage.

This is ok for the test suite as it is today, but with the extended
automated tests we want to be able to clear the slate easier, so this
patch moves the overrides to their own structure and file.

The existing tests are adjusted to reset the overrides on each test.

dns_test.go +69 -0
spf_test.go +43 -75

diff --git a/dns_test.go b/dns_test.go
new file mode 100644
index 0000000..a28fc7e
--- /dev/null
+++ b/dns_test.go
@@ -0,0 +1,69 @@
+package spf
+
+import (
+	"flag"
+	"net"
+	"os"
+	"strings"
+	"testing"
+)
+
+// DNS overrides for testing.
+
+type DNS struct {
+	txt    map[string][]string
+	mx     map[string][]*net.MX
+	ip     map[string][]net.IP
+	addr   map[string][]string
+	errors map[string]error
+}
+
+func NewDNS() DNS {
+	return DNS{
+		txt:    map[string][]string{},
+		mx:     map[string][]*net.MX{},
+		ip:     map[string][]net.IP{},
+		addr:   map[string][]string{},
+		errors: map[string]error{},
+	}
+}
+
+// Single global variable that the overridden resolvers use.
+// This way it's easier to get a clean slate between tests.
+var dns DNS
+
+func LookupTXT(domain string) (txts []string, err error) {
+	domain = strings.ToLower(domain)
+	domain = strings.TrimRight(domain, ".")
+	return dns.txt[domain], dns.errors[domain]
+}
+
+func LookupMX(domain string) (mxs []*net.MX, err error) {
+	domain = strings.ToLower(domain)
+	domain = strings.TrimRight(domain, ".")
+	return dns.mx[domain], dns.errors[domain]
+}
+
+func LookupIP(host string) (ips []net.IP, err error) {
+	host = strings.ToLower(host)
+	host = strings.TrimRight(host, ".")
+	return dns.ip[host], dns.errors[host]
+}
+
+func LookupAddr(host string) (addrs []string, err error) {
+	host = strings.ToLower(host)
+	host = strings.TrimRight(host, ".")
+	return dns.addr[host], dns.errors[host]
+}
+
+func TestMain(m *testing.M) {
+	dns = NewDNS()
+
+	lookupTXT = LookupTXT
+	lookupMX = LookupMX
+	lookupIP = LookupIP
+	lookupAddr = LookupAddr
+
+	flag.Parse()
+	os.Exit(m.Run())
+}
diff --git a/spf_test.go b/spf_test.go
index 76ca65c..54383ec 100644
--- a/spf_test.go
+++ b/spf_test.go
@@ -1,57 +1,19 @@
 package spf
 
 import (
-	"flag"
 	"fmt"
 	"net"
-	"os"
 	"testing"
 )
 
-var txtResults = map[string][]string{}
-var txtErrors = map[string]error{}
-
-func LookupTXT(domain string) (txts []string, err error) {
-	return txtResults[domain], txtErrors[domain]
-}
-
-var mxResults = map[string][]*net.MX{}
-var mxErrors = map[string]error{}
-
-func LookupMX(domain string) (mxs []*net.MX, err error) {
-	return mxResults[domain], mxErrors[domain]
-}
-
-var ipResults = map[string][]net.IP{}
-var ipErrors = map[string]error{}
-
-func LookupIP(host string) (ips []net.IP, err error) {
-	return ipResults[host], ipErrors[host]
-}
-
-var addrResults = map[string][]string{}
-var addrErrors = map[string]error{}
-
-func LookupAddr(host string) (addrs []string, err error) {
-	return addrResults[host], addrErrors[host]
-}
-
-func TestMain(m *testing.M) {
-	lookupTXT = LookupTXT
-	lookupMX = LookupMX
-	lookupIP = LookupIP
-	lookupAddr = LookupAddr
-
-	flag.Parse()
-	os.Exit(m.Run())
-}
-
 var ip1110 = net.ParseIP("1.1.1.0")
 var ip1111 = net.ParseIP("1.1.1.1")
 var ip6666 = net.ParseIP("2001:db8::68")
 var ip6660 = net.ParseIP("2001:db8::0")
 
 func TestBasic(t *testing.T) {
+	dns = NewDNS()
+
 	cases := []struct {
 		txt string
 		res Result
@@ -94,13 +56,13 @@ func TestBasic(t *testing.T) {
 		{"v=spf1 blah", PermError, errUnknownField},
 	}
 
-	ipResults["d1111"] = []net.IP{ip1111}
-	ipResults["d1110"] = []net.IP{ip1110}
-	mxResults["d1110"] = []*net.MX{{"d1110", 5}, {"nothing", 10}}
-	addrResults["1.1.1.1"] = []string{"lalala.", "domain.", "d1111."}
+	dns.ip["d1111"] = []net.IP{ip1111}
+	dns.ip["d1110"] = []net.IP{ip1110}
+	dns.mx["d1110"] = []*net.MX{{"d1110", 5}, {"nothing", 10}}
+	dns.addr["1.1.1.1"] = []string{"lalala.", "domain.", "d1111."}
 
 	for _, c := range cases {
-		txtResults["domain"] = []string{c.txt}
+		dns.txt["domain"] = []string{c.txt}
 		res, err := CheckHost(ip1111, "domain")
 		if (res == TempError || res == PermError) && (err == nil) {
 			t.Errorf("%q: expected error, got nil", c.txt)
@@ -115,6 +77,8 @@ func TestBasic(t *testing.T) {
 }
 
 func TestIPv6(t *testing.T) {
+	dns = NewDNS()
+
 	cases := []struct {
 		txt string
 		res Result
@@ -136,13 +100,13 @@ func TestIPv6(t *testing.T) {
 		{"v=spf1 ptr:sonlas6 -all", Pass, errMatchedPTR},
 	}
 
-	ipResults["d6666"] = []net.IP{ip6666}
-	ipResults["d6660"] = []net.IP{ip6660}
-	mxResults["d6660"] = []*net.MX{{"d6660", 5}, {"nothing", 10}}
-	addrResults["2001:db8::68"] = []string{"sonlas6.", "domain.", "d6666."}
+	dns.ip["d6666"] = []net.IP{ip6666}
+	dns.ip["d6660"] = []net.IP{ip6660}
+	dns.mx["d6660"] = []*net.MX{{"d6660", 5}, {"nothing", 10}}
+	dns.addr["2001:db8::68"] = []string{"sonlas6.", "domain.", "d6666."}
 
 	for _, c := range cases {
-		txtResults["domain"] = []string{c.txt}
+		dns.txt["domain"] = []string{c.txt}
 		res, err := CheckHost(ip6666, "domain")
 		if (res == TempError || res == PermError) && (err == nil) {
 			t.Errorf("%q: expected error, got nil", c.txt)
@@ -168,7 +132,7 @@ func TestNotSupported(t *testing.T) {
 	}
 
 	for _, c := range cases {
-		txtResults["domain"] = []string{c.txt}
+		dns.txt["domain"] = []string{c.txt}
 		res, err := CheckHost(ip1111, "domain")
 		if res != Neutral || err != c.err {
 			t.Errorf("%q: expected neutral/%q, got %v/%q", c.txt, c.err, res, err)
@@ -179,7 +143,8 @@ func TestNotSupported(t *testing.T) {
 func TestInclude(t *testing.T) {
 	// Test that the include is doing a recursive lookup.
 	// If we got a match on 1.1.1.1, is because include:domain2 did not match.
-	txtResults["domain"] = []string{"v=spf1 include:domain2 ip4:1.1.1.1"}
+	dns = NewDNS()
+	dns.txt["domain"] = []string{"v=spf1 include:domain2 ip4:1.1.1.1"}
 
 	cases := []struct {
 		txt string
@@ -195,7 +160,7 @@ func TestInclude(t *testing.T) {
 	}
 
 	for _, c := range cases {
-		txtResults["domain2"] = []string{c.txt}
+		dns.txt["domain2"] = []string{c.txt}
 		res, err := CheckHost(ip1111, "domain")
 		if res != c.res || err != c.err {
 			t.Errorf("%q: expected [%v/%v], got [%v/%v]",
@@ -205,7 +170,8 @@ func TestInclude(t *testing.T) {
 }
 
 func TestRecursionLimit(t *testing.T) {
-	txtResults["domain"] = []string{"v=spf1 include:domain ~all"}
+	dns = NewDNS()
+	dns.txt["domain"] = []string{"v=spf1 include:domain ~all"}
 
 	res, err := CheckHost(ip1111, "domain")
 	if res != PermError || err != errLookupLimitReached {
@@ -214,8 +180,9 @@ func TestRecursionLimit(t *testing.T) {
 }
 
 func TestRedirect(t *testing.T) {
-	txtResults["domain"] = []string{"v=spf1 redirect=domain2"}
-	txtResults["domain2"] = []string{"v=spf1 ip4:1.1.1.1 -all"}
+	dns = NewDNS()
+	dns.txt["domain"] = []string{"v=spf1 redirect=domain2"}
+	dns.txt["domain2"] = []string{"v=spf1 ip4:1.1.1.1 -all"}
 
 	res, err := CheckHost(ip1111, "domain")
 	if res != Pass {
@@ -227,7 +194,8 @@ func TestInvalidRedirect(t *testing.T) {
 	// Redirect to a non-existing host; the inner check returns None, but due
 	// to the redirection, this lookup should return PermError.
 	// https://tools.ietf.org/html/rfc7208#section-6.1
-	txtResults["domain"] = []string{"v=spf1 redirect=doesnotexist"}
+	dns = NewDNS()
+	dns.txt["domain"] = []string{"v=spf1 redirect=doesnotexist"}
 
 	res, err := CheckHost(ip1111, "doesnotexist")
 	if res != None {
@@ -243,15 +211,16 @@ func TestInvalidRedirect(t *testing.T) {
 func TestRedirectOrder(t *testing.T) {
 	// We should only check redirects after all mechanisms, even if the
 	// redirect modifier appears before them.
-	txtResults["faildom"] = []string{"v=spf1 -all"}
+	dns = NewDNS()
+	dns.txt["faildom"] = []string{"v=spf1 -all"}
 
-	txtResults["domain"] = []string{"v=spf1 redirect=faildom"}
+	dns.txt["domain"] = []string{"v=spf1 redirect=faildom"}
 	res, err := CheckHost(ip1111, "domain")
 	if res != Fail || err != errMatchedAll {
 		t.Errorf("expected fail, got %v (%v)", res, err)
 	}
 
-	txtResults["domain"] = []string{"v=spf1 redirect=faildom all"}
+	dns.txt["domain"] = []string{"v=spf1 redirect=faildom all"}
 	res, err = CheckHost(ip1111, "domain")
 	if res != Pass || err != errMatchedAll {
 		t.Errorf("expected pass, got %v (%v)", res, err)
@@ -259,9 +228,10 @@ func TestRedirectOrder(t *testing.T) {
 }
 
 func TestNoRecord(t *testing.T) {
-	txtResults["d1"] = []string{""}
-	txtResults["d2"] = []string{"loco", "v=spf2"}
-	txtErrors["nospf"] = fmt.Errorf("no such domain")
+	dns = NewDNS()
+	dns.txt["d1"] = []string{""}
+	dns.txt["d2"] = []string{"loco", "v=spf2"}
+	dns.errors["nospf"] = fmt.Errorf("no such domain")
 
 	for _, domain := range []string{"d1", "d2", "d3", "nospf"} {
 		res, err := CheckHost(ip1111, domain)
@@ -272,17 +242,16 @@ func TestNoRecord(t *testing.T) {
 }
 
 func TestDNSTemporaryErrors(t *testing.T) {
+	dns = NewDNS()
 	dnsError := &net.DNSError{
 		Err:         "temporary error for testing",
 		IsTemporary: true,
 	}
 
 	// Domain "tmperr" will fail resolution with a temporary error.
-	txtErrors["tmperr"] = dnsError
-	ipErrors["tmperr"] = dnsError
-	mxErrors["tmperr"] = dnsError
-	mxResults["tmpmx"] = []*net.MX{{"tmperr", 10}}
-	addrErrors["1.1.1.1"] = dnsError
+	dns.errors["tmperr"] = dnsError
+	dns.errors["1.1.1.1"] = dnsError
+	dns.mx["tmpmx"] = []*net.MX{{"tmperr", 10}}
 
 	cases := []struct {
 		txt string
@@ -296,7 +265,7 @@ func TestDNSTemporaryErrors(t *testing.T) {
 	}
 
 	for _, c := range cases {
-		txtResults["domain"] = []string{c.txt}
+		dns.txt["domain"] = []string{c.txt}
 		res, err := CheckHost(ip1111, "domain")
 		if res != c.res {
 			t.Errorf("%q: expected %v, got %v (%v)",
@@ -306,17 +275,16 @@ func TestDNSTemporaryErrors(t *testing.T) {
 }
 
 func TestDNSPermanentErrors(t *testing.T) {
+	dns = NewDNS()
 	dnsError := &net.DNSError{
 		Err:         "permanent error for testing",
 		IsTemporary: false,
 	}
 
 	// Domain "tmperr" will fail resolution with a temporary error.
-	txtErrors["tmperr"] = dnsError
-	ipErrors["tmperr"] = dnsError
-	mxErrors["tmperr"] = dnsError
-	mxResults["tmpmx"] = []*net.MX{{"tmperr", 10}}
-	addrErrors["1.1.1.1"] = dnsError
+	dns.errors["tmperr"] = dnsError
+	dns.errors["1.1.1.1"] = dnsError
+	dns.mx["tmpmx"] = []*net.MX{{"tmperr", 10}}
 
 	cases := []struct {
 		txt string
@@ -330,7 +298,7 @@ func TestDNSPermanentErrors(t *testing.T) {
 	}
 
 	for _, c := range cases {
-		txtResults["domain"] = []string{c.txt}
+		dns.txt["domain"] = []string{c.txt}
 		res, err := CheckHost(ip1111, "domain")
 		if res != c.res {
 			t.Errorf("%q: expected %v, got %v (%v)",