author | Brad Fitzpatrick
<bradfitz@golang.org> 2016-05-18 19:54:31 UTC |
committer | Brad Fitzpatrick
<bradfitz@golang.org> 2016-05-18 21:19:18 UTC |
parent | 3b993948b6f0e651ffb58ba135d8538a68b1cddf |
http2/go17.go | +22 | -0 |
http2/not_go17.go | +20 | -1 |
http2/server.go | +21 | -7 |
diff --git a/http2/go17.go b/http2/go17.go index 2e2eabd..3d3c71e 100644 --- a/http2/go17.go +++ b/http2/go17.go @@ -8,11 +8,33 @@ package http2 import ( "context" + "net" "net/http" "net/http/httptrace" "time" ) +type contextContext interface { + context.Context +} + +func serverConnBaseContext(c net.Conn, opts *ServeConnOpts) (ctx contextContext, cancel func()) { + ctx, cancel = context.WithCancel(context.Background()) + ctx = context.WithValue(ctx, http.LocalAddrContextKey, c.LocalAddr()) + if hs := opts.baseConfig(); hs != nil { + ctx = context.WithValue(ctx, http.ServerContextKey, hs) + } + return +} + +func contextWithCancel(ctx contextContext) (_ contextContext, cancel func()) { + return context.WithCancel(ctx) +} + +func requestWithContext(req *http.Request, ctx contextContext) *http.Request { + return req.WithContext(ctx) +} + type clientTrace httptrace.ClientTrace func reqContext(r *http.Request) context.Context { return r.Context() } diff --git a/http2/not_go17.go b/http2/not_go17.go index deffe68..077db39 100644 --- a/http2/not_go17.go +++ b/http2/not_go17.go @@ -6,7 +6,12 @@ package http2 -import "net/http" +import ( + "net" + "net/http" +) + +type contextContext interface{} type fakeContext struct{} @@ -28,3 +33,17 @@ func traceGotConn(*http.Request, *ClientConn) {} func traceFirstResponseByte(*clientTrace) {} func traceWroteHeaders(*clientTrace) {} func traceWroteRequest(*clientTrace, error) {} + +func nop() {} + +func serverConnBaseContext(c net.Conn, opts *ServeConnOpts) (ctx contextContext, cancel func()) { + return nil, nop +} + +func contextWithCancel(ctx contextContext) (_ contextContext, cancel func()) { + return ctx, nop +} + +func requestWithContext(req *http.Request, ctx contextContext) *http.Request { + return req +} diff --git a/http2/server.go b/http2/server.go index 3a46db6..a2b6c4b 100644 --- a/http2/server.go +++ b/http2/server.go @@ -250,10 +250,14 @@ func (o *ServeConnOpts) handler() http.Handler { // // The opts parameter is optional. If nil, default values are used. func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) { + baseCtx, cancel := serverConnBaseContext(c, opts) + defer cancel() + sc := &serverConn{ srv: s, hs: opts.baseConfig(), conn: c, + baseCtx: baseCtx, remoteAddrStr: c.RemoteAddr().String(), bw: newBufferedWriter(c), handler: opts.handler(), @@ -272,6 +276,7 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) { serveG: newGoroutineLock(), pushEnabled: true, } + sc.flow.add(initialWindowSize) sc.inflow.add(initialWindowSize) sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf) @@ -373,6 +378,7 @@ type serverConn struct { conn net.Conn bw *bufferedWriter // writing to conn handler http.Handler + baseCtx contextContext framer *Framer doneServing chan struct{} // closed when serverConn.serve ends readFrameCh chan readFrameResult // written by serverConn.readFrames @@ -436,10 +442,12 @@ func (sc *serverConn) maxHeaderListSize() uint32 { // responseWriter's state field. type stream struct { // immutable: - sc *serverConn - id uint32 - body *pipe // non-nil if expecting DATA frames - cw closeWaiter // closed wait stream transitions to closed state + sc *serverConn + id uint32 + body *pipe // non-nil if expecting DATA frames + cw closeWaiter // closed wait stream transitions to closed state + ctx contextContext + cancelCtx func() // owned by serverConn's serve loop: bodyBytes int64 // body bytes seen so far @@ -1157,6 +1165,7 @@ func (sc *serverConn) processResetStream(f *RSTStreamFrame) error { } if st != nil { st.gotReset = true + st.cancelCtx() sc.closeStream(st, StreamError{f.StreamID, f.ErrCode}) } return nil @@ -1380,10 +1389,13 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error { } sc.maxStreamID = id + ctx, cancelCtx := contextWithCancel(sc.baseCtx) st = &stream{ - sc: sc, - id: id, - state: stateOpen, + sc: sc, + id: id, + state: stateOpen, + ctx: ctx, + cancelCtx: cancelCtx, } if f.StreamEnded() { st.state = stateHalfClosedRemote @@ -1617,6 +1629,7 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res Body: body, Trailer: trailer, } + req = requestWithContext(req, st.ctx) if bodyOpen { // Disabled, per golang.org/issue/14960: // st.reqBuf = sc.getRequestBodyBuf() @@ -1661,6 +1674,7 @@ func (sc *serverConn) getRequestBodyBuf() []byte { func (sc *serverConn) runHandler(rw *responseWriter, req *http.Request, handler func(http.ResponseWriter, *http.Request)) { didPanic := true defer func() { + rw.rws.stream.cancelCtx() if didPanic { e := recover() // Same as net/http: