author | Alberto Bertogli
<albertito@blitiri.com.ar> 2016-10-13 01:28:30 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2016-10-21 21:15:09 UTC |
parent | 1d7a207e00f2e2419737b974b3a80f3ef45f6d85 |
chasquid.go | +58 | -2 |
chasquid_test.go | +58 | -5 |
internal/courier/smtp.go | +17 | -2 |
internal/courier/smtp_test.go | +22 | -2 |
internal/domaininfo/domaininfo.go | +133 | -0 |
internal/domaininfo/domaininfo.pb.go | +96 | -0 |
internal/domaininfo/domaininfo.proto | +28 | -0 |
internal/domaininfo/domaininfo_test.go | +133 | -0 |
diff --git a/chasquid.go b/chasquid.go index 866a74b..b689449 100644 --- a/chasquid.go +++ b/chasquid.go @@ -24,6 +24,7 @@ import ( "blitiri.com.ar/go/chasquid/internal/auth" "blitiri.com.ar/go/chasquid/internal/config" "blitiri.com.ar/go/chasquid/internal/courier" + "blitiri.com.ar/go/chasquid/internal/domaininfo" "blitiri.com.ar/go/chasquid/internal/envelope" "blitiri.com.ar/go/chasquid/internal/normalize" "blitiri.com.ar/go/chasquid/internal/queue" @@ -53,6 +54,7 @@ var ( spfResultCount = expvar.NewMap("chasquid/smtpIn/spfResultCount") loopsDetected = expvar.NewInt("chasquid/smtpIn/loopsDetected") tlsCount = expvar.NewMap("chasquid/smtpIn/tlsCount") + slcResults = expvar.NewMap("chasquid/smtpIn/securityLevelChecks") ) // Global event logs. @@ -133,12 +135,14 @@ func main() { // as a remote domain (for loops, alias resolutions, etc.). s.AddDomain("localhost") + s.InitDomainInfo(conf.DataDir + "/domaininfo") + localC := &courier.Procmail{ Binary: conf.MailDeliveryAgentBin, Args: conf.MailDeliveryAgentArgs, Timeout: 30 * time.Second, } - remoteC := &courier.SMTP{} + remoteC := &courier.SMTP{Dinfo: s.dinfo} s.InitQueue(conf.DataDir+"/queue", localC, remoteC) go s.periodicallyReload() @@ -266,6 +270,9 @@ type Server struct { // Aliases resolver. aliasesR *aliases.Resolver + // Domain info database. + dinfo *domaininfo.DB + // Time before we give up on a connection, even if it's sending data. connTimeout time.Duration @@ -314,6 +321,19 @@ func (s *Server) AddUserDB(domain string, db *userdb.DB) { s.userDBs[domain] = db } +func (s *Server) InitDomainInfo(dir string) { + var err error + s.dinfo, err = domaininfo.New(dir) + if err != nil { + glog.Fatalf("Error opening domain info database: %v", err) + } + + err = s.dinfo.Load() + if err != nil { + glog.Fatalf("Error loading domain info database: %v", err) + } +} + func (s *Server) InitQueue(path string, localC, remoteC courier.Courier) { q := queue.New(path, s.localDomains, s.aliasesR, localC, remoteC) err := q.Load() @@ -399,6 +419,7 @@ func (s *Server) serve(l net.Listener, mode SocketMode) { userDBs: s.userDBs, aliasesR: s.aliasesR, localDomains: s.localDomains, + dinfo: s.dinfo, deadline: time.Now().Add(s.connTimeout), commandTimeout: s.commandTimeout, queue: s.queue, @@ -449,6 +470,7 @@ type Conn struct { userDBs map[string]*userdb.DB localDomains *set.String aliasesR *aliases.Resolver + dinfo *domaininfo.DB // Have we successfully completed AUTH? completedAuth bool @@ -697,6 +719,10 @@ func (c *Conn) MAIL(params string) (code int, msg string) { "SPF check failed: %v", c.spfError) } + if !c.secLevelCheck(addr) { + return 550, "security level check failed" + } + addr, err = normalize.DomainToUnicode(addr) if err != nil { return 501, "malformed address (IDNA conversion failed)" @@ -727,6 +753,37 @@ func (c *Conn) checkSPF(addr string) (spf.Result, error) { return "", nil } +// secLevelCheck checks if the security level is acceptable for the given +// address. +func (c *Conn) secLevelCheck(addr string) bool { + // Only check if SPF passes. This serves two purposes: + // - Skip for authenticated connections (we trust them implicitly). + // - Don't apply this if we can't be sure the sender is authorized. + // Otherwise anyone could raise the level of any domain. + if c.spfResult != spf.Pass { + slcResults.Add("skip", 1) + c.tr.Debugf("SPF did not pass, skipping security level check") + return true + } + + domain := envelope.DomainOf(addr) + level := domaininfo.SecLevel_PLAIN + if c.onTLS { + level = domaininfo.SecLevel_TLS_CLIENT + } + + ok := c.dinfo.IncomingSecLevel(domain, level) + if ok { + slcResults.Add("pass", 1) + c.tr.Debugf("security level check for %s passed (%s)", domain, level) + } else { + slcResults.Add("fail", 1) + c.tr.Errorf("security level check for %s failed (%s)", domain, level) + } + + return ok +} + func (c *Conn) RCPT(params string) (code int, msg string) { // params should be: "TO:<name@host>", and possibly followed by options // such as "NOTIFY=SUCCESS,DELAY" (which we ignore). @@ -793,7 +850,6 @@ func (c *Conn) DATA(params string) (code int, msg string) { if c.mailFrom == "" { return 503, "sender not yet given" } - if len(c.rcptTo) == 0 { return 503, "need an address to send to" } diff --git a/chasquid_test.go b/chasquid_test.go index 78b6cb3..04cc469 100644 --- a/chasquid_test.go +++ b/chasquid_test.go @@ -19,6 +19,9 @@ import ( "blitiri.com.ar/go/chasquid/internal/aliases" "blitiri.com.ar/go/chasquid/internal/courier" + "blitiri.com.ar/go/chasquid/internal/domaininfo" + "blitiri.com.ar/go/chasquid/internal/spf" + "blitiri.com.ar/go/chasquid/internal/trace" "blitiri.com.ar/go/chasquid/internal/userdb" "github.com/golang/glog" @@ -169,7 +172,7 @@ func TestWrongMailParsing(t *testing.T) { } } - if err := c.Mail("from@from"); err != nil { + if err := c.Mail("from@plain"); err != nil { t.Errorf("Mail: %v", err) } @@ -200,11 +203,11 @@ func TestRcptBeforeMail(t *testing.T) { } func TestRcptOption(t *testing.T) { - c := mustDial(t, ModeSMTP, false) + c := mustDial(t, ModeSMTP, true) defer c.Close() if err := c.Mail("from@localhost"); err != nil { - t.Errorf("Mail: %v", err) + t.Fatalf("Mail: %v", err) } params := []string{ @@ -250,7 +253,7 @@ func TestReset(t *testing.T) { c := mustDial(t, ModeSMTP, false) defer c.Close() - if err := c.Mail("from@from"); err != nil { + if err := c.Mail("from@plain"); err != nil { t.Fatalf("MAIL FROM: %v", err) } @@ -258,7 +261,7 @@ func TestReset(t *testing.T) { t.Errorf("RSET: %v", err) } - if err := c.Mail("from@from"); err != nil { + if err := c.Mail("from@plain"); err != nil { t.Errorf("MAIL after RSET: %v", err) } } @@ -278,6 +281,55 @@ func TestRepeatedStartTLS(t *testing.T) { } } +func TestSecLevel(t *testing.T) { + // We can't simulate this externally because of the SPF record + // requirement, so do a narrow test on Conn.secLevelCheck. + tmpDir, err := ioutil.TempDir("", "chasquid_test:") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + dinfo, err := domaininfo.New(tmpDir) + if err != nil { + t.Fatalf("Failed to create domain info: %v", err) + } + + c := &Conn{ + tr: trace.New("testconn", "testconn"), + dinfo: dinfo, + } + + // No SPF, skip security checks. + c.spfResult = spf.None + c.onTLS = true + if !c.secLevelCheck("from@slc") { + t.Fatalf("TLS seclevel failed") + } + + c.onTLS = false + if !c.secLevelCheck("from@slc") { + t.Fatalf("plain seclevel failed, even though SPF does not exist") + } + + // Now the real checks, once SPF passes. + c.spfResult = spf.Pass + + if !c.secLevelCheck("from@slc") { + t.Fatalf("plain seclevel failed") + } + + c.onTLS = true + if !c.secLevelCheck("from@slc") { + t.Fatalf("TLS seclevel failed") + } + + c.onTLS = false + if c.secLevelCheck("from@slc") { + t.Fatalf("plain seclevel worked, downgrade was allowed") + } +} + // // === Benchmarks === // @@ -438,6 +490,7 @@ func realMain(m *testing.M) int { localC := &courier.Procmail{} remoteC := &courier.SMTP{} s.InitQueue(tmpDir+"/queue", localC, remoteC) + s.InitDomainInfo(tmpDir + "/domaininfo") udb := userdb.New("/dev/null") udb.AddUser("testuser", "testpasswd") diff --git a/internal/courier/smtp.go b/internal/courier/smtp.go index 26dc03c..91f8bf5 100644 --- a/internal/courier/smtp.go +++ b/internal/courier/smtp.go @@ -11,6 +11,7 @@ import ( "github.com/golang/glog" "golang.org/x/net/idna" + "blitiri.com.ar/go/chasquid/internal/domaininfo" "blitiri.com.ar/go/chasquid/internal/envelope" "blitiri.com.ar/go/chasquid/internal/smtp" "blitiri.com.ar/go/chasquid/internal/trace" @@ -32,11 +33,13 @@ var ( // Exported variables. var ( - tlsCount = expvar.NewMap("chasquid/smtpOut/tlsCount") + tlsCount = expvar.NewMap("chasquid/smtpOut/tlsCount") + slcResults = expvar.NewMap("chasquid/smtpOut/securityLevelChecks") ) // SMTP delivers remote mail via outgoing SMTP. type SMTP struct { + Dinfo *domaininfo.DB } func (s *SMTP) Deliver(from string, to string, data []byte) (error, bool) { @@ -44,7 +47,8 @@ func (s *SMTP) Deliver(from string, to string, data []byte) (error, bool) { defer tr.Finish() tr.Debugf("%s -> %s", from, to) - mx, err := lookupMX(envelope.DomainOf(to)) + toDomain := envelope.DomainOf(to) + mx, err := lookupMX(toDomain) if err != nil { // Note this is considered a permanent error. // This is in line with what other servers (Exim) do. However, the @@ -68,6 +72,7 @@ func (s *SMTP) Deliver(from string, to string, data []byte) (error, bool) { // Do we use insecure TLS? // Set as fallback when retrying. insecure := false + secLevel := domaininfo.SecLevel_PLAIN retry: conn, err := net.DialTimeout("tcp", mx+":"+*smtpPort, smtpDialTimeout) @@ -110,15 +115,25 @@ retry: if config.InsecureSkipVerify { tr.Debugf("Insecure - using TLS, but cert does not match %s", mx) tlsCount.Add("tls:insecure", 1) + secLevel = domaininfo.SecLevel_TLS_INSECURE } else { tlsCount.Add("tls:secure", 1) tr.Debugf("Secure - using TLS") + secLevel = domaininfo.SecLevel_TLS_SECURE } } else { tlsCount.Add("plain", 1) tr.Debugf("Insecure - NOT using TLS") } + if toDomain != "" && !s.Dinfo.OutgoingSecLevel(toDomain, secLevel) { + // We consider the failure transient, so transient misconfigurations + // do not affect deliveries. + slcResults.Add("fail", 1) + return tr.Errorf("Security level check failed (level:%s)", secLevel), false + } + slcResults.Add("pass", 1) + // c.Mail will add the <> for us when the address is empty. if from == "<>" { from = "" diff --git a/internal/courier/smtp_test.go b/internal/courier/smtp_test.go index 915b09c..1ba7ffd 100644 --- a/internal/courier/smtp_test.go +++ b/internal/courier/smtp_test.go @@ -2,12 +2,30 @@ package courier import ( "bufio" + "io/ioutil" "net" "net/textproto" + "os" "testing" "time" + + "blitiri.com.ar/go/chasquid/internal/domaininfo" ) +func newSMTP(t *testing.T) (*SMTP, string) { + dir, err := ioutil.TempDir("", "smtp_test") + if err != nil { + t.Fatal(err) + } + + dinfo, err := domaininfo.New(dir) + if err != nil { + t.Fatal(err) + } + + return &SMTP{dinfo}, dir +} + // Fake server, to test SMTP out. func fakeServer(t *testing.T, responses map[string]string) string { l, err := net.Listen("tcp", "localhost:0") @@ -72,7 +90,8 @@ func TestSMTP(t *testing.T) { fakeMX["to"] = host *smtpPort = port - s := &SMTP{} + s, tmpDir := newSMTP(t) + defer os.Remove(tmpDir) err, _ := s.Deliver("me@me", "to@to", []byte("data")) if err != nil { t.Errorf("deliver failed: %v", err) @@ -132,7 +151,8 @@ func TestSMTPErrors(t *testing.T) { fakeMX["to"] = host *smtpPort = port - s := &SMTP{} + s, tmpDir := newSMTP(t) + defer os.Remove(tmpDir) err, _ := s.Deliver("me@me", "to@to", []byte("data")) if err == nil { t.Errorf("deliver not failed in case %q: %v", rs["_welcome"], err) diff --git a/internal/domaininfo/domaininfo.go b/internal/domaininfo/domaininfo.go new file mode 100644 index 0000000..0f0ca77 --- /dev/null +++ b/internal/domaininfo/domaininfo.go @@ -0,0 +1,133 @@ +// Package domaininfo implements a domain information database, to keep track +// of things we know about a particular domain. +package domaininfo + +import ( + "fmt" + "sync" + + "blitiri.com.ar/go/chasquid/internal/protoio" + "blitiri.com.ar/go/chasquid/internal/trace" +) + +// Command to generate domaininfo.pb.go. +//go:generate protoc --go_out=. domaininfo.proto + +type DB struct { + // Persistent store with the list of domains we know. + store *protoio.Store + + info map[string]*Domain + sync.Mutex + + ev *trace.EventLog +} + +func New(dir string) (*DB, error) { + st, err := protoio.NewStore(dir) + if err != nil { + return nil, err + } + + l := &DB{ + store: st, + info: map[string]*Domain{}, + } + l.ev = trace.NewEventLog("DomainInfo", fmt.Sprintf("%p", l)) + + return l, nil +} + +// Load the database from disk; should be called once at initialization. +func (db *DB) Load() error { + ids, err := db.store.ListIDs() + if err != nil { + return err + } + + for _, id := range ids { + d := &Domain{} + _, err := db.store.Get(id, d) + if err != nil { + return fmt.Errorf("error loading %q: %v", id, err) + } + + db.info[d.Name] = d + } + + db.ev.Debugf("loaded: %s", ids) + return nil +} + +func (db *DB) write(d *Domain) { + err := db.store.Put(d.Name, d) + if err != nil { + db.ev.Errorf("%s error saving: %v", d.Name, err) + } else { + db.ev.Debugf("%s saved", d.Name) + } +} + +// IncomingSecLevel checks an incoming security level for the domain. +// Returns true if allowed, false otherwise. +func (db *DB) IncomingSecLevel(domain string, level SecLevel) bool { + db.Lock() + defer db.Unlock() + + d, exists := db.info[domain] + if !exists { + d = &Domain{Name: domain} + db.info[domain] = d + defer db.write(d) + } + + if level < d.IncomingSecLevel { + db.ev.Errorf("%s incoming denied: %s < %s", + d.Name, level, d.IncomingSecLevel) + return false + } else if level == d.IncomingSecLevel { + db.ev.Debugf("%s incoming allowed: %s == %s", + d.Name, level, d.IncomingSecLevel) + return true + } else { + db.ev.Printf("%s incoming level raised: %s > %s", + d.Name, level, d.IncomingSecLevel) + d.IncomingSecLevel = level + if exists { + defer db.write(d) + } + return true + } +} + +// OutgoingSecLevel checks an incoming security level for the domain. +// Returns true if allowed, false otherwise. +func (db *DB) OutgoingSecLevel(domain string, level SecLevel) bool { + db.Lock() + defer db.Unlock() + + d, exists := db.info[domain] + if !exists { + d = &Domain{Name: domain} + db.info[domain] = d + defer db.write(d) + } + + if level < d.OutgoingSecLevel { + db.ev.Errorf("%s outgoing denied: %s < %s", + d.Name, level, d.OutgoingSecLevel) + return false + } else if level == d.OutgoingSecLevel { + db.ev.Debugf("%s outgoing allowed: %s == %s", + d.Name, level, d.OutgoingSecLevel) + return true + } else { + db.ev.Printf("%s outgoing level raised: %s > %s", + d.Name, level, d.OutgoingSecLevel) + d.OutgoingSecLevel = level + if exists { + defer db.write(d) + } + return true + } +} diff --git a/internal/domaininfo/domaininfo.pb.go b/internal/domaininfo/domaininfo.pb.go new file mode 100644 index 0000000..6c5116a --- /dev/null +++ b/internal/domaininfo/domaininfo.pb.go @@ -0,0 +1,96 @@ +// Code generated by protoc-gen-go. +// source: domaininfo.proto +// DO NOT EDIT! + +/* +Package domaininfo is a generated protocol buffer package. + +It is generated from these files: + domaininfo.proto + +It has these top-level messages: + Domain +*/ +package domaininfo + +import proto "github.com/golang/protobuf/proto" +import fmt "fmt" +import math "math" + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package + +type SecLevel int32 + +const ( + // Does not do TLS. + SecLevel_PLAIN SecLevel = 0 + // TLS client connection (no certificate validation). + SecLevel_TLS_CLIENT SecLevel = 1 + // TLS, but with invalid certificates. + SecLevel_TLS_INSECURE SecLevel = 2 + // TLS, with valid certificates. + SecLevel_TLS_SECURE SecLevel = 3 +) + +var SecLevel_name = map[int32]string{ + 0: "PLAIN", + 1: "TLS_CLIENT", + 2: "TLS_INSECURE", + 3: "TLS_SECURE", +} +var SecLevel_value = map[string]int32{ + "PLAIN": 0, + "TLS_CLIENT": 1, + "TLS_INSECURE": 2, + "TLS_SECURE": 3, +} + +func (x SecLevel) String() string { + return proto.EnumName(SecLevel_name, int32(x)) +} +func (SecLevel) EnumDescriptor() ([]byte, []int) { return fileDescriptor0, []int{0} } + +type Domain struct { + Name string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty"` + // Security level for mail coming from this domain (they send to us). + IncomingSecLevel SecLevel `protobuf:"varint,2,opt,name=incoming_sec_level,json=incomingSecLevel,enum=domaininfo.SecLevel" json:"incoming_sec_level,omitempty"` + // Security level for mail going to this domain (we send to them). + OutgoingSecLevel SecLevel `protobuf:"varint,3,opt,name=outgoing_sec_level,json=outgoingSecLevel,enum=domaininfo.SecLevel" json:"outgoing_sec_level,omitempty"` +} + +func (m *Domain) Reset() { *m = Domain{} } +func (m *Domain) String() string { return proto.CompactTextString(m) } +func (*Domain) ProtoMessage() {} +func (*Domain) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} } + +func init() { + proto.RegisterType((*Domain)(nil), "domaininfo.Domain") + proto.RegisterEnum("domaininfo.SecLevel", SecLevel_name, SecLevel_value) +} + +func init() { proto.RegisterFile("domaininfo.proto", fileDescriptor0) } + +var fileDescriptor0 = []byte{ + // 189 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0x12, 0x48, 0xc9, 0xcf, 0x4d, + 0xcc, 0xcc, 0xcb, 0xcc, 0x4b, 0xcb, 0xd7, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x42, 0x88, + 0x28, 0x2d, 0x61, 0xe4, 0x62, 0x73, 0x01, 0x73, 0x85, 0x84, 0xb8, 0x58, 0xf2, 0x12, 0x73, 0x53, + 0x25, 0x18, 0x15, 0x18, 0x35, 0x38, 0x83, 0xc0, 0x6c, 0x21, 0x27, 0x2e, 0xa1, 0xcc, 0xbc, 0xe4, + 0xfc, 0xdc, 0xcc, 0xbc, 0xf4, 0xf8, 0xe2, 0xd4, 0xe4, 0xf8, 0x9c, 0xd4, 0xb2, 0xd4, 0x1c, 0x09, + 0x26, 0xa0, 0x0a, 0x3e, 0x23, 0x11, 0x3d, 0x24, 0x93, 0x83, 0x53, 0x93, 0x7d, 0x40, 0x72, 0x41, + 0x02, 0x30, 0xf5, 0x30, 0x11, 0x90, 0x19, 0xf9, 0xa5, 0x25, 0xe9, 0xf9, 0xa8, 0x66, 0x30, 0xe3, + 0x33, 0x03, 0xa6, 0x1e, 0x26, 0xa2, 0xe5, 0xce, 0xc5, 0x01, 0x37, 0x8f, 0x93, 0x8b, 0x35, 0xc0, + 0xc7, 0xd1, 0xd3, 0x4f, 0x80, 0x41, 0x88, 0x8f, 0x8b, 0x2b, 0xc4, 0x27, 0x38, 0xde, 0xd9, 0xc7, + 0xd3, 0xd5, 0x2f, 0x44, 0x80, 0x51, 0x48, 0x80, 0x8b, 0x07, 0xc4, 0xf7, 0xf4, 0x0b, 0x76, 0x75, + 0x0e, 0x0d, 0x72, 0x15, 0x60, 0x82, 0xa9, 0x80, 0xf2, 0x99, 0x93, 0xd8, 0xc0, 0x41, 0x60, 0x0c, + 0x08, 0x00, 0x00, 0xff, 0xff, 0x2c, 0x78, 0x65, 0x5b, 0x16, 0x01, 0x00, 0x00, +} diff --git a/internal/domaininfo/domaininfo.proto b/internal/domaininfo/domaininfo.proto new file mode 100644 index 0000000..a2df39f --- /dev/null +++ b/internal/domaininfo/domaininfo.proto @@ -0,0 +1,28 @@ + +syntax = "proto3"; + +package domaininfo; + +enum SecLevel { + // Does not do TLS. + PLAIN = 0; + + // TLS client connection (no certificate validation). + TLS_CLIENT = 1; + + // TLS, but with invalid certificates. + TLS_INSECURE = 2; + + // TLS, with valid certificates. + TLS_SECURE = 3; +} + +message Domain { + string name = 1; + + // Security level for mail coming from this domain (they send to us). + SecLevel incoming_sec_level = 2; + + // Security level for mail going to this domain (we send to them). + SecLevel outgoing_sec_level = 3; +} diff --git a/internal/domaininfo/domaininfo_test.go b/internal/domaininfo/domaininfo_test.go new file mode 100644 index 0000000..266dd19 --- /dev/null +++ b/internal/domaininfo/domaininfo_test.go @@ -0,0 +1,133 @@ +package domaininfo + +import ( + "io/ioutil" + "os" + "testing" + "time" +) + +func mustTempDir(t *testing.T) string { + dir, err := ioutil.TempDir("", "greylisting_test") + if err != nil { + t.Fatal(err) + } + + t.Logf("test directory: %q", dir) + return dir +} + +func TestBasic(t *testing.T) { + dir := mustTempDir(t) + db, err := New(dir) + if err != nil { + t.Fatal(err) + } + + if err := db.Load(); err != nil { + t.Fatal(err) + } + + if !db.IncomingSecLevel("d1", SecLevel_PLAIN) { + t.Errorf("new domain as plain not allowed") + } + if !db.IncomingSecLevel("d1", SecLevel_TLS_SECURE) { + t.Errorf("increment to tls-secure not allowed") + } + if db.IncomingSecLevel("d1", SecLevel_TLS_INSECURE) { + t.Errorf("decrement to tls-insecure was allowed") + } + + // Wait until it is written to disk. + for dl := time.Now().Add(30 * time.Second); time.Now().Before(dl); { + d := &Domain{} + ok, _ := db.store.Get("d1", d) + if ok { + break + } + time.Sleep(50 * time.Millisecond) + } + + // Check that it was added to the store and a new db sees it. + db2, err := New(dir) + if err != nil { + t.Fatal(err) + } + if err := db2.Load(); err != nil { + t.Fatal(err) + } + if db2.IncomingSecLevel("d1", SecLevel_TLS_INSECURE) { + t.Errorf("decrement to tls-insecure was allowed in new DB") + } + + if !t.Failed() { + os.RemoveAll(dir) + } +} + +func TestNewDomain(t *testing.T) { + dir := mustTempDir(t) + db, err := New(dir) + if err != nil { + t.Fatal(err) + } + + cases := []struct { + domain string + level SecLevel + }{ + {"plain", SecLevel_PLAIN}, + {"insecure", SecLevel_TLS_INSECURE}, + {"secure", SecLevel_TLS_SECURE}, + } + for _, c := range cases { + if !db.IncomingSecLevel(c.domain, c.level) { + t.Errorf("domain %q not allowed (in) at %s", c.domain, c.level) + } + if !db.OutgoingSecLevel(c.domain, c.level) { + t.Errorf("domain %q not allowed (out) at %s", c.domain, c.level) + } + } + if !t.Failed() { + os.RemoveAll(dir) + } +} + +func TestProgressions(t *testing.T) { + dir := mustTempDir(t) + db, err := New(dir) + if err != nil { + t.Fatal(err) + } + + cases := []struct { + domain string + lvl SecLevel + ok bool + }{ + {"pisis", SecLevel_PLAIN, true}, + {"pisis", SecLevel_TLS_INSECURE, true}, + {"pisis", SecLevel_TLS_SECURE, true}, + {"pisis", SecLevel_TLS_INSECURE, false}, + {"pisis", SecLevel_TLS_SECURE, true}, + + {"ssip", SecLevel_TLS_SECURE, true}, + {"ssip", SecLevel_TLS_SECURE, true}, + {"ssip", SecLevel_TLS_INSECURE, false}, + {"ssip", SecLevel_PLAIN, false}, + } + for i, c := range cases { + if ok := db.IncomingSecLevel(c.domain, c.lvl); ok != c.ok { + t.Errorf("%2d %q in attempt for %s failed: got %v, expected %v", + i, c.domain, c.lvl, ok, c.ok) + } + if ok := db.OutgoingSecLevel(c.domain, c.lvl); ok != c.ok { + t.Errorf("%2d %q out attempt for %s failed: got %v, expected %v", + i, c.domain, c.lvl, ok, c.ok) + } + } + + if !t.Failed() { + os.RemoveAll(dir) + } +}