git » spf » commit 4cb0f61

test: Make the test DNS resolver return NXDOMAIN automatically

author Alberto Bertogli
2022-08-01 20:12:08 UTC
committer Alberto Bertogli
2022-08-01 20:13:12 UTC
parent a691aa8177e1c2d82606912e722808d18fad1d8b

test: Make the test DNS resolver return NXDOMAIN automatically

When a domain is not explicitly entered in the test DNS resolver
information, make it return NXDOMAIN instead of a nil error.

This helps prevent accidental test passes when we're testing against
domains that we expect not to exist.

internal/dnstest/dns.go +17 -0

diff --git a/internal/dnstest/dns.go b/internal/dnstest/dns.go
index 8892a83..a870de5 100644
--- a/internal/dnstest/dns.go
+++ b/internal/dnstest/dns.go
@@ -35,12 +35,20 @@ func NewResolver() *TestResolver {
 	}
 }
 
+var nxDomainErr = &net.DNSError{
+	Err:        "domain not found (for testing)",
+	IsNotFound: true,
+}
+
 func (r *TestResolver) LookupTXT(ctx context.Context, domain string) (txts []string, err error) {
 	if ctx.Err() != nil {
 		return nil, ctx.Err()
 	}
 	domain = strings.ToLower(domain)
 	domain = strings.TrimRight(domain, ".")
+	if _, ok := r.Txt[domain]; !ok && r.Errors[domain] == nil {
+		return nil, nxDomainErr
+	}
 	return r.Txt[domain], r.Errors[domain]
 }
 
@@ -50,6 +58,9 @@ func (r *TestResolver) LookupMX(ctx context.Context, domain string) (mxs []*net.
 	}
 	domain = strings.ToLower(domain)
 	domain = strings.TrimRight(domain, ".")
+	if _, ok := r.Mx[domain]; !ok && r.Errors[domain] == nil {
+		return nil, nxDomainErr
+	}
 	return r.Mx[domain], r.Errors[domain]
 }
 
@@ -59,6 +70,9 @@ func (r *TestResolver) LookupIPAddr(ctx context.Context, host string) (as []net.
 	}
 	host = strings.ToLower(host)
 	host = strings.TrimRight(host, ".")
+	if _, ok := r.Ip[host]; !ok && r.Errors[host] == nil {
+		return nil, nxDomainErr
+	}
 	return ipsToAddrs(r.Ip[host]), r.Errors[host]
 }
 
@@ -76,5 +90,8 @@ func (r *TestResolver) LookupAddr(ctx context.Context, host string) (addrs []str
 	}
 	host = strings.ToLower(host)
 	host = strings.TrimRight(host, ".")
+	if _, ok := r.Addr[host]; !ok && r.Errors[host] == nil {
+		return nil, nxDomainErr
+	}
 	return r.Addr[host], r.Errors[host]
 }