author | Alberto Bertogli
<albertito@blitiri.com.ar> 2018-05-27 15:20:35 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2018-07-14 09:08:09 UTC |
parent | 79a8cfc21c6155ad11f94c37f7cddaf698c1028d |
internal/sts/sts.go | +1 | -8 |
internal/sts/sts_test.go | +145 | -26 |
diff --git a/internal/sts/sts.go b/internal/sts/sts.go index 47f34c2..1c49880 100644 --- a/internal/sts/sts.go +++ b/internal/sts/sts.go @@ -225,12 +225,6 @@ func httpGet(ctx context.Context, url string) ([]byte, error) { CheckRedirect: rejectRedirect, } - // Note that http does not care for the context deadline, so we need to - // construct it here. - if deadline, ok := ctx.Deadline(); ok { - client.Timeout = deadline.Sub(time.Now()) - } - resp, err := ctxhttp.Get(ctx, client, url) if err != nil { return nil, err @@ -458,9 +452,8 @@ func (c *PolicyCache) Fetch(ctx context.Context, domain string) (*Policy, error) func (c *PolicyCache) PeriodicallyRefresh(ctx context.Context) { for ctx.Err() == nil { - cacheRefreshCycles.Add(1) - c.refresh(ctx) + cacheRefreshCycles.Add(1) // Wait 10 minutes between passes; this is a background refresh and // there's no need to poke the servers very often. diff --git a/internal/sts/sts_test.go b/internal/sts/sts_test.go index b907a88..c9cec37 100644 --- a/internal/sts/sts_test.go +++ b/internal/sts/sts_test.go @@ -4,13 +4,15 @@ import ( "context" "expvar" "fmt" - "io/ioutil" "net/http" "net/http/httptest" "os" "strconv" + "strings" "testing" "time" + + "blitiri.com.ar/go/chasquid/internal/testlib" ) // Override the lookup function to control its results. @@ -145,10 +147,18 @@ func TestMatchDomain(t *testing.T) { {"x.ñaca.com", "x.xn--aca-6ma.com", true}, {"x.naca.com", "x.xn--aca-6ma.com", false}, + // Triggers errors in domainToASCII. + {strings.Repeat("x", 65536) + "\uff00", "x.com", false}, + // Examples from the RFC. {"mail.example.com", "*.example.com", true}, {"example.com", "*.example.com", false}, {"foo.bar.example.com", "*.example.com", false}, + + // Missing "*" (invalid, seen in the wild). + {"aa.b.cc.com", ".aa.b.cc.com", false}, + {"zz.aa.b.cc.com", ".aa.b.cc.com", false}, + {"zz.aa.b.cc.com", "*.aa.b.cc.com", true}, } for _, c := range cases { @@ -159,6 +169,26 @@ func TestMatchDomain(t *testing.T) { } } +func TestMXIsAllowed(t *testing.T) { + p := Policy{Version: "STSv1", Mode: "enforce", MaxAge: 1 * time.Hour, + MXs: []string{"mx1", "mx2"}} + if p.MXIsAllowed("notamx") { + t.Errorf("notamx should not be allowed") + } + if !p.MXIsAllowed("mx1") { + t.Errorf("mx1 should be allowed") + } + if !p.MXIsAllowed("mx2") { + t.Errorf("mx2 should be allowed") + } + + p = Policy{Version: "STSv1", Mode: "testing", MaxAge: 1 * time.Hour, + MXs: []string{"mx1"}} + if !p.MXIsAllowed("notamx") { + t.Errorf("notamx should be allowed (policy not enforced)") + } +} + func TestFetch(t *testing.T) { // Note the data "fetched" for each domain comes from policyForDomain, // defined in TestMain above. See httpGet for more details. @@ -212,22 +242,6 @@ func TestPolicyTooBig(t *testing.T) { // Tests for the policy cache. -func mustTempDir(t *testing.T) string { - dir, err := ioutil.TempDir("", "sts_test") - if err != nil { - t.Fatal(err) - } - - err = os.Chdir(dir) - if err != nil { - t.Fatal(err) - } - - t.Logf("test directory: %q", dir) - - return dir -} - func expvarMustEq(t *testing.T, name string, v *expvar.Int, expected int) { // TODO: Use v.Value once we drop support of Go 1.7. value, _ := strconv.Atoi(v.String()) @@ -237,7 +251,7 @@ func expvarMustEq(t *testing.T, name string, v *expvar.Int, expected int) { } func TestCacheBasics(t *testing.T) { - dir := mustTempDir(t) + dir := testlib.MustTempDir(t) c, err := NewCache(dir) if err != nil { t.Fatal(err) @@ -285,6 +299,16 @@ func TestCacheBasics(t *testing.T) { expvarMustEq(t, "cacheFetches", cacheFetches, 3) expvarMustEq(t, "cacheHits", cacheHits, 1) + // Fetch for a domain without policy. + p, err = c.Fetch(ctx, "domErr") + if err == nil || p != nil { + t.Errorf("expected failure, got: policy = %v ; error = %v", p, err) + } + t.Logf("cache fetched domErr: %v", p) + expvarMustEq(t, "cacheFetches", cacheFetches, 4) + expvarMustEq(t, "cacheHits", cacheHits, 1) + expvarMustEq(t, "cacheFailedFetch", cacheFailedFetch, 1) + if !t.Failed() { os.RemoveAll(dir) } @@ -292,7 +316,7 @@ func TestCacheBasics(t *testing.T) { // Test how the cache behaves when the files are corrupt. func TestCacheBadData(t *testing.T) { - dir := mustTempDir(t) + dir := testlib.MustTempDir(t) c, err := NewCache(dir) if err != nil { t.Fatal(err) @@ -300,12 +324,15 @@ func TestCacheBadData(t *testing.T) { ctx := context.Background() + cacheUnmarshalErrors.Set(0) + cacheInvalid.Set(0) + cases := []string{ // Case 1: A file with invalid json, which will fail unmarshalling. "this is not valid json", // Case 2: A file with a parseable but invalid policy. - `{"version": "STSv1", "mode": "INVALID", "mx": ["mx"], max_age": 1}`, + `{"version": "STSv1", "mode": "INVALID", "mx": ["mx"], "max_age": 1}`, } for _, badContent := range cases { @@ -325,10 +352,7 @@ func TestCacheBadData(t *testing.T) { // Edit the file, filling it with the bad content for this case. fname := c.domainPath("domain.com") - err = ioutil.WriteFile(fname, []byte(badContent), 0644) - if err != nil { - t.Fatalf("error writing file: %v", err) - } + mustRewriteAndChtime(t, fname, badContent) // We now expect Fetch to fall back to getting the policy from the // network (in our case, from policyForDomain). @@ -353,6 +377,9 @@ func TestCacheBadData(t *testing.T) { os.Remove(fname) } + expvarMustEq(t, "cacheUnmarshalErrors", cacheUnmarshalErrors, 1) + expvarMustEq(t, "cacheInvalid", cacheInvalid, 1) + if !t.Failed() { os.RemoveAll(dir) } @@ -367,8 +394,20 @@ func mustFetch(t *testing.T, c *PolicyCache, ctx context.Context, d string) *Pol return p } +func mustRewriteAndChtime(t *testing.T, fname, content string) { + testlib.Rewrite(t, fname, content) + + // Advance the expiration time to the future, so the rewritten policy is + // not considered expired. + expires := time.Now().Add(10 * time.Second) + err := os.Chtimes(fname, expires, expires) + if err != nil { + t.Fatalf("failed to chtime %q to the past: %v", fname, err) + } +} + func TestCacheRefresh(t *testing.T) { - dir := mustTempDir(t) + dir := testlib.MustTempDir(t) c, err := NewCache(dir) if err != nil { t.Fatal(err) @@ -400,7 +439,16 @@ func TestCacheRefresh(t *testing.T) { t.Fatalf("policy.MaxAge is %v, expected 100s", p.MaxAge) } - c.refresh(ctx) + // Launch background refreshes, and wait for one to complete. + // TODO: change to cacheRefreshCycles.Value once we drop support for Go + // 1.7. + cacheRefreshCycles.Set(0) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + go c.PeriodicallyRefresh(ctx) + for cacheRefreshCycles.String() == "0" { + time.Sleep(5 * time.Millisecond) + } p = mustFetch(t, c, ctx, "refresh-test") if p.MaxAge != 200*time.Second { @@ -412,6 +460,24 @@ func TestCacheRefresh(t *testing.T) { } } +func TestCacheSlashSafe(t *testing.T) { + dir := testlib.MustTempDir(t) + c, err := NewCache(dir) + if err != nil { + t.Fatal(err) + } + + defer func() { + if r := recover(); r != nil { + t.Logf("recovered: %v", r) + } else { + t.Fatalf("check did not panic as expected") + } + }() + + c.domainPath("a/b") +} + func TestURLForDomain(t *testing.T) { // This function will behave differently if fakeURLForTesting is set, so // temporarily unset it. @@ -452,3 +518,56 @@ func TestHasSTSRecord(t *testing.T) { } } } + +func TestHTTPGet(t *testing.T) { + // Basic test, it should work. + srv1 := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(policyForDomain["domain.com"])) + })) + defer srv1.Close() + + ctx := context.Background() + raw, err := httpGet(ctx, srv1.URL) + if err != nil { + t.Errorf("GET failed: got %q, %v", raw, err) + } + + // Test that redirects are rejected. + srv2 := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, fakeURLForTesting, http.StatusMovedPermanently) + })) + defer srv2.Close() + + raw, err = httpGet(ctx, srv2.URL) + if err == nil { + t.Errorf("redirect allowed, should have failed: got %q, %v", raw, err) + } + + // Content type != text/plain should be rejected. + srv3 := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/json") + w.Write([]byte(policyForDomain["domain.com"])) + })) + defer srv3.Close() + + raw, err = httpGet(ctx, srv3.URL) + if err != ErrInvalidMediaType { + t.Errorf("content type != text/plain was allowed: got %q, %v", raw, err) + } + + // Invalid (unparseable) media type. + srv4 := httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "invalid/content/type") + w.Write([]byte(policyForDomain["domain.com"])) + })) + defer srv4.Close() + + raw, err = httpGet(ctx, srv4.URL) + if err == nil || err == ErrInvalidMediaType { + t.Errorf("invalid content type was allowed: got %q, %v", raw, err) + } +}