git » chasquid » commit 23deaf1

Reinstate the MTA-STS (Strict Transport Security) implementation

author Alberto Bertogli
2017-04-11 00:03:05 UTC
committer Alberto Bertogli
2018-07-01 11:19:02 UTC
parent a94253ba2539890370894313c388261d798bb4c7

Reinstate the MTA-STS (Strict Transport Security) implementation

This commit brings back the experimental MTA-STS (Strict Transport
Security) implementation, removed in commit
7f5bedf4aa7f678f11e4547941cdbc6b846c8b87.

We will continue development in the "sts" branch, subject to rebase,
until it is ready to be integrated into "next" again.

chasquid.go +9 -1
cmd/smtp-check/smtp-check.go +25 -0
internal/courier/smtp.go +82 -6
internal/courier/smtp_test.go +1 -1
internal/sts/sts.go +435 -0
internal/sts/sts_test.go +384 -0

diff --git a/chasquid.go b/chasquid.go
index c39e416..63a3cea 100644
--- a/chasquid.go
+++ b/chasquid.go
@@ -7,6 +7,7 @@
 package main
 
 import (
+	"context"
 	"expvar"
 	"flag"
 	"fmt"
@@ -25,6 +26,7 @@ import (
 	"blitiri.com.ar/go/chasquid/internal/maillog"
 	"blitiri.com.ar/go/chasquid/internal/normalize"
 	"blitiri.com.ar/go/chasquid/internal/smtpsrv"
+	"blitiri.com.ar/go/chasquid/internal/sts"
 	"blitiri.com.ar/go/chasquid/internal/userdb"
 	"blitiri.com.ar/go/log"
 	"blitiri.com.ar/go/systemd"
@@ -146,12 +148,18 @@ func main() {
 
 	dinfo := s.InitDomainInfo(conf.DataDir + "/domaininfo")
 
+	stsCache, err := sts.NewCache(conf.DataDir + "/sts-cache")
+	if err != nil {
+		log.Fatalf("Failed to initialize STS cache: %v", err)
+	}
+	go stsCache.PeriodicallyRefresh(context.Background())
+
 	localC := &courier.Procmail{
 		Binary:  conf.MailDeliveryAgentBin,
 		Args:    conf.MailDeliveryAgentArgs,
 		Timeout: 30 * time.Second,
 	}
-	remoteC := &courier.SMTP{Dinfo: dinfo}
+	remoteC := &courier.SMTP{Dinfo: dinfo, STSCache: stsCache}
 	s.InitQueue(conf.DataDir+"/queue", localC, remoteC)
 
 	// Load the addresses and listeners.
diff --git a/cmd/smtp-check/smtp-check.go b/cmd/smtp-check/smtp-check.go
index 4ce912f..b9215fc 100644
--- a/cmd/smtp-check/smtp-check.go
+++ b/cmd/smtp-check/smtp-check.go
@@ -5,12 +5,15 @@
 package main
 
 import (
+	"context"
 	"crypto/tls"
 	"flag"
 	"log"
 	"net"
 	"net/smtp"
+	"time"
 
+	"blitiri.com.ar/go/chasquid/internal/sts"
 	"blitiri.com.ar/go/chasquid/internal/tlsconst"
 	"blitiri.com.ar/go/spf"
 
@@ -37,6 +40,21 @@ func main() {
 		log.Fatalf("IDNA conversion failed: %v", err)
 	}
 
+	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+	defer cancel()
+
+	log.Printf("=== STS policy")
+	policy, err := sts.UncheckedFetch(ctx, domain)
+	if err != nil {
+		log.Printf("Not available (%s)", err)
+	} else {
+		log.Printf("Parsed contents:  [%+v]\n", *policy)
+		if err := policy.Check(); err != nil {
+			log.Fatalf("Invalid: %v", err)
+		}
+		log.Printf("OK")
+	}
+
 	mxs, err := net.LookupMX(domain)
 	if err != nil {
 		log.Fatalf("MX lookup: %v", err)
@@ -86,6 +104,13 @@ func main() {
 			c.Close()
 		}
 
+		if policy != nil {
+			if !policy.MXIsAllowed(mx.Host) {
+				log.Fatalf("NOT allowed by STS policy")
+			}
+			log.Printf("Allowed by policy")
+		}
+
 		log.Printf("")
 	}
 
diff --git a/internal/courier/smtp.go b/internal/courier/smtp.go
index 4c779af..f85741d 100644
--- a/internal/courier/smtp.go
+++ b/internal/courier/smtp.go
@@ -1,6 +1,7 @@
 package courier
 
 import (
+	"context"
 	"crypto/tls"
 	"expvar"
 	"flag"
@@ -13,6 +14,7 @@ import (
 	"blitiri.com.ar/go/chasquid/internal/domaininfo"
 	"blitiri.com.ar/go/chasquid/internal/envelope"
 	"blitiri.com.ar/go/chasquid/internal/smtp"
+	"blitiri.com.ar/go/chasquid/internal/sts"
 	"blitiri.com.ar/go/chasquid/internal/trace"
 )
 
@@ -30,17 +32,26 @@ var (
 	// TODO: replace this with proper lookup interception once it is supported
 	// by Go.
 	netLookupMX = net.LookupMX
+
+	// Enable STS policy checking; this is an experimental flag and will be
+	// removed in the future, once this is made the default.
+	enableSTS = flag.Bool("experimental__enable_sts", false,
+		"enable STS policy checking; EXPERIMENTAL")
 )
 
 // Exported variables.
 var (
 	tlsCount   = expvar.NewMap("chasquid/smtpOut/tlsCount")
 	slcResults = expvar.NewMap("chasquid/smtpOut/securityLevelChecks")
+
+	stsSecurityModes   = expvar.NewMap("chasquid/smtpOut/sts/mode")
+	stsSecurityResults = expvar.NewMap("chasquid/smtpOut/sts/security")
 )
 
 // SMTP delivers remote mail via outgoing SMTP.
 type SMTP struct {
-	Dinfo *domaininfo.DB
+	Dinfo    *domaininfo.DB
+	STSCache *sts.PolicyCache
 }
 
 // Deliver an email. On failures, returns an error, and whether or not it is
@@ -62,7 +73,9 @@ func (s *SMTP) Deliver(from string, to string, data []byte) (error, bool) {
 		a.from = ""
 	}
 
-	mxs, err := lookupMXs(a.tr, a.toDomain)
+	a.stsPolicy = s.fetchSTSPolicy(a.tr, a.toDomain)
+
+	mxs, err := lookupMXs(a.tr, a.toDomain, a.stsPolicy)
 	if err != nil || len(mxs) == 0 {
 		// Note this is considered a permanent error.
 		// This is in line with what other servers (Exim) do. However, the
@@ -108,6 +121,8 @@ type attempt struct {
 	toDomain    string
 	helloDomain string
 
+	stsPolicy *sts.Policy
+
 	tr *trace.Trace
 }
 
@@ -175,6 +190,18 @@ retry:
 	}
 	slcResults.Add("pass", 1)
 
+	if a.stsPolicy != nil && a.stsPolicy.Mode == sts.Enforce {
+		// The connection MUST be validated TLS.
+		// https://tools.ietf.org/html/draft-ietf-uta-mta-sts-03#section-4.2
+		if secLevel != domaininfo.SecLevel_TLS_SECURE {
+			stsSecurityResults.Add("fail", 1)
+			return a.tr.Errorf("invalid security level (%v) for STS policy",
+				secLevel), false
+		}
+		stsSecurityResults.Add("pass", 1)
+		a.tr.Debugf("STS policy: connection is using valid TLS")
+	}
+
 	if err = c.MailAndRcpt(a.from, a.to); err != nil {
 		return a.tr.Errorf("MAIL+RCPT %v", err), smtp.IsPermanent(err)
 	}
@@ -199,7 +226,29 @@ retry:
 	return nil, false
 }
 
-func lookupMXs(tr *trace.Trace, domain string) ([]string, error) {
+func (s *SMTP) fetchSTSPolicy(tr *trace.Trace, domain string) *sts.Policy {
+	if !*enableSTS {
+		return nil
+	}
+	if s.STSCache == nil {
+		return nil
+	}
+
+	ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
+	defer cancel()
+
+	policy, err := s.STSCache.Fetch(ctx, domain)
+	if err != nil {
+		return nil
+	}
+
+	tr.Debugf("got STS policy")
+	stsSecurityModes.Add(string(policy.Mode), 1)
+
+	return policy
+}
+
+func lookupMXs(tr *trace.Trace, domain string, policy *sts.Policy) ([]string, error) {
 	domain, err := idna.ToASCII(domain)
 	if err != nil {
 		return nil, err
@@ -239,12 +288,39 @@ func lookupMXs(tr *trace.Trace, domain string) ([]string, error) {
 	// This case is explicitly covered by the SMTP RFC.
 	// https://tools.ietf.org/html/rfc5321#section-5.1
 
-	// Cap the list of MXs to 5 hosts, to keep delivery attempt times sane
-	// and prevent abuse.
-	if len(mxs) > 5 {
+	mxs = filterMXs(tr, policy, mxs)
+	if len(mxs) == 0 {
+		tr.Errorf("domain %q has no valid MX/A record", domain)
+	} else if len(mxs) > 5 {
+		// Cap the list of MXs to 5 hosts, to keep delivery attempt times
+		// sane and prevent abuse.
 		mxs = mxs[:5]
 	}
 
 	tr.Debugf("MXs: %v", mxs)
 	return mxs, nil
 }
+
+func filterMXs(tr *trace.Trace, p *sts.Policy, mxs []string) []string {
+	if p == nil {
+		return mxs
+	}
+
+	filtered := []string{}
+	for _, mx := range mxs {
+		if p.MXIsAllowed(mx) {
+			filtered = append(filtered, mx)
+		} else {
+			tr.Printf("MX %q not allowed by policy, skipping", mx)
+		}
+	}
+
+	// We don't want to return an empty set if the mode is not enforce.
+	// This prevents failures for policies in reporting mode.
+	// https://tools.ietf.org/html/draft-ietf-uta-mta-sts-03#section-5.2
+	if len(filtered) == 0 && p.Mode != sts.Enforce {
+		filtered = mxs
+	}
+
+	return filtered
+}
diff --git a/internal/courier/smtp_test.go b/internal/courier/smtp_test.go
index b12d7be..6aaf38b 100644
--- a/internal/courier/smtp_test.go
+++ b/internal/courier/smtp_test.go
@@ -35,7 +35,7 @@ func newSMTP(t *testing.T) (*SMTP, string) {
 		t.Fatal(err)
 	}
 
-	return &SMTP{dinfo}, dir
+	return &SMTP{dinfo, nil}, dir
 }
 
 // Fake server, to test SMTP out.
diff --git a/internal/sts/sts.go b/internal/sts/sts.go
new file mode 100644
index 0000000..7a1a4a2
--- /dev/null
+++ b/internal/sts/sts.go
@@ -0,0 +1,435 @@
+// Package sts implements the MTA-STS (Strict Transport Security), based on
+// the current draft, https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02.
+//
+// This is an EXPERIMENTAL implementation for now.
+//
+// It lacks (at least) the following:
+// - DNS TXT checking.
+// - Facilities for reporting.
+//
+package sts
+
+import (
+	"context"
+	"encoding/json"
+	"errors"
+	"expvar"
+	"fmt"
+	"io"
+	"io/ioutil"
+	"net/http"
+	"os"
+	"strings"
+	"sync"
+	"time"
+
+	"blitiri.com.ar/go/chasquid/internal/safeio"
+	"blitiri.com.ar/go/chasquid/internal/trace"
+
+	"golang.org/x/net/context/ctxhttp"
+	"golang.org/x/net/idna"
+)
+
+// Exported variables.
+var (
+	cacheFetches = expvar.NewInt("chasquid/sts/cache/fetches")
+	cacheHits    = expvar.NewInt("chasquid/sts/cache/hits")
+	cacheExpired = expvar.NewInt("chasquid/sts/cache/expired")
+
+	cacheIOErrors    = expvar.NewInt("chasquid/sts/cache/ioErrors")
+	cacheFailedFetch = expvar.NewInt("chasquid/sts/cache/failedFetch")
+	cacheInvalid     = expvar.NewInt("chasquid/sts/cache/invalid")
+
+	cacheMarshalErrors   = expvar.NewInt("chasquid/sts/cache/marshalErrors")
+	cacheUnmarshalErrors = expvar.NewInt("chasquid/sts/cache/unmarshalErrors")
+
+	cacheRefreshCycles = expvar.NewInt("chasquid/sts/cache/refreshCycles")
+	cacheRefreshes     = expvar.NewInt("chasquid/sts/cache/refreshes")
+	cacheRefreshErrors = expvar.NewInt("chasquid/sts/cache/refreshErrors")
+)
+
+// Policy represents a parsed policy.
+// https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-3.2
+type Policy struct {
+	Version string        `json:"version"`
+	Mode    Mode          `json:"mode"`
+	MXs     []string      `json:"mx"`
+	MaxAge  time.Duration `json:"max_age"`
+}
+
+type Mode string
+
+// Valid modes.
+const (
+	Enforce = Mode("enforce")
+	Report  = Mode("report")
+)
+
+// parsePolicy parses a JSON representation of the policy, and returns the
+// corresponding Policy structure.
+func parsePolicy(raw []byte) (*Policy, error) {
+	p := &Policy{}
+	if err := json.Unmarshal(raw, p); err != nil {
+		return nil, err
+	}
+
+	// MaxAge is in seconds.
+	p.MaxAge = p.MaxAge * time.Second
+
+	return p, nil
+}
+
+var (
+	ErrUnknownVersion = errors.New("unknown policy version")
+	ErrInvalidMaxAge  = errors.New("invalid max_age")
+	ErrInvalidMode    = errors.New("invalid mode")
+	ErrInvalidMX      = errors.New("invalid mx")
+)
+
+// Check that the policy contents are valid.
+func (p *Policy) Check() error {
+	if p.Version != "STSv1" {
+		return ErrUnknownVersion
+	}
+	if p.MaxAge <= 0 {
+		return ErrInvalidMaxAge
+	}
+
+	if p.Mode != Enforce && p.Mode != Report {
+		return ErrInvalidMode
+	}
+
+	// "mx" field is required, and the policy is invalid if it's not present.
+	// https://mailarchive.ietf.org/arch/msg/uta/Omqo1Bw6rJbrTMl2Zo69IJr35Qo
+	if len(p.MXs) == 0 {
+		return ErrInvalidMX
+	}
+
+	return nil
+}
+
+// MXMatches checks if the given MX is allowed, according to the policy.
+// https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-4.1
+func (p *Policy) MXIsAllowed(mx string) bool {
+	for _, pattern := range p.MXs {
+		if matchDomain(mx, pattern) {
+			return true
+		}
+	}
+
+	return false
+}
+
+// UncheckedFetch fetches and parses the policy, but does NOT check it.
+// This can be useful for debugging and troubleshooting, but you should always
+// call Check on the policy before using it.
+func UncheckedFetch(ctx context.Context, domain string) (*Policy, error) {
+	// Convert the domain to ascii form, as httpGet does not support IDNs in
+	// any other way.
+	domain, err := idna.ToASCII(domain)
+	if err != nil {
+		return nil, err
+	}
+
+	url := urlForDomain(domain)
+	rawPolicy, err := httpGet(ctx, url)
+	if err != nil {
+		return nil, err
+	}
+
+	return parsePolicy(rawPolicy)
+}
+
+// Fake URL for testing purposes, so we can do more end-to-end tests,
+// including the HTTP fetching code.
+var fakeURLForTesting string
+
+func urlForDomain(domain string) string {
+	if fakeURLForTesting != "" {
+		return fakeURLForTesting + "/" + domain
+	}
+
+	// URL composed from the domain, as explained in:
+	// https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-3.3
+	// https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-3.2
+	return "https://mta-sts." + domain + "/.well-known/mta-sts.json"
+}
+
+// Fetch a policy for the given domain. Note this results in various network
+// lookups and HTTPS GETs, so it can be slow.
+// The returned policy is parsed and sanity-checked (using Policy.Check), so
+// it should be safe to use.
+func Fetch(ctx context.Context, domain string) (*Policy, error) {
+	p, err := UncheckedFetch(ctx, domain)
+	if err != nil {
+		return nil, err
+	}
+
+	err = p.Check()
+	if err != nil {
+		return nil, err
+	}
+
+	return p, nil
+}
+
+// httpGet performs an HTTP GET of the given URL, using the context and
+// rejecting redirects, as per the standard.
+func httpGet(ctx context.Context, url string) ([]byte, error) {
+	client := &http.Client{
+		// We MUST NOT follow redirects, see
+		// https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-3.3
+		CheckRedirect: rejectRedirect,
+	}
+
+	// Note that http does not care for the context deadline, so we need to
+	// construct it here.
+	if deadline, ok := ctx.Deadline(); ok {
+		client.Timeout = deadline.Sub(time.Now())
+	}
+
+	resp, err := ctxhttp.Get(ctx, client, url)
+	if err != nil {
+		return nil, err
+	}
+	defer resp.Body.Close()
+
+	if resp.StatusCode == http.StatusOK {
+		// Read but up to 10k; policies should be way smaller than that, and
+		// having a limit prevents abuse/accidents with very large replies.
+		return ioutil.ReadAll(&io.LimitedReader{resp.Body, 10 * 1024})
+	}
+	return nil, fmt.Errorf("HTTP response status code: %v", resp.StatusCode)
+}
+
+var errRejectRedirect = errors.New("redirects not allowed in MTA-STS")
+
+func rejectRedirect(req *http.Request, via []*http.Request) error {
+	return errRejectRedirect
+}
+
+// matchDomain checks if the domain matches the given pattern, according to
+// https://tools.ietf.org/html/rfc6125#section-6.4
+// (from https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-4.1).
+func matchDomain(domain, pattern string) bool {
+	domain, dErr := domainToASCII(domain)
+	pattern, pErr := domainToASCII(pattern)
+	if dErr != nil || pErr != nil {
+		// Domains should already have been checked and normalized by the
+		// caller, exposing this is not worth the API complexity in this case.
+		return false
+	}
+
+	domainLabels := strings.Split(domain, ".")
+	patternLabels := strings.Split(pattern, ".")
+
+	if len(domainLabels) != len(patternLabels) {
+		return false
+	}
+
+	for i, p := range patternLabels {
+		// Wildcards only apply to the first part, see
+		// https://tools.ietf.org/html/rfc6125#section-6.4.3 #1 and #2.
+		// This also allows us to do the lenght comparison above.
+		if p == "*" && i == 0 {
+			continue
+		}
+
+		if p != domainLabels[i] {
+			return false
+		}
+	}
+
+	return true
+}
+
+// domainToASCII converts the domain to ASCII form, similar to idna.ToASCII
+// but with some preprocessing convenient for our use cases.
+func domainToASCII(domain string) (string, error) {
+	domain = strings.TrimSuffix(domain, ".")
+	domain = strings.ToLower(domain)
+	return idna.ToASCII(domain)
+}
+
+// PolicyCache is a caching layer for fetching policies.
+//
+// Policies are cached by domain, and stored in a single directory.
+// The files will have as mtime the time when the policy expires, this makes
+// the store simpler, as it can avoid keeping additional metadata.
+//
+// There is no in-memory caching. This may be added in the future, but for
+// now disk is good enough for our purposes.
+type PolicyCache struct {
+	dir string
+
+	sync.Mutex
+}
+
+func NewCache(dir string) (*PolicyCache, error) {
+	c := &PolicyCache{
+		dir: dir,
+	}
+	err := os.MkdirAll(dir, 0770)
+	return c, err
+}
+
+const pathPrefix = "pol:"
+
+func (c *PolicyCache) domainPath(domain string) string {
+	// We assume the domain is well formed, sanity check just in case.
+	if strings.Contains(domain, "/") {
+		panic("domain contains slash")
+	}
+
+	return c.dir + "/" + pathPrefix + domain
+}
+
+var ErrExpired = errors.New("cache entry expired")
+
+func (c *PolicyCache) load(domain string) (*Policy, error) {
+	fname := c.domainPath(domain)
+
+	fi, err := os.Stat(fname)
+	if err != nil {
+		return nil, err
+	}
+	if time.Since(fi.ModTime()) > 0 {
+		cacheExpired.Add(1)
+		return nil, ErrExpired
+	}
+
+	data, err := ioutil.ReadFile(fname)
+	if err != nil {
+		cacheIOErrors.Add(1)
+		return nil, err
+	}
+
+	p := &Policy{}
+	err = json.Unmarshal(data, p)
+	if err != nil {
+		cacheUnmarshalErrors.Add(1)
+		return nil, err
+	}
+
+	// The policy should always be valid, as we marshalled it ourselves;
+	// however, check it just to be safe.
+	if err := p.Check(); err != nil {
+		cacheInvalid.Add(1)
+		return nil, fmt.Errorf(
+			"%s unmarshalled invalid policy %v: %v", domain, p, err)
+	}
+
+	return p, nil
+}
+
+func (c *PolicyCache) store(domain string, p *Policy) error {
+	data, err := json.Marshal(p)
+	if err != nil {
+		cacheMarshalErrors.Add(1)
+		return fmt.Errorf("%s failed to marshal policy %v, error: %v",
+			domain, p, err)
+	}
+
+	// Change the modification time to the future, when the policy expires.
+	// load will check for this to detect expired cache entries, see above for
+	// the details.
+	expires := time.Now().Add(p.MaxAge)
+	chTime := func(fname string) error {
+		return os.Chtimes(fname, expires, expires)
+	}
+
+	fname := c.domainPath(domain)
+	err = safeio.WriteFile(fname, data, 0640, chTime)
+	if err != nil {
+		cacheIOErrors.Add(1)
+	}
+	return err
+}
+
+func (c *PolicyCache) Fetch(ctx context.Context, domain string) (*Policy, error) {
+	cacheFetches.Add(1)
+	tr := trace.New("STSCache.Fetch", domain)
+	defer tr.Finish()
+
+	p, err := c.load(domain)
+	if err == nil {
+		tr.Debugf("cache hit: %v", p)
+		cacheHits.Add(1)
+		return p, nil
+	}
+
+	p, err = Fetch(ctx, domain)
+	if err != nil {
+		tr.Debugf("failed to fetch: %v", err)
+		cacheFailedFetch.Add(1)
+		return nil, err
+	}
+	tr.Debugf("fetched: %v", p)
+
+	// We could do this asynchronously, as we got the policy to give to the
+	// caller. However, to make troubleshooting easier and the cost of storing
+	// entries easier to track down, we store synchronously.
+	// Note that even if the store returns an error, we pass on the policy: at
+	// this point we rather use the policy even if we couldn't store it in the
+	// cache.
+	err = c.store(domain, p)
+	if err != nil {
+		tr.Errorf("failed to store: %v", err)
+	} else {
+		tr.Debugf("stored")
+	}
+
+	return p, nil
+}
+
+func (c *PolicyCache) PeriodicallyRefresh(ctx context.Context) {
+	for ctx.Err() == nil {
+		cacheRefreshCycles.Add(1)
+
+		c.refresh(ctx)
+
+		// Wait 10 minutes between passes; this is a background refresh and
+		// there's no need to poke the servers very often.
+		time.Sleep(10 * time.Minute)
+	}
+}
+
+func (c *PolicyCache) refresh(ctx context.Context) {
+	tr := trace.New("STSCache.Refresh", c.dir)
+	defer tr.Finish()
+
+	entries, err := ioutil.ReadDir(c.dir)
+	if err != nil {
+		tr.Errorf("failed to list directory %q: %v", c.dir, err)
+		return
+	}
+	tr.Debugf("%d entries", len(entries))
+
+	for _, e := range entries {
+		if !strings.HasPrefix(e.Name(), pathPrefix) {
+			continue
+		}
+		domain := e.Name()[len(pathPrefix):]
+		cacheRefreshes.Add(1)
+		tr.Debugf("%v: refreshing", domain)
+
+		fetchCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
+		p, err := Fetch(fetchCtx, domain)
+		cancel()
+		if err != nil {
+			tr.Debugf("%v: failed to fetch: %v", domain, err)
+			cacheRefreshErrors.Add(1)
+			continue
+		}
+		tr.Debugf("%v: fetched", domain)
+
+		err = c.store(domain, p)
+		if err != nil {
+			tr.Errorf("%v: failed to store: %v", domain, err)
+		} else {
+			tr.Debugf("%v: stored", domain)
+		}
+	}
+
+	tr.Debugf("refresh done")
+}
diff --git a/internal/sts/sts_test.go b/internal/sts/sts_test.go
new file mode 100644
index 0000000..bc26940
--- /dev/null
+++ b/internal/sts/sts_test.go
@@ -0,0 +1,384 @@
+package sts
+
+import (
+	"context"
+	"expvar"
+	"fmt"
+	"io/ioutil"
+	"net/http"
+	"net/http/httptest"
+	"os"
+	"strconv"
+	"testing"
+	"time"
+)
+
+// Test policy for each of the requested domains.  Will be served by the test
+// HTTP server.
+var policyForDomain = map[string]string{
+	// domain.com -> valid, with reasonable policy.
+	"domain.com": `
+		{
+             "version": "STSv1",
+             "mode": "enforce",
+             "mx": ["*.mail.domain.com"],
+             "max_age": 3600
+        }`,
+
+	// version99 -> invalid policy (unknown version).
+	"version99": `
+		{
+             "version": "STSv99",
+             "mode": "enforce",
+             "mx": ["*.mail.version99"],
+             "max_age": 999
+        }`,
+}
+
+func testHTTPHandler(w http.ResponseWriter, r *http.Request) {
+	// For testing, the domain in the path (see urlForDomain).
+	policy, ok := policyForDomain[r.URL.Path[1:]]
+	if !ok {
+		http.Error(w, "not found", 404)
+		return
+	}
+	fmt.Fprintln(w, policy)
+	return
+}
+
+func TestMain(m *testing.M) {
+	// Create a test HTTP server, used by the more end-to-end tests.
+	httpServer := httptest.NewServer(http.HandlerFunc(testHTTPHandler))
+
+	fakeURLForTesting = httpServer.URL
+	os.Exit(m.Run())
+}
+
+func TestParsePolicy(t *testing.T) {
+	const pol1 = `{
+  "version": "STSv1",
+  "mode": "enforce",
+  "mx": ["*.mail.example.com"],
+  "max_age": 123456
+}
+`
+	p, err := parsePolicy([]byte(pol1))
+	if err != nil {
+		t.Errorf("failed to parse policy: %v", err)
+	}
+
+	t.Logf("pol1: %+v", p)
+}
+
+func TestCheckPolicy(t *testing.T) {
+	validPs := []Policy{
+		{Version: "STSv1", Mode: "enforce", MaxAge: 1 * time.Hour,
+			MXs: []string{"mx1", "mx2"}},
+		{Version: "STSv1", Mode: "report", MaxAge: 1 * time.Hour,
+			MXs: []string{"mx1"}},
+	}
+	for i, p := range validPs {
+		if err := p.Check(); err != nil {
+			t.Errorf("%d policy %v failed check: %v", i, p, err)
+		}
+	}
+
+	invalid := []struct {
+		p        Policy
+		expected error
+	}{
+		{Policy{Version: "STSv2"}, ErrUnknownVersion},
+		{Policy{Version: "STSv1"}, ErrInvalidMaxAge},
+		{Policy{Version: "STSv1", MaxAge: 1, Mode: "blah"}, ErrInvalidMode},
+		{Policy{Version: "STSv1", MaxAge: 1, Mode: "enforce"}, ErrInvalidMX},
+		{Policy{Version: "STSv1", MaxAge: 1, Mode: "enforce", MXs: []string{}},
+			ErrInvalidMX},
+	}
+	for i, c := range invalid {
+		if err := c.p.Check(); err != c.expected {
+			t.Errorf("%d policy %v check: expected %v, got %v", i, c.p,
+				c.expected, err)
+		}
+	}
+}
+
+func TestMatchDomain(t *testing.T) {
+	cases := []struct {
+		domain, pattern string
+		expected        bool
+	}{
+		{"lalala", "lalala", true},
+		{"a.b.", "a.b", true},
+		{"a.b", "a.b.", true},
+		{"abc.com", "*.com", true},
+
+		{"abc.com", "abc.*.com", false},
+		{"abc.com", "x.abc.com", false},
+		{"x.abc.com", "*.*.com", false},
+
+		{"ñaca.com", "ñaca.com", true},
+		{"Ñaca.com", "ñaca.com", true},
+		{"ñaca.com", "Ñaca.com", true},
+		{"x.ñaca.com", "x.xn--aca-6ma.com", true},
+		{"x.naca.com", "x.xn--aca-6ma.com", false},
+	}
+
+	for _, c := range cases {
+		if r := matchDomain(c.domain, c.pattern); r != c.expected {
+			t.Errorf("matchDomain(%q, %q) = %v, expected %v",
+				c.domain, c.pattern, r, c.expected)
+		}
+	}
+}
+
+func TestFetch(t *testing.T) {
+	// Note the data "fetched" for each domain comes from policyForDomain,
+	// defined in TestMain above. See httpGet for more details.
+
+	// Normal fetch, all valid.
+	p, err := Fetch(context.Background(), "domain.com")
+	if err != nil {
+		t.Errorf("failed to fetch policy: %v", err)
+	}
+	t.Logf("domain.com: %+v", p)
+
+	// Domain without a policy (HTTP get fails).
+	p, err = Fetch(context.Background(), "unknown")
+	if err == nil {
+		t.Errorf("fetched unknown policy: %v", p)
+	}
+	t.Logf("unknown: got error as expected: %v", err)
+
+	// Domain with an invalid policy (unknown version).
+	p, err = Fetch(context.Background(), "version99")
+	if err != ErrUnknownVersion {
+		t.Errorf("expected error %v, got %v (and policy: %v)",
+			ErrUnknownVersion, err, p)
+	}
+	t.Logf("version99: got expected error: %v", err)
+}
+
+func TestPolicyTooBig(t *testing.T) {
+	// Construct a valid but very large JSON as a policy.
+	raw := `{"version": "STSv1", "mode": "enforce", "mx": [`
+	for i := 0; i < 2000; i++ {
+		raw += fmt.Sprintf("\"mx%d\", ", i)
+	}
+	raw += `"mxlast"], "max_age": 100}`
+	policyForDomain["toobig"] = raw
+
+	_, err := Fetch(context.Background(), "toobig")
+	if err == nil {
+		t.Errorf("fetch worked, but should have failed")
+	}
+	t.Logf("got error as expected: %v", err)
+}
+
+// Tests for the policy cache.
+
+func mustTempDir(t *testing.T) string {
+	dir, err := ioutil.TempDir("", "sts_test")
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	err = os.Chdir(dir)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	t.Logf("test directory: %q", dir)
+
+	return dir
+}
+
+func expvarMustEq(t *testing.T, name string, v *expvar.Int, expected int) {
+	// TODO: Use v.Value once we drop support of Go 1.7.
+	value, _ := strconv.Atoi(v.String())
+	if value != expected {
+		t.Errorf("%s is %d, expected %d", name, value, expected)
+	}
+}
+
+func TestCacheBasics(t *testing.T) {
+	dir := mustTempDir(t)
+	c, err := NewCache(dir)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	// Note the data "fetched" for each domain comes from policyForDomain,
+	// defined in TestMain above. See httpGet for more details.
+
+	// Reset the expvar counters that we use to validate hits, misses, etc.
+	cacheFetches.Set(0)
+	cacheHits.Set(0)
+
+	ctx := context.Background()
+
+	// Fetch domain.com, check we get a reasonable policy, and that it's a
+	// cache miss.
+	p, err := c.Fetch(ctx, "domain.com")
+	if err != nil || p.Check() != nil || p.MXs[0] != "*.mail.domain.com" {
+		t.Errorf("unexpected fetch result - policy = %v ; error = %v", p, err)
+	}
+	t.Logf("cache fetched domain.com: %v", p)
+	expvarMustEq(t, "cacheFetches", cacheFetches, 1)
+	expvarMustEq(t, "cacheHits", cacheHits, 0)
+
+	// Fetch domain.com again, this time we should see a cache hit.
+	p, err = c.Fetch(ctx, "domain.com")
+	if err != nil || p.Check() != nil || p.MXs[0] != "*.mail.domain.com" {
+		t.Errorf("unexpected fetch result - policy = %v ; error = %v", p, err)
+	}
+	t.Logf("cache fetched domain.com: %v", p)
+	expvarMustEq(t, "cacheFetches", cacheFetches, 2)
+	expvarMustEq(t, "cacheHits", cacheHits, 1)
+
+	// Simulate an expired cache entry by changing the mtime of domain.com's
+	// entry to the past.
+	expires := time.Now().Add(-1 * time.Minute)
+	os.Chtimes(c.domainPath("domain.com"), expires, expires)
+
+	// Do a third fetch, check that we don't get a cache hit.
+	p, err = c.Fetch(ctx, "domain.com")
+	if err != nil || p.Check() != nil || p.MXs[0] != "*.mail.domain.com" {
+		t.Errorf("unexpected fetch result - policy = %v ; error = %v", p, err)
+	}
+	t.Logf("cache fetched domain.com: %v", p)
+	expvarMustEq(t, "cacheFetches", cacheFetches, 3)
+	expvarMustEq(t, "cacheHits", cacheHits, 1)
+
+	if !t.Failed() {
+		os.RemoveAll(dir)
+	}
+}
+
+// Test how the cache behaves when the files are corrupt.
+func TestCacheBadData(t *testing.T) {
+	dir := mustTempDir(t)
+	c, err := NewCache(dir)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	ctx := context.Background()
+
+	cases := []string{
+		// Case 1: A file with invalid json, which will fail unmarshalling.
+		"this is not valid json",
+
+		// Case 2: A file with a parseable but invalid policy.
+		`{"version": "STSv1", "mode": "INVALID", "mx": ["mx"], max_age": 1}`,
+	}
+
+	for _, badContent := range cases {
+		// Reset the expvar counters that we use to validate hits, misses, etc.
+		cacheFetches.Set(0)
+		cacheHits.Set(0)
+
+		// Fetch domain.com, should result in the file being added to the
+		// cache.
+		p, err := c.Fetch(ctx, "domain.com")
+		if err != nil {
+			t.Fatalf("Fetch failed: %v", err)
+		}
+		t.Logf("cache fetched domain.com: %v", p)
+		expvarMustEq(t, "cacheFetches", cacheFetches, 1)
+		expvarMustEq(t, "cacheHits", cacheHits, 0)
+
+		// Edit the file, filling it with the bad content for this case.
+		fname := c.domainPath("domain.com")
+		err = ioutil.WriteFile(fname, []byte(badContent), 0644)
+		if err != nil {
+			t.Fatalf("error writing file: %v", err)
+		}
+
+		// We now expect Fetch to fall back to getting the policy from the
+		// network (in our case, from policyForDomain).
+		p, err = c.Fetch(ctx, "domain.com")
+		if err != nil {
+			t.Fatalf("Fetch failed: %v", err)
+		}
+		t.Logf("cache fetched domain.com: %v", p)
+		expvarMustEq(t, "cacheFetches", cacheFetches, 2)
+		expvarMustEq(t, "cacheHits", cacheHits, 0)
+
+		// And now the file should be fine, resulting in a cache hit.
+		p, err = c.Fetch(ctx, "domain.com")
+		if err != nil {
+			t.Fatalf("Fetch failed: %v", err)
+		}
+		t.Logf("cache fetched domain.com: %v", p)
+		expvarMustEq(t, "cacheFetches", cacheFetches, 3)
+		expvarMustEq(t, "cacheHits", cacheHits, 1)
+
+		// Remove the file, to start with a clean slate for the next case.
+		os.Remove(fname)
+	}
+
+	if !t.Failed() {
+		os.RemoveAll(dir)
+	}
+}
+
+func mustFetch(t *testing.T, c *PolicyCache, ctx context.Context, d string) *Policy {
+	p, err := c.Fetch(ctx, d)
+	if err != nil {
+		t.Fatalf("Fetch %q failed: %v", d, err)
+	}
+	t.Logf("Fetch %q: %v", d, p)
+	return p
+}
+
+func TestCacheRefresh(t *testing.T) {
+	dir := mustTempDir(t)
+	c, err := NewCache(dir)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	ctx := context.Background()
+
+	policyForDomain["refresh-test"] = `
+		{"version": "STSv1", "mode": "enforce", "mx": ["mx"], "max_age": 100}`
+	p := mustFetch(t, c, ctx, "refresh-test")
+	if p.MaxAge != 100*time.Second {
+		t.Fatalf("policy.MaxAge is %v, expected 100s", p.MaxAge)
+	}
+
+	// Change the "published" policy, check that we see the old version at
+	// fetch (should be cached), and a new version after a refresh.
+	policyForDomain["refresh-test"] = `
+		{"version": "STSv1", "mode": "enforce", "mx": ["mx"], "max_age": 200}`
+
+	p = mustFetch(t, c, ctx, "refresh-test")
+	if p.MaxAge != 100*time.Second {
+		t.Fatalf("policy.MaxAge is %v, expected 100s", p.MaxAge)
+	}
+
+	c.refresh(ctx)
+
+	p = mustFetch(t, c, ctx, "refresh-test")
+	if p.MaxAge != 200*time.Second {
+		t.Fatalf("policy.MaxAge is %v, expected 200s", p.MaxAge)
+	}
+
+	if !t.Failed() {
+		os.RemoveAll(dir)
+	}
+}
+
+func TestURLForDomain(t *testing.T) {
+	// This function will behave differently if fakeURLForTesting is set, so
+	// temporarily unset it.
+	oldURL := fakeURLForTesting
+	fakeURLForTesting = ""
+	defer func() { fakeURLForTesting = oldURL }()
+
+	got := urlForDomain("a-test-domain")
+	expected := "https://mta-sts.a-test-domain/.well-known/mta-sts.json"
+	if got != expected {
+		t.Errorf("got %q, expected %q", got, expected)
+	}
+}