author | Brad Fitzpatrick
<bradfitz@golang.org> 2016-06-28 23:00:16 UTC |
committer | Brad Fitzpatrick
<bradfitz@golang.org> 2016-06-28 23:24:14 UTC |
parent | 04557861f124410b768b1ba5bb3a91b705afbfc6 |
http2/transport.go | +25 | -2 |
http2/transport_test.go | +74 | -0 |
diff --git a/http2/transport.go b/http2/transport.go index 060471e..52bc9a3 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -153,6 +153,7 @@ type ClientConn struct { inflow flow // peer's conn-level flow control closed bool goAway *GoAwayFrame // if non-nil, the GoAwayFrame we received + goAwayDebug string // goAway frame's debug data, retained as a string streams map[uint32]*clientStream // client-initiated nextStreamID uint32 bw *bufio.Writer @@ -495,6 +496,7 @@ func (cc *ClientConn) setGoAway(f *GoAwayFrame) { cc.mu.Lock() defer cc.mu.Unlock() cc.goAway = f + cc.goAwayDebug = string(f.DebugData()) } func (cc *ClientConn) CanTakeNewRequest() bool { @@ -1160,6 +1162,19 @@ func (cc *ClientConn) readLoop() { } } +// GoAwayError is returned by the Transport when the server closes the +// TCP connection after sending a GOAWAY frame. +type GoAwayError struct { + LastStreamID uint32 + ErrCode ErrCode + DebugData string +} + +func (e GoAwayError) Error() string { + return fmt.Sprintf("http2: server sent GOAWAY and closed the connection; LastStreamID=%v, ErrCode=%v, debug=%q", + e.LastStreamID, e.ErrCode, e.DebugData) +} + func (rl *clientConnReadLoop) cleanup() { cc := rl.cc defer cc.tconn.Close() @@ -1170,10 +1185,18 @@ func (rl *clientConnReadLoop) cleanup() { // TODO: also do this if we've written the headers but not // gotten a response yet. err := cc.readerErr + cc.mu.Lock() if err == io.EOF { - err = io.ErrUnexpectedEOF + if cc.goAway != nil { + err = GoAwayError{ + LastStreamID: cc.goAway.LastStreamID, + ErrCode: cc.goAway.ErrCode, + DebugData: cc.goAwayDebug, + } + } else { + err = io.ErrUnexpectedEOF + } } - cc.mu.Lock() for _, cs := range rl.activeRes { cs.bufPipe.CloseWithError(err) } diff --git a/http2/transport_test.go b/http2/transport_test.go index 631a04b..e1274b0 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -2011,3 +2011,77 @@ func TestTransportFlowControl(t *testing.T) { time.Sleep(1 * time.Millisecond) } } + +// golang.org/issue/14627 -- if the server sends a GOAWAY frame, make +// the Transport remember it and return it back to users (via +// RoundTrip or request body reads) if needed (e.g. if the server +// proceeds to close the TCP connection before the client gets its +// response) +func TestTransportUsesGoAwayDebugError_RoundTrip(t *testing.T) { + testTransportUsesGoAwayDebugError(t, false) +} + +func TestTransportUsesGoAwayDebugError_Body(t *testing.T) { + testTransportUsesGoAwayDebugError(t, true) +} + +func testTransportUsesGoAwayDebugError(t *testing.T, failMidBody bool) { + ct := newClientTester(t) + clientDone := make(chan struct{}) + + const goAwayErrCode = ErrCodeHTTP11Required // arbitrary + const goAwayDebugData = "some debug data" + + ct.client = func() error { + defer close(clientDone) + req, _ := http.NewRequest("GET", "https://dummy.tld/", nil) + res, err := ct.tr.RoundTrip(req) + if failMidBody { + if err != nil { + return fmt.Errorf("unexpected client RoundTrip error: %v", err) + } + _, err = io.Copy(ioutil.Discard, res.Body) + res.Body.Close() + } + want := GoAwayError{ + LastStreamID: 0, + ErrCode: goAwayErrCode, + DebugData: goAwayDebugData, + } + if !reflect.DeepEqual(err, want) { + t.Errorf("RoundTrip error = %T: %#v, want %T (%#T)", err, err, want, want) + } + return nil + } + ct.server = func() error { + ct.greet() + for { + f, err := ct.fr.ReadFrame() + if err != nil { + t.Logf("ReadFrame: %v", err) + return nil + } + hf, ok := f.(*HeadersFrame) + if !ok { + continue + } + if failMidBody { + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) + enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "123"}) + ct.fr.WriteHeaders(HeadersFrameParam{ + StreamID: hf.StreamID, + EndHeaders: true, + EndStream: false, + BlockFragment: buf.Bytes(), + }) + } + ct.fr.WriteGoAway(0, goAwayErrCode, []byte(goAwayDebugData)) + ct.sc.Close() + <-clientDone + return nil + } + } + ct.run() +}