diff --git a/http3/client.go b/http3/client.go index 8b05a6b7..1d80fa70 100644 --- a/http3/client.go +++ b/http3/client.go @@ -160,6 +160,7 @@ func (c *client) handleUnidirectionalStreams() { c.session.CloseWithError(quic.ErrorCode(errorIDError), "") return default: + str.CancelRead(quic.ErrorCode(errorStreamCreationError)) return } f, err := parseNextFrame(str) diff --git a/http3/client_test.go b/http3/client_test.go index bda93e3e..a03d13eb 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -234,22 +234,17 @@ var _ = Describe("Client", func() { }) It("ignores streams other than the control stream", func() { - controlBuf := &bytes.Buffer{} - utils.WriteVarInt(controlBuf, streamTypeControlStream) - (&settingsFrame{}).Write(controlBuf) - controlStr := mockquic.NewMockStream(mockCtrl) - controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(controlBuf.Read).AnyTimes() - - otherBuf := &bytes.Buffer{} - utils.WriteVarInt(otherBuf, 1337) - otherStr := mockquic.NewMockStream(mockCtrl) - otherStr.EXPECT().Read(gomock.Any()).DoAndReturn(otherBuf.Read).AnyTimes() - - sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - return otherStr, nil + buf := &bytes.Buffer{} + utils.WriteVarInt(buf, 1337) + str := mockquic.NewMockStream(mockCtrl) + str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + done := make(chan struct{}) + str.EXPECT().CancelRead(quic.ErrorCode(errorStreamCreationError)).Do(func(code quic.ErrorCode) { + close(done) }) + sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - return controlStr, nil + return str, nil }) sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { <-testDone @@ -257,7 +252,7 @@ var _ = Describe("Client", func() { }) _, err := client.RoundTrip(request) Expect(err).To(MatchError("done")) - time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to sess.CloseWithError + Eventually(done).Should(BeClosed()) }) It("errors when the first frame on the control stream is not a SETTINGS frame", func() { diff --git a/http3/server.go b/http3/server.go index dd4dcf29..bca01fe1 100644 --- a/http3/server.go +++ b/http3/server.go @@ -294,6 +294,7 @@ func (s *Server) handleUnidirectionalStreams(sess quic.EarlySession) { sess.CloseWithError(quic.ErrorCode(errorStreamCreationError), "") return default: + str.CancelRead(quic.ErrorCode(errorStreamCreationError)) return } f, err := parseNextFrame(str) diff --git a/http3/server_test.go b/http3/server_test.go index 98596776..8f081838 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -209,29 +209,24 @@ var _ = Describe("Server", func() { }) It("ignores streams other than the control stream", func() { - controlBuf := &bytes.Buffer{} - utils.WriteVarInt(controlBuf, streamTypeControlStream) - (&settingsFrame{}).Write(controlBuf) - controlStr := mockquic.NewMockStream(mockCtrl) - controlStr.EXPECT().Read(gomock.Any()).DoAndReturn(controlBuf.Read).AnyTimes() - - otherBuf := &bytes.Buffer{} - utils.WriteVarInt(otherBuf, 1337) - otherStr := mockquic.NewMockStream(mockCtrl) - otherStr.EXPECT().Read(gomock.Any()).DoAndReturn(otherBuf.Read).AnyTimes() - - sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - return otherStr, nil + buf := &bytes.Buffer{} + utils.WriteVarInt(buf, 1337) + str := mockquic.NewMockStream(mockCtrl) + str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + done := make(chan struct{}) + str.EXPECT().CancelRead(quic.ErrorCode(errorStreamCreationError)).Do(func(code quic.ErrorCode) { + close(done) }) + sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { - return controlStr, nil + return str, nil }) sess.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { <-testDone return nil, errors.New("test done") }) s.handleConn(sess) - time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to sess.CloseWithError + Eventually(done).Should(BeClosed()) }) It("errors when the first frame on the control stream is not a SETTINGS frame", func() {