author | Alberto Bertogli
<albertito@blitiri.com.ar> 2024-08-16 00:41:41 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2024-09-08 10:14:19 UTC |
parent | ccc0a2ab4f2520ce56769c550c58cb64738fab29 |
tests/run_tests | +132 | -12 |
diff --git a/tests/run_tests b/tests/run_tests index cec676c..4d2fd34 100755 --- a/tests/run_tests +++ b/tests/run_tests @@ -23,6 +23,7 @@ import ssl import subprocess import tempfile import textwrap +import threading import time import tracemalloc import unittest @@ -179,7 +180,7 @@ class StaticConfig(Config): raise NotImplementedError("StaticConfig does not support gen_cert") -def launch_daemon(cfg): +def launch_daemon(cfg, smtp_addr=None): args = [ BINS + "/kxd", "--data_dir=%s/data" % cfg, @@ -188,6 +189,8 @@ def launch_daemon(cfg): "--logfile=%s/log" % cfg, "--hook=%s/hook" % cfg, ] + if smtp_addr: + args.append("--smtp_addr=%s:%s" % smtp_addr) print("Launching server: ", " ".join(args)) return subprocess.Popen(args) @@ -197,6 +200,61 @@ def read_all(fname): return fd.read() +def receive_emails(): + """Receive emails from the server. + + Return: + - The server address. + - A list where we will put the emails: each email is a tuple of + (destination addresses, body). + """ + server = socket.create_server(("localhost", 0)) + server.listen(1) + addr = server.getsockname() + + emails = [] + + def handler(): + # Complete a simple SMTP transaction. + sock, addr = server.accept() + sock.sendall(b"220 localhost SMTP\n") + body = b"" + rcpt_to = [] + while True: + data = sock.recv(1024) + if not data: + break + if data.startswith(b"HELO"): + sock.sendall(b"250 localhost\n") + elif data.startswith(b"MAIL FROM"): + sock.sendall(b"250 OK\n") + elif data.startswith(b"RCPT TO"): + rcpt_to.append(data.split(b"<")[1].split(b">")[0].decode()) + sock.sendall(b"250 OK\n") + elif data.startswith(b"DATA"): + sock.sendall(b"354 Start data\n") + body = b"" + while True: + body += sock.recv(1024) + if body.endswith(b"\r\n.\r\n"): + sock.sendall(b"250 OK\n") + break + elif data.startswith(b"QUIT"): + sock.sendall(b"221 Chau\n") + break + else: + sock.sendall(b"500 Unknown command\n") + + emails.append((rcpt_to, body)) + sock.close() + server.close() + + server_thread = threading.Thread(target=handler, daemon=True) + server_thread.start() + + return addr, emails + + class TestCase(unittest.TestCase): def setUp(self): self.server = ServerConfig() @@ -210,8 +268,8 @@ class TestCase(unittest.TestCase): self.daemon.terminate() self.daemon.wait() - def launch_server(self, server): - self.daemon = launch_daemon(server.path) + def launch_server(self, server, smtp_addr=None): + self.daemon = launch_daemon(server.path, smtp_addr) # Wait for the server to start accepting connections. deadline = time.time() + 5 @@ -620,6 +678,15 @@ class BrokenServerConfig(TestCase): self.assertClientFails("kxd://localhost/k1", "404 Not Found") +EMAIL_TO_FILE = textwrap.dedent( + """ + # Comment. + me@example.com + you@test.net + """.strip() +) + + class Hook(TestCase): """Test cases for hook support.""" @@ -632,14 +699,6 @@ class Hook(TestCase): """.strip() ) - EMAIL_TO_FILE = textwrap.dedent( - """ - # Comment. - me@example.com - you@test.net - """.strip() - ) - def write_hook(self, exit_code): path = self.server.path + "/hook" script = self.HOOK_SCRIPT_TMPL.format(exit_code=exit_code) @@ -679,7 +738,7 @@ class Hook(TestCase): "k2", allowed_clients=[self.client.cert()], allowed_hosts=["localhost"], - email_to=self.EMAIL_TO_FILE, + email_to=EMAIL_TO_FILE, ) key = self.client.call(self.server.cert_path(), "kxd://localhost/k2") self.assertEqual(key, self.server.keys["k2"]) @@ -689,5 +748,66 @@ class Hook(TestCase): self.assertIn("EMAIL_TO=me@example.com you@test.net", hook_out) +class Emails(TestCase): + """Tests for email notifications.""" + + def setUp(self): + self.smtp_addr, self.emails = receive_emails() + self.server = ServerConfig() + self.client = ClientConfig() + self.daemon = None + self.ca = None # pylint: disable=invalid-name + self.launch_server(self.server, smtp_addr=self.smtp_addr) + + def test_emails(self): + self.server.new_key( + "k1", + allowed_clients=[self.client.cert()], + allowed_hosts=["localhost", self.server.host], + email_to=EMAIL_TO_FILE, + ) + key = self.client.call( + self.server.cert_path(), "kxd://" + self.server.host + "/k1" + ) + self.assertEqual(key, self.server.keys["k1"]) + self.assertEqual(len(self.emails), 1) + self.assertEqual(self.emails[0][0], ["me@example.com", "you@test.net"]) + self.assertRegex( + self.emails[0][1].decode().replace("\r\n", "\n"), + textwrap.dedent( + """\ + Date: .* + From: Key Exchange Daemon <kxd@127.0.0.1> + To: me@example.com, you@test.net + Subject: Access to key k1 + + Key: k1 + Accessed by: .* + On: .* + + Client certificate: + Signature: .* + Subject: O=kxd-tests-client + + Authorizing chains: + .* + """ + ), + ) + + def test_no_emails(self): + self.server.new_key( + "k2", + allowed_clients=[self.client.cert()], + allowed_hosts=["localhost", self.server.host], + ) + key = self.client.call( + self.server.cert_path(), "kxd://" + self.server.host + "/k2" + ) + self.assertEqual(key, self.server.keys["k2"]) + # Note that we did not set up email_to, so we don't expect any emails. + self.assertEqual(self.emails, []) + + if __name__ == "__main__": unittest.main()