diff --git a/codecov.yml b/codecov.yml index 597f9d9e..8fa7519a 100644 --- a/codecov.yml +++ b/codecov.yml @@ -2,6 +2,7 @@ coverage: round: nearest ignore: - ackhandler/packet_linkedlist.go + - h2quic/gzipreader.go - h2quic/response.go - utils/byteinterval_linkedlist.go - utils/packetinterval_linkedlist.go diff --git a/h2quic/client.go b/h2quic/client.go index 65bca65b..634aed95 100644 --- a/h2quic/client.go +++ b/h2quic/client.go @@ -168,7 +168,12 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) { c.Close(err) return nil, err } - err = c.requestWriter.WriteRequest(req, dataStreamID) + + var requestedGzip bool + if req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" { + requestedGzip = true + } + err = c.requestWriter.WriteRequest(req, dataStreamID, requestedGzip) if err != nil { c.Close(err) return nil, err @@ -199,6 +204,13 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) { res.Body = noBody } else { res.Body = dataStream + if requestedGzip && res.Header.Get("Content-Encoding") == "gzip" { + res.Header.Del("Content-Encoding") + res.Header.Del("Content-Length") + res.ContentLength = -1 + res.Body = &gzipReader{body: res.Body} + setUncompressed(res) + } } res.Request = req diff --git a/h2quic/client_test.go b/h2quic/client_test.go index 59ab85ea..ca812cc6 100644 --- a/h2quic/client_test.go +++ b/h2quic/client_test.go @@ -1,9 +1,12 @@ package h2quic import ( + "bytes" + "compress/gzip" "net/http" "golang.org/x/net/http2" + "golang.org/x/net/http2/hpack" "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/qerr" @@ -88,19 +91,41 @@ var _ = Describe("Client", func() { }) Context("Doing requests", func() { + var request *http.Request + + getRequest := func(data []byte) *http2.MetaHeadersFrame { + r := bytes.NewReader(data) + decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {}) + h2framer := http2.NewFramer(nil, r) + frame, err := h2framer.ReadFrame() + Expect(err).ToNot(HaveOccurred()) + mhframe := &http2.MetaHeadersFrame{HeadersFrame: frame.(*http2.HeadersFrame)} + mhframe.Fields, err = decoder.DecodeFull(mhframe.HeadersFrame.HeaderBlockFragment()) + Expect(err).ToNot(HaveOccurred()) + return mhframe + } + + getHeaderFields := func(f *http2.MetaHeadersFrame) map[string]string { + fields := make(map[string]string) + for _, hf := range f.Fields { + fields[hf.Name] = hf.Value + } + return fields + } + BeforeEach(func() { + var err error client.encryptionLevel = protocol.EncryptionForwardSecure + request, err = http.NewRequest("https", "https://quic.clemente.io:1337/file1.dat", nil) + Expect(err).ToNot(HaveOccurred()) }) It("does a request", func(done Done) { - req, err := http.NewRequest("https", "https://quic.clemente.io:1337/file1.dat", nil) - Expect(err).ToNot(HaveOccurred()) - var doRsp *http.Response var doErr error var doReturned bool go func() { - doRsp, doErr = client.Do(req) + doRsp, doErr = client.Do(request) doReturned = true }() @@ -118,14 +143,11 @@ var _ = Describe("Client", func() { Expect(doRsp).To(Equal(rsp)) Expect(doRsp.Body).ToNot(BeNil()) Expect(doRsp.ContentLength).To(BeEquivalentTo(-1)) - Expect(doRsp.Request).To(Equal(req)) + Expect(doRsp.Request).To(Equal(request)) close(done) }) It("closes the quic client when encountering an error on the header stream", func() { - req, err := http.NewRequest("https", "https://quic.clemente.io:1337/file1.dat", nil) - Expect(err).ToNot(HaveOccurred()) - headerStream.dataToRead.Write([]byte("invalid response")) go client.handleHeaderStream() @@ -133,7 +155,7 @@ var _ = Describe("Client", func() { var doErr error var doReturned bool go func() { - doRsp, doErr = client.Do(req) + doRsp, doErr = client.Do(request) doReturned = true }() @@ -180,6 +202,81 @@ var _ = Describe("Client", func() { }) }) + Context("gzip compression", func() { + var gzippedData []byte // a gzipped foobar + var response *http.Response + + BeforeEach(func() { + var b bytes.Buffer + w := gzip.NewWriter(&b) + w.Write([]byte("foobar")) + w.Close() + gzippedData = b.Bytes() + response = &http.Response{ + StatusCode: 200, + Header: http.Header{"Content-Length": []string{"1000"}}, + } + }) + + It("adds the gzip header to requests", func() { + var doRsp *http.Response + var doErr error + go func() { doRsp, doErr = client.Do(request) }() + + Eventually(func() chan *http.Response { return client.responses[5] }).ShouldNot(BeNil()) + qClient.streams[5].dataToRead.Write(gzippedData) + response.Header.Add("Content-Encoding", "gzip") + client.responses[5] <- response + Eventually(func() *http.Response { return doRsp }).ShouldNot(BeNil()) + Expect(doErr).ToNot(HaveOccurred()) + headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes())) + Expect(headers).To(HaveKeyWithValue("accept-encoding", "gzip")) + Expect(doRsp.ContentLength).To(BeEquivalentTo(-1)) + Expect(doRsp.Header.Get("Content-Encoding")).To(BeEmpty()) + Expect(doRsp.Header.Get("Content-Length")).To(BeEmpty()) + data := make([]byte, 6) + doRsp.Body.Read(data) + Expect(data).To(Equal([]byte("foobar"))) + }) + + It("only decompresses the response if the response contains the right content-encoding header", func() { + var doRsp *http.Response + var doErr error + go func() { doRsp, doErr = client.Do(request) }() + + Eventually(func() chan *http.Response { return client.responses[5] }).ShouldNot(BeNil()) + qClient.streams[5].dataToRead.Write([]byte("not gzipped")) + client.responses[5] <- response + Eventually(func() *http.Response { return doRsp }).ShouldNot(BeNil()) + Expect(doErr).ToNot(HaveOccurred()) + headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes())) + Expect(headers).To(HaveKeyWithValue("accept-encoding", "gzip")) + data := make([]byte, 11) + doRsp.Body.Read(data) + Expect(doRsp.ContentLength).ToNot(BeEquivalentTo(-1)) + Expect(data).To(Equal([]byte("not gzipped"))) + }) + + It("doesn't add the gzip header for requests that have the accept-enconding set", func() { + request.Header.Add("accept-encoding", "gzip") + var doRsp *http.Response + var doErr error + go func() { doRsp, doErr = client.Do(request) }() + + Eventually(func() chan *http.Response { return client.responses[5] }).ShouldNot(BeNil()) + qClient.streams[5].dataToRead.Write([]byte("gzipped data")) + client.responses[5] <- response + Eventually(func() *http.Response { return doRsp }).ShouldNot(BeNil()) + Expect(doErr).ToNot(HaveOccurred()) + headers := getHeaderFields(getRequest(headerStream.dataWritten.Bytes())) + Expect(headers).To(HaveKeyWithValue("accept-encoding", "gzip")) + data := make([]byte, 12) + doRsp.Body.Read(data) + Expect(doRsp.ContentLength).ToNot(BeEquivalentTo(-1)) + Expect(data).To(Equal([]byte("gzipped data"))) + }) + }) + Context("handling the header stream", func() { var h2framer *http2.Framer diff --git a/h2quic/gzipreader.go b/h2quic/gzipreader.go new file mode 100644 index 00000000..91c226b1 --- /dev/null +++ b/h2quic/gzipreader.go @@ -0,0 +1,35 @@ +package h2quic + +// copied from net/transport.go + +// gzipReader wraps a response body so it can lazily +// call gzip.NewReader on the first call to Read +import ( + "compress/gzip" + "io" +) + +// call gzip.NewReader on the first call to Read +type gzipReader struct { + body io.ReadCloser // underlying Response.Body + zr *gzip.Reader // lazily-initialized gzip reader + zerr error // sticky error +} + +func (gz *gzipReader) Read(p []byte) (n int, err error) { + if gz.zerr != nil { + return 0, gz.zerr + } + if gz.zr == nil { + gz.zr, err = gzip.NewReader(gz.body) + if err != nil { + gz.zerr = err + return 0, err + } + } + return gz.zr.Read(p) +} + +func (gz *gzipReader) Close() error { + return gz.body.Close() +} diff --git a/h2quic/request_writer.go b/h2quic/request_writer.go index 7f421adc..ec4b8cef 100644 --- a/h2quic/request_writer.go +++ b/h2quic/request_writer.go @@ -34,7 +34,7 @@ func newRequestWriter(headerStream utils.Stream) *requestWriter { return rw } -func (w *requestWriter) WriteRequest(req *http.Request, dataStreamID protocol.StreamID) error { +func (w *requestWriter) WriteRequest(req *http.Request, dataStreamID protocol.StreamID, requestGzip bool) error { // TODO: add support for trailers // TODO: add support for gzip compression // TODO: write continuation frames, if the header frame is too long @@ -42,7 +42,7 @@ func (w *requestWriter) WriteRequest(req *http.Request, dataStreamID protocol.St w.mutex.Lock() defer w.mutex.Unlock() - w.encodeHeaders(req, false, "", actualContentLength(req)) + w.encodeHeaders(req, requestGzip, "", actualContentLength(req)) h2framer := http2.NewFramer(w.headerStream, nil) return h2framer.WriteHeaders(http2.HeadersFrameParam{ StreamID: uint32(dataStreamID), diff --git a/h2quic/request_writer_test.go b/h2quic/request_writer_test.go index 78ad181e..826ad828 100644 --- a/h2quic/request_writer_test.go +++ b/h2quic/request_writer_test.go @@ -44,7 +44,7 @@ var _ = Describe("Request", func() { It("writes a GET request", func() { req, err := http.NewRequest("GET", "https://quic.clemente.io/index.html?foo=bar", nil) Expect(err).ToNot(HaveOccurred()) - rw.WriteRequest(req, 1337) + rw.WriteRequest(req, 1337, false) headerFrame, headerFields := decode(headerStream.dataWritten.Bytes()) Expect(headerFrame.StreamID).To(Equal(uint32(1337))) Expect(headerFrame.HasPriority()).To(BeTrue()) @@ -52,6 +52,15 @@ var _ = Describe("Request", func() { Expect(headerFields).To(HaveKeyWithValue(":method", "GET")) Expect(headerFields).To(HaveKeyWithValue(":path", "/index.html?foo=bar")) Expect(headerFields).To(HaveKeyWithValue(":scheme", "https")) + Expect(headerFields).ToNot(HaveKey("accept-encoding")) + }) + + It("requests gzip compression, if requested", func() { + req, err := http.NewRequest("GET", "https://quic.clemente.io/index.html?foo=bar", nil) + Expect(err).ToNot(HaveOccurred()) + rw.WriteRequest(req, 1337, true) + _, headerFields := decode(headerStream.dataWritten.Bytes()) + Expect(headerFields).To(HaveKeyWithValue("accept-encoding", "gzip")) }) It("writes a POST request", func() { @@ -59,7 +68,7 @@ var _ = Describe("Request", func() { form.Add("foo", "bar") req, err := http.NewRequest("POST", "https://quic.clemente.io/upload.html", strings.NewReader(form.Encode())) Expect(err).ToNot(HaveOccurred()) - rw.WriteRequest(req, 5) + rw.WriteRequest(req, 5, false) _, headerFields := decode(headerStream.dataWritten.Bytes()) Expect(headerFields).To(HaveKeyWithValue(":method", "POST")) Expect(headerFields).To(HaveKey("content-length")) @@ -81,7 +90,7 @@ var _ = Describe("Request", func() { } req.AddCookie(cookie1) req.AddCookie(cookie2) - rw.WriteRequest(req, 11) + rw.WriteRequest(req, 11, false) _, headerFields := decode(headerStream.dataWritten.Bytes()) Expect(headerFields).To(HaveKeyWithValue("cookie", "Cookie #1=Value #1; Cookie #2=Value #2")) }) diff --git a/h2quic/response_setuncompressed.go b/h2quic/response_setuncompressed.go new file mode 100644 index 00000000..191a2484 --- /dev/null +++ b/h2quic/response_setuncompressed.go @@ -0,0 +1,9 @@ +// +build go1.7 + +package h2quic + +import "net/http" + +func setUncompressed(res *http.Response) { + res.Uncompressed = true +} diff --git a/h2quic/response_setuncompressed_go16.go b/h2quic/response_setuncompressed_go16.go new file mode 100644 index 00000000..7359f04a --- /dev/null +++ b/h2quic/response_setuncompressed_go16.go @@ -0,0 +1,9 @@ +// +build !go1.7 + +package h2quic + +import "net/http" + +func setUncompressed(res *http.Response) { + // http.Response.Uncompressed was introduced in go 1.7 +}