author | Brad Fitzpatrick
<bradfitz@golang.org> 2016-07-31 22:09:09 UTC |
committer | Brad Fitzpatrick
<bradfitz@golang.org> 2016-08-01 02:38:47 UTC |
parent | 6a513affb38dc9788b449d59ffed099b8de18fa0 |
http2/frame.go | +25 | -2 |
http2/frame_test.go | +71 | -0 |
http2/server.go | +22 | -12 |
http2/server_test.go | +40 | -0 |
http2/transport.go | +26 | -9 |
http2/transport_test.go | +72 | -0 |
diff --git a/http2/frame.go b/http2/frame.go index 981d407..bd50d09 100644 --- a/http2/frame.go +++ b/http2/frame.go @@ -594,6 +594,7 @@ func parseDataFrame(fh FrameHeader, payload []byte) (Frame, error) { var ( errStreamID = errors.New("invalid stream ID") errDepStreamID = errors.New("invalid dependent stream ID") + errPadLength = errors.New("pad length too large") ) func validStreamIDOrZero(streamID uint32) bool { @@ -607,18 +608,40 @@ func validStreamID(streamID uint32) bool { // WriteData writes a DATA frame. // // It will perform exactly one Write to the underlying Writer. -// It is the caller's responsibility to not call other Write methods concurrently. +// It is the caller's responsibility not to violate the maximum frame size +// and to not call other Write methods concurrently. func (f *Framer) WriteData(streamID uint32, endStream bool, data []byte) error { - // TODO: ignoring padding for now. will add when somebody cares. + return f.WriteDataPadded(streamID, endStream, data, nil) +} + +// WriteData writes a DATA frame with optional padding. +// +// If pad is nil, the padding bit is not sent. +// The length of pad must not exceed 255 bytes. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility not to violate the maximum frame size +// and to not call other Write methods concurrently. +func (f *Framer) WriteDataPadded(streamID uint32, endStream bool, data, pad []byte) error { if !validStreamID(streamID) && !f.AllowIllegalWrites { return errStreamID } + if len(pad) > 255 { + return errPadLength + } var flags Flags if endStream { flags |= FlagDataEndStream } + if pad != nil { + flags |= FlagDataPadded + } f.startWrite(FrameData, flags, streamID) + if pad != nil { + f.wbuf = append(f.wbuf, byte(len(pad))) + } f.wbuf = append(f.wbuf, data...) + f.wbuf = append(f.wbuf, pad...) return f.endWrite() } diff --git a/http2/frame_test.go b/http2/frame_test.go index 9bd24af..689ef57 100644 --- a/http2/frame_test.go +++ b/http2/frame_test.go @@ -100,6 +100,77 @@ func TestWriteData(t *testing.T) { } } +func TestWriteDataPadded(t *testing.T) { + tests := [...]struct { + streamID uint32 + endStream bool + data []byte + pad []byte + wantHeader FrameHeader + }{ + // Unpadded: + 0: { + streamID: 1, + endStream: true, + data: []byte("foo"), + pad: nil, + wantHeader: FrameHeader{ + Type: FrameData, + Flags: FlagDataEndStream, + Length: 3, + StreamID: 1, + }, + }, + + // Padded bit set, but no padding: + 1: { + streamID: 1, + endStream: true, + data: []byte("foo"), + pad: []byte{}, + wantHeader: FrameHeader{ + Type: FrameData, + Flags: FlagDataEndStream | FlagDataPadded, + Length: 4, + StreamID: 1, + }, + }, + + // Padded bit set, with padding: + 2: { + streamID: 1, + endStream: false, + data: []byte("foo"), + pad: []byte("bar"), + wantHeader: FrameHeader{ + Type: FrameData, + Flags: FlagDataPadded, + Length: 7, + StreamID: 1, + }, + }, + } + for i, tt := range tests { + fr, _ := testFramer() + fr.WriteDataPadded(tt.streamID, tt.endStream, tt.data, tt.pad) + f, err := fr.ReadFrame() + if err != nil { + t.Errorf("%d. ReadFrame: %v", i, err) + continue + } + got := f.Header() + tt.wantHeader.valid = true + if got != tt.wantHeader { + t.Errorf("%d. read %+v; want %+v", i, got, tt.wantHeader) + continue + } + df := f.(*DataFrame) + if !bytes.Equal(df.Data(), tt.data) { + t.Errorf("%d. got %q; want %q", i, df.Data(), tt.data) + } + } +} + func TestWriteHeaders(t *testing.T) { tests := []struct { name string diff --git a/http2/server.go b/http2/server.go index dbe6c87..679bda4 100644 --- a/http2/server.go +++ b/http2/server.go @@ -1298,15 +1298,15 @@ func (sc *serverConn) processData(f *DataFrame) error { // But still enforce their connection-level flow control, // and return any flow control bytes since we're not going // to consume them. - if int(sc.inflow.available()) < len(data) { + if sc.inflow.available() < int32(f.Length) { return StreamError{id, ErrCodeFlowControl} } // Deduct the flow control from inflow, since we're // going to immediately add it back in // sendWindowUpdate, which also schedules sending the // frames. - sc.inflow.take(int32(len(data))) - sc.sendWindowUpdate(nil, len(data)) // conn-level + sc.inflow.take(int32(f.Length)) + sc.sendWindowUpdate(nil, int(f.Length)) // conn-level return StreamError{id, ErrCodeStreamClosed} } @@ -1319,20 +1319,30 @@ func (sc *serverConn) processData(f *DataFrame) error { st.body.CloseWithError(fmt.Errorf("sender tried to send more than declared Content-Length of %d bytes", st.declBodyBytes)) return StreamError{id, ErrCodeStreamClosed} } - if len(data) > 0 { + if f.Length > 0 { // Check whether the client has flow control quota. - if int(st.inflow.available()) < len(data) { + if st.inflow.available() < int32(f.Length) { return StreamError{id, ErrCodeFlowControl} } - st.inflow.take(int32(len(data))) - wrote, err := st.body.Write(data) - if err != nil { - return StreamError{id, ErrCodeStreamClosed} + st.inflow.take(int32(f.Length)) + + if len(data) > 0 { + wrote, err := st.body.Write(data) + if err != nil { + return StreamError{id, ErrCodeStreamClosed} + } + if wrote != len(data) { + panic("internal error: bad Writer") + } + st.bodyBytes += int64(len(data)) } - if wrote != len(data) { - panic("internal error: bad Writer") + + // Return any padded flow control now, since we won't + // refund it later on body reads. + if pad := int32(f.Length) - int32(len(data)); pad > 0 { + sc.sendWindowUpdate32(nil, pad) + sc.sendWindowUpdate32(st, pad) } - st.bodyBytes += int64(len(data)) } if f.StreamEnded() { st.endStream() diff --git a/http2/server_test.go b/http2/server_test.go index ac4d351..c1f654d 100644 --- a/http2/server_test.go +++ b/http2/server_test.go @@ -359,6 +359,12 @@ func (st *serverTester) writeData(streamID uint32, endStream bool, data []byte) } } +func (st *serverTester) writeDataPadded(streamID uint32, endStream bool, data, pad []byte) { + if err := st.fr.WriteDataPadded(streamID, endStream, data, pad); err != nil { + st.t.Fatalf("Error writing DATA: %v", err) + } +} + func (st *serverTester) readFrame() (Frame, error) { go func() { fr, err := st.fr.ReadFrame() @@ -1083,6 +1089,40 @@ func TestServer_Handler_Sends_WindowUpdate(t *testing.T) { st.wantWindowUpdate(0, 3) // no more stream-level, since END_STREAM } +// the version of the TestServer_Handler_Sends_WindowUpdate with padding. +// See golang.org/issue/16556 +func TestServer_Handler_Sends_WindowUpdate_Padding(t *testing.T) { + puppet := newHandlerPuppet() + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + puppet.act(w, r) + }) + defer st.Close() + defer puppet.done() + + st.greet() + + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, + BlockFragment: st.encodeHeader(":method", "POST"), + EndStream: false, + EndHeaders: true, + }) + st.writeDataPadded(1, false, []byte("abcdef"), []byte("1234")) + + // Expect to immediately get our 5 bytes of padding back for + // both the connection and stream (4 bytes of padding + 1 byte of length) + st.wantWindowUpdate(0, 5) + st.wantWindowUpdate(1, 5) + + puppet.do(readBodyHandler(t, "abc")) + st.wantWindowUpdate(0, 3) + st.wantWindowUpdate(1, 3) + + puppet.do(readBodyHandler(t, "def")) + st.wantWindowUpdate(0, 3) + st.wantWindowUpdate(1, 3) +} + func TestServer_Send_GoAway_After_Bogus_WindowUpdate(t *testing.T) { st := newServerTester(t, nil) defer st.Close() diff --git a/http2/transport.go b/http2/transport.go index b6f6f95..a81445d 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -1581,16 +1581,20 @@ func (rl *clientConnReadLoop) processData(f *DataFrame) error { // by the peer? Tough without accumulating too much state. // But at least return their flow control: - if len(data) > 0 { + if f.Length > 0 { + cc.mu.Lock() + cc.inflow.add(int32(f.Length)) + cc.mu.Unlock() + cc.wmu.Lock() - cc.fr.WriteWindowUpdate(0, uint32(len(data))) + cc.fr.WriteWindowUpdate(0, uint32(f.Length)) cc.bw.Flush() cc.wmu.Unlock() } return nil } - if len(data) > 0 { - if cs.bufPipe.b == nil { + if f.Length > 0 { + if len(data) > 0 && cs.bufPipe.b == nil { // Data frame after it's already closed? cc.logf("http2: Transport received DATA frame for closed stream; closing connection") return ConnectionError(ErrCodeProtocol) @@ -1598,17 +1602,30 @@ func (rl *clientConnReadLoop) processData(f *DataFrame) error { // Check connection-level flow control. cc.mu.Lock() - if cs.inflow.available() >= int32(len(data)) { - cs.inflow.take(int32(len(data))) + if cs.inflow.available() >= int32(f.Length) { + cs.inflow.take(int32(f.Length)) } else { cc.mu.Unlock() return ConnectionError(ErrCodeFlowControl) } + // Return any padded flow control now, since we won't + // refund it later on body reads. + if pad := int32(f.Length) - int32(len(data)); pad > 0 { + cs.inflow.add(pad) + cc.inflow.add(pad) + cc.wmu.Lock() + cc.fr.WriteWindowUpdate(0, uint32(pad)) + cc.fr.WriteWindowUpdate(cs.ID, uint32(pad)) + cc.bw.Flush() + cc.wmu.Unlock() + } cc.mu.Unlock() - if _, err := cs.bufPipe.Write(data); err != nil { - rl.endStreamError(cs, err) - return err + if len(data) > 0 { + if _, err := cs.bufPipe.Write(data); err != nil { + rl.endStreamError(cs, err) + return err + } } } diff --git a/http2/transport_test.go b/http2/transport_test.go index f22eeca..4f3b8a1 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -2177,3 +2177,75 @@ func TestTransportReturnsUnusedFlowControl(t *testing.T) { } ct.run() } + +// See golang.org/issue/16556 +func TestTransportReturnsDataPaddingFlowControl(t *testing.T) { + ct := newClientTester(t) + + unblockClient := make(chan bool, 1) + + ct.client = func() error { + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + res, err := ct.tr.RoundTrip(req) + if err != nil { + return err + } + defer res.Body.Close() + <-unblockClient + return nil + } + ct.server = func() error { + ct.greet() + + var hf *HeadersFrame + for { + f, err := ct.fr.ReadFrame() + if err != nil { + return fmt.Errorf("ReadFrame while waiting for Headers: %v", err) + } + switch f.(type) { + case *WindowUpdateFrame, *SettingsFrame: + continue + } + var ok bool + hf, ok = f.(*HeadersFrame) + if !ok { + return fmt.Errorf("Got %T; want HeadersFrame", f) + } + break + } + + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) + enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "5000"}) + ct.fr.WriteHeaders(HeadersFrameParam{ + StreamID: hf.StreamID, + EndHeaders: true, + EndStream: false, + BlockFragment: buf.Bytes(), + }) + pad := []byte("12345") + ct.fr.WriteDataPadded(hf.StreamID, false, make([]byte, 5000), pad) // without ending stream + + f, err := ct.fr.ReadFrame() + if err != nil { + return fmt.Errorf("ReadFrame while waiting for first WindowUpdateFrame: %v", err) + } + wantBack := uint32(len(pad)) + 1 // one byte for the length of the padding + if wuf, ok := f.(*WindowUpdateFrame); !ok || wuf.Increment != wantBack || wuf.StreamID != 0 { + return fmt.Errorf("Expected conn WindowUpdateFrame for %d bytes; got %v", wantBack, summarizeFrame(f)) + } + + f, err = ct.fr.ReadFrame() + if err != nil { + return fmt.Errorf("ReadFrame while waiting for second WindowUpdateFrame: %v", err) + } + if wuf, ok := f.(*WindowUpdateFrame); !ok || wuf.Increment != wantBack || wuf.StreamID == 0 { + return fmt.Errorf("Expected stream WindowUpdateFrame for %d bytes; got %v", wantBack, summarizeFrame(f)) + } + unblockClient <- true + return nil + } + ct.run() +}