git » chasquid » commit 79a8cfc

sts: DNS TXT record support

author Alberto Bertogli
2018-05-27 13:23:11 UTC
committer Alberto Bertogli
2018-07-01 11:19:02 UTC
parent 8bf584bd86f68f1024f9d92cc3f15ab2d68bbaa4

sts: DNS TXT record support

This patch adds support for checking the MTA-STS TXT record before
fetching the policy via https.

The content of the record is unused.

internal/sts/sts.go +33 -3
internal/sts/sts_test.go +61 -2

diff --git a/internal/sts/sts.go b/internal/sts/sts.go
index d4f7298..47f34c2 100644
--- a/internal/sts/sts.go
+++ b/internal/sts/sts.go
@@ -3,9 +3,7 @@
 //
 // This is an EXPERIMENTAL implementation for now.
 //
-// It lacks (at least) the following:
-// - DNS TXT checking.
-// - Facilities for reporting.
+// Note that "report" mode is not supported.
 //
 package sts
 
@@ -20,6 +18,7 @@ import (
 	"io"
 	"io/ioutil"
 	"mime"
+	"net"
 	"net/http"
 	"os"
 	"strconv"
@@ -167,6 +166,14 @@ func UncheckedFetch(ctx context.Context, domain string) (*Policy, error) {
 		return nil, err
 	}
 
+	ok, err := hasSTSRecord(domain)
+	if err != nil {
+		return nil, err
+	}
+	if !ok {
+		return nil, fmt.Errorf("MTA-STS TXT record missing")
+	}
+
 	url := urlForDomain(domain)
 	rawPolicy, err := httpGet(ctx, url)
 	if err != nil {
@@ -295,6 +302,29 @@ func domainToASCII(domain string) (string, error) {
 	return idna.ToASCII(domain)
 }
 
+// Function that we override for testing purposes.
+// In the future we will override net.DefaultResolver, but we don't do that
+// yet for backwards compatibility.
+var lookupTXT = net.LookupTXT
+
+// hasSTSRecord checks if there is a valid MTA-STS TXT record for the domain.
+// We don't do full parsing and don't care about the "id=" field, as it is
+// unused in this implementation.
+func hasSTSRecord(domain string) (bool, error) {
+	txts, err := lookupTXT("_mta-sts." + domain)
+	if err != nil {
+		return false, err
+	}
+
+	for _, txt := range txts {
+		if strings.HasPrefix(txt, "v=STSv1;") {
+			return true, nil
+		}
+	}
+
+	return false, nil
+}
+
 // PolicyCache is a caching layer for fetching policies.
 //
 // Policies are cached by domain, and stored in a single directory.
diff --git a/internal/sts/sts_test.go b/internal/sts/sts_test.go
index 546f113..b907a88 100644
--- a/internal/sts/sts_test.go
+++ b/internal/sts/sts_test.go
@@ -13,6 +13,27 @@ import (
 	"time"
 )
 
+// Override the lookup function to control its results.
+var txtResults = map[string][]string{
+	"dom1": nil,
+	"dom2": {},
+	"dom3": {"abc", "def"},
+	"dom4": {"abc", "v=STSv1; id=blah;"},
+
+	// Matching policyForDomain below.
+	"_mta-sts.domain.com": {"v=STSv1; id=blah;"},
+	"_mta-sts.policy404":  {"v=STSv1; id=blah;"},
+	"_mta-sts.version99":  {"v=STSv1; id=blah;"},
+}
+var testError = fmt.Errorf("error for testing purposes")
+var txtErrors = map[string]error{
+	"_mta-sts.domErr": testError,
+}
+
+func testLookupTXT(domain string) ([]string, error) {
+	return txtResults[domain], txtErrors[domain]
+}
+
 // Test policy for each of the requested domains.  Will be served by the test
 // HTTP server.
 var policyForDomain = map[string]string{
@@ -45,6 +66,8 @@ func testHTTPHandler(w http.ResponseWriter, r *http.Request) {
 }
 
 func TestMain(m *testing.M) {
+	lookupTXT = testLookupTXT
+
 	// Create a test HTTP server, used by the more end-to-end tests.
 	httpServer := httptest.NewServer(http.HandlerFunc(testHTTPHandler))
 
@@ -148,11 +171,11 @@ func TestFetch(t *testing.T) {
 	t.Logf("domain.com: %+v", p)
 
 	// Domain without a policy (HTTP get fails).
-	p, err = Fetch(context.Background(), "unknown")
+	p, err = Fetch(context.Background(), "policy404")
 	if err == nil {
 		t.Errorf("fetched unknown policy: %v", p)
 	}
-	t.Logf("unknown: got error as expected: %v", err)
+	t.Logf("policy404: got error as expected: %v", err)
 
 	// Domain with an invalid policy (unknown version).
 	p, err = Fetch(context.Background(), "version99")
@@ -161,6 +184,14 @@ func TestFetch(t *testing.T) {
 			ErrUnknownVersion, err, p)
 	}
 	t.Logf("version99: got expected error: %v", err)
+
+	// Error fetching TXT record for this domain.
+	p, err = Fetch(context.Background(), "domErr")
+	if err != testError {
+		t.Errorf("expected error %v, got %v (and policy: %v)",
+			testError, err, p)
+	}
+	t.Logf("domErr: got expected error: %v", err)
 }
 
 func TestPolicyTooBig(t *testing.T) {
@@ -345,6 +376,7 @@ func TestCacheRefresh(t *testing.T) {
 
 	ctx := context.Background()
 
+	txtResults["_mta-sts.refresh-test"] = []string{"v=STSv1; id=blah;"}
 	policyForDomain["refresh-test"] = `
 		version: STSv1
 		mode: enforce
@@ -393,3 +425,30 @@ func TestURLForDomain(t *testing.T) {
 		t.Errorf("got %q, expected %q", got, expected)
 	}
 }
+
+func TestHasSTSRecord(t *testing.T) {
+	txtResults["_mta-sts.dom1"] = nil
+	txtResults["_mta-sts.dom2"] = []string{}
+	txtResults["_mta-sts.dom3"] = []string{"abc", "def"}
+	txtResults["_mta-sts.dom4"] = []string{"abc", "v=STSv1; id=blah;"}
+
+	cases := []struct {
+		domain string
+		ok     bool
+		err    error
+	}{
+		{"", false, nil},
+		{"dom1", false, nil},
+		{"dom2", false, nil},
+		{"dom3", false, nil},
+		{"dom4", true, nil},
+		{"domErr", false, testError},
+	}
+	for _, c := range cases {
+		ok, err := hasSTSRecord(c.domain)
+		if ok != c.ok || err != c.err {
+			t.Errorf("%s: expected {%v, %v}, got {%v, %v}", c.domain,
+				c.ok, c.err, ok, err)
+		}
+	}
+}