author | Alberto Bertogli
<albertito@blitiri.com.ar> 2017-03-22 14:27:20 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2017-03-22 20:19:40 UTC |
parent | 76b18c675c3d6d47a94900cb25e104a39ef0b18c |
kxd/hook.go | +75 | -0 |
kxd/kxd.go | +10 | -0 |
tests/run_tests | +45 | -1 |
diff --git a/kxd/hook.go b/kxd/hook.go new file mode 100644 index 0000000..9d3405d --- /dev/null +++ b/kxd/hook.go @@ -0,0 +1,75 @@ +package main + +import ( + "context" + "crypto/x509" + "fmt" + "os" + "os/exec" + "strings" + "time" +) + +// RunHook runs the hook, returns an error if the request is not allowed (or +// there were problems with the hook; we don't make the distinction for now). +// +// Note that if the hook flag is not set, or points to a non-existing path, +// then we allow the request. +func RunHook(kc *KeyConfig, req *Request, chains [][]*x509.Certificate) error { + if *hookPath == "" { + return nil + } + + if _, err := os.Stat(*hookPath); os.IsNotExist(err) { + req.Printf("Hook not present, skipping") + return nil + } + + ctx, cancel := context.WithDeadline(context.Background(), + time.Now().Add(1*time.Minute)) + defer cancel() + cmd := exec.CommandContext(ctx, *hookPath) + + // Run the hook from the data directory. + cmd.Dir = *dataDir + + // Prepare the environment, copying some common variables so the hook has + // someting reasonable, and then setting the specific ones for this case. + for _, v := range strings.Fields("USER PWD SHELL PATH") { + cmd.Env = append(cmd.Env, v+"="+os.Getenv(v)) + } + + keyPath, err := req.KeyPath() + if err != nil { + return err + } + cmd.Env = append(cmd.Env, "KEY_PATH="+keyPath) + + cmd.Env = append(cmd.Env, "REMOTE_ADDR="+req.RemoteAddr) + cmd.Env = append(cmd.Env, "MAIL_FROM="+*emailFrom) + if emailTo, _ := kc.EmailTo(); emailTo != nil { + cmd.Env = append(cmd.Env, "EMAIL_TO="+strings.Join(emailTo, " ")) + } + + clientCert := chains[0][0] + cmd.Env = append(cmd.Env, + fmt.Sprintf("CLIENT_CERT_SIGNATURE=%x", clientCert.Signature)) + cmd.Env = append(cmd.Env, + "CLIENT_CERT_SUBJECT="+NameToString(clientCert.Subject)) + + for i, chain := range chains { + cmd.Env = append(cmd.Env, + fmt.Sprintf("CHAIN_%d=%s", i, ChainToString(chain))) + } + + _, err = cmd.Output() + if err != nil { + if ee, ok := err.(*exec.ExitError); ok { + err = fmt.Errorf("exited with error: %v -- stderr: %q", + ee.String(), ee.Stderr) + } + return err + } + + return nil +} diff --git a/kxd/kxd.go b/kxd/kxd.go index e8e24bc..128c65c 100644 --- a/kxd/kxd.go +++ b/kxd/kxd.go @@ -38,6 +38,9 @@ var emailFrom = flag.String( "email_from", "", "Email address to send email from") var logFile = flag.String( "logfile", "", "File to write logs to, use '-' for stdout") +var hookPath = flag.String( + "hook", "/etc/kxd/hook", + "Hook to run before authorizing keys (skipped if it doesn't exist)") // Logger we will use to log entries. var logging *log.Logger @@ -167,6 +170,13 @@ func HandlerV1(w http.ResponseWriter, httpreq *http.Request) { return } + err = RunHook(keyConf, &req, validChains) + if err != nil { + req.Printf("Prevented by hook: %s", err) + http.Error(w, "Prevented by hook", http.StatusForbidden) + return + } + req.Printf("Allowing request to %s", certToString(validChains[0][0])) err = SendMail(keyConf, &req, validChains) diff --git a/tests/run_tests b/tests/run_tests index 2973c93..a1c5691 100755 --- a/tests/run_tests +++ b/tests/run_tests @@ -26,6 +26,7 @@ import ssl import subprocess import sys import tempfile +import textwrap import time import unittest @@ -213,7 +214,9 @@ def launch_daemon(cfg): "--data_dir=%s/data" % cfg, "--key=%s/key.pem" % cfg, "--cert=%s/cert.pem" % cfg, - "--logfile=%s/log" % cfg] + "--logfile=%s/log" % cfg, + "--hook=%s/hook" % cfg + ] print "Launching server: ", " ".join(args) return subprocess.Popen(args) @@ -225,6 +228,7 @@ class TestCase(unittest.TestCase): self.daemon = None self.ca = None # pylint: disable=invalid-name self.launch_server(self.server) + self.longMessage = True def tearDown(self): if self.daemon: @@ -552,5 +556,45 @@ class Delegation(TestCase): self.assertEquals(key, self.server.keys["k1"]) +class Hook(TestCase): + """Test cases for hook support.""" + + HOOK_SCRIPT_TMPL = textwrap.dedent(""" + #!/bin/sh + pwd > hook-output + env >> hook-output + exit {exit_code} + """.strip()) + + def write_hook(self, exit_code): + path = self.server.path + "/hook" + script = self.HOOK_SCRIPT_TMPL.format(exit_code=exit_code) + + open(path, "w").write(script) + os.chmod(path, 0770) + + def test_simple(self): + self.write_hook(exit_code=0) + + # Normal successful case. + self.server.new_key("k1", + allowed_clients=[self.client.cert()], + allowed_hosts=["localhost"]) + key = self.client.call(self.server.cert_path(), "kxd://localhost/k1") + self.assertEquals(key, self.server.keys["k1"]) + + hook_out = open(self.server.path + "/data/hook-output").read() + self.assertIn("CLIENT_CERT_SUBJECT=OU=kxd-tests-client", hook_out) + + # Failure caused by the hook exiting with error. + self.write_hook(exit_code=1) + self.assertClientFails("kxd://localhost/k1", "Prevented by hook") + + # Failure caused by the hook not being executable. + self.write_hook(exit_code=0) + os.chmod(self.server.path + "/hook", 0660) + self.assertClientFails("kxd://localhost/k1", "Prevented by hook") + + if __name__ == "__main__": unittest.main()