author | Brad Fitzpatrick
<bradfitz@golang.org> 2016-05-21 00:18:04 UTC |
committer | Brad Fitzpatrick
<bradfitz@golang.org> 2016-05-21 00:25:09 UTC |
parent | 4d07e8a493e586d735e32ed147c5e64a3da7c230 |
http2/server.go | +18 | -1 |
http2/server_test.go | +51 | -0 |
diff --git a/http2/server.go b/http2/server.go index 4e07a20..1de8146 100644 --- a/http2/server.go +++ b/http2/server.go @@ -461,6 +461,7 @@ type stream struct { sentReset bool // only true once detached from streams map gotReset bool // only true once detacted from streams map gotTrailerHeader bool // HEADER frame for trailers was seen + wroteHeaders bool // whether we wrote headers (not status 100) reqBuf []byte trailer http.Header // accumulated trailers @@ -848,7 +849,23 @@ func (sc *serverConn) writeFrameFromHandler(wm frameWriteMsg) error { // If you're not on the serve goroutine, use writeFrameFromHandler instead. func (sc *serverConn) writeFrame(wm frameWriteMsg) { sc.serveG.check() - sc.writeSched.add(wm) + + var ignoreWrite bool + + // Don't send a 100-continue response if we've already sent headers. + // See golang.org/issue/14030. + switch wm.write.(type) { + case *writeResHeaders: + wm.stream.wroteHeaders = true + case write100ContinueHeadersFrame: + if wm.stream.wroteHeaders { + ignoreWrite = true + } + } + + if !ignoreWrite { + sc.writeSched.add(wm) + } sc.scheduleFrameWrite() } diff --git a/http2/server_test.go b/http2/server_test.go index 012bfd4..540b0d9 100644 --- a/http2/server_test.go +++ b/http2/server_test.go @@ -3244,3 +3244,54 @@ func TestCheckValidHTTP2Request(t *testing.T) { } } } + +// golang.org/issue/14030 +func TestExpect100ContinueAfterHandlerWrites(t *testing.T) { + const msg = "Hello" + const msg2 = "World" + + doRead := make(chan bool, 1) + defer close(doRead) // fallback cleanup + + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, msg) + w.(http.Flusher).Flush() + + // Do a read, which might force a 100-continue status to be sent. + <-doRead + r.Body.Read(make([]byte, 10)) + + io.WriteString(w, msg2) + + }, optOnlyServer) + defer st.Close() + + tr := &Transport{TLSClientConfig: tlsConfigInsecure} + defer tr.CloseIdleConnections() + + req, _ := http.NewRequest("POST", st.ts.URL, io.LimitReader(neverEnding('A'), 2<<20)) + req.Header.Set("Expect", "100-continue") + + res, err := tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + + buf := make([]byte, len(msg)) + if _, err := io.ReadFull(res.Body, buf); err != nil { + t.Fatal(err) + } + if string(buf) != msg { + t.Fatalf("msg = %q; want %q", buf, msg) + } + + doRead <- true + + if _, err := io.ReadFull(res.Body, buf); err != nil { + t.Fatal(err) + } + if string(buf) != msg2 { + t.Fatalf("second msg = %q; want %q", buf, msg2) + } +}