diff --git a/connection.go b/connection.go index 38d15274..1b8b3c4d 100644 --- a/connection.go +++ b/connection.go @@ -1879,7 +1879,7 @@ func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel) error { return err } } - if packet == nil || len(packet.packets) == 0 { + if packet == nil || (len(packet.longHdrPackets) == 0 && packet.shortHdrPacket == nil) { return fmt.Errorf("connection BUG: couldn't pack %s probe packet", encLevel) } s.sendPackedCoalescedPacket(packet, time.Now()) @@ -1935,12 +1935,18 @@ func (s *connection) sendPackedShortHeaderPacket(buffer *packetBuffer, p *ackhan func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time.Time) { s.logCoalescedPacket(packet) - for _, p := range packet.packets { + for _, p := range packet.longHdrPackets { if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() { s.firstAckElicitingPacketAfterIdleSentTime = now } s.sentPacketHandler.SentPacket(p.ToAckHandlerPacket(now, s.retransmissionQueue)) } + if p := packet.shortHdrPacket; p != nil { + if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() { + s.firstAckElicitingPacketAfterIdleSentTime = now + } + s.sentPacketHandler.SentPacket(p.Packet) + } s.connIDManager.SentPacket() s.sendQueue.Send(packet.buffer) } @@ -1967,7 +1973,7 @@ func (s *connection) sendConnectionClose(e error) ([]byte, error) { return packet.buffer.Data, s.conn.Write(packet.buffer.Data) } -func (s *connection) logLongHeaderPacket(p *packetContents) { +func (s *connection) logLongHeaderPacket(p *longHeaderPacket) { // quic-go logging if s.logger.Debug() { p.header.Log(s.logger) @@ -2043,18 +2049,17 @@ func (s *connection) logShortHeaderPacket( func (s *connection) logCoalescedPacket(packet *coalescedPacket) { if s.logger.Debug() { - if len(packet.packets) > 1 { - s.logger.Debugf("-> Sending coalesced packet (%d parts, %d bytes) for connection %s", len(packet.packets), packet.buffer.Len(), s.logID) + if len(packet.longHdrPackets) > 1 { + s.logger.Debugf("-> Sending coalesced packet (%d parts, %d bytes) for connection %s", len(packet.longHdrPackets), packet.buffer.Len(), s.logID) } else { - s.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, %s", packet.packets[0].header.PacketNumber, packet.buffer.Len(), s.logID, packet.packets[0].EncryptionLevel()) + s.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, %s", packet.longHdrPackets[0].header.PacketNumber, packet.buffer.Len(), s.logID, packet.longHdrPackets[0].EncryptionLevel()) } } - for _, p := range packet.packets { - if p.header.IsLongHeader { - s.logLongHeaderPacket(p) - } else { - s.logShortHeaderPacket(p.header.DestConnectionID, p.ack, p.frames, p.header.PacketNumber, p.header.PacketNumberLen, p.header.KeyPhase, p.length, true) - } + for _, p := range packet.longHdrPackets { + s.logLongHeaderPacket(p) + } + if p := packet.shortHdrPacket; p != nil { + s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, p.Length, true) } } diff --git a/connection_test.go b/connection_test.go index 36c5a120..ef433641 100644 --- a/connection_test.go +++ b/connection_test.go @@ -59,16 +59,27 @@ var _ = Describe("Connection", func() { return shortHeaderPacket{Packet: &ackhandler.Packet{PacketNumber: pn}}, buffer } - getCoalescedPacket := func(pn protocol.PacketNumber) *coalescedPacket { + getCoalescedPacket := func(pn protocol.PacketNumber, isLongHeader bool) *coalescedPacket { buffer := getPacketBuffer() buffer.Data = append(buffer.Data, []byte("foobar")...) - return &coalescedPacket{ - buffer: buffer, - packets: []*packetContents{{ - header: &wire.ExtendedHeader{PacketNumber: pn}, + packet := &coalescedPacket{buffer: buffer} + if isLongHeader { + packet.longHdrPackets = []*longHeaderPacket{{ + header: &wire.ExtendedHeader{ + Header: wire.Header{IsLongHeader: true}, + PacketNumber: pn, + }, length: 6, // foobar - }}, + }} + } else { + packet.shortHdrPacket = &shortHeaderPacket{ + Packet: &ackhandler.Packet{ + PacketNumber: pn, + Length: 6, + }, + } } + return packet } expectReplaceWithClosed := func() { @@ -1337,7 +1348,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().SendMode().Return(sendMode) sph.EXPECT().SendMode().Return(ackhandler.SendNone) sph.EXPECT().QueueProbePacket(encLevel) - p := getCoalescedPacket(123) + p := getCoalescedPacket(123, enc != protocol.Encryption1RTT) packer.EXPECT().MaybePackProbePacket(encLevel).Return(p, nil) sph.EXPECT().SentPacket(gomock.Any()).Do(func(packet *ackhandler.Packet) { Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(123))) @@ -1346,7 +1357,11 @@ var _ = Describe("Connection", func() { runConn() sent := make(chan struct{}) sender.EXPECT().Send(gomock.Any()).Do(func(packet *packetBuffer) { close(sent) }) - tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), p.packets[0].length, gomock.Any(), gomock.Any()) + if enc == protocol.Encryption1RTT { + tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), p.shortHdrPacket.Length, gomock.Any(), gomock.Any()) + } else { + tracer.EXPECT().SentLongHeaderPacket(gomock.Any(), p.longHdrPackets[0].length, gomock.Any(), gomock.Any()) + } conn.scheduleSending() Eventually(sent).Should(BeClosed()) }) @@ -1358,7 +1373,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().SendMode().Return(sendMode) sph.EXPECT().SendMode().Return(ackhandler.SendNone) sph.EXPECT().QueueProbePacket(encLevel).Return(false) - p := getCoalescedPacket(123) + p := getCoalescedPacket(123, enc != protocol.Encryption1RTT) packer.EXPECT().MaybePackProbePacket(encLevel).Return(p, nil) sph.EXPECT().SentPacket(gomock.Any()).Do(func(packet *ackhandler.Packet) { Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(123))) @@ -1367,7 +1382,11 @@ var _ = Describe("Connection", func() { runConn() sent := make(chan struct{}) sender.EXPECT().Send(gomock.Any()).Do(func(packet *packetBuffer) { close(sent) }) - tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), p.packets[0].length, gomock.Any(), gomock.Any()) + if enc == protocol.Encryption1RTT { + tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), p.shortHdrPacket.Length, gomock.Any(), gomock.Any()) + } else { + tracer.EXPECT().SentLongHeaderPacket(gomock.Any(), p.longHdrPackets[0].length, gomock.Any(), gomock.Any()) + } conn.scheduleSending() Eventually(sent).Should(BeClosed()) // We're using a mock packet packer in this test. @@ -1763,7 +1782,7 @@ var _ = Describe("Connection", func() { buffer.Data = append(buffer.Data, []byte("foobar")...) packer.EXPECT().PackCoalescedPacket(false).Return(&coalescedPacket{ buffer: buffer, - packets: []*packetContents{ + longHdrPackets: []*longHeaderPacket{ { header: &wire.ExtendedHeader{ Header: wire.Header{ diff --git a/packet_packer.go b/packet_packer.go index 5b33c56c..83d5fa86 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -41,7 +41,7 @@ type payload struct { length protocol.ByteCount } -type packetContents struct { +type longHeaderPacket struct { header *wire.ExtendedHeader ack *wire.AckFrame frames []ackhandler.Frame @@ -60,14 +60,17 @@ type shortHeaderPacket struct { KeyPhase protocol.KeyPhaseBit } +func (p *shortHeaderPacket) IsAckEliciting() bool { return ackhandler.HasAckElicitingFrames(p.Frames) } + type coalescedPacket struct { - buffer *packetBuffer - packets []*packetContents + buffer *packetBuffer + longHdrPackets []*longHeaderPacket + shortHdrPacket *shortHeaderPacket } -func (p *packetContents) EncryptionLevel() protocol.EncryptionLevel { +func (p *longHeaderPacket) EncryptionLevel() protocol.EncryptionLevel { if !p.header.IsLongHeader { - return protocol.Encryption1RTT + panic("this shouldn't happen any more") } //nolint:exhaustive // Will never be called for Retry packets (and they don't have encrypted data). switch p.header.Type { @@ -82,11 +85,9 @@ func (p *packetContents) EncryptionLevel() protocol.EncryptionLevel { } } -func (p *packetContents) IsAckEliciting() bool { - return ackhandler.HasAckElicitingFrames(p.frames) -} +func (p *longHeaderPacket) IsAckEliciting() bool { return ackhandler.HasAckElicitingFrames(p.frames) } -func (p *packetContents) ToAckHandlerPacket(now time.Time, q *retransmissionQueue) *ackhandler.Packet { +func (p *longHeaderPacket) ToAckHandlerPacket(now time.Time, q *retransmissionQueue) *ackhandler.Packet { largestAcked := protocol.InvalidPacketNumber if p.ack != nil { largestAcked = p.ack.LargestAcked() @@ -96,12 +97,13 @@ func (p *packetContents) ToAckHandlerPacket(now time.Time, q *retransmissionQueu if p.frames[i].OnLost != nil { continue } + //nolint:exhaustive // Short header packets are handled separately. switch encLevel { case protocol.EncryptionInitial: p.frames[i].OnLost = q.AddInitial case protocol.EncryptionHandshake: p.frames[i].OnLost = q.AddHandshake - case protocol.Encryption0RTT, protocol.Encryption1RTT: + case protocol.Encryption0RTT: p.frames[i].OnLost = q.AddAppData } } @@ -235,7 +237,8 @@ func (p *packetPacker) packConnectionClose( var hdrs [4]*wire.ExtendedHeader var payloads [4]*payload var size protocol.ByteCount - var numPackets uint8 + var keyPhase protocol.KeyPhaseBit // only set for 1-RTT + var numLongHdrPackets uint8 encLevels := [4]protocol.EncryptionLevel{protocol.EncryptionInitial, protocol.EncryptionHandshake, protocol.Encryption0RTT, protocol.Encryption1RTT} for i, encLevel := range encLevels { if p.perspective == protocol.PerspectiveServer && encLevel == protocol.Encryption0RTT { @@ -260,7 +263,6 @@ func (p *packetPacker) packConnectionClose( var sealer sealer var err error - var keyPhase protocol.KeyPhaseBit // only set for 1-RTT switch encLevel { case protocol.EncryptionInitial: sealer, err = p.cryptoSetup.GetInitialSealer() @@ -292,10 +294,15 @@ func (p *packetPacker) packConnectionClose( hdrs[i] = hdr payloads[i] = payload size += p.packetLength(hdr, payload) + protocol.ByteCount(sealer.Overhead()) - numPackets++ + if encLevel != protocol.Encryption1RTT { + numLongHdrPackets++ + } } - contents := make([]*packetContents, 0, numPackets) buffer := getPacketBuffer() + packet := &coalescedPacket{ + buffer: buffer, + longHdrPackets: make([]*longHeaderPacket, 0, numLongHdrPackets), + } for i, encLevel := range encLevels { if sealers[i] == nil { continue @@ -304,19 +311,21 @@ func (p *packetPacker) packConnectionClose( if encLevel == protocol.EncryptionInitial { paddingLen = p.initialPaddingLen(payloads[i].frames, size) } - var c *packetContents - var err error if encLevel == protocol.Encryption1RTT { - c, err = p.appendShortHeaderPacket(buffer, hdrs[i], payloads[i], paddingLen, sealers[i], false) + shortHdrPacket, err := p.appendShortHeaderPacket(buffer, hdrs[i].PacketNumber, hdrs[i].PacketNumberLen, keyPhase, payloads[i], paddingLen, sealers[i], false) + if err != nil { + return nil, err + } + packet.shortHdrPacket = shortHdrPacket } else { - c, err = p.appendLongHeaderPacket(buffer, hdrs[i], payloads[i], paddingLen, encLevel, sealers[i]) + longHdrPacket, err := p.appendLongHeaderPacket(buffer, hdrs[i], payloads[i], paddingLen, encLevel, sealers[i]) + if err != nil { + return nil, err + } + packet.longHdrPackets = append(packet.longHdrPackets, longHdrPacket) } - if err != nil { - return nil, err - } - contents = append(contents, c) } - return &coalescedPacket{buffer: buffer, packets: contents}, nil + return packet, nil } // packetLength calculates the length of the serialized packet. @@ -388,6 +397,7 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool) (*coalescedPacket, erro // Add a 0-RTT / 1-RTT packet. var appDataSealer sealer + var kp protocol.KeyPhaseBit appDataEncLevel := protocol.Encryption1RTT if (onlyAck && size == 0) || (!onlyAck && size < maxPacketSize-protocol.MinCoalescedPacketSize) { var sErr error @@ -404,7 +414,8 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool) (*coalescedPacket, erro case protocol.Encryption0RTT: appDataHdr, appDataPayload = p.maybeGetAppDataPacketFor0RTT(appDataSealer, maxPacketSize-size) case protocol.Encryption1RTT: - appDataHdr, appDataPayload = p.maybeGetShortHeaderPacket(oneRTTSealer, maxPacketSize-size, onlyAck, size == 0) + kp = oneRTTSealer.KeyPhase() + appDataHdr, appDataPayload = p.maybeGetShortHeaderPacket(oneRTTSealer, kp, maxPacketSize-size, onlyAck, size == 0) } if appDataHdr != nil && appDataPayload != nil { size += p.packetLength(appDataHdr, appDataPayload) + protocol.ByteCount(appDataSealer.Overhead()) @@ -419,8 +430,8 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool) (*coalescedPacket, erro buffer := getPacketBuffer() packet := &coalescedPacket{ - buffer: buffer, - packets: make([]*packetContents, 0, numPackets), + buffer: buffer, + longHdrPackets: make([]*longHeaderPacket, 0, numPackets), } if initialPayload != nil { padding := p.initialPaddingLen(initialPayload.frames, size) @@ -428,27 +439,29 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool) (*coalescedPacket, erro if err != nil { return nil, err } - packet.packets = append(packet.packets, cont) + packet.longHdrPackets = append(packet.longHdrPackets, cont) } if handshakePayload != nil { cont, err := p.appendLongHeaderPacket(buffer, handshakeHdr, handshakePayload, 0, protocol.EncryptionHandshake, handshakeSealer) if err != nil { return nil, err } - packet.packets = append(packet.packets, cont) + packet.longHdrPackets = append(packet.longHdrPackets, cont) } if appDataPayload != nil { - var cont *packetContents - var err error if appDataEncLevel == protocol.Encryption0RTT { - cont, err = p.appendLongHeaderPacket(buffer, appDataHdr, appDataPayload, 0, appDataEncLevel, appDataSealer) + longHdrPacket, err := p.appendLongHeaderPacket(buffer, appDataHdr, appDataPayload, 0, appDataEncLevel, appDataSealer) + if err != nil { + return nil, err + } + packet.longHdrPackets = append(packet.longHdrPackets, longHdrPacket) } else { - cont, err = p.appendShortHeaderPacket(buffer, appDataHdr, appDataPayload, 0, appDataSealer, false) + shortHdrPacket, err := p.appendShortHeaderPacket(buffer, appDataHdr.PacketNumber, appDataHdr.PacketNumberLen, kp, appDataPayload, 0, appDataSealer, false) + if err != nil { + return nil, err + } + packet.shortHdrPacket = shortHdrPacket } - if err != nil { - return nil, err - } - packet.packets = append(packet.packets, cont) } return packet, nil } @@ -460,22 +473,17 @@ func (p *packetPacker) PackPacket(onlyAck bool, now time.Time) (shortHeaderPacke if err != nil { return shortHeaderPacket{}, nil, err } - hdr, payload := p.maybeGetShortHeaderPacket(sealer, p.maxPacketSize, onlyAck, true) + kp := sealer.KeyPhase() + hdr, payload := p.maybeGetShortHeaderPacket(sealer, kp, p.maxPacketSize, onlyAck, true) if payload == nil { return shortHeaderPacket{}, nil, errNothingToPack } buffer := getPacketBuffer() - cont, err := p.appendShortHeaderPacket(buffer, hdr, payload, 0, sealer, false) + packet, err := p.appendShortHeaderPacket(buffer, hdr.PacketNumber, hdr.PacketNumberLen, kp, payload, 0, sealer, false) if err != nil { return shortHeaderPacket{}, nil, err } - return shortHeaderPacket{ - Packet: cont.ToAckHandlerPacket(now, p.retransmissionQueue), - DestConnID: hdr.DestConnectionID, - Ack: payload.ack, - PacketNumberLen: hdr.PacketNumberLen, - KeyPhase: hdr.KeyPhase, - }, buffer, nil + return *packet, buffer, nil } func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, encLevel protocol.EncryptionLevel, onlyAck, ackAllowed bool) (*wire.ExtendedHeader, *payload) { @@ -556,8 +564,8 @@ func (p *packetPacker) maybeGetAppDataPacketFor0RTT(sealer sealer, maxPacketSize return hdr, payload } -func (p *packetPacker) maybeGetShortHeaderPacket(sealer handshake.ShortHeaderSealer, maxPacketSize protocol.ByteCount, onlyAck, ackAllowed bool) (*wire.ExtendedHeader, *payload) { - hdr := p.getShortHeader(sealer.KeyPhase()) +func (p *packetPacker) maybeGetShortHeaderPacket(sealer handshake.ShortHeaderSealer, kp protocol.KeyPhaseBit, maxPacketSize protocol.ByteCount, onlyAck, ackAllowed bool) (*wire.ExtendedHeader, *payload) { + hdr := p.getShortHeader(kp) maxPayloadSize := maxPacketSize - hdr.GetLength(p.version) - protocol.ByteCount(sealer.Overhead()) payload := p.maybeGetAppDataPacket(maxPayloadSize, onlyAck, ackAllowed) return hdr, payload @@ -661,6 +669,7 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel) ( var hdr *wire.ExtendedHeader var payload *payload var sealer sealer + var kp protocol.KeyPhaseBit //nolint:exhaustive // Probe packets are never sent for 0-RTT. switch encLevel { case protocol.EncryptionInitial: @@ -682,6 +691,7 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel) ( if err != nil { return nil, err } + kp = oneRTTSealer.KeyPhase() sealer = oneRTTSealer hdr = p.getShortHeader(oneRTTSealer.KeyPhase()) payload = p.maybeGetAppDataPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead())-hdr.GetLength(p.version), false, true) @@ -697,20 +707,22 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel) ( padding = p.initialPaddingLen(payload.frames, size) } buffer := getPacketBuffer() - var cont *packetContents - var err error + packet := &coalescedPacket{buffer: buffer} if encLevel == protocol.Encryption1RTT { - cont, err = p.appendShortHeaderPacket(buffer, hdr, payload, padding, sealer, false) - } else { - cont, err = p.appendLongHeaderPacket(buffer, hdr, payload, padding, encLevel, sealer) + shortHdrPacket, err := p.appendShortHeaderPacket(buffer, hdr.PacketNumber, hdr.PacketNumberLen, kp, payload, padding, sealer, false) + if err != nil { + return nil, err + } + packet.shortHdrPacket = shortHdrPacket + return packet, nil } + + longHdrPacket, err := p.appendLongHeaderPacket(buffer, hdr, payload, padding, encLevel, sealer) if err != nil { return nil, err } - return &coalescedPacket{ - buffer: buffer, - packets: []*packetContents{cont}, - }, nil + packet.longHdrPackets = []*longHeaderPacket{longHdrPacket} + return packet, nil } func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.ByteCount, now time.Time) (shortHeaderPacket, *packetBuffer, error) { @@ -725,18 +737,11 @@ func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.B } hdr := p.getShortHeader(sealer.KeyPhase()) padding := size - p.packetLength(hdr, payload) - protocol.ByteCount(sealer.Overhead()) - cont, err := p.appendShortHeaderPacket(buffer, hdr, payload, padding, sealer, true) + packet, err := p.appendShortHeaderPacket(buffer, hdr.PacketNumber, hdr.PacketNumberLen, sealer.KeyPhase(), payload, padding, sealer, true) if err != nil { return shortHeaderPacket{}, nil, err } - cont.isMTUProbePacket = true - return shortHeaderPacket{ - Packet: cont.ToAckHandlerPacket(now, p.retransmissionQueue), - DestConnID: hdr.DestConnectionID, - Ack: payload.ack, - PacketNumberLen: hdr.PacketNumberLen, - KeyPhase: hdr.KeyPhase, - }, buffer, nil + return *packet, buffer, nil } func (p *packetPacker) getShortHeader(kp protocol.KeyPhaseBit) *wire.ExtendedHeader { @@ -773,7 +778,7 @@ func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel) *wire.Ex return hdr } -func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire.ExtendedHeader, payload *payload, padding protocol.ByteCount, encLevel protocol.EncryptionLevel, sealer sealer) (*packetContents, error) { +func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire.ExtendedHeader, payload *payload, padding protocol.ByteCount, encLevel protocol.EncryptionLevel, sealer sealer) (*longHeaderPacket, error) { if !header.IsLongHeader { panic("shouldn't have called appendLongHeaderPacket") } @@ -808,7 +813,7 @@ func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire raw = p.encryptPacket(raw, sealer, pn, payloadOffset, pnLen) buffer.Data = buffer.Data[:len(buffer.Data)+len(raw)] - return &packetContents{ + return &longHeaderPacket{ header: header, ack: payload.ack, frames: payload.frames, @@ -816,28 +821,33 @@ func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire }, nil } -func (p *packetPacker) appendShortHeaderPacket(buffer *packetBuffer, header *wire.ExtendedHeader, payload *payload, padding protocol.ByteCount, sealer sealer, isMTUProbePacket bool) (*packetContents, error) { - if header.IsLongHeader { - panic("shouldn't have called appendShortHeaderPacket") - } +func (p *packetPacker) appendShortHeaderPacket( + buffer *packetBuffer, + pn protocol.PacketNumber, + pnLen protocol.PacketNumberLen, + kp protocol.KeyPhaseBit, + payload *payload, + padding protocol.ByteCount, + sealer sealer, + isMTUProbePacket bool, +) (*shortHeaderPacket, error) { var paddingLen protocol.ByteCount - pnLen := protocol.ByteCount(header.PacketNumberLen) - if payload.length < 4-pnLen { - paddingLen = 4 - pnLen - payload.length + if payload.length < 4-protocol.ByteCount(pnLen) { + paddingLen = 4 - protocol.ByteCount(pnLen) - payload.length } paddingLen += padding raw := buffer.Data[len(buffer.Data):] buf := bytes.NewBuffer(buffer.Data) startLen := buf.Len() - if err := header.Write(buf, p.version); err != nil { + connID := p.getDestConnID() + if err := wire.WriteShortHeader(buf, connID, pn, pnLen, kp); err != nil { return nil, err } raw = raw[:buf.Len()-startLen] payloadOffset := protocol.ByteCount(len(raw)) - pn := p.pnManager.PopPacketNumber(protocol.Encryption1RTT) - if pn != header.PacketNumber { + if pn != p.pnManager.PopPacketNumber(protocol.Encryption1RTT) { return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match") } @@ -850,14 +860,36 @@ func (p *packetPacker) appendShortHeaderPacket(buffer *packetBuffer, header *wir return nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize) } } - raw = p.encryptPacket(raw, sealer, pn, payloadOffset, pnLen) + raw = p.encryptPacket(raw, sealer, pn, payloadOffset, protocol.ByteCount(pnLen)) buffer.Data = buffer.Data[:len(buffer.Data)+len(raw)] - return &packetContents{ - header: header, - ack: payload.ack, - frames: payload.frames, - length: protocol.ByteCount(len(raw)), + // create the ackhandler.Packet + largestAcked := protocol.InvalidPacketNumber + if payload.ack != nil { + largestAcked = payload.ack.LargestAcked() + } + for i := range payload.frames { + if payload.frames[i].OnLost != nil { + continue + } + payload.frames[i].OnLost = p.retransmissionQueue.AddAppData + } + + ap := ackhandler.GetPacket() + ap.PacketNumber = pn + ap.LargestAcked = largestAcked + ap.Frames = payload.frames + ap.Length = protocol.ByteCount(len(raw)) + ap.EncryptionLevel = protocol.Encryption1RTT + ap.SendTime = time.Now() + ap.IsPathMTUProbePacket = isMTUProbePacket + + return &shortHeaderPacket{ + Packet: ap, + DestConnID: connID, + Ack: payload.ack, + PacketNumberLen: pnLen, + KeyPhase: kp, }, nil } diff --git a/packet_packer_test.go b/packet_packer_test.go index 7a9a0065..b23b424f 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -193,8 +193,10 @@ var _ = Describe("Packet packer", func() { p, err := packer.PackCoalescedPacket(false) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) - Expect(p.packets).To(HaveLen(1)) - Expect(p.packets[0].frames).To(Equal([]ackhandler.Frame{{Frame: f}})) + Expect(p.longHdrPackets).To(BeEmpty()) + Expect(p.shortHdrPacket).ToNot(BeNil()) + Expect(p.shortHdrPacket.Frames).To(HaveLen(1)) + Expect(p.shortHdrPacket.Frames[0].Frame).To(Equal(f)) hdrRawEncrypted := append([]byte{}, hdrRaw...) hdrRawEncrypted[0] ^= 0xff hdrRawEncrypted[len(hdrRaw)-2] ^= 0xff @@ -241,10 +243,10 @@ var _ = Describe("Packet packer", func() { p, err := packer.PackCoalescedPacket(true) Expect(err).NotTo(HaveOccurred()) Expect(p).ToNot(BeNil()) - Expect(p.packets).To(HaveLen(1)) - Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) - Expect(p.packets[0].ack).To(Equal(ack)) - Expect(p.packets[0].frames).To(BeEmpty()) + Expect(p.longHdrPackets).To(HaveLen(1)) + Expect(p.longHdrPackets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) + Expect(p.longHdrPackets[0].ack).To(Equal(ack)) + Expect(p.longHdrPackets[0].frames).To(BeEmpty()) Expect(p.buffer.Len()).To(BeEquivalentTo(packer.maxPacketSize)) parsePacket(p.buffer.Data) }) @@ -258,10 +260,10 @@ var _ = Describe("Packet packer", func() { p, err := packer.PackCoalescedPacket(true) Expect(err).NotTo(HaveOccurred()) Expect(p).ToNot(BeNil()) - Expect(p.packets).To(HaveLen(1)) - Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) - Expect(p.packets[0].ack).To(Equal(ack)) - Expect(p.packets[0].frames).To(BeEmpty()) + Expect(p.longHdrPackets).To(HaveLen(1)) + Expect(p.longHdrPackets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) + Expect(p.longHdrPackets[0].ack).To(Equal(ack)) + Expect(p.longHdrPackets[0].frames).To(BeEmpty()) Expect(p.buffer.Len()).To(BeNumerically("<", 100)) parsePacket(p.buffer.Data) }) @@ -278,10 +280,10 @@ var _ = Describe("Packet packer", func() { p, err := packer.PackCoalescedPacket(true) Expect(err).NotTo(HaveOccurred()) Expect(p).ToNot(BeNil()) - Expect(p.packets).To(HaveLen(1)) - Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) - Expect(p.packets[0].ack).To(Equal(ack)) - Expect(p.packets[0].frames).To(BeEmpty()) + Expect(p.longHdrPackets).To(BeEmpty()) + Expect(p.shortHdrPacket).ToNot(BeNil()) + Expect(p.shortHdrPacket.Ack).To(Equal(ack)) + Expect(p.shortHdrPacket.Frames).To(BeEmpty()) parsePacket(p.buffer.Data) }) @@ -330,10 +332,10 @@ var _ = Describe("Packet packer", func() { p, err := packer.PackCoalescedPacket(false) Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) - Expect(p.packets).To(HaveLen(1)) - Expect(p.packets[0].header.Type).To(Equal(protocol.PacketType0RTT)) - Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.Encryption0RTT)) - Expect(p.packets[0].frames).To(Equal([]ackhandler.Frame{cf})) + Expect(p.longHdrPackets).To(HaveLen(1)) + Expect(p.longHdrPackets[0].header.Type).To(Equal(protocol.PacketType0RTT)) + Expect(p.longHdrPackets[0].EncryptionLevel()).To(Equal(protocol.Encryption0RTT)) + Expect(p.longHdrPackets[0].frames).To(Equal([]ackhandler.Frame{cf})) }) }) @@ -348,11 +350,11 @@ var _ = Describe("Packet packer", func() { quicErr.FrameType = 0x1234 p, err := packer.PackConnectionClose(quicErr) Expect(err).ToNot(HaveOccurred()) - Expect(p.packets).To(HaveLen(1)) - Expect(p.packets[0].header.Type).To(Equal(protocol.PacketTypeHandshake)) - Expect(p.packets[0].frames).To(HaveLen(1)) - Expect(p.packets[0].frames[0].Frame).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) - ccf := p.packets[0].frames[0].Frame.(*wire.ConnectionCloseFrame) + Expect(p.longHdrPackets).To(HaveLen(1)) + Expect(p.longHdrPackets[0].header.Type).To(Equal(protocol.PacketTypeHandshake)) + Expect(p.longHdrPackets[0].frames).To(HaveLen(1)) + Expect(p.longHdrPackets[0].frames[0].Frame).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) + ccf := p.longHdrPackets[0].frames[0].Frame.(*wire.ConnectionCloseFrame) Expect(ccf.IsApplicationError).To(BeFalse()) Expect(ccf.ErrorCode).To(BeEquivalentTo(0x100 + 0x42)) Expect(ccf.FrameType).To(BeEquivalentTo(0x1234)) @@ -371,11 +373,10 @@ var _ = Describe("Packet packer", func() { ErrorMessage: "test error", }) Expect(err).ToNot(HaveOccurred()) - Expect(p.packets).To(HaveLen(1)) - Expect(p.packets[0].header.IsLongHeader).To(BeFalse()) - Expect(p.packets[0].frames).To(HaveLen(1)) - Expect(p.packets[0].frames[0].Frame).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) - ccf := p.packets[0].frames[0].Frame.(*wire.ConnectionCloseFrame) + Expect(p.longHdrPackets).To(BeEmpty()) + Expect(p.shortHdrPacket.Frames).To(HaveLen(1)) + Expect(p.shortHdrPacket.Frames[0].Frame).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) + ccf := p.shortHdrPacket.Frames[0].Frame.(*wire.ConnectionCloseFrame) Expect(ccf.IsApplicationError).To(BeFalse()) Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.CryptoBufferExceeded)) Expect(ccf.ReasonPhrase).To(Equal("test error")) @@ -396,28 +397,28 @@ var _ = Describe("Packet packer", func() { ErrorMessage: "test error", }) Expect(err).ToNot(HaveOccurred()) - Expect(p.packets).To(HaveLen(3)) - Expect(p.packets[0].header.Type).To(Equal(protocol.PacketTypeInitial)) - Expect(p.packets[0].header.PacketNumber).To(Equal(protocol.PacketNumber(1))) - Expect(p.packets[0].frames).To(HaveLen(1)) - Expect(p.packets[0].frames[0].Frame).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) - ccf := p.packets[0].frames[0].Frame.(*wire.ConnectionCloseFrame) + Expect(p.longHdrPackets).To(HaveLen(2)) + Expect(p.longHdrPackets[0].header.Type).To(Equal(protocol.PacketTypeInitial)) + Expect(p.longHdrPackets[0].header.PacketNumber).To(Equal(protocol.PacketNumber(1))) + Expect(p.longHdrPackets[0].frames).To(HaveLen(1)) + Expect(p.longHdrPackets[0].frames[0].Frame).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) + ccf := p.longHdrPackets[0].frames[0].Frame.(*wire.ConnectionCloseFrame) Expect(ccf.IsApplicationError).To(BeFalse()) Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.ApplicationErrorErrorCode)) Expect(ccf.ReasonPhrase).To(BeEmpty()) - Expect(p.packets[1].header.Type).To(Equal(protocol.PacketTypeHandshake)) - Expect(p.packets[1].header.PacketNumber).To(Equal(protocol.PacketNumber(2))) - Expect(p.packets[1].frames).To(HaveLen(1)) - Expect(p.packets[1].frames[0].Frame).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) - ccf = p.packets[1].frames[0].Frame.(*wire.ConnectionCloseFrame) + Expect(p.longHdrPackets[1].header.Type).To(Equal(protocol.PacketTypeHandshake)) + Expect(p.longHdrPackets[1].header.PacketNumber).To(Equal(protocol.PacketNumber(2))) + Expect(p.longHdrPackets[1].frames).To(HaveLen(1)) + Expect(p.longHdrPackets[1].frames[0].Frame).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) + ccf = p.longHdrPackets[1].frames[0].Frame.(*wire.ConnectionCloseFrame) Expect(ccf.IsApplicationError).To(BeFalse()) Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.ApplicationErrorErrorCode)) Expect(ccf.ReasonPhrase).To(BeEmpty()) - Expect(p.packets[2].header.IsLongHeader).To(BeFalse()) - Expect(p.packets[2].header.PacketNumber).To(Equal(protocol.PacketNumber(3))) - Expect(p.packets[2].frames).To(HaveLen(1)) - Expect(p.packets[2].frames[0].Frame).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) - ccf = p.packets[2].frames[0].Frame.(*wire.ConnectionCloseFrame) + Expect(p.shortHdrPacket).ToNot(BeNil()) + Expect(p.shortHdrPacket.PacketNumber).To(Equal(protocol.PacketNumber(3))) + Expect(p.shortHdrPacket.Frames).To(HaveLen(1)) + Expect(p.shortHdrPacket.Frames[0].Frame).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) + ccf = p.shortHdrPacket.Frames[0].Frame.(*wire.ConnectionCloseFrame) Expect(ccf.IsApplicationError).To(BeTrue()) Expect(ccf.ErrorCode).To(BeEquivalentTo(0x1337)) Expect(ccf.ReasonPhrase).To(Equal("test error")) @@ -438,21 +439,21 @@ var _ = Describe("Packet packer", func() { ErrorMessage: "test error", }) Expect(err).ToNot(HaveOccurred()) - Expect(p.packets).To(HaveLen(2)) + Expect(p.longHdrPackets).To(HaveLen(1)) Expect(p.buffer.Len()).To(BeNumerically("<", protocol.MinInitialPacketSize)) - Expect(p.packets[0].header.Type).To(Equal(protocol.PacketTypeHandshake)) - Expect(p.packets[0].header.PacketNumber).To(Equal(protocol.PacketNumber(1))) - Expect(p.packets[0].frames).To(HaveLen(1)) - Expect(p.packets[0].frames[0].Frame).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) - ccf := p.packets[0].frames[0].Frame.(*wire.ConnectionCloseFrame) + Expect(p.longHdrPackets[0].header.Type).To(Equal(protocol.PacketTypeHandshake)) + Expect(p.longHdrPackets[0].header.PacketNumber).To(Equal(protocol.PacketNumber(1))) + Expect(p.longHdrPackets[0].frames).To(HaveLen(1)) + Expect(p.longHdrPackets[0].frames[0].Frame).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) + ccf := p.longHdrPackets[0].frames[0].Frame.(*wire.ConnectionCloseFrame) Expect(ccf.IsApplicationError).To(BeFalse()) Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.ApplicationErrorErrorCode)) Expect(ccf.ReasonPhrase).To(BeEmpty()) - Expect(p.packets[1].header.IsLongHeader).To(BeFalse()) - Expect(p.packets[1].header.PacketNumber).To(Equal(protocol.PacketNumber(2))) - Expect(p.packets[1].frames).To(HaveLen(1)) - Expect(p.packets[1].frames[0].Frame).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) - ccf = p.packets[1].frames[0].Frame.(*wire.ConnectionCloseFrame) + Expect(p.shortHdrPacket).ToNot(BeNil()) + Expect(p.shortHdrPacket.PacketNumber).To(Equal(protocol.PacketNumber(2))) + Expect(p.shortHdrPacket.Frames).To(HaveLen(1)) + Expect(p.shortHdrPacket.Frames[0].Frame).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) + ccf = p.shortHdrPacket.Frames[0].Frame.(*wire.ConnectionCloseFrame) Expect(ccf.IsApplicationError).To(BeTrue()) Expect(ccf.ErrorCode).To(BeEquivalentTo(0x1337)) Expect(ccf.ReasonPhrase).To(Equal("test error")) @@ -473,22 +474,22 @@ var _ = Describe("Packet packer", func() { ErrorMessage: "test error", }) Expect(err).ToNot(HaveOccurred()) - Expect(p.packets).To(HaveLen(2)) + Expect(p.longHdrPackets).To(HaveLen(2)) Expect(p.buffer.Len()).To(BeNumerically(">=", protocol.MinInitialPacketSize)) Expect(p.buffer.Len()).To(BeEquivalentTo(maxPacketSize)) - Expect(p.packets[0].header.Type).To(Equal(protocol.PacketTypeInitial)) - Expect(p.packets[0].header.PacketNumber).To(Equal(protocol.PacketNumber(1))) - Expect(p.packets[0].frames).To(HaveLen(1)) - Expect(p.packets[0].frames[0].Frame).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) - ccf := p.packets[0].frames[0].Frame.(*wire.ConnectionCloseFrame) + Expect(p.longHdrPackets[0].header.Type).To(Equal(protocol.PacketTypeInitial)) + Expect(p.longHdrPackets[0].header.PacketNumber).To(Equal(protocol.PacketNumber(1))) + Expect(p.longHdrPackets[0].frames).To(HaveLen(1)) + Expect(p.longHdrPackets[0].frames[0].Frame).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) + ccf := p.longHdrPackets[0].frames[0].Frame.(*wire.ConnectionCloseFrame) Expect(ccf.IsApplicationError).To(BeFalse()) Expect(ccf.ErrorCode).To(BeEquivalentTo(qerr.ApplicationErrorErrorCode)) Expect(ccf.ReasonPhrase).To(BeEmpty()) - Expect(p.packets[1].header.Type).To(Equal(protocol.PacketType0RTT)) - Expect(p.packets[1].header.PacketNumber).To(Equal(protocol.PacketNumber(2))) - Expect(p.packets[1].frames).To(HaveLen(1)) - Expect(p.packets[1].frames[0].Frame).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) - ccf = p.packets[1].frames[0].Frame.(*wire.ConnectionCloseFrame) + Expect(p.longHdrPackets[1].header.Type).To(Equal(protocol.PacketType0RTT)) + Expect(p.longHdrPackets[1].header.PacketNumber).To(Equal(protocol.PacketNumber(2))) + Expect(p.longHdrPackets[1].frames).To(HaveLen(1)) + Expect(p.longHdrPackets[1].frames[0].Frame).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) + ccf = p.longHdrPackets[1].frames[0].Frame.(*wire.ConnectionCloseFrame) Expect(ccf.IsApplicationError).To(BeTrue()) Expect(ccf.ErrorCode).To(BeEquivalentTo(0x1337)) Expect(ccf.ReasonPhrase).To(Equal("test error")) @@ -657,7 +658,7 @@ var _ = Describe("Packet packer", func() { packet, err := packer.PackCoalescedPacket(false) Expect(err).ToNot(HaveOccurred()) Expect(packet).ToNot(BeNil()) - Expect(packet.packets).To(HaveLen(1)) + Expect(packet.longHdrPackets).To(HaveLen(1)) // cut off the tag that the mock sealer added // packet.buffer.Data = packet.buffer.Data[:packet.buffer.Len()-protocol.ByteCount(sealer.Overhead())] hdr, _, _, err := wire.ParsePacket(packet.buffer.Data, packer.getDestConnID().Len()) @@ -978,10 +979,10 @@ var _ = Describe("Packet packer", func() { Expect(err).ToNot(HaveOccurred()) Expect(p.buffer.Len()).To(BeNumerically(">=", protocol.MinInitialPacketSize)) Expect(p.buffer.Len()).To(BeEquivalentTo(maxPacketSize)) - Expect(p.packets).To(HaveLen(1)) - Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) - Expect(p.packets[0].frames).To(HaveLen(1)) - Expect(p.packets[0].frames[0].Frame.(*wire.CryptoFrame).Data).To(Equal([]byte("initial"))) + Expect(p.longHdrPackets).To(HaveLen(1)) + Expect(p.longHdrPackets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) + Expect(p.longHdrPackets[0].frames).To(HaveLen(1)) + Expect(p.longHdrPackets[0].frames[0].Frame.(*wire.CryptoFrame).Data).To(Equal([]byte("initial"))) hdrs := parsePacket(p.buffer.Data) Expect(hdrs).To(HaveLen(1)) Expect(hdrs[0].Type).To(Equal(protocol.PacketTypeInitial)) @@ -1005,9 +1006,9 @@ var _ = Describe("Packet packer", func() { }) p, err := packer.PackCoalescedPacket(false) Expect(err).ToNot(HaveOccurred()) - Expect(p.packets).To(HaveLen(1)) - Expect(p.packets[0].frames).To(HaveLen(1)) - Expect(p.packets[0].header.IsLongHeader).To(BeTrue()) + Expect(p.longHdrPackets).To(HaveLen(1)) + Expect(p.longHdrPackets[0].frames).To(HaveLen(1)) + Expect(p.longHdrPackets[0].header.IsLongHeader).To(BeTrue()) Expect(p.buffer.Len()).To(BeEquivalentTo(packer.maxPacketSize)) parsePacket(p.buffer.Data) }) @@ -1033,13 +1034,14 @@ var _ = Describe("Packet packer", func() { p, err := packer.PackCoalescedPacket(false) Expect(err).ToNot(HaveOccurred()) Expect(p.buffer.Len()).To(BeEquivalentTo(packer.maxPacketSize)) - Expect(p.packets).To(HaveLen(2)) - Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) - Expect(p.packets[0].frames).To(HaveLen(1)) - Expect(p.packets[0].frames[0].Frame.(*wire.CryptoFrame).Data).To(Equal([]byte("initial"))) - Expect(p.packets[1].EncryptionLevel()).To(Equal(protocol.EncryptionHandshake)) - Expect(p.packets[1].frames).To(HaveLen(1)) - Expect(p.packets[1].frames[0].Frame.(*wire.CryptoFrame).Data).To(Equal([]byte("handshake"))) + Expect(p.longHdrPackets).To(HaveLen(2)) + Expect(p.shortHdrPacket).To(BeNil()) + Expect(p.longHdrPackets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) + Expect(p.longHdrPackets[0].frames).To(HaveLen(1)) + Expect(p.longHdrPackets[0].frames[0].Frame.(*wire.CryptoFrame).Data).To(Equal([]byte("initial"))) + Expect(p.longHdrPackets[1].EncryptionLevel()).To(Equal(protocol.EncryptionHandshake)) + Expect(p.longHdrPackets[1].frames).To(HaveLen(1)) + Expect(p.longHdrPackets[1].frames[0].Frame.(*wire.CryptoFrame).Data).To(Equal([]byte("handshake"))) hdrs := parsePacket(p.buffer.Data) Expect(hdrs).To(HaveLen(2)) Expect(hdrs[0].Type).To(Equal(protocol.PacketTypeInitial)) @@ -1065,13 +1067,14 @@ var _ = Describe("Packet packer", func() { p, err := packer.PackCoalescedPacket(false) Expect(err).ToNot(HaveOccurred()) Expect(p.buffer.Len()).To(BeEquivalentTo(packer.maxPacketSize)) - Expect(p.packets).To(HaveLen(2)) - Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) - Expect(p.packets[0].frames).To(HaveLen(1)) - Expect(p.packets[0].frames[0].Frame.(*wire.CryptoFrame).Data).To(Equal([]byte("initial"))) - Expect(p.packets[1].EncryptionLevel()).To(Equal(protocol.EncryptionHandshake)) - Expect(p.packets[1].frames).To(HaveLen(1)) - Expect(p.packets[1].frames[0].Frame).To(BeAssignableToTypeOf(&wire.PingFrame{})) + Expect(p.longHdrPackets).To(HaveLen(2)) + Expect(p.shortHdrPacket).To(BeNil()) + Expect(p.longHdrPackets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) + Expect(p.longHdrPackets[0].frames).To(HaveLen(1)) + Expect(p.longHdrPackets[0].frames[0].Frame.(*wire.CryptoFrame).Data).To(Equal([]byte("initial"))) + Expect(p.longHdrPackets[1].EncryptionLevel()).To(Equal(protocol.EncryptionHandshake)) + Expect(p.longHdrPackets[1].frames).To(HaveLen(1)) + Expect(p.longHdrPackets[1].frames[0].Frame).To(BeAssignableToTypeOf(&wire.PingFrame{})) hdrs := parsePacket(p.buffer.Data) Expect(hdrs).To(HaveLen(2)) Expect(hdrs[0].Type).To(Equal(protocol.PacketTypeInitial)) @@ -1094,13 +1097,13 @@ var _ = Describe("Packet packer", func() { p, err := packer.PackCoalescedPacket(false) Expect(err).ToNot(HaveOccurred()) Expect(p.buffer.Len()).To(BeEquivalentTo(packer.maxPacketSize)) - Expect(p.packets).To(HaveLen(2)) - Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) - Expect(p.packets[0].frames).To(HaveLen(1)) - Expect(p.packets[0].frames[0].Frame).To(BeAssignableToTypeOf(&wire.PingFrame{})) - Expect(p.packets[1].EncryptionLevel()).To(Equal(protocol.EncryptionHandshake)) - Expect(p.packets[1].frames).To(HaveLen(1)) - Expect(p.packets[1].frames[0].Frame).To(BeAssignableToTypeOf(&wire.PingFrame{})) + Expect(p.longHdrPackets).To(HaveLen(2)) + Expect(p.longHdrPackets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) + Expect(p.longHdrPackets[0].frames).To(HaveLen(1)) + Expect(p.longHdrPackets[0].frames[0].Frame).To(BeAssignableToTypeOf(&wire.PingFrame{})) + Expect(p.longHdrPackets[1].EncryptionLevel()).To(Equal(protocol.EncryptionHandshake)) + Expect(p.longHdrPackets[1].frames).To(HaveLen(1)) + Expect(p.longHdrPackets[1].frames[0].Frame).To(BeAssignableToTypeOf(&wire.PingFrame{})) hdrs := parsePacket(p.buffer.Data) Expect(hdrs).To(HaveLen(2)) Expect(hdrs[0].Type).To(Equal(protocol.PacketTypeInitial)) @@ -1127,13 +1130,13 @@ var _ = Describe("Packet packer", func() { p, err := packer.PackCoalescedPacket(false) Expect(err).ToNot(HaveOccurred()) Expect(p.buffer.Len()).To(BeEquivalentTo(packer.maxPacketSize)) - Expect(p.packets).To(HaveLen(2)) - Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) - Expect(p.packets[0].frames).To(HaveLen(1)) - Expect(p.packets[0].frames[0].Frame.(*wire.CryptoFrame).Data).To(Equal([]byte("initial"))) - Expect(p.packets[1].EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) - Expect(p.packets[1].frames).To(HaveLen(1)) - Expect(p.packets[1].frames[0].Frame).To(BeAssignableToTypeOf(&wire.PingFrame{})) + Expect(p.longHdrPackets).To(HaveLen(1)) + Expect(p.longHdrPackets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) + Expect(p.longHdrPackets[0].frames).To(HaveLen(1)) + Expect(p.longHdrPackets[0].frames[0].Frame.(*wire.CryptoFrame).Data).To(Equal([]byte("initial"))) + Expect(p.shortHdrPacket).ToNot(BeNil()) + Expect(p.shortHdrPacket.Frames).To(HaveLen(1)) + Expect(p.shortHdrPacket.Frames[0].Frame).To(BeAssignableToTypeOf(&wire.PingFrame{})) hdrs := parsePacket(p.buffer.Data) Expect(hdrs).To(HaveLen(2)) Expect(hdrs[0].Type).To(Equal(protocol.PacketTypeInitial)) @@ -1163,13 +1166,13 @@ var _ = Describe("Packet packer", func() { Expect(err).ToNot(HaveOccurred()) Expect(p.buffer.Len()).To(BeNumerically(">=", protocol.MinInitialPacketSize)) Expect(p.buffer.Len()).To(BeEquivalentTo(maxPacketSize)) - Expect(p.packets).To(HaveLen(2)) - Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) - Expect(p.packets[0].frames).To(HaveLen(1)) - Expect(p.packets[0].frames[0].Frame.(*wire.CryptoFrame).Data).To(Equal([]byte("initial"))) - Expect(p.packets[1].EncryptionLevel()).To(Equal(protocol.Encryption0RTT)) - Expect(p.packets[1].frames).To(HaveLen(1)) - Expect(p.packets[1].frames[0].Frame.(*wire.StreamFrame).Data).To(Equal([]byte("foobar"))) + Expect(p.longHdrPackets).To(HaveLen(2)) + Expect(p.longHdrPackets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) + Expect(p.longHdrPackets[0].frames).To(HaveLen(1)) + Expect(p.longHdrPackets[0].frames[0].Frame.(*wire.CryptoFrame).Data).To(Equal([]byte("initial"))) + Expect(p.longHdrPackets[1].EncryptionLevel()).To(Equal(protocol.Encryption0RTT)) + Expect(p.longHdrPackets[1].frames).To(HaveLen(1)) + Expect(p.longHdrPackets[1].frames[0].Frame.(*wire.StreamFrame).Data).To(Equal([]byte("foobar"))) hdrs := parsePacket(p.buffer.Data) Expect(hdrs).To(HaveLen(2)) Expect(hdrs[0].Type).To(Equal(protocol.PacketTypeInitial)) @@ -1196,13 +1199,13 @@ var _ = Describe("Packet packer", func() { p, err := packer.PackCoalescedPacket(false) Expect(err).ToNot(HaveOccurred()) Expect(p.buffer.Len()).To(BeNumerically("<", 100)) - Expect(p.packets).To(HaveLen(2)) - Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionHandshake)) - Expect(p.packets[0].frames).To(HaveLen(1)) - Expect(p.packets[0].frames[0].Frame.(*wire.CryptoFrame).Data).To(Equal([]byte("handshake"))) - Expect(p.packets[1].EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) - Expect(p.packets[1].frames).To(HaveLen(1)) - Expect(p.packets[1].frames[0].Frame.(*wire.StreamFrame).Data).To(Equal([]byte("foobar"))) + Expect(p.longHdrPackets).To(HaveLen(1)) + Expect(p.longHdrPackets[0].EncryptionLevel()).To(Equal(protocol.EncryptionHandshake)) + Expect(p.longHdrPackets[0].frames).To(HaveLen(1)) + Expect(p.longHdrPackets[0].frames[0].Frame.(*wire.CryptoFrame).Data).To(Equal([]byte("handshake"))) + Expect(p.shortHdrPacket).ToNot(BeNil()) + Expect(p.shortHdrPacket.Frames).To(HaveLen(1)) + Expect(p.shortHdrPacket.Frames[0].Frame.(*wire.StreamFrame).Data).To(Equal([]byte("foobar"))) hdr, _, rest, err := wire.ParsePacket(p.buffer.Data, 0) Expect(err).ToNot(HaveOccurred()) Expect(hdr.Type).To(Equal(protocol.PacketTypeHandshake)) @@ -1229,8 +1232,9 @@ var _ = Describe("Packet packer", func() { }) p, err := packer.PackCoalescedPacket(false) Expect(err).ToNot(HaveOccurred()) - Expect(p.packets).To(HaveLen(1)) - Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionHandshake)) + Expect(p.longHdrPackets).To(HaveLen(1)) + Expect(p.shortHdrPacket).To(BeNil()) + Expect(p.longHdrPackets[0].EncryptionLevel()).To(Equal(protocol.EncryptionHandshake)) Expect(len(p.buffer.Data)).To(BeEquivalentTo(maxPacketSize - protocol.MinCoalescedPacketSize)) parsePacket(p.buffer.Data) }) @@ -1248,7 +1252,8 @@ var _ = Describe("Packet packer", func() { packet, err := packer.PackCoalescedPacket(false) Expect(err).ToNot(HaveOccurred()) Expect(packet).ToNot(BeNil()) - Expect(packet.packets).To(HaveLen(1)) + Expect(packet.longHdrPackets).To(HaveLen(1)) + Expect(packet.shortHdrPacket).To(BeNil()) // cut off the tag that the mock sealer added // packet.buffer.Data = packet.buffer.Data[:packet.buffer.Len()-protocol.ByteCount(sealer.Overhead())] hdr, _, _, err := wire.ParsePacket(packet.buffer.Data, packer.getDestConnID().Len()) @@ -1287,10 +1292,10 @@ var _ = Describe("Packet packer", func() { initialStream.EXPECT().HasData() p, err := packer.PackCoalescedPacket(false) Expect(err).ToNot(HaveOccurred()) - Expect(p.packets).To(HaveLen(1)) - Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) - Expect(p.packets[0].frames).To(Equal([]ackhandler.Frame{{Frame: f}})) - Expect(p.packets[0].header.IsLongHeader).To(BeTrue()) + Expect(p.longHdrPackets).To(HaveLen(1)) + Expect(p.longHdrPackets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) + Expect(p.longHdrPackets[0].frames).To(Equal([]ackhandler.Frame{{Frame: f}})) + Expect(p.longHdrPackets[0].header.IsLongHeader).To(BeTrue()) }) It("sends an Initial packet containing only an ACK", func() { @@ -1304,8 +1309,8 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) p, err := packer.PackCoalescedPacket(false) Expect(err).ToNot(HaveOccurred()) - Expect(p.packets).To(HaveLen(1)) - Expect(p.packets[0].ack).To(Equal(ack)) + Expect(p.longHdrPackets).To(HaveLen(1)) + Expect(p.longHdrPackets[0].ack).To(Equal(ack)) }) It("doesn't pack anything if there's nothing to send at Initial and Handshake keys are not yet available", func() { @@ -1332,8 +1337,8 @@ var _ = Describe("Packet packer", func() { pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) p, err := packer.PackCoalescedPacket(false) Expect(err).ToNot(HaveOccurred()) - Expect(p.packets).To(HaveLen(1)) - Expect(p.packets[0].ack).To(Equal(ack)) + Expect(p.longHdrPackets).To(HaveLen(1)) + Expect(p.longHdrPackets[0].ack).To(Equal(ack)) }) for _, pers := range []protocol.Perspective{protocol.PerspectiveServer, protocol.PerspectiveClient} { @@ -1357,10 +1362,10 @@ var _ = Describe("Packet packer", func() { Expect(err).ToNot(HaveOccurred()) Expect(p.buffer.Len()).To(BeNumerically(">=", protocol.MinInitialPacketSize)) Expect(p.buffer.Len()).To(BeEquivalentTo(maxPacketSize)) - Expect(p.packets).To(HaveLen(1)) - Expect(p.packets[0].header.Token).To(Equal(token)) - Expect(p.packets[0].frames).To(HaveLen(1)) - cf := p.packets[0].frames[0].Frame.(*wire.CryptoFrame) + Expect(p.longHdrPackets).To(HaveLen(1)) + Expect(p.longHdrPackets[0].header.Token).To(Equal(token)) + Expect(p.longHdrPackets[0].frames).To(HaveLen(1)) + cf := p.longHdrPackets[0].frames[0].Frame.(*wire.CryptoFrame) Expect(cf.Data).To(Equal([]byte("foobar"))) }) } @@ -1381,9 +1386,9 @@ var _ = Describe("Packet packer", func() { packer.perspective = protocol.PerspectiveClient p, err := packer.PackCoalescedPacket(false) Expect(err).ToNot(HaveOccurred()) - Expect(p.packets).To(HaveLen(1)) - Expect(p.packets[0].ack).To(Equal(ack)) - Expect(p.packets[0].frames).To(HaveLen(1)) + Expect(p.longHdrPackets).To(HaveLen(1)) + Expect(p.longHdrPackets[0].ack).To(Equal(ack)) + Expect(p.longHdrPackets[0].frames).To(HaveLen(1)) Expect(p.buffer.Len()).To(BeEquivalentTo(maxPacketSize)) }) }) @@ -1405,8 +1410,8 @@ var _ = Describe("Packet packer", func() { p, err := packer.MaybePackProbePacket(protocol.EncryptionInitial) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) - Expect(p.packets).To(HaveLen(1)) - packet := p.packets[0] + Expect(p.longHdrPackets).To(HaveLen(1)) + packet := p.longHdrPackets[0] Expect(packet.EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) Expect(p.buffer.Len()).To(BeNumerically(">=", protocol.MinInitialPacketSize)) Expect(p.buffer.Len()).To(BeEquivalentTo(maxPacketSize)) @@ -1428,8 +1433,8 @@ var _ = Describe("Packet packer", func() { Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) - Expect(p.packets).To(HaveLen(1)) - packet := p.packets[0] + Expect(p.longHdrPackets).To(HaveLen(1)) + packet := p.longHdrPackets[0] Expect(packet.EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) Expect(p.buffer.Len()).To(BeNumerically(">=", protocol.MinInitialPacketSize)) Expect(p.buffer.Len()).To(BeEquivalentTo(maxPacketSize)) @@ -1452,8 +1457,8 @@ var _ = Describe("Packet packer", func() { Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) - Expect(p.packets).To(HaveLen(1)) - packet := p.packets[0] + Expect(p.longHdrPackets).To(HaveLen(1)) + packet := p.longHdrPackets[0] Expect(packet.EncryptionLevel()).To(Equal(protocol.EncryptionHandshake)) Expect(packet.frames).To(HaveLen(1)) Expect(packet.frames[0].Frame).To(Equal(f)) @@ -1473,8 +1478,8 @@ var _ = Describe("Packet packer", func() { Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) - Expect(p.packets).To(HaveLen(1)) - packet := p.packets[0] + Expect(p.longHdrPackets).To(HaveLen(1)) + packet := p.longHdrPackets[0] Expect(packet.EncryptionLevel()).To(Equal(protocol.EncryptionHandshake)) Expect(packet.frames).To(HaveLen(1)) Expect(packet.frames[0].Frame).To(BeAssignableToTypeOf(&wire.CryptoFrame{})) @@ -1496,11 +1501,11 @@ var _ = Describe("Packet packer", func() { p, err := packer.MaybePackProbePacket(protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) - Expect(p.packets).To(HaveLen(1)) - packet := p.packets[0] - Expect(packet.EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) - Expect(packet.frames).To(HaveLen(1)) - Expect(packet.frames[0].Frame).To(Equal(f)) + Expect(p.longHdrPackets).To(BeEmpty()) + Expect(p.shortHdrPacket).ToNot(BeNil()) + packet := p.shortHdrPacket + Expect(packet.Frames).To(HaveLen(1)) + Expect(packet.Frames[0].Frame).To(Equal(f)) }) It("packs a full size 1-RTT probe packet", func() { @@ -1521,12 +1526,12 @@ var _ = Describe("Packet packer", func() { p, err := packer.MaybePackProbePacket(protocol.Encryption1RTT) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) - Expect(p.packets).To(HaveLen(1)) - packet := p.packets[0] - Expect(packet.EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) - Expect(packet.frames).To(HaveLen(1)) - Expect(packet.frames[0].Frame).To(BeAssignableToTypeOf(&wire.StreamFrame{})) - Expect(packet.length).To(Equal(maxPacketSize)) + Expect(p.longHdrPackets).To(BeEmpty()) + Expect(p.shortHdrPacket).ToNot(BeNil()) + packet := p.shortHdrPacket + Expect(packet.Frames).To(HaveLen(1)) + Expect(packet.Frames[0].Frame).To(BeAssignableToTypeOf(&wire.StreamFrame{})) + Expect(packet.Length).To(Equal(maxPacketSize)) }) It("returns nil if there's no probe data to send", func() { @@ -1557,10 +1562,10 @@ var _ = Describe("Packet packer", func() { }) }) -var _ = Describe("Converting to AckHandler packets", func() { +var _ = Describe("Converting to ackhandler.Packet", func() { It("convert a packet", func() { - packet := &packetContents{ - header: &wire.ExtendedHeader{Header: wire.Header{}}, + packet := &longHeaderPacket{ + header: &wire.ExtendedHeader{Header: wire.Header{IsLongHeader: true, Type: protocol.PacketTypeInitial}}, frames: []ackhandler.Frame{{Frame: &wire.MaxDataFrame{}}, {Frame: &wire.PingFrame{}}}, ack: &wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 100, Smallest: 80}}}, length: 42, @@ -1574,27 +1579,19 @@ var _ = Describe("Converting to AckHandler packets", func() { }) It("sets the LargestAcked to invalid, if the packet doesn't have an ACK frame", func() { - packet := &packetContents{ - header: &wire.ExtendedHeader{Header: wire.Header{}}, + packet := &longHeaderPacket{ + header: &wire.ExtendedHeader{Header: wire.Header{IsLongHeader: true, Type: protocol.PacketTypeHandshake}}, frames: []ackhandler.Frame{{Frame: &wire.MaxDataFrame{}}, {Frame: &wire.PingFrame{}}}, } p := packet.ToAckHandlerPacket(time.Now(), nil) Expect(p.LargestAcked).To(Equal(protocol.InvalidPacketNumber)) }) - It("marks MTU probe packets", func() { - packet := &packetContents{ - header: &wire.ExtendedHeader{Header: wire.Header{}}, - isMTUProbePacket: true, - } - Expect(packet.ToAckHandlerPacket(time.Now(), nil).IsPathMTUProbePacket).To(BeTrue()) - }) - DescribeTable( "doesn't overwrite the OnLost callback, if it is set", func(hdr wire.Header) { var pingLost bool - packet := &packetContents{ + packet := &longHeaderPacket{ header: &wire.ExtendedHeader{Header: hdr}, frames: []ackhandler.Frame{ {Frame: &wire.MaxDataFrame{}}, @@ -1610,6 +1607,5 @@ var _ = Describe("Converting to AckHandler packets", func() { Entry(protocol.EncryptionInitial.String(), wire.Header{IsLongHeader: true, Type: protocol.PacketTypeInitial}), Entry(protocol.EncryptionHandshake.String(), wire.Header{IsLongHeader: true, Type: protocol.PacketTypeHandshake}), Entry(protocol.Encryption0RTT.String(), wire.Header{IsLongHeader: true, Type: protocol.PacketType0RTT}), - Entry(protocol.Encryption1RTT.String(), wire.Header{}), ) })