handle request cancelations while waiting for handshake completion

This commit is contained in:
Marten Seemann 2020-01-26 11:08:36 +07:00
parent 639925628d
commit af1e8f4c78
3 changed files with 23 additions and 2 deletions

View file

@ -101,7 +101,6 @@ func (c *client) dial() error {
}
}()
<-c.session.HandshakeComplete().Done()
return nil
}
@ -151,6 +150,13 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
return nil, c.handshakeErr
}
// wait for the handshake to complete
select {
case <-c.session.HandshakeComplete().Done():
case <-req.Context().Done():
return nil, req.Context().Err()
}
str, err := c.session.OpenStreamSync(req.Context())
if err != nil {
return nil, err

View file

@ -310,6 +310,21 @@ var _ = Describe("Client", func() {
})
Context("request cancellations", func() {
It("cancels a request while waiting for the handshake to complete", func() {
ctx, cancel := context.WithCancel(context.Background())
req := request.WithContext(ctx)
sess.EXPECT().HandshakeComplete().Return(context.Background())
errChan := make(chan error)
go func() {
_, err := client.RoundTrip(req)
errChan <- err
}()
Consistently(errChan).ShouldNot(Receive())
cancel()
Eventually(errChan).Should(Receive(MatchError("context canceled")))
})
It("cancels a request while the request is still in flight", func() {
ctx, cancel := context.WithCancel(context.Background())
req := request.WithContext(ctx)

View file

@ -135,8 +135,8 @@ var _ = Describe("RoundTripper", func() {
It("reuses existing clients", func() {
closed := make(chan struct{})
testErr := errors.New("test err")
session.EXPECT().HandshakeComplete().Return(handshakeCtx)
session.EXPECT().OpenUniStream().AnyTimes().Return(nil, testErr)
session.EXPECT().HandshakeComplete().Return(handshakeCtx).Times(2)
session.EXPECT().OpenStreamSync(context.Background()).Return(nil, testErr).Times(2)
session.EXPECT().CloseWithError(gomock.Any(), gomock.Any()).Do(func(quic.ErrorCode, string) { close(closed) })
req, err := http.NewRequest("GET", "https://quic.clemente.io/file1.html", nil)