author | Brad Fitzpatrick
<bradfitz@golang.org> 2016-07-30 20:25:50 UTC |
committer | Brad Fitzpatrick
<bradfitz@golang.org> 2016-08-01 23:26:40 UTC |
parent | 35028a49ca5a73b486af60cd20ac21cd6b67bfdb |
http2/transport.go | +75 | -71 |
http2/transport_test.go | +50 | -21 |
diff --git a/http2/transport.go b/http2/transport.go index a81445d..f6019db 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -148,19 +148,20 @@ type ClientConn struct { readerDone chan struct{} // closed on error readerErr error // set before readerDone is closed - mu sync.Mutex // guards following - cond *sync.Cond // hold mu; broadcast on flow/closed changes - flow flow // our conn-level flow control quota (cs.flow is per stream) - inflow flow // peer's conn-level flow control - closed bool - goAway *GoAwayFrame // if non-nil, the GoAwayFrame we received - goAwayDebug string // goAway frame's debug data, retained as a string - streams map[uint32]*clientStream // client-initiated - nextStreamID uint32 - bw *bufio.Writer - br *bufio.Reader - fr *Framer - lastActive time.Time + mu sync.Mutex // guards following + cond *sync.Cond // hold mu; broadcast on flow/closed changes + flow flow // our conn-level flow control quota (cs.flow is per stream) + inflow flow // peer's conn-level flow control + closed bool + wantSettingsAck bool // we sent a SETTINGS frame and haven't heard back + goAway *GoAwayFrame // if non-nil, the GoAwayFrame we received + goAwayDebug string // goAway frame's debug data, retained as a string + streams map[uint32]*clientStream // client-initiated + nextStreamID uint32 + bw *bufio.Writer + br *bufio.Reader + fr *Framer + lastActive time.Time // Settings from peer: maxFrameSize uint32 @@ -416,10 +417,6 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro if VerboseLogs { t.vlogf("http2: Transport creating client conn to %v", c.RemoteAddr()) } - if _, err := c.Write(clientPreface); err != nil { - t.vlogf("client preface write error: %v", err) - return nil, err - } cc := &ClientConn{ t: t, @@ -431,6 +428,7 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro maxConcurrentStreams: 1000, // "infinite", per spec. 1000 seems good enough. streams: make(map[uint32]*clientStream), singleUse: singleUse, + wantSettingsAck: true, } cc.cond = sync.NewCond(&cc.mu) cc.flow.add(int32(initialWindowSize)) @@ -459,6 +457,8 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro if max := t.maxHeaderListSize(); max != 0 { initialSettings = append(initialSettings, Setting{ID: SettingMaxHeaderListSize, Val: max}) } + + cc.bw.Write(clientPreface) cc.fr.WriteSettings(initialSettings...) cc.fr.WriteWindowUpdate(0, transportDefaultConnFlow) cc.inflow.add(transportDefaultConnFlow + initialWindowSize) @@ -467,33 +467,6 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro return nil, cc.werr } - // Read the obligatory SETTINGS frame - f, err := cc.fr.ReadFrame() - if err != nil { - return nil, err - } - sf, ok := f.(*SettingsFrame) - if !ok { - return nil, fmt.Errorf("expected settings frame, got: %T", f) - } - cc.fr.WriteSettingsAck() - cc.bw.Flush() - - sf.ForeachSetting(func(s Setting) error { - switch s.ID { - case SettingMaxFrameSize: - cc.maxFrameSize = s.Val - case SettingMaxConcurrentStreams: - cc.maxConcurrentStreams = s.Val - case SettingInitialWindowSize: - cc.initialWindowSize = s.Val - default: - // TODO(bradfitz): handle more; at least SETTINGS_HEADER_TABLE_SIZE? - t.vlogf("Unhandled Setting: %v", s) - } - return nil - }) - go cc.readLoop() return cc, nil } @@ -936,28 +909,26 @@ func (cs *clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) ( } } + var trls []byte + if !sentEnd && hasTrailers { + cc.mu.Lock() + defer cc.mu.Unlock() + trls = cc.encodeTrailers(req) + } + cc.wmu.Lock() - if !sentEnd { - var trls []byte - if hasTrailers { - cc.mu.Lock() - trls = cc.encodeTrailers(req) - cc.mu.Unlock() - } + defer cc.wmu.Unlock() - // Avoid forgetting to send an END_STREAM if the encoded - // trailers are 0 bytes. Both results produce and END_STREAM. - if len(trls) > 0 { - err = cc.writeHeaders(cs.ID, true, trls) - } else { - err = cc.fr.WriteData(cs.ID, true, nil) - } + // Avoid forgetting to send an END_STREAM if the encoded + // trailers are 0 bytes. Both results produce and END_STREAM. + if len(trls) > 0 { + err = cc.writeHeaders(cs.ID, true, trls) + } else { + err = cc.fr.WriteData(cs.ID, true, nil) } if ferr := cc.bw.Flush(); ferr != nil && err == nil { err = ferr } - cc.wmu.Unlock() - return err } @@ -1203,6 +1174,14 @@ func (e GoAwayError) Error() string { e.LastStreamID, e.ErrCode, e.DebugData) } +func isEOFOrNetReadError(err error) bool { + if err == io.EOF { + return true + } + ne, ok := err.(*net.OpError) + return ok && ne.Op == "read" +} + func (rl *clientConnReadLoop) cleanup() { cc := rl.cc defer cc.tconn.Close() @@ -1214,16 +1193,14 @@ func (rl *clientConnReadLoop) cleanup() { // gotten a response yet. err := cc.readerErr cc.mu.Lock() - if err == io.EOF { - if cc.goAway != nil { - err = GoAwayError{ - LastStreamID: cc.goAway.LastStreamID, - ErrCode: cc.goAway.ErrCode, - DebugData: cc.goAwayDebug, - } - } else { - err = io.ErrUnexpectedEOF + if cc.goAway != nil && isEOFOrNetReadError(err) { + err = GoAwayError{ + LastStreamID: cc.goAway.LastStreamID, + ErrCode: cc.goAway.ErrCode, + DebugData: cc.goAwayDebug, } + } else if err == io.EOF { + err = io.ErrUnexpectedEOF } for _, cs := range rl.activeRes { cs.bufPipe.CloseWithError(err) @@ -1243,7 +1220,8 @@ func (rl *clientConnReadLoop) cleanup() { func (rl *clientConnReadLoop) run() error { cc := rl.cc rl.closeWhenIdle = cc.t.disableKeepAlives() || cc.singleUse - gotReply := false // ever saw a reply + gotReply := false // ever saw a HEADERS reply + gotSettings := false for { f, err := cc.fr.ReadFrame() if err != nil { @@ -1260,6 +1238,13 @@ func (rl *clientConnReadLoop) run() error { if VerboseLogs { cc.vlogf("http2: Transport received %s", summarizeFrame(f)) } + if !gotSettings { + if _, ok := f.(*SettingsFrame); !ok { + cc.logf("protocol error: received %T before a SETTINGS frame", f) + return ConnectionError(ErrCodeProtocol) + } + gotSettings = true + } maybeIdle := false // whether frame might transition us to idle switch f := f.(type) { @@ -1681,7 +1666,16 @@ func (rl *clientConnReadLoop) processSettings(f *SettingsFrame) error { cc := rl.cc cc.mu.Lock() defer cc.mu.Unlock() - return f.ForeachSetting(func(s Setting) error { + + if f.IsAck() { + if cc.wantSettingsAck { + cc.wantSettingsAck = false + return nil + } + return ConnectionError(ErrCodeProtocol) + } + + err := f.ForeachSetting(func(s Setting) error { switch s.ID { case SettingMaxFrameSize: cc.maxFrameSize = s.Val @@ -1700,6 +1694,16 @@ func (rl *clientConnReadLoop) processSettings(f *SettingsFrame) error { } return nil }) + if err != nil { + return err + } + + cc.wmu.Lock() + defer cc.wmu.Unlock() + + cc.fr.WriteSettingsAck() + cc.bw.Flush() + return cc.werr } func (rl *clientConnReadLoop) processWindowUpdate(f *WindowUpdateFrame) error { diff --git a/http2/transport_test.go b/http2/transport_test.go index 4f3b8a1..614fa44 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -652,6 +652,19 @@ func (ct *clientTester) greet() { } } +func (ct *clientTester) readNonSettingsFrame() (Frame, error) { + for { + f, err := ct.fr.ReadFrame() + if err != nil { + return nil, err + } + if _, ok := f.(*SettingsFrame); ok { + continue + } + return f, nil + } +} + func (ct *clientTester) cleanup() { ct.tr.CloseIdleConnections() } @@ -703,8 +716,12 @@ func TestTransportReqBodyAfterResponse_403(t *testing.T) { testTransportReqBodyA func testTransportReqBodyAfterResponse(t *testing.T, status int) { const bodySize = 10 << 20 + clientDone := make(chan struct{}) ct := newClientTester(t) ct.client = func() error { + defer ct.cc.(*net.TCPConn).CloseWrite() + defer close(clientDone) + var n int64 // atomic req, err := http.NewRequest("PUT", "https://dummy.tld/", io.LimitReader(countingReader{&n}, bodySize)) if err != nil { @@ -745,7 +762,15 @@ func testTransportReqBodyAfterResponse(t *testing.T, status int) { for { f, err := ct.fr.ReadFrame() if err != nil { - return err + select { + case <-clientDone: + // If the client's done, it + // will have reported any + // errors on its side. + return nil + default: + return err + } } //println(fmt.Sprintf("server got frame: %v", f)) switch f := f.(type) { @@ -784,7 +809,6 @@ func testTransportReqBodyAfterResponse(t *testing.T, status int) { if err := ct.fr.WriteData(f.StreamID, true, nil); err != nil { return err } - return nil } default: return fmt.Errorf("Unexpected client frame %v", f) @@ -2090,7 +2114,7 @@ func testTransportUsesGoAwayDebugError(t *testing.T, failMidBody bool) { // the interesting parts of both. ct.fr.WriteGoAway(5, ErrCodeNo, []byte(goAwayDebugData)) ct.fr.WriteGoAway(5, goAwayErrCode, nil) - ct.sc.Close() + ct.sc.(*net.TCPConn).CloseWrite() <-clientDone return nil } @@ -2157,23 +2181,28 @@ func TestTransportReturnsUnusedFlowControl(t *testing.T) { <-clientClosed - f, err := ct.fr.ReadFrame() - if err != nil { - return fmt.Errorf("ReadFrame while waiting for RSTStreamFrame: %v", err) - } - if rf, ok := f.(*RSTStreamFrame); !ok || rf.ErrCode != ErrCodeCancel { - return fmt.Errorf("Expected a WindowUpdateFrame with code cancel; got %v", summarizeFrame(f)) - } - - // And wait for our flow control tokens back: - f, err = ct.fr.ReadFrame() - if err != nil { - return fmt.Errorf("ReadFrame while waiting for WindowUpdateFrame: %v", err) - } - if wuf, ok := f.(*WindowUpdateFrame); !ok || wuf.Increment != 4999 { - return fmt.Errorf("Expected WindowUpdateFrame for 4999 bytes; got %v", summarizeFrame(f)) + waitingFor := "RSTStreamFrame" + for { + f, err := ct.fr.ReadFrame() + if err != nil { + return fmt.Errorf("ReadFrame while waiting for %s: %v", waitingFor, err) + } + if _, ok := f.(*SettingsFrame); ok { + continue + } + switch waitingFor { + case "RSTStreamFrame": + if rf, ok := f.(*RSTStreamFrame); !ok || rf.ErrCode != ErrCodeCancel { + return fmt.Errorf("Expected a WindowUpdateFrame with code cancel; got %v", summarizeFrame(f)) + } + waitingFor = "WindowUpdateFrame" + case "WindowUpdateFrame": + if wuf, ok := f.(*WindowUpdateFrame); !ok || wuf.Increment != 4999 { + return fmt.Errorf("Expected WindowUpdateFrame for 4999 bytes; got %v", summarizeFrame(f)) + } + return nil + } } - return nil } ct.run() } @@ -2228,7 +2257,7 @@ func TestTransportReturnsDataPaddingFlowControl(t *testing.T) { pad := []byte("12345") ct.fr.WriteDataPadded(hf.StreamID, false, make([]byte, 5000), pad) // without ending stream - f, err := ct.fr.ReadFrame() + f, err := ct.readNonSettingsFrame() if err != nil { return fmt.Errorf("ReadFrame while waiting for first WindowUpdateFrame: %v", err) } @@ -2237,7 +2266,7 @@ func TestTransportReturnsDataPaddingFlowControl(t *testing.T) { return fmt.Errorf("Expected conn WindowUpdateFrame for %d bytes; got %v", wantBack, summarizeFrame(f)) } - f, err = ct.fr.ReadFrame() + f, err = ct.readNonSettingsFrame() if err != nil { return fmt.Errorf("ReadFrame while waiting for second WindowUpdateFrame: %v", err) }