author | Brad Fitzpatrick
<bradfitz@golang.org> 2016-05-20 18:55:47 UTC |
committer | Brad Fitzpatrick
<bradfitz@golang.org> 2016-05-20 20:56:38 UTC |
parent | 8a52c78636f6b7be1b1e5cb58b01a85f1e082659 |
http2/server.go | +1 | -1 |
http2/transport.go | +8 | -5 |
http2/transport_test.go | +53 | -0 |
diff --git a/http2/server.go b/http2/server.go index 57c8276..4e07a20 100644 --- a/http2/server.go +++ b/http2/server.go @@ -1833,7 +1833,7 @@ type requestBody struct { func (b *requestBody) Close() error { if b.pipe != nil { - b.pipe.CloseWithError(errClosedBody) + b.pipe.BreakWithError(errClosedBody) } b.closed = true return nil diff --git a/http2/transport.go b/http2/transport.go index b666e37..2ae7437 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -220,12 +220,14 @@ func (cs *clientStream) awaitRequestCancel(req *http.Request) { } } -// checkReset reports any error sent in a RST_STREAM frame by the -// server. -func (cs *clientStream) checkReset() error { +// checkResetOrDone reports any error sent in a RST_STREAM frame by the +// server, or errStreamClosed if the stream is complete. +func (cs *clientStream) checkResetOrDone() error { select { case <-cs.peerReset: return cs.resetErr + case <-cs.done: + return errStreamClosed default: return nil } @@ -935,7 +937,7 @@ func (cs *clientStream) awaitFlowControl(maxBytes int) (taken int32, err error) if cs.stopReqBody != nil { return 0, cs.stopReqBody } - if err := cs.checkReset(); err != nil { + if err := cs.checkResetOrDone(); err != nil { return 0, err } if a := cs.flow.available(); a > 0 { @@ -1121,6 +1123,7 @@ func (cc *ClientConn) streamByID(id uint32, andRemove bool) *clientStream { cc.lastActive = time.Now() delete(cc.streams, id) close(cs.done) + cc.cond.Broadcast() // wake up checkResetOrDone via clientStream.awaitFlowControl } return cs } @@ -1627,7 +1630,7 @@ func (rl *clientConnReadLoop) processResetStream(f *RSTStreamFrame) error { cs.resetErr = err close(cs.peerReset) cs.bufPipe.CloseWithError(err) - cs.cc.cond.Broadcast() // wake up checkReset via clientStream.awaitFlowControl + cs.cc.cond.Broadcast() // wake up checkResetOrDone via clientStream.awaitFlowControl } delete(rl.activeRes, cs.ID) return nil diff --git a/http2/transport_test.go b/http2/transport_test.go index bd07c93..7bba6a7 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -20,6 +20,7 @@ import ( "net/url" "os" "reflect" + "runtime" "sort" "strconv" "strings" @@ -101,6 +102,7 @@ func TestTransport(t *testing.T) { t.Errorf("Body = %q; want %q", slurp, body) } } + func onSameConn(t *testing.T, modReq func(*http.Request)) bool { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { io.WriteString(w, r.RemoteAddr) @@ -1862,3 +1864,54 @@ func TestTransportReadHeadResponse(t *testing.T) { } ct.run() } + +type neverEnding byte + +func (b neverEnding) Read(p []byte) (int, error) { + for i := range p { + p[i] = byte(b) + } + return len(p), nil +} + +// golang.org/issue/15425: test that a handler closing the request +// body doesn't terminate the stream to the peer. (It just stops +// readability from the handler's side, and eventually the client +// runs out of flow control tokens) +func TestTransportHandlerBodyClose(t *testing.T) { + const bodySize = 10 << 20 + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + r.Body.Close() + io.Copy(w, io.LimitReader(neverEnding('A'), bodySize)) + }, optOnlyServer) + defer st.Close() + + tr := &Transport{TLSClientConfig: tlsConfigInsecure} + defer tr.CloseIdleConnections() + + g0 := runtime.NumGoroutine() + + const numReq = 10 + for i := 0; i < numReq; i++ { + req, err := http.NewRequest("POST", st.ts.URL, struct{ io.Reader }{io.LimitReader(neverEnding('A'), bodySize)}) + if err != nil { + t.Fatal(err) + } + res, err := tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + n, err := io.Copy(ioutil.Discard, res.Body) + res.Body.Close() + if n != bodySize || err != nil { + t.Fatalf("req#d: Copy = %d, %v; want %d, nil", i, n, err, bodySize) + } + } + tr.CloseIdleConnections() + + gd := runtime.NumGoroutine() - g0 + if gd > numReq/2 { + t.Errorf("appeared to leak goroutines") + } + +}