git » chasquid » commit 933ab54

sts: Experimental MTA-STS (Strict Transport Security) implementation

author Alberto Bertogli
2017-01-04 16:26:20 UTC
committer Alberto Bertogli
2017-02-28 22:27:15 UTC
parent b8551729dbea260ee9384fb3a2113c10e60b9421

sts: Experimental MTA-STS (Strict Transport Security) implementation

This EXPERIMENTAL patch has a basic implementation of MTA-STS (Strict
Transport Security), based on the current draft at
https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02.

It integrates the policy fetching and checking into the smtp-check tool
for convenience, but not yet in chasquid itself.

This is a proof of concept. Many features and tests are missing; in
particular, there is no caching at all yet.

cmd/smtp-check/smtp-check.go +25 -0
internal/sts/sts.go +218 -0
internal/sts/sts_test.go +119 -0

diff --git a/cmd/smtp-check/smtp-check.go b/cmd/smtp-check/smtp-check.go
index c7f03a7..2dd2471 100644
--- a/cmd/smtp-check/smtp-check.go
+++ b/cmd/smtp-check/smtp-check.go
@@ -2,13 +2,16 @@
 package main
 
 import (
+	"context"
 	"crypto/tls"
 	"flag"
 	"log"
 	"net"
 	"net/smtp"
+	"time"
 
 	"blitiri.com.ar/go/chasquid/internal/spf"
+	"blitiri.com.ar/go/chasquid/internal/sts"
 	"blitiri.com.ar/go/chasquid/internal/tlsconst"
 
 	"golang.org/x/net/idna"
@@ -34,6 +37,21 @@ func main() {
 		log.Fatalf("IDNA conversion failed: %v", err)
 	}
 
+	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+	defer cancel()
+
+	log.Printf("=== STS policy")
+	policy, err := sts.UncheckedFetch(ctx, domain)
+	if err != nil {
+		log.Printf("Not available (%s)", err)
+	} else {
+		log.Printf("Parsed contents:  [%+v]\n", *policy)
+		if err := policy.Check(); err != nil {
+			log.Fatalf("Invalid: %v", err)
+		}
+		log.Printf("OK")
+	}
+
 	mxs, err := net.LookupMX(domain)
 	if err != nil {
 		log.Fatalf("MX lookup: %v", err)
@@ -83,6 +101,13 @@ func main() {
 			c.Close()
 		}
 
+		if policy != nil {
+			if !policy.MXIsAllowed(mx.Host) {
+				log.Fatalf("NOT allowed by STS policy")
+			}
+			log.Printf("Allowed by policy")
+		}
+
 		log.Printf("")
 	}
 
diff --git a/internal/sts/sts.go b/internal/sts/sts.go
new file mode 100644
index 0000000..df78758
--- /dev/null
+++ b/internal/sts/sts.go
@@ -0,0 +1,218 @@
+// Package sts implements the MTA-STS (Strict Transport Security), based on
+// the current draft, https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02.
+//
+// This is an EXPERIMENTAL implementation for now.
+//
+// It lacks (at least) the following:
+// - Caching.
+// - DNS TXT checking.
+// - Facilities for reporting.
+//
+package sts
+
+import (
+	"context"
+	"encoding/json"
+	"errors"
+	"io/ioutil"
+	"net/http"
+	"strings"
+	"time"
+
+	"golang.org/x/net/context/ctxhttp"
+	"golang.org/x/net/idna"
+)
+
+// Policy represents a parsed policy.
+// https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-3.2
+type Policy struct {
+	Version string        `json:"version"`
+	Mode    Mode          `json:"mode"`
+	MXs     []string      `json:"mx"`
+	MaxAge  time.Duration `json:"max_age"`
+}
+
+type Mode string
+
+// Valid modes.
+const (
+	Enforce = Mode("enforce")
+	Report  = Mode("report")
+)
+
+// parsePolicy parses a JSON representation of the policy, and returns the
+// corresponding Policy structure.
+func parsePolicy(raw []byte) (*Policy, error) {
+	p := &Policy{}
+	if err := json.Unmarshal(raw, p); err != nil {
+		return nil, err
+	}
+
+	// MaxAge is in seconds.
+	p.MaxAge = p.MaxAge * time.Second
+
+	return p, nil
+}
+
+var (
+	ErrUnknownVersion = errors.New("unknown policy version")
+	ErrInvalidMaxAge  = errors.New("invalid max_age")
+	ErrInvalidMode    = errors.New("invalid mode")
+)
+
+// Check that the policy contents are valid.
+func (p *Policy) Check() error {
+	if p.Version != "STSv1" {
+		return ErrUnknownVersion
+	}
+	if p.MaxAge <= 0 {
+		return ErrInvalidMaxAge
+	}
+
+	if p.Mode != Enforce && p.Mode != Report {
+		return ErrInvalidMode
+	}
+
+	return nil
+}
+
+// MXMatches checks if the given MX is allowed, according to the policy.
+// https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-4.1
+func (p *Policy) MXIsAllowed(mx string) bool {
+	// TODO: Clarify how we should treat an empty MX list.
+	for _, pattern := range p.MXs {
+		if matchDomain(mx, pattern) {
+			return true
+		}
+	}
+
+	return false
+}
+
+// UncheckedFetch fetches and parses the policy, but does NOT check it.
+// This can be useful for debugging and troubleshooting, but you should always
+// call Check on the policy before using it.
+func UncheckedFetch(ctx context.Context, domain string) (*Policy, error) {
+	// Convert the domain to ascii form, as httpGet does not support IDNs in
+	// any other way.
+	domain, err := idna.ToASCII(domain)
+	if err != nil {
+		return nil, err
+	}
+
+	// URL composed from the domain, as explained in:
+	// https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-3.3
+	// https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-3.2
+	url := "https://mta-sts." + domain + "/.well-known/mta-sts.json"
+
+	rawPolicy, err := httpGet(ctx, url)
+	if err != nil {
+		return nil, err
+	}
+
+	return parsePolicy(rawPolicy)
+}
+
+// Fetch a policy for the given domain. Note this results in various network
+// lookups and HTTPS GETs, so it can be slow.
+// The returned policy is parsed and sanity-checked (using Policy.Check), so
+// it should be safe to use.
+func Fetch(ctx context.Context, domain string) (*Policy, error) {
+	p, err := UncheckedFetch(ctx, domain)
+	if err != nil {
+		return nil, err
+	}
+
+	err = p.Check()
+	if err != nil {
+		return nil, err
+	}
+
+	return p, nil
+}
+
+// Fake HTTP content for testing purposes only.
+var fakeContent = map[string]string{}
+
+// httpGet performs an HTTP GET of the given URL, using the context and
+// rejecting redirects, as per the standard.
+func httpGet(ctx context.Context, url string) ([]byte, error) {
+	client := &http.Client{
+		// We MUST NOT follow redirects, see
+		// https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-3.3
+		CheckRedirect: rejectRedirect,
+	}
+
+	// Note that http does not care for the context deadline, so we need to
+	// construct it here.
+	if deadline, ok := ctx.Deadline(); ok {
+		client.Timeout = deadline.Sub(time.Now())
+	}
+
+	if len(fakeContent) > 0 {
+		// If we have fake content for testing, then return the content for
+		// the URL, or an error if it's missing.
+		// This makes sure we don't make actual requests for testing.
+		if d, ok := fakeContent[url]; ok {
+			return []byte(d), nil
+		}
+		return nil, errors.New("error for testing")
+	}
+
+	resp, err := ctxhttp.Get(ctx, client, url)
+	if err != nil {
+		return nil, err
+	}
+
+	defer resp.Body.Close()
+	return ioutil.ReadAll(resp.Body)
+}
+
+var errRejectRedirect = errors.New("redirects not allowed in MTA-STS")
+
+func rejectRedirect(req *http.Request, via []*http.Request) error {
+	return errRejectRedirect
+}
+
+// matchDomain checks if the domain matches the given pattern, according to
+// https://tools.ietf.org/html/rfc6125#section-6.4
+// (from https://tools.ietf.org/html/draft-ietf-uta-mta-sts-02#section-4.1).
+func matchDomain(domain, pattern string) bool {
+	domain, dErr := domainToASCII(domain)
+	pattern, pErr := domainToASCII(pattern)
+	if dErr != nil || pErr != nil {
+		// Domains should already have been checked and normalized by the
+		// caller, exposing this is not worth the API complexity in this case.
+		return false
+	}
+
+	domainLabels := strings.Split(domain, ".")
+	patternLabels := strings.Split(pattern, ".")
+
+	if len(domainLabels) != len(patternLabels) {
+		return false
+	}
+
+	for i, p := range patternLabels {
+		// Wildcards only apply to the first part, see
+		// https://tools.ietf.org/html/rfc6125#section-6.4.3 #1 and #2.
+		// This also allows us to do the lenght comparison above.
+		if p == "*" && i == 0 {
+			continue
+		}
+
+		if p != domainLabels[i] {
+			return false
+		}
+	}
+
+	return true
+}
+
+// domainToASCII converts the domain to ASCII form, similar to idna.ToASCII
+// but with some preprocessing convenient for our use cases.
+func domainToASCII(domain string) (string, error) {
+	domain = strings.TrimSuffix(domain, ".")
+	domain = strings.ToLower(domain)
+	return idna.ToASCII(domain)
+}
diff --git a/internal/sts/sts_test.go b/internal/sts/sts_test.go
new file mode 100644
index 0000000..aab8b6b
--- /dev/null
+++ b/internal/sts/sts_test.go
@@ -0,0 +1,119 @@
+package sts
+
+import (
+	"context"
+	"testing"
+	"time"
+)
+
+func TestParsePolicy(t *testing.T) {
+	const pol1 = `{
+  "version": "STSv1",
+  "mode": "enforce",
+  "mx": ["*.mail.example.com"],
+  "max_age": 123456
+}
+`
+	p, err := parsePolicy([]byte(pol1))
+	if err != nil {
+		t.Errorf("failed to parse policy: %v", err)
+	}
+
+	t.Logf("pol1: %+v", p)
+}
+
+func TestCheckPolicy(t *testing.T) {
+	validPs := []Policy{
+		{Version: "STSv1", Mode: "enforce", MaxAge: 1 * time.Hour},
+		{Version: "STSv1", Mode: "report", MaxAge: 1 * time.Hour},
+		{Version: "STSv1", Mode: "report", MaxAge: 1 * time.Hour,
+			MXs: []string{"mx1", "mx2"}},
+	}
+	for i, p := range validPs {
+		if err := p.Check(); err != nil {
+			t.Errorf("%d policy %v failed check: %v", i, p, err)
+		}
+	}
+
+	invalid := []struct {
+		p        Policy
+		expected error
+	}{
+		{Policy{Version: "STSv2"}, ErrUnknownVersion},
+		{Policy{Version: "STSv1"}, ErrInvalidMaxAge},
+		{Policy{Version: "STSv1", MaxAge: 1, Mode: "blah"}, ErrInvalidMode},
+	}
+	for i, c := range invalid {
+		if err := c.p.Check(); err != c.expected {
+			t.Errorf("%d policy %v check: expected %v, got %v", i, c.p,
+				c.expected, err)
+		}
+	}
+}
+
+func TestMatchDomain(t *testing.T) {
+	cases := []struct {
+		domain, pattern string
+		expected        bool
+	}{
+		{"lalala", "lalala", true},
+		{"a.b.", "a.b", true},
+		{"a.b", "a.b.", true},
+		{"abc.com", "*.com", true},
+
+		{"abc.com", "abc.*.com", false},
+		{"abc.com", "x.abc.com", false},
+		{"x.abc.com", "*.*.com", false},
+
+		{"ñaca.com", "ñaca.com", true},
+		{"Ñaca.com", "ñaca.com", true},
+		{"ñaca.com", "Ñaca.com", true},
+		{"x.ñaca.com", "x.xn--aca-6ma.com", true},
+		{"x.naca.com", "x.xn--aca-6ma.com", false},
+	}
+
+	for _, c := range cases {
+		if r := matchDomain(c.domain, c.pattern); r != c.expected {
+			t.Errorf("matchDomain(%q, %q) = %v, expected %v",
+				c.domain, c.pattern, r, c.expected)
+		}
+	}
+}
+
+func TestFetch(t *testing.T) {
+	// Normal fetch, all valid.
+	fakeContent["https://mta-sts.domain.com/.well-known/mta-sts.json"] = `
+		{
+             "version": "STSv1",
+             "mode": "enforce",
+             "mx": ["*.mail.example.com"],
+             "max_age": 123456
+        }`
+	p, err := Fetch(context.Background(), "domain.com")
+	if err != nil {
+		t.Errorf("failed to fetch policy: %v", err)
+	}
+	t.Logf("domain.com: %+v", p)
+
+	// Domain without a policy (HTTP get fails).
+	p, err = Fetch(context.Background(), "unknown")
+	if err == nil {
+		t.Errorf("fetched unknown policy: %v", p)
+	}
+	t.Logf("unknown: got error as expected: %v", err)
+
+	// Domain with an invalid policy (unknown version).
+	fakeContent["https://mta-sts.version99/.well-known/mta-sts.json"] = `
+		{
+             "version": "STSv99",
+             "mode": "enforce",
+             "mx": ["*.mail.example.com"],
+             "max_age": 123456
+        }`
+	p, err = Fetch(context.Background(), "version99")
+	if err != ErrUnknownVersion {
+		t.Errorf("expected error %v, got %v (and policy: %v)",
+			ErrUnknownVersion, err, p)
+	}
+	t.Logf("version99: got expected error: %v", err)
+}