git » go-net » commit 6050c11

http2/hpack: forbid excess and invalid padding in hpack decoder

author Carl Mastrangelo
2016-05-12 00:54:17 UTC
committer Brad Fitzpatrick
2016-05-13 23:09:52 UTC
parent 96dbb961a39ddccf16860cdd355bfa639c497f23

http2/hpack: forbid excess and invalid padding in hpack decoder

This change fixes a few bugs in the HPACK decoder:
 * Excess trailing padding is treated as an error per the HPACK Spec
     section 5.2
 * Non EOS prefix padding is treated as an error
 * Max length is now enforced for all decoded symbols

The idea here is to keep track of the decoded symbol length, rather
than the number of unconsumed bits in cur.  To this end, nbits has
been renamed cbits (cur bits), and sbits (sym bits) has been
introduced.  The main problem with using nbits is that it can easily
be zero, such as when decoding {0xff, 0xff}.  Using a clear moniker
makes it easier to see why checking cbits > 0 at the end of the
function is incorrect.

Fixes golang/go#15614

Change-Id: I1ae868caa9c207fcf9c9dec7f10ee9f400211f99
Reviewed-on: https://go-review.googlesource.com/23067
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>

http2/hpack/hpack_test.go +41 -0
http2/hpack/huffman.go +32 -10

diff --git a/http2/hpack/hpack_test.go b/http2/hpack/hpack_test.go
index 6dc69f9..4c7b17b 100644
--- a/http2/hpack/hpack_test.go
+++ b/http2/hpack/hpack_test.go
@@ -524,6 +524,47 @@ func testDecodeSeries(t *testing.T, size uint32, steps []encAndWant) {
 	}
 }
 
+func TestHuffmanDecodeExcessPadding(t *testing.T) {
+	tests := [][]byte{
+		{0xff},                                   // Padding Exceeds 7 bits
+		{0x1f, 0xff},                             // {"a", 1 byte excess padding}
+		{0x1f, 0xff, 0xff},                       // {"a", 2 byte excess padding}
+		{0x1f, 0xff, 0xff, 0xff},                 // {"a", 3 byte excess padding}
+		{0xff, 0x9f, 0xff, 0xff, 0xff},           // {"a", 29 bit excess padding}
+		{'R', 0xbc, '0', 0xff, 0xff, 0xff, 0xff}, // Padding ends on partial symbol.
+	}
+	for i, in := range tests {
+		var buf bytes.Buffer
+		if _, err := HuffmanDecode(&buf, in); err != ErrInvalidHuffman {
+			t.Errorf("test-%d: decode(%q) = %v; want ErrInvalidHuffman", i, in, err)
+		}
+	}
+}
+
+func TestHuffmanDecodeEOS(t *testing.T) {
+	in := []byte{0xff, 0xff, 0xff, 0xff, 0xfc} // {EOS, "?"}
+	var buf bytes.Buffer
+	if _, err := HuffmanDecode(&buf, in); err != ErrInvalidHuffman {
+		t.Errorf("error = %v; want ErrInvalidHuffman", err)
+	}
+}
+
+func TestHuffmanDecodeMaxLengthOnTrailingByte(t *testing.T) {
+	in := []byte{0x00, 0x01} // {"0", "0", "0"}
+	var buf bytes.Buffer
+	if err := huffmanDecode(&buf, 2, in); err != ErrStringLength {
+		t.Errorf("error = %v; want ErrStringLength", err)
+	}
+}
+
+func TestHuffmanDecodeCorruptPadding(t *testing.T) {
+	in := []byte{0x00}
+	var buf bytes.Buffer
+	if _, err := HuffmanDecode(&buf, in); err != ErrInvalidHuffman {
+		t.Errorf("error = %v; want ErrInvalidHuffman", err)
+	}
+}
+
 func TestHuffmanDecode(t *testing.T) {
 	tests := []struct {
 		inHex, want string
diff --git a/http2/hpack/huffman.go b/http2/hpack/huffman.go
index eb4b1f0..8850e39 100644
--- a/http2/hpack/huffman.go
+++ b/http2/hpack/huffman.go
@@ -48,12 +48,16 @@ var ErrInvalidHuffman = errors.New("hpack: invalid Huffman-encoded data")
 // maxLen bytes will return ErrStringLength.
 func huffmanDecode(buf *bytes.Buffer, maxLen int, v []byte) error {
 	n := rootHuffmanNode
-	cur, nbits := uint(0), uint8(0)
+	// cur is the bit buffer that has not been fed into n.
+	// cbits is the number of low order bits in cur that are valid.
+	// sbits is the number of bits of the symbol prefix being decoded.
+	cur, cbits, sbits := uint(0), uint8(0), uint8(0)
 	for _, b := range v {
 		cur = cur<<8 | uint(b)
-		nbits += 8
-		for nbits >= 8 {
-			idx := byte(cur >> (nbits - 8))
+		cbits += 8
+		sbits += 8
+		for cbits >= 8 {
+			idx := byte(cur >> (cbits - 8))
 			n = n.children[idx]
 			if n == nil {
 				return ErrInvalidHuffman
@@ -63,22 +67,40 @@ func huffmanDecode(buf *bytes.Buffer, maxLen int, v []byte) error {
 					return ErrStringLength
 				}
 				buf.WriteByte(n.sym)
-				nbits -= n.codeLen
+				cbits -= n.codeLen
 				n = rootHuffmanNode
+				sbits = cbits
 			} else {
-				nbits -= 8
+				cbits -= 8
 			}
 		}
 	}
-	for nbits > 0 {
-		n = n.children[byte(cur<<(8-nbits))]
-		if n.children != nil || n.codeLen > nbits {
+	for cbits > 0 {
+		n = n.children[byte(cur<<(8-cbits))]
+		if n == nil {
+			return ErrInvalidHuffman
+		}
+		if n.children != nil || n.codeLen > cbits {
 			break
 		}
+		if maxLen != 0 && buf.Len() == maxLen {
+			return ErrStringLength
+		}
 		buf.WriteByte(n.sym)
-		nbits -= n.codeLen
+		cbits -= n.codeLen
 		n = rootHuffmanNode
+		sbits = cbits
+	}
+	if sbits > 7 {
+		// Either there was an incomplete symbol, or overlong padding.
+		// Both are decoding errors per RFC 7541 section 5.2.
+		return ErrInvalidHuffman
 	}
+	if mask := uint(1<<cbits - 1); cur&mask != mask {
+		// Trailing bits must be a prefix of EOS per RFC 7541 section 5.2.
+		return ErrInvalidHuffman
+	}
+
 	return nil
 }