git » kxd » master » tree

[master] / kxd / kxd.go

// kxd is a key exchange daemon.
//
// It serves blobs of data (keys) over https, authenticating and authorizing
// the clients using SSL certificates, and notifying upon key accesses.
//
// It can be used to get keys remotely instead of using local storage.
// The main use case is to get keys to open dm-crypt devices automatically,
// without having to store them on the machine.
package main

import (
	"crypto/tls"
	"crypto/x509"
	"flag"
	"fmt"
	"io"
	"log"
	"log/syslog"
	"net/http"
	"os"
	"path"
	"strings"
)

var port = flag.Int(
	"port", 19840, "Port to listen on")
var ipAddr = flag.String(
	"ip_addr", "", "IP address to listen on")
var dataDir = flag.String(
	"data_dir", "/etc/kxd/data", "Data directory")
var certFile = flag.String(
	"cert", "/etc/kxd/cert.pem", "Certificate")
var keyFile = flag.String(
	"key", "/etc/kxd/key.pem", "Private key")
var smtpAddr = flag.String(
	"smtp_addr", "", "Address of the SMTP server to use to send emails")
var emailFrom = flag.String(
	"email_from", "", "Email address to send email from")
var logFile = flag.String(
	"logfile", "", "File to write logs to, use '-' for stdout")
var hookPath = flag.String(
	"hook", "/etc/kxd/hook",
	"Hook to run before authorizing keys (skipped if it doesn't exist)")

// Logger we will use to log entries.
var logging *log.Logger

// Request is our wrap around http.Request, so we can augment it with custom
// methods.
type Request struct {
	*http.Request
}

// Printf is a wrapper for fmt.Printf+logging.Output, which prefixes a string
// identifying this request.
func (req *Request) Printf(format string, a ...interface{}) {
	msg := fmt.Sprintf(format, a...)
	msg = fmt.Sprintf("%s %s %s", req.RemoteAddr, req.URL.Path, msg)
	logging.Output(2, msg)
}

// KeyPath returns the path to the requested key, extracting it from the URL.
func (req *Request) KeyPath() (string, error) {
	s := strings.Split(req.URL.Path, "/")

	// We expect the path to be "/v1/path/to/key".
	if len(s) < 2 || !(s[0] == "" || s[1] == "v1") {
		return "", fmt.Errorf("invalid path %q", s)
	}

	return strings.Join(s[2:], "/"), nil
}

func certToString(cert *x509.Certificate) string {
	return fmt.Sprintf(
		"(0x%.8s ou:%s)",
		fmt.Sprintf("%x", cert.Signature),
		cert.Subject.OrganizationalUnit)
}

// ChainToString makes a human-readable string out of the given certificate
// chain.
func ChainToString(chain []*x509.Certificate) (s string) {
	for i, cert := range chain {
		s += certToString(cert)
		if i < len(chain)-1 {
			s += " -> "
		}
	}
	return s
}

