add a function to distinguish between long and short header packets (#3498)

This commit is contained in:
Marten Seemann 2022-08-11 16:35:01 +02:00 committed by GitHub
parent bebff462c8
commit 80fd1b57c8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 21 additions and 11 deletions

View file

@ -13,6 +13,7 @@ import (
"github.com/lucas-clemente/quic-go"
quicproxy "github.com/lucas-clemente/quic-go/integrationtests/tools/proxy"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
@ -73,7 +74,7 @@ var _ = Describe("Datagram test", func() {
return false
}
// don't drop Long Header packets
if packet[0]&0x80 == 1 {
if wire.IsLongHeaderPacket(packet[0]) {
return false
}
drop := mrand.Int()%10 == 0

View file

@ -732,7 +732,7 @@ var _ = Describe("0-RTT", func() {
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: ln.Addr().String(),
DelayPacket: func(dir quicproxy.Direction, data []byte) time.Duration {
if dir == quicproxy.DirectionIncoming && data[0]&0x80 > 0 && data[0]&0x30>>4 == 0 { // Initial packet from client
if dir == quicproxy.DirectionIncoming && wire.IsLongHeaderPacket(data[0]) && data[0]&0x30>>4 == 0 { // Initial packet from client
return rtt/2 + rtt
}
return rtt / 2

View file

@ -19,8 +19,7 @@ func ParseConnectionID(data []byte, shortHeaderConnIDLen int) (protocol.Connecti
if len(data) == 0 {
return nil, io.EOF
}
isLongHeader := data[0]&0x80 > 0
if !isLongHeader {
if !IsLongHeaderPacket(data[0]) {
if len(data) < shortHeaderConnIDLen+1 {
return nil, io.EOF
}
@ -36,12 +35,17 @@ func ParseConnectionID(data []byte, shortHeaderConnIDLen int) (protocol.Connecti
return protocol.ConnectionID(data[6 : 6+destConnIDLen]), nil
}
// IsLongHeaderPacket says if this is a Long Header packet
func IsLongHeaderPacket(firstByte byte) bool {
return firstByte&0x80 > 0
}
// IsVersionNegotiationPacket says if this is a version negotiation packet
func IsVersionNegotiationPacket(b []byte) bool {
if len(b) < 5 {
return false
}
return b[0]&0x80 > 0 && b[1] == 0 && b[2] == 0 && b[3] == 0 && b[4] == 0
return IsLongHeaderPacket(b[0]) && b[1] == 0 && b[2] == 0 && b[3] == 0 && b[4] == 0
}
// Is0RTTPacket says if this is a 0-RTT packet.
@ -50,7 +54,7 @@ func Is0RTTPacket(b []byte) bool {
if len(b) < 5 {
return false
}
if b[0]&0x80 == 0 {
if !IsLongHeaderPacket(b[0]) {
return false
}
version := protocol.VersionNumber(binary.BigEndian.Uint32(b[1:5]))
@ -129,7 +133,7 @@ func parseHeaderImpl(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error)
h := &Header{
typeByte: typeByte,
IsLongHeader: typeByte&0x80 > 0,
IsLongHeader: IsLongHeaderPacket(typeByte),
}
if !h.IsLongHeader {

View file

@ -576,6 +576,11 @@ var _ = Describe("Header Parsing", func() {
})
})
It("distinguishes long and short header packets", func() {
Expect(IsLongHeaderPacket(0x40)).To(BeFalse())
Expect(IsLongHeaderPacket(0x80 ^ 0x40 ^ 0x12)).To(BeTrue())
})
It("tells its packet type for logging", func() {
Expect((&Header{IsLongHeader: true, Type: protocol.PacketTypeHandshake}).PacketType()).To(Equal("Handshake"))
Expect((&Header{}).PacketType()).To(Equal("1-RTT"))

View file

@ -56,7 +56,7 @@ var _ = Describe("Version Negotiation Packets", func() {
destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
versions := []protocol.VersionNumber{1001, 1003}
data := ComposeVersionNegotiation(destConnID, srcConnID, versions)
Expect(data[0] & 0x80).ToNot(BeZero())
Expect(IsLongHeaderPacket(data[0])).To(BeTrue())
hdr, supportedVersions, err := ParseVersionNegotiationPacket(bytes.NewReader(data))
Expect(err).ToNot(HaveOccurred())
Expect(hdr.DestConnectionID).To(Equal(destConnID))

View file

@ -390,7 +390,7 @@ func (h *packetHandlerMap) handlePacket(p *receivedPacket) {
return
}
}
if p.data[0]&0x80 == 0 {
if !wire.IsLongHeaderPacket(p.data[0]) {
go h.maybeSendStatelessReset(p, connID)
return
}
@ -433,7 +433,7 @@ func (h *packetHandlerMap) handlePacket(p *receivedPacket) {
func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool {
// stateless resets are always short header packets
if data[0]&0x80 != 0 {
if wire.IsLongHeaderPacket(data[0]) {
return false
}
if len(data) < 17 /* type byte + 16 bytes for the reset token */ {

View file

@ -453,7 +453,7 @@ var _ = Describe("Packet Handler Map", func() {
done := make(chan struct{})
conn.EXPECT().WriteTo(gomock.Any(), addr).Do(func(b []byte, _ net.Addr) {
defer close(done)
Expect(b[0] & 0x80).To(BeZero()) // short header packet
Expect(wire.IsLongHeaderPacket(b[0])).To(BeFalse()) // short header packet
Expect(b).To(HaveLen(protocol.MinStatelessResetSize))
})
handler.handlePacket(&receivedPacket{