diff --git a/http3/client.go b/http3/client.go index 4f7e24ce..6d22c818 100644 --- a/http3/client.go +++ b/http3/client.go @@ -301,6 +301,7 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon } c.conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason) } + return nil, rerr.err } if opt.DontCloseRequestStream { close(reqDone) diff --git a/http3/client_test.go b/http3/client_test.go index f8988e7e..72eea15e 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -743,6 +743,22 @@ var _ = Describe("Client", func() { Expect(rsp.StatusCode).To(Equal(418)) }) + It("doesn't close the request stream, with DontCloseRequestStream set", func() { + rspBuf := bytes.NewBuffer(getResponse(418)) + gomock.InOrder( + conn.EXPECT().HandshakeComplete().Return(handshakeCtx), + conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil), + conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}), + ) + str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil }) + str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() + rsp, err := client.RoundTripOpt(req, RoundTripOpt{DontCloseRequestStream: true}) + Expect(err).ToNot(HaveOccurred()) + Expect(rsp.Proto).To(Equal("HTTP/3.0")) + Expect(rsp.ProtoMajor).To(Equal(3)) + Expect(rsp.StatusCode).To(Equal(418)) + }) + Context("requests containing a Body", func() { var strBuf *bytes.Buffer @@ -846,47 +862,55 @@ var _ = Describe("Client", func() { }) Context("request cancellations", func() { - It("cancels a request while waiting for the handshake to complete", func() { - ctx, cancel := context.WithCancel(context.Background()) - req := req.WithContext(ctx) - conn.EXPECT().HandshakeComplete().Return(context.Background()) + for _, dontClose := range []bool{false, true} { + dontClose := dontClose - errChan := make(chan error) - go func() { - _, err := client.RoundTripOpt(req, RoundTripOpt{}) - errChan <- err - }() - Consistently(errChan).ShouldNot(Receive()) - cancel() - Eventually(errChan).Should(Receive(MatchError("context canceled"))) - }) + Context(fmt.Sprintf("with DontCloseRequestStream: %t", dontClose), func() { + roundTripOpt := RoundTripOpt{DontCloseRequestStream: dontClose} - It("cancels a request while the request is still in flight", func() { - ctx, cancel := context.WithCancel(context.Background()) - req := req.WithContext(ctx) - conn.EXPECT().HandshakeComplete().Return(handshakeCtx) - conn.EXPECT().OpenStreamSync(ctx).Return(str, nil) - buf := &bytes.Buffer{} - str.EXPECT().Close().MaxTimes(1) + It("cancels a request while waiting for the handshake to complete", func() { + ctx, cancel := context.WithCancel(context.Background()) + req := req.WithContext(ctx) + conn.EXPECT().HandshakeComplete().Return(context.Background()) - str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) + errChan := make(chan error) + go func() { + _, err := client.RoundTripOpt(req, roundTripOpt) + errChan <- err + }() + Consistently(errChan).ShouldNot(Receive()) + cancel() + Eventually(errChan).Should(Receive(MatchError("context canceled"))) + }) - done := make(chan struct{}) - canceled := make(chan struct{}) - gomock.InOrder( - str.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { close(canceled) }), - str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }), - ) - str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1) - str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { - cancel() - <-canceled - return 0, errors.New("test done") + It("cancels a request while the request is still in flight", func() { + ctx, cancel := context.WithCancel(context.Background()) + req := req.WithContext(ctx) + conn.EXPECT().HandshakeComplete().Return(handshakeCtx) + conn.EXPECT().OpenStreamSync(ctx).Return(str, nil) + buf := &bytes.Buffer{} + str.EXPECT().Close().MaxTimes(1) + + str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) + + done := make(chan struct{}) + canceled := make(chan struct{}) + gomock.InOrder( + str.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { close(canceled) }), + str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }), + ) + str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1) + str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { + cancel() + <-canceled + return 0, errors.New("test done") + }) + _, err := client.RoundTripOpt(req, roundTripOpt) + Expect(err).To(MatchError("test done")) + Eventually(done).Should(BeClosed()) + }) }) - _, err := client.RoundTripOpt(req, RoundTripOpt{}) - Expect(err).To(MatchError("test done")) - Eventually(done).Should(BeClosed()) - }) + } It("cancels a request after the response arrived", func() { rspBuf := bytes.NewBuffer(getResponse(404)) @@ -909,26 +933,6 @@ var _ = Describe("Client", func() { cancel() Eventually(done).Should(BeClosed()) }) - - It("doesn't cancel a request if DontCloseRequestStream is set", func() { - rspBuf := bytes.NewBuffer(getResponse(404)) - - ctx, cancel := context.WithCancel(context.Background()) - req := req.WithContext(ctx) - conn.EXPECT().HandshakeComplete().Return(handshakeCtx) - conn.EXPECT().OpenStreamSync(ctx).Return(str, nil) - conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}) - buf := &bytes.Buffer{} - str.EXPECT().Close().MaxTimes(1) - - str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) - str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes() - rsp, err := client.RoundTripOpt(req, RoundTripOpt{DontCloseRequestStream: true}) - Expect(err).ToNot(HaveOccurred()) - cancel() - _, err = io.ReadAll(rsp.Body) - Expect(err).ToNot(HaveOccurred()) - }) }) Context("gzip compression", func() {