http3: fix double close of chan when using DontCloseRequestStream (#3561)

This commit is contained in:
Marten Seemann 2022-09-23 13:10:25 +03:00 committed by GitHub
parent 17761bf35f
commit 424a66389c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 61 additions and 56 deletions

View file

@ -301,6 +301,7 @@ func (c *client) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Respon
}
c.conn.CloseWithError(quic.ApplicationErrorCode(rerr.connErr), reason)
}
return nil, rerr.err
}
if opt.DontCloseRequestStream {
close(reqDone)

View file

@ -743,6 +743,22 @@ var _ = Describe("Client", func() {
Expect(rsp.StatusCode).To(Equal(418))
})
It("doesn't close the request stream, with DontCloseRequestStream set", func() {
rspBuf := bytes.NewBuffer(getResponse(418))
gomock.InOrder(
conn.EXPECT().HandshakeComplete().Return(handshakeCtx),
conn.EXPECT().OpenStreamSync(context.Background()).Return(str, nil),
conn.EXPECT().ConnectionState().Return(quic.ConnectionState{}),
)
str.EXPECT().Write(gomock.Any()).AnyTimes().DoAndReturn(func(p []byte) (int, error) { return len(p), nil })
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
rsp, err := client.RoundTripOpt(req, RoundTripOpt{DontCloseRequestStream: true})
Expect(err).ToNot(HaveOccurred())
Expect(rsp.Proto).To(Equal("HTTP/3.0"))
Expect(rsp.ProtoMajor).To(Equal(3))
Expect(rsp.StatusCode).To(Equal(418))
})
Context("requests containing a Body", func() {
var strBuf *bytes.Buffer
@ -846,47 +862,55 @@ var _ = Describe("Client", func() {
})
Context("request cancellations", func() {
It("cancels a request while waiting for the handshake to complete", func() {
ctx, cancel := context.WithCancel(context.Background())
req := req.WithContext(ctx)
conn.EXPECT().HandshakeComplete().Return(context.Background())
for _, dontClose := range []bool{false, true} {
dontClose := dontClose
errChan := make(chan error)
go func() {
_, err := client.RoundTripOpt(req, RoundTripOpt{})
errChan <- err
}()
Consistently(errChan).ShouldNot(Receive())
cancel()
Eventually(errChan).Should(Receive(MatchError("context canceled")))
})
Context(fmt.Sprintf("with DontCloseRequestStream: %t", dontClose), func() {
roundTripOpt := RoundTripOpt{DontCloseRequestStream: dontClose}
It("cancels a request while the request is still in flight", func() {
ctx, cancel := context.WithCancel(context.Background())
req := req.WithContext(ctx)
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
conn.EXPECT().OpenStreamSync(ctx).Return(str, nil)
buf := &bytes.Buffer{}
str.EXPECT().Close().MaxTimes(1)
It("cancels a request while waiting for the handshake to complete", func() {
ctx, cancel := context.WithCancel(context.Background())
req := req.WithContext(ctx)
conn.EXPECT().HandshakeComplete().Return(context.Background())
str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
errChan := make(chan error)
go func() {
_, err := client.RoundTripOpt(req, roundTripOpt)
errChan <- err
}()
Consistently(errChan).ShouldNot(Receive())
cancel()
Eventually(errChan).Should(Receive(MatchError("context canceled")))
})
done := make(chan struct{})
canceled := make(chan struct{})
gomock.InOrder(
str.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { close(canceled) }),
str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }),
)
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1)
str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
cancel()
<-canceled
return 0, errors.New("test done")
It("cancels a request while the request is still in flight", func() {
ctx, cancel := context.WithCancel(context.Background())
req := req.WithContext(ctx)
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
conn.EXPECT().OpenStreamSync(ctx).Return(str, nil)
buf := &bytes.Buffer{}
str.EXPECT().Close().MaxTimes(1)
str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
done := make(chan struct{})
canceled := make(chan struct{})
gomock.InOrder(
str.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { close(canceled) }),
str.EXPECT().CancelRead(quic.StreamErrorCode(errorRequestCanceled)).Do(func(quic.StreamErrorCode) { close(done) }),
)
str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1)
str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) {
cancel()
<-canceled
return 0, errors.New("test done")
})
_, err := client.RoundTripOpt(req, roundTripOpt)
Expect(err).To(MatchError("test done"))
Eventually(done).Should(BeClosed())
})
})
_, err := client.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("test done"))
Eventually(done).Should(BeClosed())
})
}
It("cancels a request after the response arrived", func() {
rspBuf := bytes.NewBuffer(getResponse(404))
@ -909,26 +933,6 @@ var _ = Describe("Client", func() {
cancel()
Eventually(done).Should(BeClosed())
})
It("doesn't cancel a request if DontCloseRequestStream is set", func() {
rspBuf := bytes.NewBuffer(getResponse(404))
ctx, cancel := context.WithCancel(context.Background())
req := req.WithContext(ctx)
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
conn.EXPECT().OpenStreamSync(ctx).Return(str, nil)
conn.EXPECT().ConnectionState().Return(quic.ConnectionState{})
buf := &bytes.Buffer{}
str.EXPECT().Close().MaxTimes(1)
str.EXPECT().Write(gomock.Any()).DoAndReturn(buf.Write)
str.EXPECT().Read(gomock.Any()).DoAndReturn(rspBuf.Read).AnyTimes()
rsp, err := client.RoundTripOpt(req, RoundTripOpt{DontCloseRequestStream: true})
Expect(err).ToNot(HaveOccurred())
cancel()
_, err = io.ReadAll(rsp.Body)
Expect(err).ToNot(HaveOccurred())
})
})
Context("gzip compression", func() {