From de8d7a32b88688b9f85ccdab2f6e1a11bbeaa410 Mon Sep 17 00:00:00 2001 From: WeidiDeng Date: Thu, 13 Jul 2023 14:20:35 +0800 Subject: [PATCH] http3: return http.ErrContentLength when writing too large response (#3953) --- http3/response_writer.go | 18 ++++++++++++++++++ http3/response_writer_test.go | 11 +++++++++++ 2 files changed, 29 insertions(+) diff --git a/http3/response_writer.go b/http3/response_writer.go index dfbf0279..9fd40f72 100644 --- a/http3/response_writer.go +++ b/http3/response_writer.go @@ -23,6 +23,8 @@ type responseWriter struct { header http.Header status int // status code passed to WriteHeader headerWritten bool + contentLen int64 // if handler set valid Content-Length header + numWritten int64 // bytes written logger utils.Logger } @@ -61,6 +63,16 @@ func (w *responseWriter) WriteHeader(status int) { if _, ok := w.header["Date"]; !ok { w.header.Set("Date", time.Now().UTC().Format(http.TimeFormat)) } + // Content-Length checking + if clen := w.header.Get("Content-Length"); clen != "" { + if cl, err := strconv.ParseInt(clen, 10, 64); err == nil { + w.contentLen = cl + } else { + // emit a warning for malformed Content-Length and remove it + w.logger.Errorf("Malformed Content-Length %s", clen) + w.header.Del("Content-Length") + } + } } w.status = status @@ -111,6 +123,12 @@ func (w *responseWriter) Write(p []byte) (int, error) { if !bodyAllowed { return 0, http.ErrBodyNotAllowed } + + w.numWritten += int64(len(p)) + if w.contentLen != 0 && w.numWritten > w.contentLen { + return 0, http.ErrContentLength + } + df := &dataFrame{Length: uint64(len(p))} w.buf = w.buf[:0] w.buf = df.Append(w.buf) diff --git a/http3/response_writer_test.go b/http3/response_writer_test.go index 637353aa..044ec463 100644 --- a/http3/response_writer_test.go +++ b/http3/response_writer_test.go @@ -167,4 +167,15 @@ var _ = Describe("Response Writer", func() { Expect(rw.SetReadDeadline(time.Now().Add(1 * time.Second))).To(BeNil()) Expect(rw.SetWriteDeadline(time.Now().Add(1 * time.Second))).To(BeNil()) }) + + It(`checks Content-Length header`, func() { + rw.Header().Set("Content-Length", "6") + n, err := rw.Write([]byte("foobar")) + Expect(n).To(Equal(6)) + Expect(err).To(BeNil()) + + n, err = rw.Write([]byte("foobar")) + Expect(n).To(Equal(0)) + Expect(err).To(Equal(http.ErrContentLength)) + }) })