git » go-net » commit e2ba55e

http2: fix Transport.RoundTrip hang on stream error before headers

author Brad Fitzpatrick
2016-08-02 15:44:32 UTC
committer Brad Fitzpatrick
2016-08-03 01:52:23 UTC
parent f6d211983832e1efcd77e80779fd5f1142741356

http2: fix Transport.RoundTrip hang on stream error before headers

If the Transport got a stream error on the response headers, it was
never unblocking the client. Previously, Response.Body reads would be
aborted with the stream error, but RoundTrip itself would never
unblock.

The Transport now also sends a RST_STREAM to the server when we
encounter a stream error.

Also, add a "Cause" field to StreamError with additional detail. The
old code was just returning the detail, without the stream error
header.

Fixes golang/go#16572

Change-Id: Ibecedb5779f17bf98c32787b68eb8a9b850833b3
Reviewed-on: https://go-review.googlesource.com/25402
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Andrew Gerrand <adg@golang.org>

http2/errors.go +8 -0
http2/frame.go +4 -4
http2/frame_test.go +13 -6
http2/server.go +17 -17
http2/server_test.go +16 -22
http2/transport.go +11 -2
http2/transport_test.go +81 -2

diff --git a/http2/errors.go b/http2/errors.go
index 71a4e29..20fd762 100644
--- a/http2/errors.go
+++ b/http2/errors.go
@@ -64,9 +64,17 @@ func (e ConnectionError) Error() string { return fmt.Sprintf("connection error:
 type StreamError struct {
 	StreamID uint32
 	Code     ErrCode
+	Cause    error // optional additional detail
+}
+
+func streamError(id uint32, code ErrCode) StreamError {
+	return StreamError{StreamID: id, Code: code}
 }
 
 func (e StreamError) Error() string {
+	if e.Cause != nil {
+		return fmt.Sprintf("stream error: stream ID %d; %v; %v", e.StreamID, e.Code, e.Cause)
+	}
 	return fmt.Sprintf("stream error: stream ID %d; %v", e.StreamID, e.Code)
 }
 
diff --git a/http2/frame.go b/http2/frame.go
index 6769907..c9b09bb 100644
--- a/http2/frame.go
+++ b/http2/frame.go
@@ -863,7 +863,7 @@ func parseWindowUpdateFrame(fh FrameHeader, p []byte) (Frame, error) {
 		if fh.StreamID == 0 {
 			return nil, ConnectionError(ErrCodeProtocol)
 		}
-		return nil, StreamError{fh.StreamID, ErrCodeProtocol}
+		return nil, streamError(fh.StreamID, ErrCodeProtocol)
 	}
 	return &WindowUpdateFrame{
 		FrameHeader: fh,
@@ -944,7 +944,7 @@ func parseHeadersFrame(fh FrameHeader, p []byte) (_ Frame, err error) {
 		}
 	}
 	if len(p)-int(padLength) <= 0 {
-		return nil, StreamError{fh.StreamID, ErrCodeProtocol}
+		return nil, streamError(fh.StreamID, ErrCodeProtocol)
 	}
 	hf.headerFragBuf = p[:len(p)-int(padLength)]
 	return hf, nil
@@ -1483,14 +1483,14 @@ func (fr *Framer) readMetaFrame(hf *HeadersFrame) (*MetaHeadersFrame, error) {
 		if VerboseLogs {
 			log.Printf("http2: invalid header: %v", invalid)
 		}
-		return nil, StreamError{mh.StreamID, ErrCodeProtocol}
+		return nil, StreamError{mh.StreamID, ErrCodeProtocol, invalid}
 	}
 	if err := mh.checkPseudos(); err != nil {
 		fr.errDetail = err
 		if VerboseLogs {
 			log.Printf("http2: invalid pseudo headers: %v", err)
 		}
-		return nil, StreamError{mh.StreamID, ErrCodeProtocol}
+		return nil, StreamError{mh.StreamID, ErrCodeProtocol, err}
 	}
 	return mh, nil
 }
diff --git a/http2/frame_test.go b/http2/frame_test.go
index 689ef57..7b1933d 100644
--- a/http2/frame_test.go
+++ b/http2/frame_test.go
@@ -992,7 +992,7 @@ func TestMetaFrameHeader(t *testing.T) {
 					":path", "/", // bogus
 				))
 			},
