git » chasquid » commit 21e69aa

Implement AUTH

author Alberto Bertogli
2016-07-16 11:43:29 UTC
committer Alberto Bertogli
2016-07-22 00:44:45 UTC
parent ff103c18c3bcffc88ac2ce9829c8ccedb81a86d8

Implement AUTH

This patch implements the AUTH SMTP command, using per-domain user databases.

Note that we don't really use or check the validation for anything, this is
just implementing the command itself.

chasquid.go +133 -5
chasquid_test.go +25 -0
internal/auth/auth.go +110 -0
internal/auth/auth_test.go +93 -0

diff --git a/chasquid.go b/chasquid.go
index 5dd5721..0a801e7 100644
--- a/chasquid.go
+++ b/chasquid.go
@@ -12,15 +12,18 @@ import (
 	"net/http"
 	"net/mail"
 	"net/textproto"
+	"os"
 	"path/filepath"
 	"strings"
 	"time"
 
+	"blitiri.com.ar/go/chasquid/internal/auth"
 	"blitiri.com.ar/go/chasquid/internal/config"
 	"blitiri.com.ar/go/chasquid/internal/courier"
 	"blitiri.com.ar/go/chasquid/internal/queue"
 	"blitiri.com.ar/go/chasquid/internal/systemd"
 	"blitiri.com.ar/go/chasquid/internal/trace"
+	"blitiri.com.ar/go/chasquid/internal/userdb"
 
 	_ "net/http/pprof"
 
@@ -40,6 +43,9 @@ var (
 func main() {
 	flag.Parse()
 
+	// Seed the PRNG, just to prevent for it to be totally predictable.
+	rand.Seed(time.Now().UnixNano())
+
 	conf, err := config.Load(*configDir + "/chasquid.conf")
 	if err != nil {
 		glog.Fatalf("Error reading config")
@@ -69,10 +75,9 @@ func main() {
 	} else {
 		glog.Infof("Domain config paths:")
 		for _, info := range domainDirs {
-			glog.Infof("  %s", info.Name())
-			s.AddDomain(info.Name())
-			dir := filepath.Join(*configDir, "domains", info.Name())
-			s.AddCerts(dir+"/cert.pem", dir+"/key.pem")
+			name := info.Name()
+			dir := filepath.Join(*configDir, "domains", name)
+			loadDomain(s, name, dir)
 		}
 	}
 
@@ -107,6 +112,27 @@ func main() {
 	s.ListenAndServe()
 }
 
+// Helper to load a single domain configuration into the server.
+func loadDomain(s *Server, name, dir string) {
+	glog.Infof("  %s", name)
+	s.AddDomain(name)
+	s.AddCerts(dir+"/cert.pem", dir+"/key.pem")
+
+	if _, err := os.Stat(dir + "/users"); err == nil {
+		glog.Infof("    adding users")
+		udb, warnings, err := userdb.Load(dir + "/users")
+		if err != nil {
+			glog.Errorf("      error: %v", err)
+		} else {
+			for _, w := range warnings {
+				glog.Warningf("     %v", w)
+			}
+			s.AddUserDB(name, udb)
+			// TODO: periodically reload the database.
+		}
+	}
+}
+
 type Server struct {
 	// Main hostname, used for display only.
 	Hostname string
@@ -129,6 +155,12 @@ type Server struct {
 	// Local domains.
 	localDomains map[string]bool
 
+	// User databases (per domain).
+	userDBs map[string]*userdb.DB
+
+	// Local courier.
+	localCourier courier.Courier
+
 	// Time before we give up on a connection, even if it's sending data.
 	connTimeout time.Duration
 
@@ -144,6 +176,7 @@ func NewServer() *Server {
 		connTimeout:    20 * time.Minute,
 		commandTimeout: 1 * time.Minute,
 		localDomains:   map[string]bool{},
+		userDBs:        map[string]*userdb.DB{},
 	}
 }
 
@@ -164,6 +197,10 @@ func (s *Server) AddDomain(d string) {
 	s.localDomains[d] = true
 }
 
+func (s *Server) AddUserDB(domain string, db *userdb.DB) {
+	s.userDBs[domain] = db
+}
+
 func (s *Server) getTLSConfig() (*tls.Config, error) {
 	var err error
 	conf := &tls.Config{}
@@ -241,6 +278,7 @@ func (s *Server) serve(l net.Listener) {
 			netconn:        conn,
 			tc:             textproto.NewConn(conn),
 			tlsConfig:      s.tlsConfig,
+			userDBs:        s.userDBs,
 			deadline:       time.Now().Add(s.connTimeout),
 			commandTimeout: s.commandTimeout,
 			queue:          s.queue,
@@ -274,6 +312,19 @@ type Conn struct {
 	// Are we using TLS?
 	onTLS bool
 
+	// User databases - taken from the server at creation time.
+	userDBs map[string]*userdb.DB
+
+	// Have we successfully completed AUTH?
+	completedAuth bool
+
+	// How many times have we attempted AUTH?
+	authAttempts int
+
+	// Authenticated user and domain, empty if !completedAuth.
+	authUser   string
+	authDomain string
+
 	// When we should close this connection, no matter what.
 	deadline time.Time
 
@@ -341,6 +392,8 @@ loop:
 			code, msg = c.DATA(params, tr)
 		case "STARTTLS":
 			code, msg = c.STARTTLS(params, tr)
+		case "AUTH":
+			code, msg = c.AUTH(params, tr)
 		case "QUIT":
 			c.writeResponse(221, "Be seeing you...")
 			break loop
@@ -383,7 +436,11 @@ func (c *Conn) EHLO(params string) (code int, msg string) {
 	fmt.Fprintf(buf, "8BITMIME\n")
 	fmt.Fprintf(buf, "PIPELINING\n")
 	fmt.Fprintf(buf, "SIZE %d\n", c.maxDataSize)
-	fmt.Fprintf(buf, "STARTTLS\n")
+	if c.onTLS {
+		fmt.Fprintf(buf, "AUTH PLAIN\n")
+	} else {
+		fmt.Fprintf(buf, "STARTTLS\n")
+	}
 	fmt.Fprintf(buf, "HELP\n")
 	return 250, buf.String()
 }
@@ -582,6 +639,73 @@ func (c *Conn) STARTTLS(params string, tr *trace.Trace) (code int, msg string) {
 	return 0, ""
 }
 
+func (c *Conn) AUTH(params string, tr *trace.Trace) (code int, msg string) {
+	if !c.onTLS {
+		return 503, "You feel vulnerable"
+	}
+
+	if c.completedAuth {
+		// After a successful AUTH command completes, a server MUST reject
+		// any further AUTH commands with a 503 reply.
+		// https://tools.ietf.org/html/rfc4954#section-4
+		return 503, "You are already wearing that!"
+	}
+
+	if c.authAttempts > 3 {
+		// TODO: close the connection?
+		return 503, "Too many attempts - go away"
+	}
+	c.authAttempts++
+
+	// We only support PLAIN for now, so no need to make this too complicated.
+	// Params should be either "PLAIN" or "PLAIN <response>".
+	// If the response is not there, we reply with 334, and expect the
+	// response back from the client in the next message.
+
+	sp := strings.SplitN(params, " ", 2)
+	if len(sp) < 1 || sp[0] != "PLAIN" {
+		// As we only offer plain, this should not really happen.
+		return 534, "Asmodeus demands 534 zorkmids for safe passage"
+	}
+
+	// Note we use more "serious" error messages from now own, as these may
+	// find their way to the users in some circumstances.
+
+	// Get the response, either from the message or interactively.
+	response := ""
+	if len(sp) == 2 {
+		response = sp[1]
+	} else {
+		// Reply 334 and expect the user to provide it.
+		// In this case, the text IS relevant, as it is taken as the
+		// server-side SASL challenge (empty for PLAIN).
+		// https://tools.ietf.org/html/rfc4954#section-4
+		err := c.writeResponse(334, "")
+		if err != nil {
+			return 554, fmt.Sprintf("error writing AUTH 334: %v", err)
+		}
+
+		response, err = c.readLine()
+		if err != nil {
+			return 554, fmt.Sprintf("error reading AUTH response: %v", err)
+		}
+	}
+
+	user, domain, passwd, err := auth.DecodeResponse(response)
+	if err != nil {
+		return 535, fmt.Sprintf("error decoding AUTH response: %v", err)
+	}
+
+	if auth.Authenticate(c.userDBs[domain], user, passwd) {
+		c.authUser = user
+		c.authDomain = domain
+		c.completedAuth = true
+		return 235, ""
+	} else {
+		return 535, "Incorrect user or password"
+	}
+}
+
 func (c *Conn) resetEnvelope() {
 	c.mail_from = ""
 	c.rcpt_to = nil
@@ -605,6 +729,10 @@ func (c *Conn) readCommand() (cmd, params string, err error) {
 	return cmd, params, err
 }
 
+func (c *Conn) readLine() (line string, err error) {
+	return c.tc.ReadLine()
+}
+
 func (c *Conn) writeResponse(code int, msg string) error {
 	defer c.tc.W.Flush()
 
diff --git a/chasquid_test.go b/chasquid_test.go
index 4ed0d4b..da54706 100644
--- a/chasquid_test.go
+++ b/chasquid_test.go
@@ -17,6 +17,8 @@ import (
 	"testing"
 	"time"
 
+	"blitiri.com.ar/go/chasquid/internal/userdb"
+
 	"github.com/golang/glog"
 )
 
@@ -66,8 +68,18 @@ func mustDial(tb testing.TB, useTLS bool) *smtp.Client {
 }
 
 func sendEmail(tb testing.TB, c *smtp.Client) {
+	sendEmailWithAuth(tb, c, nil)
+}
+
+func sendEmailWithAuth(tb testing.TB, c *smtp.Client, auth smtp.Auth) {
 	var err error
 
+	if auth != nil {
+		if err = c.Auth(auth); err != nil {
+			tb.Errorf("Auth: %v", err)
+		}
+	}
+
 	if err = c.Mail("from@from"); err != nil {
 		tb.Errorf("Mail: %v", err)
 	}
@@ -111,6 +123,14 @@ func TestManyEmails(t *testing.T) {
 	sendEmail(t, c)
 }
 
+func TestAuth(t *testing.T) {
+	c := mustDial(t, true)
+	defer c.Close()
+
+	auth := smtp.PlainAuth("", "testuser@localhost", "testpasswd", "127.0.0.1")
+	sendEmailWithAuth(t, c, auth)
+}
+
 func TestWrongMailParsing(t *testing.T) {
 	c := mustDial(t, false)
 	defer c.Close()
@@ -360,6 +380,11 @@ func realMain(m *testing.M) int {
 		s.MaxDataSize = 50 * 1024 * 1025
 		s.AddCerts(tmpDir+"/cert.pem", tmpDir+"/key.pem")
 		s.AddAddr(srvAddr)
+
+		udb := userdb.New("/dev/null")
+		udb.AddUser("testuser", "testpasswd")
+		s.AddUserDB("localhost", udb)
+
 		go s.ListenAndServe()
 	}
 
diff --git a/internal/auth/auth.go b/internal/auth/auth.go
new file mode 100644
index 0000000..d9fcb26
--- /dev/null
+++ b/internal/auth/auth.go
@@ -0,0 +1,110 @@
+package auth
+
+import (
+	"bytes"
+	"encoding/base64"
+	"fmt"
+	"math/rand"
+	"strings"
+	"time"
+	"unicode/utf8"
+
+	"blitiri.com.ar/go/chasquid/internal/userdb"
+)
+
+// DecodeResponse decodes a plain auth response.
+//
+// It must be a a base64-encoded string of the form:
+//   <authorization id> NUL <authentication id> NUL <password>
+//
+// https://tools.ietf.org/html/rfc4954#section-4.1.
+//
+// Either both ID match, or one of them is empty.
+// We expect the ID to be "user@domain", which is NOT an RFC requirement but
+// our own.
+func DecodeResponse(response string) (user, domain, passwd string, err error) {
+	buf, err := base64.StdEncoding.DecodeString(response)
+	if err != nil {
+		return
+	}
+
+	bufsp := bytes.SplitN(buf, []byte{0}, 3)
+	if len(bufsp) != 3 {
+		err = fmt.Errorf("Response pieces != 3, as per RFC")
+		return
+	}
+
+	identity := ""
+	passwd = string(bufsp[2])
+
+	{
+		// We don't make the distinction between the two IDs, as long as one is
+		// empty, or they're the same.
+		z := string(bufsp[0])
+		c := string(bufsp[1])
+
+		// If neither is empty, then they must be the same.
+		if (z != "" && c != "") && (z != c) {
+			err = fmt.Errorf("Auth IDs do not match")
+			return
+		}
+
+		if z != "" {
+			identity = z
+		}
+		if c != "" {
+			identity = c
+		}
+	}
+
+	if identity == "" {
+		err = fmt.Errorf("Empty identity, must be in the form user@domain")
+		return
+	}
+
+	// Identity must be in the form "user@domain".
+	// This is NOT an RFC requirement, it's our own.
+	idsp := strings.SplitN(identity, "@", 2)
+	if len(idsp) != 2 {
+		err = fmt.Errorf("Identity must be in the form user@domain")
+		return
+	}
+
+	user = idsp[0]
+	domain = idsp[1]
+
+	// TODO: Quedamos aca. Validar dominio no (solo) como utf8, sino ver que
+	// no contenga ni "/" ni "..". Podemos usar golang.org/x/net/idna para
+	// convertirlo a unicode primero, o al reves. No se que queremos.
+	if !utf8.ValidString(user) || !utf8.ValidString(domain) {
+		err = fmt.Errorf("User/domain is not valid UTF-8")
+		return
+	}
+
+	return
+}
+
+// How long Authenticate calls should last, approximately.
+// This will be applied both for successful and unsuccessful attempts.
+// We will increase this number by 0-20%.
+var AuthenticateTime = 100 * time.Millisecond
+
+// Authenticate user/password on the given database.
+func Authenticate(udb *userdb.DB, user, passwd string) bool {
+	defer func(start time.Time) {
+		elapsed := time.Since(start)
+		delay := AuthenticateTime - elapsed
+		if delay > 0 {
+			maxDelta := int64(float64(delay) * 0.2)
+			delay += time.Duration(rand.Int63n(maxDelta))
+			time.Sleep(delay)
+		}
+	}(time.Now())
+
+	// Note that the database CAN be nil, to simplify callers.
+	if udb == nil {
+		return false
+	}
+
+	return udb.Authenticate(user, passwd)
+}
diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go
new file mode 100644
index 0000000..0d5fba5
--- /dev/null
+++ b/internal/auth/auth_test.go
@@ -0,0 +1,93 @@
+package auth
+
+import (
+	"encoding/base64"
+	"testing"
+	"time"
+
+	"blitiri.com.ar/go/chasquid/internal/userdb"
+)
+
+func TestDecodeResponse(t *testing.T) {
+	// Successful cases. Note we hard-code the response for extra assurance.
+	cases := []struct {
+		response, user, domain, passwd string
+	}{
+		{"dUBkAHVAZABwYXNz", "u", "d", "pass"},     // u@d\0u@d\0pass
+		{"dUBkAABwYXNz", "u", "d", "pass"},         // u@d\0\0pass
+		{"AHVAZABwYXNz", "u", "d", "pass"},         // \0u@d\0pass
+		{"dUBkAABwYXNz/w==", "u", "d", "pass\xff"}, // u@d\0\0pass\xff
+
+		// "ñaca@ñeque\0\0clavaré"
+		{"w7FhY2FAw7FlcXVlAABjbGF2YXLDqQ==", "ñaca", "ñeque", "clavaré"},
+	}
+	for _, c := range cases {
+		u, d, p, err := DecodeResponse(c.response)
+		if err != nil {
+			t.Errorf("Error in case %v: %v", c, err)
+		}
+
+		if u != c.user || d != c.domain || p != c.passwd {
+			t.Errorf("Expected %q %q %q ; got %q %q %q",
+				c.user, c.domain, c.passwd, u, d, p)
+		}
+	}
+
+	_, _, _, err := DecodeResponse("this is not base64 encoded")
+	if err == nil {
+		t.Errorf("invalid base64 did not fail as expected")
+	}
+
+	failedCases := []string{
+		"", "\x00", "\x00\x00", "\x00\x00\x00", "\x00\x00\x00\x00",
+		"a\x00b", "a\x00b\x00c", "a@a\x00b@b\x00pass", "a\x00a\x00pass",
+		"\xffa@b\x00\xffa@b\x00pass",
+	}
+	for _, c := range failedCases {
+		r := base64.StdEncoding.EncodeToString([]byte(c))
+		_, _, _, err := DecodeResponse(r)
+		if err == nil {
+			t.Errorf("Expected case %q to fail, but succeeded", c)
+		} else {
+			t.Logf("OK: %q failed with %v", c, err)
+		}
+	}
+}
+
+func TestAuthenticate(t *testing.T) {
+	db := userdb.New("/dev/null")
+	db.AddUser("user", "password")
+
+	// Test the correct case first
+	ts := time.Now()
+	if !Authenticate(db, "user", "password") {
+		t.Errorf("failed valid authentication for user/password")
+	}
+	if time.Since(ts) < AuthenticateTime {
+		t.Errorf("authentication was too fast")
+	}
+
+	// Incorrect cases.
+	cases := []struct{ user, password string }{
+		{"user", "incorrect"},
+		{"invalid", "p"},
+	}
+	for _, c := range cases {
+		ts = time.Now()
+		if Authenticate(db, c.user, c.password) {
+			t.Errorf("successful auth on %v", c)
+		}
+		if time.Since(ts) < AuthenticateTime {
+			t.Errorf("authentication was too fast")
+		}
+	}
+
+	// And the special case of a nil userdb.
+	ts = time.Now()
+	if Authenticate(nil, "user", "password") {
+		t.Errorf("successful auth on a nil userdb")
+	}
+	if time.Since(ts) < AuthenticateTime {
+		t.Errorf("authentication was too fast")
+	}
+}