author | Alberto Bertogli
<albertito@blitiri.com.ar> 2024-09-08 18:15:35 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2024-09-08 18:15:35 UTC |
parent | e7b953e15713ca141059e5d683e4432ea619755c |
kxd/kxd.go | +24 | -16 |
kxd/kxd_test.go | +69 | -0 |
tests/cover.sh | +7 | -3 |
diff --git a/kxd/kxd.go b/kxd/kxd.go index 1e7a6b5..6234538 100644 --- a/kxd/kxd.go +++ b/kxd/kxd.go @@ -11,6 +11,7 @@ package main import ( "crypto/tls" "crypto/x509" + "errors" "flag" "fmt" "io" @@ -61,16 +62,33 @@ func (req *Request) Printf(format string, a ...interface{}) { logging.Output(2, msg) } +var ( + errInvalidVersion = errors.New("invalid version") + errHasDotDot = errors.New("path contains '..'") +) + // KeyPath returns the path to the requested key, extracting it from the URL. func (req *Request) KeyPath() (string, error) { - s := strings.Split(req.URL.Path, "/") + // Clean the path to remove any noise. + kp := path.Clean(req.URL.Path) + + // Must start with "/v1/". Doing this after the Clean also ensures there + // is something else after the version (because Clean removes trailing + // slashes). + kp, hasVersion := strings.CutPrefix(kp, "/v1/") + if !hasVersion { + return "", errInvalidVersion + } - // We expect the path to be "/v1/path/to/key". - if len(s) < 2 || !(s[0] == "" || s[1] == "v1") { - return "", fmt.Errorf("invalid path %q", s) + // Be extra paranoid and reject keys with "..", even if they're valid + // (e.g. "/v1/x..y" is valid, but will get rejected anyway). + // Note requests like this shouldn't reach this stage anyway, due to the + // http library processing and the path.Clean above. + if strings.Contains(kp, "..") { + return "", errHasDotDot } - return strings.Join(s[2:], "/"), nil + return kp, nil } func certToString(cert *x509.Certificate) string { @@ -109,17 +127,7 @@ func HandlerV1(w http.ResponseWriter, httpreq *http.Request) { return } - // Be extra paranoid and reject keys with "..", even if they're valid - // (e.g. "/v1/x..y" is valid, but will get rejected anyway). - if strings.Contains(keyPath, "..") { - req.Printf("Rejecting because requested key %q contained '..'", - keyPath) - req.Printf("Full request: %+v", *req.Request) - http.Error(w, "Invalid key path", http.StatusNotAcceptable) - return - } - - realKeyPath := path.Clean(*dataDir + "/" + keyPath) + realKeyPath := path.Join(*dataDir, keyPath) keyConf := NewKeyConfig(realKeyPath) exists, err := keyConf.Exists() diff --git a/kxd/kxd_test.go b/kxd/kxd_test.go new file mode 100644 index 0000000..c7d8478 --- /dev/null +++ b/kxd/kxd_test.go @@ -0,0 +1,69 @@ +package main + +import ( + "crypto/tls" + "errors" + "log" + "net/http" + "net/http/httptest" + "net/url" + "testing" +) + +func init() { + // Initialize the global logger. The testing framework will capture the + // output and use it as needed. + logging = log.Default() +} + +func TestKeyPath(t *testing.T) { + cases := []struct { + url string + want string + err error + }{ + {"/v1/key", "key", nil}, + {"/v1/path/to/key", "path/to/key", nil}, + {"/v1/path/to/key/", "path/to/key", nil}, + + {"", "", errInvalidVersion}, + {"/", "", errInvalidVersion}, + {"/v1", "", errInvalidVersion}, + {"/v1/", "", errInvalidVersion}, + {"/v1//", "", errInvalidVersion}, + {"v1/path/to/key/", "", errInvalidVersion}, + {"/v2/path/to/key", "", errInvalidVersion}, + + {"/v1/a..b", "", errHasDotDot}, + } + + for _, c := range cases { + u, _ := url.Parse(c.url) + req := Request{&http.Request{ + URL: u, + }} + got, err := req.KeyPath() + if got != c.want { + t.Errorf("%q KeyPath == %q, want %q", c.url, got, c.want) + } + if !errors.Is(err, c.err) { + t.Errorf("%q KeyPath error == %v, want %v", c.url, err, c.err) + } + } +} + +func TestHandlerWithoutCert(t *testing.T) { + // Reject request without a client certificate. + // Usually the http server doesn't let it get this far, so we have a + // custom test for it. + req := &http.Request{ + URL: &url.URL{}, + TLS: &tls.ConnectionState{}, + } + w := httptest.NewRecorder() + HandlerV1(w, req) + if w.Code != http.StatusNotAcceptable { + t.Errorf("HandlerV1(%v) == %d, want %d", + req, w.Code, http.StatusNotAcceptable) + } +} diff --git a/tests/cover.sh b/tests/cover.sh index d9c211c..47927b7 100755 --- a/tests/cover.sh +++ b/tests/cover.sh @@ -7,12 +7,16 @@ cd "$(realpath `dirname ${0}`)/../" make GOFLAGS="-cover -covermode=count" rm -rf .coverage/ -mkdir -p .coverage +mkdir -p .coverage/{go,sh,all} export GOCOVERDIR="${PWD}/.coverage" -tests/run_tests -b +go test -covermode=count -coverpkg=./... ./... \ + -args -test.gocoverdir="${GOCOVERDIR}/go" -go tool covdata textfmt -i "${GOCOVERDIR}" -o .cover-merged.out +GOCOVERDIR="${GOCOVERDIR}/sh" tests/run_tests -b + +go tool covdata merge -i "${GOCOVERDIR}/go,${GOCOVERDIR}/sh" -o "${GOCOVERDIR}/all" +go tool covdata textfmt -i "${GOCOVERDIR}/all" -o .cover-merged.out go tool cover -func=.cover-merged.out | grep -i total go tool cover -html=.cover-merged.out -o .cover-kxd.html