-			want:          StreamError{1, ErrCodeProtocol},
+			want:          streamError(1, ErrCodeProtocol),
 			wantErrReason: "pseudo header field after regular",
 		},
 		7: {
@@ -1003,7 +1003,7 @@ func TestMetaFrameHeader(t *testing.T) {
 					"foo", "bar",
 				))
 			},
-			want:          StreamError{1, ErrCodeProtocol},
+			want:          streamError(1, ErrCodeProtocol),
 			wantErrReason: "invalid pseudo-header \":unknown\"",
 		},
 		8: {
@@ -1014,7 +1014,7 @@ func TestMetaFrameHeader(t *testing.T) {
 					":status", "100",
 				))
 			},
-			want:          StreamError{1, ErrCodeProtocol},
+			want:          streamError(1, ErrCodeProtocol),
 			wantErrReason: "mix of request and response pseudo headers",
 		},
 		9: {
@@ -1025,7 +1025,7 @@ func TestMetaFrameHeader(t *testing.T) {
 					":method", "POST",
 				))
 			},
-			want:          StreamError{1, ErrCodeProtocol},
+			want:          streamError(1, ErrCodeProtocol),
 			wantErrReason: "duplicate pseudo-header \":method\"",
 		},
 		10: {
@@ -1036,13 +1036,13 @@ func TestMetaFrameHeader(t *testing.T) {
 		11: {
 			name:          "invalid_field_name",
 			w:             func(f *Framer) { write(f, encodeHeaderRaw(t, "CapitalBad", "x")) },
-			want:          StreamError{1, ErrCodeProtocol},
+			want:          streamError(1, ErrCodeProtocol),
 			wantErrReason: "invalid header field name \"CapitalBad\"",
 		},
 		12: {
 			name:          "invalid_field_value",
 			w:             func(f *Framer) { write(f, encodeHeaderRaw(t, "key", "bad_null\x00")) },
-			want:          StreamError{1, ErrCodeProtocol},
+			want:          streamError(1, ErrCodeProtocol),
 			wantErrReason: "invalid header field value \"bad_null\\x00\"",
 		},
 	}
@@ -1063,6 +1063,13 @@ func TestMetaFrameHeader(t *testing.T) {
 		got, err = f.ReadFrame()
 		if err != nil {
 			got = err
+
+			// Ignore the StreamError.Cause field, if it matches the wantErrReason.
+			// The test table above predates the Cause field.
+			if se, ok := err.(StreamError); ok && se.Cause != nil && se.Cause.Error() == tt.wantErrReason {
+				se.Cause = nil
+				got = se
+			}
 		}
 		if !reflect.DeepEqual(got, tt.want) {
 			if mhg, ok := got.(*MetaHeadersFrame); ok {
diff --git a/http2/server.go b/http2/server.go
index 679bda4..8206fa7 100644
--- a/http2/server.go
+++ b/http2/server.go
@@ -922,7 +922,7 @@ func (sc *serverConn) wroteFrame(res frameWriteResult) {
 			// state here anyway, after telling the peer
 			// we're hanging up on them.
 			st.state = stateHalfClosedLocal // won't last long, but necessary for closeStream via resetStream
-			errCancel := StreamError{st.id, ErrCodeCancel}
+			errCancel := streamError(st.id, ErrCodeCancel)
 			sc.resetStream(errCancel)
 		case stateHalfClosedRemote:
 			sc.closeStream(st, errHandlerComplete)
@@ -1133,7 +1133,7 @@ func (sc *serverConn) processWindowUpdate(f *WindowUpdateFrame) error {
 			return nil
 		}
 		if !st.flow.add(int32(f.Increment)) {
-			return StreamError{f.StreamID, ErrCodeFlowControl}
+			return streamError(f.StreamID, ErrCodeFlowControl)
 		}
 	default: // connection-level flow control
 		if !sc.flow.add(int32(f.Increment)) {
@@ -1159,7 +1159,7 @@ func (sc *serverConn) processResetStream(f *RSTStreamFrame) error {
 	if st != nil {
 		st.gotReset = true
 		st.cancelCtx()
-		sc.closeStream(st, StreamError{f.StreamID, f.ErrCode})
+		sc.closeStream(st, streamError(f.StreamID, f.ErrCode))
 	}
 	return nil
 }
@@ -1299,7 +1299,7 @@ func (sc *serverConn) processData(f *DataFrame) error {
 		// and return any flow control bytes since we're not going
 		// to consume them.
 		if sc.inflow.available() < int32(f.Length) {
-			return StreamError{id, ErrCodeFlowControl}
+			return streamError(id, ErrCodeFlowControl)
 		}
 		// Deduct the flow control from inflow, since we're
 		// going to immediately add it back in
@@ -1308,7 +1308,7 @@ func (sc *serverConn) processData(f *DataFrame) error {
 		sc.inflow.take(int32(f.Length))
 		sc.sendWindowUpdate(nil, int(f.Length)) // conn-level
 
-		return StreamError{id, ErrCodeStreamClosed}
+		return streamError(id, ErrCodeStreamClosed)
 	}
 	if st.body == nil {
 		panic("internal error: should have a body in this state")
@@ -1317,19 +1317,19 @@ func (sc *serverConn) processData(f *DataFrame) error {
 	// Sender sending more than they'd declared?
 	if st.declBodyBytes != -1 && st.bodyBytes+int64(len(data)) > st.declBodyBytes {
 		st.body.CloseWithError(fmt.Errorf("sender tried to send more than declared Content-Length of %d bytes", st.declBodyBytes))
-		return StreamError{id, ErrCodeStreamClosed}
+		return streamError(id, ErrCodeStreamClosed)
 	}
 	if f.Length > 0 {
 		// Check whether the client has flow control quota.
 		if st.inflow.available() < int32(f.Length) {
-			return StreamError{id, ErrCodeFlowControl}
+			return streamError(id, ErrCodeFlowControl)
 		}
 		st.inflow.take(int32(f.Length))
 
 		if len(data) > 0 {
 			wrote, err := st.body.Write(data)
 			if err != nil {
-				return StreamError{id, ErrCodeStreamClosed}
+				return streamError(id, ErrCodeStreamClosed)
 			}
 			if wrote != len(data) {
 				panic("internal error: bad Writer")
@@ -1446,14 +1446,14 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error {
 		// REFUSED_STREAM."
 		if sc.unackedSettings == 0 {
 			// They should know better.
-			return StreamError{st.id, ErrCodeProtocol}
+			return streamError(st.id, ErrCodeProtocol)
 		}
 		// Assume it's a network race, where they just haven't
 		// received our last SETTINGS update. But actually
 		// this can't happen yet, because we don't yet provide
 		// a way for users to adjust server parameters at
 		// runtime.
-		return StreamError{st.id, ErrCodeRefusedStream}
+		return streamError(st.id, ErrCodeRefusedStream)
 	}
 
 	rw, req, err := sc.newWriterAndRequest(st, f)
@@ -1487,11 +1487,11 @@ func (st *stream) processTrailerHeaders(f *MetaHeadersFrame) error {
 	}
 	st.gotTrailerHeader = true
 	if !f.StreamEnded() {
-		return StreamError{st.id, ErrCodeProtocol}
+		return streamError(st.id, ErrCodeProtocol)
 	}
 
 	if len(f.PseudoFields()) > 0 {
-		return StreamError{st.id, ErrCodeProtocol}
+		return streamError(st.id, ErrCodeProtocol)
 	}
 	if st.trailer != nil {
 		for _, hf := range f.RegularFields() {
@@ -1500,7 +1500,7 @@ func (st *stream) processTrailerHeaders(f *MetaHeadersFrame) error {
 				// TODO: send more details to the peer somehow. But http2 has
 				// no way to send debug data at a stream level. Discuss with
 				// HTTP folk.
-				return StreamError{st.id, ErrCodeProtocol}
+				return streamError(st.id, ErrCodeProtocol)
 			}
 			st.trailer[key] = append(st.trailer[key], hf.Value)
 		}
@@ -1561,7 +1561,7 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res
 	isConnect := method == "CONNECT"
 	if isConnect {
 		if path != "" || scheme != "" || authority == "" {
-			return nil, nil, StreamError{f.StreamID, ErrCodeProtocol}
+			return nil, nil, streamError(f.StreamID, ErrCodeProtocol)
 		}
 	} else if method == "" || path == "" ||
 		(scheme != "https" && scheme != "http") {
@@ -1575,13 +1575,13 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res
 		// "All HTTP/2 requests MUST include exactly one valid
 		// value for the :method, :scheme, and :path
 		// pseudo-header fields"
-		return nil, nil, StreamError{f.StreamID, ErrCodeProtocol}
+		return nil, nil, streamError(f.StreamID, ErrCodeProtocol)
 	}
 
 	bodyOpen := !f.StreamEnded()
 	if method == "HEAD" && bodyOpen {
 		// HEAD requests can't have bodies
-		return nil, nil, StreamError{f.StreamID, ErrCodeProtocol}
+		return nil, nil, streamError(f.StreamID, ErrCodeProtocol)
 	}
 	var tlsState *tls.ConnectionState // nil if not scheme https
 
@@ -1639,7 +1639,7 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res
 		var err error
 		url_, err = url.ParseRequestURI(path)
 		if err != nil {
-			return nil, nil, StreamError{f.StreamID, ErrCodeProtocol}
+			return nil, nil, streamError(f.StreamID, ErrCodeProtocol)
 		}
 		requestURI = path
 	}
diff --git a/http2/server_test.go b/http2/server_test.go
index c1f654d..ecacf84 100644
--- a/http2/server_test.go
+++ b/http2/server_test.go
@@ -55,11 +55,6 @@ type serverTester struct {
 	// writing headers:
 	headerBuf bytes.Buffer
 	hpackEnc  *hpack.Encoder
-
-	// reading frames:
-	frc       chan Frame
-	frErrc    chan error
-	readTimer *time.Timer
 }
 
 func init() {
@@ -117,8 +112,6 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{}
 		t:      t,
 		ts:     ts,
 		logBuf: logBuf,
-		frc:    make(chan Frame, 1),
-		frErrc: make(chan error, 1),
 	}
 	st.hpackEnc = hpack.NewEncoder(&st.headerBuf)
 	st.hpackDec = hpack.NewDecoder(initialHeaderTableSize, st.onHeaderField)
@@ -365,32 +358,33 @@ func (st *serverTester) writeDataPadded(streamID uint32, endStream bool, data, p
 	}
 }
 
-func (st *serverTester) readFrame() (Frame, error) {
+func readFrameTimeout(fr *Framer, wait time.Duration) (Frame, error) {
+	ch := make(chan interface{}, 1)
 	go func() {
-		fr, err := st.fr.ReadFrame()
+		fr, err := fr.ReadFrame()
 		if err != nil {
-			st.frErrc <- err
+			ch <- err
 		} else {
-			st.frc <- fr
+			ch <- fr
 		}
 	}()
-	t := st.readTimer
-	if t == nil {
-		t = time.NewTimer(2 * time.Second)
-		st.readTimer = t
-	}
-	t.Reset(2 * time.Second)
-	defer t.Stop()
+	t := time.NewTimer(wait)
 	select {
-	case f := <-st.frc:
-		return f, nil
-	case err := <-st.frErrc:
-		return nil, err
+	case v := <-ch:
+		t.Stop()
+		if fr, ok := v.(Frame); ok {
+			return fr, nil
+		}
+		return nil, v.(error)
 	case <-t.C:
 		return nil, errors.New("timeout waiting for frame")
 	}
 }
 
+func (st *serverTester) readFrame() (Frame, error) {
+	return readFrameTimeout(st.fr, 2*time.Second)
+}
+
 func (st *serverTester) wantHeaders() *HeadersFrame {
 	f, err := st.readFrame()
 	if err != nil {
diff --git a/http2/transport.go b/http2/transport.go
index f6019db..149dcca 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -1229,7 +1229,11 @@ func (rl *clientConnReadLoop) run() error {
 		}
 		if se, ok := err.(StreamError); ok {
 			if cs := cc.streamByID(se.StreamID, true /*ended; remove it*/); cs != nil {
-				rl.endStreamError(cs, cc.fr.errDetail)
+				cs.cc.writeStreamReset(cs.ID, se.Code, err)
+				if se.Cause == nil {
+					se.Cause = cc.fr.errDetail
+				}
+				rl.endStreamError(cs, se)
 			}
 			continue
 		} else if err != nil {
@@ -1639,6 +1643,11 @@ func (rl *clientConnReadLoop) endStreamError(cs *clientStream, err error) {
 	if isConnectionCloseRequest(cs.req) {
 		rl.closeWhenIdle = true
 	}
+
+	select {
+	case cs.resc <- resAndError{err: err}:
+	default:
+	}
 }
 
 func (cs *clientStream) copyTrailers() {
@@ -1740,7 +1749,7 @@ func (rl *clientConnReadLoop) processResetStream(f *RSTStreamFrame) error {
 		// which closes this, so there
 		// isn't a race.
 	default:
-		err := StreamError{cs.ID, f.ErrCode}
+		err := streamError(cs.ID, f.ErrCode)
 		cs.resetErr = err
 		close(cs.peerReset)
 		cs.bufPipe.CloseWithError(err)
diff --git a/http2/transport_test.go b/http2/transport_test.go
index 614fa44..a09b6c1 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -699,6 +699,28 @@ func (ct *clientTester) start(which string, errc chan<- error, fn func() error)
 	}()
 }
 
+func (ct *clientTester) readFrame() (Frame, error) {
+	return readFrameTimeout(ct.fr, 2*time.Second)
+}
+
+func (ct *clientTester) firstHeaders() (*HeadersFrame, error) {
+	for {
+		f, err := ct.readFrame()
+		if err != nil {
+			return nil, fmt.Errorf("ReadFrame while waiting for Headers: %v", err)
+		}
+		switch f.(type) {
+		case *WindowUpdateFrame, *SettingsFrame:
+			continue
+		}
+		hf, ok := f.(*HeadersFrame)
+		if !ok {
+			return nil, fmt.Errorf("Got %T; want HeadersFrame", f)
+		}
+		return hf, nil
+	}
+}
+
 type countingReader struct {
 	n *int64
 }
@@ -1224,8 +1246,9 @@ func testInvalidTrailer(t *testing.T, trailers headerType, wantErr error, writeT
 			return fmt.Errorf("status code = %v; want 200", res.StatusCode)
 		}
 		slurp, err := ioutil.ReadAll(res.Body)
-		if err != wantErr {
-			return fmt.Errorf("res.Body ReadAll error = %q, %#v; want %T of %#v", slurp, err, wantErr, wantErr)
+		se, ok := err.(StreamError)
+		if !ok || se.Cause != wantErr {
+			return fmt.Errorf("res.Body ReadAll error = %q, %#v; want StreamError with cause %T, %#v", slurp, err, wantErr, wantErr)
 		}
 		if len(slurp) > 0 {
 			return fmt.Errorf("body = %q; want nothing", slurp)
@@ -2278,3 +2301,59 @@ func TestTransportReturnsDataPaddingFlowControl(t *testing.T) {
 	}
 	ct.run()
 }
+
+// golang.org/issue/16572 -- RoundTrip shouldn't hang when it gets a
+// StreamError as a result of the response HEADERS
+func TestTransportReturnsErrorOnBadResponseHeaders(t *testing.T) {
+	ct := newClientTester(t)
+
+	ct.client = func() error {
+		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+		res, err := ct.tr.RoundTrip(req)
+		if err == nil {
+			res.Body.Close()
+			return errors.New("unexpected successful GET")
+		}
+		want := StreamError{1, ErrCodeProtocol, headerFieldNameError("  content-type")}
+		if !reflect.DeepEqual(want, err) {
+			t.Errorf("RoundTrip error = %#v; want %#v", err, want)
+		}
+		return nil
+	}
+	ct.server = func() error {
+		ct.greet()
+
+		hf, err := ct.firstHeaders()
+		if err != nil {
+			return err
+		}
+
+		var buf bytes.Buffer
+		enc := hpack.NewEncoder(&buf)
+		enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
+		enc.WriteField(hpack.HeaderField{Name: "  content-type", Value: "bogus"}) // bogus spaces
+		ct.fr.WriteHeaders(HeadersFrameParam{
+			StreamID:      hf.StreamID,
+			EndHeaders:    true,
+			EndStream:     false,
+			BlockFragment: buf.Bytes(),
+		})
+
+		for {
+			fr, err := ct.readFrame()
+			if err != nil {
+				return fmt.Errorf("error waiting for RST_STREAM from client: %v", err)
+			}
+			if _, ok := fr.(*SettingsFrame); ok {
+				continue
+			}
+			if rst, ok := fr.(*RSTStreamFrame); !ok || rst.StreamID != 1 || rst.ErrCode != ErrCodeProtocol {
+				t.Errorf("Frame = %v; want RST_STREAM for stream 1 with ErrCodeProtocol", summarizeFrame(fr))
+			}
+			break
+		}
+
+		return nil
+	}
+	ct.run()
+}