diff --git a/http3/body.go b/http3/body.go index e450a56a..e0b19321 100644 --- a/http3/body.go +++ b/http3/body.go @@ -1,7 +1,7 @@ package http3 import ( - "errors" + "fmt" "io" "github.com/lucas-clemente/quic-go" @@ -19,22 +19,26 @@ type body struct { reqDone chan<- struct{} reqDoneClosed bool + onFrameError func() + bytesRemainingInFrame uint64 } var _ io.ReadCloser = &body{} -func newRequestBody(str quic.Stream) *body { +func newRequestBody(str quic.Stream, onFrameError func()) *body { return &body{ - str: str, - isRequest: true, + str: str, + onFrameError: onFrameError, + isRequest: true, } } -func newResponseBody(str quic.Stream, done chan<- struct{}) *body { +func newResponseBody(str quic.Stream, done chan<- struct{}, onFrameError func()) *body { return &body{ - str: str, - reqDone: done, + str: str, + onFrameError: onFrameError, + reqDone: done, } } @@ -62,7 +66,10 @@ func (r *body) readImpl(b []byte) (int, error) { r.bytesRemainingInFrame = f.Length break parseLoop default: - return 0, errors.New("unexpected frame") + r.onFrameError() + // parseNextFrame skips over unknown frame types + // Therefore, this condition is only entered when we parsed another known frame type. + return 0, fmt.Errorf("peer sent an unexpected frame: %T", f) } } } diff --git a/http3/body_test.go b/http3/body_test.go index 9843237e..17e90d7f 100644 --- a/http3/body_test.go +++ b/http3/body_test.go @@ -29,12 +29,15 @@ func (t bodyType) String() string { var _ = Describe("Body", func() { var ( - rb *body - str *mockquic.MockStream - buf *bytes.Buffer - reqDone chan struct{} + rb *body + str *mockquic.MockStream + buf *bytes.Buffer + reqDone chan struct{} + errorCbCalled bool ) + errorCb := func() { errorCbCalled = true } + getDataFrame := func(data []byte) []byte { b := &bytes.Buffer{} (&dataFrame{Length: uint64(len(data))}).Write(b) @@ -44,6 +47,7 @@ var _ = Describe("Body", func() { BeforeEach(func() { buf = &bytes.Buffer{} + errorCbCalled = false }) for _, bt := range []bodyType{bodyTypeRequest, bodyTypeResponse} { @@ -61,10 +65,10 @@ var _ = Describe("Body", func() { switch bodyType { case bodyTypeRequest: - rb = newRequestBody(str) + rb = newRequestBody(str, errorCb) case bodyTypeResponse: reqDone = make(chan struct{}) - rb = newResponseBody(str, reqDone) + rb = newResponseBody(str, reqDone, errorCb) } }) @@ -144,10 +148,11 @@ var _ = Describe("Body", func() { Expect(err).To(HaveOccurred()) }) - It("errors on unexpected frames", func() { + It("errors on unexpected frames, and calls the error callback", func() { (&settingsFrame{}).Write(buf) _, err := rb.Read([]byte{0}) - Expect(err).To(MatchError("unexpected frame")) + Expect(err).To(MatchError("peer sent an unexpected frame: *http3.settingsFrame")) + Expect(errorCbCalled).To(BeTrue()) }) if bodyType == bodyTypeRequest { diff --git a/http3/client.go b/http3/client.go index a105f51c..ffe949d7 100644 --- a/http3/client.go +++ b/http3/client.go @@ -237,7 +237,9 @@ func (c *client) doRequest( res.Header.Add(hf.Name, hf.Value) } } - respBody := newResponseBody(str, reqDone) + respBody := newResponseBody(str, reqDone, func() { + c.session.CloseWithError(quic.ErrorCode(errorFrameUnexpected), "") + }) if requestGzip && res.Header.Get("Content-Encoding") == "gzip" { res.Header.Del("Content-Encoding") res.Header.Del("Content-Length") diff --git a/http3/server.go b/http3/server.go index bd2ad8e0..16bffd54 100644 --- a/http3/server.go +++ b/http3/server.go @@ -180,7 +180,10 @@ func (s *Server) handleConn(sess quic.Session) { } go func() { defer ginkgo.GinkgoRecover() - if rerr := s.handleRequest(str, decoder); rerr.err != nil || rerr.streamErr != 0 || rerr.connErr != 0 { + rerr := s.handleRequest(str, decoder, func() { + sess.CloseWithError(quic.ErrorCode(errorFrameUnexpected), "") + }) + if rerr.err != nil || rerr.streamErr != 0 || rerr.connErr != 0 { s.logger.Debugf("Handling request failed: %s", err) if rerr.streamErr != 0 { str.CancelWrite(quic.ErrorCode(rerr.streamErr)) @@ -206,7 +209,7 @@ func (s *Server) maxHeaderBytes() uint64 { return uint64(s.Server.MaxHeaderBytes) } -func (s *Server) handleRequest(str quic.Stream, decoder *qpack.Decoder) requestError { +func (s *Server) handleRequest(str quic.Stream, decoder *qpack.Decoder, onFrameError func()) requestError { frame, err := parseNextFrame(str) if err != nil { return newStreamError(errorRequestIncomplete, err) @@ -232,7 +235,7 @@ func (s *Server) handleRequest(str quic.Stream, decoder *qpack.Decoder) requestE // TODO: use the right error code return newStreamError(errorGeneralProtocolError, err) } - req.Body = newRequestBody(str) + req.Body = newRequestBody(str, onFrameError) if s.logger.Debug() { s.logger.Infof("%s %s%s, on stream %d", req.Method, req.Host, req.RequestURI, str.StreamID()) diff --git a/http3/server_test.go b/http3/server_test.go index 985c8354..d02f415c 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -117,7 +117,7 @@ var _ = Describe("Server", func() { return len(p), nil }).AnyTimes() - Expect(s.handleRequest(str, qpackDecoder)).To(Equal(requestError{})) + Expect(s.handleRequest(str, qpackDecoder, nil)).To(Equal(requestError{})) var req *http.Request Eventually(requestChan).Should(Receive(&req)) Expect(req.Host).To(Equal("www.example.com")) @@ -133,7 +133,7 @@ var _ = Describe("Server", func() { return responseBuf.Write(p) }).AnyTimes() - serr := s.handleRequest(str, qpackDecoder) + serr := s.handleRequest(str, qpackDecoder, nil) Expect(serr.err).ToNot(HaveOccurred()) hfs := decodeHeader(responseBuf) Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"})) @@ -152,7 +152,7 @@ var _ = Describe("Server", func() { }).AnyTimes() str.EXPECT().CancelRead(gomock.Any()) - serr := s.handleRequest(str, qpackDecoder) + serr := s.handleRequest(str, qpackDecoder, nil) Expect(serr.err).ToNot(HaveOccurred()) hfs := decodeHeader(responseBuf) Expect(hfs).To(HaveKeyWithValue(":status", []string{"500"})) @@ -294,7 +294,7 @@ var _ = Describe("Server", func() { }).AnyTimes() str.EXPECT().CancelRead(quic.ErrorCode(errorEarlyResponse)) - serr := s.handleRequest(str, qpackDecoder) + serr := s.handleRequest(str, qpackDecoder, nil) Expect(serr.err).ToNot(HaveOccurred()) Eventually(handlerCalled).Should(BeClosed()) }) @@ -317,7 +317,7 @@ var _ = Describe("Server", func() { }).AnyTimes() str.EXPECT().CancelRead(quic.ErrorCode(errorEarlyResponse)) - serr := s.handleRequest(str, qpackDecoder) + serr := s.handleRequest(str, qpackDecoder, nil) Expect(serr.err).ToNot(HaveOccurred()) Eventually(handlerCalled).Should(BeClosed()) })