git » kxd » commit 480b41f

tests: Add tests for sending emails via SMTP

author Alberto Bertogli
2024-08-16 00:41:41 UTC
committer Alberto Bertogli
2024-09-08 10:14:19 UTC
parent ccc0a2ab4f2520ce56769c550c58cb64738fab29

tests: Add tests for sending emails via SMTP

This patch adds basic tests for sending emails via SMTP. It is fairly
crude, but works well enough to test basic functionality.

tests/run_tests +132 -12

diff --git a/tests/run_tests b/tests/run_tests
index cec676c..4d2fd34 100755
--- a/tests/run_tests
+++ b/tests/run_tests
@@ -23,6 +23,7 @@ import ssl
 import subprocess
 import tempfile
 import textwrap
+import threading
 import time
 import tracemalloc
 import unittest
@@ -179,7 +180,7 @@ class StaticConfig(Config):
         raise NotImplementedError("StaticConfig does not support gen_cert")
 
 
-def launch_daemon(cfg):
+def launch_daemon(cfg, smtp_addr=None):
     args = [
         BINS + "/kxd",
         "--data_dir=%s/data" % cfg,
@@ -188,6 +189,8 @@ def launch_daemon(cfg):
         "--logfile=%s/log" % cfg,
         "--hook=%s/hook" % cfg,
     ]
+    if smtp_addr:
+        args.append("--smtp_addr=%s:%s" % smtp_addr)
     print("Launching server: ", " ".join(args))
     return subprocess.Popen(args)
 
@@ -197,6 +200,61 @@ def read_all(fname):
         return fd.read()
 
 
+def receive_emails():
+    """Receive emails from the server.
+
+    Return:
+    - The server address.
+    - A list where we will put the emails: each email is a tuple of
+      (destination addresses, body).
+    """
+    server = socket.create_server(("localhost", 0))
+    server.listen(1)
+    addr = server.getsockname()
+
+    emails = []
+
+    def handler():
+        # Complete a simple SMTP transaction.
+        sock, addr = server.accept()
+        sock.sendall(b"220 localhost SMTP\n")
+        body = b""
+        rcpt_to = []
+        while True:
+            data = sock.recv(1024)
+            if not data:
+                break
+            if data.startswith(b"HELO"):
+                sock.sendall(b"250 localhost\n")
+            elif data.startswith(b"MAIL FROM"):
+                sock.sendall(b"250 OK\n")
+            elif data.startswith(b"RCPT TO"):
+                rcpt_to.append(data.split(b"<")[1].split(b">")[0].decode())
+                sock.sendall(b"250 OK\n")
+            elif data.startswith(b"DATA"):
+                sock.sendall(b"354 Start data\n")
+                body = b""
+                while True:
+                    body += sock.recv(1024)
+                    if body.endswith(b"\r\n.\r\n"):
+                        sock.sendall(b"250 OK\n")
+                        break
+            elif data.startswith(b"QUIT"):
+                sock.sendall(b"221 Chau\n")
+                break
+            else:
+                sock.sendall(b"500 Unknown command\n")
+
+        emails.append((rcpt_to, body))
+        sock.close()
+        server.close()
+
+    server_thread = threading.Thread(target=handler, daemon=True)
+    server_thread.start()
+
+    return addr, emails
+
+
 class TestCase(unittest.TestCase):
     def setUp(self):
         self.server = ServerConfig()
@@ -210,8 +268,8 @@ class TestCase(unittest.TestCase):
             self.daemon.terminate()
             self.daemon.wait()
 
-    def launch_server(self, server):
-        self.daemon = launch_daemon(server.path)
+    def launch_server(self, server, smtp_addr=None):
+        self.daemon = launch_daemon(server.path, smtp_addr)
 
         # Wait for the server to start accepting connections.
         deadline = time.time() + 5
@@ -620,6 +678,15 @@ class BrokenServerConfig(TestCase):
         self.assertClientFails("kxd://localhost/k1", "404 Not Found")
 
 
+EMAIL_TO_FILE = textwrap.dedent(
+    """
+    # Comment.
+    me@example.com
+    you@test.net
+    """.strip()
+)
+
+
 class Hook(TestCase):
     """Test cases for hook support."""
 
@@ -632,14 +699,6 @@ class Hook(TestCase):
         """.strip()
     )
 
-    EMAIL_TO_FILE = textwrap.dedent(
-        """
-        # Comment.
-        me@example.com
-        you@test.net
-        """.strip()
-    )
-
     def write_hook(self, exit_code):
         path = self.server.path + "/hook"
         script = self.HOOK_SCRIPT_TMPL.format(exit_code=exit_code)
@@ -679,7 +738,7 @@ class Hook(TestCase):
             "k2",
             allowed_clients=[self.client.cert()],
             allowed_hosts=["localhost"],
-            email_to=self.EMAIL_TO_FILE,
+            email_to=EMAIL_TO_FILE,
         )
         key = self.client.call(self.server.cert_path(), "kxd://localhost/k2")
         self.assertEqual(key, self.server.keys["k2"])
@@ -689,5 +748,66 @@ class Hook(TestCase):
         self.assertIn("EMAIL_TO=me@example.com you@test.net", hook_out)
 
 
+class Emails(TestCase):
+    """Tests for email notifications."""
+
+    def setUp(self):
+        self.smtp_addr, self.emails = receive_emails()
+        self.server = ServerConfig()
+        self.client = ClientConfig()
+        self.daemon = None
+        self.ca = None  # pylint: disable=invalid-name
+        self.launch_server(self.server, smtp_addr=self.smtp_addr)
+
+    def test_emails(self):
+        self.server.new_key(
+            "k1",
+            allowed_clients=[self.client.cert()],
+            allowed_hosts=["localhost", self.server.host],
+            email_to=EMAIL_TO_FILE,
+        )
+        key = self.client.call(
+            self.server.cert_path(), "kxd://" + self.server.host + "/k1"
+        )
+        self.assertEqual(key, self.server.keys["k1"])
+        self.assertEqual(len(self.emails), 1)
+        self.assertEqual(self.emails[0][0], ["me@example.com", "you@test.net"])
+        self.assertRegex(
+            self.emails[0][1].decode().replace("\r\n", "\n"),
+            textwrap.dedent(
+                """\
+                Date: .*
+                From: Key Exchange Daemon <kxd@127.0.0.1>
+                To: me@example.com, you@test.net
+                Subject: Access to key k1
+
+                Key: k1
+                Accessed by: .*
+                On: .*
+
+                Client certificate:
+                  Signature: .*
+                  Subject: O=kxd-tests-client
+
+                Authorizing chains:
+                  .*
+            """
+            ),
+        )
+
+    def test_no_emails(self):
+        self.server.new_key(
+            "k2",
+            allowed_clients=[self.client.cert()],
+            allowed_hosts=["localhost", self.server.host],
+        )
+        key = self.client.call(
+            self.server.cert_path(), "kxd://" + self.server.host + "/k2"
+        )
+        self.assertEqual(key, self.server.keys["k2"])
+        # Note that we did not set up email_to, so we don't expect any emails.
+        self.assertEqual(self.emails, [])
+
+
 if __name__ == "__main__":
     unittest.main()