git » chasquid » commit d4992ef

auth: Implement an Authenticator type

author Alberto Bertogli
2018-01-29 21:55:34 UTC
committer Alberto Bertogli
2018-02-10 22:24:39 UTC
parent 08e6f57d2edc8d4bf55b135ece0aa9c570d36f12

auth: Implement an Authenticator type

This patch implements an Authenticator type, which connections use to
do authentication and user existence checks.

It simplifies the abstractions (the server doesn't need to know about
userdb, or keep track of domain-userdb maps), and lays the foundation
for other types of authentication backends which will come in later
patches.

internal/auth/auth.go +132 -21
internal/auth/auth_test.go +198 -21
internal/smtpsrv/conn.go +11 -8
internal/smtpsrv/server.go +13 -9
internal/userdb/userdb.go +2 -2
internal/userdb/userdb_test.go +6 -6

diff --git a/internal/auth/auth.go b/internal/auth/auth.go
index 0beaf82..6ab74c2 100644
--- a/internal/auth/auth.go
+++ b/internal/auth/auth.go
@@ -1,17 +1,130 @@
+// Package auth implements authentication services for chasquid.
 package auth
 
 import (
 	"bytes"
 	"encoding/base64"
+	"errors"
 	"fmt"
 	"math/rand"
 	"strings"
 	"time"
 
 	"blitiri.com.ar/go/chasquid/internal/normalize"
-	"blitiri.com.ar/go/chasquid/internal/userdb"
 )
 
+// Interface for authentication backends.
+type Backend interface {
+	Authenticate(user, password string) (bool, error)
+	Exists(user string) (bool, error)
+	Reload() error
+}
+
+// Interface for authentication backends that don't need to emit errors.
+// This allows backends to avoid unnecessary complexity, in exchange for a bit
+// more here.
+// They can be converted to normal Backend using WrapNoErrorBackend (defined
+// below).
+type NoErrorBackend interface {
+	Authenticate(user, password string) bool
+	Exists(user string) bool
+	Reload() error
+}
+
+type Authenticator struct {
+	// Registered backends, map of domain (string) -> Backend.
+	// Backend operations will _not_ include the domain in the username.
+	backends map[string]Backend
+
+	// Fallback backend, to use when backends[domain] (which may not exist)
+	// did not yield a positive result.
+	// Note that this backend gets the user with the domain included, of the
+	// form "user@domain".
+	Fallback Backend
+
+	// How long Authenticate calls should last, approximately.
+	// This will be applied both for successful and unsuccessful attempts.
+	// We will increase this number by 0-20%.
+	AuthDuration time.Duration
+}
+
+func NewAuthenticator() *Authenticator {
+	return &Authenticator{
+		backends:     map[string]Backend{},
+		AuthDuration: 100 * time.Millisecond,
+	}
+}
+
+func (a *Authenticator) Register(domain string, be Backend) {
+	a.backends[domain] = be
+}
+
+// Authenticate the user@domain with the given password.
+func (a *Authenticator) Authenticate(user, domain, password string) (bool, error) {
+	// Make sure the call takes a.AuthDuration + 0-20% regardless of the
+	// outcome, to prevent basic timing attacks.
+	defer func(start time.Time) {
+		elapsed := time.Since(start)
+		delay := a.AuthDuration - elapsed
+		if delay > 0 {
+			maxDelta := int64(float64(delay) * 0.2)
+			delay += time.Duration(rand.Int63n(maxDelta))
+			time.Sleep(delay)
+		}
+	}(time.Now())
+
+	if be, ok := a.backends[domain]; ok {
+		ok, err := be.Authenticate(user, password)
+		if ok || err != nil {
+			return ok, err
+		}
+	}
+
+	if a.Fallback != nil {
+		return a.Fallback.Authenticate(user+"@"+domain, password)
+	}
+
+	return false, nil
+}
+
+func (a *Authenticator) Exists(user, domain string) (bool, error) {
+	if be, ok := a.backends[domain]; ok {
+		ok, err := be.Exists(user)
+		if ok || err != nil {
+			return ok, err
+		}
+	}
+
+	if a.Fallback != nil {
+		return a.Fallback.Exists(user + "@" + domain)
+	}
+
+	return false, nil
+}
+
+// Reload the registered backends.
+func (a *Authenticator) Reload() error {
+	msgs := []string{}
+
+	for domain, be := range a.backends {
+		err := be.Reload()
+		if err != nil {
+			msgs = append(msgs, fmt.Sprintf("%q: %v", domain, err))
+		}
+	}
+	if a.Fallback != nil {
+		err := a.Fallback.Reload()
+		if err != nil {
+			msgs = append(msgs, fmt.Sprintf("<fallback>: %v", err))
+		}
+	}
+
+	if len(msgs) > 0 {
+		return errors.New(strings.Join(msgs, " ; "))
+	}
+	return nil
+}
+
 // DecodeResponse decodes a plain auth response.
 //
 // It must be a a base64-encoded string of the form:
