author | Alberto Bertogli
<albertito@blitiri.com.ar> 2020-09-20 09:46:52 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2020-09-20 14:50:43 UTC |
parent | f879074c9016e2c9af43b17c44903a2fffe56d83 |
dns_test.go | +21 | -4 |
spf.go | +30 | -13 |
spf_test.go | +27 | -0 |
diff --git a/dns_test.go b/dns_test.go index a28fc7e..c27d07f 100644 --- a/dns_test.go +++ b/dns_test.go @@ -1,7 +1,9 @@ package spf import ( + "context" "flag" + "fmt" "net" "os" "strings" @@ -32,25 +34,40 @@ func NewDNS() DNS { // This way it's easier to get a clean slate between tests. var dns DNS -func LookupTXT(domain string) (txts []string, err error) { +func LookupTXT(ctx context.Context, domain string) (txts []string, err error) { + if ctx.Err() != nil { + return nil, ctx.Err() + } domain = strings.ToLower(domain) domain = strings.TrimRight(domain, ".") return dns.txt[domain], dns.errors[domain] } -func LookupMX(domain string) (mxs []*net.MX, err error) { +func LookupMX(ctx context.Context, domain string) (mxs []*net.MX, err error) { + if ctx.Err() != nil { + return nil, ctx.Err() + } domain = strings.ToLower(domain) domain = strings.TrimRight(domain, ".") return dns.mx[domain], dns.errors[domain] } -func LookupIP(host string) (ips []net.IP, err error) { +func LookupIP(ctx context.Context, net, host string) (ips []net.IP, err error) { + if net != "ip" { + panic(fmt.Sprintf("got net %q, expected ip", net)) + } + if ctx.Err() != nil { + return nil, ctx.Err() + } host = strings.ToLower(host) host = strings.TrimRight(host, ".") return dns.ip[host], dns.errors[host] } -func LookupAddr(host string) (addrs []string, err error) { +func LookupAddr(ctx context.Context, host string) (addrs []string, err error) { + if ctx.Err() != nil { + return nil, ctx.Err() + } host = strings.ToLower(host) host = strings.TrimRight(host, ".") return dns.addr[host], dns.errors[host] diff --git a/spf.go b/spf.go index 8a9f108..42ee8d5 100644 --- a/spf.go +++ b/spf.go @@ -27,6 +27,7 @@ package spf // import "blitiri.com.ar/go/spf" import ( + "context" "fmt" "net" "net/url" @@ -37,10 +38,10 @@ import ( // Functions that we can override for testing purposes. var ( - lookupTXT = net.LookupTXT - lookupMX = net.LookupMX - lookupIP = net.LookupIP - lookupAddr = net.LookupAddr + lookupTXT = net.DefaultResolver.LookupTXT + lookupMX = net.DefaultResolver.LookupMX + lookupIP = net.DefaultResolver.LookupIP + lookupAddr = net.DefaultResolver.LookupAddr trace = func(f string, a ...interface{}) {} ) @@ -129,6 +130,7 @@ func CheckHost(ip net.IP, domain string) (Result, error) { ip: ip, maxcount: defaultMaxLookups, sender: "@" + domain, + ctx: context.TODO(), } return r.Check(domain) } @@ -152,6 +154,7 @@ func CheckHostWithSender(ip net.IP, helo, sender string, opts ...Option) (Result ip: ip, maxcount: defaultMaxLookups, sender: sender, + ctx: context.TODO(), } for _, opt := range opts { @@ -173,6 +176,17 @@ func OverrideLookupLimit(limit uint) Option { } } +// WithContext is an option to set the context for this operation, which will +// be passed along to the resolver functions and other external calls if +// needed. +// +// This is EXPERIMENTAL for now, and the API is subject to change. +func WithContext(ctx context.Context) Option { + return func(r *resolution) { + r.ctx = ctx + } +} + // split an user@domain address into user and domain. func split(addr string) (string, string) { ps := strings.SplitN(addr, "@", 2) @@ -192,6 +206,9 @@ type resolution struct { // Result of doing a reverse lookup for ip (so we only do it once). ipNames []string + + // Context for this resolution. + ctx context.Context } var aField = regexp.MustCompile(`^(a$|a:|a/)`) @@ -201,7 +218,7 @@ var ptrField = regexp.MustCompile(`^(ptr$|ptr:)`) func (r *resolution) Check(domain string) (Result, error) { r.count++ trace("check %s %d", domain, r.count) - txt, err := getDNSRecord(domain) + txt, err := getDNSRecord(r.ctx, domain) if err != nil { if isTemporary(err) { trace("dns temp error: %v", err) @@ -333,8 +350,8 @@ func (r *resolution) Check(domain string) (Result, error) { // https://tools.ietf.org/html/rfc7208#section-3 // https://tools.ietf.org/html/rfc7208#section-3.2 // https://tools.ietf.org/html/rfc7208#section-4.5 -func getDNSRecord(domain string) (string, error) { - txts, err := lookupTXT(domain) +func getDNSRecord(ctx context.Context, domain string) (string, error) { + txts, err := lookupTXT(ctx, domain) if err != nil { return "", err } @@ -417,7 +434,7 @@ func (r *resolution) ptrField(res Result, field, domain string) (bool, Result, e if r.ipNames == nil { r.ipNames = []string{} r.count++ - ns, err := lookupAddr(r.ip.String()) + ns, err := lookupAddr(r.ctx, r.ip.String()) if err != nil { // https://tools.ietf.org/html/rfc7208#section-5 if isTemporary(err) { @@ -433,7 +450,7 @@ func (r *resolution) ptrField(res Result, field, domain string) (bool, Result, e return false, "", errLookupLimitReached } r.count++ - addrs, err := lookupIP(n) + addrs, err := lookupIP(r.ctx, "ip", n) if err != nil { // RFC explicitly says to skip domains which error here. continue @@ -473,7 +490,7 @@ func (r *resolution) existsField(res Result, field, domain string) (bool, Result } r.count++ - ips, err := lookupIP(eDomain) + ips, err := lookupIP(r.ctx, "ip", eDomain) if err != nil { // https://tools.ietf.org/html/rfc7208#section-5 if isTemporary(err) { @@ -590,7 +607,7 @@ func (r *resolution) aField(res Result, field, domain string) (bool, Result, err } r.count++ - ips, err := lookupIP(aDomain) + ips, err := lookupIP(r.ctx, "ip", aDomain) if err != nil { // https://tools.ietf.org/html/rfc7208#section-5 if isTemporary(err) { @@ -624,7 +641,7 @@ func (r *resolution) mxField(res Result, field, domain string) (bool, Result, er } r.count++ - mxs, err := lookupMX(mxDomain) + mxs, err := lookupMX(r.ctx, mxDomain) if err != nil { // https://tools.ietf.org/html/rfc7208#section-5 if isTemporary(err) { @@ -642,7 +659,7 @@ func (r *resolution) mxField(res Result, field, domain string) (bool, Result, er mxips := []net.IP{} for _, mx := range mxs { r.count++ - ips, err := lookupIP(mx.Host) + ips, err := lookupIP(r.ctx, "ip", mx.Host) if err != nil { // https://tools.ietf.org/html/rfc7208#section-5 if isTemporary(err) { diff --git a/spf_test.go b/spf_test.go index efd5e5b..1d7c9f7 100644 --- a/spf_test.go +++ b/spf_test.go @@ -1,6 +1,7 @@ package spf import ( + "context" "fmt" "net" "testing" @@ -481,3 +482,29 @@ func TestOverrideLookupLimit(t *testing.T) { res, err) } } + +func TestWithContext(t *testing.T) { + dns = NewDNS() + trace = t.Logf + + dns.txt["domain1"] = []string{"v=spf1 include:domain2"} + dns.txt["domain2"] = []string{"v=spf1 +all"} + + // With a normal context. + ctx := context.Background() + res, err := CheckHostWithSender(ip1111, "helo", "user@domain1", + WithContext(ctx)) + if res != Pass { + t.Errorf("expected pass, got %q / %q", res, err) + } + + // With a cancelled context. + ctx, cancelF := context.WithCancel(context.Background()) + cancelF() + res, err = CheckHostWithSender(ip1111, "helo", "user@domain1", + WithContext(ctx)) + if res != None || err != context.Canceled { + t.Errorf("expected none/context cancelled, got %q / %q", res, err) + } + +}