author | Alberto Bertogli
<albertito@blitiri.com.ar> 2017-02-25 19:54:29 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2017-02-28 22:27:15 UTC |
parent | 79c0a17328fd4da360621ef28ab620e148365339 |
internal/sts/sts.go | +213 | -3 |
internal/sts/sts_test.go | +225 | -14 |
diff --git a/internal/sts/sts.go b/internal/sts/sts.go index 8d91c20..a3fc5cf 100644 --- a/internal/sts/sts.go +++ b/internal/sts/sts.go @@ -4,7 +4,6 @@ // This is an EXPERIMENTAL implementation for now. // // It lacks (at least) the following: -// - Caching. // - DNS TXT checking. // - Facilities for reporting. // @@ -14,15 +13,40 @@ import ( "context" "encoding/json" "errors" + "expvar" + "fmt" "io/ioutil" "net/http" + "os" "strings" + "sync" "time" + "blitiri.com.ar/go/chasquid/internal/safeio" + "blitiri.com.ar/go/chasquid/internal/trace" + "golang.org/x/net/context/ctxhttp" "golang.org/x/net/idna" ) +// Exported variables. +var ( + cacheFetches = expvar.NewInt("chasquid/sts/cache/fetches") + cacheHits = expvar.NewInt("chasquid/sts/cache/hits") + cacheExpired = expvar.NewInt("chasquid/sts/cache/expired") + + cacheIOErrors = expvar.NewInt("chasquid/sts/cache/ioErrors") + cacheFailedFetch = expvar.NewInt("chasquid/sts/cache/failedFetch") + cacheInvalid = expvar.NewInt("chasquid/sts/cache/invalid") + + cacheMarshalErrors = expvar.NewInt("chasquid/sts/cache/marshalErrors") + cacheUnmarshalErrors = expvar.NewInt("chasquid/sts/cache/unmarshalErrors") + + cacheRefreshCycles = expvar.NewInt("chasquid/sts/cache/refreshCycles") + cacheRefreshes = expvar.NewInt("chasquid/sts/cache/refreshes") + cacheRefreshErrors = expvar.NewInt("chasquid/sts/cache/refreshErrors") +) + // Policy represents a parsed policy. // https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-3.2 type Policy struct { @@ -169,9 +193,12 @@ func httpGet(ctx context.Context, url string) ([]byte, error) { if err != nil { return nil, err } - defer resp.Body.Close() - return ioutil.ReadAll(resp.Body) + + if resp.StatusCode == http.StatusOK { + return ioutil.ReadAll(resp.Body) + } + return nil, fmt.Errorf("HTTP response status code: %v", resp.StatusCode) } var errRejectRedirect = errors.New("redirects not allowed in MTA-STS") @@ -222,3 +249,186 @@ func domainToASCII(domain string) (string, error) { domain = strings.ToLower(domain) return idna.ToASCII(domain) } + +// PolicyCache is a caching layer for fetching policies. +// +// Policies are cached by domain, and stored in a single directory. +// The files will have as mtime the time when the policy expires, this makes +// the store simpler, as it can avoid keeping additional metadata. +// +// There is no in-memory caching. This may be added in the future, but for +// now disk is good enough for our purposes. +type PolicyCache struct { + dir string + + sync.Mutex +} + +func NewCache(dir string) (*PolicyCache, error) { + c := &PolicyCache{ + dir: dir, + } + err := os.MkdirAll(dir, 0770) + return c, err +} + +const pathPrefix = "pol:" + +func (c *PolicyCache) domainPath(domain string) string { + // We assume the domain is well formed, sanity check just in case. + if strings.Contains(domain, "/") { + panic("domain contains slash") + } + + return c.dir + "/" + pathPrefix + domain +} + +var ErrExpired = errors.New("cache entry expired") + +func (c *PolicyCache) load(domain string) (*Policy, error) { + fname := c.domainPath(domain) + + fi, err := os.Stat(fname) + if err != nil { + return nil, err + } + if time.Since(fi.ModTime()) > 0 { + cacheExpired.Add(1) + return nil, ErrExpired + } + + data, err := ioutil.ReadFile(fname) + if err != nil { + cacheIOErrors.Add(1) + return nil, err + } + + p := &Policy{} + err = json.Unmarshal(data, p) + if err != nil { + cacheUnmarshalErrors.Add(1) + return nil, err + } + + // The policy should always be valid, as we marshalled it ourselves; + // however, check it just to be safe. + if err := p.Check(); err != nil { + cacheInvalid.Add(1) + return nil, fmt.Errorf( + "%s unmarshalled invalid policy %v: %v", domain, p, err) + } + + return p, nil +} + +func (c *PolicyCache) store(domain string, p *Policy) error { + data, err := json.Marshal(p) + if err != nil { + cacheMarshalErrors.Add(1) + return fmt.Errorf("%s failed to marshal policy %v, error: %v", + domain, p, err) + } + + // Change the modification time to the future, when the policy expires. + // load will check for this to detect expired cache entries, see above for + // the details. + expires := time.Now().Add(p.MaxAge) + chTime := func(fname string) error { + return os.Chtimes(fname, expires, expires) + } + + fname := c.domainPath(domain) + err = safeio.WriteFile(fname, data, 0640, chTime) + if err != nil { + cacheIOErrors.Add(1) + } + return err +} + +func (c *PolicyCache) Fetch(ctx context.Context, domain string) (*Policy, error) { + cacheFetches.Add(1) + tr := trace.New("STSCache.Fetch", domain) + defer tr.Finish() + + p, err := c.load(domain) + if err == nil { + tr.Debugf("cache hit: %v", p) + cacheHits.Add(1) + return p, nil + } + + p, err = Fetch(ctx, domain) + if err != nil { + tr.Debugf("failed to fetch: %v", err) + cacheFailedFetch.Add(1) + return nil, err + } + tr.Debugf("fetched: %v", p) + + // We could do this asynchronously, as we got the policy to give to the + // caller. However, to make troubleshooting easier and the cost of storing + // entries easier to track down, we store synchronously. + // Note that even if the store returns an error, we pass on the policy: at + // this point we rather use the policy even if we couldn't store it in the + // cache. + err = c.store(domain, p) + if err != nil { + tr.Errorf("failed to store: %v", err) + } else { + tr.Debugf("stored") + } + + return p, nil +} + +func (c *PolicyCache) PeriodicallyRefresh(ctx context.Context) { + for ctx.Err() == nil { + cacheRefreshCycles.Add(1) + + c.refresh(ctx) + + // Wait 10 minutes between passes; this is a background refresh and + // there's no need to poke the servers very often. + time.Sleep(10 * time.Minute) + } +} + +func (c *PolicyCache) refresh(ctx context.Context) { + tr := trace.New("STSCache.Refresh", c.dir) + defer tr.Finish() + + entries, err := ioutil.ReadDir(c.dir) + if err != nil { + tr.Errorf("failed to list directory %q: %v", c.dir, err) + return + } + tr.Debugf("%d entries", len(entries)) + + for _, e := range entries { + if !strings.HasPrefix(e.Name(), pathPrefix) { + continue + } + domain := e.Name()[len(pathPrefix):] + cacheRefreshes.Add(1) + tr.Debugf("%v: refreshing", domain) + + fetchCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + p, err := Fetch(fetchCtx, domain) + cancel() + if err != nil { + tr.Debugf("%v: failed to fetch: %v", domain, err) + cacheRefreshErrors.Add(1) + continue + } + tr.Debugf("%v: fetched", domain) + + err = c.store(domain, p) + if err != nil { + tr.Errorf("%v: failed to store: %v", domain, err) + } else { + tr.Debugf("%v: stored", domain) + } + } + + tr.Debugf("refresh done") +} diff --git a/internal/sts/sts_test.go b/internal/sts/sts_test.go index bfd90e4..71b4fcc 100644 --- a/internal/sts/sts_test.go +++ b/internal/sts/sts_test.go @@ -2,10 +2,38 @@ package sts import ( "context" + "expvar" + "io/ioutil" + "os" "testing" "time" ) +func TestMain(m *testing.M) { + // Populate the fake policy contents, used by a few tests. + // httpGet will use this data instead of using the network. + + // domain.com -> valid, with reasonable policy. + fakeContent["https://mta-sts.domain.com/.well-known/mta-sts.json"] = ` + { + "version": "STSv1", + "mode": "enforce", + "mx": ["*.mail.domain.com"], + "max_age": 3600 + }` + + // version99 -> invalid policy (unknown version). + fakeContent["https://mta-sts.version99/.well-known/mta-sts.json"] = ` + { + "version": "STSv99", + "mode": "enforce", + "mx": ["*.mail.version99"], + "max_age": 999 + }` + + os.Exit(m.Run()) +} + func TestParsePolicy(t *testing.T) { const pol1 = `{ "version": "STSv1", @@ -84,14 +112,10 @@ func TestMatchDomain(t *testing.T) { } func TestFetch(t *testing.T) { + // Note the data "fetched" for each domain comes from fakeContent, defined + // in TestMain above. See httpGet for more details. + // Normal fetch, all valid. - fakeContent["https://mta-sts.domain.com/.well-known/mta-sts.json"] = ` - { - "version": "STSv1", - "mode": "enforce", - "mx": ["*.mail.example.com"], - "max_age": 123456 - }` p, err := Fetch(context.Background(), "domain.com") if err != nil { t.Errorf("failed to fetch policy: %v", err) @@ -106,13 +130,6 @@ func TestFetch(t *testing.T) { t.Logf("unknown: got error as expected: %v", err) // Domain with an invalid policy (unknown version). - fakeContent["https://mta-sts.version99/.well-known/mta-sts.json"] = ` - { - "version": "STSv99", - "mode": "enforce", - "mx": ["*.mail.example.com"], - "max_age": 123456 - }` p, err = Fetch(context.Background(), "version99") if err != ErrUnknownVersion { t.Errorf("expected error %v, got %v (and policy: %v)", @@ -120,3 +137,197 @@ func TestFetch(t *testing.T) { } t.Logf("version99: got expected error: %v", err) } + +// Tests for the policy cache. + +func mustTempDir(t *testing.T) string { + dir, err := ioutil.TempDir("", "sts_test") + if err != nil { + t.Fatal(err) + } + + err = os.Chdir(dir) + if err != nil { + t.Fatal(err) + } + + t.Logf("test directory: %q", dir) + + return dir +} + +func expvarMustEq(t *testing.T, name string, v *expvar.Int, expected int64) { + value := v.Value() + if value != expected { + t.Errorf("%s is %d, expected %d", name, value, expected) + } +} + +func TestCacheBasics(t *testing.T) { + dir := mustTempDir(t) + c, err := NewCache(dir) + if err != nil { + t.Fatal(err) + } + + // Note the data "fetched" for each domain comes from fakeContent, defined + // in TestMain above. See httpGet for more details. + + // Reset the expvar counters that we use to validate hits, misses, etc. + cacheFetches.Set(0) + cacheHits.Set(0) + + ctx := context.Background() + + // Fetch domain.com, check we get a reasonable policy, and that it's a + // cache miss. + p, err := c.Fetch(ctx, "domain.com") + if err != nil || p.Check() != nil || p.MXs[0] != "*.mail.domain.com" { + t.Errorf("unexpected fetch result - policy = %v ; error = %v", p, err) + } + t.Logf("cache fetched domain.com: %v", p) + expvarMustEq(t, "cacheFetches", cacheFetches, 1) + expvarMustEq(t, "cacheHits", cacheHits, 0) + + // Fetch domain.com again, this time we should see a cache hit. + p, err = c.Fetch(ctx, "domain.com") + if err != nil || p.Check() != nil || p.MXs[0] != "*.mail.domain.com" { + t.Errorf("unexpected fetch result - policy = %v ; error = %v", p, err) + } + t.Logf("cache fetched domain.com: %v", p) + expvarMustEq(t, "cacheFetches", cacheFetches, 2) + expvarMustEq(t, "cacheHits", cacheHits, 1) + + // Simulate an expired cache entry by changing the mtime of domain.com's + // entry to the past. + expires := time.Now().Add(-1 * time.Minute) + os.Chtimes(c.domainPath("domain.com"), expires, expires) + + // Do a third fetch, check that we don't get a cache hit. + p, err = c.Fetch(ctx, "domain.com") + if err != nil || p.Check() != nil || p.MXs[0] != "*.mail.domain.com" { + t.Errorf("unexpected fetch result - policy = %v ; error = %v", p, err) + } + t.Logf("cache fetched domain.com: %v", p) + expvarMustEq(t, "cacheFetches", cacheFetches, 3) + expvarMustEq(t, "cacheHits", cacheHits, 1) + + if !t.Failed() { + os.RemoveAll(dir) + } +} + +// Test how the cache behaves when the files are corrupt. +func TestCacheBadData(t *testing.T) { + dir := mustTempDir(t) + c, err := NewCache(dir) + if err != nil { + t.Fatal(err) + } + + ctx := context.Background() + + cases := []string{ + // Case 1: A file with invalid json, which will fail unmarshalling. + "this is not valid json", + + // Case 2: A file with a parseable but invalid policy. + `{"version": "STSv1", "mode": "INVALID", "mx": ["mx"], max_age": 1}`, + } + + for _, badContent := range cases { + // Reset the expvar counters that we use to validate hits, misses, etc. + cacheFetches.Set(0) + cacheHits.Set(0) + + // Fetch domain.com, should result in the file being added to the + // cache. + p, err := c.Fetch(ctx, "domain.com") + if err != nil { + t.Fatalf("Fetch failed: %v", err) + } + t.Logf("cache fetched domain.com: %v", p) + expvarMustEq(t, "cacheFetches", cacheFetches, 1) + expvarMustEq(t, "cacheHits", cacheHits, 0) + + // Edit the file, filling it with the bad content for this case. + fname := c.domainPath("domain.com") + err = ioutil.WriteFile(fname, []byte(badContent), 0644) + if err != nil { + t.Fatalf("error writing file: %v", err) + } + + // We now expect Fetch to fall back to getting the policy from the + // network (in our case, from fakeContent). + p, err = c.Fetch(ctx, "domain.com") + if err != nil { + t.Fatalf("Fetch failed: %v", err) + } + t.Logf("cache fetched domain.com: %v", p) + expvarMustEq(t, "cacheFetches", cacheFetches, 2) + expvarMustEq(t, "cacheHits", cacheHits, 0) + + // And now the file should be fine, resulting in a cache hit. + p, err = c.Fetch(ctx, "domain.com") + if err != nil { + t.Fatalf("Fetch failed: %v", err) + } + t.Logf("cache fetched domain.com: %v", p) + expvarMustEq(t, "cacheFetches", cacheFetches, 3) + expvarMustEq(t, "cacheHits", cacheHits, 1) + + // Remove the file, to start with a clean slate for the next case. + os.Remove(fname) + } + + if !t.Failed() { + os.RemoveAll(dir) + } +} + +func mustFetch(t *testing.T, c *PolicyCache, ctx context.Context, d string) *Policy { + p, err := c.Fetch(ctx, d) + if err != nil { + t.Fatalf("Fetch %q failed: %v", d, err) + } + t.Logf("Fetch %q: %v", d, p) + return p +} + +func TestCacheRefresh(t *testing.T) { + dir := mustTempDir(t) + c, err := NewCache(dir) + if err != nil { + t.Fatal(err) + } + + ctx := context.Background() + + fakeContent["https://mta-sts.refresh-test/.well-known/mta-sts.json"] = ` + {"version": "STSv1", "mode": "enforce", "mx": ["mx"], "max_age": 100}` + p := mustFetch(t, c, ctx, "refresh-test") + if p.MaxAge != 100*time.Second { + t.Fatalf("policy.MaxAge is %v, expected 100s", p.MaxAge) + } + + // Change the "published" policy, check that we see the old version at + // fetch (should be cached), and a new version after a refresh. + fakeContent["https://mta-sts.refresh-test/.well-known/mta-sts.json"] = ` + {"version": "STSv1", "mode": "enforce", "mx": ["mx"], "max_age": 200}` + + p = mustFetch(t, c, ctx, "refresh-test") + if p.MaxAge != 100*time.Second { + t.Fatalf("policy.MaxAge is %v, expected 100s", p.MaxAge) + } + + c.refresh(ctx) + + p = mustFetch(t, c, ctx, "refresh-test") + if p.MaxAge != 200*time.Second { + t.Fatalf("policy.MaxAge is %v, expected 200s", p.MaxAge) + } + + if !t.Failed() { + os.RemoveAll(dir) + } +}