author | Alberto Bertogli
<albertito@blitiri.com.ar> 2023-07-30 10:34:00 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2023-07-30 10:34:27 UTC |
parent | ac1c849a279073cef9772310d47e04d68749bc38 |
internal/domaininfo/domaininfo.go | +25 | -1 |
internal/domaininfo/domaininfo_test.go | +81 | -9 |
diff --git a/internal/domaininfo/domaininfo.go b/internal/domaininfo/domaininfo.go index 0e02ead..ae53d82 100644 --- a/internal/domaininfo/domaininfo.go +++ b/internal/domaininfo/domaininfo.go @@ -75,7 +75,7 @@ func (db *DB) Reload() error { return nil } -func (db *DB) write(tr *trace.Trace, d *Domain) { +func (db *DB) write(tr *trace.Trace, d *Domain) error { tr = tr.NewChild("DomainInfo.write", d.Name) defer tr.Finish() @@ -85,6 +85,7 @@ func (db *DB) write(tr *trace.Trace, d *Domain) { } else { tr.Debugf("saved") } + return err } // IncomingSecLevel checks an incoming security level for the domain. @@ -158,3 +159,26 @@ func (db *DB) OutgoingSecLevel(tr *trace.Trace, domain string, level SecLevel) b return true } } + +// Clear sets the security level for the given domain to plain. +// This can be used for manual overrides in case there's an operational need +// to do so. +func (db *DB) Clear(tr *trace.Trace, domain string) bool { + tr = tr.NewChild("DomainInfo.SetToPlain", domain) + defer tr.Finish() + + db.Lock() + defer db.Unlock() + + d, exists := db.info[domain] + if !exists { + tr.Debugf("does not exist") + return false + } + + d.IncomingSecLevel = SecLevel_PLAIN + d.OutgoingSecLevel = SecLevel_PLAIN + db.write(tr, d) + tr.Printf("set to plain") + return true +} diff --git a/internal/domaininfo/domaininfo_test.go b/internal/domaininfo/domaininfo_test.go index 2c50d21..77121a6 100644 --- a/internal/domaininfo/domaininfo_test.go +++ b/internal/domaininfo/domaininfo_test.go @@ -1,6 +1,8 @@ package domaininfo import ( + "errors" + "os" "testing" "blitiri.com.ar/go/chasquid/internal/testlib" @@ -17,14 +19,26 @@ func TestBasic(t *testing.T) { tr := trace.New("test", "basic") defer tr.Finish() + // IncomingSecLevel checks. if !db.IncomingSecLevel(tr, "d1", SecLevel_PLAIN) { - t.Errorf("new domain as plain not allowed") + t.Errorf("incoming: new domain as plain not allowed") } if !db.IncomingSecLevel(tr, "d1", SecLevel_TLS_SECURE) { - t.Errorf("increment to tls-secure not allowed") + t.Errorf("incoming: increment to tls-secure not allowed") } if db.IncomingSecLevel(tr, "d1", SecLevel_TLS_INSECURE) { - t.Errorf("decrement to tls-insecure was allowed") + t.Errorf("incoming: decrement to tls-insecure was allowed") + } + + // OutgoingSecLevel checks. + if !db.OutgoingSecLevel(tr, "d1", SecLevel_PLAIN) { + t.Errorf("outgoing: new domain as plain not allowed") + } + if !db.OutgoingSecLevel(tr, "d1", SecLevel_TLS_SECURE) { + t.Errorf("outgoing: increment to tls-secure not allowed") + } + if db.OutgoingSecLevel(tr, "d1", SecLevel_TLS_INSECURE) { + t.Errorf("outgoing: decrement to tls-insecure was allowed") } // Check that it was added to the store and a new db sees it. @@ -35,6 +49,24 @@ func TestBasic(t *testing.T) { if db2.IncomingSecLevel(tr, "d1", SecLevel_TLS_INSECURE) { t.Errorf("decrement to tls-insecure was allowed in new DB") } + + // Check that Clear resets the entry back to plain. + ok := db.Clear(tr, "d1") + if !ok { + t.Errorf("Clear(d1) did not find the domain") + } + if !db.IncomingSecLevel(tr, "d1", SecLevel_PLAIN) { + t.Errorf("Clear did not reset the domain back to plain (incoming)") + } + if !db.OutgoingSecLevel(tr, "d1", SecLevel_PLAIN) { + t.Errorf("Clear did not reset the domain back to plain (outgoing)") + } + + // Check that Clear returns false if the domain does not exist. + ok = db.Clear(tr, "notexist") + if ok { + t.Errorf("Clear(notexist) returned true") + } } func TestNewDomain(t *testing.T) { @@ -44,7 +76,7 @@ func TestNewDomain(t *testing.T) { if err != nil { t.Fatal(err) } - tr := trace.New("test", "basic") + tr := trace.New("test", "newdomain") defer tr.Finish() cases := []struct { @@ -56,12 +88,15 @@ func TestNewDomain(t *testing.T) { {"secure", SecLevel_TLS_SECURE}, } for _, c := range cases { - if !db.IncomingSecLevel(tr, c.domain, c.level) { - t.Errorf("domain %q not allowed (in) at %s", c.domain, c.level) - } + // The other tests do an incoming check first, so new domains would get + // created via that path. We switch the order here to exercise that + // OutgoingSecLevel also handles new domains successfuly. if !db.OutgoingSecLevel(tr, c.domain, c.level) { t.Errorf("domain %q not allowed (out) at %s", c.domain, c.level) } + if !db.IncomingSecLevel(tr, c.domain, c.level) { + t.Errorf("domain %q not allowed (in) at %s", c.domain, c.level) + } } } @@ -72,7 +107,7 @@ func TestProgressions(t *testing.T) { if err != nil { t.Fatal(err) } - tr := trace.New("test", "basic") + tr := trace.New("test", "progressions") defer tr.Finish() cases := []struct { @@ -118,7 +153,7 @@ func TestErrors(t *testing.T) { t.Fatal(err) } - tr := trace.New("test", "basic") + tr := trace.New("test", "errors") defer tr.Finish() if !db.IncomingSecLevel(tr, "d1", SecLevel_TLS_SECURE) { @@ -131,4 +166,41 @@ func TestErrors(t *testing.T) { if err == nil { t.Errorf("no error when reloading db with invalid file") } + + // Creating a db with an invalid file should also result in an error. + _, err = New(dir) + if err == nil { + t.Errorf("no error when creating db with invalid file") + } +} + +func TestDirectoryErrors(t *testing.T) { + dir := testlib.MustTempDir(t) + defer testlib.RemoveIfOk(t, dir) + db, err := New(dir + "/db") + if err != nil { + t.Fatal(err) + } + + tr := trace.New("test", "direrrors") + defer tr.Finish() + + // We want to cause store.ListIDs to return an error. To do so, we will + // cause Readdir to fail by removing the underlying db directory. + err = os.Remove(dir + "/db") + if err != nil { + t.Fatal(err) + } + + err = db.Reload() + if !errors.Is(err, os.ErrNotExist) { + t.Errorf("got %v, expected %v", err, os.ErrNotExist) + } + + // We expect write() to also fail to store data in this scenario. + d := Domain{Name: "d1"} + err = db.write(tr, &d) + if !errors.Is(err, os.ErrNotExist) { + t.Errorf("got %v, expected %v", err, os.ErrNotExist) + } }