http3: fix double close of chan when using DontCloseRequestStream (#3561)

This commit is contained in:
Marten Seemann 2022-09-23 13:10:25 +03:00 committed by GitHub
parent 17761bf35f
commit 424a66389c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 61 additions and 56 deletions

View file

@ -301,6 +301,7 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon
} }
c.conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason) c.conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason)
} }
return nil, rerr.err
} }
if opt.DontCloseRequestStream { if opt.DontCloseRequestStream {
close(reqDone) close(reqDone)

View file

@ -743,6 +743,22 @@ var _ = Describe("Client", func() {
Expect(rsp.StatusCode).To(Equal(418)) Expect(rsp.StatusCode).To(Equal(418))
}) })
It("doesn't close the request stream, with DontCloseRequestStream set", func() {
rspBuf := bytes.NewBuffer(getResponse(418))
gomock.InOrder(
conn.EXPECT().HandshakeComplete().Return(handshakeCtx),
conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}),
)
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
rsp, err := client.RoundTripOpt(req, RoundTripOpt{DontCloseRequestStream: true})
Expect(err).ToNot(HaveOccurred())
Expect(rsp.Proto).To(Equal("HTTP/3.0"))
Expect(rsp.ProtoMajor).To(Equal(3))
Expect(rsp.StatusCode).To(Equal(418))
})
Context("requests containing a Body", func() { Context("requests containing a Body", func() {
var strBuf *bytes.Buffer var strBuf *bytes.Buffer
@ -846,47 +862,55 @@ var _ = Describe("Client", func() {
}) })
Context("request cancellations", func() { Context("request cancellations", func() {
It("cancels a request while waiting for the handshake to complete", func() { for _, dontClose := range []bool{false, true} {
ctx, cancel := context.WithCancel(context.Background()) dontClose := dontClose
req := req.WithContext(ctx)
conn.EXPECT().HandshakeComplete().Return(context.Background())
errChan := make(chan error) Context(fmt.Sprintf("with DontCloseRequestStream: %t", dontClose), func() {
go func() { roundTripOpt := RoundTripOpt{DontCloseRequestStream: dontClose}
_, err := client.RoundTripOpt(req, RoundTripOpt{})
errChan <- err
}()
Consistently(errChan).ShouldNot(Receive())
cancel()
Eventually(errChan).Should(Receive(MatchError("context canceled")))
})
It("cancels a request while the request is still in flight", func() { It("cancels a request while waiting for the handshake to complete", func() {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
req := req.WithContext(ctx) req := req.WithContext(ctx)
conn.EXPECT().HandshakeComplete().Return(handshakeCtx) conn.EXPECT().HandshakeComplete().Return(context.Background())
conn.EXPECT().OpenStreamSync(ctx).Return(str, nil)
buf := &bytes.Buffer{}
str.EXPECT().Close().MaxTimes(1)
str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write) errChan := make(chan error)
go func() {
_, err := client.RoundTripOpt(req, roundTripOpt)
errChan <- err
}()
Consistently(errChan).ShouldNot(Receive())
cancel()
Eventually(errChan).Should(Receive(MatchError("context canceled")))
})
done := make(chan struct{}) It("cancels a request while the request is still in flight", func() {
canceled := make(chan struct{}) ctx, cancel := context.WithCancel(context.Background())
gomock.InOrder( req := req.WithContext(ctx)
str.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { close(canceled) }), conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }), conn.EXPECT().OpenStreamSync(ctx).Return(str, nil)
) buf := &bytes.Buffer{}
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1) str.EXPECT().Close().MaxTimes(1)
str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
cancel() str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
<-canceled
return 0, errors.New("test done") done := make(chan struct{})
canceled := make(chan struct{})
gomock.InOrder(
str.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { close(canceled) }),
str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }),
)
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1)
str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
cancel()
<-canceled
return 0, errors.New("test done")
})
_, err := client.RoundTripOpt(req, roundTripOpt)
Expect(err).To(MatchError("test done"))
Eventually(done).Should(BeClosed())
})
}) })
_, err := client.RoundTripOpt(req, RoundTripOpt{}) }
Expect(err).To(MatchError("test done"))
Eventually(done).Should(BeClosed())
})
It("cancels a request after the response arrived", func() { It("cancels a request after the response arrived", func() {
rspBuf := bytes.NewBuffer(getResponse(404)) rspBuf := bytes.NewBuffer(getResponse(404))
@ -909,26 +933,6 @@ var _ = Describe("Client", func() {
cancel() cancel()
Eventually(done).Should(BeClosed()) Eventually(done).Should(BeClosed())
}) })
It("doesn't cancel a request if DontCloseRequestStream is set", func() {
rspBuf := bytes.NewBuffer(getResponse(404))
ctx, cancel := context.WithCancel(context.Background())
req := req.WithContext(ctx)
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
conn.EXPECT().OpenStreamSync(ctx).Return(str, nil)
conn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
buf := &bytes.Buffer{}
str.EXPECT().Close().MaxTimes(1)
str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
rsp, err := client.RoundTripOpt(req, RoundTripOpt{DontCloseRequestStream: true})
Expect(err).ToNot(HaveOccurred())
cancel()
_, err = io.ReadAll(rsp.Body)
Expect(err).ToNot(HaveOccurred())
})
}) })
Context("gzip compression", func() { Context("gzip compression", func() {