author | Alberto Bertogli
<albertito@blitiri.com.ar> 2018-05-27 13:23:11 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2018-07-01 11:19:02 UTC |
parent | 8bf584bd86f68f1024f9d92cc3f15ab2d68bbaa4 |
internal/sts/sts.go | +33 | -3 |
internal/sts/sts_test.go | +61 | -2 |
diff --git a/internal/sts/sts.go b/internal/sts/sts.go index d4f7298..47f34c2 100644 --- a/internal/sts/sts.go +++ b/internal/sts/sts.go @@ -3,9 +3,7 @@ // // This is an EXPERIMENTAL implementation for now. // -// It lacks (at least) the following: -// - DNS TXT checking. -// - Facilities for reporting. +// Note that "report" mode is not supported. // package sts @@ -20,6 +18,7 @@ import ( "io" "io/ioutil" "mime" + "net" "net/http" "os" "strconv" @@ -167,6 +166,14 @@ func UncheckedFetch(ctx context.Context, domain string) (*Policy, error) { return nil, err } + ok, err := hasSTSRecord(domain) + if err != nil { + return nil, err + } + if !ok { + return nil, fmt.Errorf("MTA-STS TXT record missing") + } + url := urlForDomain(domain) rawPolicy, err := httpGet(ctx, url) if err != nil { @@ -295,6 +302,29 @@ func domainToASCII(domain string) (string, error) { return idna.ToASCII(domain) } +// Function that we override for testing purposes. +// In the future we will override net.DefaultResolver, but we don't do that +// yet for backwards compatibility. +var lookupTXT = net.LookupTXT + +// hasSTSRecord checks if there is a valid MTA-STS TXT record for the domain. +// We don't do full parsing and don't care about the "id=" field, as it is +// unused in this implementation. +func hasSTSRecord(domain string) (bool, error) { + txts, err := lookupTXT("_mta-sts." + domain) + if err != nil { + return false, err + } + + for _, txt := range txts { + if strings.HasPrefix(txt, "v=STSv1;") { + return true, nil + } + } + + return false, nil +} + // PolicyCache is a caching layer for fetching policies. // // Policies are cached by domain, and stored in a single directory. diff --git a/internal/sts/sts_test.go b/internal/sts/sts_test.go index 546f113..b907a88 100644 --- a/internal/sts/sts_test.go +++ b/internal/sts/sts_test.go @@ -13,6 +13,27 @@ import ( "time" ) +// Override the lookup function to control its results. +var txtResults = map[string][]string{ + "dom1": nil, + "dom2": {}, + "dom3": {"abc", "def"}, + "dom4": {"abc", "v=STSv1; id=blah;"}, + + // Matching policyForDomain below. + "_mta-sts.domain.com": {"v=STSv1; id=blah;"}, + "_mta-sts.policy404": {"v=STSv1; id=blah;"}, + "_mta-sts.version99": {"v=STSv1; id=blah;"}, +} +var testError = fmt.Errorf("error for testing purposes") +var txtErrors = map[string]error{ + "_mta-sts.domErr": testError, +} + +func testLookupTXT(domain string) ([]string, error) { + return txtResults[domain], txtErrors[domain] +} + // Test policy for each of the requested domains. Will be served by the test // HTTP server. var policyForDomain = map[string]string{ @@ -45,6 +66,8 @@ func testHTTPHandler(w http.ResponseWriter, r *http.Request) { } func TestMain(m *testing.M) { + lookupTXT = testLookupTXT + // Create a test HTTP server, used by the more end-to-end tests. httpServer := httptest.NewServer(http.HandlerFunc(testHTTPHandler)) @@ -148,11 +171,11 @@ func TestFetch(t *testing.T) { t.Logf("domain.com: %+v", p) // Domain without a policy (HTTP get fails). - p, err = Fetch(context.Background(), "unknown") + p, err = Fetch(context.Background(), "policy404") if err == nil { t.Errorf("fetched unknown policy: %v", p) } - t.Logf("unknown: got error as expected: %v", err) + t.Logf("policy404: got error as expected: %v", err) // Domain with an invalid policy (unknown version). p, err = Fetch(context.Background(), "version99") @@ -161,6 +184,14 @@ func TestFetch(t *testing.T) { ErrUnknownVersion, err, p) } t.Logf("version99: got expected error: %v", err) + + // Error fetching TXT record for this domain. + p, err = Fetch(context.Background(), "domErr") + if err != testError { + t.Errorf("expected error %v, got %v (and policy: %v)", + testError, err, p) + } + t.Logf("domErr: got expected error: %v", err) } func TestPolicyTooBig(t *testing.T) { @@ -345,6 +376,7 @@ func TestCacheRefresh(t *testing.T) { ctx := context.Background() + txtResults["_mta-sts.refresh-test"] = []string{"v=STSv1; id=blah;"} policyForDomain["refresh-test"] = ` version: STSv1 mode: enforce @@ -393,3 +425,30 @@ func TestURLForDomain(t *testing.T) { t.Errorf("got %q, expected %q", got, expected) } } + +func TestHasSTSRecord(t *testing.T) { + txtResults["_mta-sts.dom1"] = nil + txtResults["_mta-sts.dom2"] = []string{} + txtResults["_mta-sts.dom3"] = []string{"abc", "def"} + txtResults["_mta-sts.dom4"] = []string{"abc", "v=STSv1; id=blah;"} + + cases := []struct { + domain string + ok bool + err error + }{ + {"", false, nil}, + {"dom1", false, nil}, + {"dom2", false, nil}, + {"dom3", false, nil}, + {"dom4", true, nil}, + {"domErr", false, testError}, + } + for _, c := range cases { + ok, err := hasSTSRecord(c.domain) + if ok != c.ok || err != c.err { + t.Errorf("%s: expected {%v, %v}, got {%v, %v}", c.domain, + c.ok, c.err, ok, err) + } + } +}