git » go-net » commit ef2e00e

http2: make Transport return server's GOAWAY error back to the user

author Brad Fitzpatrick
2016-06-28 23:00:16 UTC
committer Brad Fitzpatrick
2016-06-28 23:24:14 UTC
parent 04557861f124410b768b1ba5bb3a91b705afbfc6

http2: make Transport return server's GOAWAY error back to the user

Updates golang/go#14627 (fixes once bundled into std)

Change-Id: Iae91d165df749e06549a25f9664ee416f115573f
Reviewed-on: https://go-review.googlesource.com/24560
Reviewed-by: Andrew Gerrand <adg@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>

http2/transport.go +25 -2
http2/transport_test.go +74 -0

diff --git a/http2/transport.go b/http2/transport.go
index 060471e..52bc9a3 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -153,6 +153,7 @@ type ClientConn struct {
 	inflow       flow       // peer's conn-level flow control
 	closed       bool
 	goAway       *GoAwayFrame             // if non-nil, the GoAwayFrame we received
+	goAwayDebug  string                   // goAway frame's debug data, retained as a string
 	streams      map[uint32]*clientStream // client-initiated
 	nextStreamID uint32
 	bw           *bufio.Writer
@@ -495,6 +496,7 @@ func (cc *ClientConn) setGoAway(f *GoAwayFrame) {
 	cc.mu.Lock()
 	defer cc.mu.Unlock()
 	cc.goAway = f
+	cc.goAwayDebug = string(f.DebugData())
 }
 
 func (cc *ClientConn) CanTakeNewRequest() bool {
@@ -1160,6 +1162,19 @@ func (cc *ClientConn) readLoop() {
 	}
 }
 
+// GoAwayError is returned by the Transport when the server closes the
+// TCP connection after sending a GOAWAY frame.
+type GoAwayError struct {
+	LastStreamID uint32
+	ErrCode      ErrCode
+	DebugData    string
+}
+
+func (e GoAwayError) Error() string {
+	return fmt.Sprintf("http2: server sent GOAWAY and closed the connection; LastStreamID=%v, ErrCode=%v, debug=%q",
+		e.LastStreamID, e.ErrCode, e.DebugData)
+}
+
 func (rl *clientConnReadLoop) cleanup() {
 	cc := rl.cc
 	defer cc.tconn.Close()
@@ -1170,10 +1185,18 @@ func (rl *clientConnReadLoop) cleanup() {
 	// TODO: also do this if we've written the headers but not
 	// gotten a response yet.
 	err := cc.readerErr
+	cc.mu.Lock()
 	if err == io.EOF {
-		err = io.ErrUnexpectedEOF
+		if cc.goAway != nil {
+			err = GoAwayError{
+				LastStreamID: cc.goAway.LastStreamID,
+				ErrCode:      cc.goAway.ErrCode,
+				DebugData:    cc.goAwayDebug,
+			}
+		} else {
+			err = io.ErrUnexpectedEOF
+		}
 	}
-	cc.mu.Lock()
 	for _, cs := range rl.activeRes {
 		cs.bufPipe.CloseWithError(err)
 	}
diff --git a/http2/transport_test.go b/http2/transport_test.go
index 631a04b..e1274b0 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -2011,3 +2011,77 @@ func TestTransportFlowControl(t *testing.T) {
 		time.Sleep(1 * time.Millisecond)
 	}
 }
+
+// golang.org/issue/14627 -- if the server sends a GOAWAY frame, make
+// the Transport remember it and return it back to users (via
+// RoundTrip or request body reads) if needed (e.g. if the server
+// proceeds to close the TCP connection before the client gets its
+// response)
+func TestTransportUsesGoAwayDebugError_RoundTrip(t *testing.T) {
+	testTransportUsesGoAwayDebugError(t, false)
+}
+
+func TestTransportUsesGoAwayDebugError_Body(t *testing.T) {
+	testTransportUsesGoAwayDebugError(t, true)
+}
+
+func testTransportUsesGoAwayDebugError(t *testing.T, failMidBody bool) {
+	ct := newClientTester(t)
+	clientDone := make(chan struct{})
+
+	const goAwayErrCode = ErrCodeHTTP11Required // arbitrary
+	const goAwayDebugData = "some debug data"
+
+	ct.client = func() error {
+		defer close(clientDone)
+		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+		res, err := ct.tr.RoundTrip(req)
+		if failMidBody {
+			if err != nil {
+				return fmt.Errorf("unexpected client RoundTrip error: %v", err)
+			}
+			_, err = io.Copy(ioutil.Discard, res.Body)
+			res.Body.Close()
+		}
+		want := GoAwayError{
+			LastStreamID: 0,
+			ErrCode:      goAwayErrCode,
+			DebugData:    goAwayDebugData,
+		}
+		if !reflect.DeepEqual(err, want) {
+			t.Errorf("RoundTrip error = %T: %#v, want %T (%#T)", err, err, want, want)
+		}
+		return nil
+	}
+	ct.server = func() error {
+		ct.greet()
+		for {
+			f, err := ct.fr.ReadFrame()
+			if err != nil {
+				t.Logf("ReadFrame: %v", err)
+				return nil
+			}
+			hf, ok := f.(*HeadersFrame)
+			if !ok {
+				continue
+			}
+			if failMidBody {
+				var buf bytes.Buffer
+				enc := hpack.NewEncoder(&buf)
+				enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
+				enc.WriteField(hpack.HeaderField{Name: "content-length", Value: "123"})
+				ct.fr.WriteHeaders(HeadersFrameParam{
+					StreamID:      hf.StreamID,
+					EndHeaders:    true,
+					EndStream:     false,
+					BlockFragment: buf.Bytes(),
+				})
+			}
+			ct.fr.WriteGoAway(0, goAwayErrCode, []byte(goAwayDebugData))
+			ct.sc.Close()
+			<-clientDone
+			return nil
+		}
+	}
+	ct.run()
+}