add support for gzipped HTTP/3 requests

This commit is contained in:
Marten Seemann 2019-04-15 16:13:42 +09:00
parent 89ecbdfdc2
commit de6ab88437
8 changed files with 195 additions and 11 deletions

View file

@ -5,6 +5,7 @@ coverage:
- streams_map_incoming_uni.go
- streams_map_outgoing_bidi.go
- streams_map_outgoing_uni.go
- http3/gzip_reader.go
- internal/ackhandler/packet_linkedlist.go
- internal/utils/byteinterval_linkedlist.go
- internal/utils/packetinterval_linkedlist.go

View file

@ -29,6 +29,7 @@ type roundTripperOpts struct {
type client struct {
tlsConf *tls.Config
config *quic.Config
opts *roundTripperOpts
dialOnce sync.Once
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
@ -47,7 +48,7 @@ type client struct {
func newClient(
hostname string,
tlsConf *tls.Config,
_ *roundTripperOpts, // TODO: implement gzip compression
opts *roundTripperOpts,
quicConfig *quic.Config,
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error),
) *client {
@ -67,6 +68,7 @@ func newClient(
requestWriter: newRequestWriter(logger),
decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}),
config: quicConfig,
opts: opts,
dialer: dialer,
logger: logger,
}
@ -138,7 +140,11 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
return nil, err
}
if err := c.requestWriter.WriteRequest(str, req); err != nil {
var requestGzip bool
if !c.opts.DisableCompression && req.Method != "HEAD" && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" {
requestGzip = true
}
if err := c.requestWriter.WriteRequest(str, req, requestGzip); err != nil {
return nil, err
}
@ -163,7 +169,6 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
Proto: "HTTP/3",
ProtoMajor: 3,
Header: http.Header{},
Body: newResponseBody(&responseBody{str}),
}
for _, hf := range hfs {
switch hf.Name {
@ -178,5 +183,16 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
res.Header.Add(hf.Name, hf.Value)
}
}
respBody := newResponseBody(&responseBody{str})
if requestGzip && res.Header.Get("Content-Encoding") == "gzip" {
res.Header.Del("Content-Encoding")
res.Header.Del("Content-Length")
res.ContentLength = -1
res.Body = newGzipReader(respBody)
res.Uncompressed = true
} else {
res.Body = respBody
}
return res, nil
}

View file

