author | Alberto Bertogli
<albertito@blitiri.com.ar> 2022-11-12 12:07:07 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2022-11-12 16:34:35 UTC |
parent | cbb620eec24de2fcd0a2d5ef1e598f0ff0b4d61f |
internal/courier/fakeserver_test.go | +150 | -0 |
internal/courier/smtp_test.go | +1 | -143 |
diff --git a/internal/courier/fakeserver_test.go b/internal/courier/fakeserver_test.go new file mode 100644 index 0000000..8cabe98 --- /dev/null +++ b/internal/courier/fakeserver_test.go @@ -0,0 +1,150 @@ +package courier + +import ( + "bufio" + "crypto/tls" + "crypto/x509" + "io/ioutil" + "net" + "net/textproto" + "os" + "sync" + "testing" + + "blitiri.com.ar/go/chasquid/internal/testlib" +) + +// Fake server, to test SMTP out. +type FakeServer struct { + t *testing.T + tmpDir string + responses map[string]string + wg *sync.WaitGroup + addr string + tlsConfig *tls.Config +} + +func newFakeServer(t *testing.T, responses map[string]string) *FakeServer { + s := &FakeServer{ + t: t, + tmpDir: testlib.MustTempDir(t), + responses: responses, + wg: &sync.WaitGroup{}, + } + s.start() + return s +} + +func (s *FakeServer) Cleanup() { + // Remove our temporary data. Be extra paranoid and make sure the + // directory isn't too shallow. + if len(s.tmpDir) > 8 { + os.RemoveAll(s.tmpDir) + } +} + +func (s *FakeServer) initTLS() { + var err error + s.tlsConfig, err = testlib.GenerateCert(s.tmpDir) + if err != nil { + s.t.Fatalf("error generating cert: %v", err) + } + + cert, err := tls.LoadX509KeyPair(s.tmpDir+"/cert.pem", s.tmpDir+"/key.pem") + if err != nil { + s.t.Fatalf("error loading temp cert: %v", err) + } + + s.tlsConfig.Certificates = []tls.Certificate{cert} +} + +func (s *FakeServer) rootCA() *x509.CertPool { + s.t.Helper() + pool := x509.NewCertPool() + path := s.tmpDir + "/cert.pem" + data, err := ioutil.ReadFile(path) + if err != nil { + s.t.Fatalf("error reading cert %q: %v", path, err) + } + ok := pool.AppendCertsFromPEM(data) + if !ok { + s.t.Fatalf("failed to load cert %q", path) + } + return pool +} + +func (s *FakeServer) start() string { + s.t.Helper() + l, err := net.Listen("tcp", "localhost:0") + if err != nil { + s.t.Fatalf("fake server listen: %v", err) + } + s.addr = l.Addr().String() + + s.initTLS() + + s.wg.Add(1) + + go func() { + defer s.wg.Done() + defer l.Close() + + c, err := l.Accept() + if err != nil { + panic(err) + } + defer c.Close() + + s.t.Logf("fakeServer got connection") + + r := textproto.NewReader(bufio.NewReader(c)) + c.Write([]byte(s.responses["_welcome"])) + for { + line, err := r.ReadLine() + if err != nil { + s.t.Logf("fakeServer exiting: %v\n", err) + return + } + + s.t.Logf("fakeServer read: %q\n", line) + if line == "STARTTLS" && s.responses["_STARTTLS"] == "ok" { + c.Write([]byte(s.responses["STARTTLS"])) + + tlssrv := tls.Server(c, s.tlsConfig) + err = tlssrv.Handshake() + if err != nil { + s.t.Logf("starttls handshake error: %v", err) + return + } + + // Replace the connection with the wrapped one. + // Don't send a reply, as per the protocol. + c = tlssrv + defer c.Close() + r = textproto.NewReader(bufio.NewReader(c)) + continue + } + + c.Write([]byte(s.responses[line])) + if line == "DATA" { + _, err = r.ReadDotBytes() + if err != nil { + s.t.Logf("fakeServer exiting: %v\n", err) + return + } + c.Write([]byte(s.responses["_DATA"])) + } + } + }() + + return s.addr +} + +func (s *FakeServer) HostPort() (string, string) { + host, port, _ := net.SplitHostPort(s.addr) + return host, port +} + +func (s *FakeServer) Wait() { + s.wg.Wait() +} diff --git a/internal/courier/smtp_test.go b/internal/courier/smtp_test.go index 5827e3d..2f47627 100644 --- a/internal/courier/smtp_test.go +++ b/internal/courier/smtp_test.go @@ -1,16 +1,9 @@ package courier import ( - "bufio" - "crypto/tls" - "crypto/x509" "fmt" - "io/ioutil" "net" - "net/textproto" - "os" "strings" - "sync" "testing" "time" @@ -44,126 +37,6 @@ func newSMTP(t *testing.T) (*SMTP, string) { return &SMTP{"hello", dinfo, nil}, dir } -// Fake server, to test SMTP out. -type FakeServer struct { - t *testing.T - tmpDir string - responses map[string]string - wg *sync.WaitGroup - addr string - tlsConfig *tls.Config -} - -func newFakeServer(t *testing.T, responses map[string]string) *FakeServer { - s := &FakeServer{ - t: t, - tmpDir: testlib.MustTempDir(t), - responses: responses, - wg: &sync.WaitGroup{}, - } - s.start() - return s -} - -func (s *FakeServer) Cleanup() { - // Remove our temporary data. Be extra paranoid and make sure the - // directory isn't too shallow. - if len(s.tmpDir) > 8 { - os.RemoveAll(s.tmpDir) - } -} - -func (s *FakeServer) initTLS() { - var err error - s.tlsConfig, err = testlib.GenerateCert(s.tmpDir) - if err != nil { - s.t.Fatalf("error generating cert: %v", err) - } - - cert, err := tls.LoadX509KeyPair(s.tmpDir+"/cert.pem", s.tmpDir+"/key.pem") - if err != nil { - s.t.Fatalf("error loading temp cert: %v", err) - } - - s.tlsConfig.Certificates = []tls.Certificate{cert} -} - -func (s *FakeServer) start() string { - s.t.Helper() - l, err := net.Listen("tcp", "localhost:0") - if err != nil { - s.t.Fatalf("fake server listen: %v", err) - } - s.addr = l.Addr().String() - - s.initTLS() - - s.wg.Add(1) - - go func() { - defer s.wg.Done() - defer l.Close() - - c, err := l.Accept() - if err != nil { - panic(err) - } - defer c.Close() - - s.t.Logf("fakeServer got connection") - - r := textproto.NewReader(bufio.NewReader(c)) - c.Write([]byte(s.responses["_welcome"])) - for { - line, err := r.ReadLine() - if err != nil { - s.t.Logf("fakeServer exiting: %v\n", err) - return - } - - s.t.Logf("fakeServer read: %q\n", line) - if line == "STARTTLS" && s.responses["_STARTTLS"] == "ok" { - c.Write([]byte(s.responses["STARTTLS"])) - - tlssrv := tls.Server(c, s.tlsConfig) - err = tlssrv.Handshake() - if err != nil { - s.t.Logf("starttls handshake error: %v", err) - return - } - - // Replace the connection with the wrapped one. - // Don't send a reply, as per the protocol. - c = tlssrv - defer c.Close() - r = textproto.NewReader(bufio.NewReader(c)) - continue - } - - c.Write([]byte(s.responses[line])) - if line == "DATA" { - _, err = r.ReadDotBytes() - if err != nil { - s.t.Logf("fakeServer exiting: %v\n", err) - return - } - c.Write([]byte(s.responses["_DATA"])) - } - } - }() - - return s.addr -} - -func (s *FakeServer) HostPort() (string, string) { - host, port, _ := net.SplitHostPort(s.addr) - return host, port -} - -func (s *FakeServer) Wait() { - s.wg.Wait() -} - func TestSMTP(t *testing.T) { // Shorten the total timeout, so the test fails quickly if the protocol // gets stuck. @@ -503,7 +376,7 @@ func TestSTSPolicyEnforcement(t *testing.T) { _, *smtpPort = srv.HostPort() defer srv.Cleanup() - certRoots = loadCert(t, srv.tmpDir+"/cert.pem") + certRoots = srv.rootCA() defer func() { certRoots = nil }() @@ -515,18 +388,3 @@ func TestSTSPolicyEnforcement(t *testing.T) { srv.Wait() } - -func loadCert(t *testing.T, path string) *x509.CertPool { - t.Helper() - - pool := x509.NewCertPool() - data, err := ioutil.ReadFile(path) - if err != nil { - t.Fatalf("error reading cert %q: %v", path, err) - } - ok := pool.AppendCertsFromPEM(data) - if !ok { - t.Fatalf("failed to load cert %q", path) - } - return pool -}