git » chasquid » commit e66288e

sts: Make tests more end-to-end, to cover HTTP fetching

author Alberto Bertogli
2017-02-28 23:57:04 UTC
committer Alberto Bertogli
2017-03-01 00:10:10 UTC
parent 216cf47ffa1d6b6d6b7e763275a41e1ec4f42273

sts: Make tests more end-to-end, to cover HTTP fetching

The current tests stop short of fetching over HTTP, but that code is
unfortunately not trivial.

This patch changes the testing strategy to use a testing HTTP server,
which we point our URLs to. That way we can cover much more code with the
same tests.

internal/sts/sts.go +17 -18
internal/sts/sts_test.go +48 -15

diff --git a/internal/sts/sts.go b/internal/sts/sts.go
index a3fc5cf..70ebeaf 100644
--- a/internal/sts/sts.go
+++ b/internal/sts/sts.go
@@ -15,6 +15,7 @@ import (
 	"errors"
 	"expvar"
 	"fmt"
+	"io"
 	"io/ioutil"
 	"net/http"
 	"os"
@@ -130,11 +131,7 @@ func UncheckedFetch(ctx context.Context, domain string) (*Policy, error) {
 		return nil, err
 	}
 
-	// URL composed from the domain, as explained in:
-	// https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-3.3
-	// https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-3.2
-	url := "https://mta-sts." + domain + "/.well-known/mta-sts.json"
-
+	url := urlForDomain(domain)
 	rawPolicy, err := httpGet(ctx, url)
 	if err != nil {
 		return nil, err
@@ -143,6 +140,21 @@ func UncheckedFetch(ctx context.Context, domain string) (*Policy, error) {
 	return parsePolicy(rawPolicy)
 }
 
+// Fake URL for testing purposes, so we can do more end-to-end tests,
+// including the HTTP fetching code.
+var fakeURLForTesting string
+
+func urlForDomain(domain string) string {
+	if fakeURLForTesting != "" {
+		return fakeURLForTesting + "/" + domain
+	}
+
+	// URL composed from the domain, as explained in:
+	// https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-3.3
+	// https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-3.2
+	return "https://mta-sts." + domain + "/.well-known/mta-sts.json"
+}
+
 // Fetch a policy for the given domain. Note this results in various network
 // lookups and HTTPS GETs, so it can be slow.
 // The returned policy is parsed and sanity-checked (using Policy.Check), so
@@ -161,9 +173,6 @@ func Fetch(ctx context.Context, domain string) (*Policy, error) {
 	return p, nil
 }
 
-// Fake HTTP content for testing purposes only.
-var fakeContent = map[string]string{}
-
 // httpGet performs an HTTP GET of the given URL, using the context and
 // rejecting redirects, as per the standard.
 func httpGet(ctx context.Context, url string) ([]byte, error) {
@@ -179,16 +188,6 @@ func httpGet(ctx context.Context, url string) ([]byte, error) {
 		client.Timeout = deadline.Sub(time.Now())
 	}
 
-	if len(fakeContent) > 0 {
-		// If we have fake content for testing, then return the content for
-		// the URL, or an error if it's missing.
-		// This makes sure we don't make actual requests for testing.
-		if d, ok := fakeContent[url]; ok {
-			return []byte(d), nil
-		}
-		return nil, errors.New("error for testing")
-	}
-
 	resp, err := ctxhttp.Get(ctx, client, url)
 	if err != nil {
 		return nil, err
diff --git a/internal/sts/sts_test.go b/internal/sts/sts_test.go
index 71b4fcc..5ec0877 100644
--- a/internal/sts/sts_test.go
+++ b/internal/sts/sts_test.go
@@ -3,34 +3,53 @@ package sts
 import (
 	"context"
 	"expvar"
+	"fmt"
 	"io/ioutil"
+	"net/http"
+	"net/http/httptest"
 	"os"
 	"testing"
 	"time"
 )
 
-func TestMain(m *testing.M) {
-	// Populate the fake policy contents, used by a few tests.
-	// httpGet will use this data instead of using the network.
-
+// Test policy for each of the requested domains.  Will be served by the test
+// HTTP server.
+var policyForDomain = map[string]string{
 	// domain.com -> valid, with reasonable policy.
-	fakeContent["https://mta-sts.domain.com/.well-known/mta-sts.json"] = `
+	"domain.com": `
 		{
              "version": "STSv1",
              "mode": "enforce",
              "mx": ["*.mail.domain.com"],
              "max_age": 3600
-        }`
+        }`,
 
 	// version99 -> invalid policy (unknown version).
-	fakeContent["https://mta-sts.version99/.well-known/mta-sts.json"] = `
+	"version99": `
 		{
              "version": "STSv99",
              "mode": "enforce",
              "mx": ["*.mail.version99"],
              "max_age": 999
-        }`
+        }`,
+}
+
+func testHTTPHandler(w http.ResponseWriter, r *http.Request) {
+	// For testing, the domain in the path (see urlForDomain).
+	policy, ok := policyForDomain[r.URL.Path[1:]]
+	if !ok {
+		http.Error(w, "not found", 404)
+		return
+	}
+	fmt.Fprintln(w, policy)
+	return
+}
 
+func TestMain(m *testing.M) {
+	// Create a test HTTP server, used by the more end-to-end tests.
+	httpServer := httptest.NewServer(http.HandlerFunc(testHTTPHandler))
+
+	fakeURLForTesting = httpServer.URL
 	os.Exit(m.Run())
 }
 
@@ -112,8 +131,8 @@ func TestMatchDomain(t *testing.T) {
 }
 
 func TestFetch(t *testing.T) {
-	// Note the data "fetched" for each domain comes from fakeContent, defined
-	// in TestMain above. See httpGet for more details.
+	// Note the data "fetched" for each domain comes from policyForDomain,
+	// defined in TestMain above. See httpGet for more details.
 
 	// Normal fetch, all valid.
 	p, err := Fetch(context.Background(), "domain.com")
@@ -170,8 +189,8 @@ func TestCacheBasics(t *testing.T) {
 		t.Fatal(err)
 	}
 
-	// Note the data "fetched" for each domain comes from fakeContent, defined
-	// in TestMain above. See httpGet for more details.
+	// Note the data "fetched" for each domain comes from policyForDomain,
+	// defined in TestMain above. See httpGet for more details.
 
 	// Reset the expvar counters that we use to validate hits, misses, etc.
 	cacheFetches.Set(0)
@@ -258,7 +277,7 @@ func TestCacheBadData(t *testing.T) {
 		}
 
 		// We now expect Fetch to fall back to getting the policy from the
-		// network (in our case, from fakeContent).
+		// network (in our case, from policyForDomain).
 		p, err = c.Fetch(ctx, "domain.com")
 		if err != nil {
 			t.Fatalf("Fetch failed: %v", err)
@@ -303,7 +322,7 @@ func TestCacheRefresh(t *testing.T) {
 
 	ctx := context.Background()
 
-	fakeContent["https://mta-sts.refresh-test/.well-known/mta-sts.json"] = `
+	policyForDomain["refresh-test"] = `
 		{"version": "STSv1", "mode": "enforce", "mx": ["mx"], "max_age": 100}`
 	p := mustFetch(t, c, ctx, "refresh-test")
 	if p.MaxAge != 100*time.Second {
@@ -312,7 +331,7 @@ func TestCacheRefresh(t *testing.T) {
 
 	// Change the "published" policy, check that we see the old version at
 	// fetch (should be cached), and a new version after a refresh.
-	fakeContent["https://mta-sts.refresh-test/.well-known/mta-sts.json"] = `
+	policyForDomain["refresh-test"] = `
 		{"version": "STSv1", "mode": "enforce", "mx": ["mx"], "max_age": 200}`
 
 	p = mustFetch(t, c, ctx, "refresh-test")
@@ -331,3 +350,17 @@ func TestCacheRefresh(t *testing.T) {
 		os.RemoveAll(dir)
 	}
 }
+
+func TestURLForDomain(t *testing.T) {
+	// This function will behave differently if fakeURLForTesting is set, so
+	// temporarily unset it.
+	oldURL := fakeURLForTesting
+	fakeURLForTesting = ""
+	defer func() { fakeURLForTesting = oldURL }()
+
+	got := urlForDomain("a-test-domain")
+	expected := "https://mta-sts.a-test-domain/.well-known/mta-sts.json"
+	if got != expected {
+		t.Errorf("got %q, expected %q", got, expected)
+	}
+}