git » chasquid » commit d66b06d

sts: Add an on-disk cache implementation

author Alberto Bertogli
2017-02-25 19:54:29 UTC
committer Alberto Bertogli
2017-02-28 22:27:15 UTC
parent 79c0a17328fd4da360621ef28ab620e148365339

sts: Add an on-disk cache implementation

This patch adds an on-disk cache for STS policies.

Policies are cached by domain, and stored on files in a single
directory.  The files will have as mtime the time when the policy
expires, this makes the store simpler, as it can avoid keeping
additional metadata.

There is no in-memory caching. This may be added in the future, but for
now disk is good enough for our purposes.

internal/sts/sts.go +213 -3
internal/sts/sts_test.go +225 -14

diff --git a/internal/sts/sts.go b/internal/sts/sts.go
index 8d91c20..a3fc5cf 100644
--- a/internal/sts/sts.go
+++ b/internal/sts/sts.go
@@ -4,7 +4,6 @@
 // This is an EXPERIMENTAL implementation for now.
 //
 // It lacks (at least) the following:
-// - Caching.
 // - DNS TXT checking.
 // - Facilities for reporting.
 //
@@ -14,15 +13,40 @@ import (
 	"context"
 	"encoding/json"
 	"errors"
+	"expvar"
+	"fmt"
 	"io/ioutil"
 	"net/http"
+	"os"
 	"strings"
+	"sync"
 	"time"
 
+	"blitiri.com.ar/go/chasquid/internal/safeio"
+	"blitiri.com.ar/go/chasquid/internal/trace"
+
 	"golang.org/x/net/context/ctxhttp"
 	"golang.org/x/net/idna"
 )
 
