git » chasquid » smarthost » tree

[smarthost] / internal / courier / smtp_test.go

package courier

import (
	"fmt"
	"net"
	"strings"
	"testing"
	"time"

	"blitiri.com.ar/go/chasquid/internal/domaininfo"
	"blitiri.com.ar/go/chasquid/internal/sts"
	"blitiri.com.ar/go/chasquid/internal/testlib"
	"blitiri.com.ar/go/chasquid/internal/trace"
)

// This domain will cause idna.ToASCII to fail.
var invalidDomain = "test " + strings.Repeat("x", 65536) + "\uff00"

// Override the netLookupMX function, to return controlled results for
// testing.
var testMX = map[string][]*net.MX{}
var testMXErr = map[string]error{}

func init() {
	netLookupMX = func(name string) ([]*net.MX, error) {
		return testMX[name], testMXErr[name]
	}
}

func newSMTP(t *testing.T) (*SMTP, string) {
	dir := testlib.MustTempDir(t)
	dinfo, err := domaininfo.New(dir)
	if err != nil {
		t.Fatal(err)
	}

	return &SMTP{"hello", dinfo, nil}, dir
}

func TestSMTP(t *testing.T) {
	// Shorten the total timeout, so the test fails quickly if the protocol
	// gets stuck.
	smtpTotalTimeout = 5 * time.Second

	responses := map[string]string{
		"_welcome":          "220 welcome\n",
		"EHLO hello":        "250 ehlo ok\n",
		"MAIL FROM:<me@me>": "250 mail ok\n",
		"RCPT TO:<to@to>":   "250 rcpt ok\n",
		"DATA":              "354 send data\n",
		"_DATA":             "250 data ok\n",
		"QUIT":              "250 quit ok\n",
	}
	srv := newFakeServer(t, responses)
	defer srv.Cleanup()
	host, port := srv.HostPort()

	// Put a non-existing host first, so we check that if the first host
	// doesn't work, we try with the rest.
	// The host we use is invalid, to avoid having to do an actual network
	// lookup whick makes the test more hermetic. This is a hack, ideally we
	// would be able to override the default resolver, but Go does not
	// implement that yet.
	testMX["to"] = []*net.MX{
		{Host: ":::", Pref: 10},
		{Host: host, Pref: 20},
	}
	*smtpPort = port

	s, tmpDir := newSMTP(t)
	defer testlib.RemoveIfOk(t, tmpDir)
	err, _ := s.Deliver("me@me", "to@to", []byte("data"))
	if err != nil {
		t.Errorf("deliver failed: %v", err)
	}

	srv.Wait()
}

func TestSMTPErrors(t *testing.T) {
	// Shorten the total timeout, so the test fails quickly if the protocol
	// gets stuck.
	smtpTotalTimeout = 1 * time.Second

	responses := []map[string]string{
		// First test: hang response, should fail due to timeout.
		{
			"_welcome": "220 no newline",
		},

		// MAIL FROM not allowed.
		{
			"_welcome":          "220 mail from not allowed\n",
			"EHLO hello":        "250 ehlo ok\n",
			"MAIL FROM:<me@me>": "501 mail error\n",
		},

		// RCPT TO not allowed.
		{
			"_welcome":          "220 rcpt to not allowed\n",
			"EHLO hello":        "250 ehlo ok\n",
			"MAIL FROM:<me@me>": "250 mail ok\n",
			"RCPT TO:<to@to>":   "501 rcpt error\n",
		},

		// DATA error.
		{
			"_welcome":          "220 data error\n",
			"EHLO hello":        "250 ehlo ok\n",
			"MAIL FROM:<me@me>": "250 mail ok\n",
			"RCPT TO:<to@to>":   "250 rcpt ok\n",
			"DATA":              "554 data error\n",
		},

		// DATA response error.
		{
			"_welcome":          "220 data response error\n",
			"EHLO hello":        "250 ehlo ok\n",
			"MAIL FROM:<me@me>": "250 mail ok\n",
			"RCPT TO:<to@to>":   "250 rcpt ok\n",
			"DATA":              "354 send data\n",
			"_DATA":             "551 data response error\n",
		},
	}

	for _, rs := range responses {
		srv := newFakeServer(t, rs)
		defer srv.Cleanup()
		host, port := srv.HostPort()

		testMX["to"] = []*net.MX{{Host: host, Pref: 10}}
		*smtpPort = port

		s, tmpDir := newSMTP(t)
		defer testlib.RemoveIfOk(t, tmpDir)
		err, _ := s.Deliver("me@me", "to@to", []byte("data"))
		if err == nil {
			t.Errorf("deliver not failed in case %q: %v", rs["_welcome"], err)
		}
		t.Logf("failed as expected: %v", err)

		srv.Wait()
	}
}

