mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 20:57:36 +03:00
Merge pull request #2184 from lucas-clemente/http3-read-frame-error
close the connection on unexpected frames when reading a HTTP/3 body
This commit is contained in:
commit
672328ca30
5 changed files with 42 additions and 25 deletions
|
@ -1,7 +1,7 @@
|
||||||
package http3
|
package http3
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go"
|
"github.com/lucas-clemente/quic-go"
|
||||||
|
@ -19,22 +19,26 @@ type body struct {
|
||||||
reqDone chan<- struct{}
|
reqDone chan<- struct{}
|
||||||
reqDoneClosed bool
|
reqDoneClosed bool
|
||||||
|
|
||||||
|
onFrameError func()
|
||||||
|
|
||||||
bytesRemainingInFrame uint64
|
bytesRemainingInFrame uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ io.ReadCloser = &body{}
|
var _ io.ReadCloser = &body{}
|
||||||
|
|
||||||
func newRequestBody(str quic.Stream) *body {
|
func newRequestBody(str quic.Stream, onFrameError func()) *body {
|
||||||
return &body{
|
return &body{
|
||||||
str: str,
|
str: str,
|
||||||
isRequest: true,
|
onFrameError: onFrameError,
|
||||||
|
isRequest: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newResponseBody(str quic.Stream, done chan<- struct{}) *body {
|
func newResponseBody(str quic.Stream, done chan<- struct{}, onFrameError func()) *body {
|
||||||
return &body{
|
return &body{
|
||||||
str: str,
|
str: str,
|
||||||
reqDone: done,
|
onFrameError: onFrameError,
|
||||||
|
reqDone: done,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -62,7 +66,10 @@ func (r *body) readImpl(b []byte) (int, error) {
|
||||||
r.bytesRemainingInFrame = f.Length
|
r.bytesRemainingInFrame = f.Length
|
||||||
break parseLoop
|
break parseLoop
|
||||||
default:
|
default:
|
||||||
return 0, errors.New("unexpected frame")
|
r.onFrameError()
|
||||||
|
// parseNextFrame skips over unknown frame types
|
||||||
|
// Therefore, this condition is only entered when we parsed another known frame type.
|
||||||
|
return 0, fmt.Errorf("peer sent an unexpected frame: %T", f)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,12 +29,15 @@ func (t bodyType) String() string {
|
||||||
|
|
||||||
var _ = Describe("Body", func() {
|
var _ = Describe("Body", func() {
|
||||||
var (
|
var (
|
||||||
rb *body
|
rb *body
|
||||||
str *mockquic.MockStream
|
str *mockquic.MockStream
|
||||||
buf *bytes.Buffer
|
buf *bytes.Buffer
|
||||||
reqDone chan struct{}
|
reqDone chan struct{}
|
||||||
|
errorCbCalled bool
|
||||||
)
|
)
|
||||||
|
|
||||||
|
errorCb := func() { errorCbCalled = true }
|
||||||
|
|
||||||
getDataFrame := func(data []byte) []byte {
|
getDataFrame := func(data []byte) []byte {
|
||||||
b := &bytes.Buffer{}
|
b := &bytes.Buffer{}
|
||||||
(&dataFrame{Length: uint64(len(data))}).Write(b)
|
(&dataFrame{Length: uint64(len(data))}).Write(b)
|
||||||
|
@ -44,6 +47,7 @@ var _ = Describe("Body", func() {
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
buf = &bytes.Buffer{}
|
buf = &bytes.Buffer{}
|
||||||
|
errorCbCalled = false
|
||||||
})
|
})
|
||||||
|
|
||||||
for _, bt := range []bodyType{bodyTypeRequest, bodyTypeResponse} {
|
for _, bt := range []bodyType{bodyTypeRequest, bodyTypeResponse} {
|
||||||
|
@ -61,10 +65,10 @@ var _ = Describe("Body", func() {
|
||||||
|
|
||||||
switch bodyType {
|
switch bodyType {
|
||||||
case bodyTypeRequest:
|
case bodyTypeRequest:
|
||||||
rb = newRequestBody(str)
|
rb = newRequestBody(str, errorCb)
|
||||||
case bodyTypeResponse:
|
case bodyTypeResponse:
|
||||||
reqDone = make(chan struct{})
|
reqDone = make(chan struct{})
|
||||||
rb = newResponseBody(str, reqDone)
|
rb = newResponseBody(str, reqDone, errorCb)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -144,10 +148,11 @@ var _ = Describe("Body", func() {
|
||||||
Expect(err).To(HaveOccurred())
|
Expect(err).To(HaveOccurred())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("errors on unexpected frames", func() {
|
It("errors on unexpected frames, and calls the error callback", func() {
|
||||||
(&settingsFrame{}).Write(buf)
|
(&settingsFrame{}).Write(buf)
|
||||||
_, err := rb.Read([]byte{0})
|
_, err := rb.Read([]byte{0})
|
||||||
Expect(err).To(MatchError("unexpected frame"))
|
Expect(err).To(MatchError("peer sent an unexpected frame: *http3.settingsFrame"))
|
||||||
|
Expect(errorCbCalled).To(BeTrue())
|
||||||
})
|
})
|
||||||
|
|
||||||
if bodyType == bodyTypeRequest {
|
if bodyType == bodyTypeRequest {
|
||||||
|
|
|
@ -237,7 +237,9 @@ func (c *client) doRequest(
|
||||||
res.Header.Add(hf.Name, hf.Value)
|
res.Header.Add(hf.Name, hf.Value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
respBody := newResponseBody(str, reqDone)
|
respBody := newResponseBody(str, reqDone, func() {
|
||||||
|
c.session.CloseWithError(quic.ErrorCode(errorFrameUnexpected), "")
|
||||||
|
})
|
||||||
if requestGzip && res.Header.Get("Content-Encoding") == "gzip" {
|
if requestGzip && res.Header.Get("Content-Encoding") == "gzip" {
|
||||||
res.Header.Del("Content-Encoding")
|
res.Header.Del("Content-Encoding")
|
||||||
res.Header.Del("Content-Length")
|
res.Header.Del("Content-Length")
|
||||||
|
|
|
@ -180,7 +180,10 @@ func (s *Server) handleConn(sess quic.Session) {
|
||||||
}
|
}
|
||||||
go func() {
|
go func() {
|
||||||
defer ginkgo.GinkgoRecover()
|
defer ginkgo.GinkgoRecover()
|
||||||
if rerr := s.handleRequest(str, decoder); rerr.err != nil || rerr.streamErr != 0 || rerr.connErr != 0 {
|
rerr := s.handleRequest(str, decoder, func() {
|
||||||
|
sess.CloseWithError(quic.ErrorCode(errorFrameUnexpected), "")
|
||||||
|
})
|
||||||
|
if rerr.err != nil || rerr.streamErr != 0 || rerr.connErr != 0 {
|
||||||
s.logger.Debugf("Handling request failed: %s", err)
|
s.logger.Debugf("Handling request failed: %s", err)
|
||||||
if rerr.streamErr != 0 {
|
if rerr.streamErr != 0 {
|
||||||
str.CancelWrite(quic.ErrorCode(rerr.streamErr))
|
str.CancelWrite(quic.ErrorCode(rerr.streamErr))
|
||||||
|
@ -206,7 +209,7 @@ func (s *Server) maxHeaderBytes() uint64 {
|
||||||
return uint64(s.Server.MaxHeaderBytes)
|
return uint64(s.Server.MaxHeaderBytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleRequest(str quic.Stream, decoder *qpack.Decoder) requestError {
|
func (s *Server) handleRequest(str quic.Stream, decoder *qpack.Decoder, onFrameError func()) requestError {
|
||||||
frame, err := parseNextFrame(str)
|
frame, err := parseNextFrame(str)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return newStreamError(errorRequestIncomplete, err)
|
return newStreamError(errorRequestIncomplete, err)
|
||||||
|
@ -232,7 +235,7 @@ func (s *Server) handleRequest(str quic.Stream, decoder *qpack.Decoder) requestE
|
||||||
// TODO: use the right error code
|
// TODO: use the right error code
|
||||||
return newStreamError(errorGeneralProtocolError, err)
|
return newStreamError(errorGeneralProtocolError, err)
|
||||||
}
|
}
|
||||||
req.Body = newRequestBody(str)
|
req.Body = newRequestBody(str, onFrameError)
|
||||||
|
|
||||||
if s.logger.Debug() {
|
if s.logger.Debug() {
|
||||||
s.logger.Infof("%s %s%s, on stream %d", req.Method, req.Host, req.RequestURI, str.StreamID())
|
s.logger.Infof("%s %s%s, on stream %d", req.Method, req.Host, req.RequestURI, str.StreamID())
|
||||||
|
|
|
@ -117,7 +117,7 @@ var _ = Describe("Server", func() {
|
||||||
return len(p), nil
|
return len(p), nil
|
||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
|
|
||||||
Expect(s.handleRequest(str, qpackDecoder)).To(Equal(requestError{}))
|
Expect(s.handleRequest(str, qpackDecoder, nil)).To(Equal(requestError{}))
|
||||||
var req *http.Request
|
var req *http.Request
|
||||||
Eventually(requestChan).Should(Receive(&req))
|
Eventually(requestChan).Should(Receive(&req))
|
||||||
Expect(req.Host).To(Equal("www.example.com"))
|
Expect(req.Host).To(Equal("www.example.com"))
|
||||||
|
@ -133,7 +133,7 @@ var _ = Describe("Server", func() {
|
||||||
return responseBuf.Write(p)
|
return responseBuf.Write(p)
|
||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
|
|
||||||
serr := s.handleRequest(str, qpackDecoder)
|
serr := s.handleRequest(str, qpackDecoder, nil)
|
||||||
Expect(serr.err).ToNot(HaveOccurred())
|
Expect(serr.err).ToNot(HaveOccurred())
|
||||||
hfs := decodeHeader(responseBuf)
|
hfs := decodeHeader(responseBuf)
|
||||||
Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
|
Expect(hfs).To(HaveKeyWithValue(":status", []string{"200"}))
|
||||||
|
@ -152,7 +152,7 @@ var _ = Describe("Server", func() {
|
||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
str.EXPECT().CancelRead(gomock.Any())
|
str.EXPECT().CancelRead(gomock.Any())
|
||||||
|
|
||||||
serr := s.handleRequest(str, qpackDecoder)
|
serr := s.handleRequest(str, qpackDecoder, nil)
|
||||||
Expect(serr.err).ToNot(HaveOccurred())
|
Expect(serr.err).ToNot(HaveOccurred())
|
||||||
hfs := decodeHeader(responseBuf)
|
hfs := decodeHeader(responseBuf)
|
||||||
Expect(hfs).To(HaveKeyWithValue(":status", []string{"500"}))
|
Expect(hfs).To(HaveKeyWithValue(":status", []string{"500"}))
|
||||||
|
@ -294,7 +294,7 @@ var _ = Describe("Server", func() {
|
||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
str.EXPECT().CancelRead(quic.ErrorCode(errorEarlyResponse))
|
str.EXPECT().CancelRead(quic.ErrorCode(errorEarlyResponse))
|
||||||
|
|
||||||
serr := s.handleRequest(str, qpackDecoder)
|
serr := s.handleRequest(str, qpackDecoder, nil)
|
||||||
Expect(serr.err).ToNot(HaveOccurred())
|
Expect(serr.err).ToNot(HaveOccurred())
|
||||||
Eventually(handlerCalled).Should(BeClosed())
|
Eventually(handlerCalled).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
@ -317,7 +317,7 @@ var _ = Describe("Server", func() {
|
||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
str.EXPECT().CancelRead(quic.ErrorCode(errorEarlyResponse))
|
str.EXPECT().CancelRead(quic.ErrorCode(errorEarlyResponse))
|
||||||
|
|
||||||
serr := s.handleRequest(str, qpackDecoder)
|
serr := s.handleRequest(str, qpackDecoder, nil)
|
||||||
Expect(serr.err).ToNot(HaveOccurred())
|
Expect(serr.err).ToNot(HaveOccurred())
|
||||||
Eventually(handlerCalled).Should(BeClosed())
|
Eventually(handlerCalled).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue