From de6ab8843778fd4f95decbafda98d1ec89c7d590 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 15 Apr 2019 16:13:42 +0900 Subject: [PATCH] add support for gzipped HTTP/3 requests --- codecov.yml | 1 + http3/client.go | 22 ++++++- http3/client_test.go | 94 ++++++++++++++++++++++++++++++ http3/gzip_reader.go | 39 +++++++++++++ http3/request_writer.go | 8 +-- http3/request_writer_test.go | 15 ++++- http3/server_test.go | 2 +- integrationtests/self/http_test.go | 25 ++++++++ 8 files changed, 195 insertions(+), 11 deletions(-) create mode 100644 http3/gzip_reader.go diff --git a/codecov.yml b/codecov.yml index 45a1d83e..cfe901ea 100644 --- a/codecov.yml +++ b/codecov.yml @@ -5,6 +5,7 @@ coverage: - streams_map_incoming_uni.go - streams_map_outgoing_bidi.go - streams_map_outgoing_uni.go + - http3/gzip_reader.go - internal/ackhandler/packet_linkedlist.go - internal/utils/byteinterval_linkedlist.go - internal/utils/packetinterval_linkedlist.go diff --git a/http3/client.go b/http3/client.go index ca45381a..fc100818 100644 --- a/http3/client.go +++ b/http3/client.go @@ -29,6 +29,7 @@ type roundTripperOpts struct { type client struct { tlsConf *tls.Config config *quic.Config + opts *roundTripperOpts dialOnce sync.Once dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error) @@ -47,7 +48,7 @@ type client struct { func newClient( hostname string, tlsConf *tls.Config, - _ *roundTripperOpts, // TODO: implement gzip compression + opts *roundTripperOpts, quicConfig *quic.Config, dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error), ) *client { @@ -67,6 +68,7 @@ func newClient( requestWriter: newRequestWriter(logger), decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}), config: quicConfig, + opts: opts, dialer: dialer, logger: logger, } @@ -138,7 +140,11 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { return nil, err } - if err := c.requestWriter.WriteRequest(str, req); err != nil { + var requestGzip bool + if !c.opts.DisableCompression && req.Method != "HEAD" && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" { + requestGzip = true + } + if err := c.requestWriter.WriteRequest(str, req, requestGzip); err != nil { return nil, err } @@ -163,7 +169,6 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { Proto: "HTTP/3", ProtoMajor: 3, Header: http.Header{}, - Body: newResponseBody(&responseBody{str}), } for _, hf := range hfs { switch hf.Name { @@ -178,5 +183,16 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { res.Header.Add(hf.Name, hf.Value) } } + respBody := newResponseBody(&responseBody{str}) + if requestGzip && res.Header.Get("Content-Encoding") == "gzip" { + res.Header.Del("Content-Encoding") + res.Header.Del("Content-Length") + res.ContentLength = -1 + res.Body = newGzipReader(respBody) + res.Uncompressed = true + } else { + res.Body = respBody + } + return res, nil } diff --git a/http3/client_test.go b/http3/client_test.go index 906c3b5d..62f3c79a 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -2,9 +2,11 @@ package http3 import ( "bytes" + "compress/gzip" "crypto/tls" "errors" "io" + "io/ioutil" "net/http" "time" @@ -268,5 +270,97 @@ var _ = Describe("Client", func() { Expect(err).To(MatchError("test done")) }) }) + + Context("gzip compression", func() { + var gzippedData []byte // a gzipped foobar + var response *http.Response + + BeforeEach(func() { + var b bytes.Buffer + w := gzip.NewWriter(&b) + w.Write([]byte("foobar")) + w.Close() + gzippedData = b.Bytes() + response = &http.Response{ + StatusCode: 200, + Header: http.Header{"Content-Length": []string{"1000"}}, + } + _ = gzippedData + _ = response + }) + + It("adds the gzip header to requests", func() { + sess.EXPECT().OpenStreamSync().Return(str, nil) + buf := &bytes.Buffer{} + str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { + return buf.Write(p) + }) + str.EXPECT().Close() + str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done")) + _, err := client.RoundTrip(request) + Expect(err).To(MatchError("test done")) + hfs := decodeHeader(buf) + Expect(hfs).To(HaveKeyWithValue("accept-encoding", "gzip")) + }) + + It("doesn't add gzip if the header disable it", func() { + client = newClient("quic.clemente.io:1337", nil, &roundTripperOpts{DisableCompression: true}, nil, nil) + sess.EXPECT().OpenStreamSync().Return(str, nil) + buf := &bytes.Buffer{} + str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { + return buf.Write(p) + }) + str.EXPECT().Close() + str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done")) + _, err := client.RoundTrip(request) + Expect(err).To(MatchError("test done")) + hfs := decodeHeader(buf) + Expect(hfs).ToNot(HaveKey("accept-encoding")) + }) + + It("decompresses the response", func() { + sess.EXPECT().OpenStreamSync().Return(str, nil) + buf := &bytes.Buffer{} + rw := newResponseWriter(buf, utils.DefaultLogger) + rw.Header().Set("Content-Encoding", "gzip") + gz := gzip.NewWriter(rw) + gz.Write([]byte("gzipped response")) + gz.Close() + str.EXPECT().Write(gomock.Any()).AnyTimes() + str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { + return buf.Read(p) + }).AnyTimes() + str.EXPECT().Close() + + rsp, err := client.RoundTrip(request) + Expect(err).ToNot(HaveOccurred()) + data, err := ioutil.ReadAll(rsp.Body) + Expect(err).ToNot(HaveOccurred()) + Expect(rsp.ContentLength).To(BeEquivalentTo(-1)) + Expect(string(data)).To(Equal("gzipped response")) + Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty()) + Expect(rsp.Uncompressed).To(BeTrue()) + }) + + It("only decompresses the response if the response contains the right content-encoding header", func() { + sess.EXPECT().OpenStreamSync().Return(str, nil) + buf := &bytes.Buffer{} + rw := newResponseWriter(buf, utils.DefaultLogger) + rw.Write([]byte("not gzipped")) + str.EXPECT().Write(gomock.Any()).AnyTimes() + str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) { + return buf.Read(p) + }).AnyTimes() + str.EXPECT().Close() + + rsp, err := client.RoundTrip(request) + Expect(err).ToNot(HaveOccurred()) + data, err := ioutil.ReadAll(rsp.Body) + Expect(err).ToNot(HaveOccurred()) + Expect(rsp.ContentLength).ToNot(BeEquivalentTo(-1)) + Expect(string(data)).To(Equal("not gzipped")) + Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty()) + }) + }) }) }) diff --git a/http3/gzip_reader.go b/http3/gzip_reader.go new file mode 100644 index 00000000..01983ac7 --- /dev/null +++ b/http3/gzip_reader.go @@ -0,0 +1,39 @@ +package http3 + +// copied from net/transport.go + +// gzipReader wraps a response body so it can lazily +// call gzip.NewReader on the first call to Read +import ( + "compress/gzip" + "io" +) + +// call gzip.NewReader on the first call to Read +type gzipReader struct { + body io.ReadCloser // underlying Response.Body + zr *gzip.Reader // lazily-initialized gzip reader + zerr error // sticky error +} + +func newGzipReader(body io.ReadCloser) io.ReadCloser { + return &gzipReader{body: body} +} + +func (gz *gzipReader) Read(p []byte) (n int, err error) { + if gz.zerr != nil { + return 0, gz.zerr + } + if gz.zr == nil { + gz.zr, err = gzip.NewReader(gz.body) + if err != nil { + gz.zerr = err + return 0, err + } + } + return gz.zr.Read(p) +} + +func (gz *gzipReader) Close() error { + return gz.body.Close() +} diff --git a/http3/request_writer.go b/http3/request_writer.go index 95026257..dfbe207f 100644 --- a/http3/request_writer.go +++ b/http3/request_writer.go @@ -36,8 +36,8 @@ func newRequestWriter(logger utils.Logger) *requestWriter { } } -func (w *requestWriter) WriteRequest(str quic.Stream, req *http.Request) error { - headers, err := w.getHeaders(req) +func (w *requestWriter) WriteRequest(str quic.Stream, req *http.Request, gzip bool) error { + headers, err := w.getHeaders(req, gzip) if err != nil { return err } @@ -62,12 +62,12 @@ func (w *requestWriter) WriteRequest(str quic.Stream, req *http.Request) error { return nil } -func (w *requestWriter) getHeaders(req *http.Request) ([]byte, error) { +func (w *requestWriter) getHeaders(req *http.Request, gzip bool) ([]byte, error) { w.mutex.Lock() defer w.mutex.Unlock() defer w.encoder.Close() - if err := w.encodeHeaders(req, false, "", actualContentLength(req)); err != nil { + if err := w.encodeHeaders(req, gzip, "", actualContentLength(req)); err != nil { return nil, err } diff --git a/http3/request_writer_test.go b/http3/request_writer_test.go index a8dc6f5d..ffde6d6f 100644 --- a/http3/request_writer_test.go +++ b/http3/request_writer_test.go @@ -54,7 +54,7 @@ var _ = Describe("Request Writer", func() { str.EXPECT().Close() req, err := http.NewRequest("GET", "https://quic.clemente.io/index.html?foo=bar", nil) Expect(err).ToNot(HaveOccurred()) - Expect(rw.WriteRequest(str, req)).To(Succeed()) + Expect(rw.WriteRequest(str, req, false)).To(Succeed()) headerFields := decode(strBuf) Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io")) Expect(headerFields).To(HaveKeyWithValue(":method", "GET")) @@ -69,7 +69,7 @@ var _ = Describe("Request Writer", func() { postData := bytes.NewReader([]byte("foobar")) req, err := http.NewRequest("POST", "https://quic.clemente.io/upload.html", postData) Expect(err).ToNot(HaveOccurred()) - Expect(rw.WriteRequest(str, req)).To(Succeed()) + Expect(rw.WriteRequest(str, req, false)).To(Succeed()) headerFields := decode(strBuf) Expect(headerFields).To(HaveKeyWithValue(":method", "POST")) Expect(headerFields).To(HaveKey("content-length")) @@ -98,8 +98,17 @@ var _ = Describe("Request Writer", func() { } req.AddCookie(cookie1) req.AddCookie(cookie2) - Expect(rw.WriteRequest(str, req)).To(Succeed()) + Expect(rw.WriteRequest(str, req, false)).To(Succeed()) headerFields := decode(strBuf) Expect(headerFields).To(HaveKeyWithValue("cookie", `Cookie #1="Value #1"; Cookie #2="Value #2"`)) }) + + It("adds the header for gzip support", func() { + str.EXPECT().Close() + req, err := http.NewRequest("GET", "https://quic.clemente.io/", nil) + Expect(err).ToNot(HaveOccurred()) + Expect(rw.WriteRequest(str, req, true)).To(Succeed()) + headerFields := decode(strBuf) + Expect(headerFields).To(HaveKeyWithValue("accept-encoding", "gzip")) + }) }) diff --git a/http3/server_test.go b/http3/server_test.go index 66b44781..012817b6 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -81,7 +81,7 @@ var _ = Describe("Server", func() { closed := make(chan struct{}) str.EXPECT().Close().Do(func() { close(closed) }) rw := newRequestWriter(utils.DefaultLogger) - Expect(rw.WriteRequest(str, req)).To(Succeed()) + Expect(rw.WriteRequest(str, req, false)).To(Succeed()) Eventually(closed).Should(BeClosed()) return buf.Bytes() } diff --git a/integrationtests/self/http_test.go b/integrationtests/self/http_test.go index 0aa4b2d6..9403be94 100644 --- a/integrationtests/self/http_test.go +++ b/integrationtests/self/http_test.go @@ -2,6 +2,7 @@ package self_test import ( "bytes" + "compress/gzip" "crypto/tls" "fmt" "io/ioutil" @@ -42,6 +43,7 @@ var _ = Describe("HTTP tests", func() { TLSClientConfig: &tls.Config{ RootCAs: testdata.GetRootCA(), }, + DisableCompression: true, QuicConfig: &quic.Config{ Versions: []protocol.VersionNumber{version}, IdleTimeout: 10 * time.Second, @@ -159,6 +161,29 @@ var _ = Describe("HTTP tests", func() { Expect(err).ToNot(HaveOccurred()) Expect(body).To(Equal(testserver.PRData)) }) + + It("uses gzip compression", func() { + http.HandleFunc("/gzipped/hello", func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + Expect(r.Header.Get("Accept-Encoding")).To(Equal("gzip")) + w.Header().Set("Content-Encoding", "gzip") + w.Header().Set("foo", "bar") + + gw := gzip.NewWriter(w) + defer gw.Close() + gw.Write([]byte("Hello, World!\n")) + }) + + client.Transport.(*http3.RoundTripper).DisableCompression = false + resp, err := client.Get("https://localhost:" + testserver.Port() + "/gzipped/hello") + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200)) + Expect(resp.Uncompressed).To(BeTrue()) + + body, err := ioutil.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second)) + Expect(err).ToNot(HaveOccurred()) + Expect(string(body)).To(Equal("Hello, World!\n")) + }) }) } })