git » spf » commit a71dd86

Add an option for a custom tracing function

author Alberto Bertogli
2021-06-04 09:23:03 UTC
committer Alberto Bertogli
2021-07-12 17:22:31 UTC
parent c059c43277f6497118733725fb8c352b7085c78b

Add an option for a custom tracing function

This patch implements an per-check option for a custom tracing function.

With it, callers can propagate SPF debugging information, which can help
troubleshoot issues.

The output is not machine parsable, should not be included in
user-visible output, and only for debugging purposes.

spf.go +52 -33
spf_test.go +42 -17
yml_test.go +1 -1

diff --git a/spf.go b/spf.go
index 9b535ed..fb7cd83 100644
--- a/spf.go
+++ b/spf.go
@@ -36,12 +36,6 @@ import (
 	"strings"
 )
 
-// Functions that we can override for testing purposes.
-var (
-	nullTrace = func(f string, a ...interface{}) {}
-	trace     = nullTrace
-)
-
 // The Result of an SPF check. Note the values have meaning, we use them in
 // headers.  https://tools.ietf.org/html/rfc7208#section-8
 type Result string
@@ -108,6 +102,14 @@ var (
 // https://tools.ietf.org/html/rfc7208#section-4.6.4
 const defaultMaxLookups = 10
 
+// TraceFunc is the type of tracing functions.
+type TraceFunc func(f string, a ...interface{})
+
+var (
+	nullTrace    = func(f string, a ...interface{}) {}
+	defaultTrace = nullTrace
+)
+
 // Option type, for setting options. Users are expected to treat this as an
 // opaque type and not rely on the implementation, which is subject to change.
 type Option func(*resolution)
@@ -126,13 +128,13 @@ type Option func(*resolution)
 //
 // Deprecated: use CheckHostWithSender instead.
 func CheckHost(ip net.IP, domain string) (Result, error) {
-	trace("check host %q %q", ip, domain)
 	r := &resolution{
 		ip:       ip,
 		maxcount: defaultMaxLookups,
 		sender:   "@" + domain,
 		ctx:      context.TODO(),
 		resolver: defaultResolver,
+		trace:    defaultTrace,
 	}
 	return r.Check(domain)
 }
@@ -155,13 +157,13 @@ func CheckHostWithSender(ip net.IP, helo, sender string, opts ...Option) (Result
 		domain = helo
 	}
 
-	trace("check host with sender %q %q %q (%q)", ip, helo, sender, domain)
 	r := &resolution{
 		ip:       ip,
 		maxcount: defaultMaxLookups,
 		sender:   sender,
 		ctx:      context.TODO(),
 		resolver: defaultResolver,
+		trace:    defaultTrace,
 	}
 
 	for _, opt := range opts {
@@ -218,6 +220,19 @@ func WithResolver(resolver DNSResolver) Option {
 	}
 }
 
+// WithTraceFunc sets the resolver's trace function.
+//
+// This can be used for debugging. The trace messages are NOT machine
+// parseable, and are NOT stable. They should also NOT be included in
+// user-visible output, as they may include sensitive details.
+//
+// This is EXPERIMENTAL for now, and the API is subject to change.
+func WithTraceFunc(trace TraceFunc) Option {
+	return func(r *resolution) {
+		r.trace = trace
+	}
+}
+
 // split an user@domain address into user and domain.
 func split(addr string) (string, string) {
 	ps := strings.SplitN(addr, "@", 2)
@@ -243,6 +258,9 @@ type resolution struct {
 
 	// DNS resolver to use.
 	resolver DNSResolver
+
+	// Trace function, used for debugging.
+	trace TraceFunc
 }
 
 var aField = regexp.MustCompile(`^(a$|a:|a/)`)
@@ -251,23 +269,23 @@ var ptrField = regexp.MustCompile(`^(ptr$|ptr:)`)
 
 func (r *resolution) Check(domain string) (Result, error) {
 	r.count++
-	trace("check %s %d", domain, r.count)
+	r.trace("check %q %d", domain, r.count)
 	txt, err := r.getDNSRecord(domain)
 	if err != nil {
 		if isTemporary(err) {
-			trace("dns temp error: %v", err)
+			r.trace("dns temp error: %v", err)
 			return TempError, err
 		}
 		if err == errMultipleRecords {
-			trace("multiple dns records")
+			r.trace("multiple dns records")
 			return PermError, err
 		}
 		// Could not resolve the name, it may be missing the record.
 		// https://tools.ietf.org/html/rfc7208#section-2.6.1
-		trace("dns perm error: %v", err)
+		r.trace("dns perm error: %v", err)
 		return None, err
 	}
-	trace("dns record %q", txt)
+	r.trace("dns record %q", txt)
 
 	if txt == "" {
 		// No record => None.
@@ -309,7 +327,7 @@ func (r *resolution) Check(domain string) (Result, error) {
 		// Limit the number of resolutions.
 		// https://tools.ietf.org/html/rfc7208#section-4.6.4
 		if r.count > r.maxcount {
-			trace("lookup limit reached")
+			r.trace("lookup limit reached")
 			return PermError, errLookupLimitReached
 		}
 
@@ -328,54 +346,54 @@ func (r *resolution) Check(domain string) (Result, error) {
 
 		if lfield == "all" {
 			// https://tools.ietf.org/html/rfc7208#section-5.1
-			trace("%v matched all", result)
+			r.trace("%v matched all", result)
 			return result, errMatchedAll
 		} else if strings.HasPrefix(lfield, "include:") {
 			if ok, res, err := r.includeField(result, field, domain); ok {
-				trace("include ok, %v %v", res, err)
+				r.trace("include ok, %v %v", res, err)
 				return res, err
 			}
 		} else if aField.MatchString(lfield) {
 			if ok, res, err := r.aField(result, field, domain); ok {
-				trace("a ok, %v %v", res, err)
+				r.trace("a ok, %v %v", res, err)
 				return res, err
 			}
 		} else if mxField.MatchString(lfield) {
 			if ok, res, err := r.mxField(result, field, domain); ok {
-				trace("mx ok, %v %v", res, err)
+				r.trace("mx ok, %v %v", res, err)
 				return res, err
 			}
 		} else if strings.HasPrefix(lfield, "ip4:") || strings.HasPrefix(lfield, "ip6:") {
 			if ok, res, err := r.ipField(result, field); ok {
-				trace("ip ok, %v %v", res, err)
+				r.trace("ip ok, %v %v", res, err)
 				return res, err
 			}
 		} else if ptrField.MatchString(lfield) {
 			if ok, res, err := r.ptrField(result, field, domain); ok {
-				trace("ptr ok, %v %v", res, err)
+				r.trace("ptr ok, %v %v", res, err)
 				return res, err
 			}
 		} else if strings.HasPrefix(lfield, "exists:") {
 			if ok, res, err := r.existsField(result, field, domain); ok {
-				trace("exists ok, %v %v", res, err)
+				r.trace("exists ok, %v %v", res, err)
 				return res, err
 			}
 		} else if strings.HasPrefix(lfield, "exp=") {
-			trace("exp= not used, skipping")
+			r.trace("exp= not used, skipping")
 			continue
 		} else if strings.HasPrefix(lfield, "redirect=") {
-			trace("redirect, %q", field)
+			r.trace("redirect, %q", field)
 			return r.redirectField(field, domain)
 		} else {
 			// http://www.openspf.org/SPF_Record_Syntax
-			trace("permerror, unknown field")
+			r.trace("permerror, unknown field")
 			return PermError, errUnknownField
 		}
 	}
 
 	// Got to the end of the evaluation without a result => Neutral.
 	// https://tools.ietf.org/html/rfc7208#section-4.7
-	trace("fallback to neutral")
+	r.trace("fallback to neutral")
 	return Neutral, nil
 }
 
@@ -489,7 +507,7 @@ func (r *resolution) ptrField(res Result, field, domain string) (bool, Result, e
 				// RFC explicitly says to skip domains which error here.
 				continue
 			}
-			trace("ptr forward resolution %q -> %q", n, addrs)
+			r.trace("ptr forward resolution %q -> %q", n, addrs)
 			if len(addrs) > 0 {
 				// Append the lower-case variants so we do a case-insensitive
 				// lookup below.
@@ -498,7 +516,7 @@ func (r *resolution) ptrField(res Result, field, domain string) (bool, Result, e
 		}
 	}
 
-	trace("ptr evaluating %q in %q", ptrDomain, r.ipNames)
+	r.trace("ptr evaluating %q in %q", ptrDomain, r.ipNames)
 	ptrDomain = strings.ToLower(ptrDomain)
 	for _, n := range r.ipNames {
 		if strings.HasSuffix(n, ptrDomain+".") {
@@ -615,7 +633,6 @@ func domainAndMask(re *regexp.Regexp, field, domain string) (string, dualMasks,
 			masks.v6 = mask6
 		}
 	}
-	trace("masks on %q: %q %q %v", field, groups, domain, masks)
 
 	// Test to catch malformed entries: if there's a /, there must be at least
 	// one mask.
@@ -630,6 +647,7 @@ func domainAndMask(re *regexp.Regexp, field, domain string) (string, dualMasks,
 func (r *resolution) aField(res Result, field, domain string) (bool, Result, error) {
 	// https://tools.ietf.org/html/rfc7208#section-5.3
 	aDomain, masks, err := domainAndMask(aRegexp, field, domain)
+	r.trace("masks on %q, %q: %q %v", field, domain, aDomain, masks)
 	if err != nil {
 		return true, PermError, err
 	}
@@ -649,7 +667,7 @@ func (r *resolution) aField(res Result, field, domain string) (bool, Result, err
 	}
 	for _, ip := range ips {
 		if ipMatch(r.ip, ip.IP, masks) {
-			trace("a matched %v, %v, %v", r.ip, ip.IP, masks)
+			r.trace("a matched %v, %v, %v", r.ip, ip.IP, masks)
 			return true, res, errMatchedA
 		}
 	}
@@ -661,6 +679,7 @@ func (r *resolution) aField(res Result, field, domain string) (bool, Result, err
 func (r *resolution) mxField(res Result, field, domain string) (bool, Result, error) {
 	// https://tools.ietf.org/html/rfc7208#section-5.4
 	mxDomain, masks, err := domainAndMask(mxRegexp, field, domain)
+	r.trace("masks on %q, %q: %q %v", field, domain, mxDomain, masks)
 	if err != nil {
 		return true, PermError, err
 	}
@@ -702,7 +721,7 @@ func (r *resolution) mxField(res Result, field, domain string) (bool, Result, er
 	}
 	for _, ip := range mxips {
 		if ipMatch(r.ip, ip, masks) {
-			trace("mx matched %v, %v, %v", r.ip, ip, masks)
+			r.trace("mx matched %v, %v, %v", r.ip, ip, masks)
 			return true, res, errMatchedMX
 		}
 	}
@@ -744,7 +763,7 @@ func (r *resolution) expandMacros(s, domain string) (string, error) {
 	// from happening in case where it matters (a, mx), but for the ones which
 	// doesn't, prevent them from sneaking through.
 	if strings.Contains(s, "/") {
-		trace("macro contains /")
+		r.trace("macro contains /")
 		return "", errInvalidDomain
 	}
 
@@ -793,7 +812,7 @@ func (r *resolution) expandMacros(s, domain string) (string, error) {
 			// Extract letter, digit transformer, reverse transformer, and
 			// delimiters.
 			groups := macroRegexp.FindStringSubmatch(macroS)
-			trace("macro %q: %q", macroS, groups)
+			r.trace("macro %q: %q", macroS, groups)
 			macroS = ""
 			if groups == nil {
 				return "", errInvalidMacro
@@ -890,7 +909,7 @@ func (r *resolution) expandMacros(s, domain string) (string, error) {
 		n += string(c)
 	}
 
-	trace("macro expanded %q to %q", s, n)
+	r.trace("macro expanded %q to %q", s, n)
 	return n, nil
 }
 
diff --git a/spf_test.go b/spf_test.go
index 022a06c..c44b12e 100644
--- a/spf_test.go
+++ b/spf_test.go
@@ -29,7 +29,7 @@ var ip6660 = net.ParseIP("2001:db8::0")
 
 func TestBasic(t *testing.T) {
 	dns := NewDefaultResolver()
-	trace = t.Logf
+	defaultTrace = t.Logf
 
 	cases := []struct {
 		txt string
@@ -104,7 +104,7 @@ func TestBasic(t *testing.T) {
 
 func TestIPv6(t *testing.T) {
 	dns := NewDefaultResolver()
-	trace = t.Logf
+	defaultTrace = t.Logf
 
 	cases := []struct {
 		txt string
@@ -158,7 +158,7 @@ func TestInclude(t *testing.T) {
 	// If we got a match on 1.1.1.1, is because include:domain2 did not match.
 	dns := NewDefaultResolver()
 	dns.Txt["domain"] = []string{"v=spf1 include:domain2 ip4:1.1.1.1"}
-	trace = t.Logf
+	defaultTrace = t.Logf
 
 	cases := []struct {
 		txt string
@@ -186,7 +186,7 @@ func TestInclude(t *testing.T) {
 func TestRecursionLimit(t *testing.T) {
 	dns := NewDefaultResolver()
 	dns.Txt["domain"] = []string{"v=spf1 include:domain ~all"}
-	trace = t.Logf
+	defaultTrace = t.Logf
 
 	res, err := CheckHost(ip1111, "domain")
 	if res != PermError || err != errLookupLimitReached {
@@ -198,7 +198,7 @@ func TestRedirect(t *testing.T) {
 	dns := NewDefaultResolver()
 	dns.Txt["domain"] = []string{"v=spf1 redirect=domain2"}
 	dns.Txt["domain2"] = []string{"v=spf1 ip4:1.1.1.1 -all"}
-	trace = t.Logf
+	defaultTrace = t.Logf
 
 	res, err := CheckHost(ip1111, "domain")
 	if res != Pass {
@@ -212,7 +212,7 @@ func TestInvalidRedirect(t *testing.T) {
 	// https://tools.ietf.org/html/rfc7208#section-6.1
 	dns := NewDefaultResolver()
 	dns.Txt["domain"] = []string{"v=spf1 redirect=doesnotexist"}
-	trace = t.Logf
+	defaultTrace = t.Logf
 
 	res, err := CheckHost(ip1111, "doesnotexist")
 	if res != None {
@@ -230,7 +230,7 @@ func TestRedirectOrder(t *testing.T) {
 	// redirect modifier appears before them.
 	dns := NewDefaultResolver()
 	dns.Txt["faildom"] = []string{"v=spf1 -all"}
-	trace = t.Logf
+	defaultTrace = t.Logf
 
 	dns.Txt["domain"] = []string{"v=spf1 redirect=faildom"}
 	res, err := CheckHost(ip1111, "domain")
@@ -250,7 +250,7 @@ func TestNoRecord(t *testing.T) {
 	dns.Txt["d1"] = []string{""}
 	dns.Txt["d2"] = []string{"loco", "v=spf2"}
 	dns.Errors["nospf"] = fmt.Errorf("no such domain")
-	trace = t.Logf
+	defaultTrace = t.Logf
 
 	for _, domain := range []string{"d1", "d2", "d3", "nospf"} {
 		res, err := CheckHost(ip1111, domain)
@@ -271,7 +271,7 @@ func TestDNSTemporaryErrors(t *testing.T) {
 	dns.Errors["tmperr"] = dnsError
 	dns.Errors["1.1.1.1"] = dnsError
 	dns.Mx["tmpmx"] = []*net.MX{mx("tmperr", 10)}
-	trace = t.Logf
+	defaultTrace = t.Logf
 
 	cases := []struct {
 		txt string
@@ -305,7 +305,7 @@ func TestDNSPermanentErrors(t *testing.T) {
 	dns.Errors["tmperr"] = dnsError
 	dns.Errors["1.1.1.1"] = dnsError
 	dns.Mx["tmpmx"] = []*net.MX{mx("tmperr", 10)}
-	trace = t.Logf
+	defaultTrace = t.Logf
 
 	cases := []struct {
 		txt string
@@ -330,7 +330,7 @@ func TestDNSPermanentErrors(t *testing.T) {
 
 func TestMacros(t *testing.T) {
 	dns := NewDefaultResolver()
-	trace = t.Logf
+	defaultTrace = t.Logf
 
 	// Most of the cases are covered by the standard test suite, so this is
 	// targeted at gaps in coverage.
@@ -376,7 +376,7 @@ func TestMacros(t *testing.T) {
 
 func TestMacrosV4(t *testing.T) {
 	dns := NewDefaultResolver()
-	trace = t.Logf
+	defaultTrace = t.Logf
 
 	// Like TestMacros above, but specifically for IPv4.
 	// It's easier to have a separate suite.
@@ -462,6 +462,7 @@ func TestInvalidMacro(t *testing.T) {
 			ip:     ip1111,
 			count:  0,
 			sender: "sender.com",
+			trace:  t.Logf,
 		}
 
 		out, err := r.expandMacros(macro, "sender.com")
@@ -476,7 +477,7 @@ func TestInvalidMacro(t *testing.T) {
 // other tests override it.
 func TestNullTrace(t *testing.T) {
 	dns := NewDefaultResolver()
-	trace = nullTrace
+	defaultTrace = nullTrace
 
 	dns.Txt["domain1"] = []string{"v=spf1 include:domain2"}
 	dns.Txt["domain2"] = []string{"v=spf1 +all"}
@@ -490,7 +491,7 @@ func TestNullTrace(t *testing.T) {
 
 func TestOverrideLookupLimit(t *testing.T) {
 	dns := NewDefaultResolver()
-	trace = t.Logf
+	defaultTrace = t.Logf
 
 	dns.Txt["domain1"] = []string{"v=spf1 include:domain2"}
 	dns.Txt["domain2"] = []string{"v=spf1 include:domain3"}
@@ -521,7 +522,7 @@ func TestOverrideLookupLimit(t *testing.T) {
 
 func TestWithContext(t *testing.T) {
 	dns := NewDefaultResolver()
-	trace = t.Logf
+	defaultTrace = t.Logf
 
 	dns.Txt["domain1"] = []string{"v=spf1 include:domain2"}
 	dns.Txt["domain2"] = []string{"v=spf1 +all"}
@@ -548,7 +549,7 @@ func TestWithResolver(t *testing.T) {
 	// Use a custom resolver, making sure it's different from the default.
 	defaultResolver = dnstest.NewResolver()
 	dns := dnstest.NewResolver()
-	trace = t.Logf
+	defaultTrace = t.Logf
 
 	dns.Txt["domain1"] = []string{"v=spf1 include:domain2"}
 	dns.Txt["domain2"] = []string{"v=spf1 +all"}
@@ -564,7 +565,7 @@ func TestWithResolver(t *testing.T) {
 // address. This can happen if using a buggy custom resolver.
 func TestBadResolverResponse(t *testing.T) {
 	dns := dnstest.NewResolver()
-	trace = t.Logf
+	defaultTrace = t.Logf
 
 	// When LookupIPAddr returns an invalid ip, for an "a" field.
 	dns.Ip["domain1"] = []net.IP{nil}
@@ -604,3 +605,27 @@ func TestBadResolverResponse(t *testing.T) {
 		t.Errorf("expected fail, got %q / %q", res, err)
 	}
 }
+
+func TestWithTraceFunc(t *testing.T) {
+	calls := 0
+	var trace TraceFunc = func(f string, a ...interface{}) {
+		calls++
+		t.Logf("tracing "+f, a...)
+	}
+
+	dns := NewDefaultResolver()
+
+	dns.Txt["domain1"] = []string{"v=spf1 include:domain2"}
+	dns.Txt["domain2"] = []string{"v=spf1 +all"}
+
+	// Do a normal resolution, check it passes.
+	res, err := CheckHostWithSender(ip1111, "helo", "user@domain1",
+		WithTraceFunc(trace))
+	if res != Pass {
+		t.Errorf("expected pass, got %q / %q", res, err)
+	}
+
+	if calls == 0 {
+		t.Errorf("expected >0 trace function calls, got 0")
+	}
+}
diff --git a/yml_test.go b/yml_test.go
index e617006..111b8ac 100644
--- a/yml_test.go
+++ b/yml_test.go
@@ -150,7 +150,7 @@ func testRFC(t *testing.T, fname string) {
 		suites = append(suites, s)
 	}
 
-	trace = t.Logf
+	defaultTrace = t.Logf
 
 	for _, suite := range suites {
 		t.Logf("suite: %v", suite.Description)