Merge pull request #3644 from lucas-clemente/short-header-packing

use a separate code path for handling short header packets
This commit is contained in:
Marten Seemann 2023-01-17 01:08:54 -08:00 committed by GitHub
commit 3d4bbc28ba
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
37 changed files with 1230 additions and 1439 deletions

View file

@ -878,7 +878,7 @@ func (s *connection) handlePacketImpl(rp *receivedPacket) bool {
}
if wire.IsLongHeaderPacket(p.data[0]) {
hdr, packetData, rest, err := wire.ParsePacket(p.data, s.srcConnIDLen)
hdr, packetData, rest, err := wire.ParsePacket(p.data)
if err != nil {
if s.tracer != nil {
dropReason := logging.PacketDropHeaderParseError
@ -1030,7 +1030,7 @@ func (s *connection) handleLongHeaderPacket(p *receivedPacket, hdr *wire.Header)
return false
}
if err := s.handleUnpackedPacket(packet, p.ecn, p.rcvTime, p.Size()); err != nil {
if err := s.handleUnpackedLongHeaderPacket(packet, p.ecn, p.rcvTime, p.Size()); err != nil {
s.closeLocal(err)
return false
}
@ -1193,7 +1193,7 @@ func (s *connection) handleVersionNegotiationPacket(p *receivedPacket) {
})
}
func (s *connection) handleUnpackedPacket(
func (s *connection) handleUnpackedLongHeaderPacket(
packet *unpackedPacket,
ecn protocol.ECN,
rcvTime time.Time,
@ -1212,7 +1212,7 @@ func (s *connection) handleUnpackedPacket(
s.tracer.NegotiatedVersion(s.version, clientVersions, serverVersions)
}
// The server can change the source connection ID with the first Handshake packet.
if s.perspective == protocol.PerspectiveClient && packet.hdr.IsLongHeader && packet.hdr.SrcConnectionID != s.handshakeDestConnID {
if s.perspective == protocol.PerspectiveClient && packet.hdr.SrcConnectionID != s.handshakeDestConnID {
cid := packet.hdr.SrcConnectionID
s.logger.Debugf("Received first packet. Switching destination connection ID to: %s", cid)
s.handshakeDestConnID = cid
@ -1827,30 +1827,27 @@ func (s *connection) maybeSendAckOnlyPacket() error {
if packet == nil {
return nil
}
s.logCoalescedPacket(packet)
for _, p := range packet.packets {
s.sentPacketHandler.SentPacket(p.ToAckHandlerPacket(time.Now(), s.retransmissionQueue))
}
s.connIDManager.SentPacket()
s.sendQueue.Send(packet.buffer)
s.sendPackedCoalescedPacket(packet, time.Now())
return nil
}
packet, err := s.packer.PackPacket(true)
now := time.Now()
p, buffer, err := s.packer.PackPacket(true, now)
if err != nil {
if err == errNothingToPack {
return nil
}
return err
}
if packet == nil {
return nil
}
s.sendPackedPacket(packet, time.Now())
s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buffer.Len(), false)
s.sendPackedShortHeaderPacket(buffer, p.Packet, now)
return nil
}
func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel) error {
// Queue probe packets until we actually send out a packet,
// or until there are no more packets to queue.
var packet *packedPacket
var packet *coalescedPacket
for {
if wasQueued := s.sentPacketHandler.QueueProbePacket(encLevel); !wasQueued {
break
@ -1882,10 +1879,10 @@ func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel) error {
return err
}
}
if packet == nil || packet.packetContents == nil {
if packet == nil || (len(packet.longHdrPackets) == 0 && packet.shortHdrPacket == nil) {
return fmt.Errorf("connection BUG: couldn't pack %s probe packet", encLevel)
}
s.sendPackedPacket(packet, time.Now())
s.sendPackedCoalescedPacket(packet, time.Now())
return nil
}
@ -1902,39 +1899,54 @@ func (s *connection) sendPacket() (bool, error) {
return false, err
}
s.sentFirstPacket = true
s.logCoalescedPacket(packet)
for _, p := range packet.packets {
if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() {
s.firstAckElicitingPacketAfterIdleSentTime = now
}
s.sentPacketHandler.SentPacket(p.ToAckHandlerPacket(now, s.retransmissionQueue))
}
s.connIDManager.SentPacket()
s.sendQueue.Send(packet.buffer)
s.sendPackedCoalescedPacket(packet, now)
return true, nil
}
if !s.config.DisablePathMTUDiscovery && s.mtuDiscoverer.ShouldSendProbe(now) {
packet, err := s.packer.PackMTUProbePacket(s.mtuDiscoverer.GetPing())
} else if !s.config.DisablePathMTUDiscovery && s.mtuDiscoverer.ShouldSendProbe(now) {
ping, size := s.mtuDiscoverer.GetPing()
p, buffer, err := s.packer.PackMTUProbePacket(ping, size, now)
if err != nil {
return false, err
}
s.sendPackedPacket(packet, now)
s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buffer.Len(), false)
s.sendPackedShortHeaderPacket(buffer, p.Packet, now)
return true, nil
}
packet, err := s.packer.PackPacket(false)
if err != nil || packet == nil {
p, buffer, err := s.packer.PackPacket(false, now)
if err != nil {
if err == errNothingToPack {
return false, nil
}
return false, err
}
s.sendPackedPacket(packet, now)
s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buffer.Len(), false)
s.sendPackedShortHeaderPacket(buffer, p.Packet, now)
return true, nil
}
func (s *connection) sendPackedPacket(packet *packedPacket, now time.Time) {
if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && packet.IsAckEliciting() {
func (s *connection) sendPackedShortHeaderPacket(buffer *packetBuffer, p *ackhandler.Packet, now time.Time) {
if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && ackhandler.HasAckElicitingFrames(p.Frames) {
s.firstAckElicitingPacketAfterIdleSentTime = now
}
s.logPacket(packet)
s.sentPacketHandler.SentPacket(packet.ToAckHandlerPacket(now, s.retransmissionQueue))
s.sentPacketHandler.SentPacket(p)
s.connIDManager.SentPacket()
s.sendQueue.Send(buffer)
}
func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time.Time) {
s.logCoalescedPacket(packet)
for _, p := range packet.longHdrPackets {
if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() {
s.firstAckElicitingPacketAfterIdleSentTime = now
}
s.sentPacketHandler.SentPacket(p.ToAckHandlerPacket(now, s.retransmissionQueue))
}
if p := packet.shortHdrPacket; p != nil {
if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() {
s.firstAckElicitingPacketAfterIdleSentTime = now
}
s.sentPacketHandler.SentPacket(p.Packet)
}
s.connIDManager.SentPacket()
s.sendQueue.Send(packet.buffer)
}
@ -1961,7 +1973,18 @@ func (s *connection) sendConnectionClose(e error) ([]byte, error) {
return packet.buffer.Data, s.conn.Write(packet.buffer.Data)
}
func (s *connection) logPacketContents(p *packetContents) {
func (s *connection) logLongHeaderPacket(p *longHeaderPacket) {
// quic-go logging
if s.logger.Debug() {
p.header.Log(s.logger)
if p.ack != nil {
wire.LogFrame(s.logger, p.ack, true)
}
for _, frame := range p.frames {
wire.LogFrame(s.logger, frame.Frame, true)
}
}
// tracing
if s.tracer != nil {
frames := make([]logging.Frame, 0, len(p.frames))
@ -1972,40 +1995,72 @@ func (s *connection) logPacketContents(p *packetContents) {
if p.ack != nil {
ack = logutils.ConvertAckFrame(p.ack)
}
s.tracer.SentPacket(p.header, p.length, ack, frames)
s.tracer.SentLongHeaderPacket(p.header, p.length, ack, frames)
}
}
func (s *connection) logShortHeaderPacket(
destConnID protocol.ConnectionID,
ackFrame *wire.AckFrame,
frames []ackhandler.Frame,
pn protocol.PacketNumber,
pnLen protocol.PacketNumberLen,
kp protocol.KeyPhaseBit,
size protocol.ByteCount,
isCoalesced bool,
) {
if s.logger.Debug() && !isCoalesced {
s.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, 1-RTT", pn, size, s.logID)
}
// quic-go logging
if s.logger.Debug() {
wire.LogShortHeader(s.logger, destConnID, pn, pnLen, kp)
if ackFrame != nil {
wire.LogFrame(s.logger, ackFrame, true)
}
for _, frame := range frames {
wire.LogFrame(s.logger, frame.Frame, true)
}
}
// quic-go logging
if !s.logger.Debug() {
return
}
p.header.Log(s.logger)
if p.ack != nil {
wire.LogFrame(s.logger, p.ack, true)
}
for _, frame := range p.frames {
wire.LogFrame(s.logger, frame.Frame, true)
// tracing
if s.tracer != nil {
fs := make([]logging.Frame, 0, len(frames))
for _, f := range frames {
fs = append(fs, logutils.ConvertFrame(f.Frame))
}
var ack *logging.AckFrame
if ackFrame != nil {
ack = logutils.ConvertAckFrame(ackFrame)
}
s.tracer.SentShortHeaderPacket(
&logging.ShortHeader{
DestConnectionID: destConnID,
PacketNumber: pn,
PacketNumberLen: pnLen,
KeyPhase: kp,
},
size,
ack,
fs,
)
}
}
func (s *connection) logCoalescedPacket(packet *coalescedPacket) {
if s.logger.Debug() {
if len(packet.packets) > 1 {
s.logger.Debugf("-> Sending coalesced packet (%d parts, %d bytes) for connection %s", len(packet.packets), packet.buffer.Len(), s.logID)
if len(packet.longHdrPackets) > 1 {
s.logger.Debugf("-> Sending coalesced packet (%d parts, %d bytes) for connection %s", len(packet.longHdrPackets), packet.buffer.Len(), s.logID)
} else {
s.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, %s", packet.packets[0].header.PacketNumber, packet.buffer.Len(), s.logID, packet.packets[0].EncryptionLevel())
s.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, %s", packet.longHdrPackets[0].header.PacketNumber, packet.buffer.Len(), s.logID, packet.longHdrPackets[0].EncryptionLevel())
}
}
for _, p := range packet.packets {
s.logPacketContents(p)
for _, p := range packet.longHdrPackets {
s.logLongHeaderPacket(p)
}
}
func (s *connection) logPacket(packet *packedPacket) {
if s.logger.Debug() {
s.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, %s", packet.header.PacketNumber, packet.buffer.Len(), s.logID, packet.EncryptionLevel())
if p := packet.shortHdrPacket; p != nil {
s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, p.Length, true)
}
s.logPacketContents(packet.packetContents)
}
// AcceptStream returns the next stream openend by the peer

View file

@ -53,16 +53,33 @@ var _ = Describe("Connection", func() {
destConnID := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1})
clientDestConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
getPacket := func(pn protocol.PacketNumber) *packedPacket {
getShortHeaderPacket := func(pn protocol.PacketNumber) (shortHeaderPacket, *packetBuffer) {
buffer := getPacketBuffer()
buffer.Data = append(buffer.Data, []byte("foobar")...)
return &packedPacket{
buffer: buffer,
packetContents: &packetContents{
header: &wire.ExtendedHeader{PacketNumber: pn},
return shortHeaderPacket{Packet: &ackhandler.Packet{PacketNumber: pn}}, buffer
}
getCoalescedPacket := func(pn protocol.PacketNumber, isLongHeader bool) *coalescedPacket {
buffer := getPacketBuffer()
buffer.Data = append(buffer.Data, []byte("foobar")...)
packet := &coalescedPacket{buffer: buffer}
if isLongHeader {
packet.longHdrPackets = []*longHeaderPacket{{
header: &wire.ExtendedHeader{
Header: wire.Header{},
PacketNumber: pn,
},
length: 6, // foobar
},
}}
} else {
packet.shortHdrPacket = &shortHeaderPacket{
Packet: &ackhandler.Packet{
PacketNumber: pn,
Length: 6,
},
}
}
return packet
}
expectReplaceWithClosed := func() {
@ -555,11 +572,7 @@ var _ = Describe("Connection", func() {
streamManager.EXPECT().CloseWithError(gomock.Any())
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
buf := &bytes.Buffer{}
hdr := &wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: srcConnID},
PacketNumberLen: protocol.PacketNumberLen2,
}
Expect(hdr.Write(buf, conn.version)).To(Succeed())
Expect(wire.WriteShortHeader(buf, srcConnID, 42, protocol.PacketNumberLen2, protocol.KeyPhaseOne)).To(Succeed())
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(time.Time, []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) {
b, err := (&wire.ConnectionCloseFrame{ErrorCode: uint64(qerr.StreamLimitError)}).Append(nil, conn.version)
@ -593,16 +606,16 @@ var _ = Describe("Connection", func() {
sph.EXPECT().HasPacingBudget().Return(true).AnyTimes()
// only expect a single SentPacket() call
sph.EXPECT().SentPacket(gomock.Any())
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
streamManager.EXPECT().CloseWithError(gomock.Any())
connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
cryptoSetup.EXPECT().Close()
conn.sentPacketHandler = sph
p := getPacket(1)
packer.EXPECT().PackPacket(false).Return(p, nil)
packer.EXPECT().PackPacket(false).Return(nil, nil).AnyTimes()
p, buffer := getShortHeaderPacket(1)
packer.EXPECT().PackPacket(false, gomock.Any()).Return(p, buffer, nil)
packer.EXPECT().PackPacket(false, gomock.Any()).Return(shortHeaderPacket{}, nil, errNothingToPack).AnyTimes()
runConn()
conn.queueControlFrame(&wire.PingFrame{})
conn.scheduleSending()
@ -635,7 +648,17 @@ var _ = Describe("Connection", func() {
conn.unpacker = unpacker
})
getPacket := func(extHdr *wire.ExtendedHeader, data []byte) *receivedPacket {
getShortHeaderPacket := func(connID protocol.ConnectionID, pn protocol.PacketNumber, data []byte) *receivedPacket {
buf := &bytes.Buffer{}
Expect(wire.WriteShortHeader(buf, connID, pn, protocol.PacketNumberLen2, protocol.KeyPhaseOne)).To(Succeed())
return &receivedPacket{
data: append(buf.Bytes(), data...),
buffer: getPacketBuffer(),
rcvTime: time.Now(),
}
}
getLongHeaderPacket := func(extHdr *wire.ExtendedHeader, data []byte) *receivedPacket {
buf := &bytes.Buffer{}
Expect(extHdr.Write(buf, conn.version)).To(Succeed())
return &receivedPacket{
@ -646,8 +669,7 @@ var _ = Describe("Connection", func() {
}
It("drops Retry packets", func() {
p := getPacket(&wire.ExtendedHeader{Header: wire.Header{
IsLongHeader: true,
p := getLongHeaderPacket(&wire.ExtendedHeader{Header: wire.Header{
Type: protocol.PacketTypeRetry,
DestConnectionID: destConnID,
SrcConnectionID: srcConnID,
@ -672,11 +694,10 @@ var _ = Describe("Connection", func() {
})
It("drops packets for which header decryption fails", func() {
p := getPacket(&wire.ExtendedHeader{
p := getLongHeaderPacket(&wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
Version: conn.version,
Type: protocol.PacketTypeHandshake,
Version: conn.version,
},
PacketNumberLen: protocol.PacketNumberLen2,
}, nil)
@ -686,11 +707,10 @@ var _ = Describe("Connection", func() {
})
It("drops packets for which the version is unsupported", func() {
p := getPacket(&wire.ExtendedHeader{
p := getLongHeaderPacket(&wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
Version: conn.version + 1,
Type: protocol.PacketTypeHandshake,
Version: conn.version + 1,
},
PacketNumberLen: protocol.PacketNumberLen2,
}, nil)
@ -706,9 +726,8 @@ var _ = Describe("Connection", func() {
}()
protocol.SupportedVersions = append(protocol.SupportedVersions, conn.version+1)
p := getPacket(&wire.ExtendedHeader{
p := getLongHeaderPacket(&wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
DestConnectionID: destConnID,
SrcConnectionID: srcConnID,
@ -723,7 +742,6 @@ var _ = Describe("Connection", func() {
It("informs the ReceivedPacketHandler about non-ack-eliciting packets", func() {
hdr := &wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
DestConnectionID: srcConnID,
Version: protocol.Version1,
@ -734,7 +752,7 @@ var _ = Describe("Connection", func() {
}
unpackedHdr := *hdr
unpackedHdr.PacketNumber = 0x1337
packet := getPacket(hdr, nil)
packet := getLongHeaderPacket(hdr, nil)
packet.ecn = protocol.ECNCE
rcvTime := time.Now().Add(-10 * time.Second)
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), rcvTime, gomock.Any()).Return(&unpackedPacket{
@ -755,15 +773,10 @@ var _ = Describe("Connection", func() {
})
It("informs the ReceivedPacketHandler about ack-eliciting packets", func() {
hdr := &wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: srcConnID},
PacketNumber: 0x37,
PacketNumberLen: protocol.PacketNumberLen1,
}
rcvTime := time.Now().Add(-10 * time.Second)
b, err := (&wire.PingFrame{}).Append(nil, conn.version)
Expect(err).ToNot(HaveOccurred())
packet := getPacket(hdr, nil)
packet := getShortHeaderPacket(srcConnID, 0x37, nil)
packet.ecn = protocol.ECT1
unpacker.EXPECT().UnpackShortHeader(rcvTime, gomock.Any()).Return(protocol.PacketNumber(0x1337), protocol.PacketNumberLen2, protocol.KeyPhaseZero, b, nil)
rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl)
@ -778,12 +791,7 @@ var _ = Describe("Connection", func() {
})
It("drops duplicate packets", func() {
hdr := &wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: srcConnID},
PacketNumber: 0x37,
PacketNumberLen: protocol.PacketNumberLen1,
}
packet := getPacket(hdr, nil)
packet := getShortHeaderPacket(srcConnID, 0x37, nil)
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(0x1337), protocol.PacketNumberLen2, protocol.KeyPhaseOne, []byte("foobar"), nil)
rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl)
rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.Encryption1RTT).Return(true)
@ -803,9 +811,8 @@ var _ = Describe("Connection", func() {
conn.run()
}()
expectReplaceWithClosed()
p := getPacket(&wire.ExtendedHeader{
p := getLongHeaderPacket(&wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
DestConnectionID: srcConnID,
Version: conn.version,
@ -837,11 +844,7 @@ var _ = Describe("Connection", func() {
packer.EXPECT().PackCoalescedPacket(false) // only expect a single call
for i := 0; i < 3; i++ {
conn.handlePacket(getPacket(&wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: srcConnID},
PacketNumber: 0x1337,
PacketNumberLen: protocol.PacketNumberLen2,
}, []byte("foobar")))
conn.handlePacket(getShortHeaderPacket(srcConnID, 0x1337+protocol.PacketNumber(i), []byte("foobar")))
}
go func() {
@ -876,11 +879,7 @@ var _ = Describe("Connection", func() {
packer.EXPECT().PackCoalescedPacket(false).Times(3) // only expect a single call
for i := 0; i < 3; i++ {
conn.handlePacket(getPacket(&wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: srcConnID},
PacketNumber: 0x1337,
PacketNumberLen: protocol.PacketNumberLen2,
}, []byte("foobar")))
conn.handlePacket(getShortHeaderPacket(srcConnID, 0x1337+protocol.PacketNumber(i), []byte("foobar")))
}
go func() {
@ -919,10 +918,7 @@ var _ = Describe("Connection", func() {
}()
expectReplaceWithClosed()
mconn.EXPECT().Write(gomock.Any())
packet := getPacket(&wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: srcConnID},
PacketNumberLen: protocol.PacketNumberLen1,
}, nil)
packet := getShortHeaderPacket(srcConnID, 0x42, nil)
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
conn.handlePacket(packet)
@ -942,10 +938,7 @@ var _ = Describe("Connection", func() {
}()
expectReplaceWithClosed()
tracer.EXPECT().DroppedPacket(logging.PacketType1RTT, gomock.Any(), logging.PacketDropHeaderParseError)
conn.handlePacket(getPacket(&wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: srcConnID},
PacketNumberLen: protocol.PacketNumberLen1,
}, nil))
conn.handlePacket(getShortHeaderPacket(srcConnID, 0x42, nil))
Consistently(runErr).ShouldNot(Receive())
// make the go routine return
packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
@ -973,20 +966,15 @@ var _ = Describe("Connection", func() {
}()
expectReplaceWithClosed()
mconn.EXPECT().Write(gomock.Any())
packet := getPacket(&wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: srcConnID},
PacketNumberLen: protocol.PacketNumberLen1,
}, nil)
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
conn.handlePacket(packet)
conn.handlePacket(getShortHeaderPacket(srcConnID, 0x42, nil))
Eventually(conn.Context().Done()).Should(BeClosed())
})
It("ignores packets with a different source connection ID", func() {
hdr1 := &wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
DestConnectionID: destConnID,
SrcConnectionID: srcConnID,
@ -998,7 +986,6 @@ var _ = Describe("Connection", func() {
}
hdr2 := &wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
DestConnectionID: destConnID,
SrcConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}),
@ -1016,12 +1003,12 @@ var _ = Describe("Connection", func() {
hdr: hdr1,
data: []byte{0}, // one PADDING frame
}, nil)
p1 := getPacket(hdr1, nil)
p1 := getLongHeaderPacket(hdr1, nil)
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(p1.data)), gomock.Any())
Expect(conn.handlePacketImpl(p1)).To(BeTrue())
// The next packet has to be ignored, since the source connection ID doesn't match.
p2 := getPacket(hdr2, nil)
p2 := getLongHeaderPacket(hdr2, nil)
tracer.EXPECT().DroppedPacket(logging.PacketTypeInitial, protocol.ByteCount(len(p2.data)), logging.PacketDropUnknownConnectionID)
Expect(conn.handlePacketImpl(p2)).To(BeFalse())
})
@ -1030,7 +1017,6 @@ var _ = Describe("Connection", func() {
conn.handshakeComplete = false
hdr := &wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
DestConnectionID: destConnID,
SrcConnectionID: srcConnID,
@ -1041,7 +1027,7 @@ var _ = Describe("Connection", func() {
PacketNumber: 1,
}
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrKeysNotYetAvailable)
packet := getPacket(hdr, nil)
packet := getLongHeaderPacket(hdr, nil)
tracer.EXPECT().BufferedPacket(logging.PacketTypeHandshake, packet.Size())
Expect(conn.handlePacketImpl(packet)).To(BeFalse())
Expect(conn.undecryptablePackets).To(Equal([]*receivedPacket{packet}))
@ -1050,10 +1036,7 @@ var _ = Describe("Connection", func() {
Context("updating the remote address", func() {
It("doesn't support connection migration", func() {
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(10), protocol.PacketNumberLen2, protocol.KeyPhaseZero, []byte{0} /* one PADDING frame */, nil)
packet := getPacket(&wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: srcConnID},
PacketNumberLen: protocol.PacketNumberLen1,
}, nil)
packet := getShortHeaderPacket(srcConnID, 0x42, nil)
packet.remoteAddr = &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)}
tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet.data)), gomock.Any())
Expect(conn.handlePacketImpl(packet)).To(BeTrue())
@ -1067,7 +1050,6 @@ var _ = Describe("Connection", func() {
getPacketWithLength := func(connID protocol.ConnectionID, length protocol.ByteCount) (int /* header length */, *receivedPacket) {
hdr := &wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
DestConnectionID: connID,
SrcConnectionID: destConnID,
@ -1079,7 +1061,7 @@ var _ = Describe("Connection", func() {
hdrLen := hdr.GetLength(conn.version)
b := make([]byte, 1)
rand.Read(b)
packet := getPacket(hdr, bytes.Repeat(b, int(length)-3))
packet := getLongHeaderPacket(hdr, bytes.Repeat(b, int(length)-3))
return int(hdrLen), packet
}
@ -1090,7 +1072,7 @@ var _ = Describe("Connection", func() {
return &unpackedPacket{
encryptionLevel: protocol.EncryptionHandshake,
data: []byte{0},
hdr: &wire.ExtendedHeader{Header: wire.Header{IsLongHeader: true}},
hdr: &wire.ExtendedHeader{Header: wire.Header{}},
}, nil
})
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), protocol.ByteCount(len(packet.data)), gomock.Any())
@ -1106,7 +1088,7 @@ var _ = Describe("Connection", func() {
data: []byte{0},
hdr: &wire.ExtendedHeader{
PacketNumber: 1,
Header: wire.Header{SrcConnectionID: destConnID, IsLongHeader: true},
Header: wire.Header{SrcConnectionID: destConnID},
},
}, nil
})
@ -1118,7 +1100,7 @@ var _ = Describe("Connection", func() {
data: []byte{0},
hdr: &wire.ExtendedHeader{
PacketNumber: 2,
Header: wire.Header{SrcConnectionID: destConnID, IsLongHeader: true},
Header: wire.Header{SrcConnectionID: destConnID},
},
}, nil
})
@ -1141,7 +1123,7 @@ var _ = Describe("Connection", func() {
return &unpackedPacket{
encryptionLevel: protocol.EncryptionHandshake,
data: []byte{0},
hdr: &wire.ExtendedHeader{Header: wire.Header{IsLongHeader: true}},
hdr: &wire.ExtendedHeader{Header: wire.Header{}},
}, nil
}),
)
@ -1165,7 +1147,7 @@ var _ = Describe("Connection", func() {
return &unpackedPacket{
encryptionLevel: protocol.EncryptionHandshake,
data: []byte{0},
hdr: &wire.ExtendedHeader{Header: wire.Header{IsLongHeader: true}},
hdr: &wire.ExtendedHeader{Header: wire.Header{}},
}, nil
})
_, packet2 := getPacketWithLength(wrongConnID, 123)
@ -1227,13 +1209,18 @@ var _ = Describe("Connection", func() {
sph.EXPECT().SentPacket(gomock.Any())
conn.sentPacketHandler = sph
runConn()
p := getPacket(1)
packer.EXPECT().PackPacket(false).Return(p, nil)
packer.EXPECT().PackPacket(false).Return(nil, nil).AnyTimes()
p, buffer := getShortHeaderPacket(1)
packer.EXPECT().PackPacket(false, gomock.Any()).Return(p, buffer, nil)
packer.EXPECT().PackPacket(false, gomock.Any()).Return(shortHeaderPacket{}, nil, errNothingToPack).AnyTimes()
sent := make(chan struct{})
sender.EXPECT().WouldBlock().AnyTimes()
sender.EXPECT().Send(gomock.Any()).Do(func(packet *packetBuffer) { close(sent) })
tracer.EXPECT().SentPacket(p.header, p.buffer.Len(), nil, []logging.Frame{})
tracer.EXPECT().SentShortHeaderPacket(&logging.ShortHeader{
DestConnectionID: p.DestConnID,
PacketNumber: p.PacketNumber,
PacketNumberLen: p.PacketNumberLen,
KeyPhase: p.KeyPhase,
}, buffer.Len(), nil, []logging.Frame{})
conn.scheduleSending()
Eventually(sent).Should(BeClosed())
})
@ -1241,7 +1228,7 @@ var _ = Describe("Connection", func() {
It("doesn't send packets if there's nothing to send", func() {
conn.handshakeConfirmed = true
runConn()
packer.EXPECT().PackPacket(false).Return(nil, nil).AnyTimes()
packer.EXPECT().PackPacket(false, gomock.Any()).Return(shortHeaderPacket{}, nil, errNothingToPack).AnyTimes()
conn.receivedPacketHandler.ReceivedPacket(0x035e, protocol.ECNNon, protocol.Encryption1RTT, time.Now(), true)
conn.scheduleSending()
time.Sleep(50 * time.Millisecond) // make sure there are no calls to mconn.Write()
@ -1272,14 +1259,14 @@ var _ = Describe("Connection", func() {
fc := mocks.NewMockConnectionFlowController(mockCtrl)
fc.EXPECT().IsNewlyBlocked().Return(true, protocol.ByteCount(1337))
fc.EXPECT().IsNewlyBlocked()
p := getPacket(1)
packer.EXPECT().PackPacket(false).Return(p, nil)
packer.EXPECT().PackPacket(false).Return(nil, nil).AnyTimes()
p, buffer := getShortHeaderPacket(1)
packer.EXPECT().PackPacket(false, gomock.Any()).Return(p, buffer, nil)
packer.EXPECT().PackPacket(false, gomock.Any()).Return(shortHeaderPacket{}, nil, errNothingToPack).AnyTimes()
conn.connFlowController = fc
runConn()
sent := make(chan struct{})
sender.EXPECT().Send(gomock.Any()).Do(func(packet *packetBuffer) { close(sent) })
tracer.EXPECT().SentPacket(p.header, p.length, nil, []logging.Frame{})
tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), buffer.Len(), nil, []logging.Frame{})
conn.scheduleSending()
Eventually(sent).Should(BeClosed())
frames, _ := conn.framer.AppendControlFrames(nil, 1000)
@ -1326,7 +1313,7 @@ var _ = Describe("Connection", func() {
sph.EXPECT().SendMode().Return(sendMode)
sph.EXPECT().SendMode().Return(ackhandler.SendNone)
sph.EXPECT().QueueProbePacket(encLevel)
p := getPacket(123)
p := getCoalescedPacket(123, enc != protocol.Encryption1RTT)
packer.EXPECT().MaybePackProbePacket(encLevel).Return(p, nil)
sph.EXPECT().SentPacket(gomock.Any()).Do(func(packet *ackhandler.Packet) {
Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(123)))
@ -1335,7 +1322,11 @@ var _ = Describe("Connection", func() {
runConn()
sent := make(chan struct{})
sender.EXPECT().Send(gomock.Any()).Do(func(packet *packetBuffer) { close(sent) })
tracer.EXPECT().SentPacket(p.header, p.length, gomock.Any(), gomock.Any())
if enc == protocol.Encryption1RTT {
tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), p.shortHdrPacket.Length, gomock.Any(), gomock.Any())
} else {
tracer.EXPECT().SentLongHeaderPacket(gomock.Any(), p.longHdrPackets[0].length, gomock.Any(), gomock.Any())
}
conn.scheduleSending()
Eventually(sent).Should(BeClosed())
})
@ -1347,7 +1338,7 @@ var _ = Describe("Connection", func() {
sph.EXPECT().SendMode().Return(sendMode)
sph.EXPECT().SendMode().Return(ackhandler.SendNone)
sph.EXPECT().QueueProbePacket(encLevel).Return(false)
p := getPacket(123)
p := getCoalescedPacket(123, enc != protocol.Encryption1RTT)
packer.EXPECT().MaybePackProbePacket(encLevel).Return(p, nil)
sph.EXPECT().SentPacket(gomock.Any()).Do(func(packet *ackhandler.Packet) {
Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(123)))
@ -1356,7 +1347,11 @@ var _ = Describe("Connection", func() {
runConn()
sent := make(chan struct{})
sender.EXPECT().Send(gomock.Any()).Do(func(packet *packetBuffer) { close(sent) })
tracer.EXPECT().SentPacket(p.header, p.length, gomock.Any(), gomock.Any())
if enc == protocol.Encryption1RTT {
tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), p.shortHdrPacket.Length, gomock.Any(), gomock.Any())
} else {
tracer.EXPECT().SentLongHeaderPacket(gomock.Any(), p.longHdrPackets[0].length, gomock.Any(), gomock.Any())
}
conn.scheduleSending()
Eventually(sent).Should(BeClosed())
// We're using a mock packet packer in this test.
@ -1374,7 +1369,7 @@ var _ = Describe("Connection", func() {
)
BeforeEach(func() {
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
sph = mockackhandler.NewMockSentPacketHandler(mockCtrl)
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
conn.handshakeConfirmed = true
@ -1405,8 +1400,10 @@ var _ = Describe("Connection", func() {
sph.EXPECT().HasPacingBudget()
sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour))
sph.EXPECT().SendMode().Return(ackhandler.SendAny).Times(3)
packer.EXPECT().PackPacket(false).Return(getPacket(10), nil)
packer.EXPECT().PackPacket(false).Return(getPacket(11), nil)
p, buffer := getShortHeaderPacket(10)
packer.EXPECT().PackPacket(false, gomock.Any()).Return(p, buffer, nil)
p, buffer = getShortHeaderPacket(11)
packer.EXPECT().PackPacket(false, gomock.Any()).Return(p, buffer, nil)
sender.EXPECT().WouldBlock().AnyTimes()
sender.EXPECT().Send(gomock.Any()).Times(2)
go func() {
@ -1422,8 +1419,9 @@ var _ = Describe("Connection", func() {
sph.EXPECT().SentPacket(gomock.Any())
sph.EXPECT().HasPacingBudget().Return(true).AnyTimes()
sph.EXPECT().SendMode().Return(ackhandler.SendAny).Times(2)
packer.EXPECT().PackPacket(false).Return(getPacket(10), nil)
packer.EXPECT().PackPacket(false).Return(nil, nil)
p, buffer := getShortHeaderPacket(10)
packer.EXPECT().PackPacket(false, gomock.Any()).Return(p, buffer, nil)
packer.EXPECT().PackPacket(false, gomock.Any()).Return(shortHeaderPacket{}, nil, errNothingToPack)
sender.EXPECT().WouldBlock().AnyTimes()
sender.EXPECT().Send(gomock.Any())
go func() {
@ -1440,7 +1438,9 @@ var _ = Describe("Connection", func() {
sph.EXPECT().HasPacingBudget()
sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour))
sph.EXPECT().SendMode().Return(ackhandler.SendAny)
packer.EXPECT().PackPacket(true).Return(getPacket(10), nil)
p, buffer := getShortHeaderPacket(10)
packer.EXPECT().PackPacket(true, gomock.Any()).Return(p, buffer, nil)
sender.EXPECT().WouldBlock().AnyTimes()
sender.EXPECT().Send(gomock.Any())
go func() {
@ -1459,7 +1459,8 @@ var _ = Describe("Connection", func() {
sph.EXPECT().HasPacingBudget().Return(true)
sph.EXPECT().SendMode().Return(ackhandler.SendAny)
sph.EXPECT().SendMode().Return(ackhandler.SendAck)
packer.EXPECT().PackPacket(false).Return(getPacket(100), nil)
p, buffer := getShortHeaderPacket(100)
packer.EXPECT().PackPacket(false, gomock.Any()).Return(p, buffer, nil)
sender.EXPECT().WouldBlock().AnyTimes()
sender.EXPECT().Send(gomock.Any())
go func() {
@ -1474,14 +1475,16 @@ var _ = Describe("Connection", func() {
It("paces packets", func() {
pacingDelay := scaleDuration(100 * time.Millisecond)
sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes()
p1, buffer1 := getShortHeaderPacket(100)
p2, buffer2 := getShortHeaderPacket(101)
gomock.InOrder(
sph.EXPECT().HasPacingBudget().Return(true),
packer.EXPECT().PackPacket(false).Return(getPacket(100), nil),
packer.EXPECT().PackPacket(false, gomock.Any()).Return(p1, buffer1, nil),
sph.EXPECT().SentPacket(gomock.Any()),
sph.EXPECT().HasPacingBudget(),
sph.EXPECT().TimeUntilSend().Return(time.Now().Add(pacingDelay)),
sph.EXPECT().HasPacingBudget().Return(true),
packer.EXPECT().PackPacket(false).Return(getPacket(101), nil),
packer.EXPECT().PackPacket(false, gomock.Any()).Return(p2, buffer2, nil),
sph.EXPECT().SentPacket(gomock.Any()),
sph.EXPECT().HasPacingBudget(),
sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)),
@ -1506,9 +1509,10 @@ var _ = Describe("Connection", func() {
sph.EXPECT().HasPacingBudget()
sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour))
sph.EXPECT().SendMode().Return(ackhandler.SendAny).Times(4)
packer.EXPECT().PackPacket(false).Return(getPacket(1000), nil)
packer.EXPECT().PackPacket(false).Return(getPacket(1001), nil)
packer.EXPECT().PackPacket(false).Return(getPacket(1002), nil)
for pn := protocol.PacketNumber(1000); pn < 1003; pn++ {
p, buffer := getShortHeaderPacket(pn)
packer.EXPECT().PackPacket(false, gomock.Any()).Return(p, buffer, nil)
}
written := make(chan struct{}, 3)
sender.EXPECT().WouldBlock().AnyTimes()
sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { written <- struct{}{} }).Times(3)
@ -1538,8 +1542,9 @@ var _ = Describe("Connection", func() {
sph.EXPECT().SentPacket(gomock.Any())
sph.EXPECT().HasPacingBudget().Return(true).AnyTimes()
sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes()
packer.EXPECT().PackPacket(false).Return(getPacket(1000), nil)
packer.EXPECT().PackPacket(false).Return(nil, nil)
p, buffer := getShortHeaderPacket(1000)
packer.EXPECT().PackPacket(false, gomock.Any()).Return(p, buffer, nil)
packer.EXPECT().PackPacket(false, gomock.Any()).Return(shortHeaderPacket{}, nil, errNothingToPack)
sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { close(written) })
available <- struct{}{}
Eventually(written).Should(BeClosed())
@ -1561,8 +1566,9 @@ var _ = Describe("Connection", func() {
})
sph.EXPECT().HasPacingBudget().Return(true).AnyTimes()
sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes()
packer.EXPECT().PackPacket(false).Return(getPacket(1000), nil)
packer.EXPECT().PackPacket(false).Return(nil, nil)
p, buffer := getShortHeaderPacket(1000)
packer.EXPECT().PackPacket(false, gomock.Any()).Return(p, buffer, nil)
packer.EXPECT().PackPacket(false, gomock.Any()).Return(shortHeaderPacket{}, nil, errNothingToPack)
sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { close(written) })
conn.scheduleSending()
@ -1575,7 +1581,8 @@ var _ = Describe("Connection", func() {
sph.EXPECT().SentPacket(gomock.Any())
sph.EXPECT().HasPacingBudget().Return(true).AnyTimes()
sph.EXPECT().SendMode().Return(ackhandler.SendAny)
packer.EXPECT().PackPacket(false).Return(getPacket(1000), nil)
p, buffer := getShortHeaderPacket(1000)
packer.EXPECT().PackPacket(false, gomock.Any()).Return(p, buffer, nil)
written := make(chan struct{}, 1)
sender.EXPECT().WouldBlock()
sender.EXPECT().WouldBlock().Return(true).Times(2)
@ -1596,8 +1603,9 @@ var _ = Describe("Connection", func() {
sph.EXPECT().HasPacingBudget().Return(true).AnyTimes()
sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes()
sender.EXPECT().WouldBlock().AnyTimes()
packer.EXPECT().PackPacket(false).Return(getPacket(1001), nil)
packer.EXPECT().PackPacket(false).Return(nil, nil)
p, buffer = getShortHeaderPacket(1001)
packer.EXPECT().PackPacket(false, gomock.Any()).Return(p, buffer, nil)
packer.EXPECT().PackPacket(false, gomock.Any()).Return(shortHeaderPacket{}, nil, errNothingToPack)
sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { written <- struct{}{} })
available <- struct{}{}
Eventually(written).Should(Receive())
@ -1611,7 +1619,7 @@ var _ = Describe("Connection", func() {
sph.EXPECT().HasPacingBudget().Return(true)
sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes()
sender.EXPECT().WouldBlock().AnyTimes()
packer.EXPECT().PackPacket(false)
packer.EXPECT().PackPacket(false, gomock.Any()).Return(shortHeaderPacket{}, nil, errNothingToPack)
// don't EXPECT any calls to mconn.Write()
go func() {
defer GinkgoRecover()
@ -1636,7 +1644,8 @@ var _ = Describe("Connection", func() {
mtuDiscoverer.EXPECT().ShouldSendProbe(gomock.Any()).Return(true)
ping := ackhandler.Frame{Frame: &wire.PingFrame{}}
mtuDiscoverer.EXPECT().GetPing().Return(ping, protocol.ByteCount(1234))
packer.EXPECT().PackMTUProbePacket(ping, protocol.ByteCount(1234)).Return(getPacket(1), nil)
p, buffer := getShortHeaderPacket(1)
packer.EXPECT().PackMTUProbePacket(ping, protocol.ByteCount(1234), gomock.Any()).Return(p, buffer, nil)
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
@ -1680,8 +1689,9 @@ var _ = Describe("Connection", func() {
sph.EXPECT().HasPacingBudget().Return(true).AnyTimes()
sph.EXPECT().SentPacket(gomock.Any())
conn.sentPacketHandler = sph
packer.EXPECT().PackPacket(false).Return(getPacket(1), nil)
packer.EXPECT().PackPacket(false).Return(nil, nil)
p, buffer := getShortHeaderPacket(1)
packer.EXPECT().PackPacket(false, gomock.Any()).Return(p, buffer, nil)
packer.EXPECT().PackPacket(false, gomock.Any()).Return(shortHeaderPacket{}, nil, errNothingToPack)
go func() {
defer GinkgoRecover()
@ -1693,14 +1703,15 @@ var _ = Describe("Connection", func() {
// only EXPECT calls after scheduleSending is called
written := make(chan struct{})
sender.EXPECT().Send(gomock.Any()).Do(func(*packetBuffer) { close(written) })
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
conn.scheduleSending()
Eventually(written).Should(BeClosed())
})
It("sets the timer to the ack timer", func() {
packer.EXPECT().PackPacket(false).Return(getPacket(1234), nil)
packer.EXPECT().PackPacket(false).Return(nil, nil)
p, buffer := getShortHeaderPacket(1234)
packer.EXPECT().PackPacket(false, gomock.Any()).Return(p, buffer, nil)
packer.EXPECT().PackPacket(false, gomock.Any()).Return(shortHeaderPacket{}, nil, errNothingToPack)
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes()
@ -1717,7 +1728,7 @@ var _ = Describe("Connection", func() {
written := make(chan struct{})
sender.EXPECT().Send(gomock.Any()).Do(func(*packetBuffer) { close(written) })
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
@ -1736,23 +1747,17 @@ var _ = Describe("Connection", func() {
buffer.Data = append(buffer.Data, []byte("foobar")...)
packer.EXPECT().PackCoalescedPacket(false).Return(&coalescedPacket{
buffer: buffer,
packets: []*packetContents{
longHdrPackets: []*longHeaderPacket{
{
header: &wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
},
Header: wire.Header{Type: protocol.PacketTypeInitial},
PacketNumber: 13,
},
length: 123,
},
{
header: &wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
},
Header: wire.Header{Type: protocol.PacketTypeHandshake},
PacketNumber: 37,
},
length: 1234,
@ -1777,10 +1782,10 @@ var _ = Describe("Connection", func() {
}),
)
gomock.InOrder(
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ExtendedHeader, _ protocol.ByteCount, _ *wire.AckFrame, _ []logging.Frame) {
tracer.EXPECT().SentLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ExtendedHeader, _ protocol.ByteCount, _ *wire.AckFrame, _ []logging.Frame) {
Expect(hdr.Type).To(Equal(protocol.PacketTypeInitial))
}),
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ExtendedHeader, _ protocol.ByteCount, _ *wire.AckFrame, _ []logging.Frame) {
tracer.EXPECT().SentLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ExtendedHeader, _ protocol.ByteCount, _ *wire.AckFrame, _ []logging.Frame) {
Expect(hdr.Type).To(Equal(protocol.PacketTypeHandshake))
}),
)
@ -1919,23 +1924,18 @@ var _ = Describe("Connection", func() {
sph.EXPECT().SetHandshakeConfirmed()
sph.EXPECT().SentPacket(gomock.Any())
mconn.EXPECT().Write(gomock.Any())
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
conn.sentPacketHandler = sph
done := make(chan struct{})
connRunner.EXPECT().Retire(clientDestConnID)
packer.EXPECT().PackPacket(false).DoAndReturn(func(bool) (*packedPacket, error) {
packer.EXPECT().PackPacket(false, gomock.Any()).DoAndReturn(func(bool, time.Time) (shortHeaderPacket, *packetBuffer, error) {
frames, _ := conn.framer.AppendControlFrames(nil, protocol.MaxByteCount)
Expect(frames).ToNot(BeEmpty())
Expect(frames[0].Frame).To(BeEquivalentTo(&wire.HandshakeDoneFrame{}))
defer close(done)
return &packedPacket{
packetContents: &packetContents{
header: &wire.ExtendedHeader{},
},
buffer: getPacketBuffer(),
}, nil
return shortHeaderPacket{Packet: &ackhandler.Packet{}}, getPacketBuffer(), nil
})
packer.EXPECT().PackPacket(false).AnyTimes()
packer.EXPECT().PackPacket(false, gomock.Any()).Return(shortHeaderPacket{}, nil, errNothingToPack).AnyTimes()
go func() {
defer GinkgoRecover()
cryptoSetup.EXPECT().RunHandshake()
@ -2065,7 +2065,7 @@ var _ = Describe("Connection", func() {
setRemoteIdleTimeout(5 * time.Second)
conn.lastPacketReceivedTime = time.Now().Add(-5 * time.Second / 2)
sent := make(chan struct{})
packer.EXPECT().PackCoalescedPacket(false).Do(func(bool) (*packedPacket, error) {
packer.EXPECT().PackCoalescedPacket(false).Do(func(bool) (*coalescedPacket, error) {
close(sent)
return nil, nil
})
@ -2078,7 +2078,7 @@ var _ = Describe("Connection", func() {
setRemoteIdleTimeout(time.Hour)
conn.lastPacketReceivedTime = time.Now().Add(-protocol.MaxKeepAliveInterval).Add(-time.Millisecond)
sent := make(chan struct{})
packer.EXPECT().PackCoalescedPacket(false).Do(func(bool) (*packedPacket, error) {
packer.EXPECT().PackCoalescedPacket(false).Do(func(bool) (*coalescedPacket, error) {
close(sent)
return nil, nil
})
@ -2437,7 +2437,6 @@ var _ = Describe("Client Connection", func() {
newConnID := protocol.ParseConnectionID([]byte{1, 3, 3, 7, 1, 3, 3, 7})
p := getPacket(&wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
SrcConnectionID: newConnID,
DestConnectionID: srcConnID,
@ -2478,7 +2477,6 @@ var _ = Describe("Client Connection", func() {
}, nil
})
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
DestConnectionID: srcConnID,
SrcConnectionID: destConnID,
@ -2635,7 +2633,6 @@ var _ = Describe("Client Connection", func() {
JustBeforeEach(func() {
retryHdr = &wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeRetry,
SrcConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}),
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
@ -2886,7 +2883,6 @@ var _ = Describe("Client Connection", func() {
hdr1 := &wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
DestConnectionID: destConnID,
SrcConnectionID: srcConnID,
@ -2898,7 +2894,6 @@ var _ = Describe("Client Connection", func() {
}
hdr2 := &wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
DestConnectionID: destConnID,
SrcConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}),
@ -2926,7 +2921,6 @@ var _ = Describe("Client Connection", func() {
It("ignores 0-RTT packets", func() {
p := getPacket(&wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketType0RTT,
DestConnectionID: srcConnID,
Length: 2 + 6,

View file

@ -11,7 +11,7 @@ import (
"github.com/lucas-clemente/quic-go/internal/wire"
)
const version = protocol.VersionTLS
const version = protocol.Version1
func getRandomData(l int) []byte {
b := make([]byte, l)
@ -30,7 +30,6 @@ func getVNP(src, dest protocol.ArbitraryLenConnectionID, numVersions int) []byte
func main() {
headers := []wire.Header{
{ // Initial without token
IsLongHeader: true,
SrcConnectionID: protocol.ParseConnectionID(getRandomData(3)),
DestConnectionID: protocol.ParseConnectionID(getRandomData(8)),
Type: protocol.PacketTypeInitial,
@ -38,14 +37,12 @@ func main() {
Version: version,
},
{ // Initial without token, with zero-length src conn id
IsLongHeader: true,
DestConnectionID: protocol.ParseConnectionID(getRandomData(8)),
Type: protocol.PacketTypeInitial,
Length: protocol.ByteCount(rand.Intn(1000)),
Version: version,
},
{ // Initial with Token
IsLongHeader: true,
SrcConnectionID: protocol.ParseConnectionID(getRandomData(10)),
DestConnectionID: protocol.ParseConnectionID(getRandomData(19)),
Type: protocol.PacketTypeInitial,
@ -54,7 +51,6 @@ func main() {
Token: getRandomData(25),
},
{ // Handshake packet
IsLongHeader: true,
SrcConnectionID: protocol.ParseConnectionID(getRandomData(5)),
DestConnectionID: protocol.ParseConnectionID(getRandomData(10)),
Type: protocol.PacketTypeHandshake,
@ -62,14 +58,12 @@ func main() {
Version: version,
},
{ // Handshake packet, with zero-length src conn id
IsLongHeader: true,
DestConnectionID: protocol.ParseConnectionID(getRandomData(12)),
Type: protocol.PacketTypeHandshake,
Length: protocol.ByteCount(rand.Intn(1000)),
Version: version,
},
{ // 0-RTT packet
IsLongHeader: true,
SrcConnectionID: protocol.ParseConnectionID(getRandomData(8)),
DestConnectionID: protocol.ParseConnectionID(getRandomData(9)),
Type: protocol.PacketType0RTT,
@ -77,16 +71,12 @@ func main() {
Version: version,
},
{ // Retry Packet, with empty orig dest conn id
IsLongHeader: true,
SrcConnectionID: protocol.ParseConnectionID(getRandomData(8)),
DestConnectionID: protocol.ParseConnectionID(getRandomData(9)),
Type: protocol.PacketTypeRetry,
Token: getRandomData(1000),
Version: version,
},
{ // Short-Header
DestConnectionID: protocol.ParseConnectionID(getRandomData(8)),
},
}
for _, h := range headers {
@ -111,6 +101,15 @@ func main() {
}
}
// short header
b := &bytes.Buffer{}
if err := wire.WriteShortHeader(b, protocol.ParseConnectionID(getRandomData(8)), 1337, protocol.PacketNumberLen2, protocol.KeyPhaseOne); err != nil {
log.Fatal(err)
}
if err := helper.WriteCorpusFileWithPrefix("corpus", b.Bytes(), header.PrefixLen); err != nil {
log.Fatal(err)
}
vnps := [][]byte{
getVNP(
protocol.ArbitraryLenConnectionID(getRandomData(8)),

View file

@ -30,8 +30,14 @@ func Fuzz(data []byte) int {
if err != nil {
return 0
}
if !wire.IsLongHeaderPacket(data[0]) {
wire.ParseShortHeader(data, connIDLen)
return 1
}
is0RTTPacket := wire.Is0RTTPacket(data)
hdr, _, _, err := wire.ParsePacket(data, connIDLen)
hdr, _, _, err := wire.ParsePacket(data)
if err != nil {
return 0
}
@ -55,7 +61,7 @@ func Fuzz(data []byte) int {
}
// We always use a 2-byte encoding for the Length field in Long Header packets.
// Serializing the header will fail when using a higher value.
if hdr.IsLongHeader && hdr.Length > 16383 {
if hdr.Length > 16383 {
return 1
}
b := &bytes.Buffer{}

View file

@ -16,16 +16,13 @@ import (
)
var (
sentHeaders []*logging.ExtendedHeader
sentHeaders []*logging.ShortHeader
receivedHeaders []*logging.ShortHeader
)
func countKeyPhases() (sent, received int) {
lastKeyPhase := protocol.KeyPhaseOne
for _, hdr := range sentHeaders {
if hdr.IsLongHeader {
continue
}
if hdr.KeyPhase != lastKeyPhase {
sent++
lastKeyPhase = hdr.KeyPhase
@ -45,7 +42,7 @@ type keyUpdateConnTracer struct {
logging.NullConnectionTracer
}
func (t *keyUpdateConnTracer) SentPacket(hdr *logging.ExtendedHeader, size logging.ByteCount, ack *logging.AckFrame, frames []logging.Frame) {
func (t *keyUpdateConnTracer) SentShortHeaderPacket(hdr *logging.ShortHeader, _ logging.ByteCount, _ *logging.AckFrame, _ []logging.Frame) {
sentHeaders = append(sentHeaders, hdr)
}

View file

@ -95,34 +95,56 @@ var _ = Describe("MITM test", func() {
sendRandomPacketsOfSameType := func(conn net.PacketConn, remoteAddr net.Addr, raw []byte) {
defer GinkgoRecover()
hdr, _, _, err := wire.ParsePacket(raw, connIDLen)
Expect(err).ToNot(HaveOccurred())
replyHdr := &wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: hdr.IsLongHeader,
DestConnectionID: hdr.DestConnectionID,
SrcConnectionID: hdr.SrcConnectionID,
Type: hdr.Type,
Version: hdr.Version,
},
PacketNumber: protocol.PacketNumber(mrand.Int31n(math.MaxInt32 / 4)),
PacketNumberLen: protocol.PacketNumberLen(mrand.Int31n(4) + 1),
}
const numPackets = 10
ticker := time.NewTicker(rtt / numPackets)
for i := 0; i < numPackets; i++ {
payloadLen := mrand.Int31n(100)
replyHdr.Length = protocol.ByteCount(mrand.Int31n(payloadLen + 1))
buf := &bytes.Buffer{}
Expect(replyHdr.Write(buf, version)).To(Succeed())
b := make([]byte, payloadLen)
mrand.Read(b)
buf.Write(b)
if _, err := conn.WriteTo(buf.Bytes(), remoteAddr); err != nil {
return
defer ticker.Stop()
if wire.IsLongHeaderPacket(raw[0]) {
hdr, _, _, err := wire.ParsePacket(raw)
Expect(err).ToNot(HaveOccurred())
replyHdr := &wire.ExtendedHeader{
Header: wire.Header{
DestConnectionID: hdr.DestConnectionID,
SrcConnectionID: hdr.SrcConnectionID,
Type: hdr.Type,
Version: hdr.Version,
},
PacketNumber: protocol.PacketNumber(mrand.Int31n(math.MaxInt32 / 4)),
PacketNumberLen: protocol.PacketNumberLen(mrand.Int31n(4) + 1),
}
for i := 0; i < numPackets; i++ {
payloadLen := mrand.Int31n(100)
replyHdr.Length = protocol.ByteCount(mrand.Int31n(payloadLen + 1))
buf := &bytes.Buffer{}
Expect(replyHdr.Write(buf, version)).To(Succeed())
b := make([]byte, payloadLen)
mrand.Read(b)
buf.Write(b)
if _, err := conn.WriteTo(buf.Bytes(), remoteAddr); err != nil {
return
}
<-ticker.C
}
} else {
connID, err := wire.ParseConnectionID(raw, connIDLen)
Expect(err).ToNot(HaveOccurred())
_, pn, pnLen, _, err := wire.ParseShortHeader(raw, connIDLen)
if err != nil { // normally, ParseShortHeader is called after decrypting the header
Expect(err).To(MatchError(wire.ErrInvalidReservedBits))
}
for i := 0; i < numPackets; i++ {
buf := &bytes.Buffer{}
Expect(wire.WriteShortHeader(buf, connID, pn, pnLen, protocol.KeyPhaseBit(mrand.Intn(2)))).To(Succeed())
payloadLen := mrand.Int31n(100)
b := make([]byte, payloadLen)
mrand.Read(b)
buf.Write(b)
if _, err := conn.WriteTo(buf.Bytes(), remoteAddr); err != nil {
return
}
<-ticker.C
}
<-ticker.C
}
}
@ -317,7 +339,7 @@ var _ = Describe("MITM test", func() {
if dir == quicproxy.DirectionIncoming {
defer GinkgoRecover()
hdr, _, _, err := wire.ParsePacket(raw, connIDLen)
hdr, _, _, err := wire.ParsePacket(raw)
Expect(err).ToNot(HaveOccurred())
if hdr.Type != protocol.PacketTypeInitial {
@ -359,7 +381,7 @@ var _ = Describe("MITM test", func() {
defer GinkgoRecover()
defer close(done)
hdr, _, _, err := wire.ParsePacket(raw, connIDLen)
hdr, _, _, err := wire.ParsePacket(raw)
Expect(err).ToNot(HaveOccurred())
if hdr.Type != protocol.PacketTypeInitial {
@ -392,7 +414,7 @@ var _ = Describe("MITM test", func() {
if dir == quicproxy.DirectionIncoming {
defer GinkgoRecover()
hdr, _, _, err := wire.ParsePacket(raw, connIDLen)
hdr, _, _, err := wire.ParsePacket(raw)
Expect(err).ToNot(HaveOccurred())
if hdr.Type != protocol.PacketTypeInitial || injected {
return 0
@ -420,7 +442,7 @@ var _ = Describe("MITM test", func() {
if dir == quicproxy.DirectionIncoming {
defer GinkgoRecover()
hdr, _, _, err := wire.ParsePacket(raw, connIDLen)
hdr, _, _, err := wire.ParsePacket(raw)
Expect(err).ToNot(HaveOccurred())
if hdr.Type != protocol.PacketTypeInitial || injected {
return 0

View file

@ -358,9 +358,9 @@ type shortHeaderPacket struct {
type packetTracer struct {
logging.NullConnectionTracer
closed chan struct{}
rcvdShortHdr []shortHeaderPacket
sent, rcvdLongHdr []packet
closed chan struct{}
sentShortHdr, rcvdShortHdr []shortHeaderPacket
rcvdLongHdr []packet
}
func newPacketTracer() *packetTracer {
@ -375,16 +375,18 @@ func (t *packetTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, _ log
t.rcvdShortHdr = append(t.rcvdShortHdr, shortHeaderPacket{time: time.Now(), hdr: hdr, frames: frames})
}
func (t *packetTracer) SentPacket(hdr *logging.ExtendedHeader, _ logging.ByteCount, ack *wire.AckFrame, frames []logging.Frame) {
func (t *packetTracer) SentShortHeaderPacket(hdr *logging.ShortHeader, _ logging.ByteCount, ack *wire.AckFrame, frames []logging.Frame) {
if ack != nil {
frames = append(frames, ack)
}
t.sent = append(t.sent, packet{time: time.Now(), hdr: hdr, frames: frames})
t.sentShortHdr = append(t.sentShortHdr, shortHeaderPacket{time: time.Now(), hdr: hdr, frames: frames})
}
func (t *packetTracer) Close() { close(t.closed) }
func (t *packetTracer) getSentPackets() []packet {
func (t *packetTracer) getSentShortHeaderPackets() []shortHeaderPacket {
<-t.closed
return t.sent
return t.sentShortHdr
}
func (t *packetTracer) getRcvdLongHeaderPackets() []packet {

View file

@ -213,7 +213,7 @@ var _ = Describe("Timeout tests", func() {
}()
Eventually(done, 2*idleTimeout).Should(BeClosed())
var lastAckElicitingPacketSentAt time.Time
for _, p := range tr.getSentPackets() {
for _, p := range tr.getSentShortHeaderPackets() {
var hasAckElicitingFrame bool
for _, f := range p.frames {
if _, ok := f.(*logging.AckFrame); ok {

View file

@ -35,7 +35,10 @@ var _ = Describe("0-RTT", func() {
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration {
for len(data) > 0 {
hdr, _, rest, err := wire.ParsePacket(data, 0)
if !wire.IsLongHeaderPacket(data[0]) {
break
}
hdr, _, rest, err := wire.ParsePacket(data)
Expect(err).ToNot(HaveOccurred())
if hdr.Type == protocol.PacketType0RTT {
atomic.AddUint32(&num0RTTPackets, 1)
@ -347,15 +350,20 @@ var _ = Describe("0-RTT", func() {
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port),
DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration {
hdr, _, _, err := wire.ParsePacket(data, 0)
Expect(err).ToNot(HaveOccurred())
if hdr.Type == protocol.PacketType0RTT {
atomic.AddUint32(&num0RTTPackets, 1)
if wire.IsLongHeaderPacket(data[0]) {
hdr, _, _, err := wire.ParsePacket(data)
Expect(err).ToNot(HaveOccurred())
if hdr.Type == protocol.PacketType0RTT {
atomic.AddUint32(&num0RTTPackets, 1)
}
}
return rtt / 2
},
DropPacket: func(_ quicproxy.Direction, data []byte) bool {
hdr, _, _, err := wire.ParsePacket(data, 0)
if !wire.IsLongHeaderPacket(data[0]) {
return false
}
hdr, _, _, err := wire.ParsePacket(data)
Expect(err).ToNot(HaveOccurred())
if hdr.Type == protocol.PacketType0RTT {
// drop 25% of the 0-RTT packets
@ -390,7 +398,7 @@ var _ = Describe("0-RTT", func() {
countZeroRTTBytes := func(data []byte) (n protocol.ByteCount) {
for len(data) > 0 {
hdr, _, rest, err := wire.ParsePacket(data, 0)
hdr, _, rest, err := wire.ParsePacket(data)
if err != nil {
return
}

View file

@ -31,7 +31,6 @@ var _ = Describe("QUIC Proxy", func() {
b := &bytes.Buffer{}
hdr := wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
Version: protocol.VersionTLS,
Length: 4 + protocol.ByteCount(len(payload)),
@ -48,7 +47,7 @@ var _ = Describe("QUIC Proxy", func() {
}
readPacketNumber := func(b []byte) protocol.PacketNumber {
hdr, data, _, err := wire.ParsePacket(b, 0)
hdr, data, _, err := wire.ParsePacket(b)
ExpectWithOffset(1, err).ToNot(HaveOccurred())
Expect(hdr.Type).To(Equal(protocol.PacketTypeInitial))
extHdr, err := hdr.ParseExtended(bytes.NewReader(data), protocol.VersionTLS)

View file

@ -255,16 +255,28 @@ func (mr *MockConnectionTracerMockRecorder) RestoredTransportParameters(arg0 int
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RestoredTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).RestoredTransportParameters), arg0)
}
// SentPacket mocks base method.
func (m *MockConnectionTracer) SentPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 *wire.AckFrame, arg3 []logging.Frame) {
// SentLongHeaderPacket mocks base method.
func (m *MockConnectionTracer) SentLongHeaderPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 *wire.AckFrame, arg3 []logging.Frame) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SentPacket", arg0, arg1, arg2, arg3)
m.ctrl.Call(m, "SentLongHeaderPacket", arg0, arg1, arg2, arg3)
}
// SentPacket indicates an expected call of SentPacket.
func (mr *MockConnectionTracerMockRecorder) SentPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
// SentLongHeaderPacket indicates an expected call of SentLongHeaderPacket.
func (mr *MockConnectionTracerMockRecorder) SentLongHeaderPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentPacket), arg0, arg1, arg2, arg3)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentLongHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentLongHeaderPacket), arg0, arg1, arg2, arg3)
}
// SentShortHeaderPacket mocks base method.
func (m *MockConnectionTracer) SentShortHeaderPacket(arg0 *logging.ShortHeader, arg1 protocol.ByteCount, arg2 *wire.AckFrame, arg3 []logging.Frame) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SentShortHeaderPacket", arg0, arg1, arg2, arg3)
}
// SentShortHeaderPacket indicates an expected call of SentShortHeaderPacket.
func (mr *MockConnectionTracerMockRecorder) SentShortHeaderPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentShortHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentShortHeaderPacket), arg0, arg1, arg2, arg3)
}
// SentTransportParameters mocks base method.

View file

@ -50,7 +50,6 @@ func ComposeInitialPacket(srcConnID protocol.ConnectionID, destConnID protocol.C
length := payloadSize + int(pnLength) + sealer.Overhead()
hdr := &wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
SrcConnectionID: srcConnID,
DestConnectionID: destConnID,
@ -88,7 +87,6 @@ func ComposeRetryPacket(
) []byte {
hdr := &wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeRetry,
SrcConnectionID: srcConnID,
DestConnectionID: destConnID,

View file

@ -42,12 +42,7 @@ func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (bool
if _, err := b.Seek(int64(h.Header.ParsedLen())-1, io.SeekCurrent); err != nil {
return false, err
}
var reservedBitsValid bool
if h.IsLongHeader {
reservedBitsValid, err = h.parseLongHeader(b, v)
} else {
reservedBitsValid, err = h.parseShortHeader(b, v)
}
reservedBitsValid, err := h.parseLongHeader(b, v)
if err != nil {
return false, err
}
@ -65,21 +60,6 @@ func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, _ protocol.VersionNumb
return true, nil
}
func (h *ExtendedHeader) parseShortHeader(b *bytes.Reader, _ protocol.VersionNumber) (bool /* reserved bits valid */, error) {
h.KeyPhase = protocol.KeyPhaseZero
if h.typeByte&0x4 > 0 {
h.KeyPhase = protocol.KeyPhaseOne
}
if err := h.readPacketNumber(b); err != nil {
return false, err
}
if h.typeByte&0x18 != 0 {
return false, nil
}
return true, nil
}
func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error {
h.PacketNumberLen = protocol.PacketNumberLen(h.typeByte&0x3) + 1
switch h.PacketNumberLen {
@ -121,10 +101,7 @@ func (h *ExtendedHeader) Write(b *bytes.Buffer, ver protocol.VersionNumber) erro
if h.SrcConnectionID.Len() > protocol.MaxConnIDLen {
return fmt.Errorf("invalid connection ID length: %d bytes", h.SrcConnectionID.Len())
}
if h.IsLongHeader {
return h.writeLongHeader(b, ver)
}
return h.writeShortHeader(b, ver)
return h.writeLongHeader(b, ver)
}
func (h *ExtendedHeader) writeLongHeader(b *bytes.Buffer, version protocol.VersionNumber) error {
@ -177,34 +154,7 @@ func (h *ExtendedHeader) writeLongHeader(b *bytes.Buffer, version protocol.Versi
b.Write(h.Token)
}
quicvarint.WriteWithLen(b, uint64(h.Length), 2)
return h.writePacketNumber(b)
}
func (h *ExtendedHeader) writeShortHeader(b *bytes.Buffer, _ protocol.VersionNumber) error {
typeByte := 0x40 | uint8(h.PacketNumberLen-1)
if h.KeyPhase == protocol.KeyPhaseOne {
typeByte |= byte(1 << 2)
}
b.WriteByte(typeByte)
b.Write(h.DestConnectionID.Bytes())
return h.writePacketNumber(b)
}
func (h *ExtendedHeader) writePacketNumber(b *bytes.Buffer) error {
switch h.PacketNumberLen {
case protocol.PacketNumberLen1:
b.WriteByte(uint8(h.PacketNumber))
case protocol.PacketNumberLen2:
utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber))
case protocol.PacketNumberLen3:
utils.BigEndian.WriteUint24(b, uint32(h.PacketNumber))
case protocol.PacketNumberLen4:
utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber))
default:
return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
}
return nil
return writePacketNumber(b, h.PacketNumber, h.PacketNumberLen)
}
// ParsedLen returns the number of bytes that were consumed when parsing the header
@ -213,37 +163,43 @@ func (h *ExtendedHeader) ParsedLen() protocol.ByteCount {
}
// GetLength determines the length of the Header.
func (h *ExtendedHeader) GetLength(v protocol.VersionNumber) protocol.ByteCount {
if h.IsLongHeader {
length := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn ID len */ + protocol.ByteCount(h.DestConnectionID.Len()) + 1 /* src conn ID len */ + protocol.ByteCount(h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + 2 /* length */
if h.Type == protocol.PacketTypeInitial {
length += quicvarint.Len(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token))
}
return length
func (h *ExtendedHeader) GetLength(_ protocol.VersionNumber) protocol.ByteCount {
length := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn ID len */ + protocol.ByteCount(h.DestConnectionID.Len()) + 1 /* src conn ID len */ + protocol.ByteCount(h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + 2 /* length */
if h.Type == protocol.PacketTypeInitial {
length += quicvarint.Len(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token))
}
length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len())
length += protocol.ByteCount(h.PacketNumberLen)
return length
}
// Log logs the Header
func (h *ExtendedHeader) Log(logger utils.Logger) {
if h.IsLongHeader {
var token string
if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry {
if len(h.Token) == 0 {
token = "Token: (empty), "
} else {
token = fmt.Sprintf("Token: %#x, ", h.Token)
}
if h.Type == protocol.PacketTypeRetry {
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sVersion: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.Version)
return
}
var token string
if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry {
if len(h.Token) == 0 {
token = "Token: (empty), "
} else {
token = fmt.Sprintf("Token: %#x, ", h.Token)
}
if h.Type == protocol.PacketTypeRetry {
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sVersion: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.Version)
return
}
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %d, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version)
} else {
logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %d, PacketNumberLen: %d, KeyPhase: %s}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase)
}
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %d, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version)
}
func writePacketNumber(b *bytes.Buffer, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen) error {
switch pnLen {
case protocol.PacketNumberLen1:
b.WriteByte(uint8(pn))
case protocol.PacketNumberLen2:
utils.BigEndian.WriteUint16(b, uint16(pn))
case protocol.PacketNumberLen3:
utils.BigEndian.WriteUint24(b, uint32(pn))
case protocol.PacketNumberLen4:
utils.BigEndian.WriteUint32(b, uint32(pn))
default:
return fmt.Errorf("invalid packet number length: %d", pnLen)
}
return nil
}

View file

@ -14,7 +14,7 @@ import (
)
var _ = Describe("Header", func() {
const versionIETFHeader = protocol.VersionTLS // a QUIC version that uses the IETF Header format
const versionIETFHeader = protocol.Version1
Context("Writing", func() {
var buf *bytes.Buffer
@ -29,7 +29,6 @@ var _ = Describe("Header", func() {
It("writes", func() {
Expect((&ExtendedHeader{
Header: Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe}),
SrcConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad, 0x0, 0x0, 0x13, 0x37}),
@ -55,7 +54,6 @@ var _ = Describe("Header", func() {
It("writes a header with a 20 byte connection ID", func() {
err := (&ExtendedHeader{
Header: Header{
IsLongHeader: true,
SrcConnectionID: srcConnID,
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}), // connection IDs must be at most 20 bytes long
Version: 0x1020304,
@ -72,10 +70,9 @@ var _ = Describe("Header", func() {
token := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
Expect((&ExtendedHeader{
Header: Header{
IsLongHeader: true,
Version: 0x1020304,
Type: protocol.PacketTypeInitial,
Token: token,
Version: 0x1020304,
Type: protocol.PacketTypeInitial,
Token: token,
},
PacketNumber: 0xdecafbad,
PacketNumberLen: protocol.PacketNumberLen4,
@ -88,10 +85,9 @@ var _ = Describe("Header", func() {
It("uses a 2-byte encoding for the length on Initial packets", func() {
Expect((&ExtendedHeader{
Header: Header{
IsLongHeader: true,
Version: 0x1020304,
Type: protocol.PacketTypeInitial,
Length: 37,
Version: 0x1020304,
Type: protocol.PacketTypeInitial,
Length: 37,
},
PacketNumber: 0xdecafbad,
PacketNumberLen: protocol.PacketNumberLen4,
@ -104,10 +100,9 @@ var _ = Describe("Header", func() {
It("writes a Retry packet", func() {
token := []byte("Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.")
Expect((&ExtendedHeader{Header: Header{
IsLongHeader: true,
Version: protocol.Version1,
Type: protocol.PacketTypeRetry,
Token: token,
Version: protocol.Version1,
Type: protocol.PacketTypeRetry,
Token: token,
}}).Write(buf, versionIETFHeader)).To(Succeed())
expected := []byte{0xc0 | 0b11<<4}
expected = appendVersion(expected, protocol.Version1)
@ -122,9 +117,8 @@ var _ = Describe("Header", func() {
It("writes an Initial", func() {
Expect((&ExtendedHeader{
Header: Header{
IsLongHeader: true,
Version: protocol.Version2,
Type: protocol.PacketTypeInitial,
Version: protocol.Version2,
Type: protocol.PacketTypeInitial,
},
PacketNumber: 0xdecafbad,
PacketNumberLen: protocol.PacketNumberLen4,
@ -135,10 +129,9 @@ var _ = Describe("Header", func() {
It("writes a Retry packet", func() {
token := []byte("Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.")
Expect((&ExtendedHeader{Header: Header{
IsLongHeader: true,
Version: protocol.Version2,
Type: protocol.PacketTypeRetry,
Token: token,
Version: protocol.Version2,
Type: protocol.PacketTypeRetry,
Token: token,
}}).Write(buf, versionIETFHeader)).To(Succeed())
expected := []byte{0xc0 | 0b11<<4}
expected = appendVersion(expected, protocol.Version2)
@ -151,9 +144,8 @@ var _ = Describe("Header", func() {
It("writes a Handshake Packet", func() {
Expect((&ExtendedHeader{
Header: Header{
IsLongHeader: true,
Version: protocol.Version2,
Type: protocol.PacketTypeHandshake,
Version: protocol.Version2,
Type: protocol.PacketTypeHandshake,
},
PacketNumber: 0xdecafbad,
PacketNumberLen: protocol.PacketNumberLen4,
@ -164,9 +156,8 @@ var _ = Describe("Header", func() {
It("writes a 0-RTT Packet", func() {
Expect((&ExtendedHeader{
Header: Header{
IsLongHeader: true,
Version: protocol.Version2,
Type: protocol.PacketType0RTT,
Version: protocol.Version2,
Type: protocol.PacketType0RTT,
},
PacketNumber: 0xdecafbad,
PacketNumberLen: protocol.PacketNumberLen4,
@ -174,74 +165,6 @@ var _ = Describe("Header", func() {
Expect(buf.Bytes()[0]>>4&0b11 == 0b10)
})
})
Context("short header", func() {
It("writes a header with connection ID", func() {
Expect((&ExtendedHeader{
Header: Header{
DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}),
},
PacketNumberLen: protocol.PacketNumberLen1,
PacketNumber: 0x42,
}).Write(buf, versionIETFHeader)).To(Succeed())
Expect(buf.Bytes()).To(Equal([]byte{
0x40,
0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID
0x42, // packet number
}))
})
It("writes a header without connection ID", func() {
Expect((&ExtendedHeader{
PacketNumberLen: protocol.PacketNumberLen1,
PacketNumber: 0x42,
}).Write(buf, versionIETFHeader)).To(Succeed())
Expect(buf.Bytes()).To(Equal([]byte{
0x40,
0x42, // packet number
}))
})
It("writes a header with a 2 byte packet number", func() {
Expect((&ExtendedHeader{
PacketNumberLen: protocol.PacketNumberLen2,
PacketNumber: 0x765,
}).Write(buf, versionIETFHeader)).To(Succeed())
expected := []byte{0x40 | 0x1}
expected = append(expected, []byte{0x7, 0x65}...) // packet number
Expect(buf.Bytes()).To(Equal(expected))
})
It("writes a header with a 4 byte packet number", func() {
Expect((&ExtendedHeader{
PacketNumberLen: protocol.PacketNumberLen4,
PacketNumber: 0x12345678,
}).Write(buf, versionIETFHeader)).To(Succeed())
expected := []byte{0x40 | 0x3}
expected = append(expected, []byte{0x12, 0x34, 0x56, 0x78}...)
Expect(buf.Bytes()).To(Equal(expected))
})
It("errors when given an invalid packet number length", func() {
err := (&ExtendedHeader{
PacketNumberLen: 5,
PacketNumber: 0xdecafbad,
}).Write(buf, versionIETFHeader)
Expect(err).To(MatchError("invalid packet number length: 5"))
})
It("writes the Key Phase Bit", func() {
Expect((&ExtendedHeader{
KeyPhase: protocol.KeyPhaseOne,
PacketNumberLen: protocol.PacketNumberLen1,
PacketNumber: 0x42,
}).Write(buf, versionIETFHeader)).To(Succeed())
Expect(buf.Bytes()).To(Equal([]byte{
0x40 | 0x4,
0x42, // packet number
}))
})
})
})
Context("getting the length", func() {
@ -254,7 +177,6 @@ var _ = Describe("Header", func() {
It("has the right length for the Long Header, for a short length", func() {
h := &ExtendedHeader{
Header: Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
SrcConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
@ -271,7 +193,6 @@ var _ = Describe("Header", func() {
It("has the right length for the Long Header, for a long length", func() {
h := &ExtendedHeader{
Header: Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
SrcConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
@ -288,7 +209,6 @@ var _ = Describe("Header", func() {
It("has the right length for an Initial that has a short length", func() {
h := &ExtendedHeader{
Header: Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
SrcConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
@ -305,7 +225,6 @@ var _ = Describe("Header", func() {
It("has the right length for an Initial not containing a Token", func() {
h := &ExtendedHeader{
Header: Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
SrcConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
@ -322,7 +241,6 @@ var _ = Describe("Header", func() {
It("has the right length for an Initial containing a Token", func() {
h := &ExtendedHeader{
Header: Header{
IsLongHeader: true,
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
SrcConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
Type: protocol.PacketTypeInitial,
@ -336,39 +254,6 @@ var _ = Describe("Header", func() {
Expect(h.Write(buf, versionIETFHeader)).To(Succeed())
Expect(buf.Len()).To(Equal(expectedLen))
})
It("has the right length for a Short Header containing a connection ID", func() {
h := &ExtendedHeader{
Header: Header{
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
},
PacketNumberLen: protocol.PacketNumberLen1,
}
Expect(h.GetLength(versionIETFHeader)).To(Equal(protocol.ByteCount(1 + 8 + 1)))
Expect(h.Write(buf, versionIETFHeader)).To(Succeed())
Expect(buf.Len()).To(Equal(10))
})
It("has the right length for a short header without a connection ID", func() {
h := &ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}
Expect(h.GetLength(versionIETFHeader)).To(Equal(protocol.ByteCount(1 + 1)))
Expect(h.Write(buf, versionIETFHeader)).To(Succeed())
Expect(buf.Len()).To(Equal(2))
})
It("has the right length for a short header with a 2 byte packet number", func() {
h := &ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen2}
Expect(h.GetLength(versionIETFHeader)).To(Equal(protocol.ByteCount(1 + 2)))
Expect(h.Write(buf, versionIETFHeader)).To(Succeed())
Expect(buf.Len()).To(Equal(3))
})
It("has the right length for a short header with a 5 byte packet number", func() {
h := &ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen4}
Expect(h.GetLength(versionIETFHeader)).To(Equal(protocol.ByteCount(1 + 4)))
Expect(h.Write(buf, versionIETFHeader)).To(Succeed())
Expect(buf.Len()).To(Equal(5))
})
})
Context("Logging", func() {
@ -391,7 +276,6 @@ var _ = Describe("Header", func() {
It("logs Long Headers", func() {
(&ExtendedHeader{
Header: Header{
IsLongHeader: true,
DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}),
SrcConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad, 0x013, 0x37, 0x13, 0x37}),
Type: protocol.PacketTypeHandshake,
@ -407,7 +291,6 @@ var _ = Describe("Header", func() {
It("logs Initial Packets with a Token", func() {
(&ExtendedHeader{
Header: Header{
IsLongHeader: true,
DestConnectionID: protocol.ParseConnectionID([]byte{0xca, 0xfe, 0x13, 0x37}),
SrcConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}),
Type: protocol.PacketTypeInitial,
@ -424,7 +307,6 @@ var _ = Describe("Header", func() {
It("logs Initial packets without a Token", func() {
(&ExtendedHeader{
Header: Header{
IsLongHeader: true,
DestConnectionID: protocol.ParseConnectionID([]byte{0xca, 0xfe, 0x13, 0x37}),
SrcConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}),
Type: protocol.PacketTypeInitial,
@ -440,7 +322,6 @@ var _ = Describe("Header", func() {
It("logs Retry packets with a Token", func() {
(&ExtendedHeader{
Header: Header{
IsLongHeader: true,
DestConnectionID: protocol.ParseConnectionID([]byte{0xca, 0xfe, 0x13, 0x37}),
SrcConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}),
Type: protocol.PacketTypeRetry,
@ -450,17 +331,5 @@ var _ = Describe("Header", func() {
}).Log(logger)
Expect(buf.String()).To(ContainSubstring("Long Header{Type: Retry, DestConnectionID: cafe1337, SrcConnectionID: decafbad, Token: 0x123456, Version: 0xfeed}"))
})
It("logs Short Headers containing a connection ID", func() {
(&ExtendedHeader{
Header: Header{
DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}),
},
KeyPhase: protocol.KeyPhaseOne,
PacketNumber: 1337,
PacketNumberLen: 4,
}).Log(logger)
Expect(buf.String()).To(ContainSubstring("Short Header{DestConnectionID: deadbeefcafe1337, PacketNumber: 1337, PacketNumberLen: 4, KeyPhase: 1}"))
})
})
})

View file

@ -121,9 +121,8 @@ var ErrUnsupportedVersion = errors.New("unsupported version")
// The Header is the version independent part of the header
type Header struct {
IsLongHeader bool
typeByte byte
Type protocol.PacketType
typeByte byte
Type protocol.PacketType
Version protocol.VersionNumber
SrcConnectionID protocol.ConnectionID
@ -140,24 +139,22 @@ type Header struct {
// If the packet has a long header, the packet is cut according to the length field.
// If we understand the version, the packet is header up unto the packet number.
// Otherwise, only the invariant part of the header is parsed.
func ParsePacket(data []byte, shortHeaderConnIDLen int) (*Header, []byte /* packet data */, []byte /* rest */, error) {
hdr, err := parseHeader(bytes.NewReader(data), shortHeaderConnIDLen)
func ParsePacket(data []byte) (*Header, []byte, []byte, error) {
if len(data) == 0 || !IsLongHeaderPacket(data[0]) {
return nil, nil, nil, errors.New("not a long header packet")
}
hdr, err := parseHeader(bytes.NewReader(data))
if err != nil {
if err == ErrUnsupportedVersion {
return hdr, nil, nil, ErrUnsupportedVersion
}
return nil, nil, nil, err
}
var rest []byte
if hdr.IsLongHeader {
if protocol.ByteCount(len(data)) < hdr.ParsedLen()+hdr.Length {
return nil, nil, nil, fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)-int(hdr.ParsedLen()), hdr.Length)
}
packetLen := int(hdr.ParsedLen() + hdr.Length)
rest = data[packetLen:]
data = data[:packetLen]
if protocol.ByteCount(len(data)) < hdr.ParsedLen()+hdr.Length {
return nil, nil, nil, fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)-int(hdr.ParsedLen()), hdr.Length)
}
return hdr, data, rest, nil
packetLen := int(hdr.ParsedLen() + hdr.Length)
return hdr, data[:packetLen], data[packetLen:], nil
}
// ParseHeader parses the header.
@ -165,43 +162,17 @@ func ParsePacket(data []byte, shortHeaderConnIDLen int) (*Header, []byte /* pack
// For long header packets:
// * if we understand the version: up to the packet number
// * if not, only the invariant part of the header
func parseHeader(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) {
func parseHeader(b *bytes.Reader) (*Header, error) {
startLen := b.Len()
h, err := parseHeaderImpl(b, shortHeaderConnIDLen)
if err != nil {
return h, err
}
h.parsedLen = protocol.ByteCount(startLen - b.Len())
return h, err
}
func parseHeaderImpl(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) {
typeByte, err := b.ReadByte()
if err != nil {
return nil, err
}
h := &Header{
typeByte: typeByte,
IsLongHeader: IsLongHeaderPacket(typeByte),
}
if !h.IsLongHeader {
if h.typeByte&0x40 == 0 {
return nil, errors.New("not a QUIC packet")
}
if err := h.parseShortHeader(b, shortHeaderConnIDLen); err != nil {
return nil, err
}
return h, nil
}
return h, h.parseLongHeader(b)
}
func (h *Header) parseShortHeader(b *bytes.Reader, shortHeaderConnIDLen int) error {
var err error
h.DestConnectionID, err = protocol.ReadConnectionID(b, shortHeaderConnIDLen)
return err
h := &Header{typeByte: typeByte}
err = h.parseLongHeader(b)
h.parsedLen = protocol.ByteCount(startLen - b.Len())
return h, err
}
func (h *Header) parseLongHeader(b *bytes.Reader) error {
@ -321,8 +292,5 @@ func (h *Header) toExtendedHeader() *ExtendedHeader {
// PacketType is the type of the packet, for logging purposes
func (h *Header) PacketType() string {
if h.IsLongHeader {
return h.Type.String()
}
return "1-RTT"
return h.Type.String()
}

View file

@ -18,7 +18,6 @@ var _ = Describe("Header Parsing", func() {
buf := &bytes.Buffer{}
Expect((&ExtendedHeader{
Header: Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}),
SrcConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6}),
@ -31,44 +30,10 @@ var _ = Describe("Header Parsing", func() {
Expect(connID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad})))
})
It("parses the connection ID of a short header packet", func() {
buf := &bytes.Buffer{}
Expect((&ExtendedHeader{
Header: Header{
DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}),
},
PacketNumberLen: 2,
}).Write(buf, protocol.Version1)).To(Succeed())
buf.Write([]byte("foobar"))
connID, err := ParseConnectionID(buf.Bytes(), 4)
Expect(err).ToNot(HaveOccurred())
Expect(connID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad})))
})
It("errors on EOF, for short header packets", func() {
buf := &bytes.Buffer{}
Expect((&ExtendedHeader{
Header: Header{
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
},
PacketNumberLen: 2,
}).Write(buf, protocol.Version1)).To(Succeed())
data := buf.Bytes()[:buf.Len()-2] // cut the packet number
_, err := ParseConnectionID(data, 8)
Expect(err).ToNot(HaveOccurred())
for i := 0; i < len(data); i++ {
b := make([]byte, i)
copy(b, data[:i])
_, err := ParseConnectionID(b, 8)
Expect(err).To(MatchError(io.EOF))
}
})
It("errors on EOF, for long header packets", func() {
buf := &bytes.Buffer{}
Expect((&ExtendedHeader{
Header: Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad, 0x13, 0x37}),
SrcConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 8, 9}),
@ -219,10 +184,9 @@ var _ = Describe("Header Parsing", func() {
data = append(data, []byte("foobar")...)
Expect(IsVersionNegotiationPacket(data)).To(BeFalse())
hdr, pdata, rest, err := ParsePacket(data, 0)
hdr, pdata, rest, err := ParsePacket(data)
Expect(err).ToNot(HaveOccurred())
Expect(pdata).To(Equal(data))
Expect(hdr.IsLongHeader).To(BeTrue())
Expect(hdr.DestConnectionID).To(Equal(destConnID))
Expect(hdr.SrcConnectionID).To(Equal(srcConnID))
Expect(hdr.Type).To(Equal(protocol.PacketTypeInitial))
@ -247,7 +211,7 @@ var _ = Describe("Header Parsing", func() {
0xde, 0xca, 0xfb, 0xad, // dest conn ID
0xde, 0xad, 0xbe, 0xef, // src conn ID
}
_, _, _, err := ParsePacket(data, 0)
_, _, _, err := ParsePacket(data)
Expect(err).To(MatchError("not a QUIC packet"))
})
@ -261,9 +225,8 @@ var _ = Describe("Header Parsing", func() {
0x8, 0x7, 0x6, 0x5, 0x4, 0x3, 0x2, 0x1, // src conn ID
'f', 'o', 'o', 'b', 'a', 'r', // unspecified bytes
}
hdr, _, rest, err := ParsePacket(data, 0)
hdr, _, rest, err := ParsePacket(data)
Expect(err).To(MatchError(ErrUnsupportedVersion))
Expect(hdr.IsLongHeader).To(BeTrue())
Expect(hdr.Version).To(Equal(protocol.VersionNumber(0xdeadbeef)))
Expect(hdr.DestConnectionID).To(Equal(protocol.ParseConnectionID([]byte{0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8})))
Expect(hdr.SrcConnectionID).To(Equal(protocol.ParseConnectionID([]byte{0x8, 0x7, 0x6, 0x5, 0x4, 0x3, 0x2, 0x1})))
@ -278,7 +241,7 @@ var _ = Describe("Header Parsing", func() {
data = append(data, []byte{0xde, 0xad, 0xbe, 0xef}...) // source connection ID
data = append(data, encodeVarInt(0)...) // length
data = append(data, []byte{0xde, 0xca, 0xfb, 0xad}...)
hdr, _, _, err := ParsePacket(data, 0)
hdr, _, _, err := ParsePacket(data)
Expect(err).ToNot(HaveOccurred())
Expect(hdr.Type).To(Equal(protocol.PacketType0RTT))
Expect(hdr.SrcConnectionID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef})))
@ -293,7 +256,7 @@ var _ = Describe("Header Parsing", func() {
data = append(data, 0) // src conn ID len
data = append(data, encodeVarInt(0)...) // length
data = append(data, []byte{0xde, 0xca, 0xfb, 0xad}...)
hdr, _, _, err := ParsePacket(data, 0)
hdr, _, _, err := ParsePacket(data)
Expect(err).ToNot(HaveOccurred())
Expect(hdr.SrcConnectionID).To(BeZero())
Expect(hdr.DestConnectionID).To(Equal(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})))
@ -307,7 +270,7 @@ var _ = Describe("Header Parsing", func() {
data = append(data, 0x0) // src conn ID len
data = append(data, encodeVarInt(0)...) // length
data = append(data, []byte{0xde, 0xca, 0xfb, 0xad}...)
_, _, _, err := ParsePacket(data, 0)
_, _, _, err := ParsePacket(data)
Expect(err).To(MatchError(protocol.ErrInvalidConnectionIDLen))
})
@ -319,7 +282,7 @@ var _ = Describe("Header Parsing", func() {
data = append(data, encodeVarInt(0)...) // length
data = append(data, []byte{0x1, 0x23}...)
hdr, _, _, err := ParsePacket(data, 0)
hdr, _, _, err := ParsePacket(data)
Expect(err).ToNot(HaveOccurred())
b := bytes.NewReader(data)
extHdr, err := hdr.ParseExtended(b, protocol.Version1)
@ -338,7 +301,7 @@ var _ = Describe("Header Parsing", func() {
data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // source connection ID
data = append(data, []byte{'f', 'o', 'o', 'b', 'a', 'r'}...) // token
data = append(data, []byte{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}...)
hdr, pdata, rest, err := ParsePacket(data, 0)
hdr, pdata, rest, err := ParsePacket(data)
Expect(err).ToNot(HaveOccurred())
Expect(hdr.Type).To(Equal(protocol.PacketTypeRetry))
Expect(hdr.Version).To(Equal(protocol.Version1))
@ -358,7 +321,7 @@ var _ = Describe("Header Parsing", func() {
data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // source connection ID
data = append(data, []byte{'f', 'o', 'o', 'b', 'a', 'r'}...) // token
data = append(data, []byte{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}...)
hdr, pdata, rest, err := ParsePacket(data, 0)
hdr, pdata, rest, err := ParsePacket(data)
Expect(err).ToNot(HaveOccurred())
Expect(hdr.Type).To(Equal(protocol.PacketTypeRetry))
Expect(hdr.Version).To(Equal(protocol.Version2))
@ -376,7 +339,7 @@ var _ = Describe("Header Parsing", func() {
data = append(data, []byte{'f', 'o', 'o', 'b', 'a', 'r'}...) // token
data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...)
// this results in a token length of 0
_, _, _, err := ParsePacket(data, 0)
_, _, _, err := ParsePacket(data)
Expect(err).To(MatchError(io.EOF))
})
@ -388,7 +351,7 @@ var _ = Describe("Header Parsing", func() {
data = append(data, encodeVarInt(0x42)...) // length, 1 byte
data = append(data, []byte{0x12, 0x34}...) // packet number
_, _, _, err := ParsePacket(data, 0)
_, _, _, err := ParsePacket(data)
Expect(err).To(MatchError(io.EOF))
})
@ -398,7 +361,7 @@ var _ = Describe("Header Parsing", func() {
data = append(data, []byte{0x0, 0x0}...) // connection ID lengths
data = append(data, encodeVarInt(2)...) // length
data = append(data, []byte{0x12, 0x34}...) // packet number
hdr, _, _, err := ParsePacket(data, 0)
hdr, _, _, err := ParsePacket(data)
Expect(err).ToNot(HaveOccurred())
Expect(hdr.Type).To(Equal(protocol.PacketTypeHandshake))
extHdr, err := hdr.ParseExtended(bytes.NewReader(data), protocol.Version1)
@ -414,8 +377,8 @@ var _ = Describe("Header Parsing", func() {
data = append(data, []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}...) // dest conn ID
data = append(data, 0x8) // src conn ID len
data = append(data, []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}...) // src conn ID
for i := 0; i < len(data); i++ {
_, _, _, err := ParsePacket(data[:i], 0)
for i := 1; i < len(data); i++ {
_, _, _, err := ParsePacket(data[:i])
Expect(err).To(Equal(io.EOF))
}
})
@ -429,7 +392,7 @@ var _ = Describe("Header Parsing", func() {
data = append(data, []byte{0xde, 0xad, 0xbe, 0xef}...) // packet number
for i := hdrLen; i < len(data); i++ {
data = data[:i]
hdr, _, _, err := ParsePacket(data, 0)
hdr, _, _, err := ParsePacket(data)
Expect(err).ToNot(HaveOccurred())
b := bytes.NewReader(data)
_, err = hdr.ParseExtended(b, protocol.Version1)
@ -446,7 +409,7 @@ var _ = Describe("Header Parsing", func() {
hdrLen := len(data)
for i := hdrLen; i < len(data); i++ {
data = data[:i]
hdr, _, _, err := ParsePacket(data, 0)
hdr, _, _, err := ParsePacket(data)
Expect(err).ToNot(HaveOccurred())
b := bytes.NewReader(data)
_, err = hdr.ParseExtended(b, protocol.Version1)
@ -458,7 +421,6 @@ var _ = Describe("Header Parsing", func() {
It("cuts packets", func() {
buf := &bytes.Buffer{}
hdr := Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
Length: 2 + 6,
@ -472,7 +434,7 @@ var _ = Describe("Header Parsing", func() {
hdrRaw := append([]byte{}, buf.Bytes()...)
buf.Write([]byte("foobar")) // payload of the first packet
buf.Write([]byte("raboof")) // second packet
parsedHdr, data, rest, err := ParsePacket(buf.Bytes(), 4)
parsedHdr, data, rest, err := ParsePacket(buf.Bytes())
Expect(err).ToNot(HaveOccurred())
Expect(parsedHdr.Type).To(Equal(hdr.Type))
Expect(parsedHdr.DestConnectionID).To(Equal(hdr.DestConnectionID))
@ -484,7 +446,6 @@ var _ = Describe("Header Parsing", func() {
buf := &bytes.Buffer{}
Expect((&ExtendedHeader{
Header: Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
Length: 3,
@ -493,7 +454,7 @@ var _ = Describe("Header Parsing", func() {
PacketNumber: 0x1337,
PacketNumberLen: 2,
}).Write(buf, protocol.Version1)).To(Succeed())
_, _, _, err := ParsePacket(buf.Bytes(), 4)
_, _, _, err := ParsePacket(buf.Bytes())
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("packet length (2 bytes) is smaller than the expected length (3 bytes)"))
})
@ -502,7 +463,6 @@ var _ = Describe("Header Parsing", func() {
buf := &bytes.Buffer{}
Expect((&ExtendedHeader{
Header: Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
Length: 1000,
@ -512,160 +472,19 @@ var _ = Describe("Header Parsing", func() {
PacketNumberLen: 2,
}).Write(buf, protocol.Version1)).To(Succeed())
buf.Write(make([]byte, 500-2 /* for packet number length */))
_, _, _, err := ParsePacket(buf.Bytes(), 4)
_, _, _, err := ParsePacket(buf.Bytes())
Expect(err).To(MatchError("packet length (500 bytes) is smaller than the expected length (1000 bytes)"))
})
})
})
Context("Short Headers", func() {
It("reads a Short Header with a 8 byte connection ID", func() {
connID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37})
data := append([]byte{0x40}, connID.Bytes()...)
data = append(data, 0x42) // packet number
Expect(IsVersionNegotiationPacket(data)).To(BeFalse())
hdr, pdata, rest, err := ParsePacket(data, 8)
Expect(err).ToNot(HaveOccurred())
Expect(hdr.IsLongHeader).To(BeFalse())
Expect(hdr.DestConnectionID).To(Equal(connID))
b := bytes.NewReader(data)
extHdr, err := hdr.ParseExtended(b, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
Expect(extHdr.KeyPhase).To(Equal(protocol.KeyPhaseZero))
Expect(extHdr.DestConnectionID).To(Equal(connID))
Expect(extHdr.SrcConnectionID).To(BeZero())
Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x42)))
Expect(hdr.ParsedLen()).To(BeEquivalentTo(len(data) - 1))
Expect(extHdr.ParsedLen()).To(Equal(hdr.ParsedLen() + 1))
Expect(pdata).To(Equal(data))
Expect(rest).To(BeEmpty())
})
It("errors if 0x40 is not set", func() {
connID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37})
data := append([]byte{0x0}, connID.Bytes()...)
_, _, _, err := ParsePacket(data, 8)
Expect(err).To(MatchError("not a QUIC packet"))
})
It("errors if the 4th or 5th bit are set", func() {
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5})
data := append([]byte{0x40 | 0x10 /* set the 4th bit */}, connID.Bytes()...)
data = append(data, 0x42) // packet number
hdr, _, _, err := ParsePacket(data, 5)
Expect(err).ToNot(HaveOccurred())
Expect(hdr.IsLongHeader).To(BeFalse())
extHdr, err := hdr.ParseExtended(bytes.NewReader(data), protocol.Version1)
Expect(err).To(MatchError(ErrInvalidReservedBits))
Expect(extHdr).ToNot(BeNil())
Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x42)))
})
It("reads a Short Header with a 5 byte connection ID", func() {
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5})
data := append([]byte{0x40}, connID.Bytes()...)
data = append(data, 0x42) // packet number
hdr, pdata, rest, err := ParsePacket(data, 5)
Expect(err).ToNot(HaveOccurred())
Expect(pdata).To(HaveLen(len(data)))
Expect(hdr.IsLongHeader).To(BeFalse())
Expect(hdr.DestConnectionID).To(Equal(connID))
b := bytes.NewReader(data)
extHdr, err := hdr.ParseExtended(b, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
Expect(extHdr.KeyPhase).To(Equal(protocol.KeyPhaseZero))
Expect(extHdr.DestConnectionID).To(Equal(connID))
Expect(extHdr.SrcConnectionID).To(BeZero())
Expect(rest).To(BeEmpty())
})
It("reads the Key Phase Bit", func() {
data := []byte{
0x40 ^ 0x4,
0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, // connection ID
}
data = append(data, 11) // packet number
hdr, _, _, err := ParsePacket(data, 6)
Expect(err).ToNot(HaveOccurred())
Expect(hdr.IsLongHeader).To(BeFalse())
b := bytes.NewReader(data)
extHdr, err := hdr.ParseExtended(b, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
Expect(extHdr.KeyPhase).To(Equal(protocol.KeyPhaseOne))
Expect(b.Len()).To(BeZero())
})
It("reads a header with a 2 byte packet number", func() {
data := []byte{
0x40 | 0x1,
0xde, 0xad, 0xbe, 0xef, // connection ID
}
data = append(data, []byte{0x13, 0x37}...) // packet number
hdr, _, _, err := ParsePacket(data, 4)
Expect(err).ToNot(HaveOccurred())
b := bytes.NewReader(data)
extHdr, err := hdr.ParseExtended(b, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
Expect(extHdr.IsLongHeader).To(BeFalse())
Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x1337)))
Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen2))
Expect(b.Len()).To(BeZero())
})
It("reads a header with a 3 byte packet number", func() {
data := []byte{
0x40 | 0x2,
0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x1, 0x2, 0x3, 0x4, // connection ID
}
data = append(data, []byte{0x99, 0xbe, 0xef}...) // packet number
hdr, _, _, err := ParsePacket(data, 10)
Expect(err).ToNot(HaveOccurred())
b := bytes.NewReader(data)
extHdr, err := hdr.ParseExtended(b, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
Expect(extHdr.IsLongHeader).To(BeFalse())
Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x99beef)))
Expect(extHdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen3))
Expect(b.Len()).To(BeZero())
})
It("errors on EOF, when parsing the header", func() {
data := []byte{
0x40 ^ 0x2,
0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37, // connection ID
}
for i := 0; i < len(data); i++ {
data = data[:i]
_, _, _, err := ParsePacket(data, 8)
Expect(err).To(Equal(io.EOF))
}
})
It("errors on EOF, when parsing the extended header", func() {
data := []byte{
0x40 ^ 0x3,
0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, // connection ID
}
hdrLen := len(data)
data = append(data, []byte{0xde, 0xad, 0xbe, 0xef}...) // packet number
for i := hdrLen; i < len(data); i++ {
data = data[:i]
hdr, _, _, err := ParsePacket(data, 6)
Expect(err).ToNot(HaveOccurred())
_, err = hdr.ParseExtended(bytes.NewReader(data), protocol.Version1)
Expect(err).To(Equal(io.EOF))
}
})
})
It("distinguishes long and short header packets", func() {
Expect(IsLongHeaderPacket(0x40)).To(BeFalse())
Expect(IsLongHeaderPacket(0x80 ^ 0x40 ^ 0x12)).To(BeTrue())
})
It("tells its packet type for logging", func() {
Expect((&Header{IsLongHeader: true, Type: protocol.PacketTypeHandshake}).PacketType()).To(Equal("Handshake"))
Expect((&Header{}).PacketType()).To(Equal("1-RTT"))
Expect((&Header{Type: protocol.PacketTypeInitial}).PacketType()).To(Equal("Initial"))
Expect((&Header{Type: protocol.PacketTypeHandshake}).PacketType()).To(Equal("Handshake"))
})
})

View file

@ -1,6 +1,7 @@
package wire
import (
"bytes"
"errors"
"fmt"
"io"
@ -9,6 +10,9 @@ import (
"github.com/lucas-clemente/quic-go/internal/utils"
)
// ParseShortHeader parses a short header packet.
// It must be called after header protection was removed.
// Otherwise, the check for the reserved bits will (most likely) fail.
func ParseShortHeader(data []byte, connIDLen int) (length int, _ protocol.PacketNumber, _ protocol.PacketNumberLen, _ protocol.KeyPhaseBit, _ error) {
if len(data) == 0 {
return 0, 0, 0, 0, io.EOF
@ -50,6 +54,21 @@ func ParseShortHeader(data []byte, connIDLen int) (length int, _ protocol.Packet
return 1 + connIDLen + int(pnLen), pn, pnLen, kp, err
}
// WriteShortHeader writes a short header.
func WriteShortHeader(b *bytes.Buffer, connID protocol.ConnectionID, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen, kp protocol.KeyPhaseBit) error {
typeByte := 0x40 | uint8(pnLen-1)
if kp == protocol.KeyPhaseOne {
typeByte |= byte(1 << 2)
}
b.WriteByte(typeByte)
b.Write(connID.Bytes())
return writePacketNumber(b, pn, pnLen)
}
func ShortHeaderLen(dest protocol.ConnectionID, pnLen protocol.PacketNumberLen) protocol.ByteCount {
return 1 + protocol.ByteCount(dest.Len()) + protocol.ByteCount(pnLen)
}
func LogShortHeader(logger utils.Logger, dest protocol.ConnectionID, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen, kp protocol.KeyPhaseBit) {
logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %d, PacketNumberLen: %d, KeyPhase: %s}", dest, pn, pnLen, kp)
}

View file

@ -70,6 +70,25 @@ var _ = Describe("Short Header", func() {
})
})
It("determines the length", func() {
Expect(ShortHeaderLen(protocol.ParseConnectionID([]byte{1, 2, 3, 4}), protocol.PacketNumberLen3)).To(BeEquivalentTo(8))
Expect(ShortHeaderLen(protocol.ParseConnectionID([]byte{}), protocol.PacketNumberLen1)).To(BeEquivalentTo(2))
})
Context("writing", func() {
It("writes a short header packet", func() {
b := &bytes.Buffer{}
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
Expect(WriteShortHeader(b, connID, 1337, 4, protocol.KeyPhaseOne)).To(Succeed())
l, pn, pnLen, kp, err := ParseShortHeader(b.Bytes(), 4)
Expect(err).ToNot(HaveOccurred())
Expect(pn).To(Equal(protocol.PacketNumber(1337)))
Expect(pnLen).To(Equal(protocol.PacketNumberLen4))
Expect(kp).To(Equal(protocol.KeyPhaseOne))
Expect(l).To(Equal(b.Len()))
})
})
Context("logging", func() {
var (
buf *bytes.Buffer

View file

@ -121,7 +121,8 @@ type ConnectionTracer interface {
SentTransportParameters(*TransportParameters)
ReceivedTransportParameters(*TransportParameters)
RestoredTransportParameters(parameters *TransportParameters) // for 0-RTT
SentPacket(hdr *ExtendedHeader, size ByteCount, ack *AckFrame, frames []Frame)
SentLongHeaderPacket(hdr *ExtendedHeader, size ByteCount, ack *AckFrame, frames []Frame)
SentShortHeaderPacket(hdr *ShortHeader, size ByteCount, ack *AckFrame, frames []Frame)
ReceivedVersionNegotiationPacket(dest, src ArbitraryLenConnectionID, _ []VersionNumber)
ReceivedRetry(*Header)
ReceivedLongHeaderPacket(hdr *ExtendedHeader, size ByteCount, frames []Frame)

View file

@ -254,16 +254,28 @@ func (mr *MockConnectionTracerMockRecorder) RestoredTransportParameters(arg0 int
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RestoredTransportParameters", reflect.TypeOf((*MockConnectionTracer)(nil).RestoredTransportParameters), arg0)
}
// SentPacket mocks base method.
func (m *MockConnectionTracer) SentPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 *wire.AckFrame, arg3 []Frame) {
// SentLongHeaderPacket mocks base method.
func (m *MockConnectionTracer) SentLongHeaderPacket(arg0 *wire.ExtendedHeader, arg1 protocol.ByteCount, arg2 *wire.AckFrame, arg3 []Frame) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SentPacket", arg0, arg1, arg2, arg3)
m.ctrl.Call(m, "SentLongHeaderPacket", arg0, arg1, arg2, arg3)
}
// SentPacket indicates an expected call of SentPacket.
func (mr *MockConnectionTracerMockRecorder) SentPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
// SentLongHeaderPacket indicates an expected call of SentLongHeaderPacket.
func (mr *MockConnectionTracerMockRecorder) SentLongHeaderPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentPacket), arg0, arg1, arg2, arg3)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentLongHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentLongHeaderPacket), arg0, arg1, arg2, arg3)
}
// SentShortHeaderPacket mocks base method.
func (m *MockConnectionTracer) SentShortHeaderPacket(arg0 *ShortHeader, arg1 protocol.ByteCount, arg2 *wire.AckFrame, arg3 []Frame) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SentShortHeaderPacket", arg0, arg1, arg2, arg3)
}
// SentShortHeaderPacket indicates an expected call of SentShortHeaderPacket.
func (mr *MockConnectionTracerMockRecorder) SentShortHeaderPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentShortHeaderPacket", reflect.TypeOf((*MockConnectionTracer)(nil).SentShortHeaderPacket), arg0, arg1, arg2, arg3)
}
// SentTransportParameters mocks base method.

View file

@ -104,9 +104,15 @@ func (m *connTracerMultiplexer) RestoredTransportParameters(tp *TransportParamet
}
}
func (m *connTracerMultiplexer) SentPacket(hdr *ExtendedHeader, size ByteCount, ack *AckFrame, frames []Frame) {
func (m *connTracerMultiplexer) SentLongHeaderPacket(hdr *ExtendedHeader, size ByteCount, ack *AckFrame, frames []Frame) {
for _, t := range m.tracers {
t.SentPacket(hdr, size, ack, frames)
t.SentLongHeaderPacket(hdr, size, ack, frames)
}
}
func (m *connTracerMultiplexer) SentShortHeaderPacket(hdr *ShortHeader, size ByteCount, ack *AckFrame, frames []Frame) {
for _, t := range m.tracers {
t.SentShortHeaderPacket(hdr, size, ack, frames)
}
}

View file

@ -157,13 +157,22 @@ var _ = Describe("Tracing", func() {
tracer.RestoredTransportParameters(tp)
})
It("traces the SentPacket event", func() {
It("traces the SentLongHeaderPacket event", func() {
hdr := &ExtendedHeader{Header: Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3})}}
ack := &AckFrame{AckRanges: []AckRange{{Smallest: 1, Largest: 10}}}
ping := &PingFrame{}
tr1.EXPECT().SentPacket(hdr, ByteCount(1337), ack, []Frame{ping})
tr2.EXPECT().SentPacket(hdr, ByteCount(1337), ack, []Frame{ping})
tracer.SentPacket(hdr, 1337, ack, []Frame{ping})
tr1.EXPECT().SentLongHeaderPacket(hdr, ByteCount(1337), ack, []Frame{ping})
tr2.EXPECT().SentLongHeaderPacket(hdr, ByteCount(1337), ack, []Frame{ping})
tracer.SentLongHeaderPacket(hdr, 1337, ack, []Frame{ping})
})
It("traces the SentShortHeaderPacket event", func() {
hdr := &ShortHeader{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3})}
ack := &AckFrame{AckRanges: []AckRange{{Smallest: 1, Largest: 10}}}
ping := &PingFrame{}
tr1.EXPECT().SentShortHeaderPacket(hdr, ByteCount(1337), ack, []Frame{ping})
tr2.EXPECT().SentShortHeaderPacket(hdr, ByteCount(1337), ack, []Frame{ping})
tracer.SentShortHeaderPacket(hdr, 1337, ack, []Frame{ping})
})
It("traces the ReceivedVersionNegotiationPacket event", func() {

View file

@ -31,11 +31,12 @@ func (n NullConnectionTracer) StartedConnection(local, remote net.Addr, srcConnI
func (n NullConnectionTracer) NegotiatedVersion(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) {
}
func (n NullConnectionTracer) ClosedConnection(err error) {}
func (n NullConnectionTracer) SentTransportParameters(*TransportParameters) {}
func (n NullConnectionTracer) ReceivedTransportParameters(*TransportParameters) {}
func (n NullConnectionTracer) RestoredTransportParameters(*TransportParameters) {}
func (n NullConnectionTracer) SentPacket(*ExtendedHeader, ByteCount, *AckFrame, []Frame) {}
func (n NullConnectionTracer) ClosedConnection(err error) {}
func (n NullConnectionTracer) SentTransportParameters(*TransportParameters) {}
func (n NullConnectionTracer) ReceivedTransportParameters(*TransportParameters) {}
func (n NullConnectionTracer) RestoredTransportParameters(*TransportParameters) {}
func (n NullConnectionTracer) SentLongHeaderPacket(*ExtendedHeader, ByteCount, *AckFrame, []Frame) {}
func (n NullConnectionTracer) SentShortHeaderPacket(*ShortHeader, ByteCount, *AckFrame, []Frame) {}
func (n NullConnectionTracer) ReceivedVersionNegotiationPacket(dest, src ArbitraryLenConnectionID, _ []VersionNumber) {
}
func (n NullConnectionTracer) ReceivedRetry(*Header) {}

View file

@ -6,9 +6,6 @@ import (
// PacketTypeFromHeader determines the packet type from a *wire.Header.
func PacketTypeFromHeader(hdr *Header) PacketType {
if !hdr.IsLongHeader {
return PacketType1RTT
}
if hdr.Version == 0 {
return PacketTypeVersionNegotiation
}

View file

@ -12,49 +12,38 @@ var _ = Describe("Packet Header", func() {
Context("determining the packet type from the header", func() {
It("recognizes Initial packets", func() {
Expect(PacketTypeFromHeader(&wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
Version: protocol.VersionTLS,
Type: protocol.PacketTypeInitial,
Version: protocol.VersionTLS,
})).To(Equal(PacketTypeInitial))
})
It("recognizes Handshake packets", func() {
Expect(PacketTypeFromHeader(&wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
Version: protocol.VersionTLS,
Type: protocol.PacketTypeHandshake,
Version: protocol.VersionTLS,
})).To(Equal(PacketTypeHandshake))
})
It("recognizes Retry packets", func() {
Expect(PacketTypeFromHeader(&wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeRetry,
Version: protocol.VersionTLS,
Type: protocol.PacketTypeRetry,
Version: protocol.VersionTLS,
})).To(Equal(PacketTypeRetry))
})
It("recognizes 0-RTT packets", func() {
Expect(PacketTypeFromHeader(&wire.Header{
IsLongHeader: true,
Type: protocol.PacketType0RTT,
Version: protocol.VersionTLS,
Type: protocol.PacketType0RTT,
Version: protocol.VersionTLS,
})).To(Equal(PacketType0RTT))
})
It("recognizes Version Negotiation packets", func() {
Expect(PacketTypeFromHeader(&wire.Header{IsLongHeader: true})).To(Equal(PacketTypeVersionNegotiation))
})
It("recognizes 1-RTT packets", func() {
Expect(PacketTypeFromHeader(&wire.Header{})).To(Equal(PacketType1RTT))
Expect(PacketTypeFromHeader(&wire.Header{})).To(Equal(PacketTypeVersionNegotiation))
})
It("handles unrecognized packet types", func() {
Expect(PacketTypeFromHeader(&wire.Header{
IsLongHeader: true,
Version: protocol.VersionTLS,
})).To(Equal(PacketTypeNotDetermined))
Expect(PacketTypeFromHeader(&wire.Header{Version: protocol.VersionTLS})).To(Equal(PacketTypeNotDetermined))
})
})
})

View file

@ -6,6 +6,7 @@ package quic
import (
reflect "reflect"
time "time"
gomock "github.com/golang/mock/gomock"
ackhandler "github.com/lucas-clemente/quic-go/internal/ackhandler"
@ -50,10 +51,10 @@ func (mr *MockPackerMockRecorder) HandleTransportParameters(arg0 interface{}) *g
}
// MaybePackProbePacket mocks base method.
func (m *MockPacker) MaybePackProbePacket(arg0 protocol.EncryptionLevel) (*packedPacket, error) {
func (m *MockPacker) MaybePackProbePacket(arg0 protocol.EncryptionLevel) (*coalescedPacket, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MaybePackProbePacket", arg0)
ret0, _ := ret[0].(*packedPacket)
ret0, _ := ret[0].(*coalescedPacket)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@ -110,33 +111,35 @@ func (mr *MockPackerMockRecorder) PackConnectionClose(arg0 interface{}) *gomock.
}
// PackMTUProbePacket mocks base method.
func (m *MockPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount) (*packedPacket, error) {
func (m *MockPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, now time.Time) (shortHeaderPacket, *packetBuffer, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "PackMTUProbePacket", ping, size)
ret0, _ := ret[0].(*packedPacket)
ret1, _ := ret[1].(error)
return ret0, ret1
ret := m.ctrl.Call(m, "PackMTUProbePacket", ping, size, now)
ret0, _ := ret[0].(shortHeaderPacket)
ret1, _ := ret[1].(*packetBuffer)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// PackMTUProbePacket indicates an expected call of PackMTUProbePacket.
func (mr *MockPackerMockRecorder) PackMTUProbePacket(ping, size interface{}) *gomock.Call {
func (mr *MockPackerMockRecorder) PackMTUProbePacket(ping, size, now interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackMTUProbePacket", reflect.TypeOf((*MockPacker)(nil).PackMTUProbePacket), ping, size)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackMTUProbePacket", reflect.TypeOf((*MockPacker)(nil).PackMTUProbePacket), ping, size, now)
}
// PackPacket mocks base method.
func (m *MockPacker) PackPacket(onlyAck bool) (*packedPacket, error) {
func (m *MockPacker) PackPacket(onlyAck bool, now time.Time) (shortHeaderPacket, *packetBuffer, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "PackPacket", onlyAck)
ret0, _ := ret[0].(*packedPacket)
ret1, _ := ret[1].(error)
return ret0, ret1
ret := m.ctrl.Call(m, "PackPacket", onlyAck, now)
ret0, _ := ret[0].(shortHeaderPacket)
ret1, _ := ret[1].(*packetBuffer)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// PackPacket indicates an expected call of PackPacket.
func (mr *MockPackerMockRecorder) PackPacket(onlyAck interface{}) *gomock.Call {
func (mr *MockPackerMockRecorder) PackPacket(onlyAck, now interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackPacket", reflect.TypeOf((*MockPacker)(nil).PackPacket), onlyAck)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackPacket", reflect.TypeOf((*MockPacker)(nil).PackPacket), onlyAck, now)
}
// SetMaxPacketSize mocks base method.

View file

@ -40,7 +40,6 @@ var _ = Describe("Packet Handler Map", func() {
buf := &bytes.Buffer{}
Expect((&wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
Type: t,
DestConnectionID: connID,
Length: length,

View file

@ -15,15 +15,17 @@ import (
"github.com/lucas-clemente/quic-go/internal/wire"
)
var errNothingToPack = errors.New("nothing to pack")
type packer interface {
PackCoalescedPacket(onlyAck bool) (*coalescedPacket, error)
PackPacket(onlyAck bool) (*packedPacket, error)
MaybePackProbePacket(protocol.EncryptionLevel) (*packedPacket, error)
PackPacket(onlyAck bool, now time.Time) (shortHeaderPacket, *packetBuffer, error)
MaybePackProbePacket(protocol.EncryptionLevel) (*coalescedPacket, error)
PackConnectionClose(*qerr.TransportError) (*coalescedPacket, error)
PackApplicationClose(*qerr.ApplicationError) (*coalescedPacket, error)
SetMaxPacketSize(protocol.ByteCount)
PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount) (*packedPacket, error)
PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, now time.Time) (shortHeaderPacket, *packetBuffer, error)
HandleTransportParameters(*wire.TransportParameters)
SetToken([]byte)
@ -39,12 +41,7 @@ type payload struct {
length protocol.ByteCount
}
type packedPacket struct {
buffer *packetBuffer
*packetContents
}
type packetContents struct {
type longHeaderPacket struct {
header *wire.ExtendedHeader
ack *wire.AckFrame
frames []ackhandler.Frame
@ -54,15 +51,24 @@ type packetContents struct {
isMTUProbePacket bool
}
type coalescedPacket struct {
buffer *packetBuffer
packets []*packetContents
type shortHeaderPacket struct {
*ackhandler.Packet
// used for logging
DestConnID protocol.ConnectionID
Ack *wire.AckFrame
PacketNumberLen protocol.PacketNumberLen
KeyPhase protocol.KeyPhaseBit
}
func (p *packetContents) EncryptionLevel() protocol.EncryptionLevel {
if !p.header.IsLongHeader {
return protocol.Encryption1RTT
}
func (p *shortHeaderPacket) IsAckEliciting() bool { return ackhandler.HasAckElicitingFrames(p.Frames) }
type coalescedPacket struct {
buffer *packetBuffer
longHdrPackets []*longHeaderPacket
shortHdrPacket *shortHeaderPacket
}
func (p *longHeaderPacket) EncryptionLevel() protocol.EncryptionLevel {
//nolint:exhaustive // Will never be called for Retry packets (and they don't have encrypted data).
switch p.header.Type {
case protocol.PacketTypeInitial:
@ -76,11 +82,9 @@ func (p *packetContents) EncryptionLevel() protocol.EncryptionLevel {
}
}
func (p *packetContents) IsAckEliciting() bool {
return ackhandler.HasAckElicitingFrames(p.frames)
}
func (p *longHeaderPacket) IsAckEliciting() bool { return ackhandler.HasAckElicitingFrames(p.frames) }
func (p *packetContents) ToAckHandlerPacket(now time.Time, q *retransmissionQueue) *ackhandler.Packet {
func (p *longHeaderPacket) ToAckHandlerPacket(now time.Time, q *retransmissionQueue) *ackhandler.Packet {
largestAcked := protocol.InvalidPacketNumber
if p.ack != nil {
largestAcked = p.ack.LargestAcked()
@ -90,12 +94,13 @@ func (p *packetContents) ToAckHandlerPacket(now time.Time, q *retransmissionQueu
if p.frames[i].OnLost != nil {
continue
}
//nolint:exhaustive // Short header packets are handled separately.
switch encLevel {
case protocol.EncryptionInitial:
p.frames[i].OnLost = q.AddInitial
case protocol.EncryptionHandshake:
p.frames[i].OnLost = q.AddHandshake
case protocol.Encryption0RTT, protocol.Encryption1RTT:
case protocol.Encryption0RTT:
p.frames[i].OnLost = q.AddAppData
}
}
@ -226,10 +231,14 @@ func (p *packetPacker) packConnectionClose(
reason string,
) (*coalescedPacket, error) {
var sealers [4]sealer
var hdrs [4]*wire.ExtendedHeader
var hdrs [3]*wire.ExtendedHeader
var payloads [4]*payload
var size protocol.ByteCount
var numPackets uint8
var connID protocol.ConnectionID
var oneRTTPacketNumber protocol.PacketNumber
var oneRTTPacketNumberLen protocol.PacketNumberLen
var keyPhase protocol.KeyPhaseBit // only set for 1-RTT
var numLongHdrPackets uint8
encLevels := [4]protocol.EncryptionLevel{protocol.EncryptionInitial, protocol.EncryptionHandshake, protocol.Encryption0RTT, protocol.Encryption1RTT}
for i, encLevel := range encLevels {
if p.perspective == protocol.PerspectiveServer && encLevel == protocol.Encryption0RTT {
@ -254,7 +263,6 @@ func (p *packetPacker) packConnectionClose(
var sealer sealer
var err error
var keyPhase protocol.KeyPhaseBit // only set for 1-RTT
switch encLevel {
case protocol.EncryptionInitial:
sealer, err = p.cryptoSetup.GetInitialSealer()
@ -279,17 +287,22 @@ func (p *packetPacker) packConnectionClose(
sealers[i] = sealer
var hdr *wire.ExtendedHeader
if encLevel == protocol.Encryption1RTT {
hdr = p.getShortHeader(keyPhase)
connID = p.getDestConnID()
oneRTTPacketNumber, oneRTTPacketNumberLen = p.pnManager.PeekPacketNumber(protocol.Encryption1RTT)
size += p.shortHeaderPacketLength(connID, oneRTTPacketNumberLen, payload)
} else {
hdr = p.getLongHeader(encLevel)
hdrs[i] = hdr
size += p.longHeaderPacketLength(hdr, payload) + protocol.ByteCount(sealer.Overhead())
numLongHdrPackets++
}
hdrs[i] = hdr
payloads[i] = payload
size += p.packetLength(hdr, payload) + protocol.ByteCount(sealer.Overhead())
numPackets++
}
contents := make([]*packetContents, 0, numPackets)
buffer := getPacketBuffer()
packet := &coalescedPacket{
buffer: buffer,
longHdrPackets: make([]*longHeaderPacket, 0, numLongHdrPackets),
}
for i, encLevel := range encLevels {
if sealers[i] == nil {
continue
@ -298,19 +311,27 @@ func (p *packetPacker) packConnectionClose(
if encLevel == protocol.EncryptionInitial {
paddingLen = p.initialPaddingLen(payloads[i].frames, size)
}
c, err := p.appendPacket(buffer, hdrs[i], payloads[i], paddingLen, encLevel, sealers[i], false)
if err != nil {
return nil, err
if encLevel == protocol.Encryption1RTT {
shortHdrPacket, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, keyPhase, payloads[i], paddingLen, sealers[i], false)
if err != nil {
return nil, err
}
packet.shortHdrPacket = shortHdrPacket
} else {
longHdrPacket, err := p.appendLongHeaderPacket(buffer, hdrs[i], payloads[i], paddingLen, encLevel, sealers[i])
if err != nil {
return nil, err
}
packet.longHdrPackets = append(packet.longHdrPackets, longHdrPacket)
}
contents = append(contents, c)
}
return &coalescedPacket{buffer: buffer, packets: contents}, nil
return packet, nil
}
// packetLength calculates the length of the serialized packet.
// longHeaderPacketLength calculates the length of a serialized long header packet.
// It takes into account that packets that have a tiny payload need to be padded,
// such that len(payload) + packet number len >= 4 + AEAD overhead
func (p *packetPacker) packetLength(hdr *wire.ExtendedHeader, payload *payload) protocol.ByteCount {
func (p *packetPacker) longHeaderPacketLength(hdr *wire.ExtendedHeader, payload *payload) protocol.ByteCount {
var paddingLen protocol.ByteCount
pnLen := protocol.ByteCount(hdr.PacketNumberLen)
if payload.length < 4-pnLen {
@ -319,6 +340,17 @@ func (p *packetPacker) packetLength(hdr *wire.ExtendedHeader, payload *payload)
return hdr.GetLength(p.version) + payload.length + paddingLen
}
// shortHeaderPacketLength calculates the length of a serialized short header packet.
// It takes into account that packets that have a tiny payload need to be padded,
// such that len(payload) + packet number len >= 4 + AEAD overhead
func (p *packetPacker) shortHeaderPacketLength(connID protocol.ConnectionID, pnLen protocol.PacketNumberLen, payload *payload) protocol.ByteCount {
var paddingLen protocol.ByteCount
if payload.length < 4-protocol.ByteCount(pnLen) {
paddingLen = 4 - protocol.ByteCount(pnLen) - payload.length
}
return wire.ShortHeaderLen(connID, pnLen) + payload.length + paddingLen
}
// size is the expected size of the packet, if no padding was applied.
func (p *packetPacker) initialPaddingLen(frames []ackhandler.Frame, size protocol.ByteCount) protocol.ByteCount {
// For the server, only ack-eliciting Initial packets need to be padded.
@ -339,9 +371,12 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool) (*coalescedPacket, erro
if p.perspective == protocol.PerspectiveClient {
maxPacketSize = protocol.MinInitialPacketSize
}
var initialHdr, handshakeHdr, appDataHdr *wire.ExtendedHeader
var initialPayload, handshakePayload, appDataPayload *payload
var numPackets int
var (
initialHdr, handshakeHdr, zeroRTTHdr *wire.ExtendedHeader
initialPayload, handshakePayload, zeroRTTPayload, oneRTTPayload *payload
oneRTTPacketNumber protocol.PacketNumber
oneRTTPacketNumberLen protocol.PacketNumberLen
)
// Try packing an Initial packet.
initialSealer, err := p.cryptoSetup.GetInitialSealer()
if err != nil && err != handshake.ErrKeysDropped {
@ -351,8 +386,7 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool) (*coalescedPacket, erro
if initialSealer != nil {
initialHdr, initialPayload = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(initialSealer.Overhead()), protocol.EncryptionInitial, onlyAck, true)
if initialPayload != nil {
size += p.packetLength(initialHdr, initialPayload) + protocol.ByteCount(initialSealer.Overhead())
numPackets++
size += p.longHeaderPacketLength(initialHdr, initialPayload) + protocol.ByteCount(initialSealer.Overhead())
}
}
@ -367,94 +401,108 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool) (*coalescedPacket, erro
if handshakeSealer != nil {
handshakeHdr, handshakePayload = p.maybeGetCryptoPacket(maxPacketSize-size-protocol.ByteCount(handshakeSealer.Overhead()), protocol.EncryptionHandshake, onlyAck, size == 0)
if handshakePayload != nil {
s := p.packetLength(handshakeHdr, handshakePayload) + protocol.ByteCount(handshakeSealer.Overhead())
s := p.longHeaderPacketLength(handshakeHdr, handshakePayload) + protocol.ByteCount(handshakeSealer.Overhead())
size += s
numPackets++
}
}
}
// Add a 0-RTT / 1-RTT packet.
var appDataSealer sealer
appDataEncLevel := protocol.Encryption1RTT
var zeroRTTSealer sealer
var oneRTTSealer handshake.ShortHeaderSealer
var connID protocol.ConnectionID
var kp protocol.KeyPhaseBit
if (onlyAck && size == 0) || (!onlyAck && size < maxPacketSize-protocol.MinCoalescedPacketSize) {
var sErr error
var oneRTTSealer handshake.ShortHeaderSealer
oneRTTSealer, sErr = p.cryptoSetup.Get1RTTSealer()
appDataSealer = oneRTTSealer
if sErr != nil && p.perspective == protocol.PerspectiveClient {
appDataSealer, sErr = p.cryptoSetup.Get0RTTSealer()
appDataEncLevel = protocol.Encryption0RTT
var err error
oneRTTSealer, err = p.cryptoSetup.Get1RTTSealer()
if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable {
return nil, err
}
if appDataSealer != nil && sErr == nil {
//nolint:exhaustive // 0-RTT and 1-RTT are the only two application data encryption levels.
switch appDataEncLevel {
case protocol.Encryption0RTT:
appDataHdr, appDataPayload = p.maybeGetAppDataPacketFor0RTT(appDataSealer, maxPacketSize-size)
case protocol.Encryption1RTT:
appDataHdr, appDataPayload = p.maybeGetShortHeaderPacket(oneRTTSealer, maxPacketSize-size, onlyAck, size == 0)
if err == nil { // 1-RTT
kp = oneRTTSealer.KeyPhase()
connID = p.getDestConnID()
oneRTTPacketNumber, oneRTTPacketNumberLen = p.pnManager.PeekPacketNumber(protocol.Encryption1RTT)
hdrLen := wire.ShortHeaderLen(connID, oneRTTPacketNumberLen)
oneRTTPayload = p.maybeGetShortHeaderPacket(oneRTTSealer, hdrLen, maxPacketSize-size, onlyAck, size == 0)
if oneRTTPayload != nil {
size += p.shortHeaderPacketLength(connID, oneRTTPacketNumberLen, oneRTTPayload) + protocol.ByteCount(oneRTTSealer.Overhead())
}
if appDataHdr != nil && appDataPayload != nil {
size += p.packetLength(appDataHdr, appDataPayload) + protocol.ByteCount(appDataSealer.Overhead())
numPackets++
} else if p.perspective == protocol.PerspectiveClient { // 0-RTT
var err error
zeroRTTSealer, err = p.cryptoSetup.Get0RTTSealer()
if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable {
return nil, err
}
if zeroRTTSealer != nil {
zeroRTTHdr, zeroRTTPayload = p.maybeGetAppDataPacketFor0RTT(zeroRTTSealer, maxPacketSize-size)
if zeroRTTPayload != nil {
size += p.longHeaderPacketLength(zeroRTTHdr, zeroRTTPayload) + protocol.ByteCount(zeroRTTSealer.Overhead())
}
}
}
}
if numPackets == 0 {
if initialPayload == nil && handshakePayload == nil && zeroRTTPayload == nil && oneRTTPayload == nil {
return nil, nil
}
buffer := getPacketBuffer()
packet := &coalescedPacket{
buffer: buffer,
packets: make([]*packetContents, 0, numPackets),
buffer: buffer,
longHdrPackets: make([]*longHeaderPacket, 0, 3),
}
if initialPayload != nil {
padding := p.initialPaddingLen(initialPayload.frames, size)
cont, err := p.appendPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer, false)
cont, err := p.appendLongHeaderPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer)
if err != nil {
return nil, err
}
packet.packets = append(packet.packets, cont)
packet.longHdrPackets = append(packet.longHdrPackets, cont)
}
if handshakePayload != nil {
cont, err := p.appendPacket(buffer, handshakeHdr, handshakePayload, 0, protocol.EncryptionHandshake, handshakeSealer, false)
cont, err := p.appendLongHeaderPacket(buffer, handshakeHdr, handshakePayload, 0, protocol.EncryptionHandshake, handshakeSealer)
if err != nil {
return nil, err
}
packet.packets = append(packet.packets, cont)
packet.longHdrPackets = append(packet.longHdrPackets, cont)
}
if appDataPayload != nil {
cont, err := p.appendPacket(buffer, appDataHdr, appDataPayload, 0, appDataEncLevel, appDataSealer, false)
if zeroRTTPayload != nil {
longHdrPacket, err := p.appendLongHeaderPacket(buffer, zeroRTTHdr, zeroRTTPayload, 0, protocol.Encryption0RTT, zeroRTTSealer)
if err != nil {
return nil, err
}
packet.packets = append(packet.packets, cont)
packet.longHdrPackets = append(packet.longHdrPackets, longHdrPacket)
} else if oneRTTPayload != nil {
shortHdrPacket, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, kp, oneRTTPayload, 0, oneRTTSealer, false)
if err != nil {
return nil, err
}
packet.shortHdrPacket = shortHdrPacket
}
return packet, nil
}
// PackPacket packs a packet in the application data packet number space.
// It should be called after the handshake is confirmed.
func (p *packetPacker) PackPacket(onlyAck bool) (*packedPacket, error) {
func (p *packetPacker) PackPacket(onlyAck bool, now time.Time) (shortHeaderPacket, *packetBuffer, error) {
sealer, err := p.cryptoSetup.Get1RTTSealer()
if err != nil {
return nil, err
return shortHeaderPacket{}, nil, err
}
hdr, payload := p.maybeGetShortHeaderPacket(sealer, p.maxPacketSize, onlyAck, true)
pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT)
connID := p.getDestConnID()
hdrLen := wire.ShortHeaderLen(connID, pnLen)
payload := p.maybeGetShortHeaderPacket(sealer, hdrLen, p.maxPacketSize, onlyAck, true)
if payload == nil {
return nil, nil
return shortHeaderPacket{}, nil, errNothingToPack
}
kp := sealer.KeyPhase()
buffer := getPacketBuffer()
cont, err := p.appendPacket(buffer, hdr, payload, 0, protocol.Encryption1RTT, sealer, false)
packet, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, payload, 0, sealer, false)
if err != nil {
return nil, err
return shortHeaderPacket{}, nil, err
}
return &packedPacket{
buffer: buffer,
packetContents: cont,
}, nil
return *packet, buffer, nil
}
func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, encLevel protocol.EncryptionLevel, onlyAck, ackAllowed bool) (*wire.ExtendedHeader, *payload) {
@ -535,11 +583,9 @@ func (p *packetPacker) maybeGetAppDataPacketFor0RTT(sealer sealer, maxPacketSize
return hdr, payload
}
func (p *packetPacker) maybeGetShortHeaderPacket(sealer handshake.ShortHeaderSealer, maxPacketSize protocol.ByteCount, onlyAck, ackAllowed bool) (*wire.ExtendedHeader, *payload) {
hdr := p.getShortHeader(sealer.KeyPhase())
maxPayloadSize := maxPacketSize - hdr.GetLength(p.version) - protocol.ByteCount(sealer.Overhead())
payload := p.maybeGetAppDataPacket(maxPayloadSize, onlyAck, ackAllowed)
return hdr, payload
func (p *packetPacker) maybeGetShortHeaderPacket(sealer handshake.ShortHeaderSealer, hdrLen protocol.ByteCount, maxPacketSize protocol.ByteCount, onlyAck, ackAllowed bool) *payload {
maxPayloadSize := maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead())
return p.maybeGetAppDataPacket(maxPayloadSize, onlyAck, ackAllowed)
}
func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount, onlyAck, ackAllowed bool) *payload {
@ -636,10 +682,33 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAc
return payload
}
func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel) (*packedPacket, error) {
func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel) (*coalescedPacket, error) {
if encLevel == protocol.Encryption1RTT {
s, err := p.cryptoSetup.Get1RTTSealer()
if err != nil {
return nil, err
}
kp := s.KeyPhase()
connID := p.getDestConnID()
pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT)
hdrLen := wire.ShortHeaderLen(connID, pnLen)
payload := p.maybeGetAppDataPacket(p.maxPacketSize-protocol.ByteCount(s.Overhead())-hdrLen, false, true)
if payload == nil {
return nil, nil
}
buffer := getPacketBuffer()
packet := &coalescedPacket{buffer: buffer}
shortHdrPacket, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, payload, 0, s, false)
if err != nil {
return nil, err
}
packet.shortHdrPacket = shortHdrPacket
return packet, nil
}
var hdr *wire.ExtendedHeader
var payload *payload
var sealer sealer
var sealer handshake.LongHeaderSealer
//nolint:exhaustive // Probe packets are never sent for 0-RTT.
switch encLevel {
case protocol.EncryptionInitial:
@ -656,67 +725,47 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel) (
return nil, err
}
hdr, payload = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionHandshake, false, true)
case protocol.Encryption1RTT:
oneRTTSealer, err := p.cryptoSetup.Get1RTTSealer()
if err != nil {
return nil, err
}
sealer = oneRTTSealer
hdr = p.getShortHeader(oneRTTSealer.KeyPhase())
payload = p.maybeGetAppDataPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead())-hdr.GetLength(p.version), false, true)
default:
panic("unknown encryption level")
}
if payload == nil {
return nil, nil
}
size := p.packetLength(hdr, payload) + protocol.ByteCount(sealer.Overhead())
buffer := getPacketBuffer()
packet := &coalescedPacket{buffer: buffer}
size := p.longHeaderPacketLength(hdr, payload) + protocol.ByteCount(sealer.Overhead())
var padding protocol.ByteCount
if encLevel == protocol.EncryptionInitial {
padding = p.initialPaddingLen(payload.frames, size)
}
buffer := getPacketBuffer()
cont, err := p.appendPacket(buffer, hdr, payload, padding, encLevel, sealer, false)
longHdrPacket, err := p.appendLongHeaderPacket(buffer, hdr, payload, padding, encLevel, sealer)
if err != nil {
return nil, err
}
return &packedPacket{
buffer: buffer,
packetContents: cont,
}, nil
packet.longHdrPackets = []*longHeaderPacket{longHdrPacket}
return packet, nil
}
func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount) (*packedPacket, error) {
func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, now time.Time) (shortHeaderPacket, *packetBuffer, error) {
payload := &payload{
frames: []ackhandler.Frame{ping},
length: ping.Length(p.version),
}
buffer := getPacketBuffer()
sealer, err := p.cryptoSetup.Get1RTTSealer()
s, err := p.cryptoSetup.Get1RTTSealer()
if err != nil {
return nil, err
return shortHeaderPacket{}, nil, err
}
hdr := p.getShortHeader(sealer.KeyPhase())
padding := size - p.packetLength(hdr, payload) - protocol.ByteCount(sealer.Overhead())
contents, err := p.appendPacket(buffer, hdr, payload, padding, protocol.Encryption1RTT, sealer, true)
if err != nil {
return nil, err
}
contents.isMTUProbePacket = true
return &packedPacket{
buffer: buffer,
packetContents: contents,
}, nil
}
func (p *packetPacker) getShortHeader(kp protocol.KeyPhaseBit) *wire.ExtendedHeader {
connID := p.getDestConnID()
pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT)
hdr := &wire.ExtendedHeader{}
hdr.PacketNumber = pn
hdr.PacketNumberLen = pnLen
hdr.DestConnectionID = p.getDestConnID()
hdr.KeyPhase = kp
return hdr
padding := size - p.shortHeaderPacketLength(connID, pnLen, payload) - protocol.ByteCount(s.Overhead())
packet, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, s.KeyPhase(), payload, padding, s, true)
if err != nil {
return shortHeaderPacket{}, nil, err
}
return *packet, buffer, nil
}
func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel) *wire.ExtendedHeader {
@ -725,7 +774,6 @@ func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel) *wire.Ex
PacketNumber: pn,
PacketNumberLen: pnLen,
}
hdr.IsLongHeader = true
hdr.Version = p.version
hdr.SrcConnectionID = p.srcConnID
hdr.DestConnectionID = p.getDestConnID()
@ -743,25 +791,118 @@ func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel) *wire.Ex
return hdr
}
func (p *packetPacker) appendPacket(buffer *packetBuffer, header *wire.ExtendedHeader, payload *payload, padding protocol.ByteCount, encLevel protocol.EncryptionLevel, sealer sealer, isMTUProbePacket bool) (*packetContents, error) {
func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire.ExtendedHeader, payload *payload, padding protocol.ByteCount, encLevel protocol.EncryptionLevel, sealer sealer) (*longHeaderPacket, error) {
var paddingLen protocol.ByteCount
pnLen := protocol.ByteCount(header.PacketNumberLen)
if payload.length < 4-pnLen {
paddingLen = 4 - pnLen - payload.length
}
paddingLen += padding
if header.IsLongHeader {
header.Length = pnLen + protocol.ByteCount(sealer.Overhead()) + payload.length + paddingLen
}
header.Length = pnLen + protocol.ByteCount(sealer.Overhead()) + payload.length + paddingLen
hdrOffset := buffer.Len()
raw := buffer.Data[len(buffer.Data):]
buf := bytes.NewBuffer(buffer.Data)
startLen := buf.Len()
if err := header.Write(buf, p.version); err != nil {
return nil, err
}
payloadOffset := buf.Len()
raw := buffer.Data[:payloadOffset]
raw = raw[:buf.Len()-startLen]
payloadOffset := protocol.ByteCount(len(raw))
pn := p.pnManager.PopPacketNumber(encLevel)
if pn != header.PacketNumber {
return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match")
}
raw, err := p.appendPacketPayload(raw, payload, paddingLen)
if err != nil {
return nil, err
}
raw = p.encryptPacket(raw, sealer, pn, payloadOffset, pnLen)
buffer.Data = buffer.Data[:len(buffer.Data)+len(raw)]
return &longHeaderPacket{
header: header,
ack: payload.ack,
frames: payload.frames,
length: protocol.ByteCount(len(raw)),
}, nil
}
func (p *packetPacker) appendShortHeaderPacket(
buffer *packetBuffer,
connID protocol.ConnectionID,
pn protocol.PacketNumber,
pnLen protocol.PacketNumberLen,
kp protocol.KeyPhaseBit,
payload *payload,
padding protocol.ByteCount,
sealer sealer,
isMTUProbePacket bool,
) (*shortHeaderPacket, error) {
var paddingLen protocol.ByteCount
if payload.length < 4-protocol.ByteCount(pnLen) {
paddingLen = 4 - protocol.ByteCount(pnLen) - payload.length
}
paddingLen += padding
raw := buffer.Data[len(buffer.Data):]
buf := bytes.NewBuffer(buffer.Data)
startLen := buf.Len()
if err := wire.WriteShortHeader(buf, connID, pn, pnLen, kp); err != nil {
return nil, err
}
raw = raw[:buf.Len()-startLen]
payloadOffset := protocol.ByteCount(len(raw))
if pn != p.pnManager.PopPacketNumber(protocol.Encryption1RTT) {
return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match")
}
raw, err := p.appendPacketPayload(raw, payload, paddingLen)
if err != nil {
return nil, err
}
if !isMTUProbePacket {
if size := protocol.ByteCount(len(raw) + sealer.Overhead()); size > p.maxPacketSize {
return nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize)
}
}
raw = p.encryptPacket(raw, sealer, pn, payloadOffset, protocol.ByteCount(pnLen))
buffer.Data = buffer.Data[:len(buffer.Data)+len(raw)]
// create the ackhandler.Packet
largestAcked := protocol.InvalidPacketNumber
if payload.ack != nil {
largestAcked = payload.ack.LargestAcked()
}
for i := range payload.frames {
if payload.frames[i].OnLost != nil {
continue
}
payload.frames[i].OnLost = p.retransmissionQueue.AddAppData
}
ap := ackhandler.GetPacket()
ap.PacketNumber = pn
ap.LargestAcked = largestAcked
ap.Frames = payload.frames
ap.Length = protocol.ByteCount(len(raw))
ap.EncryptionLevel = protocol.Encryption1RTT
ap.SendTime = time.Now()
ap.IsPathMTUProbePacket = isMTUProbePacket
return &shortHeaderPacket{
Packet: ap,
DestConnID: connID,
Ack: payload.ack,
PacketNumberLen: pnLen,
KeyPhase: kp,
}, nil
}
func (p *packetPacker) appendPacketPayload(raw []byte, payload *payload, paddingLen protocol.ByteCount) ([]byte, error) {
payloadOffset := len(raw)
if payload.ack != nil {
var err error
raw, err = payload.ack.Append(raw, p.version)
@ -783,30 +924,16 @@ func (p *packetPacker) appendPacket(buffer *packetBuffer, header *wire.ExtendedH
if payloadSize := protocol.ByteCount(len(raw)-payloadOffset) - paddingLen; payloadSize != payload.length {
return nil, fmt.Errorf("PacketPacker BUG: payload size inconsistent (expected %d, got %d bytes)", payload.length, payloadSize)
}
if !isMTUProbePacket {
if size := protocol.ByteCount(len(raw) + sealer.Overhead()); size > p.maxPacketSize {
return nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize)
}
}
return raw, nil
}
// encrypt the packet
_ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], header.PacketNumber, raw[hdrOffset:payloadOffset])
raw = raw[0 : len(raw)+sealer.Overhead()]
func (p *packetPacker) encryptPacket(raw []byte, sealer sealer, pn protocol.PacketNumber, payloadOffset, pnLen protocol.ByteCount) []byte {
_ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], pn, raw[:payloadOffset])
raw = raw[:len(raw)+sealer.Overhead()]
// apply header protection
pnOffset := payloadOffset - int(header.PacketNumberLen)
sealer.EncryptHeader(raw[pnOffset+4:pnOffset+4+16], &raw[hdrOffset], raw[pnOffset:payloadOffset])
buffer.Data = raw
num := p.pnManager.PopPacketNumber(encLevel)
if num != header.PacketNumber {
return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match")
}
return &packetContents{
header: header,
ack: payload.ack,
frames: payload.frames,
length: buffer.Len() - hdrOffset,
}, nil
pnOffset := payloadOffset - pnLen
sealer.EncryptHeader(raw[pnOffset+4:pnOffset+4+16], &raw[0], raw[pnOffset:payloadOffset])
return raw
}
func (p *packetPacker) SetToken(token []byte) {

File diff suppressed because it is too large Load diff

View file

@ -27,18 +27,24 @@ var _ = Describe("Packet Unpacker", func() {
payload = []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
)
getHeader := func(extHdr *wire.ExtendedHeader) (*wire.Header, []byte) {
getLongHeader := func(extHdr *wire.ExtendedHeader) (*wire.Header, []byte) {
buf := &bytes.Buffer{}
ExpectWithOffset(1, extHdr.Write(buf, version)).To(Succeed())
hdrLen := buf.Len()
if extHdr.Length > protocol.ByteCount(extHdr.PacketNumberLen) {
buf.Write(make([]byte, int(extHdr.Length)-int(extHdr.PacketNumberLen)))
}
hdr, _, _, err := wire.ParsePacket(buf.Bytes(), connID.Len())
hdr, _, _, err := wire.ParsePacket(buf.Bytes())
ExpectWithOffset(1, err).ToNot(HaveOccurred())
return hdr, buf.Bytes()[:hdrLen]
}
getShortHeader := func(connID protocol.ConnectionID, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen, kp protocol.KeyPhaseBit) []byte {
buf := &bytes.Buffer{}
Expect(wire.WriteShortHeader(buf, connID, pn, pnLen, kp)).To(Succeed())
return buf.Bytes()
}
BeforeEach(func() {
cs = mocks.NewMockCryptoSetup(mockCtrl)
unpacker = newPacketUnpacker(cs, 4, version).(*packetUnpacker)
@ -47,7 +53,6 @@ var _ = Describe("Packet Unpacker", func() {
It("errors when the packet is too small to obtain the header decryption sample, for long headers", func() {
extHdr := &wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
DestConnectionID: connID,
Version: version,
@ -55,7 +60,7 @@ var _ = Describe("Packet Unpacker", func() {
PacketNumber: 1337,
PacketNumberLen: protocol.PacketNumberLen2,
}
hdr, hdrRaw := getHeader(extHdr)
hdr, hdrRaw := getLongHeader(extHdr)
data := append(hdrRaw, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...)
opener := mocks.NewMockLongHeaderOpener(mockCtrl)
cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
@ -67,12 +72,9 @@ var _ = Describe("Packet Unpacker", func() {
})
It("errors when the packet is too small to obtain the header decryption sample, for short headers", func() {
_, hdrRaw := getHeader(&wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: connID},
PacketNumber: 1337,
PacketNumberLen: protocol.PacketNumberLen2,
})
data := append(hdrRaw, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...)
buf := &bytes.Buffer{}
Expect(wire.WriteShortHeader(buf, connID, 1337, protocol.PacketNumberLen2, protocol.KeyPhaseOne)).To(Succeed())
data := append(buf.Bytes(), make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...)
opener := mocks.NewMockShortHeaderOpener(mockCtrl)
cs.EXPECT().Get1RTTOpener().Return(opener, nil)
_, _, _, _, err := unpacker.UnpackShortHeader(time.Now(), data)
@ -83,7 +85,6 @@ var _ = Describe("Packet Unpacker", func() {
It("opens Initial packets", func() {
extHdr := &wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
Length: 3 + 6, // packet number len + payload
DestConnectionID: connID,
@ -92,7 +93,7 @@ var _ = Describe("Packet Unpacker", func() {
PacketNumber: 2,
PacketNumberLen: 3,
}
hdr, hdrRaw := getHeader(extHdr)
hdr, hdrRaw := getLongHeader(extHdr)
opener := mocks.NewMockLongHeaderOpener(mockCtrl)
gomock.InOrder(
cs.EXPECT().GetInitialOpener().Return(opener, nil),
@ -109,7 +110,6 @@ var _ = Describe("Packet Unpacker", func() {
It("opens 0-RTT packets", func() {
extHdr := &wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketType0RTT,
Length: 3 + 6, // packet number len + payload
DestConnectionID: connID,
@ -118,7 +118,7 @@ var _ = Describe("Packet Unpacker", func() {
PacketNumber: 20,
PacketNumberLen: 2,
}
hdr, hdrRaw := getHeader(extHdr)
hdr, hdrRaw := getLongHeader(extHdr)
opener := mocks.NewMockLongHeaderOpener(mockCtrl)
gomock.InOrder(
cs.EXPECT().Get0RTTOpener().Return(opener, nil),
@ -133,13 +133,7 @@ var _ = Describe("Packet Unpacker", func() {
})
It("opens short header packets", func() {
extHdr := &wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: connID},
KeyPhase: protocol.KeyPhaseOne,
PacketNumber: 99,
PacketNumberLen: protocol.PacketNumberLen4,
}
_, hdrRaw := getHeader(extHdr)
hdrRaw := getShortHeader(connID, 99, protocol.PacketNumberLen4, protocol.KeyPhaseOne)
opener := mocks.NewMockShortHeaderOpener(mockCtrl)
now := time.Now()
gomock.InOrder(
@ -157,12 +151,7 @@ var _ = Describe("Packet Unpacker", func() {
})
It("returns the error when getting the opener fails", func() {
extHdr := &wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: connID},
PacketNumber: 0x1337,
PacketNumberLen: 2,
}
_, hdrRaw := getHeader(extHdr)
hdrRaw := getShortHeader(connID, 0x1337, protocol.PacketNumberLen2, protocol.KeyPhaseOne)
cs.EXPECT().Get1RTTOpener().Return(nil, handshake.ErrKeysNotYetAvailable)
_, _, _, _, err := unpacker.UnpackShortHeader(time.Now(), append(hdrRaw, payload...))
Expect(err).To(MatchError(handshake.ErrKeysNotYetAvailable))
@ -171,7 +160,6 @@ var _ = Describe("Packet Unpacker", func() {
It("errors on empty packets, for long header packets", func() {
extHdr := &wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
DestConnectionID: connID,
Version: Version1,
@ -179,7 +167,7 @@ var _ = Describe("Packet Unpacker", func() {
KeyPhase: protocol.KeyPhaseOne,
PacketNumberLen: protocol.PacketNumberLen4,
}
hdr, hdrRaw := getHeader(extHdr)
hdr, hdrRaw := getLongHeader(extHdr)
opener := mocks.NewMockLongHeaderOpener(mockCtrl)
gomock.InOrder(
cs.EXPECT().GetHandshakeOpener().Return(opener, nil),
@ -195,12 +183,7 @@ var _ = Describe("Packet Unpacker", func() {
})
It("errors on empty packets, for short header packets", func() {
extHdr := &wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: connID},
KeyPhase: protocol.KeyPhaseOne,
PacketNumberLen: protocol.PacketNumberLen4,
}
_, hdrRaw := getHeader(extHdr)
hdrRaw := getShortHeader(connID, 0x42, protocol.PacketNumberLen4, protocol.KeyPhaseOne)
opener := mocks.NewMockShortHeaderOpener(mockCtrl)
now := time.Now()
gomock.InOrder(
@ -219,7 +202,6 @@ var _ = Describe("Packet Unpacker", func() {
It("returns the error when unpacking fails", func() {
extHdr := &wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
Length: 3, // packet number len
DestConnectionID: connID,
@ -228,7 +210,7 @@ var _ = Describe("Packet Unpacker", func() {
PacketNumber: 2,
PacketNumberLen: 3,
}
hdr, hdrRaw := getHeader(extHdr)
hdr, hdrRaw := getLongHeader(extHdr)
opener := mocks.NewMockLongHeaderOpener(mockCtrl)
cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
@ -242,7 +224,6 @@ var _ = Describe("Packet Unpacker", func() {
It("defends against the timing side-channel when the reserved bits are wrong, for long header packets", func() {
extHdr := &wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
DestConnectionID: connID,
Version: version,
@ -250,7 +231,7 @@ var _ = Describe("Packet Unpacker", func() {
PacketNumber: 0x1337,
PacketNumberLen: 2,
}
hdr, hdrRaw := getHeader(extHdr)
hdr, hdrRaw := getLongHeader(extHdr)
hdrRaw[0] |= 0xc
opener := mocks.NewMockLongHeaderOpener(mockCtrl)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
@ -262,12 +243,7 @@ var _ = Describe("Packet Unpacker", func() {
})
It("defends against the timing side-channel when the reserved bits are wrong, for short header packets", func() {
extHdr := &wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: connID},
PacketNumber: 0x1337,
PacketNumberLen: 2,
}
_, hdrRaw := getHeader(extHdr)
hdrRaw := getShortHeader(connID, 0x1337, protocol.PacketNumberLen2, protocol.KeyPhaseZero)
hdrRaw[0] |= 0x18
opener := mocks.NewMockShortHeaderOpener(mockCtrl)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
@ -281,7 +257,6 @@ var _ = Describe("Packet Unpacker", func() {
It("returns the decryption error, when unpacking a packet with wrong reserved bits fails, for long headers", func() {
extHdr := &wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
DestConnectionID: connID,
Version: version,
@ -289,7 +264,7 @@ var _ = Describe("Packet Unpacker", func() {
PacketNumber: 0x1337,
PacketNumberLen: 2,
}
hdr, hdrRaw := getHeader(extHdr)
hdr, hdrRaw := getLongHeader(extHdr)
hdrRaw[0] |= 0x18
opener := mocks.NewMockLongHeaderOpener(mockCtrl)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
@ -301,12 +276,7 @@ var _ = Describe("Packet Unpacker", func() {
})
It("returns the decryption error, when unpacking a packet with wrong reserved bits fails, for short headers", func() {
extHdr := &wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: connID},
PacketNumber: 0x1337,
PacketNumberLen: 2,
}
_, hdrRaw := getHeader(extHdr)
hdrRaw := getShortHeader(connID, 0x1337, protocol.PacketNumberLen2, protocol.KeyPhaseZero)
hdrRaw[0] |= 0x18
opener := mocks.NewMockShortHeaderOpener(mockCtrl)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
@ -320,7 +290,6 @@ var _ = Describe("Packet Unpacker", func() {
It("decrypts the header", func() {
extHdr := &wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
Length: 3, // packet number len
DestConnectionID: connID,
@ -329,7 +298,7 @@ var _ = Describe("Packet Unpacker", func() {
PacketNumber: 0x1337,
PacketNumberLen: 2,
}
hdr, hdrRaw := getHeader(extHdr)
hdr, hdrRaw := getLongHeader(extHdr)
origHdrRaw := append([]byte{}, hdrRaw...) // save a copy of the header
firstHdrByte := hdrRaw[0]
hdrRaw[0] ^= 0xff // invert the first byte

View file

@ -154,7 +154,7 @@ func (e eventConnectionClosed) MarshalJSONObject(enc *gojay.Encoder) {
}
type eventPacketSent struct {
Header packetHeader
Header gojay.MarshalerJSONObject // either a shortHeader or a packetHeader
Length logging.ByteCount
PayloadLength logging.ByteCount
Frames frames

View file

@ -32,30 +32,14 @@ var _ = Describe("Packet Header", func() {
checkEncoding(data, expected)
}
It("marshals a header for a 1-RTT packet", func() {
check(
&wire.ExtendedHeader{
PacketNumber: 42,
KeyPhase: protocol.KeyPhaseZero,
},
map[string]interface{}{
"packet_type": "1RTT",
"packet_number": 42,
"dcil": 0,
"key_phase_bit": "0",
},
)
})
It("marshals a header with a payload length", func() {
check(
&wire.ExtendedHeader{
PacketNumber: 42,
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
Length: 123,
Version: protocol.VersionNumber(0xdecafbad),
Type: protocol.PacketTypeInitial,
Length: 123,
Version: protocol.VersionNumber(0xdecafbad),
},
},
map[string]interface{}{
@ -73,11 +57,10 @@ var _ = Describe("Packet Header", func() {
&wire.ExtendedHeader{
PacketNumber: 4242,
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
Length: 123,
Version: protocol.VersionNumber(0xdecafbad),
Token: []byte{0xde, 0xad, 0xbe, 0xef},
Type: protocol.PacketTypeInitial,
Length: 123,
Version: protocol.VersionNumber(0xdecafbad),
Token: []byte{0xde, 0xad, 0xbe, 0xef},
},
},
map[string]interface{}{
@ -95,7 +78,6 @@ var _ = Describe("Packet Header", func() {
check(
&wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeRetry,
SrcConnectionID: protocol.ParseConnectionID([]byte{0x11, 0x22, 0x33, 0x44}),
Version: protocol.VersionNumber(0xdecafbad),
@ -118,9 +100,8 @@ var _ = Describe("Packet Header", func() {
&wire.ExtendedHeader{
PacketNumber: 0,
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
Version: protocol.VersionNumber(0xdecafbad),
Type: protocol.PacketTypeHandshake,
Version: protocol.VersionNumber(0xdecafbad),
},
},
map[string]interface{}{
@ -138,7 +119,6 @@ var _ = Describe("Packet Header", func() {
&wire.ExtendedHeader{
PacketNumber: 42,
Header: wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
SrcConnectionID: protocol.ParseConnectionID([]byte{0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff}),
Version: protocol.VersionNumber(0xdecafbad),
@ -154,22 +134,5 @@ var _ = Describe("Packet Header", func() {
},
)
})
It("marshals a 1-RTT header with a destination connection ID", func() {
check(
&wire.ExtendedHeader{
PacketNumber: 42,
Header: wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef})},
KeyPhase: protocol.KeyPhaseOne,
},
map[string]interface{}{
"packet_type": "1RTT",
"packet_number": 42,
"dcil": 4,
"dcid": "deadbeef",
"key_phase_bit": "1",
},
)
})
})
})

View file

@ -274,7 +274,15 @@ func (t *connectionTracer) toTransportParameters(tp *wire.TransportParameters) *
}
}
func (t *connectionTracer) SentPacket(hdr *wire.ExtendedHeader, packetSize logging.ByteCount, ack *logging.AckFrame, frames []logging.Frame) {
func (t *connectionTracer) SentLongHeaderPacket(hdr *logging.ExtendedHeader, packetSize logging.ByteCount, ack *logging.AckFrame, frames []logging.Frame) {
t.sentPacket(*transformLongHeader(hdr), packetSize, hdr.Length, ack, frames)
}
func (t *connectionTracer) SentShortHeaderPacket(hdr *logging.ShortHeader, packetSize logging.ByteCount, ack *logging.AckFrame, frames []logging.Frame) {
t.sentPacket(*transformShortHeader(hdr), packetSize, 0, ack, frames)
}
func (t *connectionTracer) sentPacket(hdr gojay.MarshalerJSONObject, packetSize, payloadLen logging.ByteCount, ack *logging.AckFrame, frames []logging.Frame) {
numFrames := len(frames)
if ack != nil {
numFrames++
@ -286,12 +294,11 @@ func (t *connectionTracer) SentPacket(hdr *wire.ExtendedHeader, packetSize loggi
for _, f := range frames {
fs = append(fs, frame{Frame: f})
}
header := *transformLongHeader(hdr)
t.mutex.Lock()
t.recordEvent(time.Now(), &eventPacketSent{
Header: header,
Header: hdr,
Length: packetSize,
PayloadLength: hdr.Length,
PayloadLength: payloadLen,
Frames: fs,
})
t.mutex.Unlock()
@ -319,12 +326,11 @@ func (t *connectionTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, p
fs[i] = frame{Frame: f}
}
header := *transformShortHeader(hdr)
hdrLen := 1 + hdr.DestConnectionID.Len() + int(hdr.PacketNumberLen)
t.mutex.Lock()
t.recordEvent(time.Now(), &eventPacketReceived{
Header: header,
Length: packetSize,
PayloadLength: packetSize - protocol.ByteCount(hdrLen),
PayloadLength: packetSize - wire.ShortHeaderLen(hdr.DestConnectionID, hdr.PacketNumberLen),
Frames: fs,
})
t.mutex.Unlock()

View file

@ -29,8 +29,8 @@ func scaleDuration(t time.Duration) time.Duration {
func checkEncoding(data []byte, expected map[string]interface{}) {
// unmarshal the data
m := make(map[string]interface{})
ExpectWithOffset(1, json.Unmarshal(data, &m)).To(Succeed())
ExpectWithOffset(1, m).To(HaveLen(len(expected)))
ExpectWithOffset(2, json.Unmarshal(data, &m)).To(Succeed())
ExpectWithOffset(2, m).To(HaveLen(len(expected)))
for key, value := range expected {
switch v := value.(type) {
case bool, string, map[string]interface{}:

View file

@ -420,11 +420,10 @@ var _ = Describe("Tracing", func() {
Expect(ev).To(HaveKeyWithValue("initial_max_stream_data_uni", float64(300)))
})
It("records a sent packet, without an ACK", func() {
tracer.SentPacket(
It("records a sent long header packet, without an ACK", func() {
tracer.SentLongHeaderPacket(
&logging.ExtendedHeader{
Header: logging.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
SrcConnectionID: protocol.ParseConnectionID([]byte{4, 3, 2, 1}),
@ -460,11 +459,11 @@ var _ = Describe("Tracing", func() {
Expect(frames[1].(map[string]interface{})).To(HaveKeyWithValue("frame_type", "stream"))
})
It("records a sent packet, without an ACK", func() {
tracer.SentPacket(
&logging.ExtendedHeader{
Header: logging.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4})},
PacketNumber: 1337,
It("records a sent short header packet, without an ACK", func() {
tracer.SentShortHeaderPacket(
&logging.ShortHeader{
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
PacketNumber: 1337,
},
123,
&logging.AckFrame{AckRanges: []logging.AckRange{{Smallest: 1, Largest: 10}}},
@ -490,7 +489,6 @@ var _ = Describe("Tracing", func() {
tracer.ReceivedLongHeaderPacket(
&logging.ExtendedHeader{
Header: logging.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
SrcConnectionID: protocol.ParseConnectionID([]byte{4, 3, 2, 1}),
@ -561,7 +559,6 @@ var _ = Describe("Tracing", func() {
It("records a received Retry packet", func() {
tracer.ReceivedRetry(
&logging.Header{
IsLongHeader: true,
Type: protocol.PacketTypeRetry,
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
SrcConnectionID: protocol.ParseConnectionID([]byte{4, 3, 2, 1}),

View file

@ -348,7 +348,7 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
}
// If we're creating a new connection, the packet will be passed to the connection.
// The header will then be parsed again.
hdr, _, _, err := wire.ParsePacket(p.data, s.config.ConnectionIDGenerator.ConnectionIDLen())
hdr, _, _, err := wire.ParsePacket(p.data)
if err != nil {
if s.config.Tracer != nil {
s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError)
@ -364,7 +364,7 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s
return false
}
if hdr.IsLongHeader && hdr.Type != protocol.PacketTypeInitial {
if hdr.Type != protocol.PacketTypeInitial {
// Drop long header packets.
// There's little point in sending a Stateless Reset, since the client
// might not have received the token yet.
@ -566,7 +566,6 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *pack
return err
}
replyHdr := &wire.ExtendedHeader{}
replyHdr.IsLongHeader = true
replyHdr.Type = protocol.PacketTypeRetry
replyHdr.Version = hdr.Version
replyHdr.SrcConnectionID = srcConnID
@ -635,7 +634,6 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han
ccf := &wire.ConnectionCloseFrame{ErrorCode: uint64(errorCode)}
replyHdr := &wire.ExtendedHeader{}
replyHdr.IsLongHeader = true
replyHdr.Type = protocol.PacketTypeInitial
replyHdr.Version = hdr.Version
replyHdr.SrcConnectionID = hdr.DestConnectionID

View file

@ -44,9 +44,7 @@ var _ = Describe("Server", func() {
getPacket := func(hdr *wire.Header, p []byte) *receivedPacket {
buffer := getPacketBuffer()
buf := bytes.NewBuffer(buffer.Data)
if hdr.IsLongHeader {
hdr.Length = 4 + protocol.ByteCount(len(p)) + 16
}
hdr.Length = 4 + protocol.ByteCount(len(p)) + 16
Expect((&wire.ExtendedHeader{
Header: *hdr,
PacketNumber: 0x42,
@ -69,7 +67,6 @@ var _ = Describe("Server", func() {
getInitial := func(destConnID protocol.ConnectionID) *receivedPacket {
senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
DestConnectionID: destConnID,
@ -90,7 +87,7 @@ var _ = Describe("Server", func() {
}
parseHeader := func(data []byte) *wire.Header {
hdr, _, _, err := wire.ParsePacket(data, 0)
hdr, _, _, err := wire.ParsePacket(data)
Expect(err).ToNot(HaveOccurred())
return hdr
}
@ -202,7 +199,6 @@ var _ = Describe("Server", func() {
Context("handling packets", func() {
It("drops Initial packets with a too short connection ID", func() {
p := getPacket(&wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
Version: serv.config.Versions[0],
@ -215,7 +211,6 @@ var _ = Describe("Server", func() {
It("drops too small Initial", func() {
p := getPacket(&wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}),
Version: serv.config.Versions[0],
@ -229,9 +224,8 @@ var _ = Describe("Server", func() {
It("drops non-Initial packets", func() {
p := getPacket(&wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
Version: serv.config.Versions[0],
Type: protocol.PacketTypeHandshake,
Version: serv.config.Versions[0],
}, []byte("invalid"))
tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeHandshake, p.Size(), logging.PacketDropUnexpectedPacket)
serv.handlePacket(p)
@ -249,7 +243,6 @@ var _ = Describe("Server", func() {
)
Expect(err).ToNot(HaveOccurred())
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
@ -326,7 +319,6 @@ var _ = Describe("Server", func() {
srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5})
destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6})
packet := getPacket(&wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
SrcConnectionID: srcConnID,
DestConnectionID: destConnID,
@ -358,7 +350,6 @@ var _ = Describe("Server", func() {
srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5})
destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6})
packet := getPacket(&wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
SrcConnectionID: srcConnID,
DestConnectionID: destConnID,
@ -397,7 +388,6 @@ var _ = Describe("Server", func() {
srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5})
destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6})
p := getPacket(&wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeHandshake,
SrcConnectionID: srcConnID,
DestConnectionID: destConnID,
@ -419,7 +409,6 @@ var _ = Describe("Server", func() {
It("replies with a Retry packet, if a token is required", func() {
serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
@ -451,7 +440,6 @@ var _ = Describe("Server", func() {
It("creates a connection, if no token is required", func() {
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
@ -663,7 +651,7 @@ var _ = Describe("Server", func() {
}
wg.Wait()
p := getInitialWithRandomDestConnID()
hdr, _, _, err := wire.ParsePacket(p.data, 0)
hdr, _, _, err := wire.ParsePacket(p.data)
Expect(err).ToNot(HaveOccurred())
tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
done := make(chan struct{})
@ -767,10 +755,9 @@ var _ = Describe("Server", func() {
token, err := serv.tokenGenerator.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{})
Expect(err).ToNot(HaveOccurred())
packet := getPacket(&wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
Token: token,
Version: serv.config.Versions[0],
Type: protocol.PacketTypeInitial,
Token: token,
Version: serv.config.Versions[0],
}, make([]byte, protocol.MinInitialPacketSize))
packet.remoteAddr = raddr
conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).MaxTimes(1)
@ -787,7 +774,6 @@ var _ = Describe("Server", func() {
token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{})
Expect(err).ToNot(HaveOccurred())
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
@ -826,7 +812,6 @@ var _ = Describe("Server", func() {
Expect(err).ToNot(HaveOccurred())
time.Sleep(2 * time.Millisecond) // make sure the token is expired
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
@ -860,7 +845,6 @@ var _ = Describe("Server", func() {
token, err := serv.tokenGenerator.NewToken(&net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337})
Expect(err).ToNot(HaveOccurred())
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
@ -892,7 +876,6 @@ var _ = Describe("Server", func() {
Expect(err).ToNot(HaveOccurred())
time.Sleep(2 * time.Millisecond) // make sure the token is expired
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
@ -918,7 +901,6 @@ var _ = Describe("Server", func() {
token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{})
Expect(err).ToNot(HaveOccurred())
hdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeInitial,
SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}),
DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),