git » spf » commit 872f2ef

Allow using a custom DNS resolver

author Alberto Bertogli
2020-09-20 14:48:33 UTC
committer Alberto Bertogli
2020-09-20 16:21:34 UTC
parent 55b54fc5ccc120596e452fb227b3764f5fc6c270

Allow using a custom DNS resolver

In some cases, like for integration testing or when running in special
environments, the users may need to force the library to use a non-default
DNS resolver.

This patch adds an option for that, via a new WithResolver option.

Thanks to Thierry Fournier <thierry.fournier@ozon.io> for the bug report
and alternative patch.

dns_test.go +29 -31
spf.go +44 -17
spf_test.go +29 -14
yml_test.go +1 -1

diff --git a/dns_test.go b/dns_test.go
index c27d07f..da6d320 100644
--- a/dns_test.go
+++ b/dns_test.go
@@ -2,17 +2,13 @@ package spf
 
 import (
 	"context"
-	"flag"
-	"fmt"
 	"net"
-	"os"
 	"strings"
-	"testing"
 )
 
 // DNS overrides for testing.
 
-type DNS struct {
+type TestResolver struct {
 	txt    map[string][]string
 	mx     map[string][]*net.MX
 	ip     map[string][]net.IP
@@ -20,8 +16,8 @@ type DNS struct {
 	errors map[string]error
 }
 
-func NewDNS() DNS {
-	return DNS{
+func NewResolver() *TestResolver {
+	return &TestResolver{
 		txt:    map[string][]string{},
 		mx:     map[string][]*net.MX{},
 		ip:     map[string][]net.IP{},
@@ -30,57 +26,59 @@ func NewDNS() DNS {
 	}
 }
 
-// Single global variable that the overridden resolvers use.
-// This way it's easier to get a clean slate between tests.
-var dns DNS
+func NewDefaultResolver() *TestResolver {
+	dns := NewResolver()
+	defaultResolver = dns
+	return dns
+}
 
-func LookupTXT(ctx context.Context, domain string) (txts []string, err error) {
+func (r *TestResolver) LookupTXT(ctx context.Context, domain string) (txts []string, err error) {
 	if ctx.Err() != nil {
 		return nil, ctx.Err()
 	}
 	domain = strings.ToLower(domain)
 	domain = strings.TrimRight(domain, ".")
-	return dns.txt[domain], dns.errors[domain]
+	return r.txt[domain], r.errors[domain]
 }
 
-func LookupMX(ctx context.Context, domain string) (mxs []*net.MX, err error) {
+func (r *TestResolver) LookupMX(ctx context.Context, domain string) (mxs []*net.MX, err error) {
 	if ctx.Err() != nil {
 		return nil, ctx.Err()
 	}
 	domain = strings.ToLower(domain)
 	domain = strings.TrimRight(domain, ".")
-	return dns.mx[domain], dns.errors[domain]
+	return r.mx[domain], r.errors[domain]
 }
 
-func LookupIP(ctx context.Context, net, host string) (ips []net.IP, err error) {
-	if net != "ip" {
-		panic(fmt.Sprintf("got net %q, expected ip", net))
-	}
+func (r *TestResolver) LookupIPAddr(ctx context.Context, host string) (as []net.IPAddr, err error) {
 	if ctx.Err() != nil {
 		return nil, ctx.Err()
 	}
 	host = strings.ToLower(host)
 	host = strings.TrimRight(host, ".")
-	return dns.ip[host], dns.errors[host]
+	return ipsToAddrs(r.ip[host]), r.errors[host]
 }
 
-func LookupAddr(ctx context.Context, host string) (addrs []string, err error) {
+func ipsToAddrs(ips []net.IP) []net.IPAddr {
+	as := []net.IPAddr{}
+	for _, ip := range ips {
+		as = append(as, net.IPAddr{IP: ip, Zone: ""})
+	}
+	return as
+}
+
+func (r *TestResolver) LookupAddr(ctx context.Context, host string) (addrs []string, err error) {
 	if ctx.Err() != nil {
 		return nil, ctx.Err()
 	}
 	host = strings.ToLower(host)
 	host = strings.TrimRight(host, ".")
-	return dns.addr[host], dns.errors[host]
+	return r.addr[host], r.errors[host]
 }
 
-func TestMain(m *testing.M) {
-	dns = NewDNS()
-
-	lookupTXT = LookupTXT
-	lookupMX = LookupMX
-	lookupIP = LookupIP
-	lookupAddr = LookupAddr
-
-	flag.Parse()
-	os.Exit(m.Run())
+func init() {
+	// Override the default resolver to make sure the tests are not using the
+	// one from net. Individual tests will override this as well, but just in
+	// case.
+	NewDefaultResolver()
 }
diff --git a/spf.go b/spf.go
index 3d9dfda..38718bf 100644
--- a/spf.go
+++ b/spf.go
@@ -38,10 +38,6 @@ import (
 
 // Functions that we can override for testing purposes.
 var (
-	lookupTXT  = net.DefaultResolver.LookupTXT
-	lookupMX   = net.DefaultResolver.LookupMX
-	lookupIP   = net.DefaultResolver.LookupIP
-	lookupAddr = net.DefaultResolver.LookupAddr
 	nullTrace = func(f string, a ...interface{}) {}
 	trace     = nullTrace
 )
@@ -132,6 +128,7 @@ func CheckHost(ip net.IP, domain string) (Result, error) {
 		maxcount: defaultMaxLookups,
 		sender:   "@" + domain,
 		ctx:      context.TODO(),
+		resolver: defaultResolver,
 	}
 	return r.Check(domain)
 }
@@ -156,6 +153,7 @@ func CheckHostWithSender(ip net.IP, helo, sender string, opts ...Option) (Result
 		maxcount: defaultMaxLookups,
 		sender:   sender,
 		ctx:      context.TODO(),
+		resolver: defaultResolver,
 	}
 
 	for _, opt := range opts {
@@ -188,6 +186,30 @@ func WithContext(ctx context.Context) Option {
 	}
 }
 
+// DNSResolver implements the methods we use to resolve DNS queries.
+// It is intentionally compatible with *net.Resolver.
+type DNSResolver interface {
+	LookupTXT(ctx context.Context, name string) ([]string, error)
+	LookupMX(ctx context.Context, name string) ([]*net.MX, error)
+	LookupIPAddr(ctx context.Context, host string) ([]net.IPAddr, error)
+	LookupAddr(ctx context.Context, addr string) (names []string, err error)
+}
+
+var defaultResolver DNSResolver = net.DefaultResolver
+
+// WithResolver sets the resolver to use for DNS lookups. It can be useful for
+// testing, and for customize DNS resolution specifically for this library.
+//
+// The default is to use net.DefaultResolver, which should be appropriate for
+// most users.
+//
+// This is EXPERIMENTAL for now, and the API is subject to change.
+func WithResolver(resolver DNSResolver) Option {
+	return func(r *resolution) {
+		r.resolver = resolver
+	}
+}
+
 // split an user@domain address into user and domain.
 func split(addr string) (string, string) {
 	ps := strings.SplitN(addr, "@", 2)
@@ -210,6 +232,9 @@ type resolution struct {
 
 	// Context for this resolution.
 	ctx context.Context
+
+	// DNS resolver to use.
+	resolver DNSResolver
 }
 
 var aField = regexp.MustCompile(`^(a$|a:|a/)`)
@@ -219,7 +244,7 @@ var ptrField = regexp.MustCompile(`^(ptr$|ptr:)`)
 func (r *resolution) Check(domain string) (Result, error) {
 	r.count++
 	trace("check %s %d", domain, r.count)
-	txt, err := getDNSRecord(r.ctx, domain)
+	txt, err := r.getDNSRecord(domain)
 	if err != nil {
 		if isTemporary(err) {
 			trace("dns temp error: %v", err)
@@ -351,8 +376,8 @@ func (r *resolution) Check(domain string) (Result, error) {
 // https://tools.ietf.org/html/rfc7208#section-3
 // https://tools.ietf.org/html/rfc7208#section-3.2
 // https://tools.ietf.org/html/rfc7208#section-4.5
-func getDNSRecord(ctx context.Context, domain string) (string, error) {
-	txts, err := lookupTXT(ctx, domain)
+func (r *resolution) getDNSRecord(domain string) (string, error) {
+	txts, err := r.resolver.LookupTXT(r.ctx, domain)
 	if err != nil {
 		return "", err
 	}
@@ -435,7 +460,7 @@ func (r *resolution) ptrField(res Result, field, domain string) (bool, Result, e
 	if r.ipNames == nil {
 		r.ipNames = []string{}
 		r.count++
-		ns, err := lookupAddr(r.ctx, r.ip.String())
+		ns, err := r.resolver.LookupAddr(r.ctx, r.ip.String())
 		if err != nil {
 			// https://tools.ietf.org/html/rfc7208#section-5
 			if isTemporary(err) {
@@ -451,7 +476,7 @@ func (r *resolution) ptrField(res Result, field, domain string) (bool, Result, e
 				return false, "", errLookupLimitReached
 			}
 			r.count++
-			addrs, err := lookupIP(r.ctx, "ip", n)
+			addrs, err := r.resolver.LookupIPAddr(r.ctx, n)
 			if err != nil {
 				// RFC explicitly says to skip domains which error here.
 				continue
@@ -491,7 +516,7 @@ func (r *resolution) existsField(res Result, field, domain string) (bool, Result
 	}
 
 	r.count++
-	ips, err := lookupIP(r.ctx, "ip", eDomain)
+	ips, err := r.resolver.LookupIPAddr(r.ctx, eDomain)
 	if err != nil {
 		// https://tools.ietf.org/html/rfc7208#section-5
 		if isTemporary(err) {
@@ -502,7 +527,7 @@ func (r *resolution) existsField(res Result, field, domain string) (bool, Result
 
 	// Exists only counts if there are IPv4 matches.
 	for _, ip := range ips {
-		if ip.To4() != nil {
+		if ip.IP.To4() != nil {
 			return true, res, errMatchedExists
 		}
 	}
@@ -608,7 +633,7 @@ func (r *resolution) aField(res Result, field, domain string) (bool, Result, err
 	}
 
 	r.count++
-	ips, err := lookupIP(r.ctx, "ip", aDomain)
+	ips, err := r.resolver.LookupIPAddr(r.ctx, aDomain)
 	if err != nil {
 		// https://tools.ietf.org/html/rfc7208#section-5
 		if isTemporary(err) {
@@ -617,9 +642,9 @@ func (r *resolution) aField(res Result, field, domain string) (bool, Result, err
 		return false, "", err
 	}
 	for _, ip := range ips {
-		ok, err := ipMatch(r.ip, ip, masks)
+		ok, err := ipMatch(r.ip, ip.IP, masks)
 		if ok {
-			trace("mx matched %v, %v, %v", r.ip, ip, masks)
+			trace("mx matched %v, %v, %v", r.ip, ip.IP, masks)
 			return true, res, errMatchedA
 		} else if err != nil {
 			return true, PermError, err
@@ -642,7 +667,7 @@ func (r *resolution) mxField(res Result, field, domain string) (bool, Result, er
 	}
 
 	r.count++
-	mxs, err := lookupMX(r.ctx, mxDomain)
+	mxs, err := r.resolver.LookupMX(r.ctx, mxDomain)
 	if err != nil {
 		// https://tools.ietf.org/html/rfc7208#section-5
 		if isTemporary(err) {
@@ -660,7 +685,7 @@ func (r *resolution) mxField(res Result, field, domain string) (bool, Result, er
 	mxips := []net.IP{}
 	for _, mx := range mxs {
 		r.count++
-		ips, err := lookupIP(r.ctx, "ip", mx.Host)
+		ips, err := r.resolver.LookupIPAddr(r.ctx, mx.Host)
 		if err != nil {
 			// https://tools.ietf.org/html/rfc7208#section-5
 			if isTemporary(err) {
@@ -668,7 +693,9 @@ func (r *resolution) mxField(res Result, field, domain string) (bool, Result, er
 			}
 			return false, "", err
 		}
-		mxips = append(mxips, ips...)
+		for _, ipaddr := range ips {
+			mxips = append(mxips, ipaddr.IP)
+		}
 	}
 	for _, ip := range mxips {
 		ok, err := ipMatch(r.ip, ip, masks)
diff --git a/spf_test.go b/spf_test.go
index db94c25..6b7787f 100644
--- a/spf_test.go
+++ b/spf_test.go
@@ -13,7 +13,7 @@ var ip6666 = net.ParseIP("2001:db8::68")
 var ip6660 = net.ParseIP("2001:db8::0")
 
 func TestBasic(t *testing.T) {
-	dns = NewDNS()
+	dns := NewDefaultResolver()
 	trace = t.Logf
 
 	cases := []struct {
@@ -88,7 +88,7 @@ func TestBasic(t *testing.T) {
 }
 
 func TestIPv6(t *testing.T) {
-	dns = NewDNS()
+	dns := NewDefaultResolver()
 	trace = t.Logf
 
 	cases := []struct {
@@ -141,7 +141,7 @@ func TestIPv6(t *testing.T) {
 func TestInclude(t *testing.T) {
 	// Test that the include is doing a recursive lookup.
 	// If we got a match on 1.1.1.1, is because include:domain2 did not match.
-	dns = NewDNS()
+	dns := NewDefaultResolver()
 	dns.txt["domain"] = []string{"v=spf1 include:domain2 ip4:1.1.1.1"}
 	trace = t.Logf
 
@@ -169,7 +169,7 @@ func TestInclude(t *testing.T) {
 }
 
 func TestRecursionLimit(t *testing.T) {
-	dns = NewDNS()
+	dns := NewDefaultResolver()
 	dns.txt["domain"] = []string{"v=spf1 include:domain ~all"}
 	trace = t.Logf
 
@@ -180,7 +180,7 @@ func TestRecursionLimit(t *testing.T) {
 }
 
 func TestRedirect(t *testing.T) {
-	dns = NewDNS()
+	dns := NewDefaultResolver()
 	dns.txt["domain"] = []string{"v=spf1 redirect=domain2"}
 	dns.txt["domain2"] = []string{"v=spf1 ip4:1.1.1.1 -all"}
 	trace = t.Logf
@@ -195,7 +195,7 @@ func TestInvalidRedirect(t *testing.T) {
 	// Redirect to a non-existing host; the inner check returns None, but due
 	// to the redirection, this lookup should return PermError.
 	// https://tools.ietf.org/html/rfc7208#section-6.1
-	dns = NewDNS()
+	dns := NewDefaultResolver()
 	dns.txt["domain"] = []string{"v=spf1 redirect=doesnotexist"}
 	trace = t.Logf
 
@@ -213,7 +213,7 @@ func TestInvalidRedirect(t *testing.T) {
 func TestRedirectOrder(t *testing.T) {
 	// We should only check redirects after all mechanisms, even if the
 	// redirect modifier appears before them.
-	dns = NewDNS()
+	dns := NewDefaultResolver()
 	dns.txt["faildom"] = []string{"v=spf1 -all"}
 	trace = t.Logf
 
@@ -231,7 +231,7 @@ func TestRedirectOrder(t *testing.T) {
 }
 
 func TestNoRecord(t *testing.T) {
-	dns = NewDNS()
+	dns := NewDefaultResolver()
 	dns.txt["d1"] = []string{""}
 	dns.txt["d2"] = []string{"loco", "v=spf2"}
 	dns.errors["nospf"] = fmt.Errorf("no such domain")
@@ -246,7 +246,7 @@ func TestNoRecord(t *testing.T) {
 }
 
 func TestDNSTemporaryErrors(t *testing.T) {
-	dns = NewDNS()
+	dns := NewDefaultResolver()
 	dnsError := &net.DNSError{
 		Err:         "temporary error for testing",
 		IsTemporary: true,
@@ -280,7 +280,7 @@ func TestDNSTemporaryErrors(t *testing.T) {
 }
 
 func TestDNSPermanentErrors(t *testing.T) {
-	dns = NewDNS()
+	dns := NewDefaultResolver()
 	dnsError := &net.DNSError{
 		Err:         "permanent error for testing",
 		IsTemporary: false,
@@ -314,7 +314,7 @@ func TestDNSPermanentErrors(t *testing.T) {
 }
 
 func TestMacros(t *testing.T) {
-	dns = NewDNS()
+	dns := NewDefaultResolver()
 	trace = t.Logf
 
 	// Most of the cases are covered by the standard test suite, so this is
@@ -358,7 +358,7 @@ func TestMacros(t *testing.T) {
 }
 
 func TestMacrosV4(t *testing.T) {
-	dns = NewDNS()
+	dns := NewDefaultResolver()
 	trace = t.Logf
 
 	// Like TestMacros above, but specifically for IPv4.
@@ -469,7 +469,7 @@ func TestNullTrace(t *testing.T) {
 }
 
 func TestOverrideLookupLimit(t *testing.T) {
-	dns = NewDNS()
+	dns := NewDefaultResolver()
 	trace = t.Logf
 
 	dns.txt["domain1"] = []string{"v=spf1 include:domain2"}
@@ -500,7 +500,7 @@ func TestOverrideLookupLimit(t *testing.T) {
 }
 
 func TestWithContext(t *testing.T) {
-	dns = NewDNS()
+	dns := NewDefaultResolver()
 	trace = t.Logf
 
 	dns.txt["domain1"] = []string{"v=spf1 include:domain2"}
@@ -522,5 +522,20 @@ func TestWithContext(t *testing.T) {
 	if res != None || err != context.Canceled {
 		t.Errorf("expected none/context cancelled, got %q / %q", res, err)
 	}
+}
+
+func TestWithResolver(t *testing.T) {
+	// Use a custom resolver, making sure it's different from the default.
+	defaultResolver = NewResolver()
+	dns := NewResolver()
+	trace = t.Logf
+
+	dns.txt["domain1"] = []string{"v=spf1 include:domain2"}
+	dns.txt["domain2"] = []string{"v=spf1 +all"}
 
+	res, err := CheckHostWithSender(ip1111, "helo", "user@domain1",
+		WithResolver(dns))
+	if res != Pass {
+		t.Errorf("expected pass, got %q / %q", res, err)
+	}
 }
diff --git a/yml_test.go b/yml_test.go
index 3443e46..c52baf2 100644
--- a/yml_test.go
+++ b/yml_test.go
@@ -156,7 +156,7 @@ func testRFC(t *testing.T, fname string) {
 		t.Logf("suite: %v", suite.Description)
 
 		// Set up zone for the suite based on zonedata.
-		dns = NewDNS()
+		dns := NewDefaultResolver()
 		for domain, records := range suite.ZoneData {
 			t.Logf("  domain %v", domain)
 			for _, record := range records {