author | Brad Fitzpatrick
<bradfitz@golang.org> 2016-05-19 05:29:09 UTC |
committer | Brad Fitzpatrick
<bradfitz@golang.org> 2016-05-20 02:53:00 UTC |
parent | 5916dcb167ed985a5b9e6871fbfd74848a4c170b |
http2/go16.go | +16 | -0 |
http2/go17.go | +14 | -2 |
http2/not_go16.go | +8 | -1 |
http2/not_go17.go | +2 | -0 |
http2/transport.go | +127 | -32 |
diff --git a/http2/go16.go b/http2/go16.go new file mode 100644 index 0000000..00b2e9e --- /dev/null +++ b/http2/go16.go @@ -0,0 +1,16 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.6 + +package http2 + +import ( + "net/http" + "time" +) + +func transportExpectContinueTimeout(t1 *http.Transport) time.Duration { + return t1.ExpectContinueTimeout +} diff --git a/http2/go17.go b/http2/go17.go index 3d3c71e..730319d 100644 --- a/http2/go17.go +++ b/http2/go17.go @@ -49,8 +49,8 @@ func traceGotConn(req *http.Request, cc *ClientConn) { ci := httptrace.GotConnInfo{Conn: cc.tconn} cc.mu.Lock() ci.Reused = cc.nextStreamID > 1 - ci.WasIdle = len(cc.streams) == 0 - if ci.WasIdle { + ci.WasIdle = len(cc.streams) == 0 && ci.Reused + if ci.WasIdle && !cc.lastActive.IsZero() { ci.IdleTime = time.Now().Sub(cc.lastActive) } cc.mu.Unlock() @@ -64,6 +64,18 @@ func traceWroteHeaders(trace *clientTrace) { } } +func traceGot100Continue(trace *clientTrace) { + if trace != nil && trace.Got100Continue != nil { + trace.Got100Continue() + } +} + +func traceWait100Continue(trace *clientTrace) { + if trace != nil && trace.Wait100Continue != nil { + trace.Wait100Continue() + } +} + func traceWroteRequest(trace *clientTrace, err error) { if trace != nil && trace.WroteRequest != nil { trace.WroteRequest(httptrace.WroteRequestInfo{Err: err}) diff --git a/http2/not_go16.go b/http2/not_go16.go index db53c5b..51a7f19 100644 --- a/http2/not_go16.go +++ b/http2/not_go16.go @@ -6,8 +6,15 @@ package http2 -import "net/http" +import ( + "net/http" + "time" +) func configureTransport(t1 *http.Transport) (*Transport, error) { return nil, errTransportVersion } + +func transportExpectContinueTimeout(t1 *http.Transport) time.Duration { + return 0 +} diff --git a/http2/not_go17.go b/http2/not_go17.go index 077db39..28df0c1 100644 --- a/http2/not_go17.go +++ b/http2/not_go17.go @@ -33,6 +33,8 @@ func traceGotConn(*http.Request, *ClientConn) {} func traceFirstResponseByte(*clientTrace) {} func traceWroteHeaders(*clientTrace) {} func traceWroteRequest(*clientTrace, error) {} +func traceGot100Continue(trace *clientTrace) {} +func traceWait100Continue(trace *clientTrace) {} func nop() {} diff --git a/http2/transport.go b/http2/transport.go index 5f1564a..03712d5 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -178,6 +178,7 @@ type clientStream struct { resc chan resAndError bufPipe pipe // buffered pipe with the flow-controlled response payload requestedGzip bool + on100 func() // optional code to run if get a 100 continue response flow flow // guarded by cc.mu inflow flow // guarded by cc.mu @@ -387,6 +388,13 @@ func (t *Transport) disableKeepAlives() bool { return t.t1 != nil && t.t1.DisableKeepAlives } +func (t *Transport) expectContinueTimeout() time.Duration { + if t.t1 == nil { + return 0 + } + return transportExpectContinueTimeout(t.t1) +} + func (t *Transport) NewClientConn(c net.Conn) (*ClientConn, error) { if VerboseLogs { t.vlogf("http2: Transport creating client conn to %v", c.RemoteAddr()) @@ -593,6 +601,33 @@ func checkConnHeaders(req *http.Request) error { return nil } +func bodyAndLength(req *http.Request) (body io.Reader, contentLen int64) { + body = req.Body + if body == nil { + return nil, 0 + } + if req.ContentLength != 0 { + return req.Body, req.ContentLength + } + + // We have a body but a zero content length. Test to see if + // it's actually zero or just unset. + var buf [1]byte + n, rerr := io.ReadFull(body, buf[:]) + if rerr != nil && rerr != io.EOF { + return errorReader{rerr}, -1 + } + if n == 1 { + // Oh, guess there is data in this Body Reader after all. + // The ContentLength field just wasn't set. + // Stich the Body back together again, re-attaching our + // consumed byte. + return io.MultiReader(bytes.NewReader(buf[:]), body), -1 + } + // Body is actually zero bytes. + return nil, 0 +} + func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { if err := checkConnHeaders(req); err != nil { return nil, err @@ -604,27 +639,8 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { } hasTrailers := trailers != "" - var body io.Reader = req.Body - contentLen := req.ContentLength - if req.Body != nil && contentLen == 0 { - // Test to see if it's actually zero or just unset. - var buf [1]byte - n, rerr := io.ReadFull(body, buf[:]) - if rerr != nil && rerr != io.EOF { - contentLen = -1 - body = errorReader{rerr} - } else if n == 1 { - // Oh, guess there is data in this Body Reader after all. - // The ContentLength field just wasn't set. - // Stich the Body back together again, re-attaching our - // consumed byte. - contentLen = -1 - body = io.MultiReader(bytes.NewReader(buf[:]), body) - } else { - // Body is actually empty. - body = nil - } - } + body, contentLen := bodyAndLength(req) + hasBody := body != nil cc.mu.Lock() cc.lastActive = time.Now() @@ -666,8 +682,9 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { cs := cc.newStream() cs.req = req cs.trace = requestTrace(req) - hasBody := body != nil cs.requestedGzip = requestedGzip + bodyWriter := cc.t.getBodyWriterState(cs, body) + cs.on100 = bodyWriter.on100 cc.wmu.Lock() endStream := !hasBody && !hasTrailers @@ -679,6 +696,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { if werr != nil { if hasBody { req.Body.Close() // per RoundTripper contract + bodyWriter.cancel() } cc.forgetStreamID(cs.ID) // Don't bother sending a RST_STREAM (our write already failed; @@ -688,12 +706,8 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { } var respHeaderTimer <-chan time.Time - var bodyCopyErrc chan error // result of body copy if hasBody { - bodyCopyErrc = make(chan error, 1) - go func() { - bodyCopyErrc <- cs.writeRequestBody(body, req.Body) - }() + bodyWriter.scheduleBodyWrite() } else { traceWroteRequest(cs.trace, nil) if d := cc.responseHeaderTimeout(); d != 0 { @@ -721,6 +735,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { // doesn't, they'll RST_STREAM us soon enough. This is a // heuristic to avoid adding knobs to Transport. Hopefully // we can keep it. + bodyWriter.cancel() cs.abortRequestBodyWrite(errStopReqBodyWrite) } if re.err != nil { @@ -735,6 +750,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { if !hasBody || bodyWritten { cc.writeStreamReset(cs.ID, ErrCodeCancel, nil) } else { + bodyWriter.cancel() cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel) } return nil, errTimeout @@ -743,6 +759,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { if !hasBody || bodyWritten { cc.writeStreamReset(cs.ID, ErrCodeCancel, nil) } else { + bodyWriter.cancel() cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel) } return nil, ctx.Err() @@ -751,6 +768,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { if !hasBody || bodyWritten { cc.writeStreamReset(cs.ID, ErrCodeCancel, nil) } else { + bodyWriter.cancel() cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel) } return nil, errRequestCanceled @@ -759,8 +777,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { // stream from the streams map; no need for // forgetStreamID. return nil, cs.resetErr - case err := <-bodyCopyErrc: - traceWroteRequest(cs.trace, err) + case err := <-bodyWriter.resc: if err != nil { return nil, err } @@ -821,6 +838,7 @@ func (cs *clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) ( defer cc.putFrameScratchBuffer(buf) defer func() { + traceWroteRequest(cs.trace, err) // TODO: write h12Compare test showing whether // Request.Body is closed by the Transport, // and in multiple cases: server replies <=299 and >299 @@ -1281,9 +1299,10 @@ func (rl *clientConnReadLoop) handleResponse(cs *clientStream, f *MetaHeadersFra } if statusCode == 100 { - // Just skip 100-continue response headers for now. - // TODO: golang.org/issue/13851 for doing it properly. - // TODO: also call the httptrace.ClientTrace hooks + traceGot100Continue(cs.trace) + if cs.on100 != nil { + cs.on100() // forces any write delay timer to fire + } cs.pastHeaders = false // do it all again return nil, nil } @@ -1716,3 +1735,79 @@ func (gz *gzipReader) Close() error { type errorReader struct{ err error } func (r errorReader) Read(p []byte) (int, error) { return 0, r.err } + +// bodyWriterState encapsulates various state around the Transport's writing +// of the request body, particularly regarding doing delayed writes of the body +// when the request contains "Expect: 100-continue". +type bodyWriterState struct { + cs *clientStream + timer *time.Timer // if non-nil, we're doing a delayed write + fnonce *sync.Once // to call fn with + fn func() // the code to run in the goroutine, writing the body + resc chan error // result of fn's execution + delay time.Duration // how long we should delay a delayed write for +} + +func (t *Transport) getBodyWriterState(cs *clientStream, body io.Reader) (s bodyWriterState) { + s.cs = cs + if body == nil { + return + } + resc := make(chan error, 1) + s.resc = resc + s.fn = func() { + resc <- cs.writeRequestBody(body, cs.req.Body) + } + s.delay = t.expectContinueTimeout() + if s.delay == 0 || + !httplex.HeaderValuesContainsToken( + cs.req.Header["Expect"], + "100-continue") { + return + } + s.fnonce = new(sync.Once) + + // Arm the timer with a very large duration, which we'll + // intentionally lower later. It has to be large now because + // we need a handle to it before writing the headers, but the + // s.delay value is defined to not start until after the + // request headers were written. + const hugeDuration = 365 * 24 * time.Hour + s.timer = time.AfterFunc(hugeDuration, func() { + s.fnonce.Do(s.fn) + }) + return +} + +func (s bodyWriterState) cancel() { + if s.timer != nil { + s.timer.Stop() + } +} + +func (s bodyWriterState) on100() { + if s.timer == nil { + // If we didn't do a delayed write, ignore the server's + // bogus 100 continue response. + return + } + s.timer.Stop() + go func() { s.fnonce.Do(s.fn) }() +} + +// scheduleBodyWrite starts writing the body, either immediately (in +// the common case) or after the delay timeout. It should not be +// called until after the headers have been written. +func (s bodyWriterState) scheduleBodyWrite() { + if s.timer == nil { + // We're not doing a delayed write (see + // getBodyWriterState), so just start the writing + // goroutine immediately. + go s.fn() + return + } + traceWait100Continue(s.cs.trace) + if s.timer.Stop() { + s.timer.Reset(s.delay) + } +}