git » go-net » commit 8aacbec

http2: with Go 1.7 set Request.Context in ServeHTTP handlers

author Brad Fitzpatrick
2016-05-18 19:54:31 UTC
committer Brad Fitzpatrick
2016-05-18 21:19:18 UTC
parent 3b993948b6f0e651ffb58ba135d8538a68b1cddf

http2: with Go 1.7 set Request.Context in ServeHTTP handlers

And act the same as HTTP/1.x in Go 1.7.

Updates golang/go#15134

Change-Id: Ib64dd82cc5f8dd60e1680525f664d5b72be11fc6
Reviewed-on: https://go-review.googlesource.com/23220
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Andrew Gerrand <adg@golang.org>

http2/go17.go +22 -0
http2/not_go17.go +20 -1
http2/server.go +21 -7

diff --git a/http2/go17.go b/http2/go17.go
index 2e2eabd..3d3c71e 100644
--- a/http2/go17.go
+++ b/http2/go17.go
@@ -8,11 +8,33 @@ package http2
 
 import (
 	"context"
+	"net"
 	"net/http"
 	"net/http/httptrace"
 	"time"
 )
 
+type contextContext interface {
+	context.Context
+}
+
+func serverConnBaseContext(c net.Conn, opts *ServeConnOpts) (ctx contextContext, cancel func()) {
+	ctx, cancel = context.WithCancel(context.Background())
+	ctx = context.WithValue(ctx, http.LocalAddrContextKey, c.LocalAddr())
+	if hs := opts.baseConfig(); hs != nil {
+		ctx = context.WithValue(ctx, http.ServerContextKey, hs)
+	}
+	return
+}
+
+func contextWithCancel(ctx contextContext) (_ contextContext, cancel func()) {
+	return context.WithCancel(ctx)
+}
+
+func requestWithContext(req *http.Request, ctx contextContext) *http.Request {
+	return req.WithContext(ctx)
+}
+
 type clientTrace httptrace.ClientTrace
 
 func reqContext(r *http.Request) context.Context { return r.Context() }
diff --git a/http2/not_go17.go b/http2/not_go17.go
index deffe68..077db39 100644
--- a/http2/not_go17.go
+++ b/http2/not_go17.go
@@ -6,7 +6,12 @@
 
 package http2
 
-import "net/http"
+import (
+	"net"
+	"net/http"
+)
+
+type contextContext interface{}
 
 type fakeContext struct{}
 
@@ -28,3 +33,17 @@ func traceGotConn(*http.Request, *ClientConn) {}
 func traceFirstResponseByte(*clientTrace)     {}
 func traceWroteHeaders(*clientTrace)          {}
 func traceWroteRequest(*clientTrace, error)   {}
+
+func nop() {}
+
+func serverConnBaseContext(c net.Conn, opts *ServeConnOpts) (ctx contextContext, cancel func()) {
+	return nil, nop
+}
+
+func contextWithCancel(ctx contextContext) (_ contextContext, cancel func()) {
+	return ctx, nop
+}
+
+func requestWithContext(req *http.Request, ctx contextContext) *http.Request {
+	return req
+}
diff --git a/http2/server.go b/http2/server.go
index 3a46db6..a2b6c4b 100644
--- a/http2/server.go
+++ b/http2/server.go
@@ -250,10 +250,14 @@ func (o *ServeConnOpts) handler() http.Handler {
 //
 // The opts parameter is optional. If nil, default values are used.
 func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
+	baseCtx, cancel := serverConnBaseContext(c, opts)
+	defer cancel()
+
 	sc := &serverConn{
 		srv:              s,
 		hs:               opts.baseConfig(),
 		conn:             c,
+		baseCtx:          baseCtx,
 		remoteAddrStr:    c.RemoteAddr().String(),
 		bw:               newBufferedWriter(c),
 		handler:          opts.handler(),
@@ -272,6 +276,7 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
 		serveG:            newGoroutineLock(),
 		pushEnabled:       true,
 	}
+
 	sc.flow.add(initialWindowSize)
 	sc.inflow.add(initialWindowSize)
 	sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf)
@@ -373,6 +378,7 @@ type serverConn struct {
 	conn             net.Conn
 	bw               *bufferedWriter // writing to conn
 	handler          http.Handler
+	baseCtx          contextContext
 	framer           *Framer
 	doneServing      chan struct{}         // closed when serverConn.serve ends
 	readFrameCh      chan readFrameResult  // written by serverConn.readFrames
@@ -436,10 +442,12 @@ func (sc *serverConn) maxHeaderListSize() uint32 {
 // responseWriter's state field.
 type stream struct {
 	// immutable:
-	sc   *serverConn
-	id   uint32
-	body *pipe       // non-nil if expecting DATA frames
-	cw   closeWaiter // closed wait stream transitions to closed state
+	sc        *serverConn
+	id        uint32
+	body      *pipe       // non-nil if expecting DATA frames
+	cw        closeWaiter // closed wait stream transitions to closed state
+	ctx       contextContext
+	cancelCtx func()
 
 	// owned by serverConn's serve loop:
 	bodyBytes        int64   // body bytes seen so far
@@ -1157,6 +1165,7 @@ func (sc *serverConn) processResetStream(f *RSTStreamFrame) error {
 	}
 	if st != nil {
 		st.gotReset = true
+		st.cancelCtx()
 		sc.closeStream(st, StreamError{f.StreamID, f.ErrCode})
 	}
 	return nil
@@ -1380,10 +1389,13 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error {
 	}
 	sc.maxStreamID = id
 
+	ctx, cancelCtx := contextWithCancel(sc.baseCtx)
 	st = &stream{
-		sc:    sc,
-		id:    id,
-		state: stateOpen,
+		sc:        sc,
+		id:        id,
+		state:     stateOpen,
+		ctx:       ctx,
+		cancelCtx: cancelCtx,
 	}
 	if f.StreamEnded() {
 		st.state = stateHalfClosedRemote
@@ -1617,6 +1629,7 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res
 		Body:       body,
 		Trailer:    trailer,
 	}
+	req = requestWithContext(req, st.ctx)
 	if bodyOpen {
 		// Disabled, per golang.org/issue/14960:
 		// st.reqBuf = sc.getRequestBodyBuf()
@@ -1661,6 +1674,7 @@ func (sc *serverConn) getRequestBodyBuf() []byte {
 func (sc *serverConn) runHandler(rw *responseWriter, req *http.Request, handler func(http.ResponseWriter, *http.Request)) {
 	didPanic := true
 	defer func() {
+		rw.rws.stream.cancelCtx()
 		if didPanic {
 			e := recover()
 			// Same as net/http: