author | Alberto Bertogli
<albertito@blitiri.com.ar> 2018-01-29 21:55:34 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2018-02-10 22:24:39 UTC |
parent | 08e6f57d2edc8d4bf55b135ece0aa9c570d36f12 |
internal/auth/auth.go | +132 | -21 |
internal/auth/auth_test.go | +198 | -21 |
internal/smtpsrv/conn.go | +11 | -8 |
internal/smtpsrv/server.go | +13 | -9 |
internal/userdb/userdb.go | +2 | -2 |
internal/userdb/userdb_test.go | +6 | -6 |
diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 0beaf82..6ab74c2 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -1,17 +1,130 @@ +// Package auth implements authentication services for chasquid. package auth import ( "bytes" "encoding/base64" + "errors" "fmt" "math/rand" "strings" "time" "blitiri.com.ar/go/chasquid/internal/normalize" - "blitiri.com.ar/go/chasquid/internal/userdb" ) +// Interface for authentication backends. +type Backend interface { + Authenticate(user, password string) (bool, error) + Exists(user string) (bool, error) + Reload() error +} + +// Interface for authentication backends that don't need to emit errors. +// This allows backends to avoid unnecessary complexity, in exchange for a bit +// more here. +// They can be converted to normal Backend using WrapNoErrorBackend (defined +// below). +type NoErrorBackend interface { + Authenticate(user, password string) bool + Exists(user string) bool + Reload() error +} + +type Authenticator struct { + // Registered backends, map of domain (string) -> Backend. + // Backend operations will _not_ include the domain in the username. + backends map[string]Backend + + // Fallback backend, to use when backends[domain] (which may not exist) + // did not yield a positive result. + // Note that this backend gets the user with the domain included, of the + // form "user@domain". + Fallback Backend + + // How long Authenticate calls should last, approximately. + // This will be applied both for successful and unsuccessful attempts. + // We will increase this number by 0-20%. + AuthDuration time.Duration +} + +func NewAuthenticator() *Authenticator { + return &Authenticator{ + backends: map[string]Backend{}, + AuthDuration: 100 * time.Millisecond, + } +} + +func (a *Authenticator) Register(domain string, be Backend) { + a.backends[domain] = be +} + +// Authenticate the user@domain with the given password. +func (a *Authenticator) Authenticate(user, domain, password string) (bool, error) { + // Make sure the call takes a.AuthDuration + 0-20% regardless of the + // outcome, to prevent basic timing attacks. + defer func(start time.Time) { + elapsed := time.Since(start) + delay := a.AuthDuration - elapsed + if delay > 0 { + maxDelta := int64(float64(delay) * 0.2) + delay += time.Duration(rand.Int63n(maxDelta)) + time.Sleep(delay) + } + }(time.Now()) + + if be, ok := a.backends[domain]; ok { + ok, err := be.Authenticate(user, password) + if ok || err != nil { + return ok, err + } + } + + if a.Fallback != nil { + return a.Fallback.Authenticate(user+"@"+domain, password) + } + + return false, nil +} + +func (a *Authenticator) Exists(user, domain string) (bool, error) { + if be, ok := a.backends[domain]; ok { + ok, err := be.Exists(user) + if ok || err != nil { + return ok, err + } + } + + if a.Fallback != nil { + return a.Fallback.Exists(user + "@" + domain) + } + + return false, nil +} + +// Reload the registered backends. +func (a *Authenticator) Reload() error { + msgs := []string{} + + for domain, be := range a.backends { + err := be.Reload() + if err != nil { + msgs = append(msgs, fmt.Sprintf("%q: %v", domain, err)) + } + } + if a.Fallback != nil { + err := a.Fallback.Reload() + if err != nil { + msgs = append(msgs, fmt.Sprintf("<fallback>: %v", err)) + } + } + + if len(msgs) > 0 { + return errors.New(strings.Join(msgs, " ; ")) + } + return nil +} + // DecodeResponse decodes a plain auth response. // // It must be a a base64-encoded string of the form: @@ -89,27 +202,25 @@ func DecodeResponse(response string) (user, domain, passwd string, err error) { return } -// How long Authenticate calls should last, approximately. -// This will be applied both for successful and unsuccessful attempts. -// We will increase this number by 0-20%. -var AuthenticateTime = 100 * time.Millisecond +// WrapNoErrorBackend wraps a NoErrorBackend, converting it into a valid +// Backend. This is normally used in Auth.Register calls, to register no-error +// backends. +func WrapNoErrorBackend(be NoErrorBackend) Backend { + return &wrapNoErrorBackend{be} +} -// Authenticate user/password on the given database. -func Authenticate(udb *userdb.DB, user, passwd string) bool { - defer func(start time.Time) { - elapsed := time.Since(start) - delay := AuthenticateTime - elapsed - if delay > 0 { - maxDelta := int64(float64(delay) * 0.2) - delay += time.Duration(rand.Int63n(maxDelta)) - time.Sleep(delay) - } - }(time.Now()) +type wrapNoErrorBackend struct { + be NoErrorBackend +} - // Note that the database CAN be nil, to simplify callers. - if udb == nil { - return false - } +func (w *wrapNoErrorBackend) Authenticate(user, password string) (bool, error) { + return w.be.Authenticate(user, password), nil +} + +func (w *wrapNoErrorBackend) Exists(user string) (bool, error) { + return w.be.Exists(user), nil +} - return udb.Authenticate(user, passwd) +func (w *wrapNoErrorBackend) Reload() error { + return w.be.Reload() } diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go index 0d5fba5..95c3eb5 100644 --- a/internal/auth/auth_test.go +++ b/internal/auth/auth_test.go @@ -2,6 +2,7 @@ package auth import ( "encoding/base64" + "fmt" "testing" "time" @@ -58,36 +59,212 @@ func TestAuthenticate(t *testing.T) { db := userdb.New("/dev/null") db.AddUser("user", "password") + a := NewAuthenticator() + a.Register("domain", WrapNoErrorBackend(db)) + + // Shorten the duration to speed up the test. This should still be long + // enough for it to fail if we don't sleep intentionally. + a.AuthDuration = 20 * time.Millisecond + // Test the correct case first + check(t, a, "user", "domain", "password", true) + + // Wrong password, but valid user@domain. ts := time.Now() - if !Authenticate(db, "user", "password") { - t.Errorf("failed valid authentication for user/password") + if ok, _ := a.Authenticate("user", "domain", "invalid"); ok { + t.Errorf("invalid password, but authentication succeeded") } - if time.Since(ts) < AuthenticateTime { - t.Errorf("authentication was too fast") + if time.Since(ts) < a.AuthDuration { + t.Errorf("authentication was too fast (invalid case)") } - // Incorrect cases. - cases := []struct{ user, password string }{ - {"user", "incorrect"}, - {"invalid", "p"}, + // Incorrect cases, where the user@domain do not exist. + cases := []struct{ user, domain, password string }{ + {"user", "unknown", "password"}, + {"invalid", "domain", "p"}, + {"invalid", "unknown", "p"}, + {"user", "", "password"}, + {"invalid", "", "p"}, + {"", "domain", "password"}, + {"", "", ""}, } for _, c := range cases { - ts = time.Now() - if Authenticate(db, c.user, c.password) { - t.Errorf("successful auth on %v", c) - } - if time.Since(ts) < AuthenticateTime { - t.Errorf("authentication was too fast") - } + check(t, a, c.user, c.domain, c.password, false) + } +} + +func check(t *testing.T, a *Authenticator, user, domain, passwd string, expect bool) { + c := fmt.Sprintf("{%s@%s %s}", user, domain, passwd) + ts := time.Now() + + ok, err := a.Authenticate(user, domain, passwd) + if time.Since(ts) < a.AuthDuration { + t.Errorf("auth on %v was too fast", c) + } + if ok != expect { + t.Errorf("auth on %v: got %v, expected %v", c, ok, expect) + } + if err != nil { + t.Errorf("auth on %v: got error %v", c, err) + } + + ok, err = a.Exists(user, domain) + if ok != expect { + t.Errorf("exists on %v: got %v, expected %v", c, ok, expect) + } + if err != nil { + t.Errorf("exists on %v: error %v", c, err) + } +} + +func TestInterfaces(t *testing.T) { + var _ NoErrorBackend = userdb.New("/dev/null") +} + +// Backend implementation for testing. +type TestBE struct { + users map[string]string + reloadCount int + nextError error +} + +func NewTestBE() *TestBE { + return &TestBE{ + users: map[string]string{}, + } +} +func (d *TestBE) add(user, password string) { + d.users[user] = password +} + +func (d *TestBE) Authenticate(user, password string) (bool, error) { + if d.nextError != nil { + return false, d.nextError + } + + if validP, ok := d.users[user]; ok { + return validP == password, nil + } + return false, nil +} + +func (d *TestBE) Exists(user string) (bool, error) { + if d.nextError != nil { + return false, d.nextError + } + _, ok := d.users[user] + return ok, nil +} + +func (d *TestBE) Reload() error { + d.reloadCount++ + if d.nextError != nil { + return d.nextError } + return nil +} + +func TestMultipleBackends(t *testing.T) { + domain1 := NewTestBE() + domain2 := NewTestBE() + fallback := NewTestBE() + + a := NewAuthenticator() + a.Register("domain1", domain1) + a.Register("domain2", domain2) + a.Fallback = fallback + + // Shorten the duration to speed up the test. This should still be long + // enough for it to fail if we don't sleep intentionally. + a.AuthDuration = 20 * time.Millisecond + + domain1.add("user1", "passwd1") + domain2.add("user2", "passwd2") + fallback.add("user3@fallback", "passwd3") + fallback.add("user4@domain1", "passwd4") + + // Successful tests. + cases := []struct{ user, domain, password string }{ + {"user1", "domain1", "passwd1"}, + {"user2", "domain2", "passwd2"}, + {"user3", "fallback", "passwd3"}, + {"user4", "domain1", "passwd4"}, + } + for _, c := range cases { + check(t, a, c.user, c.domain, c.password, true) + } + + // Unsuccessful tests (users don't exist). + cases = []struct{ user, domain, password string }{ + {"nobody", "domain1", "p"}, + {"nobody", "domain2", "p"}, + {"nobody", "fallback", "p"}, + {"user3", "", "p"}, + } + for _, c := range cases { + check(t, a, c.user, c.domain, c.password, false) + } +} + +func TestErrors(t *testing.T) { + be := NewTestBE() + be.add("user", "passwd") + + a := NewAuthenticator() + a.Register("domain", be) + a.AuthDuration = 0 - // And the special case of a nil userdb. - ts = time.Now() - if Authenticate(nil, "user", "password") { - t.Errorf("successful auth on a nil userdb") + ok, err := a.Authenticate("user", "domain", "passwd") + if err != nil || !ok { + t.Fatalf("failed auth") + } + + expectedErr := fmt.Errorf("test error") + be.nextError = expectedErr + + ok, err = a.Authenticate("user", "domain", "passwd") + if ok { + t.Errorf("authentication succeeded, expected error") + } + if err != expectedErr { + t.Errorf("expected error, got %v", err) + } + + ok, err = a.Exists("user", "domain") + if ok { + t.Errorf("exists succeeded, expected error") + } + if err != expectedErr { + t.Errorf("expected error, got %v", err) + } +} + +func TestReload(t *testing.T) { + be1 := NewTestBE() + be2 := NewTestBE() + fallback := NewTestBE() + + a := NewAuthenticator() + a.Register("domain1", be1) + a.Register("domain2", be2) + a.Fallback = fallback + + err := a.Reload() + if err != nil { + t.Errorf("unexpected error reloading: %v", err) + } + if be1.reloadCount != 1 || be2.reloadCount != 1 || fallback.reloadCount != 1 { + t.Errorf("unexpected reload counts: %d %d %d != 1 1 1", + be1.reloadCount, be2.reloadCount, fallback.reloadCount) + } + + be2.nextError = fmt.Errorf("test error") + err = a.Reload() + if err == nil { + t.Errorf("expected error reloading, got nil") } - if time.Since(ts) < AuthenticateTime { - t.Errorf("authentication was too fast") + if be1.reloadCount != 2 || be2.reloadCount != 2 || fallback.reloadCount != 2 { + t.Errorf("unexpected reload counts: %d %d %d != 2 2 2", + be1.reloadCount, be2.reloadCount, fallback.reloadCount) } } diff --git a/internal/smtpsrv/conn.go b/internal/smtpsrv/conn.go index 9ab3981..42a8cd7 100644 --- a/internal/smtpsrv/conn.go +++ b/internal/smtpsrv/conn.go @@ -30,7 +30,6 @@ import ( "blitiri.com.ar/go/chasquid/internal/set" "blitiri.com.ar/go/chasquid/internal/tlsconst" "blitiri.com.ar/go/chasquid/internal/trace" - "blitiri.com.ar/go/chasquid/internal/userdb" "blitiri.com.ar/go/spf" ) @@ -120,9 +119,9 @@ type Conn struct { // Are we using TLS? onTLS bool - // User databases, aliases and local domains, taken from the server at + // Authenticator, aliases and local domains, taken from the server at // creation time. - userDBs map[string]*userdb.DB + authr *auth.Authenticator localDomains *set.String aliasesR *aliases.Resolver dinfo *domaininfo.DB @@ -897,7 +896,11 @@ func (c *Conn) AUTH(params string) (code int, msg string) { return 535, fmt.Sprintf("error decoding AUTH response: %v", err) } - if auth.Authenticate(c.userDBs[domain], user, passwd) { + authOk, err := c.authr.Authenticate(user, domain, passwd) + if err != nil { + c.tr.Errorf("error authenticating %q@%q: %v", user, domain, err) + } + if authOk { c.authUser = user c.authDomain = domain c.completedAuth = true @@ -929,11 +932,11 @@ func (c *Conn) userExists(addr string) bool { // look up "user" in our databases if the domain is local, which is what // we want. user, domain := envelope.Split(addr) - udb := c.userDBs[domain] - if udb == nil { - return false + ok, err := c.authr.Exists(user, domain) + if err != nil { + c.tr.Errorf("error checking if user %q exists: %v", addr, err) } - return udb.HasUser(user) + return ok } func (c *Conn) readCommand() (cmd, params string, err error) { diff --git a/internal/smtpsrv/server.go b/internal/smtpsrv/server.go index f0d41e6..777159c 100644 --- a/internal/smtpsrv/server.go +++ b/internal/smtpsrv/server.go @@ -9,6 +9,7 @@ import ( "time" "blitiri.com.ar/go/chasquid/internal/aliases" + "blitiri.com.ar/go/chasquid/internal/auth" "blitiri.com.ar/go/chasquid/internal/courier" "blitiri.com.ar/go/chasquid/internal/domaininfo" "blitiri.com.ar/go/chasquid/internal/maillog" @@ -38,7 +39,8 @@ type Server struct { localDomains *set.String // User databases (per domain). - userDBs map[string]*userdb.DB + // Authenticator. + authr *auth.Authenticator // Aliases resolver. aliasesR *aliases.Resolver @@ -67,7 +69,7 @@ func NewServer() *Server { connTimeout: 20 * time.Minute, commandTimeout: 1 * time.Minute, localDomains: &set.String{}, - userDBs: map[string]*userdb.DB{}, + authr: auth.NewAuthenticator(), aliasesR: aliases.NewResolver(), } } @@ -95,13 +97,17 @@ func (s *Server) AddDomain(d string) { } func (s *Server) AddUserDB(domain string, db *userdb.DB) { - s.userDBs[domain] = db + s.authr.Register(domain, auth.WrapNoErrorBackend(db)) } func (s *Server) AddAliasesFile(domain, f string) error { return s.aliasesR.AddAliasesFile(domain, f) } +func (s *Server) SetAuthFallback(be auth.Backend) { + s.authr.Fallback = be +} + func (s *Server) SetAliasesConfig(suffixSep, dropChars string) { s.aliasesR.SuffixSep = suffixSep s.aliasesR.DropChars = dropChars @@ -145,11 +151,9 @@ func (s *Server) periodicallyReload() { log.Errorf("Error reloading aliases: %v", err) } - for domain, udb := range s.userDBs { - err = udb.Reload() - if err != nil { - log.Errorf("Error reloading %q user db: %v", domain, err) - } + err = s.authr.Reload() + if err != nil { + log.Errorf("Error reloading authenticators: %v", err) } } } @@ -219,7 +223,7 @@ func (s *Server) serve(l net.Listener, mode SocketMode) { mode: mode, tlsConfig: s.tlsConfig, onTLS: mode.TLS, - userDBs: s.userDBs, + authr: s.authr, aliasesR: s.aliasesR, localDomains: s.localDomains, dinfo: s.dinfo, diff --git a/internal/userdb/userdb.go b/internal/userdb/userdb.go index 83f9942..0359fbb 100644 --- a/internal/userdb/userdb.go +++ b/internal/userdb/userdb.go @@ -178,8 +178,8 @@ func (db *DB) RemoveUser(name string) bool { return present } -// HasUser returns true if the user is present, False otherwise. -func (db *DB) HasUser(name string) bool { +// Exists returns true if the user is present, False otherwise. +func (db *DB) Exists(name string) bool { db.mu.Lock() _, present := db.db.Users[name] db.mu.Unlock() diff --git a/internal/userdb/userdb_test.go b/internal/userdb/userdb_test.go index f0a2953..37f2c64 100644 --- a/internal/userdb/userdb_test.go +++ b/internal/userdb/userdb_test.go @@ -129,7 +129,7 @@ func TestWrite(t *testing.T) { db = mustLoad(t, fname) for _, name := range []string{"user1", "ñoño"} { - if !db.HasUser(name) { + if !db.Exists(name) { t.Errorf("user %q not in database", name) } if db.db.Users[name].GetScheme() == nil { @@ -294,12 +294,12 @@ func TestRemoveUser(t *testing.T) { } } -func TestHasUser(t *testing.T) { +func TestExists(t *testing.T) { fname := mustCreateDB(t, "") defer removeIfSuccessful(t, fname) db := mustLoad(t, fname) - if db.HasUser("unknown") { + if db.Exists("unknown") { t.Errorf("unknown user exists") } @@ -307,15 +307,15 @@ func TestHasUser(t *testing.T) { t.Fatalf("error adding user: %v", err) } - if db.HasUser("unknown") { + if db.Exists("unknown") { t.Errorf("unknown user exists") } - if !db.HasUser("user") { + if !db.Exists("user") { t.Errorf("known user does not exist") } - if !db.HasUser("user") { + if !db.Exists("user") { t.Errorf("known user does not exist") } }