git » chasquid » commit c013c98

domaininfo: New package to track domain (security) information

author Alberto Bertogli
2016-10-13 01:28:30 UTC
committer Alberto Bertogli
2016-10-21 21:15:09 UTC
parent 1d7a207e00f2e2419737b974b3a80f3ef45f6d85

domaininfo: New package to track domain (security) information

This patch introduces a new "domaininfo" package, which implements a
database with information about domains.  In particular, it tracks
incoming and outgoing security levels.

That information is used in incoming and outgoing SMTP to prevent
downgrades.

chasquid.go +58 -2
chasquid_test.go +58 -5
internal/courier/smtp.go +17 -2
internal/courier/smtp_test.go +22 -2
internal/domaininfo/domaininfo.go +133 -0
internal/domaininfo/domaininfo.pb.go +96 -0
internal/domaininfo/domaininfo.proto +28 -0
internal/domaininfo/domaininfo_test.go +133 -0

diff --git a/chasquid.go b/chasquid.go
index 866a74b..b689449 100644
--- a/chasquid.go
+++ b/chasquid.go
@@ -24,6 +24,7 @@ import (
 	"blitiri.com.ar/go/chasquid/internal/auth"
 	"blitiri.com.ar/go/chasquid/internal/config"
 	"blitiri.com.ar/go/chasquid/internal/courier"
+	"blitiri.com.ar/go/chasquid/internal/domaininfo"
 	"blitiri.com.ar/go/chasquid/internal/envelope"
 	"blitiri.com.ar/go/chasquid/internal/normalize"
 	"blitiri.com.ar/go/chasquid/internal/queue"
@@ -53,6 +54,7 @@ var (
 	spfResultCount    = expvar.NewMap("chasquid/smtpIn/spfResultCount")
 	loopsDetected     = expvar.NewInt("chasquid/smtpIn/loopsDetected")
 	tlsCount          = expvar.NewMap("chasquid/smtpIn/tlsCount")
+	slcResults        = expvar.NewMap("chasquid/smtpIn/securityLevelChecks")
 )
 
 // Global event logs.
@@ -133,12 +135,14 @@ func main() {
 	// as a remote domain (for loops, alias resolutions, etc.).
 	s.AddDomain("localhost")
 
+	s.InitDomainInfo(conf.DataDir + "/domaininfo")
+
 	localC := &courier.Procmail{
 		Binary:  conf.MailDeliveryAgentBin,
 		Args:    conf.MailDeliveryAgentArgs,
 		Timeout: 30 * time.Second,
 	}
-	remoteC := &courier.SMTP{}
+	remoteC := &courier.SMTP{Dinfo: s.dinfo}
 	s.InitQueue(conf.DataDir+"/queue", localC, remoteC)
 
 	go s.periodicallyReload()
@@ -266,6 +270,9 @@ type Server struct {
 	// Aliases resolver.
 	aliasesR *aliases.Resolver
 
+	// Domain info database.
+	dinfo *domaininfo.DB
+
 	// Time before we give up on a connection, even if it's sending data.
 	connTimeout time.Duration
 
@@ -314,6 +321,19 @@ func (s *Server) AddUserDB(domain string, db *userdb.DB) {
 	s.userDBs[domain] = db
 }
 
+func (s *Server) InitDomainInfo(dir string) {
+	var err error
+	s.dinfo, err = domaininfo.New(dir)
+	if err != nil {
+		glog.Fatalf("Error opening domain info database: %v", err)
+	}
+
+	err = s.dinfo.Load()
+	if err != nil {
+		glog.Fatalf("Error loading domain info database: %v", err)
+	}
+}
+
 func (s *Server) InitQueue(path string, localC, remoteC courier.Courier) {
 	q := queue.New(path, s.localDomains, s.aliasesR, localC, remoteC)
 	err := q.Load()
@@ -399,6 +419,7 @@ func (s *Server) serve(l net.Listener, mode SocketMode) {
 			userDBs:        s.userDBs,
 			aliasesR:       s.aliasesR,
 			localDomains:   s.localDomains,
+			dinfo:          s.dinfo,
 			deadline:       time.Now().Add(s.connTimeout),
 			commandTimeout: s.commandTimeout,
 			queue:          s.queue,
@@ -449,6 +470,7 @@ type Conn struct {
 	userDBs      map[string]*userdb.DB
 	localDomains *set.String
 	aliasesR     *aliases.Resolver
+	dinfo        *domaininfo.DB
 
 	// Have we successfully completed AUTH?
 	completedAuth bool
@@ -697,6 +719,10 @@ func (c *Conn) MAIL(params string) (code int, msg string) {
 				"SPF check failed: %v", c.spfError)
 		}
 
+		if !c.secLevelCheck(addr) {
+			return 550, "security level check failed"
+		}
+
 		addr, err = normalize.DomainToUnicode(addr)
 		if err != nil {
 			return 501, "malformed address (IDNA conversion failed)"
@@ -727,6 +753,37 @@ func (c *Conn) checkSPF(addr string) (spf.Result, error) {
 	return "", nil
 }
 
+// secLevelCheck checks if the security level is acceptable for the given
+// address.
+func (c *Conn) secLevelCheck(addr string) bool {
+	// Only check if SPF passes. This serves two purposes:
+	//  - Skip for authenticated connections (we trust them implicitly).
+	//  - Don't apply this if we can't be sure the sender is authorized.
+	//    Otherwise anyone could raise the level of any domain.
+	if c.spfResult != spf.Pass {
+		slcResults.Add("skip", 1)
+		c.tr.Debugf("SPF did not pass, skipping security level check")
+		return true
+	}
+
+	domain := envelope.DomainOf(addr)
+	level := domaininfo.SecLevel_PLAIN
+	if c.onTLS {
+		level = domaininfo.SecLevel_TLS_CLIENT
+	}
+
+	ok := c.dinfo.IncomingSecLevel(domain, level)
+	if ok {
+		slcResults.Add("pass", 1)
+		c.tr.Debugf("security level check for %s passed (%s)", domain, level)
+	} else {
+		slcResults.Add("fail", 1)
+		c.tr.Errorf("security level check for %s failed (%s)", domain, level)
+	}
+
+	return ok
+}
+
 func (c *Conn) RCPT(params string) (code int, msg string) {
 	// params should be: "TO:<name@host>", and possibly followed by options
 	// such as "NOTIFY=SUCCESS,DELAY" (which we ignore).
@@ -793,7 +850,6 @@ func (c *Conn) DATA(params string) (code int, msg string) {
 	if c.mailFrom == "" {
 		return 503, "sender not yet given"
 	}
-
 	if len(c.rcptTo) == 0 {
 		return 503, "need an address to send to"
 	}
diff --git a/chasquid_test.go b/chasquid_test.go
index 78b6cb3..04cc469 100644
--- a/chasquid_test.go
+++ b/chasquid_test.go
@@ -19,6 +19,9 @@ import (
 
 	"blitiri.com.ar/go/chasquid/internal/aliases"
 	"blitiri.com.ar/go/chasquid/internal/courier"
+	"blitiri.com.ar/go/chasquid/internal/domaininfo"
+	"blitiri.com.ar/go/chasquid/internal/spf"
+	"blitiri.com.ar/go/chasquid/internal/trace"
 	"blitiri.com.ar/go/chasquid/internal/userdb"
 
 	"github.com/golang/glog"
@@ -169,7 +172,7 @@ func TestWrongMailParsing(t *testing.T) {
 		}
 	}
 
-	if err := c.Mail("from@from"); err != nil {
+	if err := c.Mail("from@plain"); err != nil {
 		t.Errorf("Mail: %v", err)
 	}
 
@@ -200,11 +203,11 @@ func TestRcptBeforeMail(t *testing.T) {
 }
 
 func TestRcptOption(t *testing.T) {
-	c := mustDial(t, ModeSMTP, false)
+	c := mustDial(t, ModeSMTP, true)
 	defer c.Close()
 
 	if err := c.Mail("from@localhost"); err != nil {
-		t.Errorf("Mail: %v", err)
+		t.Fatalf("Mail: %v", err)
 	}
 
 	params := []string{
@@ -250,7 +253,7 @@ func TestReset(t *testing.T) {
 	c := mustDial(t, ModeSMTP, false)
 	defer c.Close()
 
-	if err := c.Mail("from@from"); err != nil {
+	if err := c.Mail("from@plain"); err != nil {
 		t.Fatalf("MAIL FROM: %v", err)
 	}
 
@@ -258,7 +261,7 @@ func TestReset(t *testing.T) {
 		t.Errorf("RSET: %v", err)
 	}
 
-	if err := c.Mail("from@from"); err != nil {
+	if err := c.Mail("from@plain"); err != nil {
 		t.Errorf("MAIL after RSET: %v", err)
 	}
 }
@@ -278,6 +281,55 @@ func TestRepeatedStartTLS(t *testing.T) {
 	}
 }
 
+func TestSecLevel(t *testing.T) {
+	// We can't simulate this externally because of the SPF record
+	// requirement, so do a narrow test on Conn.secLevelCheck.
+	tmpDir, err := ioutil.TempDir("", "chasquid_test:")
+	if err != nil {
+		t.Fatalf("Failed to create temp dir: %v", err)
+	}
+	defer os.RemoveAll(tmpDir)
+
+	dinfo, err := domaininfo.New(tmpDir)
+	if err != nil {
+		t.Fatalf("Failed to create domain info: %v", err)
+	}
+
+	c := &Conn{
+		tr:    trace.New("testconn", "testconn"),
+		dinfo: dinfo,
+	}
+
+	// No SPF, skip security checks.
+	c.spfResult = spf.None
+	c.onTLS = true
+	if !c.secLevelCheck("from@slc") {
+		t.Fatalf("TLS seclevel failed")
+	}
+
+	c.onTLS = false
+	if !c.secLevelCheck("from@slc") {
+		t.Fatalf("plain seclevel failed, even though SPF does not exist")
+	}
+
+	// Now the real checks, once SPF passes.
+	c.spfResult = spf.Pass
+
+	if !c.secLevelCheck("from@slc") {
+		t.Fatalf("plain seclevel failed")
+	}
+
+	c.onTLS = true
+	if !c.secLevelCheck("from@slc") {
+		t.Fatalf("TLS seclevel failed")
+	}
+
+	c.onTLS = false
+	if c.secLevelCheck("from@slc") {
+		t.Fatalf("plain seclevel worked, downgrade was allowed")
+	}
+}
+
 //
 // === Benchmarks ===
 //
@@ -438,6 +490,7 @@ func realMain(m *testing.M) int {
 		localC := &courier.Procmail{}
 		remoteC := &courier.SMTP{}
 		s.InitQueue(tmpDir+"/queue", localC, remoteC)
+		s.InitDomainInfo(tmpDir + "/domaininfo")
 
 		udb := userdb.New("/dev/null")
 		udb.AddUser("testuser", "testpasswd")
diff --git a/internal/courier/smtp.go b/internal/courier/smtp.go
index 26dc03c..91f8bf5 100644
--- a/internal/courier/smtp.go
+++ b/internal/courier/smtp.go
@@ -11,6 +11,7 @@ import (
 	"github.com/golang/glog"
 	"golang.org/x/net/idna"
 
+	"blitiri.com.ar/go/chasquid/internal/domaininfo"
 	"blitiri.com.ar/go/chasquid/internal/envelope"
 	"blitiri.com.ar/go/chasquid/internal/smtp"
 	"blitiri.com.ar/go/chasquid/internal/trace"
@@ -32,11 +33,13 @@ var (
 
 // Exported variables.
 var (
-	tlsCount = expvar.NewMap("chasquid/smtpOut/tlsCount")
+	tlsCount   = expvar.NewMap("chasquid/smtpOut/tlsCount")
+	slcResults = expvar.NewMap("chasquid/smtpOut/securityLevelChecks")
 )
 
 // SMTP delivers remote mail via outgoing SMTP.
 type SMTP struct {
+	Dinfo *domaininfo.DB
 }
 
 func (s *SMTP) Deliver(from string, to string, data []byte) (error, bool) {
@@ -44,7 +47,8 @@ func (s *SMTP) Deliver(from string, to string, data []byte) (error, bool) {
 	defer tr.Finish()
 	tr.Debugf("%s  ->  %s", from, to)
 
-	mx, err := lookupMX(envelope.DomainOf(to))
+	toDomain := envelope.DomainOf(to)
+	mx, err := lookupMX(toDomain)
 	if err != nil {
 		// Note this is considered a permanent error.
 		// This is in line with what other servers (Exim) do. However, the
@@ -68,6 +72,7 @@ func (s *SMTP) Deliver(from string, to string, data []byte) (error, bool) {
 	// Do we use insecure TLS?
 	// Set as fallback when retrying.
 	insecure := false
+	secLevel := domaininfo.SecLevel_PLAIN
 
 retry:
 	conn, err := net.DialTimeout("tcp", mx+":"+*smtpPort, smtpDialTimeout)
@@ -110,15 +115,25 @@ retry:
 		if config.InsecureSkipVerify {
 			tr.Debugf("Insecure - using TLS, but cert does not match %s", mx)
 			tlsCount.Add("tls:insecure", 1)
+			secLevel = domaininfo.SecLevel_TLS_INSECURE
 		} else {
 			tlsCount.Add("tls:secure", 1)
 			tr.Debugf("Secure - using TLS")
+			secLevel = domaininfo.SecLevel_TLS_SECURE
 		}
 	} else {
 		tlsCount.Add("plain", 1)
 		tr.Debugf("Insecure - NOT using TLS")
 	}
 
+	if toDomain != "" && !s.Dinfo.OutgoingSecLevel(toDomain, secLevel) {
+		// We consider the failure transient, so transient misconfigurations
+		// do not affect deliveries.
+		slcResults.Add("fail", 1)
+		return tr.Errorf("Security level check failed (level:%s)", secLevel), false
+	}
+	slcResults.Add("pass", 1)
+
 	// c.Mail will add the <> for us when the address is empty.
 	if from == "<>" {
 		from = ""
diff --git a/internal/courier/smtp_test.go b/internal/courier/smtp_test.go
index 915b09c..1ba7ffd 100644
--- a/internal/courier/smtp_test.go
+++ b/internal/courier/smtp_test.go
@@ -2,12 +2,30 @@ package courier
 
 import (
 	"bufio"
+	"io/ioutil"
 	"net"
 	"net/textproto"
+	"os"
 	"testing"
 	"time"
+
+	"blitiri.com.ar/go/chasquid/internal/domaininfo"
 )
 
+func newSMTP(t *testing.T) (*SMTP, string) {
+	dir, err := ioutil.TempDir("", "smtp_test")
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	dinfo, err := domaininfo.New(dir)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	return &SMTP{dinfo}, dir
+}
+
 // Fake server, to test SMTP out.
 func fakeServer(t *testing.T, responses map[string]string) string {
 	l, err := net.Listen("tcp", "localhost:0")
@@ -72,7 +90,8 @@ func TestSMTP(t *testing.T) {
 	fakeMX["to"] = host
 	*smtpPort = port
 
-	s := &SMTP{}
+	s, tmpDir := newSMTP(t)
+	defer os.Remove(tmpDir)
 	err, _ := s.Deliver("me@me", "to@to", []byte("data"))
 	if err != nil {
 		t.Errorf("deliver failed: %v", err)
@@ -132,7 +151,8 @@ func TestSMTPErrors(t *testing.T) {
 		fakeMX["to"] = host
 		*smtpPort = port
 
-		s := &SMTP{}
+		s, tmpDir := newSMTP(t)
+		defer os.Remove(tmpDir)
 		err, _ := s.Deliver("me@me", "to@to", []byte("data"))
 		if err == nil {
 			t.Errorf("deliver not failed in case %q: %v", rs["_welcome"], err)
diff --git a/internal/domaininfo/domaininfo.go b/internal/domaininfo/domaininfo.go
new file mode 100644
index 0000000..0f0ca77
--- /dev/null
+++ b/internal/domaininfo/domaininfo.go
@@ -0,0 +1,133 @@
+// Package domaininfo implements a domain information database, to keep track
+// of things we know about a particular domain.
+package domaininfo
+
+import (
+	"fmt"
+	"sync"
+
+	"blitiri.com.ar/go/chasquid/internal/protoio"
+	"blitiri.com.ar/go/chasquid/internal/trace"
+)
+
+// Command to generate domaininfo.pb.go.
+//go:generate protoc --go_out=. domaininfo.proto
+
+type DB struct {
+	// Persistent store with the list of domains we know.
+	store *protoio.Store
+
+	info map[string]*Domain
+	sync.Mutex
+
+	ev *trace.EventLog
+}
+
+func New(dir string) (*DB, error) {
+	st, err := protoio.NewStore(dir)
+	if err != nil {
+		return nil, err
+	}
+
+	l := &DB{
+		store: st,
+		info:  map[string]*Domain{},
+	}
+	l.ev = trace.NewEventLog("DomainInfo", fmt.Sprintf("%p", l))
+
+	return l, nil
+}
+
+// Load the database from disk; should be called once at initialization.
+func (db *DB) Load() error {
+	ids, err := db.store.ListIDs()
+	if err != nil {
+		return err
+	}
+
+	for _, id := range ids {
+		d := &Domain{}
+		_, err := db.store.Get(id, d)
+		if err != nil {
+			return fmt.Errorf("error loading %q: %v", id, err)
+		}
+
+		db.info[d.Name] = d
+	}
+
+	db.ev.Debugf("loaded: %s", ids)
+	return nil
+}
+
+func (db *DB) write(d *Domain) {
+	err := db.store.Put(d.Name, d)
+	if err != nil {
+		db.ev.Errorf("%s error saving: %v", d.Name, err)
+	} else {
+		db.ev.Debugf("%s saved", d.Name)
+	}
+}
+
+// IncomingSecLevel checks an incoming security level for the domain.
+// Returns true if allowed, false otherwise.
+func (db *DB) IncomingSecLevel(domain string, level SecLevel) bool {
+	db.Lock()
+	defer db.Unlock()
+
+	d, exists := db.info[domain]
+	if !exists {
+		d = &Domain{Name: domain}
+		db.info[domain] = d
+		defer db.write(d)
+	}
+
+	if level < d.IncomingSecLevel {
+		db.ev.Errorf("%s incoming denied: %s < %s",
+			d.Name, level, d.IncomingSecLevel)
+		return false
+	} else if level == d.IncomingSecLevel {
+		db.ev.Debugf("%s incoming allowed: %s == %s",
+			d.Name, level, d.IncomingSecLevel)
+		return true
+	} else {
+		db.ev.Printf("%s incoming level raised: %s > %s",
+			d.Name, level, d.IncomingSecLevel)
+		d.IncomingSecLevel = level
+		if exists {
+			defer db.write(d)
+		}
+		return true
+	}
+}
+
+// OutgoingSecLevel checks an incoming security level for the domain.
+// Returns true if allowed, false otherwise.
+func (db *DB) OutgoingSecLevel(domain string, level SecLevel) bool {
+	db.Lock()
+	defer db.Unlock()
+
+	d, exists := db.info[domain]
+	if !exists {
+		d = &Domain{Name: domain}
+		db.info[domain] = d
+		defer db.write(d)
+	}
+
+	if level < d.OutgoingSecLevel {
+		db.ev.Errorf("%s outgoing denied: %s < %s",
+			d.Name, level, d.OutgoingSecLevel)
+		return false
+	} else if level == d.OutgoingSecLevel {
+		db.ev.Debugf("%s outgoing allowed: %s == %s",
+			d.Name, level, d.OutgoingSecLevel)
+		return true
+	} else {
+		db.ev.Printf("%s outgoing level raised: %s > %s",
+			d.Name, level, d.OutgoingSecLevel)
+		d.OutgoingSecLevel = level
+		if exists {
+			defer db.write(d)
+		}
+		return true
+	}
+}
diff --git a/internal/domaininfo/domaininfo.pb.go b/internal/domaininfo/domaininfo.pb.go
new file mode 100644
index 0000000..6c5116a
--- /dev/null
+++ b/internal/domaininfo/domaininfo.pb.go
@@ -0,0 +1,96 @@
+// Code generated by protoc-gen-go.
+// source: domaininfo.proto
+// DO NOT EDIT!
+
+/*
+Package domaininfo is a generated protocol buffer package.
+
+It is generated from these files:
+	domaininfo.proto
+
+It has these top-level messages:
+	Domain
+*/
+package domaininfo
+
+import proto "github.com/golang/protobuf/proto"
+import fmt "fmt"
+import math "math"
+
+// Reference imports to suppress errors if they are not otherwise used.
+var _ = proto.Marshal
+var _ = fmt.Errorf
+var _ = math.Inf
+
+// This is a compile-time assertion to ensure that this generated file
+// is compatible with the proto package it is being compiled against.
+// A compilation error at this line likely means your copy of the
+// proto package needs to be updated.
+const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
+
+type SecLevel int32
+
+const (
+	// Does not do TLS.
+	SecLevel_PLAIN SecLevel = 0
+	// TLS client connection (no certificate validation).
+	SecLevel_TLS_CLIENT SecLevel = 1
+	// TLS, but with invalid certificates.
+	SecLevel_TLS_INSECURE SecLevel = 2
+	// TLS, with valid certificates.
+	SecLevel_TLS_SECURE SecLevel = 3
+)
+
+var SecLevel_name = map[int32]string{
+	0: "PLAIN",
+	1: "TLS_CLIENT",
+	2: "TLS_INSECURE",
+	3: "TLS_SECURE",
+}
+var SecLevel_value = map[string]int32{
+	"PLAIN":        0,
+	"TLS_CLIENT":   1,
+	"TLS_INSECURE": 2,
+	"TLS_SECURE":   3,
+}
+
+func (x SecLevel) String() string {
+	return proto.EnumName(SecLevel_name, int32(x))
+}
+func (SecLevel) EnumDescriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
+
+type Domain struct {
+	Name string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty"`
+	// Security level for mail coming from this domain (they send to us).
+	IncomingSecLevel SecLevel `protobuf:"varint,2,opt,name=incoming_sec_level,json=incomingSecLevel,enum=domaininfo.SecLevel" json:"incoming_sec_level,omitempty"`
+	// Security level for mail going to this domain (we send to them).
+	OutgoingSecLevel SecLevel `protobuf:"varint,3,opt,name=outgoing_sec_level,json=outgoingSecLevel,enum=domaininfo.SecLevel" json:"outgoing_sec_level,omitempty"`
+}
+
+func (m *Domain) Reset()                    { *m = Domain{} }
+func (m *Domain) String() string            { return proto.CompactTextString(m) }
+func (*Domain) ProtoMessage()               {}
+func (*Domain) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
+
+func init() {
+	proto.RegisterType((*Domain)(nil), "domaininfo.Domain")
+	proto.RegisterEnum("domaininfo.SecLevel", SecLevel_name, SecLevel_value)
+}
+
+func init() { proto.RegisterFile("domaininfo.proto", fileDescriptor0) }
+
+var fileDescriptor0 = []byte{
+	// 189 bytes of a gzipped FileDescriptorProto
+	0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0x12, 0x48, 0xc9, 0xcf, 0x4d,
+	0xcc, 0xcc, 0xcb, 0xcc, 0x4b, 0xcb, 0xd7, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x42, 0x88,
+	0x28, 0x2d, 0x61, 0xe4, 0x62, 0x73, 0x01, 0x73, 0x85, 0x84, 0xb8, 0x58, 0xf2, 0x12, 0x73, 0x53,
+	0x25, 0x18, 0x15, 0x18, 0x35, 0x38, 0x83, 0xc0, 0x6c, 0x21, 0x27, 0x2e, 0xa1, 0xcc, 0xbc, 0xe4,
+	0xfc, 0xdc, 0xcc, 0xbc, 0xf4, 0xf8, 0xe2, 0xd4, 0xe4, 0xf8, 0x9c, 0xd4, 0xb2, 0xd4, 0x1c, 0x09,
+	0x26, 0xa0, 0x0a, 0x3e, 0x23, 0x11, 0x3d, 0x24, 0x93, 0x83, 0x53, 0x93, 0x7d, 0x40, 0x72, 0x41,
+	0x02, 0x30, 0xf5, 0x30, 0x11, 0x90, 0x19, 0xf9, 0xa5, 0x25, 0xe9, 0xf9, 0xa8, 0x66, 0x30, 0xe3,
+	0x33, 0x03, 0xa6, 0x1e, 0x26, 0xa2, 0xe5, 0xce, 0xc5, 0x01, 0x37, 0x8f, 0x93, 0x8b, 0x35, 0xc0,
+	0xc7, 0xd1, 0xd3, 0x4f, 0x80, 0x41, 0x88, 0x8f, 0x8b, 0x2b, 0xc4, 0x27, 0x38, 0xde, 0xd9, 0xc7,
+	0xd3, 0xd5, 0x2f, 0x44, 0x80, 0x51, 0x48, 0x80, 0x8b, 0x07, 0xc4, 0xf7, 0xf4, 0x0b, 0x76, 0x75,
+	0x0e, 0x0d, 0x72, 0x15, 0x60, 0x82, 0xa9, 0x80, 0xf2, 0x99, 0x93, 0xd8, 0xc0, 0x41, 0x60, 0x0c,
+	0x08, 0x00, 0x00, 0xff, 0xff, 0x2c, 0x78, 0x65, 0x5b, 0x16, 0x01, 0x00, 0x00,
+}
diff --git a/internal/domaininfo/domaininfo.proto b/internal/domaininfo/domaininfo.proto
new file mode 100644
index 0000000..a2df39f
--- /dev/null
+++ b/internal/domaininfo/domaininfo.proto
@@ -0,0 +1,28 @@
+
+syntax = "proto3";
+
+package domaininfo;
+
+enum SecLevel {
+	// Does not do TLS.
+	PLAIN = 0;
+
+	// TLS client connection (no certificate validation).
+	TLS_CLIENT = 1;
+
+	// TLS, but with invalid certificates.
+	TLS_INSECURE = 2;
+
+	// TLS, with valid certificates.
+	TLS_SECURE = 3;
+}
+
+message Domain {
+	string name = 1;
+
+	// Security level for mail coming from this domain (they send to us).
+	SecLevel incoming_sec_level = 2;
+
+	// Security level for mail going to this domain (we send to them).
+	SecLevel outgoing_sec_level = 3;
+}
diff --git a/internal/domaininfo/domaininfo_test.go b/internal/domaininfo/domaininfo_test.go
new file mode 100644
index 0000000..266dd19
--- /dev/null
+++ b/internal/domaininfo/domaininfo_test.go
@@ -0,0 +1,133 @@
+package domaininfo
+
+import (
+	"io/ioutil"
+	"os"
+	"testing"
+	"time"
+)
+
+func mustTempDir(t *testing.T) string {
+	dir, err := ioutil.TempDir("", "greylisting_test")
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	t.Logf("test directory: %q", dir)
+	return dir
+}
+
+func TestBasic(t *testing.T) {
+	dir := mustTempDir(t)
+	db, err := New(dir)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if err := db.Load(); err != nil {
+		t.Fatal(err)
+	}
+
+	if !db.IncomingSecLevel("d1", SecLevel_PLAIN) {
+		t.Errorf("new domain as plain not allowed")
+	}
+	if !db.IncomingSecLevel("d1", SecLevel_TLS_SECURE) {
+		t.Errorf("increment to tls-secure not allowed")
+	}
+	if db.IncomingSecLevel("d1", SecLevel_TLS_INSECURE) {
+		t.Errorf("decrement to tls-insecure was allowed")
+	}
+
+	// Wait until it is written to disk.
+	for dl := time.Now().Add(30 * time.Second); time.Now().Before(dl); {
+		d := &Domain{}
+		ok, _ := db.store.Get("d1", d)
+		if ok {
+			break
+		}
+		time.Sleep(50 * time.Millisecond)
+	}
+
+	// Check that it was added to the store and a new db sees it.
+	db2, err := New(dir)
+	if err != nil {
+		t.Fatal(err)
+	}
+	if err := db2.Load(); err != nil {
+		t.Fatal(err)
+	}
+	if db2.IncomingSecLevel("d1", SecLevel_TLS_INSECURE) {
+		t.Errorf("decrement to tls-insecure was allowed in new DB")
+	}
+
+	if !t.Failed() {
+		os.RemoveAll(dir)
+	}
+}
+
+func TestNewDomain(t *testing.T) {
+	dir := mustTempDir(t)
+	db, err := New(dir)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	cases := []struct {
+		domain string
+		level  SecLevel
+	}{
+		{"plain", SecLevel_PLAIN},
+		{"insecure", SecLevel_TLS_INSECURE},
+		{"secure", SecLevel_TLS_SECURE},
+	}
+	for _, c := range cases {
+		if !db.IncomingSecLevel(c.domain, c.level) {
+			t.Errorf("domain %q not allowed (in) at %s", c.domain, c.level)
+		}
+		if !db.OutgoingSecLevel(c.domain, c.level) {
+			t.Errorf("domain %q not allowed (out) at %s", c.domain, c.level)
+		}
+	}
+	if !t.Failed() {
+		os.RemoveAll(dir)
+	}
+}
+
+func TestProgressions(t *testing.T) {
+	dir := mustTempDir(t)
+	db, err := New(dir)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	cases := []struct {
+		domain string
+		lvl    SecLevel
+		ok     bool
+	}{
+		{"pisis", SecLevel_PLAIN, true},
+		{"pisis", SecLevel_TLS_INSECURE, true},
+		{"pisis", SecLevel_TLS_SECURE, true},
+		{"pisis", SecLevel_TLS_INSECURE, false},
+		{"pisis", SecLevel_TLS_SECURE, true},
+
+		{"ssip", SecLevel_TLS_SECURE, true},
+		{"ssip", SecLevel_TLS_SECURE, true},
+		{"ssip", SecLevel_TLS_INSECURE, false},
+		{"ssip", SecLevel_PLAIN, false},
+	}
+	for i, c := range cases {
+		if ok := db.IncomingSecLevel(c.domain, c.lvl); ok != c.ok {
+			t.Errorf("%2d %q in  attempt for %s failed: got %v, expected %v",
+				i, c.domain, c.lvl, ok, c.ok)
+		}
+		if ok := db.OutgoingSecLevel(c.domain, c.lvl); ok != c.ok {
+			t.Errorf("%2d %q out attempt for %s failed: got %v, expected %v",
+				i, c.domain, c.lvl, ok, c.ok)
+		}
+	}
+
+	if !t.Failed() {
+		os.RemoveAll(dir)
+	}
+}