author | Brad Fitzpatrick
<bradfitz@golang.org> 2016-04-26 21:38:55 UTC |
committer | Brad Fitzpatrick
<bradfitz@golang.org> 2016-04-26 22:44:14 UTC |
parent | b797637b7aeeed133049c7281bfa31dcc9ca42d6 |
http2/frame.go | +13 | -3 |
http2/frame_test.go | +49 | -0 |
diff --git a/http2/frame.go b/http2/frame.go index 6943f93..4badb9d 100644 --- a/http2/frame.go +++ b/http2/frame.go @@ -590,7 +590,14 @@ func parseDataFrame(fh FrameHeader, payload []byte) (Frame, error) { return f, nil } -var errStreamID = errors.New("invalid streamid") +var ( + errStreamID = errors.New("invalid stream ID") + errDepStreamID = errors.New("invalid dependent stream ID") +) + +func validStreamIDOrZero(streamID uint32) bool { + return streamID&(1<<31) == 0 +} func validStreamID(streamID uint32) bool { return streamID != 0 && streamID&(1<<31) == 0 @@ -977,8 +984,8 @@ func (f *Framer) WriteHeaders(p HeadersFrameParam) error { } if !p.Priority.IsZero() { v := p.Priority.StreamDep - if !validStreamID(v) && !f.AllowIllegalWrites { - return errors.New("invalid dependent stream id") + if !validStreamIDOrZero(v) && !f.AllowIllegalWrites { + return errDepStreamID } if p.Priority.Exclusive { v |= 1 << 31 @@ -1046,6 +1053,9 @@ func (f *Framer) WritePriority(streamID uint32, p PriorityParam) error { if !validStreamID(streamID) && !f.AllowIllegalWrites { return errStreamID } + if !validStreamIDOrZero(p.StreamDep) { + return errDepStreamID + } f.startWrite(FramePriority, 0, streamID) v := p.StreamDep if p.Exclusive { diff --git a/http2/frame_test.go b/http2/frame_test.go index bf37a67..9bd24af 100644 --- a/http2/frame_test.go +++ b/http2/frame_test.go @@ -202,6 +202,37 @@ func TestWriteHeaders(t *testing.T) { headerFragBuf: []byte("abc"), }, }, + { + "with priority stream dep zero", // golang.org/issue/15444 + HeadersFrameParam{ + StreamID: 42, + BlockFragment: []byte("abc"), + EndStream: true, + EndHeaders: true, + PadLength: 2, + Priority: PriorityParam{ + StreamDep: 0, + Exclusive: true, + Weight: 127, + }, + }, + "\x00\x00\v\x01-\x00\x00\x00*\x02\x80\x00\x00\x00\u007fabc\x00\x00", + &HeadersFrame{ + FrameHeader: FrameHeader{ + valid: true, + StreamID: 42, + Type: FrameHeaders, + Flags: FlagHeadersEndStream | FlagHeadersEndHeaders | FlagHeadersPadded | FlagHeadersPriority, + Length: uint32(1 + 5 + len("abc") + 2), // pad length + priority + contents + padding + }, + Priority: PriorityParam{ + StreamDep: 0, + Exclusive: true, + Weight: 127, + }, + headerFragBuf: []byte("abc"), + }, + }, } for _, tt := range tests { fr, buf := testFramer() @@ -223,6 +254,24 @@ func TestWriteHeaders(t *testing.T) { } } +func TestWriteInvalidStreamDep(t *testing.T) { + fr, _ := testFramer() + err := fr.WriteHeaders(HeadersFrameParam{ + StreamID: 42, + Priority: PriorityParam{ + StreamDep: 1 << 31, + }, + }) + if err != errDepStreamID { + t.Errorf("header error = %v; want %q", err, errDepStreamID) + } + + err = fr.WritePriority(2, PriorityParam{StreamDep: 1 << 31}) + if err != errDepStreamID { + t.Errorf("priority error = %v; want %q", err, errDepStreamID) + } +} + func TestWriteContinuation(t *testing.T) { const streamID = 42 tests := []struct {