git » chasquid » commit 764c09e

domaininfo: Add a Clear method to clear information for a given domain

author Alberto Bertogli
2023-07-30 10:34:00 UTC
committer Alberto Bertogli
2023-07-30 10:34:27 UTC
parent ac1c849a279073cef9772310d47e04d68749bc38

domaininfo: Add a Clear method to clear information for a given domain

This patch adds a Clear method to the domaininfo database, which removes
information for the given domain.

This can be used to manually make the server forget about a domain, in
case there are operational reasons to do so.

Today, this is done via chasquid-util (which removes the backing file),
but that is hacky, and this is part of replacing it with a cleaner
implementation.

internal/domaininfo/domaininfo.go +25 -1
internal/domaininfo/domaininfo_test.go +81 -9

diff --git a/internal/domaininfo/domaininfo.go b/internal/domaininfo/domaininfo.go
index 0e02ead..ae53d82 100644
--- a/internal/domaininfo/domaininfo.go
+++ b/internal/domaininfo/domaininfo.go
@@ -75,7 +75,7 @@ func (db *DB) Reload() error {
 	return nil
 }
 
-func (db *DB) write(tr *trace.Trace, d *Domain) {
+func (db *DB) write(tr *trace.Trace, d *Domain) error {
 	tr = tr.NewChild("DomainInfo.write", d.Name)
 	defer tr.Finish()
 
@@ -85,6 +85,7 @@ func (db *DB) write(tr *trace.Trace, d *Domain) {
 	} else {
 		tr.Debugf("saved")
 	}
+	return err
 }
 
 // IncomingSecLevel checks an incoming security level for the domain.
@@ -158,3 +159,26 @@ func (db *DB) OutgoingSecLevel(tr *trace.Trace, domain string, level SecLevel) b
 		return true
 	}
 }
