git » chasquid » commit 99df5e7

smtpsrv: Limit incoming line length and improve large message handling

author Alberto Bertogli
2019-12-01 01:30:36 UTC
committer Alberto Bertogli
2019-12-01 19:07:58 UTC
parent d7006d0e1628385f01ebebaafeb36c688cc055e7

smtpsrv: Limit incoming line length and improve large message handling

Currently, there is no limit to incoming line length, so an evil client
could cause a memory exhaustion DoS by issuing very long lines.

This patch fixes the bug by limiting the size of the lines.

To do that, we replace the textproto.Conn with a pair of buffered reader
and writer, which simplify the code and allow for better and cleaner
control.

Thanks to Max Mazurov (fox.cpp@disroot.org) for finding and reporting
this issue.

internal/smtpsrv/conn.go +47 -12
internal/smtpsrv/server.go +0 -2
internal/smtpsrv/server_test.go +82 -3
test/t-12-minor_dialogs/line_too_long.cmy +5 -0

diff --git a/internal/smtpsrv/conn.go b/internal/smtpsrv/conn.go
index 4731bae..16385c1 100644
--- a/internal/smtpsrv/conn.go
+++ b/internal/smtpsrv/conn.go
@@ -1,6 +1,7 @@
 package smtpsrv
 
 import (
+	"bufio"
 	"bytes"
 	"context"
 	"crypto/tls"
@@ -94,10 +95,13 @@ type Conn struct {
 
 	// Connection information.
 	conn         net.Conn
-	tc           *textproto.Conn
 	mode         SocketMode
 	tlsConnState *tls.ConnectionState
 
+	// Reader and text writer, so we can control limits.
+	reader *bufio.Reader
+	writer *bufio.Writer
+
 	// Tracer to use.
 	tr *trace.Trace
 
@@ -178,7 +182,12 @@ func (c *Conn) Handle() {
 		}
 	}
 
-	c.tc.PrintfLine("220 %s ESMTP chasquid", c.hostname)
+	// Set up a buffered reader and writer from the conn.
+	// They will be used to do line-oriented, limited I/O.
+	c.reader = bufio.NewReader(c.conn)
+	c.writer = bufio.NewWriter(c.conn)
+
+	c.printfLine("220 %s ESMTP chasquid", c.hostname)
 
 	var cmd, params string
 	var err error
@@ -196,7 +205,7 @@ loop:
 
 		cmd, params, err = c.readCommand()
 		if err != nil {
-			c.tc.PrintfLine("554 error reading command: %v", err)
+			c.printfLine("554 error reading command: %v", err)
 			break
 		}
 
@@ -577,9 +586,14 @@ func (c *Conn) DATA(params string) (code int, msg string) {
 	// one, we don't want the command timeout to interfere.
 	c.conn.SetDeadline(c.deadline)
 
-	dotr := io.LimitReader(c.tc.DotReader(), c.maxDataSize)
+	// Create a dot reader, limited to the maximum size.
+	dotr := textproto.NewReader(bufio.NewReader(
+		io.LimitReader(c.reader, c.maxDataSize))).DotReader()
 	c.data, err = ioutil.ReadAll(dotr)
 	if err != nil {
+		if err == io.ErrUnexpectedEOF {
+			return 552, fmt.Sprintf("5.3.4 Message too big")
+		}
 		return 554, fmt.Sprintf("5.4.0 Error reading DATA: %v", err)
 	}
 
@@ -875,9 +889,10 @@ func (c *Conn) STARTTLS(params string) (code int, msg string) {
 
 	c.tr.Debugf("<> ...  jump to TLS was successful")
 
-	// Override the connections. We don't need the older ones anymore.
+	// Override the connection. We don't need the older one anymore.
 	c.conn = server
-	c.tc = textproto.NewConn(server)
+	c.reader = bufio.NewReader(c.conn)
+	c.writer = bufio.NewWriter(c.conn)
 
 	// Take the connection state, so we can use it later for logging and
 	// tracing purposes.
@@ -1001,9 +1016,7 @@ func (c *Conn) userExists(addr string) bool {
 }
 
 func (c *Conn) readCommand() (cmd, params string, err error) {
-	var msg string
-
-	msg, err = c.tc.ReadLine()
+	msg, err := c.readLine()
 	if err != nil {
 		return "", "", err
 	}
@@ -1018,14 +1031,36 @@ func (c *Conn) readCommand() (cmd, params string, err error) {
 }
 
 func (c *Conn) readLine() (line string, err error) {
-	return c.tc.ReadLine()
+	// The bufio reader's ReadLine will only read up to the buffer size, which
+	// prevents DoS due to memory exhaustion on extremely long lines.
+	l, more, err := c.reader.ReadLine()
+	if err != nil {
+		return "", err
+	}
+
+	// As per RFC, the maximum length of a text line is 1000 octets.
+	// https://tools.ietf.org/html/rfc5321#section-4.5.3.1.6
+	if len(l) > 1000 || more {
+		// Keep reading to maintain the protocol status, but discard the data.
+		for more && err == nil {
+			_, more, err = c.reader.ReadLine()
+		}
+		return "", fmt.Errorf("line too long")
+	}
+
+	return string(l), nil
 }
 
 func (c *Conn) writeResponse(code int, msg string) error {
-	defer c.tc.W.Flush()
+	defer c.writer.Flush()
 
 	responseCodeCount.Add(strconv.Itoa(code), 1)
-	return writeResponse(c.tc.W, code, msg)
+	return writeResponse(c.writer, code, msg)
+}
+
+func (c *Conn) printfLine(format string, args ...interface{}) error {
+	fmt.Fprintf(c.writer, format+"\r\n", args...)
+	return c.writer.Flush()
 }
 
 // writeResponse writes a multi-line response to the given writer.
diff --git a/internal/smtpsrv/server.go b/internal/smtpsrv/server.go
index 882f9e3..4e3105a 100644
--- a/internal/smtpsrv/server.go
+++ b/internal/smtpsrv/server.go
@@ -6,7 +6,6 @@ import (
 	"flag"
 	"net"
 	"net/http"
-	"net/textproto"
 	"path"
 	"time"
 
@@ -247,7 +246,6 @@ func (s *Server) serve(l net.Listener, mode SocketMode) {
 			maxDataSize:    s.MaxDataSize,
 			postDataHook:   pdhook,
 			conn:           conn,
-			tc:             textproto.NewConn(conn),
 			mode:           mode,
 			tlsConfig:      s.tlsConfig,
 			onTLS:          mode.TLS,
diff --git a/internal/smtpsrv/server_test.go b/internal/smtpsrv/server_test.go
index 23a5e90..4fe95ed 100644
--- a/internal/smtpsrv/server_test.go
+++ b/internal/smtpsrv/server_test.go
@@ -43,6 +43,9 @@ var (
 	// TLS configuration to use in the clients.
 	// Will contain the generated server certificate as root CA.
 	tlsConfig *tls.Config
+
+	// Max data size, in MiB.
+	maxDataSizeMiB = 5
 )
 
 //
@@ -259,14 +262,69 @@ func TestRelayForbidden(t *testing.T) {
 	}
 }
 
-func simpleCmd(t *testing.T, c *smtp.Client, cmd string, expected int) {
+var str1MiB string
+
+func sendLargeEmail(tb testing.TB, c *smtp.Client, sizeMiB int) error {
+	tb.Helper()
+	if err := c.Mail("from@from"); err != nil {
+		tb.Fatalf("Mail: %v", err)
+	}
+	if err := c.Rcpt("to@localhost"); err != nil {
+		tb.Fatalf("Rcpt: %v", err)
+	}
+
+	w, err := c.Data()
+	if err != nil {
+		tb.Fatalf("Data: %v", err)
+	}
+
+	if _, err := w.Write([]byte("Subject: I ate too much\n\n")); err != nil {
+		tb.Fatalf("Data write: %v", err)
+	}
+
+	// Write the 1 MiB string sizeMiB times.
+	for i := 0; i < sizeMiB; i++ {
+		if _, err := w.Write([]byte(str1MiB)); err != nil {
+			tb.Fatalf("Data write: %v", err)
+		}
+	}
+
+	return w.Close()
+}
+
+func TestTooMuchData(t *testing.T) {
+	c := mustDial(t, ModeSMTP, true)
+	defer c.Close()
+
+	err := sendLargeEmail(t, c, maxDataSizeMiB-1)
+	if err != nil {
+		t.Errorf("Error sending large but ok email: %v", err)
+	}
+
+	// Repeat the test - we want to check that the limit applies to each
+	// message, not the entire connection.
+	err = sendLargeEmail(t, c, maxDataSizeMiB-1)
+	if err != nil {
+		t.Errorf("Error sending large but ok email: %v", err)
+	}
+
+	err = sendLargeEmail(t, c, maxDataSizeMiB+1)
+	if err == nil || err.Error() != "552 5.3.4 Message too big" {
+		t.Fatalf("Expected message too big, got: %v", err)
+	}
+}
+
+func simpleCmd(t *testing.T, c *smtp.Client, cmd string, expected int) string {
+	t.Helper()
 	if err := c.Text.PrintfLine(cmd); err != nil {
 		t.Fatalf("Failed to write %s: %v", cmd, err)
 	}
 
-	if _, _, err := c.Text.ReadResponse(expected); err != nil {
+	_, msg, err := c.Text.ReadResponse(expected)
+	if err != nil {
 		t.Errorf("Incorrect %s response: %v", cmd, err)
 	}
+	return msg
 }
 
 func TestSimpleCommands(t *testing.T) {
@@ -278,6 +336,20 @@ func TestSimpleCommands(t *testing.T) {
 	simpleCmd(t, c, "EXPN", 502)
 }
 
+func TestLongLines(t *testing.T) {
+	c := mustDial(t, ModeSMTP, false)
+	defer c.Close()
+
+	// Send a not-too-long line.
+	simpleCmd(t, c, fmt.Sprintf("%1000s", "x"), 500)
+
+	// Send a very long line, expect an error.
+	msg := simpleCmd(t, c, fmt.Sprintf("%1001s", "x"), 554)
+	if msg != "error reading command: line too long" {
+		t.Errorf("Expected 'line too long', got %v", msg)
+	}
+}
+
 func TestReset(t *testing.T) {
 	c := mustDial(t, ModeSMTP, false)
 	defer c.Close()
@@ -448,6 +520,13 @@ func waitForServer(addr string) error {
 func realMain(m *testing.M) int {
 	flag.Parse()
 
+	// Create a 1MiB string, which the large message tests use.
+	buf := make([]byte, 1024*1024)
+	for i := 0; i < len(buf); i++ {
+		buf[i] = 'a'
+	}
+	str1MiB = string(buf)
+
 	if *externalSMTPAddr != "" {
 		smtpAddr = *externalSMTPAddr
 		submissionAddr = *externalSubmissionAddr
@@ -476,7 +555,7 @@ func realMain(m *testing.M) int {
 
 		s := NewServer()
 		s.Hostname = "localhost"
-		s.MaxDataSize = 50 * 1024 * 1025
+		s.MaxDataSize = int64(maxDataSizeMiB) * 1024 * 1024
 		s.AddCerts(tmpDir+"/cert.pem", tmpDir+"/key.pem")
 		s.AddAddr(smtpAddr, ModeSMTP)
 		s.AddAddr(submissionAddr, ModeSubmission)
diff --git a/test/t-12-minor_dialogs/line_too_long.cmy b/test/t-12-minor_dialogs/line_too_long.cmy
new file mode 100644
index 0000000..20f6219
--- /dev/null
+++ b/test/t-12-minor_dialogs/line_too_long.cmy
@@ -0,0 +1,5 @@
+c tcp_connect localhost:1025
+c <~ 220
+
+c -> HELO aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1aaaaaaaaa1
+c <~ 554