git » go-net » commit 202ff48

http2: delay sending request body in Transport if 100-continue is set

author Brad Fitzpatrick
2016-05-19 05:29:09 UTC
committer Brad Fitzpatrick
2016-05-20 02:53:00 UTC
parent 5916dcb167ed985a5b9e6871fbfd74848a4c170b

http2: delay sending request body in Transport if 100-continue is set

In Go 1.6, the HTTP/1 client got Transport.ExpectContinueTimeout.

This makes the HTTP/2 client respect a Request's "Expect:
100-continue" field and the Transport.ExpectContinueTimeout
configuration.

This also makes sure to call the traceWroteRequest hook if the server
replied while we're still writing the request, since that code was
in the same spot and it couldn't be trivially separated.

Updates golang/go#13851 (fixed after integrating it into std)
Updates golang/go#15744

Change-Id: I67dfd68532daa6c4a0c026549c6e5cbfce50e1ea
Reviewed-on: https://go-review.googlesource.com/23235
Reviewed-by: Andrew Gerrand <adg@golang.org>

http2/go16.go +16 -0
http2/go17.go +14 -2
http2/not_go16.go +8 -1
http2/not_go17.go +2 -0
http2/transport.go +127 -32

diff --git a/http2/go16.go b/http2/go16.go
new file mode 100644
index 0000000..00b2e9e
--- /dev/null
+++ b/http2/go16.go
@@ -0,0 +1,16 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build go1.6
+
+package http2
+
+import (
+	"net/http"
+	"time"
+)
+
+func transportExpectContinueTimeout(t1 *http.Transport) time.Duration {
+	return t1.ExpectContinueTimeout
+}
diff --git a/http2/go17.go b/http2/go17.go
index 3d3c71e..730319d 100644
--- a/http2/go17.go
+++ b/http2/go17.go
@@ -49,8 +49,8 @@ func traceGotConn(req *http.Request, cc *ClientConn) {
 	ci := httptrace.GotConnInfo{Conn: cc.tconn}
 	cc.mu.Lock()
 	ci.Reused = cc.nextStreamID > 1
-	ci.WasIdle = len(cc.streams) == 0
-	if ci.WasIdle {
+	ci.WasIdle = len(cc.streams) == 0 && ci.Reused
+	if ci.WasIdle && !cc.lastActive.IsZero() {
 		ci.IdleTime = time.Now().Sub(cc.lastActive)
 	}
 	cc.mu.Unlock()
@@ -64,6 +64,18 @@ func traceWroteHeaders(trace *clientTrace) {
 	}
 }
 
+func traceGot100Continue(trace *clientTrace) {
+	if trace != nil && trace.Got100Continue != nil {
+		trace.Got100Continue()
+	}
+}
+
+func traceWait100Continue(trace *clientTrace) {
+	if trace != nil && trace.Wait100Continue != nil {
+		trace.Wait100Continue()
+	}
+}
+
 func traceWroteRequest(trace *clientTrace, err error) {
 	if trace != nil && trace.WroteRequest != nil {
 		trace.WroteRequest(httptrace.WroteRequestInfo{Err: err})
diff --git a/http2/not_go16.go b/http2/not_go16.go
index db53c5b..51a7f19 100644
--- a/http2/not_go16.go
+++ b/http2/not_go16.go
@@ -6,8 +6,15 @@
 
 package http2
 
-import "net/http"
+import (
+	"net/http"
+	"time"
+)
 
 func configureTransport(t1 *http.Transport) (*Transport, error) {
 	return nil, errTransportVersion
 }
+
+func transportExpectContinueTimeout(t1 *http.Transport) time.Duration {
+	return 0
+}
diff --git a/http2/not_go17.go b/http2/not_go17.go
index 077db39..28df0c1 100644
--- a/http2/not_go17.go
+++ b/http2/not_go17.go
@@ -33,6 +33,8 @@ func traceGotConn(*http.Request, *ClientConn) {}
 func traceFirstResponseByte(*clientTrace)     {}
 func traceWroteHeaders(*clientTrace)          {}
 func traceWroteRequest(*clientTrace, error)   {}
+func traceGot100Continue(trace *clientTrace)  {}
+func traceWait100Continue(trace *clientTrace) {}
 
 func nop() {}
 
diff --git a/http2/transport.go b/http2/transport.go
index 5f1564a..03712d5 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -178,6 +178,7 @@ type clientStream struct {
 	resc          chan resAndError
 	bufPipe       pipe // buffered pipe with the flow-controlled response payload
 	requestedGzip bool
+	on100         func() // optional code to run if get a 100 continue response
 
 	flow        flow  // guarded by cc.mu
 	inflow      flow  // guarded by cc.mu
@@ -387,6 +388,13 @@ func (t *Transport) disableKeepAlives() bool {
 	return t.t1 != nil && t.t1.DisableKeepAlives
 }
 
+func (t *Transport) expectContinueTimeout() time.Duration {
+	if t.t1 == nil {
+		return 0
+	}
+	return transportExpectContinueTimeout(t.t1)
+}
+
 func (t *Transport) NewClientConn(c net.Conn) (*ClientConn, error) {
 	if VerboseLogs {
 		t.vlogf("http2: Transport creating client conn to %v", c.RemoteAddr())
@@ -593,6 +601,33 @@ func checkConnHeaders(req *http.Request) error {
 	return nil
 }
 
+func bodyAndLength(req *http.Request) (body io.Reader, contentLen int64) {
+	body = req.Body
+	if body == nil {
+		return nil, 0
+	}
+	if req.ContentLength != 0 {
+		return req.Body, req.ContentLength
+	}
+
+	// We have a body but a zero content length. Test to see if
+	// it's actually zero or just unset.
+	var buf [1]byte
+	n, rerr := io.ReadFull(body, buf[:])
+	if rerr != nil && rerr != io.EOF {
+		return errorReader{rerr}, -1
+	}
+	if n == 1 {
+		// Oh, guess there is data in this Body Reader after all.
+		// The ContentLength field just wasn't set.
+		// Stich the Body back together again, re-attaching our
+		// consumed byte.
+		return io.MultiReader(bytes.NewReader(buf[:]), body), -1
+	}
+	// Body is actually zero bytes.
+	return nil, 0
+}
+
 func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
 	if err := checkConnHeaders(req); err != nil {
 		return nil, err
@@ -604,27 +639,8 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
 	}
 	hasTrailers := trailers != ""
 
-	var body io.Reader = req.Body
-	contentLen := req.ContentLength
-	if req.Body != nil && contentLen == 0 {
-		// Test to see if it's actually zero or just unset.
-		var buf [1]byte
-		n, rerr := io.ReadFull(body, buf[:])
-		if rerr != nil && rerr != io.EOF {
-			contentLen = -1
-			body = errorReader{rerr}
-		} else if n == 1 {
-			// Oh, guess there is data in this Body Reader after all.
-			// The ContentLength field just wasn't set.
-			// Stich the Body back together again, re-attaching our
-			// consumed byte.
-			contentLen = -1
-			body = io.MultiReader(bytes.NewReader(buf[:]), body)
-		} else {
-			// Body is actually empty.
-			body = nil
-		}
-	}
+	body, contentLen := bodyAndLength(req)
+	hasBody := body != nil
 
 	cc.mu.Lock()
 	cc.lastActive = time.Now()
@@ -666,8 +682,9 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
 	cs := cc.newStream()
 	cs.req = req
 	cs.trace = requestTrace(req)
-	hasBody := body != nil
 	cs.requestedGzip = requestedGzip
+	bodyWriter := cc.t.getBodyWriterState(cs, body)
+	cs.on100 = bodyWriter.on100
 
 	cc.wmu.Lock()
 	endStream := !hasBody && !hasTrailers
@@ -679,6 +696,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
 	if werr != nil {
 		if hasBody {
 			req.Body.Close() // per RoundTripper contract
+			bodyWriter.cancel()
 		}
 		cc.forgetStreamID(cs.ID)
 		// Don't bother sending a RST_STREAM (our write already failed;
@@ -688,12 +706,8 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
 	}
 
 	var respHeaderTimer <-chan time.Time
-	var bodyCopyErrc chan error // result of body copy
 	if hasBody {
-		bodyCopyErrc = make(chan error, 1)
-		go func() {
-			bodyCopyErrc <- cs.writeRequestBody(body, req.Body)
-		}()
+		bodyWriter.scheduleBodyWrite()
 	} else {
 		traceWroteRequest(cs.trace, nil)
 		if d := cc.responseHeaderTimeout(); d != 0 {
@@ -721,6 +735,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
 				// doesn't, they'll RST_STREAM us soon enough.  This is a
 				// heuristic to avoid adding knobs to Transport.  Hopefully
 				// we can keep it.
+				bodyWriter.cancel()
 				cs.abortRequestBodyWrite(errStopReqBodyWrite)
 			}
 			if re.err != nil {
@@ -735,6 +750,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
 			if !hasBody || bodyWritten {
 				cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
 			} else {
+				bodyWriter.cancel()
 				cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
 			}
 			return nil, errTimeout
@@ -743,6 +759,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
 			if !hasBody || bodyWritten {
 				cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
 			} else {
+				bodyWriter.cancel()
 				cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
 			}
 			return nil, ctx.Err()
@@ -751,6 +768,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
 			if !hasBody || bodyWritten {
 				cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
 			} else {
+				bodyWriter.cancel()
 				cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
 			}
 			return nil, errRequestCanceled
@@ -759,8 +777,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
 			// stream from the streams map; no need for
 			// forgetStreamID.
 			return nil, cs.resetErr
-		case err := <-bodyCopyErrc:
-			traceWroteRequest(cs.trace, err)
+		case err := <-bodyWriter.resc:
 			if err != nil {
 				return nil, err
 			}
@@ -821,6 +838,7 @@ func (cs *clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) (
 	defer cc.putFrameScratchBuffer(buf)
 
 	defer func() {
+		traceWroteRequest(cs.trace, err)
 		// TODO: write h12Compare test showing whether
 		// Request.Body is closed by the Transport,
 		// and in multiple cases: server replies <=299 and >299
@@ -1281,9 +1299,10 @@ func (rl *clientConnReadLoop) handleResponse(cs *clientStream, f *MetaHeadersFra
 	}
 
 	if statusCode == 100 {
-		// Just skip 100-continue response headers for now.
-		// TODO: golang.org/issue/13851 for doing it properly.
-		// TODO: also call the httptrace.ClientTrace hooks
+		traceGot100Continue(cs.trace)
+		if cs.on100 != nil {
+			cs.on100() // forces any write delay timer to fire
+		}
 		cs.pastHeaders = false // do it all again
 		return nil, nil
 	}
@@ -1716,3 +1735,79 @@ func (gz *gzipReader) Close() error {
 type errorReader struct{ err error }
 
 func (r errorReader) Read(p []byte) (int, error) { return 0, r.err }
+
+// bodyWriterState encapsulates various state around the Transport's writing
+// of the request body, particularly regarding doing delayed writes of the body
+// when the request contains "Expect: 100-continue".
+type bodyWriterState struct {
+	cs     *clientStream
+	timer  *time.Timer   // if non-nil, we're doing a delayed write
+	fnonce *sync.Once    // to call fn with
+	fn     func()        // the code to run in the goroutine, writing the body
+	resc   chan error    // result of fn's execution
+	delay  time.Duration // how long we should delay a delayed write for
+}
+
+func (t *Transport) getBodyWriterState(cs *clientStream, body io.Reader) (s bodyWriterState) {
+	s.cs = cs
+	if body == nil {
+		return
+	}
+	resc := make(chan error, 1)
+	s.resc = resc
+	s.fn = func() {
+		resc <- cs.writeRequestBody(body, cs.req.Body)
+	}
+	s.delay = t.expectContinueTimeout()
+	if s.delay == 0 ||
+		!httplex.HeaderValuesContainsToken(
+			cs.req.Header["Expect"],
+			"100-continue") {
+		return
+	}
+	s.fnonce = new(sync.Once)
+
+	// Arm the timer with a very large duration, which we'll
+	// intentionally lower later. It has to be large now because
+	// we need a handle to it before writing the headers, but the
+	// s.delay value is defined to not start until after the
+	// request headers were written.
+	const hugeDuration = 365 * 24 * time.Hour
+	s.timer = time.AfterFunc(hugeDuration, func() {
+		s.fnonce.Do(s.fn)
+	})
+	return
+}
+
+func (s bodyWriterState) cancel() {
+	if s.timer != nil {
+		s.timer.Stop()
+	}
+}
+
+func (s bodyWriterState) on100() {
+	if s.timer == nil {
+		// If we didn't do a delayed write, ignore the server's
+		// bogus 100 continue response.
+		return
+	}
+	s.timer.Stop()
+	go func() { s.fnonce.Do(s.fn) }()
+}
+
+// scheduleBodyWrite starts writing the body, either immediately (in
+// the common case) or after the delay timeout. It should not be
+// called until after the headers have been written.
+func (s bodyWriterState) scheduleBodyWrite() {
+	if s.timer == nil {
+		// We're not doing a delayed write (see
+		// getBodyWriterState), so just start the writing
+		// goroutine immediately.
+		go s.fn()
+		return
+	}
+	traceWait100Continue(s.cs.trace)
+	if s.timer.Stop() {
+		s.timer.Reset(s.delay)
+	}
+}