git » spf » commit 41d86b0

Allow passing a context.Context to the resolver functions

author Alberto Bertogli
2020-09-20 09:46:52 UTC
committer Alberto Bertogli
2020-09-20 14:50:43 UTC
parent f879074c9016e2c9af43b17c44903a2fffe56d83

Allow passing a context.Context to the resolver functions

In some cases it can be useful to propagate a context through the SPF
checks, in particular to make sure that the resolver functions honour
context cancellations.

This patch adds a new WithContext option, that lets user give a context
for each operation, which will be passed along to resolver functions.

dns_test.go +21 -4
spf.go +30 -13
spf_test.go +27 -0

diff --git a/dns_test.go b/dns_test.go
index a28fc7e..c27d07f 100644
--- a/dns_test.go
+++ b/dns_test.go
@@ -1,7 +1,9 @@
 package spf
 
 import (
+	"context"
 	"flag"
+	"fmt"
 	"net"
 	"os"
 	"strings"
@@ -32,25 +34,40 @@ func NewDNS() DNS {
 // This way it's easier to get a clean slate between tests.
 var dns DNS
 
-func LookupTXT(domain string) (txts []string, err error) {
+func LookupTXT(ctx context.Context, domain string) (txts []string, err error) {
+	if ctx.Err() != nil {
+		return nil, ctx.Err()
+	}
 	domain = strings.ToLower(domain)
 	domain = strings.TrimRight(domain, ".")
 	return dns.txt[domain], dns.errors[domain]
 }
 
-func LookupMX(domain string) (mxs []*net.MX, err error) {
+func LookupMX(ctx context.Context, domain string) (mxs []*net.MX, err error) {
+	if ctx.Err() != nil {
+		return nil, ctx.Err()
+	}
 	domain = strings.ToLower(domain)
 	domain = strings.TrimRight(domain, ".")
 	return dns.mx[domain], dns.errors[domain]
 }
 
-func LookupIP(host string) (ips []net.IP, err error) {
+func LookupIP(ctx context.Context, net, host string) (ips []net.IP, err error) {
+	if net != "ip" {
+		panic(fmt.Sprintf("got net %q, expected ip", net))
+	}
+	if ctx.Err() != nil {
+		return nil, ctx.Err()
+	}
 	host = strings.ToLower(host)
 	host = strings.TrimRight(host, ".")
 	return dns.ip[host], dns.errors[host]
 }
 
-func LookupAddr(host string) (addrs []string, err error) {
+func LookupAddr(ctx context.Context, host string) (addrs []string, err error) {
+	if ctx.Err() != nil {
+		return nil, ctx.Err()
+	}
 	host = strings.ToLower(host)
 	host = strings.TrimRight(host, ".")
 	return dns.addr[host], dns.errors[host]
diff --git a/spf.go b/spf.go
index 8a9f108..42ee8d5 100644
--- a/spf.go
+++ b/spf.go
@@ -27,6 +27,7 @@
 package spf // import "blitiri.com.ar/go/spf"
 
 import (
+	"context"
 	"fmt"
 	"net"
 	"net/url"
@@ -37,10 +38,10 @@ import (
 
 // Functions that we can override for testing purposes.
 var (
-	lookupTXT  = net.LookupTXT
-	lookupMX   = net.LookupMX
-	lookupIP   = net.LookupIP
-	lookupAddr = net.LookupAddr
+	lookupTXT  = net.DefaultResolver.LookupTXT
+	lookupMX   = net.DefaultResolver.LookupMX
+	lookupIP   = net.DefaultResolver.LookupIP
+	lookupAddr = net.DefaultResolver.LookupAddr
 	trace      = func(f string, a ...interface{}) {}
 )
 
@@ -129,6 +130,7 @@ func CheckHost(ip net.IP, domain string) (Result, error) {
 		ip:       ip,
 		maxcount: defaultMaxLookups,
 		sender:   "@" + domain,
+		ctx:      context.TODO(),
 	}
 	return r.Check(domain)
 }
@@ -152,6 +154,7 @@ func CheckHostWithSender(ip net.IP, helo, sender string, opts ...Option) (Result
 		ip:       ip,
 		maxcount: defaultMaxLookups,
 		sender:   sender,
+		ctx:      context.TODO(),
 	}
 
 	for _, opt := range opts {
@@ -173,6 +176,17 @@ func OverrideLookupLimit(limit uint) Option {
 	}
 }
 
+// WithContext is an option to set the context for this operation, which will
+// be passed along to the resolver functions and other external calls if
+// needed.
+//
+// This is EXPERIMENTAL for now, and the API is subject to change.
+func WithContext(ctx context.Context) Option {
+	return func(r *resolution) {
+		r.ctx = ctx
+	}
+}
+
 // split an user@domain address into user and domain.
 func split(addr string) (string, string) {
 	ps := strings.SplitN(addr, "@", 2)
@@ -192,6 +206,9 @@ type resolution struct {
 
 	// Result of doing a reverse lookup for ip (so we only do it once).
 	ipNames []string
+
+	// Context for this resolution.
+	ctx context.Context
 }
 
 var aField = regexp.MustCompile(`^(a$|a:|a/)`)
@@ -201,7 +218,7 @@ var ptrField = regexp.MustCompile(`^(ptr$|ptr:)`)
 func (r *resolution) Check(domain string) (Result, error) {
 	r.count++
 	trace("check %s %d", domain, r.count)
-	txt, err := getDNSRecord(domain)
+	txt, err := getDNSRecord(r.ctx, domain)
 	if err != nil {
 		if isTemporary(err) {
 			trace("dns temp error: %v", err)
@@ -333,8 +350,8 @@ func (r *resolution) Check(domain string) (Result, error) {
 // https://tools.ietf.org/html/rfc7208#section-3
 // https://tools.ietf.org/html/rfc7208#section-3.2
 // https://tools.ietf.org/html/rfc7208#section-4.5
-func getDNSRecord(domain string) (string, error) {
-	txts, err := lookupTXT(domain)
+func getDNSRecord(ctx context.Context, domain string) (string, error) {
+	txts, err := lookupTXT(ctx, domain)
 	if err != nil {
 		return "", err
 	}
@@ -417,7 +434,7 @@ func (r *resolution) ptrField(res Result, field, domain string) (bool, Result, e
 	if r.ipNames == nil {
 		r.ipNames = []string{}
 		r.count++
-		ns, err := lookupAddr(r.ip.String())
+		ns, err := lookupAddr(r.ctx, r.ip.String())
 		if err != nil {
 			// https://tools.ietf.org/html/rfc7208#section-5
 			if isTemporary(err) {
@@ -433,7 +450,7 @@ func (r *resolution) ptrField(res Result, field, domain string) (bool, Result, e
 				return false, "", errLookupLimitReached
 			}
 			r.count++
-			addrs, err := lookupIP(n)
+			addrs, err := lookupIP(r.ctx, "ip", n)
 			if err != nil {
 				// RFC explicitly says to skip domains which error here.
 				continue
@@ -473,7 +490,7 @@ func (r *resolution) existsField(res Result, field, domain string) (bool, Result
 	}
 
 	r.count++
-	ips, err := lookupIP(eDomain)
+	ips, err := lookupIP(r.ctx, "ip", eDomain)
 	if err != nil {
 		// https://tools.ietf.org/html/rfc7208#section-5
 		if isTemporary(err) {
@@ -590,7 +607,7 @@ func (r *resolution) aField(res Result, field, domain string) (bool, Result, err
 	}
 
 	r.count++
-	ips, err := lookupIP(aDomain)
+	ips, err := lookupIP(r.ctx, "ip", aDomain)
 	if err != nil {
 		// https://tools.ietf.org/html/rfc7208#section-5
 		if isTemporary(err) {
@@ -624,7 +641,7 @@ func (r *resolution) mxField(res Result, field, domain string) (bool, Result, er
 	}
 
 	r.count++
-	mxs, err := lookupMX(mxDomain)
+	mxs, err := lookupMX(r.ctx, mxDomain)
 	if err != nil {
 		// https://tools.ietf.org/html/rfc7208#section-5
 		if isTemporary(err) {
@@ -642,7 +659,7 @@ func (r *resolution) mxField(res Result, field, domain string) (bool, Result, er
 	mxips := []net.IP{}
 	for _, mx := range mxs {
 		r.count++
-		ips, err := lookupIP(mx.Host)
+		ips, err := lookupIP(r.ctx, "ip", mx.Host)
 		if err != nil {
 			// https://tools.ietf.org/html/rfc7208#section-5
 			if isTemporary(err) {
diff --git a/spf_test.go b/spf_test.go
index efd5e5b..1d7c9f7 100644
--- a/spf_test.go
+++ b/spf_test.go
@@ -1,6 +1,7 @@
 package spf
 
 import (
+	"context"
 	"fmt"
 	"net"
 	"testing"
@@ -481,3 +482,29 @@ func TestOverrideLookupLimit(t *testing.T) {
 			res, err)
 	}
 }
+
+func TestWithContext(t *testing.T) {
+	dns = NewDNS()
+	trace = t.Logf
+
+	dns.txt["domain1"] = []string{"v=spf1 include:domain2"}
+	dns.txt["domain2"] = []string{"v=spf1 +all"}
+
+	// With a normal context.
+	ctx := context.Background()
+	res, err := CheckHostWithSender(ip1111, "helo", "user@domain1",
+		WithContext(ctx))
+	if res != Pass {
+		t.Errorf("expected pass, got %q / %q", res, err)
+	}
+
+	// With a cancelled context.
+	ctx, cancelF := context.WithCancel(context.Background())
+	cancelF()
+	res, err = CheckHostWithSender(ip1111, "helo", "user@domain1",
+		WithContext(ctx))
+	if res != None || err != context.Canceled {
+		t.Errorf("expected none/context cancelled, got %q / %q", res, err)
+	}
+
+}