diff --git a/http3/body.go b/http3/body.go index cbc8e9e9..e450a56a 100644 --- a/http3/body.go +++ b/http3/body.go @@ -13,6 +13,12 @@ type body struct { isRequest bool + // only set for the http.Response + // The channel is closed when the user is done with this response: + // either when Read() errors, or when Close() is called. + reqDone chan<- struct{} + reqDoneClosed bool + bytesRemainingInFrame uint64 } @@ -25,11 +31,22 @@ func newRequestBody(str quic.Stream) *body { } } -func newResponseBody(str quic.Stream) *body { - return &body{str: str} +func newResponseBody(str quic.Stream, done chan<- struct{}) *body { + return &body{ + str: str, + reqDone: done, + } } func (r *body) Read(b []byte) (int, error) { + n, err := r.readImpl(b) + if err != nil && !r.isRequest { + r.requestDone() + } + return n, err +} + +func (r *body) readImpl(b []byte) (int, error) { if r.bytesRemainingInFrame == 0 { parseLoop: for { @@ -61,11 +78,20 @@ func (r *body) Read(b []byte) (int, error) { return n, err } +func (r *body) requestDone() { + if r.reqDoneClosed { + return + } + close(r.reqDone) + r.reqDoneClosed = true +} + func (r *body) Close() error { // quic.Stream.Close() closes the write side, not the read side if r.isRequest { return r.str.Close() } + r.requestDone() r.str.CancelRead(quic.ErrorCode(errorRequestCanceled)) return nil } diff --git a/http3/body_test.go b/http3/body_test.go index dac6e2d5..9843237e 100644 --- a/http3/body_test.go +++ b/http3/body_test.go @@ -29,9 +29,10 @@ func (t bodyType) String() string { var _ = Describe("Body", func() { var ( - rb *body - str *mockquic.MockStream - buf *bytes.Buffer + rb *body + str *mockquic.MockStream + buf *bytes.Buffer + reqDone chan struct{} ) getDataFrame := func(data []byte) []byte { @@ -62,7 +63,8 @@ var _ = Describe("Body", func() { case bodyTypeRequest: rb = newRequestBody(str) case bodyTypeResponse: - rb = newResponseBody(str) + reqDone = make(chan struct{}) + rb = newResponseBody(str, reqDone) } }) @@ -156,10 +158,33 @@ var _ = Describe("Body", func() { } if bodyType == bodyTypeResponse { + It("closes the reqDone channel when Read errors", func() { + buf.Write([]byte("invalid")) + _, err := rb.Read([]byte{0}) + Expect(err).To(HaveOccurred()) + Expect(reqDone).To(BeClosed()) + }) + + It("allows multiple calls to Read, when Read errors", func() { + buf.Write([]byte("invalid")) + _, err := rb.Read([]byte{0}) + Expect(err).To(HaveOccurred()) + Expect(reqDone).To(BeClosed()) + _, err = rb.Read([]byte{0}) + Expect(err).To(HaveOccurred()) + }) + It("closes responses", func() { str.EXPECT().CancelRead(quic.ErrorCode(errorRequestCanceled)) Expect(rb.Close()).To(Succeed()) }) + + It("allows multiple calls to Close", func() { + str.EXPECT().CancelRead(quic.ErrorCode(errorRequestCanceled)).MaxTimes(2) + Expect(rb.Close()).To(Succeed()) + Expect(reqDone).To(BeClosed()) + Expect(rb.Close()).To(Succeed()) + }) } }) } diff --git a/http3/client.go b/http3/client.go index b4ae76e9..adce6d09 100644 --- a/http3/client.go +++ b/http3/client.go @@ -153,6 +153,19 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { return nil, err } + // Request Cancelation: + // This go routine keeps running even after RoundTrip() returns. + // It is shut down when the application is done processing the body. + reqDone := make(chan struct{}) + go func() { + select { + case <-req.Context().Done(): + str.CancelWrite(quic.ErrorCode(errorRequestCanceled)) + str.CancelRead(quic.ErrorCode(errorRequestCanceled)) + case <-reqDone: + } + }() + var requestGzip bool if !c.opts.DisableCompression && req.Method != "HEAD" && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" { requestGzip = true @@ -198,7 +211,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { res.Header.Add(hf.Name, hf.Value) } } - respBody := newResponseBody(str) + respBody := newResponseBody(str, reqDone) if requestGzip && res.Header.Get("Content-Encoding") == "gzip" { res.Header.Del("Content-Encoding") res.Header.Del("Content-Length") diff --git a/http3/client_test.go b/http3/client_test.go index e6709c10..6afb7a4a 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -301,6 +301,56 @@ var _ = Describe("Client", func() { }) }) + Context("request cancelations", func() { + It("cancels a request while the request is still in flight", func() { + ctx, cancel := context.WithCancel(context.Background()) + req := request.WithContext(ctx) + sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) + buf := &bytes.Buffer{} + str.EXPECT().Close().MaxTimes(1) + + done := make(chan struct{}) + str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { + return buf.Write(p) + }) + str.EXPECT().CancelWrite(quic.ErrorCode(errorRequestCanceled)) + str.EXPECT().CancelRead(quic.ErrorCode(errorRequestCanceled)).Do(func(quic.ErrorCode) { close(done) }) + str.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { + cancel() + return 0, errors.New("test done") + }) + _, err := client.RoundTrip(req) + Expect(err).To(MatchError("test done")) + Eventually(done).Should(BeClosed()) + }) + + It("cancels a request after the response arrived", func() { + rspBuf := &bytes.Buffer{} + rw := newResponseWriter(rspBuf, utils.DefaultLogger) + rw.WriteHeader(418) + + ctx, cancel := context.WithCancel(context.Background()) + req := request.WithContext(ctx) + sess.EXPECT().OpenStreamSync(context.Background()).Return(str, nil) + buf := &bytes.Buffer{} + str.EXPECT().Close().MaxTimes(1) + + done := make(chan struct{}) + str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { + return buf.Write(p) + }) + str.EXPECT().Read(gomock.Any()).DoAndReturn(func(b []byte) (int, error) { + return rspBuf.Read(b) + }).AnyTimes() + str.EXPECT().CancelWrite(quic.ErrorCode(errorRequestCanceled)) + str.EXPECT().CancelRead(quic.ErrorCode(errorRequestCanceled)).Do(func(quic.ErrorCode) { close(done) }) + _, err := client.RoundTrip(req) + Expect(err).ToNot(HaveOccurred()) + cancel() + Eventually(done).Should(BeClosed()) + }) + }) + Context("gzip compression", func() { var gzippedData []byte // a gzipped foobar var response *http.Response diff --git a/integrationtests/self/http_test.go b/integrationtests/self/http_test.go index 9403be94..ff44c343 100644 --- a/integrationtests/self/http_test.go +++ b/integrationtests/self/http_test.go @@ -3,6 +3,7 @@ package self_test import ( "bytes" "compress/gzip" + "context" "crypto/tls" "fmt" "io/ioutil" @@ -20,6 +21,11 @@ import ( "github.com/onsi/gomega/gbytes" ) +type streamCancelError interface { + Canceled() bool + ErrorCode() protocol.ApplicationErrorCode +} + var _ = Describe("HTTP tests", func() { var client *http.Client @@ -184,6 +190,36 @@ var _ = Describe("HTTP tests", func() { Expect(err).ToNot(HaveOccurred()) Expect(string(body)).To(Equal("Hello, World!\n")) }) + + It("cancels requests", func() { + handlerCalled := make(chan struct{}) + http.HandleFunc("/cancel", func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + defer close(handlerCalled) + for { + if _, err := w.Write([]byte("foobar")); err != nil { + Expect(r.Context().Done()).To(BeClosed()) + serr, ok := err.(streamCancelError) + Expect(ok).To(BeTrue()) + Expect(serr.Canceled()).To(BeTrue()) + Expect(serr.ErrorCode()).To(BeEquivalentTo(5)) + return + } + } + }) + + req, err := http.NewRequest(http.MethodGet, "https://localhost:"+testserver.Port()+"/cancel", nil) + Expect(err).ToNot(HaveOccurred()) + ctx, cancel := context.WithCancel(context.Background()) + req = req.WithContext(ctx) + resp, err := client.Do(req) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200)) + cancel() + Eventually(handlerCalled).Should(BeClosed()) + _, err = resp.Body.Read([]byte{0}) + Expect(err).To(HaveOccurred()) + }) }) } })