author | Alberto Bertogli
<albertito@blitiri.com.ar> 2015-11-06 02:03:21 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2015-11-13 03:41:06 UTC |
parent | e5c2676f830dbfcfdb0470ac8c5dcbbb7bd0a03a |
chasquid.go | +34 | -6 |
internal/courier/courier.go | +48 | -0 |
internal/courier/courier_test.go | +44 | -0 |
internal/courier/procmail.go | +93 | -0 |
internal/courier/procmail_test.go | +65 | -0 |
internal/courier/smtp.go | +128 | -0 |
internal/courier/smtp_test.go | +144 | -0 |
internal/queue/queue.go | +28 | -10 |
internal/queue/queue_test.go | +38 | -8 |
internal/trace/trace.go | +17 | -1 |
diff --git a/chasquid.go b/chasquid.go index 4cefc73..4a207fc 100644 --- a/chasquid.go +++ b/chasquid.go @@ -17,6 +17,7 @@ import ( "time" "blitiri.com.ar/go/chasquid/internal/config" + "blitiri.com.ar/go/chasquid/internal/courier" "blitiri.com.ar/go/chasquid/internal/queue" "blitiri.com.ar/go/chasquid/internal/systemd" "blitiri.com.ar/go/chasquid/internal/trace" @@ -55,21 +56,28 @@ func main() { s.MaxDataSize = conf.MaxDataSizeMb * 1024 * 1024 // Load domains. - domains, err := filepath.Glob(*configDir + "/domains/*") + domainDirs, err := ioutil.ReadDir(*configDir + "/domains/") if err != nil { glog.Fatalf("Error in glob: %v", err) } - if len(domains) == 0 { + if len(domainDirs) == 0 { glog.Warningf("No domains found in config, using test certs") s.AddCerts(*testCert, *testKey) } else { glog.Infof("Domain config paths:") - for _, d := range domains { - glog.Infof(" %s", d) - s.AddCerts(d+"/cert.pem", d+"/key.pem") + for _, info := range domainDirs { + glog.Infof(" %s", info.Name()) + s.AddDomain(info.Name()) + dir := filepath.Join(*configDir, "domains", info.Name()) + s.AddCerts(dir+"/cert.pem", dir+"/key.pem") } } + // Always include localhost as local domain. + // This can prevent potential trouble if we were to accidentally treat it + // as a remote domain (for loops, alias resolutions, etc.). + s.AddDomain("localhost") + // Load addresses. acount := 0 for _, addr := range conf.Address { @@ -115,6 +123,9 @@ type Server struct { // TLS config. tlsConfig *tls.Config + // Local domains. + localDomains map[string]bool + // Time before we give up on a connection, even if it's sending data. connTimeout time.Duration @@ -129,7 +140,7 @@ func NewServer() *Server { return &Server{ connTimeout: 20 * time.Minute, commandTimeout: 1 * time.Minute, - queue: queue.New(), + localDomains: map[string]bool{}, } } @@ -146,6 +157,10 @@ func (s *Server) AddListeners(ls []net.Listener) { s.listeners = append(s.listeners, ls...) } +func (s *Server) AddDomain(d string) { + s.localDomains[d] = true +} + func (s *Server) getTLSConfig() (*tls.Config, error) { var err error conf := &tls.Config{} @@ -172,6 +187,15 @@ func (s *Server) ListenAndServe() { glog.Fatalf("Error loading TLS config: %v", err) } + // Create the queue, giving it a routing courier for delivery. + // We need to do this early, before accepting connections. + courier := &courier.Router{ + Local: &courier.Procmail{}, + Remote: &courier.SMTP{}, + LocalDomains: s.localDomains, + } + s.queue = queue.New(courier) + for _, addr := range s.addrs { // Listen. l, err := net.Listen("tcp", addr) @@ -420,6 +444,10 @@ func (c *Conn) RCPT(params string) (code int, msg string) { return 500, "unknown command" } + // TODO: Write our own parser (we have different needs, mail.ParseAddress + // is useful for other things). + // Allow utf8, but prevent "control" characters. + e, err := mail.ParseAddress(sp[1]) if err != nil || e.Address == "" { return 501, "malformed address" diff --git a/internal/courier/courier.go b/internal/courier/courier.go new file mode 100644 index 0000000..42635f5 --- /dev/null +++ b/internal/courier/courier.go @@ -0,0 +1,48 @@ +// Package courier implements various couriers for delivering messages. +package courier + +import "strings" + +// Courier delivers mail to a single recipient. +// It is implemented by different couriers, for both local and remote +// recipients. +type Courier interface { + Deliver(from string, to string, data []byte) error +} + +// Router decides if the destination is local or remote, and delivers +// accordingly. +type Router struct { + Local Courier + Remote Courier + LocalDomains map[string]bool +} + +func (r *Router) Deliver(from string, to string, data []byte) error { + d := domainOf(to) + if r.LocalDomains[d] { + return r.Local.Deliver(from, to, data) + } else { + return r.Remote.Deliver(from, to, data) + } +} + +// Split an user@domain address into user and domain. +func split(addr string) (string, string) { + ps := strings.SplitN(addr, "@", 2) + if len(ps) != 2 { + return addr, "" + } + + return ps[0], ps[1] +} + +func userOf(addr string) string { + user, _ := split(addr) + return user +} + +func domainOf(addr string) string { + _, domain := split(addr) + return domain +} diff --git a/internal/courier/courier_test.go b/internal/courier/courier_test.go new file mode 100644 index 0000000..069aa29 --- /dev/null +++ b/internal/courier/courier_test.go @@ -0,0 +1,44 @@ +package courier + +import "testing" + +// Counter courier, for testing purposes. +type counter struct { + c int +} + +func (c *counter) Deliver(from string, to string, data []byte) error { + c.c++ + return nil +} + +func TestRouter(t *testing.T) { + localC := &counter{} + remoteC := &counter{} + r := Router{ + Local: localC, + Remote: remoteC, + LocalDomains: map[string]bool{ + "local1": true, + "local2": true, + }, + } + + for domain, c := range map[string]int{ + "local1": 1, + "local2": 2, + "remote": 9, + } { + for i := 0; i < c; i++ { + r.Deliver("from", "a@"+domain, nil) + } + } + + if localC.c != 3 { + t.Errorf("local mis-count: expected 3, got %d", localC.c) + } + + if remoteC.c != 9 { + t.Errorf("remote mis-count: expected 9, got %d", remoteC.c) + } +} diff --git a/internal/courier/procmail.go b/internal/courier/procmail.go new file mode 100644 index 0000000..bc21cb9 --- /dev/null +++ b/internal/courier/procmail.go @@ -0,0 +1,93 @@ +package courier + +import ( + "bytes" + "fmt" + "os/exec" + "strings" + "time" + + "blitiri.com.ar/go/chasquid/internal/trace" +) + +var ( + // Location of the procmail binary, and arguments to use. + // The string "%user%" will be replaced with the local user. + procmailBin = "procmail" + procmailArgs = []string{"-d", "%user%"} + + // Give procmail 1m to deliver mail. + procmailTimeout = 1 * time.Minute +) + +var ( + timeoutError = fmt.Errorf("Operation timed out") +) + +// Procmail delivers local mail via procmail. +type Procmail struct { +} + +func (p *Procmail) Deliver(from string, to string, data []byte) error { + tr := trace.New("Procmail", "Deliver") + defer tr.Finish() + + // Get the user, and sanitize to be extra paranoid. + user := sanitizeForProcmail(userOf(to)) + tr.LazyPrintf("%s -> %s (%s)", from, user, to) + + // Prepare the command, replacing the necessary arguments. + args := []string{} + for _, a := range procmailArgs { + args = append(args, strings.Replace(a, "%user%", user, -1)) + } + cmd := exec.Command(procmailBin, args...) + + cmdStdin, err := cmd.StdinPipe() + if err != nil { + return tr.Errorf("StdinPipe: %v", err) + } + + output := &bytes.Buffer{} + cmd.Stdout = output + cmd.Stderr = output + + err = cmd.Start() + if err != nil { + return tr.Errorf("Error starting procmail: %v", err) + } + + _, err = bytes.NewBuffer(data).WriteTo(cmdStdin) + if err != nil { + return tr.Errorf("Error sending data to procmail: %v", err) + } + + cmdStdin.Close() + + timer := time.AfterFunc(procmailTimeout, func() { + cmd.Process.Kill() + }) + err = cmd.Wait() + timedOut := !timer.Stop() + + if timedOut { + return tr.Error(timeoutError) + } + if err != nil { + return tr.Errorf("Procmail failed: %v - %q", err, output.String()) + } + return nil +} + +// sanitizeForProcmail cleans the string, leaving only [a-zA-Z-.]. +func sanitizeForProcmail(s string) string { + valid := func(r rune) rune { + switch { + case r >= 'A' && r <= 'Z', r >= 'a' && r <= 'z', r == '-', r == '.': + return r + default: + return rune(-1) + } + } + return strings.Map(valid, s) +} diff --git a/internal/courier/procmail_test.go b/internal/courier/procmail_test.go new file mode 100644 index 0000000..7e2d35c --- /dev/null +++ b/internal/courier/procmail_test.go @@ -0,0 +1,65 @@ +package courier + +import ( + "bytes" + "io/ioutil" + "os" + "testing" + "time" +) + +func TestProcmail(t *testing.T) { + dir, err := ioutil.TempDir("", "test-chasquid-courier") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(dir) + + procmailBin = "tee" + procmailArgs = []string{dir + "/%user%"} + + p := Procmail{} + err = p.Deliver("from@x", "to@y", []byte("data")) + if err != nil { + t.Fatalf("Deliver: %v", err) + } + + data, err := ioutil.ReadFile(dir + "/to") + if err != nil || !bytes.Equal(data, []byte("data")) { + t.Errorf("Invalid data: %q - %v", string(data), err) + } +} + +func TestProcmailTimeout(t *testing.T) { + procmailBin = "/bin/sleep" + procmailArgs = []string{"1"} + procmailTimeout = 100 * time.Millisecond + + p := Procmail{} + err := p.Deliver("from", "to", []byte("data")) + if err != timeoutError { + t.Errorf("Unexpected error: %v", err) + } + + procmailTimeout = 1 * time.Second +} + +func TestProcmailBadCommandLine(t *testing.T) { + p := Procmail{} + + // Non-existent binary. + procmailBin = "thisdoesnotexist" + err := p.Deliver("from", "to", []byte("data")) + if err == nil { + t.Errorf("Unexpected success: %q %v", procmailBin, procmailArgs) + } + + // Incorrect arguments. + procmailBin = "cat" + procmailArgs = []string{"--fail_unknown_option"} + + err = p.Deliver("from", "to", []byte("data")) + if err == nil { + t.Errorf("Unexpected success: %q %v", procmailBin, procmailArgs) + } +} diff --git a/internal/courier/smtp.go b/internal/courier/smtp.go new file mode 100644 index 0000000..4750cd1 --- /dev/null +++ b/internal/courier/smtp.go @@ -0,0 +1,128 @@ +package courier + +import ( + "crypto/tls" + "net" + "net/smtp" + "time" + + "github.com/golang/glog" + + "blitiri.com.ar/go/chasquid/internal/trace" +) + +var ( + // Timeouts for SMTP delivery. + smtpDialTimeout = 1 * time.Minute + smtpTotalTimeout = 10 * time.Minute + + // Port for outgoing SMTP. + // Tests can override this. + smtpPort = "25" + + // Fake MX records, used for testing only. + fakeMX = map[string]string{} +) + +// SMTP delivers remote mail via outgoing SMTP. +type SMTP struct { +} + +func (s *SMTP) Deliver(from string, to string, data []byte) error { + tr := trace.New("goingSMTP", "Deliver") + defer tr.Finish() + tr.LazyPrintf("%s -> %s", from, to) + + mx, err := lookupMX(domainOf(to)) + if err != nil { + return tr.Errorf("Could not find mail server: %v", err) + } + tr.LazyPrintf("MX: %s", mx) + + // Do we use insecure TLS? + // Set as fallback when retrying. + insecure := false + +retry: + conn, err := net.DialTimeout("tcp", mx+":"+smtpPort, smtpDialTimeout) + if err != nil { + return tr.Errorf("Could not dial: %v", err) + } + conn.SetDeadline(time.Now().Add(smtpTotalTimeout)) + + c, err := smtp.NewClient(conn, mx) + if err != nil { + return tr.Errorf("Error creating client: %v", err) + } + + // TODO: Keep track of hosts and MXs that we've successfully done TLS + // against, and enforce it. + if ok, _ := c.Extension("STARTTLS"); ok { + config := &tls.Config{ + ServerName: mx, + InsecureSkipVerify: insecure, + } + err = c.StartTLS(config) + if err != nil { + // Unfortunately, many servers use self-signed certs, so if we + // fail verification we just try again without validating. + if insecure { + return tr.Errorf("TLS error: %v", err) + } + + insecure = true + tr.LazyPrintf("TLS error, retrying insecurely") + goto retry + } + + if config.InsecureSkipVerify { + tr.LazyPrintf("Insecure - self-signed certificate") + } else { + tr.LazyPrintf("Secure - using TLS") + } + } else { + tr.LazyPrintf("Insecure - not using TLS") + } + + if err = c.Mail(from); err != nil { + return tr.Errorf("MAIL %v", err) + } + + if err = c.Rcpt(to); err != nil { + return tr.Errorf("RCPT TO %v", err) + } + + w, err := c.Data() + if err != nil { + return tr.Errorf("DATA %v", err) + } + _, err = w.Write(data) + if err != nil { + return tr.Errorf("DATA writing: %v", err) + } + + err = w.Close() + if err != nil { + return tr.Errorf("DATA closing %v", err) + } + + c.Quit() + + return nil +} + +func lookupMX(domain string) (string, error) { + if v, ok := fakeMX[domain]; ok { + return v, nil + } + + mxs, err := net.LookupMX(domain) + if err != nil { + return "", err + } else if len(mxs) == 0 { + glog.Infof("domain %q has no MX, falling back to A", domain) + return domain, nil + } + + return mxs[0].Host, nil +} diff --git a/internal/courier/smtp_test.go b/internal/courier/smtp_test.go new file mode 100644 index 0000000..6a6c3e4 --- /dev/null +++ b/internal/courier/smtp_test.go @@ -0,0 +1,144 @@ +package courier + +import ( + "bufio" + "net" + "net/textproto" + "testing" + "time" +) + +// Fake server, to test SMTP out. +func fakeServer(t *testing.T, responses map[string]string) string { + l, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("fake server listen: %v", err) + } + + go func() { + defer l.Close() + + c, err := l.Accept() + if err != nil { + t.Fatalf("fake server accept: %v", err) + } + defer c.Close() + + t.Logf("fakeServer got connection") + + r := textproto.NewReader(bufio.NewReader(c)) + c.Write([]byte(responses["_welcome"])) + for { + line, err := r.ReadLine() + if err != nil { + t.Logf("fakeServer exiting: %v\n", err) + return + } + + t.Logf("fakeServer read: %q\n", line) + c.Write([]byte(responses[line])) + + if line == "DATA" { + _, err = r.ReadDotBytes() + if err != nil { + t.Logf("fakeServer exiting: %v\n", err) + return + } + c.Write([]byte(responses["_DATA"])) + } + } + }() + + return l.Addr().String() +} + +func TestSMTP(t *testing.T) { + // Shorten the total timeout, so the test fails quickly if the protocol + // gets stuck. + smtpTotalTimeout = 5 * time.Second + + responses := map[string]string{ + "_welcome": "220 welcome\n", + "EHLO localhost": "250 ehlo ok\n", + "MAIL FROM:<me@me>": "250 mail ok\n", + "RCPT TO:<to@to>": "250 rcpt ok\n", + "DATA": "354 send data\n", + "_DATA": "250 data ok\n", + "QUIT": "250 quit ok\n", + } + addr := fakeServer(t, responses) + host, port, _ := net.SplitHostPort(addr) + + fakeMX["to"] = host + smtpPort = port + + s := &SMTP{} + err := s.Deliver("me@me", "to@to", []byte("data")) + if err != nil { + t.Errorf("deliver failed: %v", err) + } +} + +func TestSMTPErrors(t *testing.T) { + // Shorten the total timeout, so the test fails quickly if the protocol + // gets stuck. + smtpTotalTimeout = 1 * time.Second + + responses := []map[string]string{ + // First test: hang response, should fail due to timeout. + map[string]string{ + "_welcome": "220 no newline", + }, + + // MAIL FROM not allowed. + map[string]string{ + "_welcome": "220 mail from not allowed\n", + "EHLO localhost": "250 ehlo ok\n", + "MAIL FROM:<me@me>": "501 mail error\n", + }, + + // RCPT TO not allowed. + map[string]string{ + "_welcome": "220 rcpt to not allowed\n", + "EHLO localhost": "250 ehlo ok\n", + "MAIL FROM:<me@me>": "250 mail ok\n", + "RCPT TO:<to@to>": "501 rcpt error\n", + }, + + // DATA error. + map[string]string{ + "_welcome": "220 data error\n", + "EHLO localhost": "250 ehlo ok\n", + "MAIL FROM:<me@me>": "250 mail ok\n", + "RCPT TO:<to@to>": "250 rcpt ok\n", + "DATA": "554 data error\n", + }, + + // DATA response error. + map[string]string{ + "_welcome": "220 data response error\n", + "EHLO localhost": "250 ehlo ok\n", + "MAIL FROM:<me@me>": "250 mail ok\n", + "RCPT TO:<to@to>": "250 rcpt ok\n", + "DATA": "354 send data\n", + "_DATA": "551 data response error\n", + }, + } + + for _, rs := range responses { + addr := fakeServer(t, rs) + host, port, _ := net.SplitHostPort(addr) + + fakeMX["to"] = host + smtpPort = port + + s := &SMTP{} + err := s.Deliver("me@me", "to@to", []byte("data")) + if err == nil { + t.Errorf("deliver not failed in case %q: %v", rs["_welcome"], err) + } + t.Logf("failed as expected: %v", err) + } +} + +// TODO: Test STARTTLS negotiation. diff --git a/internal/queue/queue.go b/internal/queue/queue.go index e46d251..7bbd877 100644 --- a/internal/queue/queue.go +++ b/internal/queue/queue.go @@ -9,6 +9,8 @@ import ( "sync" "time" + "blitiri.com.ar/go/chasquid/internal/courier" + "github.com/golang/glog" "golang.org/x/net/trace" ) @@ -57,19 +59,29 @@ type Queue struct { // Mutex protecting q. mu sync.RWMutex + + // Courier to use to deliver mail. + courier courier.Courier } // TODO: Store the queue on disk. // Load the queue and launch the sending loops on startup. -func New() *Queue { +func New(c courier.Courier) *Queue { return &Queue{ - q: map[string]*Item{}, + q: map[string]*Item{}, + courier: c, } } +func (q *Queue) Len() int { + q.mu.RLock() + defer q.mu.RUnlock() + return len(q.q) +} + // Put an envelope in the queue. func (q *Queue) Put(from string, to []string, data []byte) (string, error) { - if len(q.q) >= maxQueueSize { + if q.Len() >= maxQueueSize { return "", queueFullError } @@ -85,7 +97,7 @@ func (q *Queue) Put(from string, to []string, data []byte) (string, error) { q.q[item.ID] = item q.mu.Unlock() - glog.Infof("Queue accepted %s from %q", item.ID, from) + glog.Infof("%s accepted from %q", item.ID, from) // Begin to send it right away. go item.SendLoop(q) @@ -131,6 +143,7 @@ func (item *Item) SendLoop(q *Queue) { defer tr.Finish() tr.LazyPrintf("from: %s", item.From) + var err error for time.Since(item.Created) < giveUpAfter { // Send to all recipients that are still pending. successful := 0 @@ -144,11 +157,16 @@ func (item *Item) SendLoop(q *Queue) { tr.LazyPrintf("%s sending", to) // TODO: deliver, serially or in parallel with a waitgroup. - // Fake a successful send for now. - item.Results[to] = nil - successful++ - - tr.LazyPrintf("%s successful", to) + err = q.courier.Deliver(item.From, to, item.Data) + item.Results[to] = err + if err != nil { + tr.LazyPrintf("error: %v", err) + glog.Infof("%s -> %q fail: %v", item.ID, to, err) + } else { + successful++ + tr.LazyPrintf("%s successful", to) + glog.Infof("%s -> %q sent", item.ID, to) + } } if successful == len(item.To) { @@ -165,7 +183,7 @@ func (item *Item) SendLoop(q *Queue) { // Put a table and function below, to change this easily. // We should track the duration of the previous one too? Or computed // based on created? - time.Sleep(3 * time.Minute) + time.Sleep(30 * time.Second) } diff --git a/internal/queue/queue_test.go b/internal/queue/queue_test.go index 06a9f9f..119e3e3 100644 --- a/internal/queue/queue_test.go +++ b/internal/queue/queue_test.go @@ -6,8 +6,34 @@ import ( "time" ) +// Our own courier, for testing purposes. +// Delivery is done by sending on a channel. +type ChanCourier struct { + requests chan deliverRequest + results chan error +} + +type deliverRequest struct { + from string + to string + data []byte +} + +func (cc *ChanCourier) Deliver(from string, to string, data []byte) error { + cc.requests <- deliverRequest{from, to, data} + return <-cc.results +} + +func newCourier() *ChanCourier { + return &ChanCourier{ + requests: make(chan deliverRequest), + results: make(chan error), + } +} + func TestBasic(t *testing.T) { - q := New() + courier := newCourier() + q := New(courier) id, err := q.Put("from", []string{"to"}, []byte("data")) if err != nil { @@ -22,23 +48,27 @@ func TestBasic(t *testing.T) { item := q.q[id] q.mu.RUnlock() - // TODO: There's a race because the item may finish the loop before we - // poll it from the queue, and we would get a nil item in that case. - // We have to live with this for now, and will close it later once we - // implement deliveries. if item == nil { - t.Logf("hit item race, nothing else to do") - return + t.Fatalf("item not in queue, racy test?") } if item.From != "from" || item.To[0] != "to" || !bytes.Equal(item.Data, []byte("data")) { t.Errorf("different item: %#v", item) } + + // Test that we delivered the item. + req := <-courier.requests + courier.results <- nil + + if req.from != "from" || req.to != "to" || + !bytes.Equal(req.data, []byte("data")) { + t.Errorf("different courier request: %#v", req) + } } func TestFullQueue(t *testing.T) { - q := New() + q := New(newCourier()) // Force-insert maxQueueSize items in the queue. oneID := "" diff --git a/internal/trace/trace.go b/internal/trace/trace.go index f2b5952..4314079 100644 --- a/internal/trace/trace.go +++ b/internal/trace/trace.go @@ -35,7 +35,23 @@ func (t *Trace) SetError() { func (t *Trace) Errorf(format string, a ...interface{}) error { err := fmt.Errorf(format, a...) t.t.SetError() - t.LazyPrintf("Error: %v", err) + t.t.LazyPrintf("error: %v", err) + + if glog.V(2) { + msg := fmt.Sprintf("%p %s %s: error: %v", t, t.family, t.title, err) + glog.InfoDepth(1, msg) + } + return err +} + +func (t *Trace) Error(err error) error { + t.t.SetError() + t.t.LazyPrintf("error: %v", err) + + if glog.V(2) { + msg := fmt.Sprintf("%p %s %s: error: %v", t, t.family, t.title, err) + glog.InfoDepth(1, msg) + } return err }