diff --git a/http3/client.go b/http3/client.go index 977355cc..7c59bc74 100644 --- a/http3/client.go +++ b/http3/client.go @@ -323,6 +323,21 @@ func (c *client) doRequest( respBody := newResponseBody(str, reqDone, func() { c.session.CloseWithError(quic.ErrorCode(errorFrameUnexpected), "") }) + + // Rules for when to set Content-Length are defined in https://tools.ietf.org/html/rfc7230#section-3.3.2. + _, hasTransferEncoding := res.Header["Transfer-Encoding"] + isInformational := res.StatusCode >= 100 && res.StatusCode < 200 + isNoContent := res.StatusCode == 204 + isSuccessfulConnect := req.Method == http.MethodConnect && res.StatusCode >= 200 && res.StatusCode < 300 + if !hasTransferEncoding && !isInformational && !isNoContent && !isSuccessfulConnect { + res.ContentLength = -1 + if clens, ok := res.Header["Content-Length"]; ok && len(clens) == 1 { + if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil { + res.ContentLength = clen64 + } + } + } + if requestGzip && res.Header.Get("Content-Encoding") == "gzip" { res.Header.Del("Content-Encoding") res.Header.Del("Content-Length") diff --git a/http3/client_test.go b/http3/client_test.go index 52723f5e..8dd79601 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -393,6 +393,19 @@ var _ = Describe("Client", func() { ) testDone := make(chan struct{}) + getHeadersFrame := func(headers map[string]string) []byte { + buf := &bytes.Buffer{} + headerBuf := &bytes.Buffer{} + enc := qpack.NewEncoder(headerBuf) + for name, value := range headers { + Expect(enc.WriteField(qpack.HeaderField{Name: name, Value: value})).To(Succeed()) + } + Expect(enc.Close()).To(Succeed()) + (&headersFrame{Length: uint64(headerBuf.Len())}).Write(buf) + buf.Write(headerBuf.Bytes()) + return buf.Bytes() + } + decodeHeader := func(str io.Reader) map[string]string { fields := make(map[string]string) decoder := qpack.NewDecoder(nil) @@ -548,15 +561,33 @@ var _ = Describe("Client", func() { Expect(err).To(MatchError("test done")) }) + It("sets the Content-Length", func() { + done := make(chan struct{}) + buf := &bytes.Buffer{} + buf.Write(getHeadersFrame(map[string]string{ + ":status": "200", + "Content-Length": "1337", + })) + (&dataFrame{Length: 0x6}).Write(buf) + buf.Write([]byte("foobar")) + str.EXPECT().Close().Do(func() { close(done) }) + sess.EXPECT().ConnectionState().Return(quic.ConnectionState{}) + str.EXPECT().CancelWrite(gomock.Any()).MaxTimes(1) // when reading the response errors + // the response body is sent asynchronously, while already reading the response + str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + req, err := client.RoundTrip(request) + Expect(err).ToNot(HaveOccurred()) + Expect(req.ContentLength).To(BeEquivalentTo(1337)) + Eventually(done).Should(BeClosed()) + }) + It("closes the connection when the first frame is not a HEADERS frame", func() { buf := &bytes.Buffer{} (&dataFrame{Length: 0x42}).Write(buf) sess.EXPECT().CloseWithError(quic.ErrorCode(errorFrameUnexpected), gomock.Any()) closed := make(chan struct{}) str.EXPECT().Close().Do(func() { close(closed) }) - str.EXPECT().Read(gomock.Any()).DoAndReturn(func(b []byte) (int, error) { - return buf.Read(b) - }).AnyTimes() + str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() _, err := client.RoundTrip(request) Expect(err).To(MatchError("expected first frame to be a HEADERS frame")) Eventually(closed).Should(BeClosed()) @@ -568,9 +599,7 @@ var _ = Describe("Client", func() { str.EXPECT().CancelWrite(quic.ErrorCode(errorFrameError)) closed := make(chan struct{}) str.EXPECT().Close().Do(func() { close(closed) }) - str.EXPECT().Read(gomock.Any()).DoAndReturn(func(b []byte) (int, error) { - return buf.Read(b) - }).AnyTimes() + str.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() _, err := client.RoundTrip(request) Expect(err).To(MatchError("HEADERS frame too large: 1338 bytes (max: 1337)")) Eventually(closed).Should(BeClosed()) @@ -723,7 +752,6 @@ var _ = Describe("Client", func() { Expect(err).ToNot(HaveOccurred()) data, err := ioutil.ReadAll(rsp.Body) Expect(err).ToNot(HaveOccurred()) - Expect(rsp.ContentLength).ToNot(BeEquivalentTo(-1)) Expect(string(data)).To(Equal("not gzipped")) Expect(rsp.Header.Get("Content-Encoding")).To(BeEmpty()) })