package userdb
import (
"fmt"
"os"
"reflect"
"strings"
"testing"
)
// Remove the file if the test was successful. Used in defer statements, to
// leave files around for inspection when the tests failed.
func removeIfSuccessful(t *testing.T, fname string) {
// Safeguard, to make sure we only remove test files.
// This should help prevent accidental deletions.
if !strings.Contains(fname, "userdb_test") {
panic("invalid/dangerous directory")
}
if !t.Failed() {
os.Remove(fname)
}
}
// Create a database with the given content on a temporary filename. Return
// the filename, or an error if there were errors creating it.
func mustCreateDB(t *testing.T, content string) string {
f, err := os.CreateTemp("", "userdb_test")
if err != nil {
t.Fatal(err)
}
if _, err := f.WriteString(content); err != nil {
t.Fatal(err)
}
t.Logf("file: %q", f.Name())
return f.Name()
}
func dbEquals(a, b *DB) bool {
if a.db == nil || b.db == nil {
return a.db == nil && b.db == nil
}
if len(a.db.Users) != len(b.db.Users) {
return false
}
for k, av := range a.db.Users {
bv, ok := b.db.Users[k]
if !ok || !reflect.DeepEqual(av, bv) {
return false
}
}
return true
}
var emptyDB = &DB{
db: &ProtoDB{Users: map[string]*Password{}},
}
// Test various cases of loading an empty/broken database.
func TestEmptyLoad(t *testing.T) {
cases := []struct {
desc string
content string
fatal bool
fatalErr error
}{
{"empty file", "", false, nil},
{"invalid ", "users: < invalid >", true, nil},
}
for _, c := range cases {
testOneLoad(t, c.desc, c.content, c.fatal, c.fatalErr)
}
}
func testOneLoad(t *testing.T, desc, content string, fatal bool, fatalErr error) {
fname := mustCreateDB(t, content)
defer removeIfSuccessful(t, fname)
db, err := Load(fname)
if fatal {
if err == nil {
t.Errorf("case %q: expected error loading, got nil", desc)
}
if fatalErr != nil && fatalErr != err {
t.Errorf("case %q: expected error %v, got %v", desc, fatalErr, err)
}
} else if !fatal && err != nil {
t.Fatalf("case %q: error loading database: %v", desc, err)
}
if db != nil && !dbEquals(db, emptyDB) {
t.Errorf("case %q: DB not empty: %#v", desc, db.db.Users)
}
}
func mustLoad(t *testing.T, fname string) *DB {
db, err := Load(fname)
if err != nil {
t.Fatalf("error loading database: %v", err)
}
return db
}
func TestWrite(t *testing.T) {
fname := mustCreateDB(t, "")
defer removeIfSuccessful(t, fname)
db := mustLoad(t, fname)
if err := db.Write(); err != nil {
t.Fatalf("error writing database: %v", err)
}
// Load again, check it works and it's still empty.
db = mustLoad(t, fname)
if !dbEquals(emptyDB, db) {
t.Fatalf("expected %v, got %v", emptyDB, db)
}
// Add users, write, and load again.
if err := db.AddUser("user1", "passwd1"); err != nil {
t.Fatalf("failed to add user1: %v", err)
}
if err := db.AddUser("ñoño", "añicos"); err != nil {
t.Fatalf("failed to add ñoño: %v", err)
}
if err := db.AddDeniedUser("ñaca"); err != nil {
t.Fatalf("failed to add ñaca: %v", err)
}
if err := db.Write(); err != nil {
t.Fatalf("error writing database: %v", err)
}
db = mustLoad(t, fname)
for _, name := range []string{"user1", "ñoño", "ñaca"} {
if !db.Exists(name) {
t.Errorf("user %q not in database", name)
}
if db.db.Users[name].GetScheme() == nil {
t.Errorf("user %q missing scheme: %#v", name, db.db.Users[name])
}
}
// Check various user and password combinations, not all valid.
combinations := []struct {
user, passwd string
expected bool
}{
{"user1", "passwd1", true},
{"user1", "passwd", false},
{"user1", "passwd12", false},
{"ñoño", "añicos", true},
{"ñoño", "anicos", false},
{"ñaca", "", false},
{"ñaca", "lalala", false},
{"notindb", "something", false},
{"", "", false},
{" ", " ", false},
}
for _, c := range combinations {
if db.Authenticate(c.user, c.passwd) != c.expected {
t.Errorf("auth(%q, %q) != %v", c.user, c.passwd, c.expected)
}
}
}
func TestNew(t *testing.T) {
fname := fmt.Sprintf("%s/userdb_test-%d", os.TempDir(), os.Getpid())
defer os.Remove(fname)
db1 := New(fname)
db1.AddUser("user", "passwd")
db1.Write()
db2, err := Load(fname)
if err != nil {
t.Fatalf("error loading: %v", err)
}
if !dbEquals(db1, db2) {
t.Errorf("databases differ. db1:%v != db2:%v", db1, db2)
}
}
func TestInvalidUsername(t *testing.T) {
fname := mustCreateDB(t, "")
defer removeIfSuccessful(t, fname)
db := mustLoad(t, fname)
// Names that are invalid.
names := []string{
// Contain various types of spaces.
" ", " ", "a b", "ñ ñ", "a\xa0b", "a\x85b", "a\nb", "a\tb", "a\xffb",
// Contain characters not allowed by PRECIS.
"\u00b9", "\u2163",
// Names that are not normalized, but would otherwise be valid.
"A", "Ñ",
}
for _, name := range names {
err := db.AddUser(name, "passwd")
if err == nil {
t.Errorf("AddUser(%q) worked, expected it to fail", name)
}
err = db.AddDeniedUser(name)
if err == nil {
t.Errorf("AddDeniedUser(%q) worked, expected it to fail", name)
}
}
}
func plainPassword(p string) *Password {
return &Password{
Scheme: &Password_Plain{
Plain: &Plain{Password: []byte(p)},
},
}
}
// Test the plain scheme. Note we don't expect to use it in cases other than
// debugging, but it should be functional for that purpose.
func TestPlainScheme(t *testing.T) {
fname := mustCreateDB(t, "")
defer removeIfSuccessful(t, fname)
db := mustLoad(t, fname)
db.db.Users["user"] = plainPassword("pass word")
err := db.Write()
if err != nil {
t.Errorf("Write failed: %v", err)
}
db = mustLoad(t, fname)
if !db.Authenticate("user", "pass word") {
t.Errorf("failed plain authentication")
}
if db.Authenticate("user", "wrong") {
t.Errorf("plain authentication worked but it shouldn't")
}
}
// Test the denied scheme.
func TestDeniedScheme(t *testing.T) {
fname := mustCreateDB(t, "")
defer removeIfSuccessful(t, fname)
db := mustLoad(t, fname)
db.db.Users["user"] = &Password{Scheme: &Password_Denied{}}
err := db.Write()
if err != nil {
t.Errorf("Write failed: %v", err)
}
db = mustLoad(t, fname)
if db.Authenticate("user", "anything") {
t.Errorf("denied authentication worked but it shouldn't")
}
}
func TestReload(t *testing.T) {
content := "users:< key: 'u1' value:< plain:< password: 'pass' >>>"
fname := mustCreateDB(t, content)
defer removeIfSuccessful(t, fname)
db := mustLoad(t, fname)
// Add a valid line to the file.
content += "users:< key: 'u2' value:< plain:< password: 'pass' >>>"
os.WriteFile(fname, []byte(content), 0660)
err := db.Reload()
if err != nil {
t.Errorf("Reload failed: %v", err)
}
if len(db.db.Users) != 2 {
t.Errorf("expected 2 users, got %d", len(db.db.Users))
}
// And now a broken one.
content += "users:< invalid >"
os.WriteFile(fname, []byte(content), 0660)
err = db.Reload()
if err == nil {
t.Errorf("expected error, got nil")
}
if len(db.db.Users) != 2 {
t.Errorf("expected 2 users, got %d", len(db.db.Users))
}
// Delete the file (which is not considered an error).
os.Remove(fname)
err = db.Reload()
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if len(db.db.Users) != 0 {
t.Errorf("expected 0 users, got %d", len(db.db.Users))
}
}
func TestRemoveUser(t *testing.T) {
fname := mustCreateDB(t, "")
defer removeIfSuccessful(t, fname)
db := mustLoad(t, fname)
if ok := db.RemoveUser("unknown"); ok {
t.Errorf("removal of unknown user succeeded")
}
if err := db.AddUser("user", "passwd"); err != nil {
t.Fatalf("error adding user: %v", err)
}
if ok := db.RemoveUser("unknown"); ok {
t.Errorf("removal of unknown user succeeded")
}
if ok := db.RemoveUser("user"); !ok {
t.Errorf("removal of existing user failed")
}
if ok := db.RemoveUser("user"); ok {
t.Errorf("removal of unknown user succeeded")
}
}
func TestExists(t *testing.T) {
fname := mustCreateDB(t, "")
defer removeIfSuccessful(t, fname)
db := mustLoad(t, fname)
if db.Exists("unknown") {
t.Errorf("unknown user exists")
}
if err := db.AddUser("user", "passwd"); err != nil {
t.Fatalf("error adding user: %v", err)
}
if db.Exists("unknown") {
t.Errorf("unknown user exists")
}
if !db.Exists("user") {
t.Errorf("known user does not exist")
}
if !db.Exists("user") {
t.Errorf("known user does not exist")
}
if err := db.AddDeniedUser("denieduser"); err != nil {
t.Fatalf("error adding user: %v", err)
}
if !db.Exists("denieduser") {
t.Errorf("known (denied) user does not exist")
}
}