git » chasquid » commit c738c7e

courier: Move the test fake server into a separate file

author Alberto Bertogli
2022-11-12 12:07:07 UTC
committer Alberto Bertogli
2022-11-12 16:34:35 UTC
parent cbb620eec24de2fcd0a2d5ef1e598f0ff0b4d61f

courier: Move the test fake server into a separate file

In future patches we will use the test fake server in other tests, so
move it to a separate file for clarity.

internal/courier/fakeserver_test.go +150 -0
internal/courier/smtp_test.go +1 -143

diff --git a/internal/courier/fakeserver_test.go b/internal/courier/fakeserver_test.go
new file mode 100644
index 0000000..8cabe98
--- /dev/null
+++ b/internal/courier/fakeserver_test.go
@@ -0,0 +1,150 @@
+package courier
+
+import (
+	"bufio"
+	"crypto/tls"
+	"crypto/x509"
+	"io/ioutil"
+	"net"
+	"net/textproto"
+	"os"
+	"sync"
+	"testing"
+
+	"blitiri.com.ar/go/chasquid/internal/testlib"
+)
+
+// Fake server, to test SMTP out.
+type FakeServer struct {
+	t         *testing.T
+	tmpDir    string
+	responses map[string]string
+	wg        *sync.WaitGroup
+	addr      string
+	tlsConfig *tls.Config
+}
+
+func newFakeServer(t *testing.T, responses map[string]string) *FakeServer {
+	s := &FakeServer{
+		t:         t,
+		tmpDir:    testlib.MustTempDir(t),
+		responses: responses,
+		wg:        &sync.WaitGroup{},
+	}
+	s.start()
+	return s
+}
+
+func (s *FakeServer) Cleanup() {
+	// Remove our temporary data. Be extra paranoid and make sure the
+	// directory isn't too shallow.
+	if len(s.tmpDir) > 8 {
+		os.RemoveAll(s.tmpDir)
+	}
+}
+
+func (s *FakeServer) initTLS() {
+	var err error
+	s.tlsConfig, err = testlib.GenerateCert(s.tmpDir)
+	if err != nil {
+		s.t.Fatalf("error generating cert: %v", err)
+	}
+
+	cert, err := tls.LoadX509KeyPair(s.tmpDir+"/cert.pem", s.tmpDir+"/key.pem")
+	if err != nil {
+		s.t.Fatalf("error loading temp cert: %v", err)
+	}
+
+	s.tlsConfig.Certificates = []tls.Certificate{cert}
+}
+
+func (s *FakeServer) rootCA() *x509.CertPool {
+	s.t.Helper()
+	pool := x509.NewCertPool()
+	path := s.tmpDir + "/cert.pem"
+	data, err := ioutil.ReadFile(path)
+	if err != nil {
+		s.t.Fatalf("error reading cert %q: %v", path, err)
+	}
+	ok := pool.AppendCertsFromPEM(data)
+	if !ok {
+		s.t.Fatalf("failed to load cert %q", path)
+	}
+	return pool
+}
+
+func (s *FakeServer) start() string {
+	s.t.Helper()
+	l, err := net.Listen("tcp", "localhost:0")
+	if err != nil {
+		s.t.Fatalf("fake server listen: %v", err)
+	}
+	s.addr = l.Addr().String()
+
+	s.initTLS()
+
+	s.wg.Add(1)
+
+	go func() {
+		defer s.wg.Done()
+		defer l.Close()
+
+		c, err := l.Accept()
+		if err != nil {
+			panic(err)
+		}
+		defer c.Close()
+
+		s.t.Logf("fakeServer got connection")
+
+		r := textproto.NewReader(bufio.NewReader(c))
+		c.Write([]byte(s.responses["_welcome"]))
+		for {
+			line, err := r.ReadLine()
+			if err != nil {
+				s.t.Logf("fakeServer exiting: %v\n", err)
+				return
+			}
+
+			s.t.Logf("fakeServer read: %q\n", line)
+			if line == "STARTTLS" && s.responses["_STARTTLS"] == "ok" {
+				c.Write([]byte(s.responses["STARTTLS"]))
+
+				tlssrv := tls.Server(c, s.tlsConfig)
+				err = tlssrv.Handshake()
+				if err != nil {
+					s.t.Logf("starttls handshake error: %v", err)
+					return
+				}
+
+				// Replace the connection with the wrapped one.
+				// Don't send a reply, as per the protocol.
+				c = tlssrv
+				defer c.Close()
+				r = textproto.NewReader(bufio.NewReader(c))
+				continue
+			}
+
+			c.Write([]byte(s.responses[line]))
+			if line == "DATA" {
+				_, err = r.ReadDotBytes()
+				if err != nil {
+					s.t.Logf("fakeServer exiting: %v\n", err)
+					return
+				}
+				c.Write([]byte(s.responses["_DATA"]))
+			}
+		}
+	}()
+
+	return s.addr
+}
+
+func (s *FakeServer) HostPort() (string, string) {
+	host, port, _ := net.SplitHostPort(s.addr)
+	return host, port
+}
+
+func (s *FakeServer) Wait() {
+	s.wg.Wait()
+}
diff --git a/internal/courier/smtp_test.go b/internal/courier/smtp_test.go
index 5827e3d..2f47627 100644
--- a/internal/courier/smtp_test.go
+++ b/internal/courier/smtp_test.go
@@ -1,16 +1,9 @@
 package courier
 
 import (
-	"bufio"
-	"crypto/tls"
-	"crypto/x509"
 	"fmt"
-	"io/ioutil"
 	"net"
-	"net/textproto"
-	"os"
 	"strings"
-	"sync"
 	"testing"
 	"time"
 
@@ -44,126 +37,6 @@ func newSMTP(t *testing.T) (*SMTP, string) {
 	return &SMTP{"hello", dinfo, nil}, dir
 }
 
-// Fake server, to test SMTP out.
-type FakeServer struct {
-	t         *testing.T
-	tmpDir    string
-	responses map[string]string
-	wg        *sync.WaitGroup
-	addr      string
-	tlsConfig *tls.Config
-}
-
-func newFakeServer(t *testing.T, responses map[string]string) *FakeServer {
-	s := &FakeServer{
-		t:         t,
-		tmpDir:    testlib.MustTempDir(t),
-		responses: responses,
-		wg:        &sync.WaitGroup{},
-	}
-	s.start()
-	return s
-}
-
-func (s *FakeServer) Cleanup() {
-	// Remove our temporary data. Be extra paranoid and make sure the
-	// directory isn't too shallow.
-	if len(s.tmpDir) > 8 {
-		os.RemoveAll(s.tmpDir)
-	}
-}
-
-func (s *FakeServer) initTLS() {
-	var err error
-	s.tlsConfig, err = testlib.GenerateCert(s.tmpDir)
-	if err != nil {
-		s.t.Fatalf("error generating cert: %v", err)
-	}
-
-	cert, err := tls.LoadX509KeyPair(s.tmpDir+"/cert.pem", s.tmpDir+"/key.pem")
-	if err != nil {
-		s.t.Fatalf("error loading temp cert: %v", err)
-	}
-
-	s.tlsConfig.Certificates = []tls.Certificate{cert}
-}
-
-func (s *FakeServer) start() string {
-	s.t.Helper()
-	l, err := net.Listen("tcp", "localhost:0")
-	if err != nil {
-		s.t.Fatalf("fake server listen: %v", err)
-	}
-	s.addr = l.Addr().String()
-
-	s.initTLS()
-
-	s.wg.Add(1)
-
-	go func() {
-		defer s.wg.Done()
-		defer l.Close()
-
-		c, err := l.Accept()
-		if err != nil {
-			panic(err)
-		}
-		defer c.Close()
-
-		s.t.Logf("fakeServer got connection")
-
-		r := textproto.NewReader(bufio.NewReader(c))
-		c.Write([]byte(s.responses["_welcome"]))
-		for {
-			line, err := r.ReadLine()
-			if err != nil {
-				s.t.Logf("fakeServer exiting: %v\n", err)
-				return
-			}
-
-			s.t.Logf("fakeServer read: %q\n", line)
-			if line == "STARTTLS" && s.responses["_STARTTLS"] == "ok" {
-				c.Write([]byte(s.responses["STARTTLS"]))
-
-				tlssrv := tls.Server(c, s.tlsConfig)
-				err = tlssrv.Handshake()
-				if err != nil {
-					s.t.Logf("starttls handshake error: %v", err)
-					return
-				}
-
-				// Replace the connection with the wrapped one.
-				// Don't send a reply, as per the protocol.
-				c = tlssrv
-				defer c.Close()
-				r = textproto.NewReader(bufio.NewReader(c))
-				continue
-			}
-
-			c.Write([]byte(s.responses[line]))
-			if line == "DATA" {
-				_, err = r.ReadDotBytes()
-				if err != nil {
-					s.t.Logf("fakeServer exiting: %v\n", err)
-					return
-				}
-				c.Write([]byte(s.responses["_DATA"]))
-			}
-		}
-	}()
-
-	return s.addr
-}
-
-func (s *FakeServer) HostPort() (string, string) {
-	host, port, _ := net.SplitHostPort(s.addr)
-	return host, port
-}
-
-func (s *FakeServer) Wait() {
-	s.wg.Wait()
-}
-
 func TestSMTP(t *testing.T) {
 	// Shorten the total timeout, so the test fails quickly if the protocol
 	// gets stuck.
@@ -503,7 +376,7 @@ func TestSTSPolicyEnforcement(t *testing.T) {
 	_, *smtpPort = srv.HostPort()
 	defer srv.Cleanup()
 
-	certRoots = loadCert(t, srv.tmpDir+"/cert.pem")
+	certRoots = srv.rootCA()
 	defer func() {
 		certRoots = nil
 	}()
@@ -515,18 +388,3 @@ func TestSTSPolicyEnforcement(t *testing.T) {
 
 	srv.Wait()
 }
-
-func loadCert(t *testing.T, path string) *x509.CertPool {
-	t.Helper()
-
-	pool := x509.NewCertPool()
-	data, err := ioutil.ReadFile(path)
-	if err != nil {
-		t.Fatalf("error reading cert %q: %v", path, err)
-	}
-	ok := pool.AppendCertsFromPEM(data)
-	if !ok {
-		t.Fatalf("failed to load cert %q", path)
-	}
-	return pool
-}