author | Alberto Bertogli
<albertito@blitiri.com.ar> 2020-09-20 14:48:33 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2020-09-20 16:21:34 UTC |
parent | 55b54fc5ccc120596e452fb227b3764f5fc6c270 |
dns_test.go | +29 | -31 |
spf.go | +44 | -17 |
spf_test.go | +29 | -14 |
yml_test.go | +1 | -1 |
diff --git a/dns_test.go b/dns_test.go index c27d07f..da6d320 100644 --- a/dns_test.go +++ b/dns_test.go @@ -2,17 +2,13 @@ package spf import ( "context" - "flag" - "fmt" "net" - "os" "strings" - "testing" ) // DNS overrides for testing. -type DNS struct { +type TestResolver struct { txt map[string][]string mx map[string][]*net.MX ip map[string][]net.IP @@ -20,8 +16,8 @@ type DNS struct { errors map[string]error } -func NewDNS() DNS { - return DNS{ +func NewResolver() *TestResolver { + return &TestResolver{ txt: map[string][]string{}, mx: map[string][]*net.MX{}, ip: map[string][]net.IP{}, @@ -30,57 +26,59 @@ func NewDNS() DNS { } } -// Single global variable that the overridden resolvers use. -// This way it's easier to get a clean slate between tests. -var dns DNS +func NewDefaultResolver() *TestResolver { + dns := NewResolver() + defaultResolver = dns + return dns +} -func LookupTXT(ctx context.Context, domain string) (txts []string, err error) { +func (r *TestResolver) LookupTXT(ctx context.Context, domain string) (txts []string, err error) { if ctx.Err() != nil { return nil, ctx.Err() } domain = strings.ToLower(domain) domain = strings.TrimRight(domain, ".") - return dns.txt[domain], dns.errors[domain] + return r.txt[domain], r.errors[domain] } -func LookupMX(ctx context.Context, domain string) (mxs []*net.MX, err error) { +func (r *TestResolver) LookupMX(ctx context.Context, domain string) (mxs []*net.MX, err error) { if ctx.Err() != nil { return nil, ctx.Err() } domain = strings.ToLower(domain) domain = strings.TrimRight(domain, ".") - return dns.mx[domain], dns.errors[domain] + return r.mx[domain], r.errors[domain] } -func LookupIP(ctx context.Context, net, host string) (ips []net.IP, err error) { - if net != "ip" { - panic(fmt.Sprintf("got net %q, expected ip", net)) - } +func (r *TestResolver) LookupIPAddr(ctx context.Context, host string) (as []net.IPAddr, err error) { if ctx.Err() != nil { return nil, ctx.Err() } host = strings.ToLower(host) host = strings.TrimRight(host, ".") - return dns.ip[host], dns.errors[host] + return ipsToAddrs(r.ip[host]), r.errors[host] } -func LookupAddr(ctx context.Context, host string) (addrs []string, err error) { +func ipsToAddrs(ips []net.IP) []net.IPAddr { + as := []net.IPAddr{} + for _, ip := range ips { + as = append(as, net.IPAddr{IP: ip, Zone: ""}) + } + return as +} + +func (r *TestResolver) LookupAddr(ctx context.Context, host string) (addrs []string, err error) { if ctx.Err() != nil { return nil, ctx.Err() } host = strings.ToLower(host) host = strings.TrimRight(host, ".") - return dns.addr[host], dns.errors[host] + return r.addr[host], r.errors[host] } -func TestMain(m *testing.M) { - dns = NewDNS() - - lookupTXT = LookupTXT - lookupMX = LookupMX - lookupIP = LookupIP - lookupAddr = LookupAddr - - flag.Parse() - os.Exit(m.Run()) +func init() { + // Override the default resolver to make sure the tests are not using the + // one from net. Individual tests will override this as well, but just in + // case. + NewDefaultResolver() } diff --git a/spf.go b/spf.go index 3d9dfda..38718bf 100644 --- a/spf.go +++ b/spf.go @@ -38,10 +38,6 @@ import ( // Functions that we can override for testing purposes. var ( - lookupTXT = net.DefaultResolver.LookupTXT - lookupMX = net.DefaultResolver.LookupMX - lookupIP = net.DefaultResolver.LookupIP - lookupAddr = net.DefaultResolver.LookupAddr nullTrace = func(f string, a ...interface{}) {} trace = nullTrace ) @@ -132,6 +128,7 @@ func CheckHost(ip net.IP, domain string) (Result, error) { maxcount: defaultMaxLookups, sender: "@" + domain, ctx: context.TODO(), + resolver: defaultResolver, } return r.Check(domain) } @@ -156,6 +153,7 @@ func CheckHostWithSender(ip net.IP, helo, sender string, opts ...Option) (Result maxcount: defaultMaxLookups, sender: sender, ctx: context.TODO(), + resolver: defaultResolver, } for _, opt := range opts { @@ -188,6 +186,30 @@ func WithContext(ctx context.Context) Option { } } +// DNSResolver implements the methods we use to resolve DNS queries. +// It is intentionally compatible with *net.Resolver. +type DNSResolver interface { + LookupTXT(ctx context.Context, name string) ([]string, error) + LookupMX(ctx context.Context, name string) ([]*net.MX, error) + LookupIPAddr(ctx context.Context, host string) ([]net.IPAddr, error) + LookupAddr(ctx context.Context, addr string) (names []string, err error) +} + +var defaultResolver DNSResolver = net.DefaultResolver + +// WithResolver sets the resolver to use for DNS lookups. It can be useful for +// testing, and for customize DNS resolution specifically for this library. +// +// The default is to use net.DefaultResolver, which should be appropriate for +// most users. +// +// This is EXPERIMENTAL for now, and the API is subject to change. +func WithResolver(resolver DNSResolver) Option { + return func(r *resolution) { + r.resolver = resolver + } +} + // split an user@domain address into user and domain. func split(addr string) (string, string) { ps := strings.SplitN(addr, "@", 2) @@ -210,6 +232,9 @@ type resolution struct { // Context for this resolution. ctx context.Context + + // DNS resolver to use. + resolver DNSResolver } var aField = regexp.MustCompile(`^(a$|a:|a/)`) @@ -219,7 +244,7 @@ var ptrField = regexp.MustCompile(`^(ptr$|ptr:)`) func (r *resolution) Check(domain string) (Result, error) { r.count++ trace("check %s %d", domain, r.count) - txt, err := getDNSRecord(r.ctx, domain) + txt, err := r.getDNSRecord(domain) if err != nil { if isTemporary(err) { trace("dns temp error: %v", err) @@ -351,8 +376,8 @@ func (r *resolution) Check(domain string) (Result, error) { // https://tools.ietf.org/html/rfc7208#section-3 // https://tools.ietf.org/html/rfc7208#section-3.2 // https://tools.ietf.org/html/rfc7208#section-4.5 -func getDNSRecord(ctx context.Context, domain string) (string, error) { - txts, err := lookupTXT(ctx, domain) +func (r *resolution) getDNSRecord(domain string) (string, error) { + txts, err := r.resolver.LookupTXT(r.ctx, domain) if err != nil { return "", err } @@ -435,7 +460,7 @@ func (r *resolution) ptrField(res Result, field, domain string) (bool, Result, e if r.ipNames == nil { r.ipNames = []string{} r.count++ - ns, err := lookupAddr(r.ctx, r.ip.String()) + ns, err := r.resolver.LookupAddr(r.ctx, r.ip.String()) if err != nil { // https://tools.ietf.org/html/rfc7208#section-5 if isTemporary(err) { @@ -451,7 +476,7 @@ func (r *resolution) ptrField(res Result, field, domain string) (bool, Result, e return false, "", errLookupLimitReached } r.count++ - addrs, err := lookupIP(r.ctx, "ip", n) + addrs, err := r.resolver.LookupIPAddr(r.ctx, n) if err != nil { // RFC explicitly says to skip domains which error here. continue @@ -491,7 +516,7 @@ func (r *resolution) existsField(res Result, field, domain string) (bool, Result } r.count++ - ips, err := lookupIP(r.ctx, "ip", eDomain) + ips, err := r.resolver.LookupIPAddr(r.ctx, eDomain) if err != nil { // https://tools.ietf.org/html/rfc7208#section-5 if isTemporary(err) { @@ -502,7 +527,7 @@ func (r *resolution) existsField(res Result, field, domain string) (bool, Result // Exists only counts if there are IPv4 matches. for _, ip := range ips { - if ip.To4() != nil { + if ip.IP.To4() != nil { return true, res, errMatchedExists } } @@ -608,7 +633,7 @@ func (r *resolution) aField(res Result, field, domain string) (bool, Result, err } r.count++ - ips, err := lookupIP(r.ctx, "ip", aDomain) + ips, err := r.resolver.LookupIPAddr(r.ctx, aDomain) if err != nil { // https://tools.ietf.org/html/rfc7208#section-5 if isTemporary(err) { @@ -617,9 +642,9 @@ func (r *resolution) aField(res Result, field, domain string) (bool, Result, err return false, "", err } for _, ip := range ips { - ok, err := ipMatch(r.ip, ip, masks) + ok, err := ipMatch(r.ip, ip.IP, masks) if ok { - trace("mx matched %v, %v, %v", r.ip, ip, masks) + trace("mx matched %v, %v, %v", r.ip, ip.IP, masks) return true, res, errMatchedA } else if err != nil { return true, PermError, err @@ -642,7 +667,7 @@ func (r *resolution) mxField(res Result, field, domain string) (bool, Result, er } r.count++ - mxs, err := lookupMX(r.ctx, mxDomain) + mxs, err := r.resolver.LookupMX(r.ctx, mxDomain) if err != nil { // https://tools.ietf.org/html/rfc7208#section-5 if isTemporary(err) { @@ -660,7 +685,7 @@ func (r *resolution) mxField(res Result, field, domain string) (bool, Result, er mxips := []net.IP{} for _, mx := range mxs { r.count++ - ips, err := lookupIP(r.ctx, "ip", mx.Host) + ips, err := r.resolver.LookupIPAddr(r.ctx, mx.Host) if err != nil { // https://tools.ietf.org/html/rfc7208#section-5 if isTemporary(err) { @@ -668,7 +693,9 @@ func (r *resolution) mxField(res Result, field, domain string) (bool, Result, er } return false, "", err } - mxips = append(mxips, ips...) + for _, ipaddr := range ips { + mxips = append(mxips, ipaddr.IP) + } } for _, ip := range mxips { ok, err := ipMatch(r.ip, ip, masks) diff --git a/spf_test.go b/spf_test.go index db94c25..6b7787f 100644 --- a/spf_test.go +++ b/spf_test.go @@ -13,7 +13,7 @@ var ip6666 = net.ParseIP("2001:db8::68") var ip6660 = net.ParseIP("2001:db8::0") func TestBasic(t *testing.T) { - dns = NewDNS() + dns := NewDefaultResolver() trace = t.Logf cases := []struct { @@ -88,7 +88,7 @@ func TestBasic(t *testing.T) { } func TestIPv6(t *testing.T) { - dns = NewDNS() + dns := NewDefaultResolver() trace = t.Logf cases := []struct { @@ -141,7 +141,7 @@ func TestIPv6(t *testing.T) { func TestInclude(t *testing.T) { // Test that the include is doing a recursive lookup. // If we got a match on 1.1.1.1, is because include:domain2 did not match. - dns = NewDNS() + dns := NewDefaultResolver() dns.txt["domain"] = []string{"v=spf1 include:domain2 ip4:1.1.1.1"} trace = t.Logf @@ -169,7 +169,7 @@ func TestInclude(t *testing.T) { } func TestRecursionLimit(t *testing.T) { - dns = NewDNS() + dns := NewDefaultResolver() dns.txt["domain"] = []string{"v=spf1 include:domain ~all"} trace = t.Logf @@ -180,7 +180,7 @@ func TestRecursionLimit(t *testing.T) { } func TestRedirect(t *testing.T) { - dns = NewDNS() + dns := NewDefaultResolver() dns.txt["domain"] = []string{"v=spf1 redirect=domain2"} dns.txt["domain2"] = []string{"v=spf1 ip4:1.1.1.1 -all"} trace = t.Logf @@ -195,7 +195,7 @@ func TestInvalidRedirect(t *testing.T) { // Redirect to a non-existing host; the inner check returns None, but due // to the redirection, this lookup should return PermError. // https://tools.ietf.org/html/rfc7208#section-6.1 - dns = NewDNS() + dns := NewDefaultResolver() dns.txt["domain"] = []string{"v=spf1 redirect=doesnotexist"} trace = t.Logf @@ -213,7 +213,7 @@ func TestInvalidRedirect(t *testing.T) { func TestRedirectOrder(t *testing.T) { // We should only check redirects after all mechanisms, even if the // redirect modifier appears before them. - dns = NewDNS() + dns := NewDefaultResolver() dns.txt["faildom"] = []string{"v=spf1 -all"} trace = t.Logf @@ -231,7 +231,7 @@ func TestRedirectOrder(t *testing.T) { } func TestNoRecord(t *testing.T) { - dns = NewDNS() + dns := NewDefaultResolver() dns.txt["d1"] = []string{""} dns.txt["d2"] = []string{"loco", "v=spf2"} dns.errors["nospf"] = fmt.Errorf("no such domain") @@ -246,7 +246,7 @@ func TestNoRecord(t *testing.T) { } func TestDNSTemporaryErrors(t *testing.T) { - dns = NewDNS() + dns := NewDefaultResolver() dnsError := &net.DNSError{ Err: "temporary error for testing", IsTemporary: true, @@ -280,7 +280,7 @@ func TestDNSTemporaryErrors(t *testing.T) { } func TestDNSPermanentErrors(t *testing.T) { - dns = NewDNS() + dns := NewDefaultResolver() dnsError := &net.DNSError{ Err: "permanent error for testing", IsTemporary: false, @@ -314,7 +314,7 @@ func TestDNSPermanentErrors(t *testing.T) { } func TestMacros(t *testing.T) { - dns = NewDNS() + dns := NewDefaultResolver() trace = t.Logf // Most of the cases are covered by the standard test suite, so this is @@ -358,7 +358,7 @@ func TestMacros(t *testing.T) { } func TestMacrosV4(t *testing.T) { - dns = NewDNS() + dns := NewDefaultResolver() trace = t.Logf // Like TestMacros above, but specifically for IPv4. @@ -469,7 +469,7 @@ func TestNullTrace(t *testing.T) { } func TestOverrideLookupLimit(t *testing.T) { - dns = NewDNS() + dns := NewDefaultResolver() trace = t.Logf dns.txt["domain1"] = []string{"v=spf1 include:domain2"} @@ -500,7 +500,7 @@ func TestOverrideLookupLimit(t *testing.T) { } func TestWithContext(t *testing.T) { - dns = NewDNS() + dns := NewDefaultResolver() trace = t.Logf dns.txt["domain1"] = []string{"v=spf1 include:domain2"} @@ -522,5 +522,20 @@ func TestWithContext(t *testing.T) { if res != None || err != context.Canceled { t.Errorf("expected none/context cancelled, got %q / %q", res, err) } +} + +func TestWithResolver(t *testing.T) { + // Use a custom resolver, making sure it's different from the default. + defaultResolver = NewResolver() + dns := NewResolver() + trace = t.Logf + + dns.txt["domain1"] = []string{"v=spf1 include:domain2"} + dns.txt["domain2"] = []string{"v=spf1 +all"} + res, err := CheckHostWithSender(ip1111, "helo", "user@domain1", + WithResolver(dns)) + if res != Pass { + t.Errorf("expected pass, got %q / %q", res, err) + } } diff --git a/yml_test.go b/yml_test.go index 3443e46..c52baf2 100644 --- a/yml_test.go +++ b/yml_test.go @@ -156,7 +156,7 @@ func testRFC(t *testing.T, fname string) { t.Logf("suite: %v", suite.Description) // Set up zone for the suite based on zonedata. - dns = NewDNS() + dns := NewDefaultResolver() for domain, records := range suite.ZoneData { t.Logf(" domain %v", domain) for _, record := range records {