author | Brad Fitzpatrick
<bradfitz@golang.org> 2016-08-02 15:44:32 UTC |
committer | Brad Fitzpatrick
<bradfitz@golang.org> 2016-08-03 01:52:23 UTC |
parent | f6d211983832e1efcd77e80779fd5f1142741356 |
http2/errors.go | +8 | -0 |
http2/frame.go | +4 | -4 |
http2/frame_test.go | +13 | -6 |
http2/server.go | +17 | -17 |
http2/server_test.go | +16 | -22 |
http2/transport.go | +11 | -2 |
http2/transport_test.go | +81 | -2 |
diff --git a/http2/errors.go b/http2/errors.go index 71a4e29..20fd762 100644 --- a/http2/errors.go +++ b/http2/errors.go @@ -64,9 +64,17 @@ func (e ConnectionError) Error() string { return fmt.Sprintf("connection error: type StreamError struct { StreamID uint32 Code ErrCode + Cause error // optional additional detail +} + +func streamError(id uint32, code ErrCode) StreamError { + return StreamError{StreamID: id, Code: code} } func (e StreamError) Error() string { + if e.Cause != nil { + return fmt.Sprintf("stream error: stream ID %d; %v; %v", e.StreamID, e.Code, e.Cause) + } return fmt.Sprintf("stream error: stream ID %d; %v", e.StreamID, e.Code) } diff --git a/http2/frame.go b/http2/frame.go index 6769907..c9b09bb 100644 --- a/http2/frame.go +++ b/http2/frame.go @@ -863,7 +863,7 @@ func parseWindowUpdateFrame(fh FrameHeader, p []byte) (Frame, error) { if fh.StreamID == 0 { return nil, ConnectionError(ErrCodeProtocol) } - return nil, StreamError{fh.StreamID, ErrCodeProtocol} + return nil, streamError(fh.StreamID, ErrCodeProtocol) } return &WindowUpdateFrame{ FrameHeader: fh, @@ -944,7 +944,7 @@ func parseHeadersFrame(fh FrameHeader, p []byte) (_ Frame, err error) { } } if len(p)-int(padLength) <= 0 { - return nil, StreamError{fh.StreamID, ErrCodeProtocol} + return nil, streamError(fh.StreamID, ErrCodeProtocol) } hf.headerFragBuf = p[:len(p)-int(padLength)] return hf, nil @@ -1483,14 +1483,14 @@ func (fr *Framer) readMetaFrame(hf *HeadersFrame) (*MetaHeadersFrame, error) { if VerboseLogs { log.Printf("http2: invalid header: %v", invalid) } - return nil, StreamError{mh.StreamID, ErrCodeProtocol} + return nil, StreamError{mh.StreamID, ErrCodeProtocol, invalid} } if err := mh.checkPseudos(); err != nil { fr.errDetail = err if VerboseLogs { log.Printf("http2: invalid pseudo headers: %v", err) } - return nil, StreamError{mh.StreamID, ErrCodeProtocol} + return nil, StreamError{mh.StreamID, ErrCodeProtocol, err} } return mh, nil } diff --git a/http2/frame_test.go b/http2/frame_test.go index 689ef57..7b1933d 100644 --- a/http2/frame_test.go +++ b/http2/frame_test.go @@ -992,7 +992,7 @@ func TestMetaFrameHeader(t *testing.T) { ":path", "/", // bogus )) }, - want: StreamError{1, ErrCodeProtocol}, + want: streamError(1, ErrCodeProtocol), wantErrReason: "pseudo header field after regular", }, 7: { @@ -1003,7 +1003,7 @@ func TestMetaFrameHeader(t *testing.T) { "foo", "bar", )) }, - want: StreamError{1, ErrCodeProtocol}, + want: streamError(1, ErrCodeProtocol), wantErrReason: "invalid pseudo-header \":unknown\"", }, 8: { @@ -1014,7 +1014,7 @@ func TestMetaFrameHeader(t *testing.T) { ":status", "100", )) }, - want: StreamError{1, ErrCodeProtocol}, + want: streamError(1, ErrCodeProtocol), wantErrReason: "mix of request and response pseudo headers", }, 9: { @@ -1025,7 +1025,7 @@ func TestMetaFrameHeader(t *testing.T) { ":method", "POST", )) }, - want: StreamError{1, ErrCodeProtocol}, + want: streamError(1, ErrCodeProtocol), wantErrReason: "duplicate pseudo-header \":method\"", }, 10: { @@ -1036,13 +1036,13 @@ func TestMetaFrameHeader(t *testing.T) { 11: { name: "invalid_field_name", w: func(f *Framer) { write(f, encodeHeaderRaw(t, "CapitalBad", "x")) }, - want: StreamError{1, ErrCodeProtocol}, + want: streamError(1, ErrCodeProtocol), wantErrReason: "invalid header field name \"CapitalBad\"", }, 12: { name: "invalid_field_value", w: func(f *Framer) { write(f, encodeHeaderRaw(t, "key", "bad_null\x00")) }, - want: StreamError{1, ErrCodeProtocol}, + want: streamError(1, ErrCodeProtocol), wantErrReason: "invalid header field value \"bad_null\\x00\"", }, } @@ -1063,6 +1063,13 @@ func TestMetaFrameHeader(t *testing.T) { got, err = f.ReadFrame() if err != nil { got = err + + // Ignore the StreamError.Cause field, if it matches the wantErrReason. + // The test table above predates the Cause field. + if se, ok := err.(StreamError); ok && se.Cause != nil && se.Cause.Error() == tt.wantErrReason { + se.Cause = nil + got = se + } } if !reflect.DeepEqual(got, tt.want) { if mhg, ok := got.(*MetaHeadersFrame); ok { diff --git a/http2/server.go b/http2/server.go index 679bda4..8206fa7 100644 --- a/http2/server.go +++ b/http2/server.go @@ -922,7 +922,7 @@ func (sc *serverConn) wroteFrame(res frameWriteResult) { // state here anyway, after telling the peer // we're hanging up on them. st.state = stateHalfClosedLocal // won't last long, but necessary for closeStream via resetStream - errCancel := StreamError{st.id, ErrCodeCancel} + errCancel := streamError(st.id, ErrCodeCancel) sc.resetStream(errCancel) case stateHalfClosedRemote: sc.closeStream(st, errHandlerComplete) @@ -1133,7 +1133,7 @@ func (sc *serverConn) processWindowUpdate(f *WindowUpdateFrame) error { return nil } if !st.flow.add(int32(f.Increment)) { - return StreamError{f.StreamID, ErrCodeFlowControl} + return streamError(f.StreamID, ErrCodeFlowControl) } default: // connection-level flow control if !sc.flow.add(int32(f.Increment)) { @@ -1159,7 +1159,7 @@ func (sc *serverConn) processResetStream(f *RSTStreamFrame) error { if st != nil { st.gotReset = true st.cancelCtx() - sc.closeStream(st, StreamError{f.StreamID, f.ErrCode}) + sc.closeStream(st, streamError(f.StreamID, f.ErrCode)) } return nil } @@ -1299,7 +1299,7 @@ func (sc *serverConn) processData(f *DataFrame) error { // and return any flow control bytes since we're not going // to consume them. if sc.inflow.available() < int32(f.Length) { - return StreamError{id, ErrCodeFlowControl} + return streamError(id, ErrCodeFlowControl) } // Deduct the flow control from inflow, since we're // going to immediately add it back in @@ -1308,7 +1308,7 @@ func (sc *serverConn) processData(f *DataFrame) error { sc.inflow.take(int32(f.Length)) sc.sendWindowUpdate(nil, int(f.Length)) // conn-level - return StreamError{id, ErrCodeStreamClosed} + return streamError(id, ErrCodeStreamClosed) } if st.body == nil { panic("internal error: should have a body in this state") @@ -1317,19 +1317,19 @@ func (sc *serverConn) processData(f *DataFrame) error { // Sender sending more than they'd declared? if st.declBodyBytes != -1 && st.bodyBytes+int64(len(data)) > st.declBodyBytes { st.body.CloseWithError(fmt.Errorf("sender tried to send more than declared Content-Length of %d bytes", st.declBodyBytes)) - return StreamError{id, ErrCodeStreamClosed} + return streamError(id, ErrCodeStreamClosed) } if f.Length > 0 { // Check whether the client has flow control quota. if st.inflow.available() < int32(f.Length) { - return StreamError{id, ErrCodeFlowControl} + return streamError(id, ErrCodeFlowControl) } st.inflow.take(int32(f.Length)) if len(data) > 0 { wrote, err := st.body.Write(data) if err != nil { - return StreamError{id, ErrCodeStreamClosed} + return streamError(id, ErrCodeStreamClosed) } if wrote != len(data) { panic("internal error: bad Writer") @@ -1446,14 +1446,14 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error { // REFUSED_STREAM." if sc.unackedSettings == 0 { // They should know better. - return StreamError{st.id, ErrCodeProtocol} + return streamError(st.id, ErrCodeProtocol) } // Assume it's a network race, where they just haven't // received our last SETTINGS update. But actually // this can't happen yet, because we don't yet provide // a way for users to adjust server parameters at // runtime. - return StreamError{st.id, ErrCodeRefusedStream} + return streamError(st.id, ErrCodeRefusedStream) } rw, req, err := sc.newWriterAndRequest(st, f) @@ -1487,11 +1487,11 @@ func (st *stream) processTrailerHeaders(f *MetaHeadersFrame) error { } st.gotTrailerHeader = true if !f.StreamEnded() { - return StreamError{st.id, ErrCodeProtocol} + return streamError(st.id, ErrCodeProtocol) } if len(f.PseudoFields()) > 0 { - return StreamError{st.id, ErrCodeProtocol} + return streamError(st.id, ErrCodeProtocol) } if st.trailer != nil { for _, hf := range f.RegularFields() { @@ -1500,7 +1500,7 @@ func (st *stream) processTrailerHeaders(f *MetaHeadersFrame) error { // TODO: send more details to the peer somehow. But http2 has // no way to send debug data at a stream level. Discuss with // HTTP folk. - return StreamError{st.id, ErrCodeProtocol} + return streamError(st.id, ErrCodeProtocol) } st.trailer[key] = append(st.trailer[key], hf.Value) } @@ -1561,7 +1561,7 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res isConnect := method == "CONNECT" if isConnect { if path != "" || scheme != "" || authority == "" { - return nil, nil, StreamError{f.StreamID, ErrCodeProtocol} + return nil, nil, streamError(f.StreamID, ErrCodeProtocol) } } else if method == "" || path == "" || (scheme != "https" && scheme != "http") { @@ -1575,13 +1575,13 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res // "All HTTP/2 requests MUST include exactly one valid // value for the :method, :scheme, and :path // pseudo-header fields" - return nil, nil, StreamError{f.StreamID, ErrCodeProtocol} + return nil, nil, streamError(f.StreamID, ErrCodeProtocol) } bodyOpen := !f.StreamEnded() if method == "HEAD" && bodyOpen { // HEAD requests can't have bodies - return nil, nil, StreamError{f.StreamID, ErrCodeProtocol} + return nil, nil, streamError(f.StreamID, ErrCodeProtocol) } var tlsState *tls.ConnectionState // nil if not scheme https @@ -1639,7 +1639,7 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res var err error url_, err = url.ParseRequestURI(path) if err != nil { - return nil, nil, StreamError{f.StreamID, ErrCodeProtocol} + return nil, nil, streamError(f.StreamID, ErrCodeProtocol) } requestURI = path } diff --git a/http2/server_test.go b/http2/server_test.go index c1f654d..ecacf84 100644 --- a/http2/server_test.go +++ b/http2/server_test.go @@ -55,11 +55,6 @@ type serverTester struct { // writing headers: headerBuf bytes.Buffer hpackEnc *hpack.Encoder - - // reading frames: - frc chan Frame - frErrc chan error - readTimer *time.Timer } func init() { @@ -117,8 +112,6 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{} t: t, ts: ts, logBuf: logBuf, - frc: make(chan Frame, 1), - frErrc: make(chan error, 1), } st.hpackEnc = hpack.NewEncoder(&st.headerBuf) st.hpackDec = hpack.NewDecoder(initialHeaderTableSize, st.onHeaderField) @@ -365,32 +358,33 @@ func (st *serverTester) writeDataPadded(streamID uint32, endStream bool, data, p } } -func (st *serverTester) readFrame() (Frame, error) { +func readFrameTimeout(fr *Framer, wait time.Duration) (Frame, error) { + ch := make(chan interface{}, 1) go func() { - fr, err := st.fr.ReadFrame() + fr, err := fr.ReadFrame() if err != nil { - st.frErrc <- err + ch <- err } else { - st.frc <- fr + ch <- fr } }() - t := st.readTimer - if t == nil { - t = time.NewTimer(2 * time.Second) - st.readTimer = t - } - t.Reset(2 * time.Second) - defer t.Stop() + t := time.NewTimer(wait) select { - case f := <-st.frc: - return f, nil - case err := <-st.frErrc: - return nil, err + case v := <-ch: + t.Stop() + if fr, ok := v.(Frame); ok { + return fr, nil + } + return nil, v.(error) case <-t.C: return nil, errors.New("timeout waiting for frame") } } +func (st *serverTester) readFrame() (Frame, error) { + return readFrameTimeout(st.fr, 2*time.Second) +} + func (st *serverTester) wantHeaders() *HeadersFrame { f, err := st.readFrame() if err != nil { diff --git a/http2/transport.go b/http2/transport.go index f6019db..149dcca 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -1229,7 +1229,11 @@ func (rl *clientConnReadLoop) run() error { } if se, ok := err.(StreamError); ok { if cs := cc.streamByID(se.StreamID, true /*ended; remove it*/); cs != nil { - rl.endStreamError(cs, cc.fr.errDetail) + cs.cc.writeStreamReset(cs.ID, se.Code, err) + if se.Cause == nil { + se.Cause = cc.fr.errDetail + } + rl.endStreamError(cs, se) } continue } else if err != nil { @@ -1639,6 +1643,11 @@ func (rl *clientConnReadLoop) endStreamError(cs *clientStream, err error) { if isConnectionCloseRequest(cs.req) { rl.closeWhenIdle = true } + + select { + case cs.resc <- resAndError{err: err}: + default: + } } func (cs *clientStream) copyTrailers() { @@ -1740,7 +1749,7 @@ func (rl *clientConnReadLoop) processResetStream(f *RSTStreamFrame) error { // which closes this, so there // isn't a race. default: - err := StreamError{cs.ID, f.ErrCode} + err := streamError(cs.ID, f.ErrCode) cs.resetErr = err close(cs.peerReset) cs.bufPipe.CloseWithError(err) diff --git a/http2/transport_test.go b/http2/transport_test.go index 614fa44..a09b6c1 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -699,6 +699,28 @@ func (ct *clientTester) start(which string, errc chan<- error, fn func() error) }() } +func (ct *clientTester) readFrame() (Frame, error) { + return readFrameTimeout(ct.fr, 2*time.Second) +} + +func (ct *clientTester) firstHeaders() (*HeadersFrame, error) { + for { + f, err := ct.readFrame() + if err != nil { + return nil, fmt.Errorf("ReadFrame while waiting for Headers: %v", err) + } + switch f.(type) { + case *WindowUpdateFrame, *SettingsFrame: + continue + } + hf, ok := f.(*HeadersFrame) + if !ok { + return nil, fmt.Errorf("Got %T; want HeadersFrame", f) + } + return hf, nil + } +} + type countingReader struct { n *int64 } @@ -1224,8 +1246,9 @@ func testInvalidTrailer(t *testing.T, trailers headerType, wantErr error, writeT return fmt.Errorf("status code = %v; want 200", res.StatusCode) } slurp, err := ioutil.ReadAll(res.Body) - if err != wantErr { - return fmt.Errorf("res.Body ReadAll error = %q, %#v; want %T of %#v", slurp, err, wantErr, wantErr) + se, ok := err.(StreamError) + if !ok || se.Cause != wantErr { + return fmt.Errorf("res.Body ReadAll error = %q, %#v; want StreamError with cause %T, %#v", slurp, err, wantErr, wantErr) } if len(slurp) > 0 { return fmt.Errorf("body = %q; want nothing", slurp) @@ -2278,3 +2301,59 @@ func TestTransportReturnsDataPaddingFlowControl(t *testing.T) { } ct.run() } + +// golang.org/issue/16572 -- RoundTrip shouldn't hang when it gets a +// StreamError as a result of the response HEADERS +func TestTransportReturnsErrorOnBadResponseHeaders(t *testing.T) { + ct := newClientTester(t) + + ct.client = func() error { + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + res, err := ct.tr.RoundTrip(req) + if err == nil { + res.Body.Close() + return errors.New("unexpected successful GET") + } + want := StreamError{1, ErrCodeProtocol, headerFieldNameError(" content-type")} + if !reflect.DeepEqual(want, err) { + t.Errorf("RoundTrip error = %#v; want %#v", err, want) + } + return nil + } + ct.server = func() error { + ct.greet() + + hf, err := ct.firstHeaders() + if err != nil { + return err + } + + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) + enc.WriteField(hpack.HeaderField{Name: " content-type", Value: "bogus"}) // bogus spaces + ct.fr.WriteHeaders(HeadersFrameParam{ + StreamID: hf.StreamID, + EndHeaders: true, + EndStream: false, + BlockFragment: buf.Bytes(), + }) + + for { + fr, err := ct.readFrame() + if err != nil { + return fmt.Errorf("error waiting for RST_STREAM from client: %v", err) + } + if _, ok := fr.(*SettingsFrame); ok { + continue + } + if rst, ok := fr.(*RSTStreamFrame); !ok || rst.StreamID != 1 || rst.ErrCode != ErrCodeProtocol { + t.Errorf("Frame = %v; want RST_STREAM for stream 1 with ErrCodeProtocol", summarizeFrame(fr)) + } + break + } + + return nil + } + ct.run() +}