git » kxd » commit d08f76c

kxd: Simplify key path cleanup and checking

author Alberto Bertogli
2024-09-08 18:15:35 UTC
committer Alberto Bertogli
2024-09-08 18:15:35 UTC
parent e7b953e15713ca141059e5d683e4432ea619755c

kxd: Simplify key path cleanup and checking

The key path cleanup is unnecessary complex and it leaves some weird
gaps. This patch simplifies it, makes it more streamline and direct, and
also adds some tests.

The tests are in Go because these conditions can't be reproduced end to
end, the checks are for defense-in-depth.

kxd/kxd.go +24 -16
kxd/kxd_test.go +69 -0
tests/cover.sh +7 -3

diff --git a/kxd/kxd.go b/kxd/kxd.go
index 1e7a6b5..6234538 100644
--- a/kxd/kxd.go
+++ b/kxd/kxd.go
@@ -11,6 +11,7 @@ package main
 import (
 	"crypto/tls"
 	"crypto/x509"
+	"errors"
 	"flag"
 	"fmt"
 	"io"
@@ -61,16 +62,33 @@ func (req *Request) Printf(format string, a ...interface{}) {
 	logging.Output(2, msg)
 }
 
+var (
+	errInvalidVersion = errors.New("invalid version")
+	errHasDotDot      = errors.New("path contains '..'")
+)
+
 // KeyPath returns the path to the requested key, extracting it from the URL.
 func (req *Request) KeyPath() (string, error) {
-	s := strings.Split(req.URL.Path, "/")
+	// Clean the path to remove any noise.
+	kp := path.Clean(req.URL.Path)
+
+	// Must start with "/v1/". Doing this after the Clean also ensures there
+	// is something else after the version (because Clean removes trailing
+	// slashes).
+	kp, hasVersion := strings.CutPrefix(kp, "/v1/")
+	if !hasVersion {
+		return "", errInvalidVersion
+	}
 
-	// We expect the path to be "/v1/path/to/key".
-	if len(s) < 2 || !(s[0] == "" || s[1] == "v1") {
-		return "", fmt.Errorf("invalid path %q", s)
+	// Be extra paranoid and reject keys with "..", even if they're valid
+	// (e.g. "/v1/x..y" is valid, but will get rejected anyway).
+	// Note requests like this shouldn't reach this stage anyway, due to the
+	// http library processing and the path.Clean above.
+	if strings.Contains(kp, "..") {
+		return "", errHasDotDot
 	}
 
-	return strings.Join(s[2:], "/"), nil
+	return kp, nil
 }
 
 func certToString(cert *x509.Certificate) string {
@@ -109,17 +127,7 @@ func HandlerV1(w http.ResponseWriter, httpreq *http.Request) {
 		return
 	}
 
-	// Be extra paranoid and reject keys with "..", even if they're valid
-	// (e.g. "/v1/x..y" is valid, but will get rejected anyway).
-	if strings.Contains(keyPath, "..") {
-		req.Printf("Rejecting because requested key %q contained '..'",
-			keyPath)
-		req.Printf("Full request: %+v", *req.Request)
-		http.Error(w, "Invalid key path", http.StatusNotAcceptable)
-		return
-	}
-
-	realKeyPath := path.Clean(*dataDir + "/" + keyPath)
+	realKeyPath := path.Join(*dataDir, keyPath)
 	keyConf := NewKeyConfig(realKeyPath)
 
 	exists, err := keyConf.Exists()
diff --git a/kxd/kxd_test.go b/kxd/kxd_test.go
new file mode 100644
index 0000000..c7d8478
--- /dev/null
+++ b/kxd/kxd_test.go
@@ -0,0 +1,69 @@
+package main
+
+import (
+	"crypto/tls"
+	"errors"
+	"log"
+	"net/http"
+	"net/http/httptest"
+	"net/url"
+	"testing"
+)
+
+func init() {
+	// Initialize the global logger. The testing framework will capture the
+	// output and use it as needed.
+	logging = log.Default()
+}
+
+func TestKeyPath(t *testing.T) {
+	cases := []struct {
+		url  string
+		want string
+		err  error
+	}{
+		{"/v1/key", "key", nil},
+		{"/v1/path/to/key", "path/to/key", nil},
+		{"/v1/path/to/key/", "path/to/key", nil},
+
+		{"", "", errInvalidVersion},
+		{"/", "", errInvalidVersion},
+		{"/v1", "", errInvalidVersion},
+		{"/v1/", "", errInvalidVersion},
+		{"/v1//", "", errInvalidVersion},
+		{"v1/path/to/key/", "", errInvalidVersion},
+		{"/v2/path/to/key", "", errInvalidVersion},
+
+		{"/v1/a..b", "", errHasDotDot},
+	}
+
+	for _, c := range cases {
+		u, _ := url.Parse(c.url)
+		req := Request{&http.Request{
+			URL: u,
+		}}
+		got, err := req.KeyPath()
+		if got != c.want {
+			t.Errorf("%q KeyPath == %q, want %q", c.url, got, c.want)
+		}
+		if !errors.Is(err, c.err) {
+			t.Errorf("%q KeyPath error == %v, want %v", c.url, err, c.err)
+		}
+	}
+}
+
+func TestHandlerWithoutCert(t *testing.T) {
+	// Reject request without a client certificate.
+	// Usually the http server doesn't let it get this far, so we have a
+	// custom test for it.
+	req := &http.Request{
+		URL: &url.URL{},
+		TLS: &tls.ConnectionState{},
+	}
+	w := httptest.NewRecorder()
+	HandlerV1(w, req)
+	if w.Code != http.StatusNotAcceptable {
+		t.Errorf("HandlerV1(%v) == %d, want %d",
+			req, w.Code, http.StatusNotAcceptable)
+	}
+}
diff --git a/tests/cover.sh b/tests/cover.sh
index d9c211c..47927b7 100755
--- a/tests/cover.sh
+++ b/tests/cover.sh
@@ -7,12 +7,16 @@ cd "$(realpath `dirname ${0}`)/../"
 make GOFLAGS="-cover -covermode=count"
 
 rm -rf .coverage/
-mkdir -p .coverage
+mkdir -p .coverage/{go,sh,all}
 export GOCOVERDIR="${PWD}/.coverage"
 
-tests/run_tests -b
+go test -covermode=count -coverpkg=./... ./... \
+	-args -test.gocoverdir="${GOCOVERDIR}/go"
 
-go tool covdata textfmt -i "${GOCOVERDIR}" -o .cover-merged.out
+GOCOVERDIR="${GOCOVERDIR}/sh" tests/run_tests -b
+
+go tool covdata merge -i "${GOCOVERDIR}/go,${GOCOVERDIR}/sh" -o "${GOCOVERDIR}/all"
+go tool covdata textfmt -i "${GOCOVERDIR}/all" -o .cover-merged.out
 go tool cover -func=.cover-merged.out | grep -i total
 go tool cover -html=.cover-merged.out -o .cover-kxd.html