func TestNoMXServer(t *testing.T) {
	testMX["to"] = []*net.MX{}

	s, tmpDir := newSMTP(t)
	defer testlib.RemoveIfOk(t, tmpDir)
	err, permanent := s.Deliver("me@me", "to@to", []byte("data"))
	if err == nil {
		t.Errorf("delivery worked, expected failure")
	}
	if !permanent {
		t.Errorf("expected permanent failure, got transient (%v)", err)
	}
	t.Logf("got permanent failure, as expected: %v", err)
}

func TestTooManyMX(t *testing.T) {
	tr := trace.New("test", "test")
	testMX["domain"] = []*net.MX{
		{Host: "h1", Pref: 10}, {Host: "h2", Pref: 20},
		{Host: "h3", Pref: 30}, {Host: "h4", Pref: 40},
		{Host: "h5", Pref: 50}, {Host: "h5", Pref: 60},
	}
	mxs, err, perm := lookupMXs(tr, "domain")
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
	}
	if perm != true {
		t.Fatalf("expected perm == true")
	}
	if len(mxs) != 5 {
		t.Errorf("expected len(mxs) == 5, got: %v", mxs)
	}
}

func TestFallbackToA(t *testing.T) {
	tr := trace.New("test", "test")
	testMX["domain"] = nil
	testMXErr["domain"] = &net.DNSError{
		Err:         "no such host (test)",
		IsTemporary: false,
		IsNotFound:  true,
	}

	mxs, err, perm := lookupMXs(tr, "domain")
	if err != nil {
		t.Errorf("unexpected error: %v", err)
	}
	if perm != true {
		t.Errorf("expected perm == true")
	}
	if !(len(mxs) == 1 && mxs[0] == "domain") {
		t.Errorf("expected mxs == [domain], got: %v", mxs)
	}
}

func TestTemporaryDNSerror(t *testing.T) {
	tr := trace.New("test", "test")
	testMX["domain"] = nil
	testMXErr["domain"] = &net.DNSError{
		Err:         "temp error (test)",
		IsTemporary: true,
	}

	mxs, err, perm := lookupMXs(tr, "domain")
	if !(mxs == nil && err == testMXErr["domain"]) {
		t.Errorf("expected mxs == nil, err == test error, got: %v, %v", mxs, err)
	}
	if perm != false {
		t.Errorf("expected perm == false")
	}
}

func TestMXLookupError(t *testing.T) {
	tr := trace.New("test", "test")
	testMX["domain"] = nil
	testMXErr["domain"] = fmt.Errorf("test error")

	mxs, err, perm := lookupMXs(tr, "domain")
	if !(mxs == nil && err == testMXErr["domain"]) {
		t.Errorf("expected mxs == nil, err == test error, got: %v, %v", mxs, err)
	}
	if perm != false {
		t.Errorf("expected perm == false")
	}
}

func TestLookupInvalidDomain(t *testing.T) {
	tr := trace.New("test", "test")

	mxs, err, perm := lookupMXs(tr, invalidDomain)
	if !(mxs == nil && err != nil) {
		t.Errorf("expected err != nil, got: %v, %v", mxs, err)
	}
	if perm != true {
		t.Fatalf("expected perm == true")
	}
}

// Server fake responses for a complete TLS delivery.
// We use this in a few tests, so make it common.
var tlsResponses = map[string]string{
	"_welcome":          "220 welcome\n",
	"EHLO hello":        "250-ehlo ok\n250 STARTTLS\n",
	"STARTTLS":          "220 starttls go\n",
	"_STARTTLS":         "ok",
	"MAIL FROM:<me@me>": "250 mail ok\n",
	"RCPT TO:<to@to>":   "250 rcpt ok\n",
	"DATA":              "354 send data\n",
	"_DATA":             "250 data ok\n",
	"QUIT":              "250 quit ok\n",
}

