author | Alberto Bertogli
<albertito@blitiri.com.ar> 2017-04-11 00:03:05 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2018-07-01 11:19:02 UTC |
parent | a94253ba2539890370894313c388261d798bb4c7 |
chasquid.go | +9 | -1 |
cmd/smtp-check/smtp-check.go | +25 | -0 |
internal/courier/smtp.go | +82 | -6 |
internal/courier/smtp_test.go | +1 | -1 |
internal/sts/sts.go | +435 | -0 |
internal/sts/sts_test.go | +384 | -0 |
diff --git a/chasquid.go b/chasquid.go index c39e416..63a3cea 100644 --- a/chasquid.go +++ b/chasquid.go @@ -7,6 +7,7 @@ package main import ( + "context" "expvar" "flag" "fmt" @@ -25,6 +26,7 @@ import ( "blitiri.com.ar/go/chasquid/internal/maillog" "blitiri.com.ar/go/chasquid/internal/normalize" "blitiri.com.ar/go/chasquid/internal/smtpsrv" + "blitiri.com.ar/go/chasquid/internal/sts" "blitiri.com.ar/go/chasquid/internal/userdb" "blitiri.com.ar/go/log" "blitiri.com.ar/go/systemd" @@ -146,12 +148,18 @@ func main() { dinfo := s.InitDomainInfo(conf.DataDir + "/domaininfo") + stsCache, err := sts.NewCache(conf.DataDir + "/sts-cache") + if err != nil { + log.Fatalf("Failed to initialize STS cache: %v", err) + } + go stsCache.PeriodicallyRefresh(context.Background()) + localC := &courier.Procmail{ Binary: conf.MailDeliveryAgentBin, Args: conf.MailDeliveryAgentArgs, Timeout: 30 * time.Second, } - remoteC := &courier.SMTP{Dinfo: dinfo} + remoteC := &courier.SMTP{Dinfo: dinfo, STSCache: stsCache} s.InitQueue(conf.DataDir+"/queue", localC, remoteC) // Load the addresses and listeners. diff --git a/cmd/smtp-check/smtp-check.go b/cmd/smtp-check/smtp-check.go index 4ce912f..b9215fc 100644 --- a/cmd/smtp-check/smtp-check.go +++ b/cmd/smtp-check/smtp-check.go @@ -5,12 +5,15 @@ package main import ( + "context" "crypto/tls" "flag" "log" "net" "net/smtp" + "time" + "blitiri.com.ar/go/chasquid/internal/sts" "blitiri.com.ar/go/chasquid/internal/tlsconst" "blitiri.com.ar/go/spf" @@ -37,6 +40,21 @@ func main() { log.Fatalf("IDNA conversion failed: %v", err) } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + log.Printf("=== STS policy") + policy, err := sts.UncheckedFetch(ctx, domain) + if err != nil { + log.Printf("Not available (%s)", err) + } else { + log.Printf("Parsed contents: [%+v]\n", *policy) + if err := policy.Check(); err != nil { + log.Fatalf("Invalid: %v", err) + } + log.Printf("OK") + } + mxs, err := net.LookupMX(domain) if err != nil { log.Fatalf("MX lookup: %v", err) @@ -86,6 +104,13 @@ func main() { c.Close() } + if policy != nil { + if !policy.MXIsAllowed(mx.Host) { + log.Fatalf("NOT allowed by STS policy") + } + log.Printf("Allowed by policy") + } + log.Printf("") } diff --git a/internal/courier/smtp.go b/internal/courier/smtp.go index 4c779af..f85741d 100644 --- a/internal/courier/smtp.go +++ b/internal/courier/smtp.go @@ -1,6 +1,7 @@ package courier import ( + "context" "crypto/tls" "expvar" "flag" @@ -13,6 +14,7 @@ import ( "blitiri.com.ar/go/chasquid/internal/domaininfo" "blitiri.com.ar/go/chasquid/internal/envelope" "blitiri.com.ar/go/chasquid/internal/smtp" + "blitiri.com.ar/go/chasquid/internal/sts" "blitiri.com.ar/go/chasquid/internal/trace" ) @@ -30,17 +32,26 @@ var ( // TODO: replace this with proper lookup interception once it is supported // by Go. netLookupMX = net.LookupMX + + // Enable STS policy checking; this is an experimental flag and will be + // removed in the future, once this is made the default. + enableSTS = flag.Bool("experimental__enable_sts", false, + "enable STS policy checking; EXPERIMENTAL") ) // Exported variables. var ( tlsCount = expvar.NewMap("chasquid/smtpOut/tlsCount") slcResults = expvar.NewMap("chasquid/smtpOut/securityLevelChecks") + + stsSecurityModes = expvar.NewMap("chasquid/smtpOut/sts/mode") + stsSecurityResults = expvar.NewMap("chasquid/smtpOut/sts/security") ) // SMTP delivers remote mail via outgoing SMTP. type SMTP struct { - Dinfo *domaininfo.DB + Dinfo *domaininfo.DB + STSCache *sts.PolicyCache } // Deliver an email. On failures, returns an error, and whether or not it is @@ -62,7 +73,9 @@ func (s *SMTP) Deliver(from string, to string, data []byte) (error, bool) { a.from = "" } - mxs, err := lookupMXs(a.tr, a.toDomain) + a.stsPolicy = s.fetchSTSPolicy(a.tr, a.toDomain) + + mxs, err := lookupMXs(a.tr, a.toDomain, a.stsPolicy) if err != nil || len(mxs) == 0 { // Note this is considered a permanent error. // This is in line with what other servers (Exim) do. However, the @@ -108,6 +121,8 @@ type attempt struct { toDomain string helloDomain string + stsPolicy *sts.Policy + tr *trace.Trace } @@ -175,6 +190,18 @@ retry: } slcResults.Add("pass", 1) + if a.stsPolicy != nil && a.stsPolicy.Mode == sts.Enforce { + // The connection MUST be validated TLS. + // https://tools.ietf.org/html/draft-ietf-uta-mta-sts-03#section-4.2 + if secLevel != domaininfo.SecLevel_TLS_SECURE { + stsSecurityResults.Add("fail", 1) + return a.tr.Errorf("invalid security level (%v) for STS policy", + secLevel), false + } + stsSecurityResults.Add("pass", 1) + a.tr.Debugf("STS policy: connection is using valid TLS") + } + if err = c.MailAndRcpt(a.from, a.to); err != nil { return a.tr.Errorf("MAIL+RCPT %v", err), smtp.IsPermanent(err) } @@ -199,7 +226,29 @@ retry: return nil, false } -func lookupMXs(tr *trace.Trace, domain string) ([]string, error) { +func (s *SMTP) fetchSTSPolicy(tr *trace.Trace, domain string) *sts.Policy { + if !*enableSTS { + return nil + } + if s.STSCache == nil { + return nil + } + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) + defer cancel() + + policy, err := s.STSCache.Fetch(ctx, domain) + if err != nil { + return nil + } + + tr.Debugf("got STS policy") + stsSecurityModes.Add(string(policy.Mode), 1) + + return policy +} + +func lookupMXs(tr *trace.Trace, domain string, policy *sts.Policy) ([]string, error) { domain, err := idna.ToASCII(domain) if err != nil { return nil, err @@ -239,12 +288,39 @@ func lookupMXs(tr *trace.Trace, domain string) ([]string, error) { // This case is explicitly covered by the SMTP RFC. // https://tools.ietf.org/html/rfc5321#section-5.1 - // Cap the list of MXs to 5 hosts, to keep delivery attempt times sane - // and prevent abuse. - if len(mxs) > 5 { + mxs = filterMXs(tr, policy, mxs) + if len(mxs) == 0 { + tr.Errorf("domain %q has no valid MX/A record", domain) + } else if len(mxs) > 5 { + // Cap the list of MXs to 5 hosts, to keep delivery attempt times + // sane and prevent abuse. mxs = mxs[:5] } tr.Debugf("MXs: %v", mxs) return mxs, nil } + +func filterMXs(tr *trace.Trace, p *sts.Policy, mxs []string) []string { + if p == nil { + return mxs + } + + filtered := []string{} + for _, mx := range mxs { + if p.MXIsAllowed(mx) { + filtered = append(filtered, mx) + } else { + tr.Printf("MX %q not allowed by policy, skipping", mx) + } + } + + // We don't want to return an empty set if the mode is not enforce. + // This prevents failures for policies in reporting mode. + // https://tools.ietf.org/html/draft-ietf-uta-mta-sts-03#section-5.2 + if len(filtered) == 0 && p.Mode != sts.Enforce { + filtered = mxs + } + + return filtered +} diff --git a/internal/courier/smtp_test.go b/internal/courier/smtp_test.go index b12d7be..6aaf38b 100644 --- a/internal/courier/smtp_test.go +++ b/internal/courier/smtp_test.go @@ -35,7 +35,7 @@ func newSMTP(t *testing.T) (*SMTP, string) { t.Fatal(err) } - return &SMTP{dinfo}, dir + return &SMTP{dinfo, nil}, dir } // Fake server, to test SMTP out. diff --git a/internal/sts/sts.go b/internal/sts/sts.go new file mode 100644 index 0000000..7a1a4a2 --- /dev/null +++ b/internal/sts/sts.go @@ -0,0 +1,435 @@ +// Package sts implements the MTA-STS (Strict Transport Security), based on +// the current draft, https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02. +// +// This is an EXPERIMENTAL implementation for now. +// +// It lacks (at least) the following: +// - DNS TXT checking. +// - Facilities for reporting. +// +package sts + +import ( + "context" + "encoding/json" + "errors" + "expvar" + "fmt" + "io" + "io/ioutil" + "net/http" + "os" + "strings" + "sync" + "time" + + "blitiri.com.ar/go/chasquid/internal/safeio" + "blitiri.com.ar/go/chasquid/internal/trace" + + "golang.org/x/net/context/ctxhttp" + "golang.org/x/net/idna" +) + +// Exported variables. +var ( + cacheFetches = expvar.NewInt("chasquid/sts/cache/fetches") + cacheHits = expvar.NewInt("chasquid/sts/cache/hits") + cacheExpired = expvar.NewInt("chasquid/sts/cache/expired") + + cacheIOErrors = expvar.NewInt("chasquid/sts/cache/ioErrors") + cacheFailedFetch = expvar.NewInt("chasquid/sts/cache/failedFetch") + cacheInvalid = expvar.NewInt("chasquid/sts/cache/invalid") + + cacheMarshalErrors = expvar.NewInt("chasquid/sts/cache/marshalErrors") + cacheUnmarshalErrors = expvar.NewInt("chasquid/sts/cache/unmarshalErrors") + + cacheRefreshCycles = expvar.NewInt("chasquid/sts/cache/refreshCycles") + cacheRefreshes = expvar.NewInt("chasquid/sts/cache/refreshes") + cacheRefreshErrors = expvar.NewInt("chasquid/sts/cache/refreshErrors") +) + +// Policy represents a parsed policy. +// https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-3.2 +type Policy struct { + Version string `json:"version"` + Mode Mode `json:"mode"` + MXs []string `json:"mx"` + MaxAge time.Duration `json:"max_age"` +} + +type Mode string + +// Valid modes. +const ( + Enforce = Mode("enforce") + Report = Mode("report") +) + +// parsePolicy parses a JSON representation of the policy, and returns the +// corresponding Policy structure. +func parsePolicy(raw []byte) (*Policy, error) { + p := &Policy{} + if err := json.Unmarshal(raw, p); err != nil { + return nil, err + } + + // MaxAge is in seconds. + p.MaxAge = p.MaxAge * time.Second + + return p, nil +} + +var ( + ErrUnknownVersion = errors.New("unknown policy version") + ErrInvalidMaxAge = errors.New("invalid max_age") + ErrInvalidMode = errors.New("invalid mode") + ErrInvalidMX = errors.New("invalid mx") +) + +// Check that the policy contents are valid. +func (p *Policy) Check() error { + if p.Version != "STSv1" { + return ErrUnknownVersion + } + if p.MaxAge <= 0 { + return ErrInvalidMaxAge + } + + if p.Mode != Enforce && p.Mode != Report { + return ErrInvalidMode + } + + // "mx" field is required, and the policy is invalid if it's not present. + // https://mailarchive.ietf.org/arch/msg/uta/Omqo1Bw6rJbrTMl2Zo69IJr35Qo + if len(p.MXs) == 0 { + return ErrInvalidMX + } + + return nil +} + +// MXMatches checks if the given MX is allowed, according to the policy. +// https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-4.1 +func (p *Policy) MXIsAllowed(mx string) bool { + for _, pattern := range p.MXs { + if matchDomain(mx, pattern) { + return true + } + } + + return false +} + +// UncheckedFetch fetches and parses the policy, but does NOT check it. +// This can be useful for debugging and troubleshooting, but you should always +// call Check on the policy before using it. +func UncheckedFetch(ctx context.Context, domain string) (*Policy, error) { + // Convert the domain to ascii form, as httpGet does not support IDNs in + // any other way. + domain, err := idna.ToASCII(domain) + if err != nil { + return nil, err + } + + url := urlForDomain(domain) + rawPolicy, err := httpGet(ctx, url) + if err != nil { + return nil, err + } + + return parsePolicy(rawPolicy) +} + +// Fake URL for testing purposes, so we can do more end-to-end tests, +// including the HTTP fetching code. +var fakeURLForTesting string + +func urlForDomain(domain string) string { + if fakeURLForTesting != "" { + return fakeURLForTesting + "/" + domain + } + + // URL composed from the domain, as explained in: + // https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-3.3 + // https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-3.2 + return "https://mta-sts." + domain + "/.well-known/mta-sts.json" +} + +// Fetch a policy for the given domain. Note this results in various network +// lookups and HTTPS GETs, so it can be slow. +// The returned policy is parsed and sanity-checked (using Policy.Check), so +// it should be safe to use. +func Fetch(ctx context.Context, domain string) (*Policy, error) { + p, err := UncheckedFetch(ctx, domain) + if err != nil { + return nil, err + } + + err = p.Check() + if err != nil { + return nil, err + } + + return p, nil +} + +// httpGet performs an HTTP GET of the given URL, using the context and +// rejecting redirects, as per the standard. +func httpGet(ctx context.Context, url string) ([]byte, error) { + client := &http.Client{ + // We MUST NOT follow redirects, see + // https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-3.3 + CheckRedirect: rejectRedirect, + } + + // Note that http does not care for the context deadline, so we need to + // construct it here. + if deadline, ok := ctx.Deadline(); ok { + client.Timeout = deadline.Sub(time.Now()) + } + + resp, err := ctxhttp.Get(ctx, client, url) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusOK { + // Read but up to 10k; policies should be way smaller than that, and + // having a limit prevents abuse/accidents with very large replies. + return ioutil.ReadAll(&io.LimitedReader{resp.Body, 10 * 1024}) + } + return nil, fmt.Errorf("HTTP response status code: %v", resp.StatusCode) +} + +var errRejectRedirect = errors.New("redirects not allowed in MTA-STS") + +func rejectRedirect(req *http.Request, via []*http.Request) error { + return errRejectRedirect +} + +// matchDomain checks if the domain matches the given pattern, according to +// https://tools.ietf.org/html/rfc6125#section-6.4 +// (from https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-4.1). +func matchDomain(domain, pattern string) bool { + domain, dErr := domainToASCII(domain) + pattern, pErr := domainToASCII(pattern) + if dErr != nil || pErr != nil { + // Domains should already have been checked and normalized by the + // caller, exposing this is not worth the API complexity in this case. + return false + } + + domainLabels := strings.Split(domain, ".") + patternLabels := strings.Split(pattern, ".") + + if len(domainLabels) != len(patternLabels) { + return false + } + + for i, p := range patternLabels { + // Wildcards only apply to the first part, see + // https://tools.ietf.org/html/rfc6125#section-6.4.3 #1 and #2. + // This also allows us to do the lenght comparison above. + if p == "*" && i == 0 { + continue + } + + if p != domainLabels[i] { + return false + } + } + + return true +} + +// domainToASCII converts the domain to ASCII form, similar to idna.ToASCII +// but with some preprocessing convenient for our use cases. +func domainToASCII(domain string) (string, error) { + domain = strings.TrimSuffix(domain, ".") + domain = strings.ToLower(domain) + return idna.ToASCII(domain) +} + +// PolicyCache is a caching layer for fetching policies. +// +// Policies are cached by domain, and stored in a single directory. +// The files will have as mtime the time when the policy expires, this makes +// the store simpler, as it can avoid keeping additional metadata. +// +// There is no in-memory caching. This may be added in the future, but for +// now disk is good enough for our purposes. +type PolicyCache struct { + dir string + + sync.Mutex +} + +func NewCache(dir string) (*PolicyCache, error) { + c := &PolicyCache{ + dir: dir, + } + err := os.MkdirAll(dir, 0770) + return c, err +} + +const pathPrefix = "pol:" + +func (c *PolicyCache) domainPath(domain string) string { + // We assume the domain is well formed, sanity check just in case. + if strings.Contains(domain, "/") { + panic("domain contains slash") + } + + return c.dir + "/" + pathPrefix + domain +} + +var ErrExpired = errors.New("cache entry expired") + +func (c *PolicyCache) load(domain string) (*Policy, error) { + fname := c.domainPath(domain) + + fi, err := os.Stat(fname) + if err != nil { + return nil, err + } + if time.Since(fi.ModTime()) > 0 { + cacheExpired.Add(1) + return nil, ErrExpired + } + + data, err := ioutil.ReadFile(fname) + if err != nil { + cacheIOErrors.Add(1) + return nil, err + } + + p := &Policy{} + err = json.Unmarshal(data, p) + if err != nil { + cacheUnmarshalErrors.Add(1) + return nil, err + } + + // The policy should always be valid, as we marshalled it ourselves; + // however, check it just to be safe. + if err := p.Check(); err != nil { + cacheInvalid.Add(1) + return nil, fmt.Errorf( + "%s unmarshalled invalid policy %v: %v", domain, p, err) + } + + return p, nil +} + +func (c *PolicyCache) store(domain string, p *Policy) error { + data, err := json.Marshal(p) + if err != nil { + cacheMarshalErrors.Add(1) + return fmt.Errorf("%s failed to marshal policy %v, error: %v", + domain, p, err) + } + + // Change the modification time to the future, when the policy expires. + // load will check for this to detect expired cache entries, see above for + // the details. + expires := time.Now().Add(p.MaxAge) + chTime := func(fname string) error { + return os.Chtimes(fname, expires, expires) + } + + fname := c.domainPath(domain) + err = safeio.WriteFile(fname, data, 0640, chTime) + if err != nil { + cacheIOErrors.Add(1) + } + return err +} + +func (c *PolicyCache) Fetch(ctx context.Context, domain string) (*Policy, error) { + cacheFetches.Add(1) + tr := trace.New("STSCache.Fetch", domain) + defer tr.Finish() + + p, err := c.load(domain) + if err == nil { + tr.Debugf("cache hit: %v", p) + cacheHits.Add(1) + return p, nil + } + + p, err = Fetch(ctx, domain) + if err != nil { + tr.Debugf("failed to fetch: %v", err) + cacheFailedFetch.Add(1) + return nil, err + } + tr.Debugf("fetched: %v", p) + + // We could do this asynchronously, as we got the policy to give to the + // caller. However, to make troubleshooting easier and the cost of storing + // entries easier to track down, we store synchronously. + // Note that even if the store returns an error, we pass on the policy: at + // this point we rather use the policy even if we couldn't store it in the + // cache. + err = c.store(domain, p) + if err != nil { + tr.Errorf("failed to store: %v", err) + } else { + tr.Debugf("stored") + } + + return p, nil +} + +func (c *PolicyCache) PeriodicallyRefresh(ctx context.Context) { + for ctx.Err() == nil { + cacheRefreshCycles.Add(1) + + c.refresh(ctx) + + // Wait 10 minutes between passes; this is a background refresh and + // there's no need to poke the servers very often. + time.Sleep(10 * time.Minute) + } +} + +func (c *PolicyCache) refresh(ctx context.Context) { + tr := trace.New("STSCache.Refresh", c.dir) + defer tr.Finish() + + entries, err := ioutil.ReadDir(c.dir) + if err != nil { + tr.Errorf("failed to list directory %q: %v", c.dir, err) + return + } + tr.Debugf("%d entries", len(entries)) + + for _, e := range entries { + if !strings.HasPrefix(e.Name(), pathPrefix) { + continue + } + domain := e.Name()[len(pathPrefix):] + cacheRefreshes.Add(1) + tr.Debugf("%v: refreshing", domain) + + fetchCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + p, err := Fetch(fetchCtx, domain) + cancel() + if err != nil { + tr.Debugf("%v: failed to fetch: %v", domain, err) + cacheRefreshErrors.Add(1) + continue + } + tr.Debugf("%v: fetched", domain) + + err = c.store(domain, p) + if err != nil { + tr.Errorf("%v: failed to store: %v", domain, err) + } else { + tr.Debugf("%v: stored", domain) + } + } + + tr.Debugf("refresh done") +} diff --git a/internal/sts/sts_test.go b/internal/sts/sts_test.go new file mode 100644 index 0000000..bc26940 --- /dev/null +++ b/internal/sts/sts_test.go @@ -0,0 +1,384 @@ +package sts + +import ( + "context" + "expvar" + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "os" + "strconv" + "testing" + "time" +) + +// Test policy for each of the requested domains. Will be served by the test +// HTTP server. +var policyForDomain = map[string]string{ + // domain.com -> valid, with reasonable policy. + "domain.com": ` + { + "version": "STSv1", + "mode": "enforce", + "mx": ["*.mail.domain.com"], + "max_age": 3600 + }`, + + // version99 -> invalid policy (unknown version). + "version99": ` + { + "version": "STSv99", + "mode": "enforce", + "mx": ["*.mail.version99"], + "max_age": 999 + }`, +} + +func testHTTPHandler(w http.ResponseWriter, r *http.Request) { + // For testing, the domain in the path (see urlForDomain). + policy, ok := policyForDomain[r.URL.Path[1:]] + if !ok { + http.Error(w, "not found", 404) + return + } + fmt.Fprintln(w, policy) + return +} + +func TestMain(m *testing.M) { + // Create a test HTTP server, used by the more end-to-end tests. + httpServer := httptest.NewServer(http.HandlerFunc(testHTTPHandler)) + + fakeURLForTesting = httpServer.URL + os.Exit(m.Run()) +} + +func TestParsePolicy(t *testing.T) { + const pol1 = `{ + "version": "STSv1", + "mode": "enforce", + "mx": ["*.mail.example.com"], + "max_age": 123456 +} +` + p, err := parsePolicy([]byte(pol1)) + if err != nil { + t.Errorf("failed to parse policy: %v", err) + } + + t.Logf("pol1: %+v", p) +} + +func TestCheckPolicy(t *testing.T) { + validPs := []Policy{ + {Version: "STSv1", Mode: "enforce", MaxAge: 1 * time.Hour, + MXs: []string{"mx1", "mx2"}}, + {Version: "STSv1", Mode: "report", MaxAge: 1 * time.Hour, + MXs: []string{"mx1"}}, + } + for i, p := range validPs { + if err := p.Check(); err != nil { + t.Errorf("%d policy %v failed check: %v", i, p, err) + } + } + + invalid := []struct { + p Policy + expected error + }{ + {Policy{Version: "STSv2"}, ErrUnknownVersion}, + {Policy{Version: "STSv1"}, ErrInvalidMaxAge}, + {Policy{Version: "STSv1", MaxAge: 1, Mode: "blah"}, ErrInvalidMode}, + {Policy{Version: "STSv1", MaxAge: 1, Mode: "enforce"}, ErrInvalidMX}, + {Policy{Version: "STSv1", MaxAge: 1, Mode: "enforce", MXs: []string{}}, + ErrInvalidMX}, + } + for i, c := range invalid { + if err := c.p.Check(); err != c.expected { + t.Errorf("%d policy %v check: expected %v, got %v", i, c.p, + c.expected, err) + } + } +} + +func TestMatchDomain(t *testing.T) { + cases := []struct { + domain, pattern string + expected bool + }{ + {"lalala", "lalala", true}, + {"a.b.", "a.b", true}, + {"a.b", "a.b.", true}, + {"abc.com", "*.com", true}, + + {"abc.com", "abc.*.com", false}, + {"abc.com", "x.abc.com", false}, + {"x.abc.com", "*.*.com", false}, + + {"ñaca.com", "ñaca.com", true}, + {"Ñaca.com", "ñaca.com", true}, + {"ñaca.com", "Ñaca.com", true}, + {"x.ñaca.com", "x.xn--aca-6ma.com", true}, + {"x.naca.com", "x.xn--aca-6ma.com", false}, + } + + for _, c := range cases { + if r := matchDomain(c.domain, c.pattern); r != c.expected { + t.Errorf("matchDomain(%q, %q) = %v, expected %v", + c.domain, c.pattern, r, c.expected) + } + } +} + +func TestFetch(t *testing.T) { + // Note the data "fetched" for each domain comes from policyForDomain, + // defined in TestMain above. See httpGet for more details. + + // Normal fetch, all valid. + p, err := Fetch(context.Background(), "domain.com") + if err != nil { + t.Errorf("failed to fetch policy: %v", err) + } + t.Logf("domain.com: %+v", p) + + // Domain without a policy (HTTP get fails). + p, err = Fetch(context.Background(), "unknown") + if err == nil { + t.Errorf("fetched unknown policy: %v", p) + } + t.Logf("unknown: got error as expected: %v", err) + + // Domain with an invalid policy (unknown version). + p, err = Fetch(context.Background(), "version99") + if err != ErrUnknownVersion { + t.Errorf("expected error %v, got %v (and policy: %v)", + ErrUnknownVersion, err, p) + } + t.Logf("version99: got expected error: %v", err) +} + +func TestPolicyTooBig(t *testing.T) { + // Construct a valid but very large JSON as a policy. + raw := `{"version": "STSv1", "mode": "enforce", "mx": [` + for i := 0; i < 2000; i++ { + raw += fmt.Sprintf("\"mx%d\", ", i) + } + raw += `"mxlast"], "max_age": 100}` + policyForDomain["toobig"] = raw + + _, err := Fetch(context.Background(), "toobig") + if err == nil { + t.Errorf("fetch worked, but should have failed") + } + t.Logf("got error as expected: %v", err) +} + +// Tests for the policy cache. + +func mustTempDir(t *testing.T) string { + dir, err := ioutil.TempDir("", "sts_test") + if err != nil { + t.Fatal(err) + } + + err = os.Chdir(dir) + if err != nil { + t.Fatal(err) + } + + t.Logf("test directory: %q", dir) + + return dir +} + +func expvarMustEq(t *testing.T, name string, v *expvar.Int, expected int) { + // TODO: Use v.Value once we drop support of Go 1.7. + value, _ := strconv.Atoi(v.String()) + if value != expected { + t.Errorf("%s is %d, expected %d", name, value, expected) + } +} + +func TestCacheBasics(t *testing.T) { + dir := mustTempDir(t) + c, err := NewCache(dir) + if err != nil { + t.Fatal(err) + } + + // Note the data "fetched" for each domain comes from policyForDomain, + // defined in TestMain above. See httpGet for more details. + + // Reset the expvar counters that we use to validate hits, misses, etc. + cacheFetches.Set(0) + cacheHits.Set(0) + + ctx := context.Background() + + // Fetch domain.com, check we get a reasonable policy, and that it's a + // cache miss. + p, err := c.Fetch(ctx, "domain.com") + if err != nil || p.Check() != nil || p.MXs[0] != "*.mail.domain.com" { + t.Errorf("unexpected fetch result - policy = %v ; error = %v", p, err) + } + t.Logf("cache fetched domain.com: %v", p) + expvarMustEq(t, "cacheFetches", cacheFetches, 1) + expvarMustEq(t, "cacheHits", cacheHits, 0) + + // Fetch domain.com again, this time we should see a cache hit. + p, err = c.Fetch(ctx, "domain.com") + if err != nil || p.Check() != nil || p.MXs[0] != "*.mail.domain.com" { + t.Errorf("unexpected fetch result - policy = %v ; error = %v", p, err) + } + t.Logf("cache fetched domain.com: %v", p) + expvarMustEq(t, "cacheFetches", cacheFetches, 2) + expvarMustEq(t, "cacheHits", cacheHits, 1) + + // Simulate an expired cache entry by changing the mtime of domain.com's + // entry to the past. + expires := time.Now().Add(-1 * time.Minute) + os.Chtimes(c.domainPath("domain.com"), expires, expires) + + // Do a third fetch, check that we don't get a cache hit. + p, err = c.Fetch(ctx, "domain.com") + if err != nil || p.Check() != nil || p.MXs[0] != "*.mail.domain.com" { + t.Errorf("unexpected fetch result - policy = %v ; error = %v", p, err) + } + t.Logf("cache fetched domain.com: %v", p) + expvarMustEq(t, "cacheFetches", cacheFetches, 3) + expvarMustEq(t, "cacheHits", cacheHits, 1) + + if !t.Failed() { + os.RemoveAll(dir) + } +} + +// Test how the cache behaves when the files are corrupt. +func TestCacheBadData(t *testing.T) { + dir := mustTempDir(t) + c, err := NewCache(dir) + if err != nil { + t.Fatal(err) + } + + ctx := context.Background() + + cases := []string{ + // Case 1: A file with invalid json, which will fail unmarshalling. + "this is not valid json", + + // Case 2: A file with a parseable but invalid policy. + `{"version": "STSv1", "mode": "INVALID", "mx": ["mx"], max_age": 1}`, + } + + for _, badContent := range cases { + // Reset the expvar counters that we use to validate hits, misses, etc. + cacheFetches.Set(0) + cacheHits.Set(0) + + // Fetch domain.com, should result in the file being added to the + // cache. + p, err := c.Fetch(ctx, "domain.com") + if err != nil { + t.Fatalf("Fetch failed: %v", err) + } + t.Logf("cache fetched domain.com: %v", p) + expvarMustEq(t, "cacheFetches", cacheFetches, 1) + expvarMustEq(t, "cacheHits", cacheHits, 0) + + // Edit the file, filling it with the bad content for this case. + fname := c.domainPath("domain.com") + err = ioutil.WriteFile(fname, []byte(badContent), 0644) + if err != nil { + t.Fatalf("error writing file: %v", err) + } + + // We now expect Fetch to fall back to getting the policy from the + // network (in our case, from policyForDomain). + p, err = c.Fetch(ctx, "domain.com") + if err != nil { + t.Fatalf("Fetch failed: %v", err) + } + t.Logf("cache fetched domain.com: %v", p) + expvarMustEq(t, "cacheFetches", cacheFetches, 2) + expvarMustEq(t, "cacheHits", cacheHits, 0) + + // And now the file should be fine, resulting in a cache hit. + p, err = c.Fetch(ctx, "domain.com") + if err != nil { + t.Fatalf("Fetch failed: %v", err) + } + t.Logf("cache fetched domain.com: %v", p) + expvarMustEq(t, "cacheFetches", cacheFetches, 3) + expvarMustEq(t, "cacheHits", cacheHits, 1) + + // Remove the file, to start with a clean slate for the next case. + os.Remove(fname) + } + + if !t.Failed() { + os.RemoveAll(dir) + } +} + +func mustFetch(t *testing.T, c *PolicyCache, ctx context.Context, d string) *Policy { + p, err := c.Fetch(ctx, d) + if err != nil { + t.Fatalf("Fetch %q failed: %v", d, err) + } + t.Logf("Fetch %q: %v", d, p) + return p +} + +func TestCacheRefresh(t *testing.T) { + dir := mustTempDir(t) + c, err := NewCache(dir) + if err != nil { + t.Fatal(err) + } + + ctx := context.Background() + + policyForDomain["refresh-test"] = ` + {"version": "STSv1", "mode": "enforce", "mx": ["mx"], "max_age": 100}` + p := mustFetch(t, c, ctx, "refresh-test") + if p.MaxAge != 100*time.Second { + t.Fatalf("policy.MaxAge is %v, expected 100s", p.MaxAge) + } + + // Change the "published" policy, check that we see the old version at + // fetch (should be cached), and a new version after a refresh. + policyForDomain["refresh-test"] = ` + {"version": "STSv1", "mode": "enforce", "mx": ["mx"], "max_age": 200}` + + p = mustFetch(t, c, ctx, "refresh-test") + if p.MaxAge != 100*time.Second { + t.Fatalf("policy.MaxAge is %v, expected 100s", p.MaxAge) + } + + c.refresh(ctx) + + p = mustFetch(t, c, ctx, "refresh-test") + if p.MaxAge != 200*time.Second { + t.Fatalf("policy.MaxAge is %v, expected 200s", p.MaxAge) + } + + if !t.Failed() { + os.RemoveAll(dir) + } +} + +func TestURLForDomain(t *testing.T) { + // This function will behave differently if fakeURLForTesting is set, so + // temporarily unset it. + oldURL := fakeURLForTesting + fakeURLForTesting = "" + defer func() { fakeURLForTesting = oldURL }() + + got := urlForDomain("a-test-domain") + expected := "https://mta-sts.a-test-domain/.well-known/mta-sts.json" + if got != expected { + t.Errorf("got %q, expected %q", got, expected) + } +}