+// Exported variables.
+var (
+	cacheFetches = expvar.NewInt("chasquid/sts/cache/fetches")
+	cacheHits    = expvar.NewInt("chasquid/sts/cache/hits")
+	cacheExpired = expvar.NewInt("chasquid/sts/cache/expired")
+
+	cacheIOErrors    = expvar.NewInt("chasquid/sts/cache/ioErrors")
+	cacheFailedFetch = expvar.NewInt("chasquid/sts/cache/failedFetch")
+	cacheInvalid     = expvar.NewInt("chasquid/sts/cache/invalid")
+
+	cacheMarshalErrors   = expvar.NewInt("chasquid/sts/cache/marshalErrors")
+	cacheUnmarshalErrors = expvar.NewInt("chasquid/sts/cache/unmarshalErrors")
+
+	cacheRefreshCycles = expvar.NewInt("chasquid/sts/cache/refreshCycles")
+	cacheRefreshes     = expvar.NewInt("chasquid/sts/cache/refreshes")
+	cacheRefreshErrors = expvar.NewInt("chasquid/sts/cache/refreshErrors")
+)
+
 // Policy represents a parsed policy.
 // https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-3.2
 type Policy struct {
@@ -169,9 +193,12 @@ func httpGet(ctx context.Context, url string) ([]byte, error) {
 	if err != nil {
 		return nil, err
 	}
-
 	defer resp.Body.Close()
-	return ioutil.ReadAll(resp.Body)
+
+	if resp.StatusCode == http.StatusOK {
+		return ioutil.ReadAll(resp.Body)
+	}
+	return nil, fmt.Errorf("HTTP response status code: %v", resp.StatusCode)
 }
 
 var errRejectRedirect = errors.New("redirects not allowed in MTA-STS")
@@ -222,3 +249,186 @@ func domainToASCII(domain string) (string, error) {
 	domain = strings.ToLower(domain)
 	return idna.ToASCII(domain)
 }
+
+// PolicyCache is a caching layer for fetching policies.
+//
+// Policies are cached by domain, and stored in a single directory.
+// The files will have as mtime the time when the policy expires, this makes
+// the store simpler, as it can avoid keeping additional metadata.
+//
+// There is no in-memory caching. This may be added in the future, but for
+// now disk is good enough for our purposes.
+type PolicyCache struct {
+	dir string
+
+	sync.Mutex
+}
+
+func NewCache(dir string) (*PolicyCache, error) {
+	c := &PolicyCache{
+		dir: dir,
+	}
+	err := os.MkdirAll(dir, 0770)
+	return c, err
+}
+
+const pathPrefix = "pol:"
+
+func (c *PolicyCache) domainPath(domain string) string {
+	// We assume the domain is well formed, sanity check just in case.
+	if strings.Contains(domain, "/") {
+		panic("domain contains slash")
+	}
+
+	return c.dir + "/" + pathPrefix + domain
+}
+
+var ErrExpired = errors.New("cache entry expired")
+
+func (c *PolicyCache) load(domain string) (*Policy, error) {
+	fname := c.domainPath(domain)
+
+	fi, err := os.Stat(fname)
+	if err != nil {
+		return nil, err
+	}
+	if time.Since(fi.ModTime()) > 0 {
+		cacheExpired.Add(1)
+		return nil, ErrExpired
+	}
+
+	data, err := ioutil.ReadFile(fname)
+	if err != nil {
+		cacheIOErrors.Add(1)
+		return nil, err
+	}
+
+	p := &Policy{}
+	err = json.Unmarshal(data, p)
+	if err != nil {
+		cacheUnmarshalErrors.Add(1)
+		return nil, err
+	}
+
+	// The policy should always be valid, as we marshalled it ourselves;
+	// however, check it just to be safe.
+	if err := p.Check(); err != nil {
+		cacheInvalid.Add(1)
+		return nil, fmt.Errorf(
+			"%s unmarshalled invalid policy %v: %v", domain, p, err)
+	}
+
+	return p, nil
+}
+
+func (c *PolicyCache) store(domain string, p *Policy) error {
+	data, err := json.Marshal(p)
+	if err != nil {
+		cacheMarshalErrors.Add(1)
+		return fmt.Errorf("%s failed to marshal policy %v, error: %v",
+			domain, p, err)
+	}
+
+	// Change the modification time to the future, when the policy expires.
+	// load will check for this to detect expired cache entries, see above for
+	// the details.
+	expires := time.Now().Add(p.MaxAge)
+	chTime := func(fname string) error {
+		return os.Chtimes(fname, expires, expires)
+	}
+
+	fname := c.domainPath(domain)
+	err = safeio.WriteFile(fname, data, 0640, chTime)
+	if err != nil {
+		cacheIOErrors.Add(1)
+	}
+	return err
+}
+
+func (c *PolicyCache) Fetch(ctx context.Context, domain string) (*Policy, error) {
+	cacheFetches.Add(1)
+	tr := trace.New("STSCache.Fetch", domain)
+	defer tr.Finish()
+
+	p, err := c.load(domain)
+	if err == nil {
+		tr.Debugf("cache hit: %v", p)
+		cacheHits.Add(1)
+		return p, nil
+	}
+
+	p, err = Fetch(ctx, domain)
+	if err != nil {
+		tr.Debugf("failed to fetch: %v", err)
+		cacheFailedFetch.Add(1)
+		return nil, err
+	}
+	tr.Debugf("fetched: %v", p)
+
+	// We could do this asynchronously, as we got the policy to give to the
+	// caller. However, to make troubleshooting easier and the cost of storing
+	// entries easier to track down, we store synchronously.
+	// Note that even if the store returns an error, we pass on the policy: at
+	// this point we rather use the policy even if we couldn't store it in the
+	// cache.
+	err = c.store(domain, p)
+	if err != nil {
+		tr.Errorf("failed to store: %v", err)
+	} else {
+		tr.Debugf("stored")
+	}
+
+	return p, nil
+}
+
+func (c *PolicyCache) PeriodicallyRefresh(ctx context.Context) {
+	for ctx.Err() == nil {
+		cacheRefreshCycles.Add(1)
+
+		c.refresh(ctx)
+
+		// Wait 10 minutes between passes; this is a background refresh and
+		// there's no need to poke the servers very often.
+		time.Sleep(10 * time.Minute)
+	}
+}
+
+func (c *PolicyCache) refresh(ctx context.Context) {
+	tr := trace.New("STSCache.Refresh", c.dir)
+	defer tr.Finish()
+
+	entries, err := ioutil.ReadDir(c.dir)
+	if err != nil {
+		tr.Errorf("failed to list directory %q: %v", c.dir, err)
+		return
+	}
+	tr.Debugf("%d entries", len(entries))
+
+	for _, e := range entries {
+		if !strings.HasPrefix(e.Name(), pathPrefix) {
+			continue
+		}
+		domain := e.Name()[len(pathPrefix):]
+		cacheRefreshes.Add(1)
+		tr.Debugf("%v: refreshing", domain)
+
+		fetchCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
+		p, err := Fetch(fetchCtx, domain)
+		cancel()
+		if err != nil {
+			tr.Debugf("%v: failed to fetch: %v", domain, err)
+			cacheRefreshErrors.Add(1)
+			continue
+		}
+		tr.Debugf("%v: fetched", domain)
+
+		err = c.store(domain, p)
+		if err != nil {
+			tr.Errorf("%v: failed to store: %v", domain, err)
+		} else {
+			tr.Debugf("%v: stored", domain)
+		}
+	}
+
+	tr.Debugf("refresh done")
+}
diff --git a/internal/sts/sts_test.go b/internal/sts/sts_test.go
index bfd90e4..71b4fcc 100644
--- a/internal/sts/sts_test.go
+++ b/internal/sts/sts_test.go
@@ -2,10 +2,38 @@ package sts
 
 import (
 	"context"
+	"expvar"
+	"io/ioutil"
+	"os"
 	"testing"
 	"time"
 )
 
+func TestMain(m *testing.M) {
+	// Populate the fake policy contents, used by a few tests.
+	// httpGet will use this data instead of using the network.
+
+	// domain.com -> valid, with reasonable policy.
+	fakeContent["https://mta-sts.domain.com/.well-known/mta-sts.json"] = `
+		{
+             "version": "STSv1",
+             "mode": "enforce",
+             "mx": ["*.mail.domain.com"],
+             "max_age": 3600
+        }`
+
+	// version99 -> invalid policy (unknown version).
+	fakeContent["https://mta-sts.version99/.well-known/mta-sts.json"] = `
+		{
+             "version": "STSv99",
+             "mode": "enforce",
+             "mx": ["*.mail.version99"],
+             "max_age": 999
+        }`
+
+	os.Exit(m.Run())
+}
+
 func TestParsePolicy(t *testing.T) {
 	const pol1 = `{
   "version": "STSv1",
@@ -84,14 +112,10 @@ func TestMatchDomain(t *testing.T) {
 }
 
 func TestFetch(t *testing.T) {
+	// Note the data "fetched" for each domain comes from fakeContent, defined
+	// in TestMain above. See httpGet for more details.
+
 	// Normal fetch, all valid.
-	fakeContent["https://mta-sts.domain.com/.well-known/mta-sts.json"] = `
-		{
-             "version": "STSv1",
-             "mode": "enforce",
-             "mx": ["*.mail.example.com"],
-             "max_age": 123456
-        }`
 	p, err := Fetch(context.Background(), "domain.com")
 	if err != nil {
 		t.Errorf("failed to fetch policy: %v", err)
@@ -106,13 +130,6 @@ func TestFetch(t *testing.T) {
 	t.Logf("unknown: got error as expected: %v", err)
 
 	// Domain with an invalid policy (unknown version).
-	fakeContent["https://mta-sts.version99/.well-known/mta-sts.json"] = `
-		{
-             "version": "STSv99",
-             "mode": "enforce",
-             "mx": ["*.mail.example.com"],
-             "max_age": 123456
-        }`
 	p, err = Fetch(context.Background(), "version99")
 	if err != ErrUnknownVersion {
 		t.Errorf("expected error %v, got %v (and policy: %v)",
@@ -120,3 +137,197 @@ func TestFetch(t *testing.T) {
 	}
 	t.Logf("version99: got expected error: %v", err)
 }
+
+// Tests for the policy cache.
+
+func mustTempDir(t *testing.T) string {
+	dir, err := ioutil.TempDir("", "sts_test")
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	err = os.Chdir(dir)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	t.Logf("test directory: %q", dir)
+
+	return dir
+}
+
+func expvarMustEq(t *testing.T, name string, v *expvar.Int, expected int64) {
+	value := v.Value()
+	if value != expected {
+		t.Errorf("%s is %d, expected %d", name, value, expected)
+	}
+}
+
+func TestCacheBasics(t *testing.T) {
+	dir := mustTempDir(t)
+	c, err := NewCache(dir)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	// Note the data "fetched" for each domain comes from fakeContent, defined
+	// in TestMain above. See httpGet for more details.
+
+	// Reset the expvar counters that we use to validate hits, misses, etc.
+	cacheFetches.Set(0)
+	cacheHits.Set(0)
+
+	ctx := context.Background()
+
+	// Fetch domain.com, check we get a reasonable policy, and that it's a
+	// cache miss.
+	p, err := c.Fetch(ctx, "domain.com")
+	if err != nil || p.Check() != nil || p.MXs[0] != "*.mail.domain.com" {
+		t.Errorf("unexpected fetch result - policy = %v ; error = %v", p, err)
+	}
+	t.Logf("cache fetched domain.com: %v", p)
+	expvarMustEq(t, "cacheFetches", cacheFetches, 1)
+	expvarMustEq(t, "cacheHits", cacheHits, 0)
+
+	// Fetch domain.com again, this time we should see a cache hit.
+	p, err = c.Fetch(ctx, "domain.com")
+	if err != nil || p.Check() != nil || p.MXs[0] != "*.mail.domain.com" {
+		t.Errorf("unexpected fetch result - policy = %v ; error = %v", p, err)
+	}
+	t.Logf("cache fetched domain.com: %v", p)
+	expvarMustEq(t, "cacheFetches", cacheFetches, 2)
+	expvarMustEq(t, "cacheHits", cacheHits, 1)
+
+	// Simulate an expired cache entry by changing the mtime of domain.com's
+	// entry to the past.
+	expires := time.Now().Add(-1 * time.Minute)
+	os.Chtimes(c.domainPath("domain.com"), expires, expires)
+
+	// Do a third fetch, check that we don't get a cache hit.
+	p, err = c.Fetch(ctx, "domain.com")
+	if err != nil || p.Check() != nil || p.MXs[0] != "*.mail.domain.com" {
+		t.Errorf("unexpected fetch result - policy = %v ; error = %v", p, err)
+	}
+	t.Logf("cache fetched domain.com: %v", p)
+	expvarMustEq(t, "cacheFetches", cacheFetches, 3)
+	expvarMustEq(t, "cacheHits", cacheHits, 1)
+
+	if !t.Failed() {
+		os.RemoveAll(dir)
+	}
+}
+
+// Test how the cache behaves when the files are corrupt.
+func TestCacheBadData(t *testing.T) {
+	dir := mustTempDir(t)
+	c, err := NewCache(dir)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	ctx := context.Background()
+
+	cases := []string{
+		// Case 1: A file with invalid json, which will fail unmarshalling.
+		"this is not valid json",
+
+		// Case 2: A file with a parseable but invalid policy.
+		`{"version": "STSv1", "mode": "INVALID", "mx": ["mx"], max_age": 1}`,
+	}
+
+	for _, badContent := range cases {
+		// Reset the expvar counters that we use to validate hits, misses, etc.
+		cacheFetches.Set(0)
+		cacheHits.Set(0)
+
+		// Fetch domain.com, should result in the file being added to the
+		// cache.
+		p, err := c.Fetch(ctx, "domain.com")
+		if err != nil {
+			t.Fatalf("Fetch failed: %v", err)
+		}
+		t.Logf("cache fetched domain.com: %v", p)
+		expvarMustEq(t, "cacheFetches", cacheFetches, 1)
+		expvarMustEq(t, "cacheHits", cacheHits, 0)
+
+		// Edit the file, filling it with the bad content for this case.
+		fname := c.domainPath("domain.com")
+		err = ioutil.WriteFile(fname, []byte(badContent), 0644)
+		if err != nil {
+			t.Fatalf("error writing file: %v", err)
+		}
+
+		// We now expect Fetch to fall back to getting the policy from the
+		// network (in our case, from fakeContent).
+		p, err = c.Fetch(ctx, "domain.com")
+		if err != nil {
+			t.Fatalf("Fetch failed: %v", err)
+		}
+		t.Logf("cache fetched domain.com: %v", p)
+		expvarMustEq(t, "cacheFetches", cacheFetches, 2)
+		expvarMustEq(t, "cacheHits", cacheHits, 0)
+
+		// And now the file should be fine, resulting in a cache hit.
+		p, err = c.Fetch(ctx, "domain.com")
+		if err != nil {
+			t.Fatalf("Fetch failed: %v", err)
+		}
+		t.Logf("cache fetched domain.com: %v", p)
+		expvarMustEq(t, "cacheFetches", cacheFetches, 3)
+		expvarMustEq(t, "cacheHits", cacheHits, 1)
+
+		// Remove the file, to start with a clean slate for the next case.
+		os.Remove(fname)
+	}
+
+	if !t.Failed() {
+		os.RemoveAll(dir)
+	}
+}
+
+func mustFetch(t *testing.T, c *PolicyCache, ctx context.Context, d string) *Policy {
+	p, err := c.Fetch(ctx, d)
+	if err != nil {
+		t.Fatalf("Fetch %q failed: %v", d, err)
+	}
+	t.Logf("Fetch %q: %v", d, p)
+	return p
+}
+
+func TestCacheRefresh(t *testing.T) {
+	dir := mustTempDir(t)
+	c, err := NewCache(dir)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	ctx := context.Background()
+
+	fakeContent["https://mta-sts.refresh-test/.well-known/mta-sts.json"] = `
+		{"version": "STSv1", "mode": "enforce", "mx": ["mx"], "max_age": 100}`
+	p := mustFetch(t, c, ctx, "refresh-test")
+	if p.MaxAge != 100*time.Second {
+		t.Fatalf("policy.MaxAge is %v, expected 100s", p.MaxAge)
+	}
+
+	// Change the "published" policy, check that we see the old version at
+	// fetch (should be cached), and a new version after a refresh.
+	fakeContent["https://mta-sts.refresh-test/.well-known/mta-sts.json"] = `
+		{"version": "STSv1", "mode": "enforce", "mx": ["mx"], "max_age": 200}`
+
+	p = mustFetch(t, c, ctx, "refresh-test")
+	if p.MaxAge != 100*time.Second {
+		t.Fatalf("policy.MaxAge is %v, expected 100s", p.MaxAge)
+	}
+
+	c.refresh(ctx)
+
+	p = mustFetch(t, c, ctx, "refresh-test")
+	if p.MaxAge != 200*time.Second {
+		t.Fatalf("policy.MaxAge is %v, expected 200s", p.MaxAge)
+	}
+
+	if !t.Failed() {
+		os.RemoveAll(dir)
+	}
+}