@ -2,9 +2,11 @@ package http3
import (
"bytes"
"compress/gzip"
"crypto/tls"
"errors"
"io"
"io/ioutil"
"net/http"
"time"
@ -268,5 +270,97 @@ var _ = Describe("Client", func() {
Expect(err).To(MatchError("test done"))
})
})
Context("gzip compression", func() {
var gzippedData []byte // a gzipped foobar
var response *http.Response
BeforeEach(func() {
var b bytes.Buffer
w := gzip.NewWriter(&b)
w.Write([]byte("foobar"))
w.Close()
gzippedData = b.Bytes()
response = &http.Response{
StatusCode: 200,
Header: http.Header{"Content-Length": []string{"1000"}},
}
_ = gzippedData
_ = response
})
It("adds the gzip header to requests", func() {
sess.EXPECT().OpenStreamSync().Return(str, nil)
buf := &bytes.Buffer{}
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return buf.Write(p)
})
str.EXPECT().Close()
str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done"))
_, err := client.RoundTrip(request)
Expect(err).To(MatchError("test done"))
hfs := decodeHeader(buf)
Expect(hfs).To(HaveKeyWithValue("accept-encoding", "gzip"))
})
It("doesn't add gzip if the header disable it", func() {
client = newClient("quic.clemente.io:1337", nil, &roundTripperOpts{DisableCompression: true}, nil, nil)
sess.EXPECT().OpenStreamSync().Return(str, nil)
buf := &bytes.Buffer{}
str.EXPECT().Write(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return buf.Write(p)
})
str.EXPECT().Close()
str.EXPECT().Read(gomock.Any()).Return(0, errors.New("test done"))
_, err := client.RoundTrip(request)
Expect(err).To(MatchError("test done"))
hfs := decodeHeader(buf)
Expect(hfs).ToNot(HaveKey("accept-encoding"))
})
It("decompresses the response", func() {
sess.EXPECT().OpenStreamSync().Return(str, nil)
buf := &bytes.Buffer{}
rw := newResponseWriter(buf, utils.DefaultLogger)
rw.Header().Set("Content-Encoding", "gzip")
gz := gzip.NewWriter(rw)
gz.Write([]byte("gzipped response"))
gz.Close()
str.EXPECT().Write(gomock.Any()).AnyTimes()
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return buf.Read(p)
}).AnyTimes()
str.EXPECT().Close()
rsp, err := client.RoundTrip(request)
Expect(err).ToNot(HaveOccurred())
data, err := ioutil.ReadAll(rsp.Body)
Expect(err).ToNot(HaveOccurred())
Expect(rsp.ContentLength).To(BeEquivalentTo(-1))
Expect(string(data)).To(Equal("gzipped response"))
Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty())
Expect(rsp.Uncompressed).To(BeTrue())
})
It("only decompresses the response if the response contains the right content-encoding header", func() {
sess.EXPECT().OpenStreamSync().Return(str, nil)
buf := &bytes.Buffer{}
rw := newResponseWriter(buf, utils.DefaultLogger)
rw.Write([]byte("not gzipped"))
str.EXPECT().Write(gomock.Any()).AnyTimes()
str.EXPECT().Read(gomock.Any()).DoAndReturn(func(p []byte) (int, error) {
return buf.Read(p)
}).AnyTimes()
str.EXPECT().Close()
rsp, err := client.RoundTrip(request)
Expect(err).ToNot(HaveOccurred())
data, err := ioutil.ReadAll(rsp.Body)
Expect(err).ToNot(HaveOccurred())
Expect(rsp.ContentLength).ToNot(BeEquivalentTo(-1))
Expect(string(data)).To(Equal("not gzipped"))
Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty())
})
})
})
})

39
http3/gzip_reader.go Normal file
View file

@ -0,0 +1,39 @@
package http3
// copied from net/transport.go
// gzipReader wraps a response body so it can lazily
// call gzip.NewReader on the first call to Read
import (
"compress/gzip"
"io"
)
// call gzip.NewReader on the first call to Read
type gzipReader struct {
body io.ReadCloser // underlying Response.Body
zr *gzip.Reader // lazily-initialized gzip reader
zerr error // sticky error
}
func newGzipReader(body io.ReadCloser) io.ReadCloser {
return &gzipReader{body: body}
}
func (gz *gzipReader) Read(p []byte) (n int, err error) {
if gz.zerr != nil {
return 0, gz.zerr
}
if gz.zr == nil {
gz.zr, err = gzip.NewReader(gz.body)
if err != nil {
gz.zerr = err
return 0, err
}
}
return gz.zr.Read(p)
}
func (gz *gzipReader) Close() error {
return gz.body.Close()
}

View file

