git » spf » commit aa9c766

Implement macros

author Alberto Bertogli
2019-10-14 00:32:14 UTC
committer Alberto Bertogli
2019-10-14 12:35:13 UTC
parent f8b7a67464e5669d5ec52ed4d9afb594caabb8c8

Implement macros

This patch adds support for macro expansion to the library.

spf.go +217 -22
spf_test.go +0 -2

diff --git a/spf.go b/spf.go
index 67ab070..1d22bff 100644
--- a/spf.go
+++ b/spf.go
@@ -17,10 +17,10 @@
 //   ip6
 //   redirect
 //   exp (ignored)
+//   Macros
 //
 // Not supported (return Neutral if used):
 //   exists
-//   Macros
 //
 // This is intentional and there are no plans to add them for now, as they are
 // very rare, convoluted and not worth the additional complexity.
@@ -33,6 +33,7 @@ package spf // import "blitiri.com.ar/go/spf"
 import (
 	"fmt"
 	"net"
+	"net/url"
 	"regexp"
 	"strconv"
 	"strings"
@@ -91,11 +92,11 @@ var qualToResult = map[byte]Result{
 
 var (
 	errLookupLimitReached = fmt.Errorf("lookup limit reached")
-	errMacrosNotSupported = fmt.Errorf("macros not supported")
 	errExistsNotSupported = fmt.Errorf("'exists' not supported")
 	errUnknownField       = fmt.Errorf("unknown field")
 	errInvalidIP          = fmt.Errorf("invalid ipX value")
 	errInvalidMask        = fmt.Errorf("invalid mask")
+	errInvalidMacro       = fmt.Errorf("invalid macro")
 	errNoResult           = fmt.Errorf("lookup yielded no result")
 	errMultipleRecords    = fmt.Errorf("multiple matching DNS records")
 
@@ -111,7 +112,7 @@ var (
 // Reference: https://tools.ietf.org/html/rfc7208#section-4
 func CheckHost(ip net.IP, domain string) (Result, error) {
 	trace("check host %q %q", ip, domain)
-	r := &resolution{ip, 0, "", nil}
+	r := &resolution{ip, 0, "@" + domain, nil}
 	return r.Check(domain)
 }
 
@@ -209,10 +210,6 @@ func (r *resolution) Check(domain string) (Result, error) {
 			return PermError, errLookupLimitReached
 		}
 
-		if strings.Contains(field, "%") {
-			return Neutral, errMacrosNotSupported
-		}
-
 		// See if we have a qualifier, defaulting to + (pass).
 		// https://tools.ietf.org/html/rfc7208#section-4.6.2
 		result, ok := qualToResult[field[0]]
@@ -231,7 +228,7 @@ func (r *resolution) Check(domain string) (Result, error) {
 			trace("%v matched all", result)
 			return result, errMatchedAll
 		} else if strings.HasPrefix(lfield, "include:") {
-			if ok, res, err := r.includeField(result, field); ok {
+			if ok, res, err := r.includeField(result, field, domain); ok {
 				trace("include ok, %v %v", res, err)
 				return res, err
 			}
@@ -263,12 +260,7 @@ func (r *resolution) Check(domain string) (Result, error) {
 			continue
 		} else if strings.HasPrefix(lfield, "redirect=") {
 			trace("redirect, %q", field)
-			// https://tools.ietf.org/html/rfc7208#section-6.1
-			result, err := r.Check(field[len("redirect="):])
-			if result == None {
-				result = PermError
-			}
-			return result, err
+			return r.redirectField(field, domain)
 		} else {
 			// http://www.openspf.org/SPF_Record_Syntax
 			trace("permerror, unknown field")
@@ -353,11 +345,16 @@ func (r *resolution) ipField(res Result, field string) (bool, Result, error) {
 
 // ptrField processes a "ptr" field.
 func (r *resolution) ptrField(res Result, field, domain string) (bool, Result, error) {
-	// Extract the domain if the field is in the form "ptr:domain"
+	// Extract the domain if the field is in the form "ptr:domain".
+	ptrDomain := domain
 	if len(field) >= 4 {
-		domain = field[4:]
+		ptrDomain = field[4:]
 
 	}
+	ptrDomain, err := r.expandMacros(ptrDomain, domain)
+	if err != nil {
+		return true, PermError, errInvalidMacro
+	}
 
 	if r.ipNames == nil {
 		r.count++
@@ -373,7 +370,7 @@ func (r *resolution) ptrField(res Result, field, domain string) (bool, Result, e
 	}
 
 	for _, n := range r.ipNames {
-		if strings.HasSuffix(n, domain+".") {
+		if strings.HasSuffix(n, ptrDomain+".") {
 			return true, res, errMatchedPTR
 		}
 	}
@@ -382,9 +379,13 @@ func (r *resolution) ptrField(res Result, field, domain string) (bool, Result, e
 }
 
 // includeField processes an "include" field.
-func (r *resolution) includeField(res Result, field string) (bool, Result, error) {
+func (r *resolution) includeField(res Result, field, domain string) (bool, Result, error) {
 	// https://tools.ietf.org/html/rfc7208#section-5.2
 	incdomain := field[len("include:"):]
+	incdomain, err := r.expandMacros(incdomain, domain)
+	if err != nil {
+		return true, PermError, errInvalidMacro
+	}
 	ir, err := r.Check(incdomain)
 	switch ir {
 	case Pass:
@@ -469,13 +470,17 @@ func domainAndMask(re *regexp.Regexp, field, domain string) (string, DualMasks,
 // aField processes an "a" field.
 func (r *resolution) aField(res Result, field, domain string) (bool, Result, error) {
 	// https://tools.ietf.org/html/rfc7208#section-5.3
-	domain, masks, err := domainAndMask(aRegexp, field, domain)
+	aDomain, masks, err := domainAndMask(aRegexp, field, domain)
 	if err != nil {
 		return true, PermError, err
 	}
+	aDomain, err = r.expandMacros(aDomain, domain)
+	if err != nil {
+		return true, PermError, errInvalidMacro
+	}
 
 	r.count++
-	ips, err := lookupIP(domain)
+	ips, err := lookupIP(aDomain)
 	if err != nil {
 		// https://tools.ietf.org/html/rfc7208#section-5
 		if isTemporary(err) {
@@ -499,13 +504,17 @@ func (r *resolution) aField(res Result, field, domain string) (bool, Result, err
 // mxField processes an "mx" field.
 func (r *resolution) mxField(res Result, field, domain string) (bool, Result, error) {
 	// https://tools.ietf.org/html/rfc7208#section-5.4
-	domain, masks, err := domainAndMask(mxRegexp, field, domain)
+	mxDomain, masks, err := domainAndMask(mxRegexp, field, domain)
 	if err != nil {
 		return true, PermError, err
 	}
+	mxDomain, err = r.expandMacros(mxDomain, domain)
+	if err != nil {
+		return true, PermError, errInvalidMacro
+	}
 
 	r.count++
-	mxs, err := lookupMX(domain)
+	mxs, err := lookupMX(mxDomain)
 	if err != nil {
 		// https://tools.ietf.org/html/rfc7208#section-5
 		if isTemporary(err) {
@@ -538,3 +547,189 @@ func (r *resolution) mxField(res Result, field, domain string) (bool, Result, er
 
 	return false, "", nil
 }
+
+// redirectField proces a "redirect=" field.
+func (r *resolution) redirectField(field, domain string) (Result, error) {
+	rDomain := field[len("redirect="):]
+	rDomain, err := r.expandMacros(rDomain, domain)
+	if err != nil {
+		return PermError, errInvalidMacro
+	}
+
+	// https://tools.ietf.org/html/rfc7208#section-6.1
+	result, err := r.Check(rDomain)
+	if result == None {
+		result = PermError
+	}
+	return result, err
+}
+
+// Group extraction of macro-string from the formal specification.
+// https://tools.ietf.org/html/rfc7208#section-7.1
+var macroRegexp = regexp.MustCompile(
+	`([slodiphcrtvSLODIPHCRTV])([0-9]+)?([rR])?([-.+,/_=]+)?`)
+
+// Expand macros, return the expanded string.
+// This expects to be passed the domain-spec within a field, not an entire
+// field or larger (that has problematic security implications).
+// https://tools.ietf.org/html/rfc7208#section-7
+func (r *resolution) expandMacros(s, domain string) (string, error) {
+	// Macros/domains shouldn't contain CIDR. Our parsing should prevent it
+	// from happening in case where it matters (a, mx), but for the ones which
+	// doesn't, prevent them from sneaking through.
+	if strings.Contains(s, "/") {
+		trace("macro contains /")
+		return "", errInvalidMacro
+	}
+
+	// Bypass the complex logic if there are no macros present.
+	if !strings.Contains(s, "%") {
+		return s, nil
+	}
+
+	// Are we processing the character right after "%"?
+	afterPercent := false
+
+	// Are we inside a macro definition (%{...}) ?
+	inMacroDefinition := false
+
+	// Macro string, where we accumulate the values inside the definition.
+	macroS := ""
+
+	var err error
+	n := ""
+	for _, c := range s {
+		if afterPercent {
+			afterPercent = false
+			switch c {
+			case '%':
+				n += "%"
+				continue
+			case '_':
+				n += " "
+				continue
+			case '-':
+				n += "%20"
+				continue
+			case '{':
+				inMacroDefinition = true
+				continue
+			}
+			return "", errInvalidMacro
+		}
+		if inMacroDefinition {
+			if c != '}' {
+				macroS += string(c)
+				continue
+			}
+			inMacroDefinition = false
+
+			// Extract letter, digit transformer, reverse transformer, and
+			// delimiters.
+			groups := macroRegexp.FindStringSubmatch(macroS)
+			trace("macro %q: %q", macroS, groups)
+			macroS = ""
+			if groups == nil {
+				return "", errInvalidMacro
+			}
+			letter := groups[1]
+
+			digits := 0
+			if groups[2] != "" {
+				// Use 0 as "no digits given"; an explicit value of 0 is not
+				// valid.
+				digits, err = strconv.Atoi(groups[2])
+				if err != nil || digits <= 0 {
+					return "", errInvalidMacro
+				}
+			}
+			reverse := groups[3] == "r" || groups[3] == "R"
+			delimiters := groups[4]
+			if delimiters == "" {
+				// By default, split strings by ".".
+				delimiters = "."
+			}
+
+			// Uppercase letters indicate URL escaping of the results.
+			urlEscape := letter == strings.ToUpper(letter)
+			letter = strings.ToLower(letter)
+
+			str := ""
+			switch letter {
+			case "s":
+				str = r.sender
+			case "l":
+				str, _ = split(r.sender)
+			case "o":
+				_, str = split(r.sender)
+			case "d":
+				str = domain
+			case "i":
+				str = r.ip.String()
+			case "p":
+				// This shouldn't be used, we don't want to support it, it's
+				// risky. "unknown" is a safe value.
+				// https://tools.ietf.org/html/rfc7208#section-7.3
+				str = "unknown"
+			case "v":
+				if r.ip.To4() != nil {
+					str = "in-addr"
+				} else {
+					str = "ip6"
+				}
+			case "h":
+				str = domain
+			default:
+				// c, r, t are allowed in exp only, and we don't expand macros
+				// in exp so they are just as invalid as the rest.
+				return "", errInvalidMacro
+			}
+
+			// Split str using the given separators.
+			splitFunc := func(r rune) bool {
+				return strings.ContainsRune(delimiters, r)
+			}
+			split := strings.FieldsFunc(str, splitFunc)
+
+			// Reverse if requested.
+			if reverse {
+				reverseStrings(split)
+			}
+
+			// Leave the last $digits fields, if given.
+			if digits > 0 {
+				if digits > len(split) {
+					digits = len(split)
+				}
+				split = split[len(split)-digits : len(split)]
+			}
+
+			// Join back, always with "."
+			str = strings.Join(split, ".")
+
+			// Escape if requested. Note this doesn't strictly escape ALL
+			// unreserved characters, it's the closest we can get without
+			// reimplmenting it ourselves.
+			if urlEscape {
+				str = url.QueryEscape(str)
+			}
+
+			n += str
+			continue
+		}
+		if c == '%' {
+			afterPercent = true
+			continue
+		}
+		n += string(c)
+	}
+
+	trace("macro expanded %q to %q", s, n)
+	return n, nil
+}
+
+func reverseStrings(a []string) {
+	for left, right := 0, len(a)-1; left < right; left, right = left+1, right-1 {
+		a[left], a[right] = a[right], a[left]
+	}
+}
diff --git a/spf_test.go b/spf_test.go
index acbff91..c2393e1 100644
--- a/spf_test.go
+++ b/spf_test.go
@@ -137,8 +137,6 @@ func TestNotSupported(t *testing.T) {
 		err error
 	}{
 		{"v=spf1 exists:blah -all", errExistsNotSupported},
-		{"v=spf1 a:%{o} -all", errMacrosNotSupported},
-		{"v=spf1 redirect=_spf.%{d}", errMacrosNotSupported},
 	}
 
 	for _, c := range cases {