git » kxd » commit 7403299

tests: Use black to auto-format test code

author Alberto Bertogli
2020-08-21 01:57:25 UTC
committer Alberto Bertogli
2020-08-21 15:53:37 UTC
parent 51986345eb9197dd52a479336c8bc44388709cd4

tests: Use black to auto-format test code

This patch uses the `black` utility to auto-format the test code, for
increased consistency.

Makefile +1 -0
tests/run_tests +93 -80

diff --git a/Makefile b/Makefile
index 649b94d..df795fd 100644
--- a/Makefile
+++ b/Makefile
@@ -18,6 +18,7 @@ kxc:
 
 fmt:
 	gofmt -w .
+	black tests/run_tests
 
 vet:
 	$(GO) vet ./...
diff --git a/tests/run_tests b/tests/run_tests
index 62079fb..f0e8634 100755
--- a/tests/run_tests
+++ b/tests/run_tests
@@ -10,10 +10,8 @@ It will create different test configurations and run the compiled server and
 client under various conditions, to make sure they behave as intended.
 """
 
-# NOTE: Please run "pylint3 --rcfile=.pylintrc run_tests" after making changes,
-# to make sure the file has a reasonably uniform coding style.
-# You can also use "autopep8 -d --ignore=E301,E26 run_tests" to help with
-# this, but make sure the output looks sane.
+# NOTE: Please run "black run_tests" after making changes, to to make sure the
+# file has a reasonably uniform coding style.
 
 
 import contextlib
@@ -39,17 +37,16 @@ tracemalloc.start()
 
 # Path to our built binaries; used to run the server and client for testing
 # purposes.
-BINS = os.path.abspath(
-    os.path.dirname(os.path.realpath(__file__)) + "/../out")
+BINS = os.path.abspath(os.path.dirname(os.path.realpath(__file__)) + "/../out")
 
 TEMPDIR = "/does/not/exist"
 
 # User the script is running as. Just informational, for troubleshooting
 # purposes, so we don't care if it's missing.
-LOGNAME = os.environ.get('LOGNAME', 'unknown')
+LOGNAME = os.environ.get("LOGNAME", "unknown")
 
 
-def setUpModule():    # pylint: disable=invalid-name
+def setUpModule():  # pylint: disable=invalid-name
     if not os.path.isfile(BINS + "/kxd"):
         raise RuntimeError("kxd not found at " + BINS + "/kxd")
     if not os.path.isfile(BINS + "/kxc"):
@@ -57,15 +54,15 @@ def setUpModule():    # pylint: disable=invalid-name
     if not os.path.isfile(BINS + "/kxgencert"):
         raise RuntimeError("kxgencert not found at " + BINS + "/kxgencert")
 
-    global TEMPDIR    # pylint: disable=global-statement
+    global TEMPDIR  # pylint: disable=global-statement
     TEMPDIR = tempfile.mkdtemp(prefix="kxdtest-")
 
 
-def tearDownModule():   # pylint: disable=invalid-name
+def tearDownModule():  # pylint: disable=invalid-name
     # Remove the temporary directory only on success.
     # Be extra paranoid about removing.
     # TODO: Only remove on success.
-    if os.environ.get('KEEPTMP'):
+    if os.environ.get("KEEPTMP"):
         return
     if len(TEMPDIR) > 10 and not TEMPDIR.startswith("/home"):
         shutil.rmtree(TEMPDIR)
@@ -86,10 +83,12 @@ class Config:
 
     def gen_cert(self):
         try:
-            cmd = [BINS + "/kxgencert",
-                   "-organization=kxd-tests-%s" % self.name,
-                   "-key=" + self.key_path(),
-                   "-cert=" + self.cert_path()]
+            cmd = [
+                BINS + "/kxgencert",
+                "-organization=kxd-tests-%s" % self.name,
+                "-key=" + self.key_path(),
+                "-cert=" + self.cert_path(),
+            ]
             subprocess.check_output(cmd, stderr=subprocess.STDOUT)
         except subprocess.CalledProcessError as err:
             print("kxgencert call failed, output: %r" % err.output)
@@ -136,11 +135,13 @@ class ClientConfig(Config):
         self.gen_cert()
 
     def call(self, server_cert, url):
-        args = [BINS + "/kxc",
-                "--client_cert=%s/cert.pem" % self.path,
-                "--client_key=%s/key.pem" % self.path,
-                "--server_cert=%s" % server_cert,
-                url]
+        args = [
+            BINS + "/kxc",
+            "--client_cert=%s/cert.pem" % self.path,
+            "--client_key=%s/key.pem" % self.path,
+            "--server_cert=%s" % server_cert,
+            url,
+        ]
         try:
             print("Running client:", " ".join(args))
             return subprocess.check_output(args, stderr=subprocess.STDOUT)
@@ -150,18 +151,20 @@ class ClientConfig(Config):
 
 
 def launch_daemon(cfg):
-    args = [BINS + "/kxd",
-            "--data_dir=%s/data" % cfg,
-            "--key=%s/key.pem" % cfg,
-            "--cert=%s/cert.pem" % cfg,
-            "--logfile=%s/log" % cfg,
-            "--hook=%s/hook" % cfg]
+    args = [
+        BINS + "/kxd",
+        "--data_dir=%s/data" % cfg,
+        "--key=%s/key.pem" % cfg,
+        "--cert=%s/cert.pem" % cfg,
+        "--logfile=%s/log" % cfg,
+        "--hook=%s/hook" % cfg,
+    ]
     print("Launching server: ", " ".join(args))
     return subprocess.Popen(args)
 
 
 def read_all(fname):
-    with open(fname) as fd:    # pylint: disable=invalid-name
+    with open(fname) as fd:  # pylint: disable=invalid-name
         return fd.read()
 
 
@@ -170,7 +173,7 @@ class TestCase(unittest.TestCase):
         self.server = ServerConfig()
         self.client = ClientConfig()
         self.daemon = None
-        self.ca = None    # pylint: disable=invalid-name
+        self.ca = None  # pylint: disable=invalid-name
         self.launch_server(self.server)
 
     def tearDown(self):
@@ -185,8 +188,7 @@ class TestCase(unittest.TestCase):
         deadline = time.time() + 5
         while time.time() < deadline:
             try:
-                with socket.create_connection(
-                        ("localhost", 19840), timeout=5):
+                with socket.create_connection(("localhost", 19840), timeout=5):
                     break
             except socket.error:
                 continue
@@ -212,6 +214,7 @@ class TestCase(unittest.TestCase):
 # Test cases.
 #
 
+
 class Simple(TestCase):
     """Simple test cases for common (mis)configurations."""
 
@@ -221,9 +224,9 @@ class Simple(TestCase):
         # overhead of creating the certificates and bringing up the server.
 
         # Normal successful case.
-        self.server.new_key("k1",
-                            allowed_clients=[self.client.cert()],
-                            allowed_hosts=["localhost"])
+        self.server.new_key(
+            "k1", allowed_clients=[self.client.cert()], allowed_hosts=["localhost"]
+        )
         key = self.client.call(self.server.cert_path(), "kxd://localhost/k1")
         self.assertEqual(key, self.server.keys["k1"])
 
@@ -232,15 +235,15 @@ class Simple(TestCase):
 
         # No certificates allowed -> 403.
         self.server.new_key("k3", allowed_hosts=["localhost"])
-        self.assertClientFails("kxd://localhost/k3",
-                               "403 Forbidden.*No allowed certificate found")
+        self.assertClientFails(
+            "kxd://localhost/k3", "403 Forbidden.*No allowed certificate found"
+        )
 
         # Host not allowed -> 403.
-        self.server.new_key("k4",
-                            allowed_clients=[self.client.cert()],
-                            allowed_hosts=[])
-        self.assertClientFails("kxd://localhost/k4",
-                               "403 Forbidden.*Host not allowed")
+        self.server.new_key(
+            "k4", allowed_clients=[self.client.cert()], allowed_hosts=[]
+        )
+        self.assertClientFails("kxd://localhost/k4", "403 Forbidden.*Host not allowed")
 
         # Nothing allowed -> 403.
         # We don't restrict the reason of failure, that's not defined in this
@@ -251,9 +254,11 @@ class Simple(TestCase):
 
         # We tell the client to expect the server certificate to be the client
         # one, which is never going to work.
-        self.assertClientFails("kxd://localhost/k1",
-                               "certificate signed by unknown authority",
-                               cert_path=self.client.cert_path())
+        self.assertClientFails(
+            "kxd://localhost/k1",
+            "certificate signed by unknown authority",
+            cert_path=self.client.cert_path(),
+        )
 
 
 class Multiples(TestCase):
@@ -264,10 +269,11 @@ class Multiples(TestCase):
         self.client2 = ClientConfig(name="client2")
 
     def test_two_clients(self):
-        self.server.new_key("k1",
-                            allowed_clients=[
-                                self.client.cert(), self.client2.cert()],
-                            allowed_hosts=["localhost"])
+        self.server.new_key(
+            "k1",
+            allowed_clients=[self.client.cert(), self.client2.cert()],
+            allowed_hosts=["localhost"],
+        )
         key = self.client.call(self.server.cert_path(), "kxd://localhost/k1")
         self.assertEqual(key, self.server.keys["k1"])
 
@@ -275,15 +281,17 @@ class Multiples(TestCase):
         self.assertEqual(key, self.server.keys["k1"])
 
         # Only one client allowed.
-        self.server.new_key("k2",
-                            allowed_clients=[self.client.cert()],
-                            allowed_hosts=["localhost"])
+        self.server.new_key(
+            "k2", allowed_clients=[self.client.cert()], allowed_hosts=["localhost"]
+        )
         key = self.client.call(self.server.cert_path(), "kxd://localhost/k2")
         self.assertEqual(key, self.server.keys["k2"])
 
-        self.assertClientFails("kxd://localhost/k2",
-                               "403 Forbidden.*No allowed certificate found",
-                               client=self.client2)
+        self.assertClientFails(
+            "kxd://localhost/k2",
+            "403 Forbidden.*No allowed certificate found",
+            client=self.client2,
+        )
 
     def test_many_keys(self):
         keys = ["a", "d/e", "a/b/c", "d/"]
@@ -291,15 +299,16 @@ class Multiples(TestCase):
             self.server.new_key(
                 key,
                 allowed_clients=[self.client.cert(), self.client2.cert()],
-                allowed_hosts=["localhost"])
+                allowed_hosts=["localhost"],
+            )
 
         for key in keys:
-            data = self.client.call(self.server.cert_path(),
-                                    "kxd://localhost/%s" % key)
+            data = self.client.call(self.server.cert_path(), "kxd://localhost/%s" % key)
             self.assertEqual(data, self.server.keys[key])
 
-            data = self.client2.call(self.server.cert_path(),
-                                     "kxd://localhost/%s" % key)
+            data = self.client2.call(
+                self.server.cert_path(), "kxd://localhost/%s" % key
+            )
             self.assertEqual(data, self.server.keys[key])
 
         self.assertClientFails("kxd://localhost/a/b", "404 Not Found")
@@ -354,9 +363,12 @@ class TrickyRequests(TestCase):
 
     def test_path_with_dotdot(self):
         """Requests with '..'."""
-        conn = self.https_connection("localhost", 19840,
-                                     key_file=self.client.key_path(),
-                                     cert_file=self.client.cert_path())
+        conn = self.https_connection(
+            "localhost",
+            19840,
+            key_file=self.client.key_path(),
+            cert_file=self.client.cert_path(),
+        )
         conn.request("GET", "/v1/a/../b")
         response = conn.getresponse()
         conn.close()
@@ -367,16 +379,15 @@ class TrickyRequests(TestCase):
 
     def test_server_cert(self):
         rawsock = socket.create_connection(("localhost", 19840))
-        sock = ssl.wrap_socket(rawsock,
-                               keyfile=self.client.key_path(),
-                               certfile=self.client.cert_path())
+        sock = ssl.wrap_socket(
+            rawsock, keyfile=self.client.key_path(), certfile=self.client.cert_path()
+        )
 
         # We don't check the cipher itself, as it depends on the environment,
         # but we should be using >= 128 bit secrets.
         self.assertTrue(sock.cipher()[2] >= 128)
 
-        server_cert = ssl.DER_cert_to_PEM_cert(
-            sock.getpeercert(binary_form=True))
+        server_cert = ssl.DER_cert_to_PEM_cert(sock.getpeercert(binary_form=True))
         self.assertEqual(server_cert, self.server.cert())
         sock.close()
 
@@ -385,23 +396,23 @@ class BrokenServerConfig(TestCase):
     """Tests for a broken server config."""
 
     def test_broken_client_certs(self):
-        self.server.new_key("k1",
-                            allowed_clients=[self.client.cert()],
-                            allowed_hosts=["localhost"])
+        self.server.new_key(
+            "k1", allowed_clients=[self.client.cert()], allowed_hosts=["localhost"]
+        )
 
         # Corrupt the client certificate.
         with open(self.server.path + "/data/k1/allowed_clients", "tr+") as cfd:
             cfd.seek(30)
-            cfd.write('+/+BROKEN+/+')
+            cfd.write("+/+BROKEN+/+")
 
         self.assertClientFails(
-            "kxd://localhost/k1",
-            "Error loading certs|No allowed certificate found")
+            "kxd://localhost/k1", "Error loading certs|No allowed certificate found"
+        )
 
     def test_missing_key(self):
-        self.server.new_key("k1",
-                            allowed_clients=[self.client.cert()],
-                            allowed_hosts=["localhost"])
+        self.server.new_key(
+            "k1", allowed_clients=[self.client.cert()], allowed_hosts=["localhost"]
+        )
 
         os.unlink(self.server.path + "/data/k1/key")
         self.assertClientFails("kxd://localhost/k1", "404 Not Found")
@@ -410,12 +421,14 @@ class BrokenServerConfig(TestCase):
 class Hook(TestCase):
     """Test cases for hook support."""
 
-    HOOK_SCRIPT_TMPL = textwrap.dedent("""
+    HOOK_SCRIPT_TMPL = textwrap.dedent(
+        """
         #!/bin/sh
         pwd > hook-output
         env >> hook-output
         exit {exit_code}
-        """.strip())
+        """.strip()
+    )
 
     def write_hook(self, exit_code):
         path = self.server.path + "/hook"
@@ -429,9 +442,9 @@ class Hook(TestCase):
         self.write_hook(exit_code=0)
 
         # Normal successful case.
-        self.server.new_key("k1",
-                            allowed_clients=[self.client.cert()],
-                            allowed_hosts=["localhost"])
+        self.server.new_key(
+            "k1", allowed_clients=[self.client.cert()], allowed_hosts=["localhost"]
+        )
         key = self.client.call(self.server.cert_path(), "kxd://localhost/k1")
         self.assertEqual(key, self.server.keys["k1"])