@ -36,8 +36,8 @@ func newRequestWriter(logger utils.Logger) *requestWriter {
}
}
func (w *requestWriter) WriteRequest(str quic.Stream, req *http.Request) error {
headers, err := w.getHeaders(req)
func (w *requestWriter) WriteRequest(str quic.Stream, req *http.Request, gzip bool) error {
headers, err := w.getHeaders(req, gzip)
if err != nil {
return err
}
@ -62,12 +62,12 @@ func (w *requestWriter) WriteRequest(str quic.Stream, req *http.Request) error {
return nil
}
func (w *requestWriter) getHeaders(req *http.Request) ([]byte, error) {
func (w *requestWriter) getHeaders(req *http.Request, gzip bool) ([]byte, error) {
w.mutex.Lock()
defer w.mutex.Unlock()
defer w.encoder.Close()
if err := w.encodeHeaders(req, false, "", actualContentLength(req)); err != nil {
if err := w.encodeHeaders(req, gzip, "", actualContentLength(req)); err != nil {
return nil, err
}

View file

@ -54,7 +54,7 @@ var _ = Describe("Request Writer", func() {
str.EXPECT().Close()
req, err := http.NewRequest("GET", "https://quic.clemente.io/index.html?foo=bar", nil)
Expect(err).ToNot(HaveOccurred())
Expect(rw.WriteRequest(str, req)).To(Succeed())
Expect(rw.WriteRequest(str, req, false)).To(Succeed())
headerFields := decode(strBuf)
Expect(headerFields).To(HaveKeyWithValue(":authority", "quic.clemente.io"))
Expect(headerFields).To(HaveKeyWithValue(":method", "GET"))
@ -69,7 +69,7 @@ var _ = Describe("Request Writer", func() {
postData := bytes.NewReader([]byte("foobar"))
req, err := http.NewRequest("POST", "https://quic.clemente.io/upload.html", postData)
Expect(err).ToNot(HaveOccurred())
Expect(rw.WriteRequest(str, req)).To(Succeed())
Expect(rw.WriteRequest(str, req, false)).To(Succeed())
headerFields := decode(strBuf)
Expect(headerFields).To(HaveKeyWithValue(":method", "POST"))
Expect(headerFields).To(HaveKey("content-length"))
@ -98,8 +98,17 @@ var _ = Describe("Request Writer", func() {
}
req.AddCookie(cookie1)
req.AddCookie(cookie2)
Expect(rw.WriteRequest(str, req)).To(Succeed())
Expect(rw.WriteRequest(str, req, false)).To(Succeed())
headerFields := decode(strBuf)
Expect(headerFields).To(HaveKeyWithValue("cookie", `Cookie #1="Value #1"; Cookie #2="Value #2"`))
})
It("adds the header for gzip support", func() {
str.EXPECT().Close()
req, err := http.NewRequest("GET", "https://quic.clemente.io/", nil)
Expect(err).ToNot(HaveOccurred())
Expect(rw.WriteRequest(str, req, true)).To(Succeed())
headerFields := decode(strBuf)
Expect(headerFields).To(HaveKeyWithValue("accept-encoding", "gzip"))
})
})

View file

@ -81,7 +81,7 @@ var _ = Describe("Server", func() {
closed := make(chan struct{})
str.EXPECT().Close().Do(func() { close(closed) })
rw := newRequestWriter(utils.DefaultLogger)
Expect(rw.WriteRequest(str, req)).To(Succeed())
Expect(rw.WriteRequest(str, req, false)).To(Succeed())
Eventually(closed).Should(BeClosed())
return buf.Bytes()
}

View file

@ -2,6 +2,7 @@ package self_test
import (
"bytes"
"compress/gzip"
"crypto/tls"
"fmt"
"io/ioutil"
@ -42,6 +43,7 @@ var _ = Describe("HTTP tests", func() {
TLSClientConfig: &tls.Config{
RootCAs: testdata.GetRootCA(),
},
DisableCompression: true,
QuicConfig: &quic.Config{
Versions: []protocol.VersionNumber{version},
IdleTimeout: 10 * time.Second,
@ -159,6 +161,29 @@ var _ = Describe("HTTP tests", func() {
Expect(err).ToNot(HaveOccurred())
Expect(body).To(Equal(testserver.PRData))
})
It("uses gzip compression", func() {
http.HandleFunc("/gzipped/hello", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
Expect(r.Header.Get("Accept-Encoding")).To(Equal("gzip"))
w.Header().Set("Content-Encoding", "gzip")
w.Header().Set("foo", "bar")
gw := gzip.NewWriter(w)
defer gw.Close()
gw.Write([]byte("Hello, World!\n"))
})
client.Transport.(*http3.RoundTripper).DisableCompression = false
resp, err := client.Get("https://localhost:" + testserver.Port() + "/gzipped/hello")
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
Expect(resp.Uncompressed).To(BeTrue())
body, err := ioutil.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second))
Expect(err).ToNot(HaveOccurred())
Expect(string(body)).To(Equal("Hello, World!\n"))
})
})
}
})