author | Alberto Bertogli
<albertito@blitiri.com.ar> 2017-04-10 23:39:21 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2017-04-10 23:58:59 UTC |
parent | 213bc63a95bae17048b284fe247c20c7f3c625f4 |
chasquid.go | +1 | -9 |
cmd/smtp-check/smtp-check.go | +0 | -25 |
internal/courier/smtp.go | +12 | -88 |
internal/courier/smtp_test.go | +1 | -1 |
internal/sts/sts.go | +0 | -435 |
internal/sts/sts_test.go | +0 | -384 |
diff --git a/chasquid.go b/chasquid.go index e9a3b22..9b073a2 100644 --- a/chasquid.go +++ b/chasquid.go @@ -1,7 +1,6 @@ package main import ( - "context" "expvar" "flag" "fmt" @@ -20,7 +19,6 @@ import ( "blitiri.com.ar/go/chasquid/internal/maillog" "blitiri.com.ar/go/chasquid/internal/normalize" "blitiri.com.ar/go/chasquid/internal/smtpsrv" - "blitiri.com.ar/go/chasquid/internal/sts" "blitiri.com.ar/go/chasquid/internal/systemd" "blitiri.com.ar/go/chasquid/internal/userdb" @@ -137,18 +135,12 @@ func main() { dinfo := s.InitDomainInfo(conf.DataDir + "/domaininfo") - stsCache, err := sts.NewCache(conf.DataDir + "/sts-cache") - if err != nil { - log.Fatalf("Failed to initialize STS cache: %v", err) - } - go stsCache.PeriodicallyRefresh(context.Background()) - localC := &courier.Procmail{ Binary: conf.MailDeliveryAgentBin, Args: conf.MailDeliveryAgentArgs, Timeout: 30 * time.Second, } - remoteC := &courier.SMTP{Dinfo: dinfo, STSCache: stsCache} + remoteC := &courier.SMTP{Dinfo: dinfo} s.InitQueue(conf.DataDir+"/queue", localC, remoteC) // Load the addresses and listeners. diff --git a/cmd/smtp-check/smtp-check.go b/cmd/smtp-check/smtp-check.go index 2dd2471..c7f03a7 100644 --- a/cmd/smtp-check/smtp-check.go +++ b/cmd/smtp-check/smtp-check.go @@ -2,16 +2,13 @@ package main import ( - "context" "crypto/tls" "flag" "log" "net" "net/smtp" - "time" "blitiri.com.ar/go/chasquid/internal/spf" - "blitiri.com.ar/go/chasquid/internal/sts" "blitiri.com.ar/go/chasquid/internal/tlsconst" "golang.org/x/net/idna" @@ -37,21 +34,6 @@ func main() { log.Fatalf("IDNA conversion failed: %v", err) } - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - log.Printf("=== STS policy") - policy, err := sts.UncheckedFetch(ctx, domain) - if err != nil { - log.Printf("Not available (%s)", err) - } else { - log.Printf("Parsed contents: [%+v]\n", *policy) - if err := policy.Check(); err != nil { - log.Fatalf("Invalid: %v", err) - } - log.Printf("OK") - } - mxs, err := net.LookupMX(domain) if err != nil { log.Fatalf("MX lookup: %v", err) @@ -101,13 +83,6 @@ func main() { c.Close() } - if policy != nil { - if !policy.MXIsAllowed(mx.Host) { - log.Fatalf("NOT allowed by STS policy") - } - log.Printf("Allowed by policy") - } - log.Printf("") } diff --git a/internal/courier/smtp.go b/internal/courier/smtp.go index d633649..ebb2ce7 100644 --- a/internal/courier/smtp.go +++ b/internal/courier/smtp.go @@ -1,7 +1,6 @@ package courier import ( - "context" "crypto/tls" "expvar" "flag" @@ -14,7 +13,6 @@ import ( "blitiri.com.ar/go/chasquid/internal/domaininfo" "blitiri.com.ar/go/chasquid/internal/envelope" "blitiri.com.ar/go/chasquid/internal/smtp" - "blitiri.com.ar/go/chasquid/internal/sts" "blitiri.com.ar/go/chasquid/internal/trace" ) @@ -28,11 +26,6 @@ var ( smtpPort = flag.String("testing__outgoing_smtp_port", "25", "port to use for outgoing SMTP connections, ONLY FOR TESTING") - // Enable STS policy checking; this is an experimental flag and will be - // removed in the future, once this is made the default. - enableSTS = flag.Bool("experimental__enable_sts", false, - "enable STS policy checking; EXPERIMENTAL") - // Fake MX records, used for testing only. fakeMX = map[string][]string{} ) @@ -41,25 +34,20 @@ var ( var ( tlsCount = expvar.NewMap("chasquid/smtpOut/tlsCount") slcResults = expvar.NewMap("chasquid/smtpOut/securityLevelChecks") - - stsSecurityModes = expvar.NewMap("chasquid/smtpOut/sts/mode") - stsSecurityResults = expvar.NewMap("chasquid/smtpOut/sts/security") ) // SMTP delivers remote mail via outgoing SMTP. type SMTP struct { - Dinfo *domaininfo.DB - STSCache *sts.PolicyCache + Dinfo *domaininfo.DB } func (s *SMTP) Deliver(from string, to string, data []byte) (error, bool) { a := &attempt{ - courier: s, - from: from, - to: to, - data: data, - toDomain: envelope.DomainOf(to), - tr: trace.New("Courier.SMTP", to), + courier: s, + from: from, + to: to, + data: data, + tr: trace.New("Courier.SMTP", to), } defer a.tr.Finish() a.tr.Debugf("%s -> %s", from, to) @@ -69,9 +57,8 @@ func (s *SMTP) Deliver(from string, to string, data []byte) (error, bool) { a.from = "" } - a.stsPolicy = s.fetchSTSPolicy(a.tr, a.toDomain) - - mxs, err := lookupMXs(a.tr, a.toDomain, a.stsPolicy) + toDomain := envelope.DomainOf(to) + mxs, err := lookupMXs(a.tr, toDomain) if err != nil || len(mxs) == 0 { // Note this is considered a permanent error. // This is in line with what other servers (Exim) do. However, the @@ -117,8 +104,6 @@ type attempt struct { toDomain string helloDomain string - stsPolicy *sts.Policy - tr *trace.Trace } @@ -186,18 +171,6 @@ retry: } slcResults.Add("pass", 1) - if a.stsPolicy != nil && a.stsPolicy.Mode == sts.Enforce { - // The connection MUST be validated TLS. - // https://tools.ietf.org/html/draft-ietf-uta-mta-sts-03#section-4.2 - if secLevel != domaininfo.SecLevel_TLS_SECURE { - stsSecurityResults.Add("fail", 1) - return a.tr.Errorf("invalid security level (%v) for STS policy", - secLevel), false - } - stsSecurityResults.Add("pass", 1) - a.tr.Debugf("STS policy: connection is using valid TLS") - } - if err = c.MailAndRcpt(a.from, a.to); err != nil { return a.tr.Errorf("MAIL+RCPT %v", err), smtp.IsPermanent(err) } @@ -222,29 +195,7 @@ retry: return nil, false } -func (s *SMTP) fetchSTSPolicy(tr *trace.Trace, domain string) *sts.Policy { - if !*enableSTS { - return nil - } - if s.STSCache == nil { - return nil - } - - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) - defer cancel() - - policy, err := s.STSCache.Fetch(ctx, domain) - if err != nil { - return nil - } - - tr.Debugf("got STS policy") - stsSecurityModes.Add(string(policy.Mode), 1) - - return policy -} - -func lookupMXs(tr *trace.Trace, domain string, policy *sts.Policy) ([]string, error) { +func lookupMXs(tr *trace.Trace, domain string) ([]string, error) { if v, ok := fakeMX[domain]; ok { return v, nil } @@ -288,39 +239,12 @@ func lookupMXs(tr *trace.Trace, domain string, policy *sts.Policy) ([]string, er // This case is explicitly covered by the SMTP RFC. // https://tools.ietf.org/html/rfc5321#section-5.1 - mxs = filterMXs(tr, policy, mxs) - if len(mxs) == 0 { - tr.Errorf("domain %q has no valid MX/A record", domain) - } else if len(mxs) > 5 { - // Cap the list of MXs to 5 hosts, to keep delivery attempt times - // sane and prevent abuse. + // Cap the list of MXs to 5 hosts, to keep delivery attempt times sane + // and prevent abuse. + if len(mxs) > 5 { mxs = mxs[:5] } tr.Debugf("MXs: %v", mxs) return mxs, nil } - -func filterMXs(tr *trace.Trace, p *sts.Policy, mxs []string) []string { - if p == nil { - return mxs - } - - filtered := []string{} - for _, mx := range mxs { - if p.MXIsAllowed(mx) { - filtered = append(filtered, mx) - } else { - tr.Printf("MX %q not allowed by policy, skipping", mx) - } - } - - // We don't want to return an empty set if the mode is not enforce. - // This prevents failures for policies in reporting mode. - // https://tools.ietf.org/html/draft-ietf-uta-mta-sts-03#section-5.2 - if len(filtered) == 0 && p.Mode != sts.Enforce { - filtered = mxs - } - - return filtered -} diff --git a/internal/courier/smtp_test.go b/internal/courier/smtp_test.go index 8b16b60..72fbe10 100644 --- a/internal/courier/smtp_test.go +++ b/internal/courier/smtp_test.go @@ -23,7 +23,7 @@ func newSMTP(t *testing.T) (*SMTP, string) { t.Fatal(err) } - return &SMTP{dinfo, nil}, dir + return &SMTP{dinfo}, dir } // Fake server, to test SMTP out. diff --git a/internal/sts/sts.go b/internal/sts/sts.go deleted file mode 100644 index 7a1a4a2..0000000 --- a/internal/sts/sts.go +++ /dev/null @@ -1,435 +0,0 @@ -// Package sts implements the MTA-STS (Strict Transport Security), based on -// the current draft, https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02. -// -// This is an EXPERIMENTAL implementation for now. -// -// It lacks (at least) the following: -// - DNS TXT checking. -// - Facilities for reporting. -// -package sts - -import ( - "context" - "encoding/json" - "errors" - "expvar" - "fmt" - "io" - "io/ioutil" - "net/http" - "os" - "strings" - "sync" - "time" - - "blitiri.com.ar/go/chasquid/internal/safeio" - "blitiri.com.ar/go/chasquid/internal/trace" - - "golang.org/x/net/context/ctxhttp" - "golang.org/x/net/idna" -) - -// Exported variables. -var ( - cacheFetches = expvar.NewInt("chasquid/sts/cache/fetches") - cacheHits = expvar.NewInt("chasquid/sts/cache/hits") - cacheExpired = expvar.NewInt("chasquid/sts/cache/expired") - - cacheIOErrors = expvar.NewInt("chasquid/sts/cache/ioErrors") - cacheFailedFetch = expvar.NewInt("chasquid/sts/cache/failedFetch") - cacheInvalid = expvar.NewInt("chasquid/sts/cache/invalid") - - cacheMarshalErrors = expvar.NewInt("chasquid/sts/cache/marshalErrors") - cacheUnmarshalErrors = expvar.NewInt("chasquid/sts/cache/unmarshalErrors") - - cacheRefreshCycles = expvar.NewInt("chasquid/sts/cache/refreshCycles") - cacheRefreshes = expvar.NewInt("chasquid/sts/cache/refreshes") - cacheRefreshErrors = expvar.NewInt("chasquid/sts/cache/refreshErrors") -) - -// Policy represents a parsed policy. -// https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-3.2 -type Policy struct { - Version string `json:"version"` - Mode Mode `json:"mode"` - MXs []string `json:"mx"` - MaxAge time.Duration `json:"max_age"` -} - -type Mode string - -// Valid modes. -const ( - Enforce = Mode("enforce") - Report = Mode("report") -) - -// parsePolicy parses a JSON representation of the policy, and returns the -// corresponding Policy structure. -func parsePolicy(raw []byte) (*Policy, error) { - p := &Policy{} - if err := json.Unmarshal(raw, p); err != nil { - return nil, err - } - - // MaxAge is in seconds. - p.MaxAge = p.MaxAge * time.Second - - return p, nil -} - -var ( - ErrUnknownVersion = errors.New("unknown policy version") - ErrInvalidMaxAge = errors.New("invalid max_age") - ErrInvalidMode = errors.New("invalid mode") - ErrInvalidMX = errors.New("invalid mx") -) - -// Check that the policy contents are valid. -func (p *Policy) Check() error { - if p.Version != "STSv1" { - return ErrUnknownVersion - } - if p.MaxAge <= 0 { - return ErrInvalidMaxAge - } - - if p.Mode != Enforce && p.Mode != Report { - return ErrInvalidMode - } - - // "mx" field is required, and the policy is invalid if it's not present. - // https://mailarchive.ietf.org/arch/msg/uta/Omqo1Bw6rJbrTMl2Zo69IJr35Qo - if len(p.MXs) == 0 { - return ErrInvalidMX - } - - return nil -} - -// MXMatches checks if the given MX is allowed, according to the policy. -// https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-4.1 -func (p *Policy) MXIsAllowed(mx string) bool { - for _, pattern := range p.MXs { - if matchDomain(mx, pattern) { - return true - } - } - - return false -} - -// UncheckedFetch fetches and parses the policy, but does NOT check it. -// This can be useful for debugging and troubleshooting, but you should always -// call Check on the policy before using it. -func UncheckedFetch(ctx context.Context, domain string) (*Policy, error) { - // Convert the domain to ascii form, as httpGet does not support IDNs in - // any other way. - domain, err := idna.ToASCII(domain) - if err != nil { - return nil, err - } - - url := urlForDomain(domain) - rawPolicy, err := httpGet(ctx, url) - if err != nil { - return nil, err - } - - return parsePolicy(rawPolicy) -} - -// Fake URL for testing purposes, so we can do more end-to-end tests, -// including the HTTP fetching code. -var fakeURLForTesting string - -func urlForDomain(domain string) string { - if fakeURLForTesting != "" { - return fakeURLForTesting + "/" + domain - } - - // URL composed from the domain, as explained in: - // https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-3.3 - // https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-3.2 - return "https://mta-sts." + domain + "/.well-known/mta-sts.json" -} - -// Fetch a policy for the given domain. Note this results in various network -// lookups and HTTPS GETs, so it can be slow. -// The returned policy is parsed and sanity-checked (using Policy.Check), so -// it should be safe to use. -func Fetch(ctx context.Context, domain string) (*Policy, error) { - p, err := UncheckedFetch(ctx, domain) - if err != nil { - return nil, err - } - - err = p.Check() - if err != nil { - return nil, err - } - - return p, nil -} - -// httpGet performs an HTTP GET of the given URL, using the context and -// rejecting redirects, as per the standard. -func httpGet(ctx context.Context, url string) ([]byte, error) { - client := &http.Client{ - // We MUST NOT follow redirects, see - // https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-3.3 - CheckRedirect: rejectRedirect, - } - - // Note that http does not care for the context deadline, so we need to - // construct it here. - if deadline, ok := ctx.Deadline(); ok { - client.Timeout = deadline.Sub(time.Now()) - } - - resp, err := ctxhttp.Get(ctx, client, url) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - if resp.StatusCode == http.StatusOK { - // Read but up to 10k; policies should be way smaller than that, and - // having a limit prevents abuse/accidents with very large replies. - return ioutil.ReadAll(&io.LimitedReader{resp.Body, 10 * 1024}) - } - return nil, fmt.Errorf("HTTP response status code: %v", resp.StatusCode) -} - -var errRejectRedirect = errors.New("redirects not allowed in MTA-STS") - -func rejectRedirect(req *http.Request, via []*http.Request) error { - return errRejectRedirect -} - -// matchDomain checks if the domain matches the given pattern, according to -// https://tools.ietf.org/html/rfc6125#section-6.4 -// (from https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-4.1). -func matchDomain(domain, pattern string) bool { - domain, dErr := domainToASCII(domain) - pattern, pErr := domainToASCII(pattern) - if dErr != nil || pErr != nil { - // Domains should already have been checked and normalized by the - // caller, exposing this is not worth the API complexity in this case. - return false - } - - domainLabels := strings.Split(domain, ".") - patternLabels := strings.Split(pattern, ".") - - if len(domainLabels) != len(patternLabels) { - return false - } - - for i, p := range patternLabels { - // Wildcards only apply to the first part, see - // https://tools.ietf.org/html/rfc6125#section-6.4.3 #1 and #2. - // This also allows us to do the lenght comparison above. - if p == "*" && i == 0 { - continue - } - - if p != domainLabels[i] { - return false - } - } - - return true -} - -// domainToASCII converts the domain to ASCII form, similar to idna.ToASCII -// but with some preprocessing convenient for our use cases. -func domainToASCII(domain string) (string, error) { - domain = strings.TrimSuffix(domain, ".") - domain = strings.ToLower(domain) - return idna.ToASCII(domain) -} - -// PolicyCache is a caching layer for fetching policies. -// -// Policies are cached by domain, and stored in a single directory. -// The files will have as mtime the time when the policy expires, this makes -// the store simpler, as it can avoid keeping additional metadata. -// -// There is no in-memory caching. This may be added in the future, but for -// now disk is good enough for our purposes. -type PolicyCache struct { - dir string - - sync.Mutex -} - -func NewCache(dir string) (*PolicyCache, error) { - c := &PolicyCache{ - dir: dir, - } - err := os.MkdirAll(dir, 0770) - return c, err -} - -const pathPrefix = "pol:" - -func (c *PolicyCache) domainPath(domain string) string { - // We assume the domain is well formed, sanity check just in case. - if strings.Contains(domain, "/") { - panic("domain contains slash") - } - - return c.dir + "/" + pathPrefix + domain -} - -var ErrExpired = errors.New("cache entry expired") - -func (c *PolicyCache) load(domain string) (*Policy, error) { - fname := c.domainPath(domain) - - fi, err := os.Stat(fname) - if err != nil { - return nil, err - } - if time.Since(fi.ModTime()) > 0 { - cacheExpired.Add(1) - return nil, ErrExpired - } - - data, err := ioutil.ReadFile(fname) - if err != nil { - cacheIOErrors.Add(1) - return nil, err - } - - p := &Policy{} - err = json.Unmarshal(data, p) - if err != nil { - cacheUnmarshalErrors.Add(1) - return nil, err - } - - // The policy should always be valid, as we marshalled it ourselves; - // however, check it just to be safe. - if err := p.Check(); err != nil { - cacheInvalid.Add(1) - return nil, fmt.Errorf( - "%s unmarshalled invalid policy %v: %v", domain, p, err) - } - - return p, nil -} - -func (c *PolicyCache) store(domain string, p *Policy) error { - data, err := json.Marshal(p) - if err != nil { - cacheMarshalErrors.Add(1) - return fmt.Errorf("%s failed to marshal policy %v, error: %v", - domain, p, err) - } - - // Change the modification time to the future, when the policy expires. - // load will check for this to detect expired cache entries, see above for - // the details. - expires := time.Now().Add(p.MaxAge) - chTime := func(fname string) error { - return os.Chtimes(fname, expires, expires) - } - - fname := c.domainPath(domain) - err = safeio.WriteFile(fname, data, 0640, chTime) - if err != nil { - cacheIOErrors.Add(1) - } - return err -} - -func (c *PolicyCache) Fetch(ctx context.Context, domain string) (*Policy, error) { - cacheFetches.Add(1) - tr := trace.New("STSCache.Fetch", domain) - defer tr.Finish() - - p, err := c.load(domain) - if err == nil { - tr.Debugf("cache hit: %v", p) - cacheHits.Add(1) - return p, nil - } - - p, err = Fetch(ctx, domain) - if err != nil { - tr.Debugf("failed to fetch: %v", err) - cacheFailedFetch.Add(1) - return nil, err - } - tr.Debugf("fetched: %v", p) - - // We could do this asynchronously, as we got the policy to give to the - // caller. However, to make troubleshooting easier and the cost of storing - // entries easier to track down, we store synchronously. - // Note that even if the store returns an error, we pass on the policy: at - // this point we rather use the policy even if we couldn't store it in the - // cache. - err = c.store(domain, p) - if err != nil { - tr.Errorf("failed to store: %v", err) - } else { - tr.Debugf("stored") - } - - return p, nil -} - -func (c *PolicyCache) PeriodicallyRefresh(ctx context.Context) { - for ctx.Err() == nil { - cacheRefreshCycles.Add(1) - - c.refresh(ctx) - - // Wait 10 minutes between passes; this is a background refresh and - // there's no need to poke the servers very often. - time.Sleep(10 * time.Minute) - } -} - -func (c *PolicyCache) refresh(ctx context.Context) { - tr := trace.New("STSCache.Refresh", c.dir) - defer tr.Finish() - - entries, err := ioutil.ReadDir(c.dir) - if err != nil { - tr.Errorf("failed to list directory %q: %v", c.dir, err) - return - } - tr.Debugf("%d entries", len(entries)) - - for _, e := range entries { - if !strings.HasPrefix(e.Name(), pathPrefix) { - continue - } - domain := e.Name()[len(pathPrefix):] - cacheRefreshes.Add(1) - tr.Debugf("%v: refreshing", domain) - - fetchCtx, cancel := context.WithTimeout(ctx, 30*time.Second) - p, err := Fetch(fetchCtx, domain) - cancel() - if err != nil { - tr.Debugf("%v: failed to fetch: %v", domain, err) - cacheRefreshErrors.Add(1) - continue - } - tr.Debugf("%v: fetched", domain) - - err = c.store(domain, p) - if err != nil { - tr.Errorf("%v: failed to store: %v", domain, err) - } else { - tr.Debugf("%v: stored", domain) - } - } - - tr.Debugf("refresh done") -} diff --git a/internal/sts/sts_test.go b/internal/sts/sts_test.go deleted file mode 100644 index bc26940..0000000 --- a/internal/sts/sts_test.go +++ /dev/null @@ -1,384 +0,0 @@ -package sts - -import ( - "context" - "expvar" - "fmt" - "io/ioutil" - "net/http" - "net/http/httptest" - "os" - "strconv" - "testing" - "time" -) - -// Test policy for each of the requested domains. Will be served by the test -// HTTP server. -var policyForDomain = map[string]string{ - // domain.com -> valid, with reasonable policy. - "domain.com": ` - { - "version": "STSv1", - "mode": "enforce", - "mx": ["*.mail.domain.com"], - "max_age": 3600 - }`, - - // version99 -> invalid policy (unknown version). - "version99": ` - { - "version": "STSv99", - "mode": "enforce", - "mx": ["*.mail.version99"], - "max_age": 999 - }`, -} - -func testHTTPHandler(w http.ResponseWriter, r *http.Request) { - // For testing, the domain in the path (see urlForDomain). - policy, ok := policyForDomain[r.URL.Path[1:]] - if !ok { - http.Error(w, "not found", 404) - return - } - fmt.Fprintln(w, policy) - return -} - -func TestMain(m *testing.M) { - // Create a test HTTP server, used by the more end-to-end tests. - httpServer := httptest.NewServer(http.HandlerFunc(testHTTPHandler)) - - fakeURLForTesting = httpServer.URL - os.Exit(m.Run()) -} - -func TestParsePolicy(t *testing.T) { - const pol1 = `{ - "version": "STSv1", - "mode": "enforce", - "mx": ["*.mail.example.com"], - "max_age": 123456 -} -` - p, err := parsePolicy([]byte(pol1)) - if err != nil { - t.Errorf("failed to parse policy: %v", err) - } - - t.Logf("pol1: %+v", p) -} - -func TestCheckPolicy(t *testing.T) { - validPs := []Policy{ - {Version: "STSv1", Mode: "enforce", MaxAge: 1 * time.Hour, - MXs: []string{"mx1", "mx2"}}, - {Version: "STSv1", Mode: "report", MaxAge: 1 * time.Hour, - MXs: []string{"mx1"}}, - } - for i, p := range validPs { - if err := p.Check(); err != nil { - t.Errorf("%d policy %v failed check: %v", i, p, err) - } - } - - invalid := []struct { - p Policy - expected error - }{ - {Policy{Version: "STSv2"}, ErrUnknownVersion}, - {Policy{Version: "STSv1"}, ErrInvalidMaxAge}, - {Policy{Version: "STSv1", MaxAge: 1, Mode: "blah"}, ErrInvalidMode}, - {Policy{Version: "STSv1", MaxAge: 1, Mode: "enforce"}, ErrInvalidMX}, - {Policy{Version: "STSv1", MaxAge: 1, Mode: "enforce", MXs: []string{}}, - ErrInvalidMX}, - } - for i, c := range invalid { - if err := c.p.Check(); err != c.expected { - t.Errorf("%d policy %v check: expected %v, got %v", i, c.p, - c.expected, err) - } - } -} - -func TestMatchDomain(t *testing.T) { - cases := []struct { - domain, pattern string - expected bool - }{ - {"lalala", "lalala", true}, - {"a.b.", "a.b", true}, - {"a.b", "a.b.", true}, - {"abc.com", "*.com", true}, - - {"abc.com", "abc.*.com", false}, - {"abc.com", "x.abc.com", false}, - {"x.abc.com", "*.*.com", false}, - - {"ñaca.com", "ñaca.com", true}, - {"Ñaca.com", "ñaca.com", true}, - {"ñaca.com", "Ñaca.com", true}, - {"x.ñaca.com", "x.xn--aca-6ma.com", true}, - {"x.naca.com", "x.xn--aca-6ma.com", false}, - } - - for _, c := range cases { - if r := matchDomain(c.domain, c.pattern); r != c.expected { - t.Errorf("matchDomain(%q, %q) = %v, expected %v", - c.domain, c.pattern, r, c.expected) - } - } -} - -func TestFetch(t *testing.T) { - // Note the data "fetched" for each domain comes from policyForDomain, - // defined in TestMain above. See httpGet for more details. - - // Normal fetch, all valid. - p, err := Fetch(context.Background(), "domain.com") - if err != nil { - t.Errorf("failed to fetch policy: %v", err) - } - t.Logf("domain.com: %+v", p) - - // Domain without a policy (HTTP get fails). - p, err = Fetch(context.Background(), "unknown") - if err == nil { - t.Errorf("fetched unknown policy: %v", p) - } - t.Logf("unknown: got error as expected: %v", err) - - // Domain with an invalid policy (unknown version). - p, err = Fetch(context.Background(), "version99") - if err != ErrUnknownVersion { - t.Errorf("expected error %v, got %v (and policy: %v)", - ErrUnknownVersion, err, p) - } - t.Logf("version99: got expected error: %v", err) -} - -func TestPolicyTooBig(t *testing.T) { - // Construct a valid but very large JSON as a policy. - raw := `{"version": "STSv1", "mode": "enforce", "mx": [` - for i := 0; i < 2000; i++ { - raw += fmt.Sprintf("\"mx%d\", ", i) - } - raw += `"mxlast"], "max_age": 100}` - policyForDomain["toobig"] = raw - - _, err := Fetch(context.Background(), "toobig") - if err == nil { - t.Errorf("fetch worked, but should have failed") - } - t.Logf("got error as expected: %v", err) -} - -// Tests for the policy cache. - -func mustTempDir(t *testing.T) string { - dir, err := ioutil.TempDir("", "sts_test") - if err != nil { - t.Fatal(err) - } - - err = os.Chdir(dir) - if err != nil { - t.Fatal(err) - } - - t.Logf("test directory: %q", dir) - - return dir -} - -func expvarMustEq(t *testing.T, name string, v *expvar.Int, expected int) { - // TODO: Use v.Value once we drop support of Go 1.7. - value, _ := strconv.Atoi(v.String()) - if value != expected { - t.Errorf("%s is %d, expected %d", name, value, expected) - } -} - -func TestCacheBasics(t *testing.T) { - dir := mustTempDir(t) - c, err := NewCache(dir) - if err != nil { - t.Fatal(err) - } - - // Note the data "fetched" for each domain comes from policyForDomain, - // defined in TestMain above. See httpGet for more details. - - // Reset the expvar counters that we use to validate hits, misses, etc. - cacheFetches.Set(0) - cacheHits.Set(0) - - ctx := context.Background() - - // Fetch domain.com, check we get a reasonable policy, and that it's a - // cache miss. - p, err := c.Fetch(ctx, "domain.com") - if err != nil || p.Check() != nil || p.MXs[0] != "*.mail.domain.com" { - t.Errorf("unexpected fetch result - policy = %v ; error = %v", p, err) - } - t.Logf("cache fetched domain.com: %v", p) - expvarMustEq(t, "cacheFetches", cacheFetches, 1) - expvarMustEq(t, "cacheHits", cacheHits, 0) - - // Fetch domain.com again, this time we should see a cache hit. - p, err = c.Fetch(ctx, "domain.com") - if err != nil || p.Check() != nil || p.MXs[0] != "*.mail.domain.com" { - t.Errorf("unexpected fetch result - policy = %v ; error = %v", p, err) - } - t.Logf("cache fetched domain.com: %v", p) - expvarMustEq(t, "cacheFetches", cacheFetches, 2) - expvarMustEq(t, "cacheHits", cacheHits, 1) - - // Simulate an expired cache entry by changing the mtime of domain.com's - // entry to the past. - expires := time.Now().Add(-1 * time.Minute) - os.Chtimes(c.domainPath("domain.com"), expires, expires) - - // Do a third fetch, check that we don't get a cache hit. - p, err = c.Fetch(ctx, "domain.com") - if err != nil || p.Check() != nil || p.MXs[0] != "*.mail.domain.com" { - t.Errorf("unexpected fetch result - policy = %v ; error = %v", p, err) - } - t.Logf("cache fetched domain.com: %v", p) - expvarMustEq(t, "cacheFetches", cacheFetches, 3) - expvarMustEq(t, "cacheHits", cacheHits, 1) - - if !t.Failed() { - os.RemoveAll(dir) - } -} - -// Test how the cache behaves when the files are corrupt. -func TestCacheBadData(t *testing.T) { - dir := mustTempDir(t) - c, err := NewCache(dir) - if err != nil { - t.Fatal(err) - } - - ctx := context.Background() - - cases := []string{ - // Case 1: A file with invalid json, which will fail unmarshalling. - "this is not valid json", - - // Case 2: A file with a parseable but invalid policy. - `{"version": "STSv1", "mode": "INVALID", "mx": ["mx"], max_age": 1}`, - } - - for _, badContent := range cases { - // Reset the expvar counters that we use to validate hits, misses, etc. - cacheFetches.Set(0) - cacheHits.Set(0) - - // Fetch domain.com, should result in the file being added to the - // cache. - p, err := c.Fetch(ctx, "domain.com") - if err != nil { - t.Fatalf("Fetch failed: %v", err) - } - t.Logf("cache fetched domain.com: %v", p) - expvarMustEq(t, "cacheFetches", cacheFetches, 1) - expvarMustEq(t, "cacheHits", cacheHits, 0) - - // Edit the file, filling it with the bad content for this case. - fname := c.domainPath("domain.com") - err = ioutil.WriteFile(fname, []byte(badContent), 0644) - if err != nil { - t.Fatalf("error writing file: %v", err) - } - - // We now expect Fetch to fall back to getting the policy from the - // network (in our case, from policyForDomain). - p, err = c.Fetch(ctx, "domain.com") - if err != nil { - t.Fatalf("Fetch failed: %v", err) - } - t.Logf("cache fetched domain.com: %v", p) - expvarMustEq(t, "cacheFetches", cacheFetches, 2) - expvarMustEq(t, "cacheHits", cacheHits, 0) - - // And now the file should be fine, resulting in a cache hit. - p, err = c.Fetch(ctx, "domain.com") - if err != nil { - t.Fatalf("Fetch failed: %v", err) - } - t.Logf("cache fetched domain.com: %v", p) - expvarMustEq(t, "cacheFetches", cacheFetches, 3) - expvarMustEq(t, "cacheHits", cacheHits, 1) - - // Remove the file, to start with a clean slate for the next case. - os.Remove(fname) - } - - if !t.Failed() { - os.RemoveAll(dir) - } -} - -func mustFetch(t *testing.T, c *PolicyCache, ctx context.Context, d string) *Policy { - p, err := c.Fetch(ctx, d) - if err != nil { - t.Fatalf("Fetch %q failed: %v", d, err) - } - t.Logf("Fetch %q: %v", d, p) - return p -} - -func TestCacheRefresh(t *testing.T) { - dir := mustTempDir(t) - c, err := NewCache(dir) - if err != nil { - t.Fatal(err) - } - - ctx := context.Background() - - policyForDomain["refresh-test"] = ` - {"version": "STSv1", "mode": "enforce", "mx": ["mx"], "max_age": 100}` - p := mustFetch(t, c, ctx, "refresh-test") - if p.MaxAge != 100*time.Second { - t.Fatalf("policy.MaxAge is %v, expected 100s", p.MaxAge) - } - - // Change the "published" policy, check that we see the old version at - // fetch (should be cached), and a new version after a refresh. - policyForDomain["refresh-test"] = ` - {"version": "STSv1", "mode": "enforce", "mx": ["mx"], "max_age": 200}` - - p = mustFetch(t, c, ctx, "refresh-test") - if p.MaxAge != 100*time.Second { - t.Fatalf("policy.MaxAge is %v, expected 100s", p.MaxAge) - } - - c.refresh(ctx) - - p = mustFetch(t, c, ctx, "refresh-test") - if p.MaxAge != 200*time.Second { - t.Fatalf("policy.MaxAge is %v, expected 200s", p.MaxAge) - } - - if !t.Failed() { - os.RemoveAll(dir) - } -} - -func TestURLForDomain(t *testing.T) { - // This function will behave differently if fakeURLForTesting is set, so - // temporarily unset it. - oldURL := fakeURLForTesting - fakeURLForTesting = "" - defer func() { fakeURLForTesting = oldURL }() - - got := urlForDomain("a-test-domain") - expected := "https://mta-sts.a-test-domain/.well-known/mta-sts.json" - if got != expected { - t.Errorf("got %q, expected %q", got, expected) - } -}