@@ -89,27 +202,25 @@ func DecodeResponse(response string) (user, domain, passwd string, err error) {
 	return
 }
 
-// How long Authenticate calls should last, approximately.
-// This will be applied both for successful and unsuccessful attempts.
-// We will increase this number by 0-20%.
-var AuthenticateTime = 100 * time.Millisecond
+// WrapNoErrorBackend wraps a NoErrorBackend, converting it into a valid
+// Backend. This is normally used in Auth.Register calls, to register no-error
+// backends.
+func WrapNoErrorBackend(be NoErrorBackend) Backend {
+	return &wrapNoErrorBackend{be}
+}
 
-// Authenticate user/password on the given database.
-func Authenticate(udb *userdb.DB, user, passwd string) bool {
-	defer func(start time.Time) {
-		elapsed := time.Since(start)
-		delay := AuthenticateTime - elapsed
-		if delay > 0 {
-			maxDelta := int64(float64(delay) * 0.2)
-			delay += time.Duration(rand.Int63n(maxDelta))
-			time.Sleep(delay)
-		}
-	}(time.Now())
+type wrapNoErrorBackend struct {
+	be NoErrorBackend
+}
 
-	// Note that the database CAN be nil, to simplify callers.
-	if udb == nil {
-		return false
-	}
+func (w *wrapNoErrorBackend) Authenticate(user, password string) (bool, error) {
+	return w.be.Authenticate(user, password), nil
+}
+
+func (w *wrapNoErrorBackend) Exists(user string) (bool, error) {
+	return w.be.Exists(user), nil
+}
 
-	return udb.Authenticate(user, passwd)
+func (w *wrapNoErrorBackend) Reload() error {
+	return w.be.Reload()
 }
diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go
index 0d5fba5..95c3eb5 100644
--- a/internal/auth/auth_test.go
+++ b/internal/auth/auth_test.go
@@ -2,6 +2,7 @@ package auth
 
 import (
 	"encoding/base64"
+	"fmt"
 	"testing"
 	"time"
 
@@ -58,36 +59,212 @@ func TestAuthenticate(t *testing.T) {
 	db := userdb.New("/dev/null")
 	db.AddUser("user", "password")
 
+	a := NewAuthenticator()
+	a.Register("domain", WrapNoErrorBackend(db))
+
+	// Shorten the duration to speed up the test. This should still be long
+	// enough for it to fail if we don't sleep intentionally.
+	a.AuthDuration = 20 * time.Millisecond
+
 	// Test the correct case first
+	check(t, a, "user", "domain", "password", true)
+
+	// Wrong password, but valid user@domain.
 	ts := time.Now()
-	if !Authenticate(db, "user", "password") {
-		t.Errorf("failed valid authentication for user/password")
+	if ok, _ := a.Authenticate("user", "domain", "invalid"); ok {
+		t.Errorf("invalid password, but authentication succeeded")
 	}
-	if time.Since(ts) < AuthenticateTime {
-		t.Errorf("authentication was too fast")
+	if time.Since(ts) < a.AuthDuration {
+		t.Errorf("authentication was too fast (invalid case)")
 	}
 
-	// Incorrect cases.
-	cases := []struct{ user, password string }{
-		{"user", "incorrect"},
-		{"invalid", "p"},
+	// Incorrect cases, where the user@domain do not exist.
+	cases := []struct{ user, domain, password string }{
+		{"user", "unknown", "password"},
+		{"invalid", "domain", "p"},
+		{"invalid", "unknown", "p"},
+		{"user", "", "password"},
+		{"invalid", "", "p"},
+		{"", "domain", "password"},
+		{"", "", ""},
 	}
 	for _, c := range cases {
-		ts = time.Now()
-		if Authenticate(db, c.user, c.password) {
-			t.Errorf("successful auth on %v", c)
-		}
-		if time.Since(ts) < AuthenticateTime {
-			t.Errorf("authentication was too fast")
-		}
+		check(t, a, c.user, c.domain, c.password, false)
+	}
+}
+
+func check(t *testing.T, a *Authenticator, user, domain, passwd string, expect bool) {
+	c := fmt.Sprintf("{%s@%s %s}", user, domain, passwd)
+	ts := time.Now()
+
+	ok, err := a.Authenticate(user, domain, passwd)
+	if time.Since(ts) < a.AuthDuration {
+		t.Errorf("auth on %v was too fast", c)
+	}
+	if ok != expect {
+		t.Errorf("auth on %v: got %v, expected %v", c, ok, expect)
+	}
+	if err != nil {
+		t.Errorf("auth on %v: got error %v", c, err)
+	}
+
+	ok, err = a.Exists(user, domain)
+	if ok != expect {
+		t.Errorf("exists on %v: got %v, expected %v", c, ok, expect)
+	}
+	if err != nil {
+		t.Errorf("exists on %v: error %v", c, err)
+	}
+}
+
+func TestInterfaces(t *testing.T) {
+	var _ NoErrorBackend = userdb.New("/dev/null")
+}
+
+// Backend implementation for testing.
+type TestBE struct {
+	users       map[string]string
+	reloadCount int
+	nextError   error
+}
+
+func NewTestBE() *TestBE {
+	return &TestBE{
+		users: map[string]string{},
+	}
+}
+func (d *TestBE) add(user, password string) {
+	d.users[user] = password
+}
+
+func (d *TestBE) Authenticate(user, password string) (bool, error) {
+	if d.nextError != nil {
+		return false, d.nextError
+	}
+
+	if validP, ok := d.users[user]; ok {
+		return validP == password, nil
+	}
+	return false, nil
+}
+
+func (d *TestBE) Exists(user string) (bool, error) {
+	if d.nextError != nil {
+		return false, d.nextError
+	}
+	_, ok := d.users[user]
+	return ok, nil
+}
+
+func (d *TestBE) Reload() error {
+	d.reloadCount++
+	if d.nextError != nil {
+		return d.nextError
 	}
+	return nil
+}
+
+func TestMultipleBackends(t *testing.T) {
+	domain1 := NewTestBE()
+	domain2 := NewTestBE()
+	fallback := NewTestBE()
+
+	a := NewAuthenticator()
+	a.Register("domain1", domain1)
+	a.Register("domain2", domain2)
+	a.Fallback = fallback
+
+	// Shorten the duration to speed up the test. This should still be long
+	// enough for it to fail if we don't sleep intentionally.
+	a.AuthDuration = 20 * time.Millisecond
+
+	domain1.add("user1", "passwd1")
+	domain2.add("user2", "passwd2")
+	fallback.add("user3@fallback", "passwd3")
+	fallback.add("user4@domain1", "passwd4")
+
+	// Successful tests.
+	cases := []struct{ user, domain, password string }{
+		{"user1", "domain1", "passwd1"},
+		{"user2", "domain2", "passwd2"},
+		{"user3", "fallback", "passwd3"},
+		{"user4", "domain1", "passwd4"},
+	}
+	for _, c := range cases {
+		check(t, a, c.user, c.domain, c.password, true)
+	}
+
+	// Unsuccessful tests (users don't exist).
+	cases = []struct{ user, domain, password string }{
+		{"nobody", "domain1", "p"},
+		{"nobody", "domain2", "p"},
+		{"nobody", "fallback", "p"},
+		{"user3", "", "p"},
+	}
+	for _, c := range cases {
+		check(t, a, c.user, c.domain, c.password, false)
+	}
+}
+
+func TestErrors(t *testing.T) {
+	be := NewTestBE()
+	be.add("user", "passwd")
+
+	a := NewAuthenticator()
+	a.Register("domain", be)
+	a.AuthDuration = 0
 
-	// And the special case of a nil userdb.
-	ts = time.Now()
-	if Authenticate(nil, "user", "password") {
-		t.Errorf("successful auth on a nil userdb")
+	ok, err := a.Authenticate("user", "domain", "passwd")
+	if err != nil || !ok {
+		t.Fatalf("failed auth")
+	}
+
+	expectedErr := fmt.Errorf("test error")
+	be.nextError = expectedErr
+
+	ok, err = a.Authenticate("user", "domain", "passwd")
+	if ok {
+		t.Errorf("authentication succeeded, expected error")
+	}
+	if err != expectedErr {
+		t.Errorf("expected error, got %v", err)
+	}
+
+	ok, err = a.Exists("user", "domain")
+	if ok {
+		t.Errorf("exists succeeded, expected error")
+	}
+	if err != expectedErr {
+		t.Errorf("expected error, got %v", err)
+	}
+}
+
+func TestReload(t *testing.T) {
+	be1 := NewTestBE()
+	be2 := NewTestBE()
+	fallback := NewTestBE()
+
+	a := NewAuthenticator()
+	a.Register("domain1", be1)
+	a.Register("domain2", be2)
+	a.Fallback = fallback
+
+	err := a.Reload()
+	if err != nil {
+		t.Errorf("unexpected error reloading: %v", err)
+	}
+	if be1.reloadCount != 1 || be2.reloadCount != 1 || fallback.reloadCount != 1 {
+		t.Errorf("unexpected reload counts: %d %d %d != 1 1 1",
+			be1.reloadCount, be2.reloadCount, fallback.reloadCount)
+	}
+
+	be2.nextError = fmt.Errorf("test error")
+	err = a.Reload()
+	if err == nil {
+		t.Errorf("expected error reloading, got nil")
 	}
-	if time.Since(ts) < AuthenticateTime {
-		t.Errorf("authentication was too fast")
+	if be1.reloadCount != 2 || be2.reloadCount != 2 || fallback.reloadCount != 2 {
+		t.Errorf("unexpected reload counts: %d %d %d != 2 2 2",
+			be1.reloadCount, be2.reloadCount, fallback.reloadCount)
 	}
 }
diff --git a/internal/smtpsrv/conn.go b/internal/smtpsrv/conn.go
index 9ab3981..42a8cd7 100644
--- a/internal/smtpsrv/conn.go
+++ b/internal/smtpsrv/conn.go
@@ -30,7 +30,6 @@ import (
 	"blitiri.com.ar/go/chasquid/internal/set"
 	"blitiri.com.ar/go/chasquid/internal/tlsconst"
 	"blitiri.com.ar/go/chasquid/internal/trace"
-	"blitiri.com.ar/go/chasquid/internal/userdb"
 	"blitiri.com.ar/go/spf"
 )
 
@@ -120,9 +119,9 @@ type Conn struct {
 	// Are we using TLS?
 	onTLS bool
 
-	// User databases, aliases and local domains, taken from the server at
+	// Authenticator, aliases and local domains, taken from the server at
 	// creation time.
-	userDBs      map[string]*userdb.DB
+	authr        *auth.Authenticator
 	localDomains *set.String
 	aliasesR     *aliases.Resolver
 	dinfo        *domaininfo.DB
@@ -897,7 +896,11 @@ func (c *Conn) AUTH(params string) (code int, msg string) {
 		return 535, fmt.Sprintf("error decoding AUTH response: %v", err)
 	}
 
-	if auth.Authenticate(c.userDBs[domain], user, passwd) {
+	authOk, err := c.authr.Authenticate(user, domain, passwd)
+	if err != nil {
+		c.tr.Errorf("error authenticating %q@%q: %v", user, domain, err)
+	}
+	if authOk {
 		c.authUser = user
 		c.authDomain = domain
 		c.completedAuth = true
@@ -929,11 +932,11 @@ func (c *Conn) userExists(addr string) bool {
 	// look up "user" in our databases if the domain is local, which is what
 	// we want.
 	user, domain := envelope.Split(addr)
-	udb := c.userDBs[domain]
-	if udb == nil {
-		return false
+	ok, err := c.authr.Exists(user, domain)
+	if err != nil {
+		c.tr.Errorf("error checking if user %q exists: %v", addr, err)
 	}
-	return udb.HasUser(user)
+	return ok
 }
 
 func (c *Conn) readCommand() (cmd, params string, err error) {
diff --git a/internal/smtpsrv/server.go b/internal/smtpsrv/server.go
index f0d41e6..777159c 100644
--- a/internal/smtpsrv/server.go
+++ b/internal/smtpsrv/server.go
@@ -9,6 +9,7 @@ import (
 	"time"
 
 	"blitiri.com.ar/go/chasquid/internal/aliases"
+	"blitiri.com.ar/go/chasquid/internal/auth"
 	"blitiri.com.ar/go/chasquid/internal/courier"
 	"blitiri.com.ar/go/chasquid/internal/domaininfo"
 	"blitiri.com.ar/go/chasquid/internal/maillog"
@@ -38,7 +39,8 @@ type Server struct {
 	localDomains *set.String
 
 	// User databases (per domain).
-	userDBs map[string]*userdb.DB
+	// Authenticator.
+	authr *auth.Authenticator
 
 	// Aliases resolver.
 	aliasesR *aliases.Resolver
@@ -67,7 +69,7 @@ func NewServer() *Server {
 		connTimeout:    20 * time.Minute,
 		commandTimeout: 1 * time.Minute,
 		localDomains:   &set.String{},
-		userDBs:        map[string]*userdb.DB{},
+		authr:          auth.NewAuthenticator(),
 		aliasesR:       aliases.NewResolver(),
 	}
 }
@@ -95,13 +97,17 @@ func (s *Server) AddDomain(d string) {
 }
 
 func (s *Server) AddUserDB(domain string, db *userdb.DB) {
-	s.userDBs[domain] = db
+	s.authr.Register(domain, auth.WrapNoErrorBackend(db))
 }
 
 func (s *Server) AddAliasesFile(domain, f string) error {
 	return s.aliasesR.AddAliasesFile(domain, f)
 }
 
+func (s *Server) SetAuthFallback(be auth.Backend) {
+	s.authr.Fallback = be
+}
+
 func (s *Server) SetAliasesConfig(suffixSep, dropChars string) {
 	s.aliasesR.SuffixSep = suffixSep
 	s.aliasesR.DropChars = dropChars
@@ -145,11 +151,9 @@ func (s *Server) periodicallyReload() {
 			log.Errorf("Error reloading aliases: %v", err)
 		}
 
-		for domain, udb := range s.userDBs {
-			err = udb.Reload()
-			if err != nil {
-				log.Errorf("Error reloading %q user db: %v", domain, err)
-			}
+		err = s.authr.Reload()
+		if err != nil {
+			log.Errorf("Error reloading authenticators: %v", err)
 		}
 	}
 }
@@ -219,7 +223,7 @@ func (s *Server) serve(l net.Listener, mode SocketMode) {
 			mode:           mode,
 			tlsConfig:      s.tlsConfig,
 			onTLS:          mode.TLS,
-			userDBs:        s.userDBs,
+			authr:          s.authr,
 			aliasesR:       s.aliasesR,
 			localDomains:   s.localDomains,
 			dinfo:          s.dinfo,
diff --git a/internal/userdb/userdb.go b/internal/userdb/userdb.go
index 83f9942..0359fbb 100644
--- a/internal/userdb/userdb.go
+++ b/internal/userdb/userdb.go
@@ -178,8 +178,8 @@ func (db *DB) RemoveUser(name string) bool {
 	return present
 }
 
-// HasUser returns true if the user is present, False otherwise.
-func (db *DB) HasUser(name string) bool {
+// Exists returns true if the user is present, False otherwise.
+func (db *DB) Exists(name string) bool {
 	db.mu.Lock()
 	_, present := db.db.Users[name]
 	db.mu.Unlock()
diff --git a/internal/userdb/userdb_test.go b/internal/userdb/userdb_test.go
index f0a2953..37f2c64 100644
--- a/internal/userdb/userdb_test.go
+++ b/internal/userdb/userdb_test.go
@@ -129,7 +129,7 @@ func TestWrite(t *testing.T) {
 
 	db = mustLoad(t, fname)
 	for _, name := range []string{"user1", "ñoño"} {
-		if !db.HasUser(name) {
+		if !db.Exists(name) {
 			t.Errorf("user %q not in database", name)
 		}
 		if db.db.Users[name].GetScheme() == nil {
@@ -294,12 +294,12 @@ func TestRemoveUser(t *testing.T) {
 	}
 }
 
-func TestHasUser(t *testing.T) {
+func TestExists(t *testing.T) {
 	fname := mustCreateDB(t, "")
 	defer removeIfSuccessful(t, fname)
 	db := mustLoad(t, fname)
 
-	if db.HasUser("unknown") {
+	if db.Exists("unknown") {
 		t.Errorf("unknown user exists")
 	}
 
@@ -307,15 +307,15 @@ func TestHasUser(t *testing.T) {
 		t.Fatalf("error adding user: %v", err)
 	}
 
-	if db.HasUser("unknown") {
+	if db.Exists("unknown") {
 		t.Errorf("unknown user exists")
 	}
 
-	if !db.HasUser("user") {
+	if !db.Exists("user") {
 		t.Errorf("known user does not exist")
 	}
 
-	if !db.HasUser("user") {
+	if !db.Exists("user") {
 		t.Errorf("known user does not exist")
 	}
 }