git » chasquid » commit 46bce57

sts: Add miscellaneous tests

author Alberto Bertogli
2018-05-27 15:20:35 UTC
committer Alberto Bertogli
2018-07-14 09:08:09 UTC
parent 79a8cfc21c6155ad11f94c37f7cddaf698c1028d

sts: Add miscellaneous tests

This patch adds a few miscellaneous tests to the sts package, covering
various previously-untested code paths.

internal/sts/sts.go +1 -8
internal/sts/sts_test.go +145 -26

diff --git a/internal/sts/sts.go b/internal/sts/sts.go
index 47f34c2..1c49880 100644
--- a/internal/sts/sts.go
+++ b/internal/sts/sts.go
@@ -225,12 +225,6 @@ func httpGet(ctx context.Context, url string) ([]byte, error) {
 		CheckRedirect: rejectRedirect,
 	}
 
-	// Note that http does not care for the context deadline, so we need to
-	// construct it here.
-	if deadline, ok := ctx.Deadline(); ok {
-		client.Timeout = deadline.Sub(time.Now())
-	}
-
 	resp, err := ctxhttp.Get(ctx, client, url)
 	if err != nil {
 		return nil, err
@@ -458,9 +452,8 @@ func (c *PolicyCache) Fetch(ctx context.Context, domain string) (*Policy, error)
 
 func (c *PolicyCache) PeriodicallyRefresh(ctx context.Context) {
 	for ctx.Err() == nil {
-		cacheRefreshCycles.Add(1)
-
 		c.refresh(ctx)
+		cacheRefreshCycles.Add(1)
 
 		// Wait 10 minutes between passes; this is a background refresh and
 		// there's no need to poke the servers very often.
diff --git a/internal/sts/sts_test.go b/internal/sts/sts_test.go
index b907a88..c9cec37 100644
--- a/internal/sts/sts_test.go
+++ b/internal/sts/sts_test.go
@@ -4,13 +4,15 @@ import (
 	"context"
 	"expvar"
 	"fmt"
-	"io/ioutil"
 	"net/http"
 	"net/http/httptest"
 	"os"
 	"strconv"
+	"strings"
 	"testing"
 	"time"
+
+	"blitiri.com.ar/go/chasquid/internal/testlib"
 )
 
 // Override the lookup function to control its results.
@@ -145,10 +147,18 @@ func TestMatchDomain(t *testing.T) {
 		{"x.ñaca.com", "x.xn--aca-6ma.com", true},
 		{"x.naca.com", "x.xn--aca-6ma.com", false},
 
+		// Triggers errors in domainToASCII.
+		{strings.Repeat("x", 65536) + "\uff00", "x.com", false},
+
 		// Examples from the RFC.
 		{"mail.example.com", "*.example.com", true},
 		{"example.com", "*.example.com", false},
 		{"foo.bar.example.com", "*.example.com", false},
+
+		// Missing "*" (invalid, seen in the wild).
+		{"aa.b.cc.com", ".aa.b.cc.com", false},
+		{"zz.aa.b.cc.com", ".aa.b.cc.com", false},
+		{"zz.aa.b.cc.com", "*.aa.b.cc.com", true},
 	}
 
 	for _, c := range cases {
@@ -159,6 +169,26 @@ func TestMatchDomain(t *testing.T) {
 	}
 }
 
+func TestMXIsAllowed(t *testing.T) {
+	p := Policy{Version: "STSv1", Mode: "enforce", MaxAge: 1 * time.Hour,
+		MXs: []string{"mx1", "mx2"}}
+	if p.MXIsAllowed("notamx") {
+		t.Errorf("notamx should not be allowed")
+	}
+	if !p.MXIsAllowed("mx1") {
+		t.Errorf("mx1 should be allowed")
+	}
+	if !p.MXIsAllowed("mx2") {
+		t.Errorf("mx2 should be allowed")
+	}
+
+	p = Policy{Version: "STSv1", Mode: "testing", MaxAge: 1 * time.Hour,
+		MXs: []string{"mx1"}}
+	if !p.MXIsAllowed("notamx") {
+		t.Errorf("notamx should be allowed (policy not enforced)")
+	}
+}
+
 func TestFetch(t *testing.T) {
 	// Note the data "fetched" for each domain comes from policyForDomain,
 	// defined in TestMain above. See httpGet for more details.
@@ -212,22 +242,6 @@ func TestPolicyTooBig(t *testing.T) {
 
 // Tests for the policy cache.
 
-func mustTempDir(t *testing.T) string {
-	dir, err := ioutil.TempDir("", "sts_test")
-	if err != nil {
-		t.Fatal(err)
-	}
-
-	err = os.Chdir(dir)
-	if err != nil {
-		t.Fatal(err)
-	}
-
-	t.Logf("test directory: %q", dir)
-
-	return dir
-}
-
 func expvarMustEq(t *testing.T, name string, v *expvar.Int, expected int) {
 	// TODO: Use v.Value once we drop support of Go 1.7.
 	value, _ := strconv.Atoi(v.String())
@@ -237,7 +251,7 @@ func expvarMustEq(t *testing.T, name string, v *expvar.Int, expected int) {
 }
 
 func TestCacheBasics(t *testing.T) {
-	dir := mustTempDir(t)
+	dir := testlib.MustTempDir(t)
 	c, err := NewCache(dir)
 	if err != nil {
 		t.Fatal(err)
@@ -285,6 +299,16 @@ func TestCacheBasics(t *testing.T) {
 	expvarMustEq(t, "cacheFetches", cacheFetches, 3)
 	expvarMustEq(t, "cacheHits", cacheHits, 1)
 
+	// Fetch for a domain without policy.
+	p, err = c.Fetch(ctx, "domErr")
+	if err == nil || p != nil {
+		t.Errorf("expected failure, got: policy = %v ; error = %v", p, err)
+	}
+	t.Logf("cache fetched domErr: %v", p)
+	expvarMustEq(t, "cacheFetches", cacheFetches, 4)
+	expvarMustEq(t, "cacheHits", cacheHits, 1)
+	expvarMustEq(t, "cacheFailedFetch", cacheFailedFetch, 1)
+
 	if !t.Failed() {
 		os.RemoveAll(dir)
 	}
@@ -292,7 +316,7 @@ func TestCacheBasics(t *testing.T) {
 
 // Test how the cache behaves when the files are corrupt.
 func TestCacheBadData(t *testing.T) {
-	dir := mustTempDir(t)
+	dir := testlib.MustTempDir(t)
 	c, err := NewCache(dir)
 	if err != nil {
 		t.Fatal(err)
@@ -300,12 +324,15 @@ func TestCacheBadData(t *testing.T) {
 
 	ctx := context.Background()
 
+	cacheUnmarshalErrors.Set(0)
+	cacheInvalid.Set(0)
+
 	cases := []string{
 		// Case 1: A file with invalid json, which will fail unmarshalling.
 		"this is not valid json",
 
 		// Case 2: A file with a parseable but invalid policy.
-		`{"version": "STSv1", "mode": "INVALID", "mx": ["mx"], max_age": 1}`,
+		`{"version": "STSv1", "mode": "INVALID", "mx": ["mx"], "max_age": 1}`,
 	}
 
 	for _, badContent := range cases {
@@ -325,10 +352,7 @@ func TestCacheBadData(t *testing.T) {
 
 		// Edit the file, filling it with the bad content for this case.
 		fname := c.domainPath("domain.com")
-		err = ioutil.WriteFile(fname, []byte(badContent), 0644)
-		if err != nil {
-			t.Fatalf("error writing file: %v", err)
-		}
+		mustRewriteAndChtime(t, fname, badContent)
 
 		// We now expect Fetch to fall back to getting the policy from the
 		// network (in our case, from policyForDomain).
@@ -353,6 +377,9 @@ func TestCacheBadData(t *testing.T) {
 		os.Remove(fname)
 	}
 
+	expvarMustEq(t, "cacheUnmarshalErrors", cacheUnmarshalErrors, 1)
+	expvarMustEq(t, "cacheInvalid", cacheInvalid, 1)
+
 	if !t.Failed() {
 		os.RemoveAll(dir)
 	}
@@ -367,8 +394,20 @@ func mustFetch(t *testing.T, c *PolicyCache, ctx context.Context, d string) *Pol
 	return p
 }
 
+func mustRewriteAndChtime(t *testing.T, fname, content string) {
+	testlib.Rewrite(t, fname, content)
+
+	// Advance the expiration time to the future, so the rewritten policy is
+	// not considered expired.
+	expires := time.Now().Add(10 * time.Second)
+	err := os.Chtimes(fname, expires, expires)
+	if err != nil {
+		t.Fatalf("failed to chtime %q to the past: %v", fname, err)
+	}
+}
+
 func TestCacheRefresh(t *testing.T) {
-	dir := mustTempDir(t)
+	dir := testlib.MustTempDir(t)
 	c, err := NewCache(dir)
 	if err != nil {
 		t.Fatal(err)
@@ -400,7 +439,16 @@ func TestCacheRefresh(t *testing.T) {
 		t.Fatalf("policy.MaxAge is %v, expected 100s", p.MaxAge)
 	}
 
-	c.refresh(ctx)
+	// Launch background refreshes, and wait for one to complete.
+	// TODO: change to cacheRefreshCycles.Value once we drop support for Go
+	// 1.7.
+	cacheRefreshCycles.Set(0)
+	ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
+	defer cancel()
+	go c.PeriodicallyRefresh(ctx)
+	for cacheRefreshCycles.String() == "0" {
+		time.Sleep(5 * time.Millisecond)
+	}
 
 	p = mustFetch(t, c, ctx, "refresh-test")
 	if p.MaxAge != 200*time.Second {
@@ -412,6 +460,24 @@ func TestCacheRefresh(t *testing.T) {
 	}
 }
 
+func TestCacheSlashSafe(t *testing.T) {
+	dir := testlib.MustTempDir(t)
+	c, err := NewCache(dir)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	defer func() {
+		if r := recover(); r != nil {
+			t.Logf("recovered: %v", r)
+		} else {
+			t.Fatalf("check did not panic as expected")
+		}
+	}()
+
+	c.domainPath("a/b")
+}
+
 func TestURLForDomain(t *testing.T) {
 	// This function will behave differently if fakeURLForTesting is set, so
 	// temporarily unset it.
@@ -452,3 +518,56 @@ func TestHasSTSRecord(t *testing.T) {
 		}
 	}
 }
+
+func TestHTTPGet(t *testing.T) {
+	// Basic test, it should work.
+	srv1 := httptest.NewServer(
+		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+			w.Write([]byte(policyForDomain["domain.com"]))
+		}))
+	defer srv1.Close()
+
+	ctx := context.Background()
+	raw, err := httpGet(ctx, srv1.URL)
+	if err != nil {
+		t.Errorf("GET failed: got %q, %v", raw, err)
+	}
+
+	// Test that redirects are rejected.
+	srv2 := httptest.NewServer(
+		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+			http.Redirect(w, r, fakeURLForTesting, http.StatusMovedPermanently)
+		}))
+	defer srv2.Close()
+
+	raw, err = httpGet(ctx, srv2.URL)
+	if err == nil {
+		t.Errorf("redirect allowed, should have failed: got %q, %v", raw, err)
+	}
+
+	// Content type != text/plain should be rejected.
+	srv3 := httptest.NewServer(
+		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+			w.Header().Set("Content-Type", "text/json")
+			w.Write([]byte(policyForDomain["domain.com"]))
+		}))
+	defer srv3.Close()
+
+	raw, err = httpGet(ctx, srv3.URL)
+	if err != ErrInvalidMediaType {
+		t.Errorf("content type != text/plain was allowed: got %q, %v", raw, err)
+	}
+
+	// Invalid (unparseable) media type.
+	srv4 := httptest.NewServer(
+		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+			w.Header().Set("Content-Type", "invalid/content/type")
+			w.Write([]byte(policyForDomain["domain.com"]))
+		}))
+	defer srv4.Close()
+
+	raw, err = httpGet(ctx, srv4.URL)
+	if err == nil || err == ErrInvalidMediaType {
+		t.Errorf("invalid content type was allowed: got %q, %v", raw, err)
+	}
+}