+
+// Clear sets the security level for the given domain to plain.
+// This can be used for manual overrides in case there's an operational need
+// to do so.
+func (db *DB) Clear(tr *trace.Trace, domain string) bool {
+	tr = tr.NewChild("DomainInfo.SetToPlain", domain)
+	defer tr.Finish()
+
+	db.Lock()
+	defer db.Unlock()
+
+	d, exists := db.info[domain]
+	if !exists {
+		tr.Debugf("does not exist")
+		return false
+	}
+
+	d.IncomingSecLevel = SecLevel_PLAIN
+	d.OutgoingSecLevel = SecLevel_PLAIN
+	db.write(tr, d)
+	tr.Printf("set to plain")
+	return true
+}
diff --git a/internal/domaininfo/domaininfo_test.go b/internal/domaininfo/domaininfo_test.go
index 2c50d21..77121a6 100644
--- a/internal/domaininfo/domaininfo_test.go
+++ b/internal/domaininfo/domaininfo_test.go
@@ -1,6 +1,8 @@
 package domaininfo
 
 import (
+	"errors"
+	"os"
 	"testing"
 
 	"blitiri.com.ar/go/chasquid/internal/testlib"
@@ -17,14 +19,26 @@ func TestBasic(t *testing.T) {
 	tr := trace.New("test", "basic")
 	defer tr.Finish()
 
+	// IncomingSecLevel checks.
 	if !db.IncomingSecLevel(tr, "d1", SecLevel_PLAIN) {
-		t.Errorf("new domain as plain not allowed")
+		t.Errorf("incoming: new domain as plain not allowed")
 	}
 	if !db.IncomingSecLevel(tr, "d1", SecLevel_TLS_SECURE) {
-		t.Errorf("increment to tls-secure not allowed")
+		t.Errorf("incoming: increment to tls-secure not allowed")
 	}
 	if db.IncomingSecLevel(tr, "d1", SecLevel_TLS_INSECURE) {
-		t.Errorf("decrement to tls-insecure was allowed")
+		t.Errorf("incoming: decrement to tls-insecure was allowed")
+	}
+
+	// OutgoingSecLevel checks.
+	if !db.OutgoingSecLevel(tr, "d1", SecLevel_PLAIN) {
+		t.Errorf("outgoing: new domain as plain not allowed")
+	}
+	if !db.OutgoingSecLevel(tr, "d1", SecLevel_TLS_SECURE) {
+		t.Errorf("outgoing: increment to tls-secure not allowed")
+	}
+	if db.OutgoingSecLevel(tr, "d1", SecLevel_TLS_INSECURE) {
+		t.Errorf("outgoing: decrement to tls-insecure was allowed")
 	}
 
 	// Check that it was added to the store and a new db sees it.
@@ -35,6 +49,24 @@ func TestBasic(t *testing.T) {
 	if db2.IncomingSecLevel(tr, "d1", SecLevel_TLS_INSECURE) {
 		t.Errorf("decrement to tls-insecure was allowed in new DB")
 	}
+
+	// Check that Clear resets the entry back to plain.
+	ok := db.Clear(tr, "d1")
+	if !ok {
+		t.Errorf("Clear(d1) did not find the domain")
+	}
+	if !db.IncomingSecLevel(tr, "d1", SecLevel_PLAIN) {
+		t.Errorf("Clear did not reset the domain back to plain (incoming)")
+	}
+	if !db.OutgoingSecLevel(tr, "d1", SecLevel_PLAIN) {
+		t.Errorf("Clear did not reset the domain back to plain (outgoing)")
+	}
+
+	// Check that Clear returns false if the domain does not exist.
+	ok = db.Clear(tr, "notexist")
+	if ok {
+		t.Errorf("Clear(notexist) returned true")
+	}
 }
 
 func TestNewDomain(t *testing.T) {
@@ -44,7 +76,7 @@ func TestNewDomain(t *testing.T) {
 	if err != nil {
 		t.Fatal(err)
 	}
-	tr := trace.New("test", "basic")
+	tr := trace.New("test", "newdomain")
 	defer tr.Finish()
 
 	cases := []struct {
@@ -56,12 +88,15 @@ func TestNewDomain(t *testing.T) {
 		{"secure", SecLevel_TLS_SECURE},
 	}
 	for _, c := range cases {
-		if !db.IncomingSecLevel(tr, c.domain, c.level) {
-			t.Errorf("domain %q not allowed (in) at %s", c.domain, c.level)
-		}
+		// The other tests do an incoming check first, so new domains would get
+		// created via that path. We switch the order here to exercise that
+		// OutgoingSecLevel also handles new domains successfuly.
 		if !db.OutgoingSecLevel(tr, c.domain, c.level) {
 			t.Errorf("domain %q not allowed (out) at %s", c.domain, c.level)
 		}
+		if !db.IncomingSecLevel(tr, c.domain, c.level) {
+			t.Errorf("domain %q not allowed (in) at %s", c.domain, c.level)
+		}
 	}
 }
 
@@ -72,7 +107,7 @@ func TestProgressions(t *testing.T) {
 	if err != nil {
 		t.Fatal(err)
 	}
-	tr := trace.New("test", "basic")
+	tr := trace.New("test", "progressions")
 	defer tr.Finish()
 
 	cases := []struct {
@@ -118,7 +153,7 @@ func TestErrors(t *testing.T) {
 		t.Fatal(err)
 	}
 
-	tr := trace.New("test", "basic")
+	tr := trace.New("test", "errors")
 	defer tr.Finish()
 
 	if !db.IncomingSecLevel(tr, "d1", SecLevel_TLS_SECURE) {
@@ -131,4 +166,41 @@ func TestErrors(t *testing.T) {
 	if err == nil {
 		t.Errorf("no error when reloading db with invalid file")
 	}
+
+	// Creating a db with an invalid file should also result in an error.
+	_, err = New(dir)
+	if err == nil {
+		t.Errorf("no error when creating db with invalid file")
+	}
+}
+
+func TestDirectoryErrors(t *testing.T) {
+	dir := testlib.MustTempDir(t)
+	defer testlib.RemoveIfOk(t, dir)
+	db, err := New(dir + "/db")
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	tr := trace.New("test", "direrrors")
+	defer tr.Finish()
+
+	// We want to cause store.ListIDs to return an error. To do so, we will
+	// cause Readdir to fail by removing the underlying db directory.
+	err = os.Remove(dir + "/db")
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	err = db.Reload()
+	if !errors.Is(err, os.ErrNotExist) {
+		t.Errorf("got %v, expected %v", err, os.ErrNotExist)
+	}
+
+	// We expect write() to also fail to store data in this scenario.
+	d := Domain{Name: "d1"}
+	err = db.write(tr, &d)
+	if !errors.Is(err, os.ErrNotExist) {
+		t.Errorf("got %v, expected %v", err, os.ErrNotExist)
+	}
 }