author | Alberto Bertogli
<albertito@blitiri.com.ar> 2017-02-28 23:57:04 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2017-03-01 00:10:10 UTC |
parent | 216cf47ffa1d6b6d6b7e763275a41e1ec4f42273 |
internal/sts/sts.go | +17 | -18 |
internal/sts/sts_test.go | +48 | -15 |
diff --git a/internal/sts/sts.go b/internal/sts/sts.go index a3fc5cf..70ebeaf 100644 --- a/internal/sts/sts.go +++ b/internal/sts/sts.go @@ -15,6 +15,7 @@ import ( "errors" "expvar" "fmt" + "io" "io/ioutil" "net/http" "os" @@ -130,11 +131,7 @@ func UncheckedFetch(ctx context.Context, domain string) (*Policy, error) { return nil, err } - // URL composed from the domain, as explained in: - // https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-3.3 - // https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-3.2 - url := "https://mta-sts." + domain + "/.well-known/mta-sts.json" - + url := urlForDomain(domain) rawPolicy, err := httpGet(ctx, url) if err != nil { return nil, err @@ -143,6 +140,21 @@ func UncheckedFetch(ctx context.Context, domain string) (*Policy, error) { return parsePolicy(rawPolicy) } +// Fake URL for testing purposes, so we can do more end-to-end tests, +// including the HTTP fetching code. +var fakeURLForTesting string + +func urlForDomain(domain string) string { + if fakeURLForTesting != "" { + return fakeURLForTesting + "/" + domain + } + + // URL composed from the domain, as explained in: + // https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-3.3 + // https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-3.2 + return "https://mta-sts." + domain + "/.well-known/mta-sts.json" +} + // Fetch a policy for the given domain. Note this results in various network // lookups and HTTPS GETs, so it can be slow. // The returned policy is parsed and sanity-checked (using Policy.Check), so @@ -161,9 +173,6 @@ func Fetch(ctx context.Context, domain string) (*Policy, error) { return p, nil } -// Fake HTTP content for testing purposes only. -var fakeContent = map[string]string{} - // httpGet performs an HTTP GET of the given URL, using the context and // rejecting redirects, as per the standard. func httpGet(ctx context.Context, url string) ([]byte, error) { @@ -179,16 +188,6 @@ func httpGet(ctx context.Context, url string) ([]byte, error) { client.Timeout = deadline.Sub(time.Now()) } - if len(fakeContent) > 0 { - // If we have fake content for testing, then return the content for - // the URL, or an error if it's missing. - // This makes sure we don't make actual requests for testing. - if d, ok := fakeContent[url]; ok { - return []byte(d), nil - } - return nil, errors.New("error for testing") - } - resp, err := ctxhttp.Get(ctx, client, url) if err != nil { return nil, err diff --git a/internal/sts/sts_test.go b/internal/sts/sts_test.go index 71b4fcc..5ec0877 100644 --- a/internal/sts/sts_test.go +++ b/internal/sts/sts_test.go @@ -3,34 +3,53 @@ package sts import ( "context" "expvar" + "fmt" "io/ioutil" + "net/http" + "net/http/httptest" "os" "testing" "time" ) -func TestMain(m *testing.M) { - // Populate the fake policy contents, used by a few tests. - // httpGet will use this data instead of using the network. - +// Test policy for each of the requested domains. Will be served by the test +// HTTP server. +var policyForDomain = map[string]string{ // domain.com -> valid, with reasonable policy. - fakeContent["https://mta-sts.domain.com/.well-known/mta-sts.json"] = ` + "domain.com": ` { "version": "STSv1", "mode": "enforce", "mx": ["*.mail.domain.com"], "max_age": 3600 - }` + }`, // version99 -> invalid policy (unknown version). - fakeContent["https://mta-sts.version99/.well-known/mta-sts.json"] = ` + "version99": ` { "version": "STSv99", "mode": "enforce", "mx": ["*.mail.version99"], "max_age": 999 - }` + }`, +} + +func testHTTPHandler(w http.ResponseWriter, r *http.Request) { + // For testing, the domain in the path (see urlForDomain). + policy, ok := policyForDomain[r.URL.Path[1:]] + if !ok { + http.Error(w, "not found", 404) + return + } + fmt.Fprintln(w, policy) + return +} +func TestMain(m *testing.M) { + // Create a test HTTP server, used by the more end-to-end tests. + httpServer := httptest.NewServer(http.HandlerFunc(testHTTPHandler)) + + fakeURLForTesting = httpServer.URL os.Exit(m.Run()) } @@ -112,8 +131,8 @@ func TestMatchDomain(t *testing.T) { } func TestFetch(t *testing.T) { - // Note the data "fetched" for each domain comes from fakeContent, defined - // in TestMain above. See httpGet for more details. + // Note the data "fetched" for each domain comes from policyForDomain, + // defined in TestMain above. See httpGet for more details. // Normal fetch, all valid. p, err := Fetch(context.Background(), "domain.com") @@ -170,8 +189,8 @@ func TestCacheBasics(t *testing.T) { t.Fatal(err) } - // Note the data "fetched" for each domain comes from fakeContent, defined - // in TestMain above. See httpGet for more details. + // Note the data "fetched" for each domain comes from policyForDomain, + // defined in TestMain above. See httpGet for more details. // Reset the expvar counters that we use to validate hits, misses, etc. cacheFetches.Set(0) @@ -258,7 +277,7 @@ func TestCacheBadData(t *testing.T) { } // We now expect Fetch to fall back to getting the policy from the - // network (in our case, from fakeContent). + // network (in our case, from policyForDomain). p, err = c.Fetch(ctx, "domain.com") if err != nil { t.Fatalf("Fetch failed: %v", err) @@ -303,7 +322,7 @@ func TestCacheRefresh(t *testing.T) { ctx := context.Background() - fakeContent["https://mta-sts.refresh-test/.well-known/mta-sts.json"] = ` + policyForDomain["refresh-test"] = ` {"version": "STSv1", "mode": "enforce", "mx": ["mx"], "max_age": 100}` p := mustFetch(t, c, ctx, "refresh-test") if p.MaxAge != 100*time.Second { @@ -312,7 +331,7 @@ func TestCacheRefresh(t *testing.T) { // Change the "published" policy, check that we see the old version at // fetch (should be cached), and a new version after a refresh. - fakeContent["https://mta-sts.refresh-test/.well-known/mta-sts.json"] = ` + policyForDomain["refresh-test"] = ` {"version": "STSv1", "mode": "enforce", "mx": ["mx"], "max_age": 200}` p = mustFetch(t, c, ctx, "refresh-test") @@ -331,3 +350,17 @@ func TestCacheRefresh(t *testing.T) { os.RemoveAll(dir) } } + +func TestURLForDomain(t *testing.T) { + // This function will behave differently if fakeURLForTesting is set, so + // temporarily unset it. + oldURL := fakeURLForTesting + fakeURLForTesting = "" + defer func() { fakeURLForTesting = oldURL }() + + got := urlForDomain("a-test-domain") + expected := "https://mta-sts.a-test-domain/.well-known/mta-sts.json" + if got != expected { + t.Errorf("got %q, expected %q", got, expected) + } +}