refactor parsing of headers with invalid reserved bits

This commit is contained in:
Marten Seemann 2019-12-09 14:56:08 +08:00
parent 9d6ae474d7
commit 490d9ddab2
2 changed files with 20 additions and 14 deletions

View file

@ -28,15 +28,15 @@ type ExtendedHeader struct {
PacketNumber protocol.PacketNumber
}
func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) {
func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (bool /* reserved bits valid */, error) {
// read the (now unencrypted) first byte
var err error
h.typeByte, err = b.ReadByte()
if err != nil {
return nil, err
return false, err
}
if _, err := b.Seek(int64(h.ParsedLen())-1, io.SeekCurrent); err != nil {
return nil, err
return false, err
}
if h.IsLongHeader {
return h.parseLongHeader(b, v)
@ -44,31 +44,29 @@ func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (*Exte
return h.parseShortHeader(b, v)
}
func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, _ protocol.VersionNumber) (*ExtendedHeader, error) {
func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, _ protocol.VersionNumber) (bool /* reserved bits valid */, error) {
if err := h.readPacketNumber(b); err != nil {
return nil, err
return false, err
}
var err error
if h.typeByte&0xc != 0 {
err = ErrInvalidReservedBits
return false, nil
}
return h, err
return true, nil
}
func (h *ExtendedHeader) parseShortHeader(b *bytes.Reader, _ protocol.VersionNumber) (*ExtendedHeader, error) {
func (h *ExtendedHeader) parseShortHeader(b *bytes.Reader, _ protocol.VersionNumber) (bool /* reserved bits valid */, error) {
h.KeyPhase = protocol.KeyPhaseZero
if h.typeByte&0x4 > 0 {
h.KeyPhase = protocol.KeyPhaseOne
}
if err := h.readPacketNumber(b); err != nil {
return nil, err
return false, err
}
var err error
if h.typeByte&0x18 != 0 {
err = ErrInvalidReservedBits
return false, nil
}
return h, err
return true, nil
}
func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error {

View file

@ -241,7 +241,15 @@ func (h *Header) ParsedLen() protocol.ByteCount {
// ParseExtended parses the version dependent part of the header.
// The Reader has to be set such that it points to the first byte of the header.
func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.VersionNumber) (*ExtendedHeader, error) {
return h.toExtendedHeader().parse(b, ver)
extHdr := h.toExtendedHeader()
reservedBitsValid, err := extHdr.parse(b, ver)
if err != nil {
return nil, err
}
if !reservedBitsValid {
return extHdr, ErrInvalidReservedBits
}
return extHdr, nil
}
func (h *Header) toExtendedHeader() *ExtendedHeader {