git » chasquid » commit 3d3b771

internal/spf: Add an SPF package

author Alberto Bertogli
2016-10-07 22:54:23 UTC
committer Alberto Bertogli
2016-10-09 23:51:05 UTC
parent 498bb29585d4efbdde6c4a8b8df66625efe6bfe1

internal/spf: Add an SPF package

This patch adds a package for evaluating SPF, as defined by RFC 7208
(https://tools.ietf.org/html/rfc7208).

It doesn't implement 100% of the RFC, but it coves enough to handle the
most common cases, and will fail open on the others.

cmd/spf-check/spf-check.go +20 -0
internal/spf/spf.go +379 -0
internal/spf/spf_test.go +131 -0

diff --git a/cmd/spf-check/spf-check.go b/cmd/spf-check/spf-check.go
new file mode 100644
index 0000000..9328930
--- /dev/null
+++ b/cmd/spf-check/spf-check.go
@@ -0,0 +1,20 @@
+// Command line tool for playing with the SPF library.
+//
+// Not for use in production, just development and experimentation.
+package main
+
+import (
+	"flag"
+	"fmt"
+	"net"
+
+	"blitiri.com.ar/go/chasquid/internal/spf"
+)
+
+func main() {
+	flag.Parse()
+
+	r, err := spf.CheckHost(net.ParseIP(flag.Arg(0)), flag.Arg(1))
+	fmt.Println(r)
+	fmt.Println(err)
+}
diff --git a/internal/spf/spf.go b/internal/spf/spf.go
new file mode 100644
index 0000000..974ab09
--- /dev/null
+++ b/internal/spf/spf.go
@@ -0,0 +1,379 @@
+// Package spf implements SPF (Sender Policy Framework) lookup and validation.
+//
+// Supported:
+//  - "all".
+//  - "include".
+//  - "a".
+//  - "mx".
+//  - "ip4".
+//  - "ip6".
+//  - "redirect".
+//
+// Not supported (return Neutral if used):
+//  - "exists".
+//  - "ptr".
+//  - "exp".
+//  - Macros.
+//
+// References:
+// https://tools.ietf.org/html/rfc7208
+// https://en.wikipedia.org/wiki/Sender_Policy_Framework
+package spf
+
+import (
+	"fmt"
+	"net"
+	"regexp"
+	"strconv"
+	"strings"
+)
+
+// TODO: Neutral if not supported (including macros).
+
+// Functions that we can override for testing purposes.
+var (
+	lookupTXT func(domain string) (txts []string, err error) = net.LookupTXT
+	lookupMX  func(domain string) (mxs []*net.MX, err error) = net.LookupMX
+	lookupIP  func(host string) (ips []net.IP, err error)    = net.LookupIP
+)
+
+// Results and Errors. Note the values have meaning, we use them in headers.
+// https://tools.ietf.org/html/rfc7208#section-8
+type Result string
+
+var (
+	// https://tools.ietf.org/html/rfc7208#section-8.1
+	// Not able to reach any conclusion.
+	None = Result("none")
+
+	// https://tools.ietf.org/html/rfc7208#section-8.2
+	// No definite assertion (positive or negative).
+	Neutral = Result("neutral")
+
+	// https://tools.ietf.org/html/rfc7208#section-8.3
+	// Client is authorized to inject mail.
+	Pass = Result("pass")
+
+	// https://tools.ietf.org/html/rfc7208#section-8.4
+	// Client is *not* authorized to use the domain
+	Fail = Result("fail")
+
+	// https://tools.ietf.org/html/rfc7208#section-8.5
+	// Not authorized, but unwilling to make a strong policy statement/
+	SoftFail = Result("softfail")
+
+	// https://tools.ietf.org/html/rfc7208#section-8.6
+	// Transient error while performing the check.
+	TempError = Result("temperror")
+
+	// https://tools.ietf.org/html/rfc7208#section-8.7
+	// Records could not be correctly interpreted.
+	PermError = Result("permerror")
+)
+
+var QualToResult = map[byte]Result{
+	'+': Pass,
+	'-': Fail,
+	'~': SoftFail,
+	'?': Neutral,
+}
+
+// CheckHost function fetches SPF records, parses them, and evaluates them to
+// determine whether a particular host is or is not permitted to send mail
+// with a given identity.
+// Reference: https://tools.ietf.org/html/rfc7208#section-4
+func CheckHost(ip net.IP, domain string) (Result, error) {
+	r := &resolution{ip, 0}
+	return r.Check(domain)
+}
+
+type resolution struct {
+	ip    net.IP
+	count uint
+}
+
+func (r *resolution) Check(domain string) (Result, error) {
+	// Limit the number of resolutions to 10
+	// https://tools.ietf.org/html/rfc7208#section-4.6.4
+	if r.count > 10 {
+		return PermError, fmt.Errorf("lookup limit reached")
+	}
+	r.count++
+
+	txt, err := getDNSRecord(domain)
+	if err != nil {
+		if isTemporary(err) {
+			return TempError, err
+		}
+		// Could not resolve the name, it may be missing the record.
+		// https://tools.ietf.org/html/rfc7208#section-2.6.1
+		return None, err
+	}
+
+	if txt == "" {
+		// No record => None.
+		// https://tools.ietf.org/html/rfc7208#section-4.6
+		return None, nil
+	}
+
+	fields := strings.Fields(txt)
+
+	// redirects must be handled after the rest; instead of having two loops,
+	// we just move them to the end.
+	var newfields, redirects []string
+	for _, field := range fields {
+		if strings.HasPrefix(field, "redirect:") {
+			redirects = append(redirects, field)
+		} else {
+			newfields = append(newfields, field)
+		}
+	}
+	fields = append(newfields, redirects...)
+
+	for _, field := range fields {
+		if strings.HasPrefix(field, "v=") {
+			continue
+		}
+		if r.count > 10 {
+			return PermError, fmt.Errorf("lookup limit reached")
+		}
+		if strings.Contains(field, "%") {
+			return Neutral, fmt.Errorf("macros not supported")
+		}
+
+		// See if we have a qualifier, defaulting to + (pass).
+		// https://tools.ietf.org/html/rfc7208#section-4.6.2
+		result, ok := QualToResult[field[0]]
+		if ok {
+			field = field[1:]
+		} else {
+			result = Pass
+		}
+
+		if field == "all" {
+			// https://tools.ietf.org/html/rfc7208#section-5.1
+			return result, fmt.Errorf("matched 'all'")
+		} else if strings.HasPrefix(field, "include:") {
+			if ok, res, err := r.includeField(result, field); ok {
+				return res, err
+			}
+		} else if strings.HasPrefix(field, "a") {
+			if ok, res, err := r.aField(result, field, domain); ok {
+				return res, err
+			}
+		} else if strings.HasPrefix(field, "mx") {
+			if ok, res, err := r.mxField(result, field, domain); ok {
+				return res, err
+			}
+		} else if strings.HasPrefix(field, "ip4:") || strings.HasPrefix(field, "ip6:") {
+			if ok, res, err := r.ipField(result, field); ok {
+				return res, err
+			}
+		} else if strings.HasPrefix(field, "exists") {
+			return Neutral, fmt.Errorf("'exists' not supported")
+		} else if strings.HasPrefix(field, "ptr") {
+			return Neutral, fmt.Errorf("'ptr' not supported")
+		} else if strings.HasPrefix(field, "exp=") {
+			return Neutral, fmt.Errorf("'exp' not supported")
+		} else if strings.HasPrefix(field, "redirect=") {
+			// https://tools.ietf.org/html/rfc7208#section-6.1
+			result, err := r.Check(field[len("redirect="):])
+			if result == None {
+				result = PermError
+			}
+			return result, err
+		} else {
+			// http://www.openspf.org/SPF_Record_Syntax
+			return PermError, fmt.Errorf("unknown field %q", field)
+		}
+	}
+
+	// Got to the end of the evaluation without a result => Neutral.
+	// https://tools.ietf.org/html/rfc7208#section-4.7
+	return Neutral, nil
+}
+
+// getDNSRecord gets TXT records from the given domain, and returns the SPF
+// (if any).  Note that at most one SPF is allowed per a given domain:
+// https://tools.ietf.org/html/rfc7208#section-3
+// https://tools.ietf.org/html/rfc7208#section-3.2
+// https://tools.ietf.org/html/rfc7208#section-4.5
+func getDNSRecord(domain string) (string, error) {
+	txts, err := lookupTXT(domain)
+	if err != nil {
+		return "", err
+	}
+
+	for _, txt := range txts {
+		if strings.HasPrefix(txt, "v=spf1 ") {
+			return txt, nil
+		}
+
+		// An empty record is explicitly allowed:
+		// https://tools.ietf.org/html/rfc7208#section-4.5
+		if txt == "v=spf1" {
+			return txt, nil
+		}
+	}
+
+	return "", nil
+}
+
+func isTemporary(err error) bool {
+	derr, ok := err.(*net.DNSError)
+	return ok && derr.Temporary()
+}
+
+// ipField processes an "ip" field.
+func (r *resolution) ipField(res Result, field string) (bool, Result, error) {
+	fip := field[4:]
+	if strings.Contains(fip, "/") {
+		_, ipnet, err := net.ParseCIDR(fip)
+		if err != nil {
+			return true, PermError, err
+		}
+		if ipnet.Contains(r.ip) {
+			return true, res, fmt.Errorf("matched %v", ipnet)
+		}
+	} else {
+		ip := net.ParseIP(fip)
+		if ip == nil {
+			return true, PermError, fmt.Errorf("invalid ipX value")
+		}
+		if ip.Equal(r.ip) {
+			return true, res, fmt.Errorf("matched %v", ip)
+		}
+	}
+
+	return false, "", nil
+}
+
+// includeField processes an "include" field.
+func (r *resolution) includeField(res Result, field string) (bool, Result, error) {
+	// https://tools.ietf.org/html/rfc7208#section-5.2
+	incdomain := field[len("include:"):]
+	ir, err := r.Check(incdomain)
+	switch ir {
+	case Pass:
+		return true, res, err
+	case Fail, SoftFail, Neutral:
+		return false, ir, err
+	case TempError:
+		return true, TempError, err
+	case PermError, None:
+		return true, PermError, err
+	}
+
+	return false, "", fmt.Errorf("This should never be reached")
+
+}
+
+func ipMatch(ip, tomatch net.IP, mask int) (bool, error) {
+	if mask >= 0 {
+		_, ipnet, err := net.ParseCIDR(fmt.Sprintf("%s/%d", tomatch.String(), mask))
+		if err != nil {
+			return false, err
+		}
+		if ipnet.Contains(ip) {
+			return true, fmt.Errorf("%v", ipnet)
+		}
+		return false, nil
+	} else {
+		if ip.Equal(tomatch) {
+			return true, fmt.Errorf("%v", tomatch)
+		}
+		return false, nil
+	}
+}
+
+var aRegexp = regexp.MustCompile("a(:([^/]+))?(/(.+))?")
+var mxRegexp = regexp.MustCompile("mx(:([^/]+))?(/(.+))?")
+
+func domainAndMask(re *regexp.Regexp, field, domain string) (string, int, error) {
+	var err error
+	mask := -1
+	if groups := re.FindStringSubmatch(field); groups != nil {
+		if groups[2] != "" {
+			domain = groups[2]
+		}
+		if groups[4] != "" {
+			mask, err = strconv.Atoi(groups[4])
+			if err != nil {
+				return "", -1, fmt.Errorf("error parsing mask")
+			}
+		}
+	}
+
+	return domain, mask, nil
+}
+
+// aField processes an "a" field.
+func (r *resolution) aField(res Result, field, domain string) (bool, Result, error) {
+	// https://tools.ietf.org/html/rfc7208#section-5.3
+	domain, mask, err := domainAndMask(aRegexp, field, domain)
+	if err != nil {
+		return true, PermError, err
+	}
+
+	r.count++
+	ips, err := lookupIP(domain)
+	if err != nil {
+		// https://tools.ietf.org/html/rfc7208#section-5
+		if isTemporary(err) {
+			return true, TempError, err
+		}
+		return false, "", err
+	}
+	for _, ip := range ips {
+		ok, err := ipMatch(r.ip, ip, mask)
+		if ok {
+			return true, res, fmt.Errorf("matched 'a' (%v)", err)
+		} else if err != nil {
+			return true, PermError, err
+		}
+	}
+
+	return false, "", nil
+}
+
+// mxField processes an "mx" field.
+func (r *resolution) mxField(res Result, field, domain string) (bool, Result, error) {
+	// https://tools.ietf.org/html/rfc7208#section-5.4
+	domain, mask, err := domainAndMask(mxRegexp, field, domain)
+	if err != nil {
+		return true, PermError, err
+	}
+
+	r.count++
+	mxs, err := lookupMX(domain)
+	if err != nil {
+		// https://tools.ietf.org/html/rfc7208#section-5
+		if isTemporary(err) {
+			return true, TempError, err
+		}
+		return false, "", err
+	}
+	mxips := []net.IP{}
+	for _, mx := range mxs {
+		r.count++
+		ips, err := lookupIP(mx.Host)
+		if err != nil {
+			// https://tools.ietf.org/html/rfc7208#section-5
+			if isTemporary(err) {
+				return true, TempError, err
+			}
+			return false, "", err
+		}
+		mxips = append(mxips, ips...)
+	}
+	for _, ip := range mxips {
+		ok, err := ipMatch(r.ip, ip, mask)
+		if ok {
+			return true, res, fmt.Errorf("matched 'mx' (%v)", err)
+		} else if err != nil {
+			return true, PermError, err
+		}
+	}
+
+	return false, "", nil
+}
diff --git a/internal/spf/spf_test.go b/internal/spf/spf_test.go
new file mode 100644
index 0000000..a5007dc
--- /dev/null
+++ b/internal/spf/spf_test.go
@@ -0,0 +1,131 @@
+package spf
+
+import (
+	"flag"
+	"fmt"
+	"net"
+	"os"
+	"testing"
+)
+
+var txtResults = map[string][]string{}
+var txtErrors = map[string]error{}
+
+func LookupTXT(domain string) (txts []string, err error) {
+	return txtResults[domain], txtErrors[domain]
+}
+
+var mxResults = map[string][]*net.MX{}
+var mxErrors = map[string]error{}
+
+func LookupMX(domain string) (mxs []*net.MX, err error) {
+	return mxResults[domain], mxErrors[domain]
+}
+
+var ipResults = map[string][]net.IP{}
+var ipErrors = map[string]error{}
+
+func LookupIP(host string) (ips []net.IP, err error) {
+	return ipResults[host], ipErrors[host]
+}
+
+func TestMain(m *testing.M) {
+	lookupTXT = LookupTXT
+	lookupMX = LookupMX
+	lookupIP = LookupIP
+
+	flag.Parse()
+	os.Exit(m.Run())
+}
+
+var ip1110 = net.ParseIP("1.1.1.0")
+var ip1111 = net.ParseIP("1.1.1.1")
+var ip6666 = net.ParseIP("2001:db8::68")
+
+func TestBasic(t *testing.T) {
+	cases := []struct {
+		txt string
+		res Result
+	}{
+		{"", None},
+		{"blah", None},
+		{"v=spf1", Neutral},
+		{"v=spf1 ", Neutral},
+		{"v=spf1 -", PermError},
+		{"v=spf1 all", Pass},
+		{"v=spf1  +all", Pass},
+		{"v=spf1 -all ", Fail},
+		{"v=spf1 ~all", SoftFail},
+		{"v=spf1 ?all", Neutral},
+		{"v=spf1 a ~all", SoftFail},
+		{"v=spf1 a/24", Neutral},
+		{"v=spf1 a:d1110/24", Pass},
+		{"v=spf1 a:d1110", Neutral},
+		{"v=spf1 a:d1111", Pass},
+		{"v=spf1 a:nothing/24", Neutral},
+		{"v=spf1 mx", Neutral},
+		{"v=spf1 mx/24", Neutral},
+		{"v=spf1 mx:a/montoto ~all", PermError},
+		{"v=spf1 mx:d1110/24 ~all", Pass},
+		{"v=spf1 ip4:1.2.3.4 ~all", SoftFail},
+		{"v=spf1 ip6:12 ~all", PermError},
+		{"v=spf1 ip4:1.1.1.1 -all", Pass},
+		{"v=spf1 blah", PermError},
+	}
+
+	ipResults["d1111"] = []net.IP{ip1111}
+	ipResults["d1110"] = []net.IP{ip1110}
+	mxResults["d1110"] = []*net.MX{{"d1110", 5}, {"nothing", 10}}
+
+	for _, c := range cases {
+		txtResults["domain"] = []string{c.txt}
+		res, err := CheckHost(ip1111, "domain")
+		if (res == TempError || res == PermError) && (err == nil) {
+			t.Errorf("%q: expected error, got nil", c.txt)
+		}
+		if res != c.res {
+			t.Errorf("%q: expected %q, got %q", c.txt, c.res, res)
+			t.Logf("%q:   error: %v", c.txt, err)
+		}
+	}
+}
+
+func TestNotSupported(t *testing.T) {
+	cases := []string{
+		"v=spf1 exists:blah -all",
+		"v=spf1 ptr -all",
+		"v=spf1 exp=blah -all",
+		"v=spf1 a:%{o} -all",
+	}
+
+	for _, txt := range cases {
+		txtResults["domain"] = []string{txt}
+		res, err := CheckHost(ip1111, "domain")
+		if res != Neutral {
+			t.Errorf("%q: expected neutral, got %v", txt, res)
+			t.Logf("%q:   error: %v", txt, err)
+		}
+	}
+}
+
+func TestRecursion(t *testing.T) {
+	txtResults["domain"] = []string{"v=spf1 include:domain ~all"}
+
+	res, err := CheckHost(ip1111, "domain")
+	if res != PermError {
+		t.Errorf("expected permerror, got %v (%v)", res, err)
+	}
+}
+
+func TestNoRecord(t *testing.T) {
+	txtResults["d1"] = []string{""}
+	txtResults["d2"] = []string{"loco", "v=spf2"}
+	txtErrors["nospf"] = fmt.Errorf("no such domain")
+
+	for _, domain := range []string{"d1", "d2", "d3", "nospf"} {
+		res, err := CheckHost(ip1111, domain)
+		if res != None {
+			t.Errorf("expected none, got %v (%v)", res, err)
+		}
+	}
+}