author | Alberto Bertogli
<albertito@blitiri.com.ar> 2021-06-04 09:23:03 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2021-07-12 17:22:31 UTC |
parent | c059c43277f6497118733725fb8c352b7085c78b |
spf.go | +52 | -33 |
spf_test.go | +42 | -17 |
yml_test.go | +1 | -1 |
diff --git a/spf.go b/spf.go index 9b535ed..fb7cd83 100644 --- a/spf.go +++ b/spf.go @@ -36,12 +36,6 @@ import ( "strings" ) -// Functions that we can override for testing purposes. -var ( - nullTrace = func(f string, a ...interface{}) {} - trace = nullTrace -) - // The Result of an SPF check. Note the values have meaning, we use them in // headers. https://tools.ietf.org/html/rfc7208#section-8 type Result string @@ -108,6 +102,14 @@ var ( // https://tools.ietf.org/html/rfc7208#section-4.6.4 const defaultMaxLookups = 10 +// TraceFunc is the type of tracing functions. +type TraceFunc func(f string, a ...interface{}) + +var ( + nullTrace = func(f string, a ...interface{}) {} + defaultTrace = nullTrace +) + // Option type, for setting options. Users are expected to treat this as an // opaque type and not rely on the implementation, which is subject to change. type Option func(*resolution) @@ -126,13 +128,13 @@ type Option func(*resolution) // // Deprecated: use CheckHostWithSender instead. func CheckHost(ip net.IP, domain string) (Result, error) { - trace("check host %q %q", ip, domain) r := &resolution{ ip: ip, maxcount: defaultMaxLookups, sender: "@" + domain, ctx: context.TODO(), resolver: defaultResolver, + trace: defaultTrace, } return r.Check(domain) } @@ -155,13 +157,13 @@ func CheckHostWithSender(ip net.IP, helo, sender string, opts ...Option) (Result domain = helo } - trace("check host with sender %q %q %q (%q)", ip, helo, sender, domain) r := &resolution{ ip: ip, maxcount: defaultMaxLookups, sender: sender, ctx: context.TODO(), resolver: defaultResolver, + trace: defaultTrace, } for _, opt := range opts { @@ -218,6 +220,19 @@ func WithResolver(resolver DNSResolver) Option { } } +// WithTraceFunc sets the resolver's trace function. +// +// This can be used for debugging. The trace messages are NOT machine +// parseable, and are NOT stable. They should also NOT be included in +// user-visible output, as they may include sensitive details. +// +// This is EXPERIMENTAL for now, and the API is subject to change. +func WithTraceFunc(trace TraceFunc) Option { + return func(r *resolution) { + r.trace = trace + } +} + // split an user@domain address into user and domain. func split(addr string) (string, string) { ps := strings.SplitN(addr, "@", 2) @@ -243,6 +258,9 @@ type resolution struct { // DNS resolver to use. resolver DNSResolver + + // Trace function, used for debugging. + trace TraceFunc } var aField = regexp.MustCompile(`^(a$|a:|a/)`) @@ -251,23 +269,23 @@ var ptrField = regexp.MustCompile(`^(ptr$|ptr:)`) func (r *resolution) Check(domain string) (Result, error) { r.count++ - trace("check %s %d", domain, r.count) + r.trace("check %q %d", domain, r.count) txt, err := r.getDNSRecord(domain) if err != nil { if isTemporary(err) { - trace("dns temp error: %v", err) + r.trace("dns temp error: %v", err) return TempError, err } if err == errMultipleRecords { - trace("multiple dns records") + r.trace("multiple dns records") return PermError, err } // Could not resolve the name, it may be missing the record. // https://tools.ietf.org/html/rfc7208#section-2.6.1 - trace("dns perm error: %v", err) + r.trace("dns perm error: %v", err) return None, err } - trace("dns record %q", txt) + r.trace("dns record %q", txt) if txt == "" { // No record => None. @@ -309,7 +327,7 @@ func (r *resolution) Check(domain string) (Result, error) { // Limit the number of resolutions. // https://tools.ietf.org/html/rfc7208#section-4.6.4 if r.count > r.maxcount { - trace("lookup limit reached") + r.trace("lookup limit reached") return PermError, errLookupLimitReached } @@ -328,54 +346,54 @@ func (r *resolution) Check(domain string) (Result, error) { if lfield == "all" { // https://tools.ietf.org/html/rfc7208#section-5.1 - trace("%v matched all", result) + r.trace("%v matched all", result) return result, errMatchedAll } else if strings.HasPrefix(lfield, "include:") { if ok, res, err := r.includeField(result, field, domain); ok { - trace("include ok, %v %v", res, err) + r.trace("include ok, %v %v", res, err) return res, err } } else if aField.MatchString(lfield) { if ok, res, err := r.aField(result, field, domain); ok { - trace("a ok, %v %v", res, err) + r.trace("a ok, %v %v", res, err) return res, err } } else if mxField.MatchString(lfield) { if ok, res, err := r.mxField(result, field, domain); ok { - trace("mx ok, %v %v", res, err) + r.trace("mx ok, %v %v", res, err) return res, err } } else if strings.HasPrefix(lfield, "ip4:") || strings.HasPrefix(lfield, "ip6:") { if ok, res, err := r.ipField(result, field); ok { - trace("ip ok, %v %v", res, err) + r.trace("ip ok, %v %v", res, err) return res, err } } else if ptrField.MatchString(lfield) { if ok, res, err := r.ptrField(result, field, domain); ok { - trace("ptr ok, %v %v", res, err) + r.trace("ptr ok, %v %v", res, err) return res, err } } else if strings.HasPrefix(lfield, "exists:") { if ok, res, err := r.existsField(result, field, domain); ok { - trace("exists ok, %v %v", res, err) + r.trace("exists ok, %v %v", res, err) return res, err } } else if strings.HasPrefix(lfield, "exp=") { - trace("exp= not used, skipping") + r.trace("exp= not used, skipping") continue } else if strings.HasPrefix(lfield, "redirect=") { - trace("redirect, %q", field) + r.trace("redirect, %q", field) return r.redirectField(field, domain) } else { // http://www.openspf.org/SPF_Record_Syntax - trace("permerror, unknown field") + r.trace("permerror, unknown field") return PermError, errUnknownField } } // Got to the end of the evaluation without a result => Neutral. // https://tools.ietf.org/html/rfc7208#section-4.7 - trace("fallback to neutral") + r.trace("fallback to neutral") return Neutral, nil } @@ -489,7 +507,7 @@ func (r *resolution) ptrField(res Result, field, domain string) (bool, Result, e // RFC explicitly says to skip domains which error here. continue } - trace("ptr forward resolution %q -> %q", n, addrs) + r.trace("ptr forward resolution %q -> %q", n, addrs) if len(addrs) > 0 { // Append the lower-case variants so we do a case-insensitive // lookup below. @@ -498,7 +516,7 @@ func (r *resolution) ptrField(res Result, field, domain string) (bool, Result, e } } - trace("ptr evaluating %q in %q", ptrDomain, r.ipNames) + r.trace("ptr evaluating %q in %q", ptrDomain, r.ipNames) ptrDomain = strings.ToLower(ptrDomain) for _, n := range r.ipNames { if strings.HasSuffix(n, ptrDomain+".") { @@ -615,7 +633,6 @@ func domainAndMask(re *regexp.Regexp, field, domain string) (string, dualMasks, masks.v6 = mask6 } } - trace("masks on %q: %q %q %v", field, groups, domain, masks) // Test to catch malformed entries: if there's a /, there must be at least // one mask. @@ -630,6 +647,7 @@ func domainAndMask(re *regexp.Regexp, field, domain string) (string, dualMasks, func (r *resolution) aField(res Result, field, domain string) (bool, Result, error) { // https://tools.ietf.org/html/rfc7208#section-5.3 aDomain, masks, err := domainAndMask(aRegexp, field, domain) + r.trace("masks on %q, %q: %q %v", field, domain, aDomain, masks) if err != nil { return true, PermError, err } @@ -649,7 +667,7 @@ func (r *resolution) aField(res Result, field, domain string) (bool, Result, err } for _, ip := range ips { if ipMatch(r.ip, ip.IP, masks) { - trace("a matched %v, %v, %v", r.ip, ip.IP, masks) + r.trace("a matched %v, %v, %v", r.ip, ip.IP, masks) return true, res, errMatchedA } } @@ -661,6 +679,7 @@ func (r *resolution) aField(res Result, field, domain string) (bool, Result, err func (r *resolution) mxField(res Result, field, domain string) (bool, Result, error) { // https://tools.ietf.org/html/rfc7208#section-5.4 mxDomain, masks, err := domainAndMask(mxRegexp, field, domain) + r.trace("masks on %q, %q: %q %v", field, domain, mxDomain, masks) if err != nil { return true, PermError, err } @@ -702,7 +721,7 @@ func (r *resolution) mxField(res Result, field, domain string) (bool, Result, er } for _, ip := range mxips { if ipMatch(r.ip, ip, masks) { - trace("mx matched %v, %v, %v", r.ip, ip, masks) + r.trace("mx matched %v, %v, %v", r.ip, ip, masks) return true, res, errMatchedMX } } @@ -744,7 +763,7 @@ func (r *resolution) expandMacros(s, domain string) (string, error) { // from happening in case where it matters (a, mx), but for the ones which // doesn't, prevent them from sneaking through. if strings.Contains(s, "/") { - trace("macro contains /") + r.trace("macro contains /") return "", errInvalidDomain } @@ -793,7 +812,7 @@ func (r *resolution) expandMacros(s, domain string) (string, error) { // Extract letter, digit transformer, reverse transformer, and // delimiters. groups := macroRegexp.FindStringSubmatch(macroS) - trace("macro %q: %q", macroS, groups) + r.trace("macro %q: %q", macroS, groups) macroS = "" if groups == nil { return "", errInvalidMacro @@ -890,7 +909,7 @@ func (r *resolution) expandMacros(s, domain string) (string, error) { n += string(c) } - trace("macro expanded %q to %q", s, n) + r.trace("macro expanded %q to %q", s, n) return n, nil } diff --git a/spf_test.go b/spf_test.go index 022a06c..c44b12e 100644 --- a/spf_test.go +++ b/spf_test.go @@ -29,7 +29,7 @@ var ip6660 = net.ParseIP("2001:db8::0") func TestBasic(t *testing.T) { dns := NewDefaultResolver() - trace = t.Logf + defaultTrace = t.Logf cases := []struct { txt string @@ -104,7 +104,7 @@ func TestBasic(t *testing.T) { func TestIPv6(t *testing.T) { dns := NewDefaultResolver() - trace = t.Logf + defaultTrace = t.Logf cases := []struct { txt string @@ -158,7 +158,7 @@ func TestInclude(t *testing.T) { // If we got a match on 1.1.1.1, is because include:domain2 did not match. dns := NewDefaultResolver() dns.Txt["domain"] = []string{"v=spf1 include:domain2 ip4:1.1.1.1"} - trace = t.Logf + defaultTrace = t.Logf cases := []struct { txt string @@ -186,7 +186,7 @@ func TestInclude(t *testing.T) { func TestRecursionLimit(t *testing.T) { dns := NewDefaultResolver() dns.Txt["domain"] = []string{"v=spf1 include:domain ~all"} - trace = t.Logf + defaultTrace = t.Logf res, err := CheckHost(ip1111, "domain") if res != PermError || err != errLookupLimitReached { @@ -198,7 +198,7 @@ func TestRedirect(t *testing.T) { dns := NewDefaultResolver() dns.Txt["domain"] = []string{"v=spf1 redirect=domain2"} dns.Txt["domain2"] = []string{"v=spf1 ip4:1.1.1.1 -all"} - trace = t.Logf + defaultTrace = t.Logf res, err := CheckHost(ip1111, "domain") if res != Pass { @@ -212,7 +212,7 @@ func TestInvalidRedirect(t *testing.T) { // https://tools.ietf.org/html/rfc7208#section-6.1 dns := NewDefaultResolver() dns.Txt["domain"] = []string{"v=spf1 redirect=doesnotexist"} - trace = t.Logf + defaultTrace = t.Logf res, err := CheckHost(ip1111, "doesnotexist") if res != None { @@ -230,7 +230,7 @@ func TestRedirectOrder(t *testing.T) { // redirect modifier appears before them. dns := NewDefaultResolver() dns.Txt["faildom"] = []string{"v=spf1 -all"} - trace = t.Logf + defaultTrace = t.Logf dns.Txt["domain"] = []string{"v=spf1 redirect=faildom"} res, err := CheckHost(ip1111, "domain") @@ -250,7 +250,7 @@ func TestNoRecord(t *testing.T) { dns.Txt["d1"] = []string{""} dns.Txt["d2"] = []string{"loco", "v=spf2"} dns.Errors["nospf"] = fmt.Errorf("no such domain") - trace = t.Logf + defaultTrace = t.Logf for _, domain := range []string{"d1", "d2", "d3", "nospf"} { res, err := CheckHost(ip1111, domain) @@ -271,7 +271,7 @@ func TestDNSTemporaryErrors(t *testing.T) { dns.Errors["tmperr"] = dnsError dns.Errors["1.1.1.1"] = dnsError dns.Mx["tmpmx"] = []*net.MX{mx("tmperr", 10)} - trace = t.Logf + defaultTrace = t.Logf cases := []struct { txt string @@ -305,7 +305,7 @@ func TestDNSPermanentErrors(t *testing.T) { dns.Errors["tmperr"] = dnsError dns.Errors["1.1.1.1"] = dnsError dns.Mx["tmpmx"] = []*net.MX{mx("tmperr", 10)} - trace = t.Logf + defaultTrace = t.Logf cases := []struct { txt string @@ -330,7 +330,7 @@ func TestDNSPermanentErrors(t *testing.T) { func TestMacros(t *testing.T) { dns := NewDefaultResolver() - trace = t.Logf + defaultTrace = t.Logf // Most of the cases are covered by the standard test suite, so this is // targeted at gaps in coverage. @@ -376,7 +376,7 @@ func TestMacros(t *testing.T) { func TestMacrosV4(t *testing.T) { dns := NewDefaultResolver() - trace = t.Logf + defaultTrace = t.Logf // Like TestMacros above, but specifically for IPv4. // It's easier to have a separate suite. @@ -462,6 +462,7 @@ func TestInvalidMacro(t *testing.T) { ip: ip1111, count: 0, sender: "sender.com", + trace: t.Logf, } out, err := r.expandMacros(macro, "sender.com") @@ -476,7 +477,7 @@ func TestInvalidMacro(t *testing.T) { // other tests override it. func TestNullTrace(t *testing.T) { dns := NewDefaultResolver() - trace = nullTrace + defaultTrace = nullTrace dns.Txt["domain1"] = []string{"v=spf1 include:domain2"} dns.Txt["domain2"] = []string{"v=spf1 +all"} @@ -490,7 +491,7 @@ func TestNullTrace(t *testing.T) { func TestOverrideLookupLimit(t *testing.T) { dns := NewDefaultResolver() - trace = t.Logf + defaultTrace = t.Logf dns.Txt["domain1"] = []string{"v=spf1 include:domain2"} dns.Txt["domain2"] = []string{"v=spf1 include:domain3"} @@ -521,7 +522,7 @@ func TestOverrideLookupLimit(t *testing.T) { func TestWithContext(t *testing.T) { dns := NewDefaultResolver() - trace = t.Logf + defaultTrace = t.Logf dns.Txt["domain1"] = []string{"v=spf1 include:domain2"} dns.Txt["domain2"] = []string{"v=spf1 +all"} @@ -548,7 +549,7 @@ func TestWithResolver(t *testing.T) { // Use a custom resolver, making sure it's different from the default. defaultResolver = dnstest.NewResolver() dns := dnstest.NewResolver() - trace = t.Logf + defaultTrace = t.Logf dns.Txt["domain1"] = []string{"v=spf1 include:domain2"} dns.Txt["domain2"] = []string{"v=spf1 +all"} @@ -564,7 +565,7 @@ func TestWithResolver(t *testing.T) { // address. This can happen if using a buggy custom resolver. func TestBadResolverResponse(t *testing.T) { dns := dnstest.NewResolver() - trace = t.Logf + defaultTrace = t.Logf // When LookupIPAddr returns an invalid ip, for an "a" field. dns.Ip["domain1"] = []net.IP{nil} @@ -604,3 +605,27 @@ func TestBadResolverResponse(t *testing.T) { t.Errorf("expected fail, got %q / %q", res, err) } } + +func TestWithTraceFunc(t *testing.T) { + calls := 0 + var trace TraceFunc = func(f string, a ...interface{}) { + calls++ + t.Logf("tracing "+f, a...) + } + + dns := NewDefaultResolver() + + dns.Txt["domain1"] = []string{"v=spf1 include:domain2"} + dns.Txt["domain2"] = []string{"v=spf1 +all"} + + // Do a normal resolution, check it passes. + res, err := CheckHostWithSender(ip1111, "helo", "user@domain1", + WithTraceFunc(trace)) + if res != Pass { + t.Errorf("expected pass, got %q / %q", res, err) + } + + if calls == 0 { + t.Errorf("expected >0 trace function calls, got 0") + } +} diff --git a/yml_test.go b/yml_test.go index e617006..111b8ac 100644 --- a/yml_test.go +++ b/yml_test.go @@ -150,7 +150,7 @@ func testRFC(t *testing.T, fname string) { suites = append(suites, s) } - trace = t.Logf + defaultTrace = t.Logf for _, suite := range suites { t.Logf("suite: %v", suite.Description)