author | Alberto Bertogli
<albertito@blitiri.com.ar> 2018-05-20 11:21:15 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2018-05-20 12:18:17 UTC |
parent | 2064e9e65dc940f1409acf09f3056b9c1cb63194 |
internal/domaininfo/domaininfo.go | +13 | -2 |
internal/domaininfo/domaininfo_test.go | +27 | -18 |
internal/smtpsrv/server.go | +5 | -5 |
internal/testlib/testlib.go | +14 | -0 |
internal/testlib/testlib_test.go | +24 | -0 |
diff --git a/internal/domaininfo/domaininfo.go b/internal/domaininfo/domaininfo.go index 2d7ff42..0992e2f 100644 --- a/internal/domaininfo/domaininfo.go +++ b/internal/domaininfo/domaininfo.go @@ -38,11 +38,22 @@ func New(dir string) (*DB, error) { } l.ev = trace.NewEventLog("DomainInfo", dir) + err = l.Reload() + if err != nil { + return nil, err + } + return l, nil } -// Load the database from disk; should be called once at initialization. -func (db *DB) Load() error { +// Reload the database from disk. +func (db *DB) Reload() error { + db.Lock() + defer db.Unlock() + + // Clear the map, in case it has data. + db.info = map[string]*Domain{} + ids, err := db.store.ListIDs() if err != nil { return err diff --git a/internal/domaininfo/domaininfo_test.go b/internal/domaininfo/domaininfo_test.go index 6aa2790..521894c 100644 --- a/internal/domaininfo/domaininfo_test.go +++ b/internal/domaininfo/domaininfo_test.go @@ -2,7 +2,6 @@ package domaininfo import ( "testing" - "time" "blitiri.com.ar/go/chasquid/internal/testlib" ) @@ -15,10 +14,6 @@ func TestBasic(t *testing.T) { t.Fatal(err) } - if err := db.Load(); err != nil { - t.Fatal(err) - } - if !db.IncomingSecLevel("d1", SecLevel_PLAIN) { t.Errorf("new domain as plain not allowed") } @@ -29,24 +24,11 @@ func TestBasic(t *testing.T) { t.Errorf("decrement to tls-insecure was allowed") } - // Wait until it is written to disk. - for dl := time.Now().Add(30 * time.Second); time.Now().Before(dl); { - d := &Domain{} - ok, _ := db.store.Get("d1", d) - if ok { - break - } - time.Sleep(50 * time.Millisecond) - } - // Check that it was added to the store and a new db sees it. db2, err := New(dir) if err != nil { t.Fatal(err) } - if err := db2.Load(); err != nil { - t.Fatal(err) - } if db2.IncomingSecLevel("d1", SecLevel_TLS_INSECURE) { t.Errorf("decrement to tls-insecure was allowed in new DB") } @@ -113,3 +95,30 @@ func TestProgressions(t *testing.T) { } } } + +func TestErrors(t *testing.T) { + // Non-existent directory. + _, err := New("/doesnotexists") + if err == nil { + t.Error("could create a DB on a non-existent directory") + } + + // Corrupt/invalid file. + dir := testlib.MustTempDir(t) + defer testlib.RemoveIfOk(t, dir) + db, err := New(dir) + if err != nil { + t.Fatal(err) + } + + if !db.IncomingSecLevel("d1", SecLevel_TLS_SECURE) { + t.Errorf("increment to tls-secure not allowed") + } + + testlib.Rewrite(t, dir+"/s:d1", "invalid-text-protobuf-contents") + + err = db.Reload() + if err == nil { + t.Errorf("no error when reloading db with invalid file") + } +} diff --git a/internal/smtpsrv/server.go b/internal/smtpsrv/server.go index f3cb0e6..dec8e13 100644 --- a/internal/smtpsrv/server.go +++ b/internal/smtpsrv/server.go @@ -140,11 +140,6 @@ func (s *Server) InitDomainInfo(dir string) *domaininfo.DB { log.Fatalf("Error opening domain info database: %v", err) } - err = s.dinfo.Load() - if err != nil { - log.Fatalf("Error loading domain info database: %v", err) - } - return s.dinfo } @@ -176,6 +171,11 @@ func (s *Server) periodicallyReload() { if err != nil { log.Errorf("Error reloading authenticators: %v", err) } + + err = s.dinfo.Reload() + if err != nil { + log.Errorf("Error reloading domaininfo: %v", err) + } } } diff --git a/internal/testlib/testlib.go b/internal/testlib/testlib.go index 7a96cbd..638b47f 100644 --- a/internal/testlib/testlib.go +++ b/internal/testlib/testlib.go @@ -37,3 +37,17 @@ func RemoveIfOk(t *testing.T, dir string) { os.RemoveAll(dir) } } + +func Rewrite(t *testing.T, path, contents string) error { + // Safeguard, to make sure we only mess with test files. + if !strings.Contains(path, "testlib_") { + panic("invalid/dangerous path") + } + + err := ioutil.WriteFile(path, []byte(contents), 0600) + if err != nil { + t.Errorf("failed to rewrite file: %v", err) + } + + return err +} diff --git a/internal/testlib/testlib_test.go b/internal/testlib/testlib_test.go index fad4b65..71b84ef 100644 --- a/internal/testlib/testlib_test.go +++ b/internal/testlib/testlib_test.go @@ -52,3 +52,27 @@ func TestLeaveDirOnError(t *testing.T) { // Remove the directory for real this time. RemoveIfOk(t, dir) } + +func TestRewriteSafeguard(t *testing.T) { + myt := &testing.T{} + defer func() { + if r := recover(); r != nil { + t.Logf("recovered: %v", r) + } else { + t.Fatalf("check did not panic as expected") + } + }() + + Rewrite(myt, "/something", "test") +} + +func TestRewrite(t *testing.T) { + dir := MustTempDir(t) + defer RemoveIfOk(t, dir) + + myt := &testing.T{} + Rewrite(myt, dir+"/file", "hola") + if myt.Failed() { + t.Errorf("basic rewrite failed") + } +}