author | Alberto Bertogli
<albertito@blitiri.com.ar> 2017-01-04 16:26:20 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2017-02-28 22:27:15 UTC |
parent | b8551729dbea260ee9384fb3a2113c10e60b9421 |
cmd/smtp-check/smtp-check.go | +25 | -0 |
internal/sts/sts.go | +218 | -0 |
internal/sts/sts_test.go | +119 | -0 |
diff --git a/cmd/smtp-check/smtp-check.go b/cmd/smtp-check/smtp-check.go index c7f03a7..2dd2471 100644 --- a/cmd/smtp-check/smtp-check.go +++ b/cmd/smtp-check/smtp-check.go @@ -2,13 +2,16 @@ package main import ( + "context" "crypto/tls" "flag" "log" "net" "net/smtp" + "time" "blitiri.com.ar/go/chasquid/internal/spf" + "blitiri.com.ar/go/chasquid/internal/sts" "blitiri.com.ar/go/chasquid/internal/tlsconst" "golang.org/x/net/idna" @@ -34,6 +37,21 @@ func main() { log.Fatalf("IDNA conversion failed: %v", err) } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + log.Printf("=== STS policy") + policy, err := sts.UncheckedFetch(ctx, domain) + if err != nil { + log.Printf("Not available (%s)", err) + } else { + log.Printf("Parsed contents: [%+v]\n", *policy) + if err := policy.Check(); err != nil { + log.Fatalf("Invalid: %v", err) + } + log.Printf("OK") + } + mxs, err := net.LookupMX(domain) if err != nil { log.Fatalf("MX lookup: %v", err) @@ -83,6 +101,13 @@ func main() { c.Close() } + if policy != nil { + if !policy.MXIsAllowed(mx.Host) { + log.Fatalf("NOT allowed by STS policy") + } + log.Printf("Allowed by policy") + } + log.Printf("") } diff --git a/internal/sts/sts.go b/internal/sts/sts.go new file mode 100644 index 0000000..df78758 --- /dev/null +++ b/internal/sts/sts.go @@ -0,0 +1,218 @@ +// Package sts implements the MTA-STS (Strict Transport Security), based on +// the current draft, https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02. +// +// This is an EXPERIMENTAL implementation for now. +// +// It lacks (at least) the following: +// - Caching. +// - DNS TXT checking. +// - Facilities for reporting. +// +package sts + +import ( + "context" + "encoding/json" + "errors" + "io/ioutil" + "net/http" + "strings" + "time" + + "golang.org/x/net/context/ctxhttp" + "golang.org/x/net/idna" +) + +// Policy represents a parsed policy. +// https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-3.2 +type Policy struct { + Version string `json:"version"` + Mode Mode `json:"mode"` + MXs []string `json:"mx"` + MaxAge time.Duration `json:"max_age"` +} + +type Mode string + +// Valid modes. +const ( + Enforce = Mode("enforce") + Report = Mode("report") +) + +// parsePolicy parses a JSON representation of the policy, and returns the +// corresponding Policy structure. +func parsePolicy(raw []byte) (*Policy, error) { + p := &Policy{} + if err := json.Unmarshal(raw, p); err != nil { + return nil, err + } + + // MaxAge is in seconds. + p.MaxAge = p.MaxAge * time.Second + + return p, nil +} + +var ( + ErrUnknownVersion = errors.New("unknown policy version") + ErrInvalidMaxAge = errors.New("invalid max_age") + ErrInvalidMode = errors.New("invalid mode") +) + +// Check that the policy contents are valid. +func (p *Policy) Check() error { + if p.Version != "STSv1" { + return ErrUnknownVersion + } + if p.MaxAge <= 0 { + return ErrInvalidMaxAge + } + + if p.Mode != Enforce && p.Mode != Report { + return ErrInvalidMode + } + + return nil +} + +// MXMatches checks if the given MX is allowed, according to the policy. +// https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-4.1 +func (p *Policy) MXIsAllowed(mx string) bool { + // TODO: Clarify how we should treat an empty MX list. + for _, pattern := range p.MXs { + if matchDomain(mx, pattern) { + return true + } + } + + return false +} + +// UncheckedFetch fetches and parses the policy, but does NOT check it. +// This can be useful for debugging and troubleshooting, but you should always +// call Check on the policy before using it. +func UncheckedFetch(ctx context.Context, domain string) (*Policy, error) { + // Convert the domain to ascii form, as httpGet does not support IDNs in + // any other way. + domain, err := idna.ToASCII(domain) + if err != nil { + return nil, err + } + + // URL composed from the domain, as explained in: + // https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-3.3 + // https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-3.2 + url := "https://mta-sts." + domain + "/.well-known/mta-sts.json" + + rawPolicy, err := httpGet(ctx, url) + if err != nil { + return nil, err + } + + return parsePolicy(rawPolicy) +} + +// Fetch a policy for the given domain. Note this results in various network +// lookups and HTTPS GETs, so it can be slow. +// The returned policy is parsed and sanity-checked (using Policy.Check), so +// it should be safe to use. +func Fetch(ctx context.Context, domain string) (*Policy, error) { + p, err := UncheckedFetch(ctx, domain) + if err != nil { + return nil, err + } + + err = p.Check() + if err != nil { + return nil, err + } + + return p, nil +} + +// Fake HTTP content for testing purposes only. +var fakeContent = map[string]string{} + +// httpGet performs an HTTP GET of the given URL, using the context and +// rejecting redirects, as per the standard. +func httpGet(ctx context.Context, url string) ([]byte, error) { + client := &http.Client{ + // We MUST NOT follow redirects, see + // https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-3.3 + CheckRedirect: rejectRedirect, + } + + // Note that http does not care for the context deadline, so we need to + // construct it here. + if deadline, ok := ctx.Deadline(); ok { + client.Timeout = deadline.Sub(time.Now()) + } + + if len(fakeContent) > 0 { + // If we have fake content for testing, then return the content for + // the URL, or an error if it's missing. + // This makes sure we don't make actual requests for testing. + if d, ok := fakeContent[url]; ok { + return []byte(d), nil + } + return nil, errors.New("error for testing") + } + + resp, err := ctxhttp.Get(ctx, client, url) + if err != nil { + return nil, err + } + + defer resp.Body.Close() + return ioutil.ReadAll(resp.Body) +} + +var errRejectRedirect = errors.New("redirects not allowed in MTA-STS") + +func rejectRedirect(req *http.Request, via []*http.Request) error { + return errRejectRedirect +} + +// matchDomain checks if the domain matches the given pattern, according to +// https://tools.ietf.org/html/rfc6125#section-6.4 +// (from https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-4.1). +func matchDomain(domain, pattern string) bool { + domain, dErr := domainToASCII(domain) + pattern, pErr := domainToASCII(pattern) + if dErr != nil || pErr != nil { + // Domains should already have been checked and normalized by the + // caller, exposing this is not worth the API complexity in this case. + return false + } + + domainLabels := strings.Split(domain, ".") + patternLabels := strings.Split(pattern, ".") + + if len(domainLabels) != len(patternLabels) { + return false + } + + for i, p := range patternLabels { + // Wildcards only apply to the first part, see + // https://tools.ietf.org/html/rfc6125#section-6.4.3 #1 and #2. + // This also allows us to do the lenght comparison above. + if p == "*" && i == 0 { + continue + } + + if p != domainLabels[i] { + return false + } + } + + return true +} + +// domainToASCII converts the domain to ASCII form, similar to idna.ToASCII +// but with some preprocessing convenient for our use cases. +func domainToASCII(domain string) (string, error) { + domain = strings.TrimSuffix(domain, ".") + domain = strings.ToLower(domain) + return idna.ToASCII(domain) +} diff --git a/internal/sts/sts_test.go b/internal/sts/sts_test.go new file mode 100644 index 0000000..aab8b6b --- /dev/null +++ b/internal/sts/sts_test.go @@ -0,0 +1,119 @@ +package sts + +import ( + "context" + "testing" + "time" +) + +func TestParsePolicy(t *testing.T) { + const pol1 = `{ + "version": "STSv1", + "mode": "enforce", + "mx": ["*.mail.example.com"], + "max_age": 123456 +} +` + p, err := parsePolicy([]byte(pol1)) + if err != nil { + t.Errorf("failed to parse policy: %v", err) + } + + t.Logf("pol1: %+v", p) +} + +func TestCheckPolicy(t *testing.T) { + validPs := []Policy{ + {Version: "STSv1", Mode: "enforce", MaxAge: 1 * time.Hour}, + {Version: "STSv1", Mode: "report", MaxAge: 1 * time.Hour}, + {Version: "STSv1", Mode: "report", MaxAge: 1 * time.Hour, + MXs: []string{"mx1", "mx2"}}, + } + for i, p := range validPs { + if err := p.Check(); err != nil { + t.Errorf("%d policy %v failed check: %v", i, p, err) + } + } + + invalid := []struct { + p Policy + expected error + }{ + {Policy{Version: "STSv2"}, ErrUnknownVersion}, + {Policy{Version: "STSv1"}, ErrInvalidMaxAge}, + {Policy{Version: "STSv1", MaxAge: 1, Mode: "blah"}, ErrInvalidMode}, + } + for i, c := range invalid { + if err := c.p.Check(); err != c.expected { + t.Errorf("%d policy %v check: expected %v, got %v", i, c.p, + c.expected, err) + } + } +} + +func TestMatchDomain(t *testing.T) { + cases := []struct { + domain, pattern string + expected bool + }{ + {"lalala", "lalala", true}, + {"a.b.", "a.b", true}, + {"a.b", "a.b.", true}, + {"abc.com", "*.com", true}, + + {"abc.com", "abc.*.com", false}, + {"abc.com", "x.abc.com", false}, + {"x.abc.com", "*.*.com", false}, + + {"ñaca.com", "ñaca.com", true}, + {"Ñaca.com", "ñaca.com", true}, + {"ñaca.com", "Ñaca.com", true}, + {"x.ñaca.com", "x.xn--aca-6ma.com", true}, + {"x.naca.com", "x.xn--aca-6ma.com", false}, + } + + for _, c := range cases { + if r := matchDomain(c.domain, c.pattern); r != c.expected { + t.Errorf("matchDomain(%q, %q) = %v, expected %v", + c.domain, c.pattern, r, c.expected) + } + } +} + +func TestFetch(t *testing.T) { + // Normal fetch, all valid. + fakeContent["https://mta-sts.domain.com/.well-known/mta-sts.json"] = ` + { + "version": "STSv1", + "mode": "enforce", + "mx": ["*.mail.example.com"], + "max_age": 123456 + }` + p, err := Fetch(context.Background(), "domain.com") + if err != nil { + t.Errorf("failed to fetch policy: %v", err) + } + t.Logf("domain.com: %+v", p) + + // Domain without a policy (HTTP get fails). + p, err = Fetch(context.Background(), "unknown") + if err == nil { + t.Errorf("fetched unknown policy: %v", p) + } + t.Logf("unknown: got error as expected: %v", err) + + // Domain with an invalid policy (unknown version). + fakeContent["https://mta-sts.version99/.well-known/mta-sts.json"] = ` + { + "version": "STSv99", + "mode": "enforce", + "mx": ["*.mail.example.com"], + "max_age": 123456 + }` + p, err = Fetch(context.Background(), "version99") + if err != ErrUnknownVersion { + t.Errorf("expected error %v, got %v (and policy: %v)", + ErrUnknownVersion, err, p) + } + t.Logf("version99: got expected error: %v", err) +}