// HandlerV1 handles /v1/ key requests.
func HandlerV1(w http.ResponseWriter, httpreq *http.Request) {
	req := Request{httpreq}
	if len(req.TLS.PeerCertificates) <= 0 {
		req.Printf("Rejecting request without certificate")
		http.Error(w, "Client certificate not provided",
			http.StatusNotAcceptable)
		return
	}

	keyPath, err := req.KeyPath()
	if err != nil {
		req.Printf("Rejecting request with invalid key path: %s", err)
		http.Error(w, "Invalid key path", http.StatusNotAcceptable)
		return
	}

	// Be extra paranoid and reject keys with "..", even if they're valid
	// (e.g. "/v1/x..y" is valid, but will get rejected anyway).
	if strings.Contains(keyPath, "..") {
		req.Printf("Rejecting because requested key %q contained '..'",
			keyPath)
		req.Printf("Full request: %+v", *req.Request)
		http.Error(w, "Invalid key path", http.StatusNotAcceptable)
		return
	}

	realKeyPath := path.Clean(*dataDir + "/" + keyPath)
	keyConf := NewKeyConfig(realKeyPath)

	exists, err := keyConf.Exists()
	if err != nil {
		req.Printf("Error checking key path %q: %s", keyPath, err)
		http.Error(w, "Error checking key",
			http.StatusInternalServerError)
		return
	}
	if !exists {
		req.Printf("Unknown key path %q", keyPath)
		http.Error(w, "Unknown key", http.StatusNotFound)
		return
	}

	if err = keyConf.LoadClientCerts(); err != nil {
		req.Printf("Error loading certs: %s", err)
		http.Error(w, "Error loading certs",
			http.StatusInternalServerError)
		return
	}

	if err = keyConf.LoadAllowedHosts(); err != nil {
		req.Printf("Error loading allowed hosts: %s", err)
		http.Error(w, "Error loading allowed hosts",
			http.StatusInternalServerError)
		return
	}

	err = keyConf.IsHostAllowed(req.RemoteAddr)
	if err != nil {
		req.Printf("Host not allowed: %s", err)
		http.Error(w, "Host not allowed", http.StatusForbidden)
		return
	}

	validChains := keyConf.IsAnyCertAllowed(req.TLS.PeerCertificates)
	if validChains == nil {
		req.Printf("No allowed certificate found")
		http.Error(w, "No allowed certificate found",
			http.StatusForbidden)
		return
	}

	keyData, err := keyConf.Key()
	if err != nil {
		req.Printf("Error getting key data: %s", err)
		http.Error(w, "Error getting key data",
			http.StatusInternalServerError)
		return
	}

	err = RunHook(keyConf, &req, validChains)
	if err != nil {
		req.Printf("Prevented by hook: %s", err)
		http.Error(w, "Prevented by hook", http.StatusForbidden)
		return
	}

	req.Printf("Allowing request to %s", certToString(validChains[0][0]))

	err = SendMail(keyConf, &req, validChains)
	if err != nil {
		req.Printf("Error sending notification: %s", err)
		http.Error(w, "Error sending notification",
			http.StatusInternalServerError)
		return
	}

	w.Header().Set("Content-Type", "application/octet-stream")
	w.Write(keyData)
}

func initLog() {
	var err error
	var logfd io.Writer

	if *logFile == "-" {
		logfd = os.Stdout
	} else if *logFile != "" {
		logfd, err = os.OpenFile(*logFile,
			os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0666)
		if err != nil {
			log.Fatalf("Error opening log file %s: %s",
				*logFile, err)
		}
	} else {
		logfd, err = syslog.New(
			syslog.LOG_INFO|syslog.LOG_DAEMON, "kxd")
		if err != nil {
			log.Fatalf("Error opening syslog: %s", err)
		}
	}

	logging = log.New(logfd, "",
		log.Ldate|log.Ltime|log.Lmicroseconds|log.Lshortfile)
}

func main() {
	flag.Parse()

	initLog()

	if *smtpAddr == "" {
		logging.Print(
			"WARNING: No emails will be sent, use --smtp_addr")
	}

	if *emailFrom == "" {
		// Try to get a sane default if not provided, using
		// kxd@<smtp host>.
		*emailFrom = fmt.Sprintf("kxd@%s",
			strings.Split(*smtpAddr, ":")[0])
	}

	listenAddr := fmt.Sprintf("%s:%d", *ipAddr, *port)

	tlsConfig := tls.Config{
		ClientAuth: tls.RequireAnyClientCert,
	}

	server := http.Server{
		Addr:      listenAddr,
		TLSConfig: &tlsConfig,
		ErrorLog:  logging,
	}

	http.HandleFunc("/v1/", HandlerV1)

	logging.Printf("Listening on %s", listenAddr)
	err := server.ListenAndServeTLS(*certFile, *keyFile)
	if err != nil {
		logging.Fatal(err)
	}
}