func TestTLS(t *testing.T) {
	smtpTotalTimeout = 5 * time.Second
	srv := newFakeServer(t, tlsResponses)
	defer srv.Cleanup()
	_, *smtpPort = srv.HostPort()

	testMX["to"] = []*net.MX{
		{Host: "localhost", Pref: 20},
	}

	s, tmpDir := newSMTP(t)
	defer testlib.RemoveIfOk(t, tmpDir)
	err, _ := s.Deliver("me@me", "to@to", []byte("data"))
	if err != nil {
		t.Errorf("deliver failed: %v", err)
	}

	srv.Wait()

	// Now do another delivery, but without TLS, to check that the detection
	// of connection downgrade is working.
	responses := map[string]string{
		"_welcome":          "220 welcome\n",
		"EHLO hello":        "250 ehlo ok\n",
		"MAIL FROM:<me@me>": "250 mail ok\n",
		"RCPT TO:<to@to>":   "250 rcpt ok\n",
		"DATA":              "354 send data\n",
		"_DATA":             "250 data ok\n",
		"QUIT":              "250 quit ok\n",
	}
	srv = newFakeServer(t, responses)
	defer srv.Cleanup()
	_, *smtpPort = srv.HostPort()

	err, permanent := s.Deliver("me@me", "to@to", []byte("data"))
	if !strings.Contains(err.Error(),
		"Security level check failed (level:PLAIN)") {
		t.Errorf("expected sec level check failed, got: %v", err)
	}
	if permanent != false {
		t.Errorf("expected transient failure, got permanent")
	}

	srv.Wait()
}

func TestTLSError(t *testing.T) {
	smtpTotalTimeout = 5 * time.Second

	responses := map[string]string{
		"_welcome":   "220 welcome\n",
		"EHLO hello": "250-ehlo ok\n250 STARTTLS\n",
		"STARTTLS":   "500 starttls err\n",
		"_STARTTLS":  "no",
	}
	srv := newFakeServer(t, responses)
	defer srv.Cleanup()
	_, *smtpPort = srv.HostPort()

	testMX["to"] = []*net.MX{
		{Host: "localhost", Pref: 20},
	}

	s, tmpDir := newSMTP(t)
	defer testlib.RemoveIfOk(t, tmpDir)
	err, permanent := s.Deliver("me@me", "to@to", []byte("data"))
	if !strings.Contains(err.Error(), "TLS error:") {
		t.Errorf("expected TLS error, got: %v", err)
	}
	if permanent != false {
		t.Errorf("expected transient failure, got permanent")
	}

	srv.Wait()
}

func TestSTSPolicyEnforcement(t *testing.T) {
	smtpTotalTimeout = 5 * time.Second
	srv := newFakeServer(t, tlsResponses)
	defer srv.Cleanup()
	_, *smtpPort = srv.HostPort()

	s, tmpDir := newSMTP(t)
	defer testlib.RemoveIfOk(t, tmpDir)

	a := &attempt{
		courier:  s,
		from:     "me@me",
		to:       "to@to",
		toDomain: "to",
		data:     []byte("data"),
		tr:       trace.New("test", "test"),
	}

	a.stsPolicy = &sts.Policy{
		Version: "STSv1",
		Mode:    sts.Enforce,
		MXs:     []string{"mx"},
		MaxAge:  1 * time.Minute,
	}

	// At this point the cert is not valid, which is incompatible with STS
	// policy, so we expect it to fail.
	err, permanent := a.deliver("localhost")
	if !strings.Contains(err.Error(),
		"invalid security level (TLS_INSECURE) for STS policy") {
		t.Errorf("expected invalid sec level error, got %v", err)
	}
	if permanent != false {
		t.Errorf("expected transient error, got permanent")
	}

	srv.Wait()

	// Do another delivery attempt, but this time we trust the server cert.
	// This time it should be successful, because the connection level should
	// be TLS_SECURE which is required by the STS policy.
	srv = newFakeServer(t, tlsResponses)
	_, *smtpPort = srv.HostPort()
	defer srv.Cleanup()

	certRoots = srv.rootCA()
	defer func() {
		certRoots = nil
	}()

	err, permanent = a.deliver("localhost")
	if err != nil {
		t.Errorf("expected success, got %v (permanent=%v)", err, permanent)
	}

	srv.Wait()
}