Don't pass the stop waiting frame to the packer as a parameter

This commit is contained in:
Lucas Clemente 2017-06-19 18:11:02 +02:00
parent 495399ede6
commit cbb640b093
3 changed files with 103 additions and 91 deletions

View file

@ -29,6 +29,7 @@ type packetPacker struct {
streamFramer *streamFramer streamFramer *streamFramer
controlFrames []frames.Frame controlFrames []frames.Frame
stopWaiting *frames.StopWaitingFrame
} }
func newPacketPacker(connectionID protocol.ConnectionID, func newPacketPacker(connectionID protocol.ConnectionID,
@ -54,28 +55,25 @@ func (p *packetPacker) PackConnectionClose(ccf *frames.ConnectionCloseFrame, lea
// in case the connection is closed, all queued control frames aren't of any use anymore // in case the connection is closed, all queued control frames aren't of any use anymore
// discard them and queue the ConnectionCloseFrame // discard them and queue the ConnectionCloseFrame
p.controlFrames = []frames.Frame{ccf} p.controlFrames = []frames.Frame{ccf}
return p.packPacket(nil, leastUnacked, nil) return p.packPacket(leastUnacked, nil)
} }
// RetransmitNonForwardSecurePacket retransmits a handshake packet, that was sent with less than forward-secure encryption // RetransmitNonForwardSecurePacket retransmits a handshake packet, that was sent with less than forward-secure encryption
func (p *packetPacker) RetransmitNonForwardSecurePacket(stopWaitingFrame *frames.StopWaitingFrame, packet *ackhandler.Packet) (*packedPacket, error) { func (p *packetPacker) RetransmitNonForwardSecurePacket(packet *ackhandler.Packet) (*packedPacket, error) {
if packet.EncryptionLevel == protocol.EncryptionForwardSecure { if packet.EncryptionLevel == protocol.EncryptionForwardSecure {
return nil, errors.New("PacketPacker BUG: forward-secure encrypted handshake packets don't need special treatment") return nil, errors.New("PacketPacker BUG: forward-secure encrypted handshake packets don't need special treatment")
} }
if stopWaitingFrame == nil {
return nil, errors.New("PacketPacker BUG: Handshake retransmissions must contain a StopWaitingFrame") return p.packPacket(0, packet)
}
return p.packPacket(stopWaitingFrame, 0, packet)
} }
// PackPacket packs a new packet // PackPacket packs a new packet
// the stopWaitingFrame is *guaranteed* to be included in the next packet
// the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise // the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise
func (p *packetPacker) PackPacket(stopWaitingFrame *frames.StopWaitingFrame, leastUnacked protocol.PacketNumber) (*packedPacket, error) { func (p *packetPacker) PackPacket(leastUnacked protocol.PacketNumber) (*packedPacket, error) {
return p.packPacket(stopWaitingFrame, leastUnacked, nil) return p.packPacket(leastUnacked, nil)
} }
func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, leastUnacked protocol.PacketNumber, handshakePacketToRetransmit *ackhandler.Packet) (*packedPacket, error) { func (p *packetPacker) packPacket(leastUnacked protocol.PacketNumber, handshakePacketToRetransmit *ackhandler.Packet) (*packedPacket, error) {
// handshakePacketToRetransmit is only set for handshake retransmissions // handshakePacketToRetransmit is only set for handshake retransmissions
isHandshakeRetransmission := (handshakePacketToRetransmit != nil) isHandshakeRetransmission := (handshakePacketToRetransmit != nil)
isCryptoStreamFrame := p.streamFramer.HasCryptoStreamFrame() isCryptoStreamFrame := p.streamFramer.HasCryptoStreamFrame()
@ -103,9 +101,9 @@ func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, lea
return nil, err return nil, err
} }
if stopWaitingFrame != nil { if p.stopWaiting != nil {
stopWaitingFrame.PacketNumber = publicHeader.PacketNumber p.stopWaiting.PacketNumber = publicHeader.PacketNumber
stopWaitingFrame.PacketNumberLen = publicHeader.PacketNumberLen p.stopWaiting.PacketNumberLen = publicHeader.PacketNumberLen
} }
// we're packing a ConnectionClose, don't add any StreamFrames // we're packing a ConnectionClose, don't add any StreamFrames
@ -116,7 +114,11 @@ func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, lea
var payloadFrames []frames.Frame var payloadFrames []frames.Frame
if isHandshakeRetransmission { if isHandshakeRetransmission {
payloadFrames = append(payloadFrames, stopWaitingFrame) // Find the SWF
if p.stopWaiting == nil {
return nil, errors.New("PacketPacker BUG: Handshake retransmissions must contain a StopWaitingFrame")
}
payloadFrames = append(payloadFrames, p.stopWaiting)
// don't retransmit Acks and StopWaitings // don't retransmit Acks and StopWaitings
for _, f := range handshakePacketToRetransmit.Frames { for _, f := range handshakePacketToRetransmit.Frames {
switch f.(type) { switch f.(type) {
@ -134,7 +136,7 @@ func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, lea
payloadFrames = []frames.Frame{p.streamFramer.PopCryptoStreamFrame(maxLen)} payloadFrames = []frames.Frame{p.streamFramer.PopCryptoStreamFrame(maxLen)}
} else { } else {
maxSize := protocol.MaxFrameAndPublicHeaderSize - publicHeaderLength maxSize := protocol.MaxFrameAndPublicHeaderSize - publicHeaderLength
payloadFrames, err = p.composeNextPacket(stopWaitingFrame, maxSize, p.canSendData(encLevel)) payloadFrames, err = p.composeNextPacket(maxSize, p.canSendData(encLevel))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -145,12 +147,12 @@ func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, lea
return nil, nil return nil, nil
} }
// Don't send out packets that only contain a StopWaitingFrame // Don't send out packets that only contain a StopWaitingFrame
if len(payloadFrames) == 1 { if len(payloadFrames) == 1 && p.stopWaiting != nil {
if _, ok := payloadFrames[0].(*frames.StopWaitingFrame); ok { return nil, nil
return nil, nil
}
} }
p.stopWaiting = nil
raw := getPacketBuffer() raw := getPacketBuffer()
buffer := bytes.NewBuffer(raw) buffer := bytes.NewBuffer(raw)
@ -189,25 +191,21 @@ func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, lea
} }
func (p *packetPacker) composeNextPacket( func (p *packetPacker) composeNextPacket(
stopWaitingFrame *frames.StopWaitingFrame,
maxFrameSize protocol.ByteCount, maxFrameSize protocol.ByteCount,
canSendStreamFrames bool, canSendStreamFrames bool,
) ([]frames.Frame, error) { ) ([]frames.Frame, error) {
var payloadLength protocol.ByteCount var payloadLength protocol.ByteCount
var payloadFrames []frames.Frame var payloadFrames []frames.Frame
if stopWaitingFrame != nil { if p.stopWaiting != nil {
payloadFrames = append(payloadFrames, stopWaitingFrame) p.controlFrames = append(p.controlFrames, p.stopWaiting)
minLength, err := stopWaitingFrame.MinLength(p.version) }
for len(p.controlFrames) > 0 {
frame := p.controlFrames[len(p.controlFrames)-1]
minLength, err := frame.MinLength(p.version)
if err != nil { if err != nil {
return nil, err return nil, err
} }
payloadLength += minLength
}
for len(p.controlFrames) > 0 {
frame := p.controlFrames[len(p.controlFrames)-1]
minLength, _ := frame.MinLength(p.version) // controlFrames does not contain any StopWaitingFrames. So it will *never* return an error
if payloadLength+minLength > maxFrameSize { if payloadLength+minLength > maxFrameSize {
break break
} }
@ -247,7 +245,11 @@ func (p *packetPacker) composeNextPacket(
} }
func (p *packetPacker) QueueControlFrameForNextPacket(f frames.Frame) { func (p *packetPacker) QueueControlFrameForNextPacket(f frames.Frame) {
p.controlFrames = append(p.controlFrames, f) if swf, ok := f.(*frames.StopWaitingFrame); ok {
p.stopWaiting = swf
} else {
p.controlFrames = append(p.controlFrames, f)
}
} }
func (p *packetPacker) getPublicHeader(leastUnacked protocol.PacketNumber, encLevel protocol.EncryptionLevel) *PublicHeader { func (p *packetPacker) getPublicHeader(leastUnacked protocol.PacketNumber, encLevel protocol.EncryptionLevel) *PublicHeader {

View file

@ -79,7 +79,7 @@ var _ = Describe("Packet packer", func() {
}) })
It("returns nil when no packet is queued", func() { It("returns nil when no packet is queued", func() {
p, err := packer.PackPacket(nil, 0) p, err := packer.PackPacket(0)
Expect(p).To(BeNil()) Expect(p).To(BeNil())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
}) })
@ -90,7 +90,7 @@ var _ = Describe("Packet packer", func() {
Data: []byte{0xDE, 0xCA, 0xFB, 0xAD}, Data: []byte{0xDE, 0xCA, 0xFB, 0xAD},
} }
streamFramer.AddFrameForRetransmission(f) streamFramer.AddFrameForRetransmission(f)
p, err := packer.PackPacket(nil, 0) p, err := packer.PackPacket(0)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil()) Expect(p).ToNot(BeNil())
b := &bytes.Buffer{} b := &bytes.Buffer{}
@ -106,7 +106,7 @@ var _ = Describe("Packet packer", func() {
Data: []byte("foobar"), Data: []byte("foobar"),
} }
streamFramer.AddFrameForRetransmission(f) streamFramer.AddFrameForRetransmission(f)
p, err := packer.PackPacket(nil, 0) p, err := packer.PackPacket(0)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p.encryptionLevel).To(Equal(protocol.EncryptionForwardSecure)) Expect(p.encryptionLevel).To(Equal(protocol.EncryptionForwardSecure))
}) })
@ -171,7 +171,7 @@ var _ = Describe("Packet packer", func() {
It("packs only control frames", func() { It("packs only control frames", func() {
packer.QueueControlFrameForNextPacket(&frames.RstStreamFrame{}) packer.QueueControlFrameForNextPacket(&frames.RstStreamFrame{})
packer.QueueControlFrameForNextPacket(&frames.WindowUpdateFrame{}) packer.QueueControlFrameForNextPacket(&frames.WindowUpdateFrame{})
p, err := packer.PackPacket(nil, 0) p, err := packer.PackPacket(0)
Expect(p).ToNot(BeNil()) Expect(p).ToNot(BeNil())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(HaveLen(2)) Expect(p.frames).To(HaveLen(2))
@ -180,11 +180,11 @@ var _ = Describe("Packet packer", func() {
It("increases the packet number", func() { It("increases the packet number", func() {
packer.QueueControlFrameForNextPacket(&frames.RstStreamFrame{}) packer.QueueControlFrameForNextPacket(&frames.RstStreamFrame{})
p1, err := packer.PackPacket(nil, 0) p1, err := packer.PackPacket(0)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p1).ToNot(BeNil()) Expect(p1).ToNot(BeNil())
packer.QueueControlFrameForNextPacket(&frames.RstStreamFrame{}) packer.QueueControlFrameForNextPacket(&frames.RstStreamFrame{})
p2, err := packer.PackPacket(nil, 0) p2, err := packer.PackPacket(0)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p2).ToNot(BeNil()) Expect(p2).ToNot(BeNil())
Expect(p2.number).To(BeNumerically(">", p1.number)) Expect(p2.number).To(BeNumerically(">", p1.number))
@ -194,7 +194,8 @@ var _ = Describe("Packet packer", func() {
packer.packetNumberGenerator.next = 15 packer.packetNumberGenerator.next = 15
swf := &frames.StopWaitingFrame{LeastUnacked: 10} swf := &frames.StopWaitingFrame{LeastUnacked: 10}
packer.QueueControlFrameForNextPacket(&frames.RstStreamFrame{}) packer.QueueControlFrameForNextPacket(&frames.RstStreamFrame{})
p, err := packer.PackPacket(swf, 0) packer.QueueControlFrameForNextPacket(swf)
p, err := packer.PackPacket(0)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil()) Expect(p).ToNot(BeNil())
Expect(p.frames).To(HaveLen(2)) Expect(p.frames).To(HaveLen(2))
@ -206,21 +207,23 @@ var _ = Describe("Packet packer", func() {
packer.packetNumberGenerator.next = packetNumber packer.packetNumberGenerator.next = packetNumber
swf := &frames.StopWaitingFrame{LeastUnacked: packetNumber - 0x100} swf := &frames.StopWaitingFrame{LeastUnacked: packetNumber - 0x100}
packer.QueueControlFrameForNextPacket(&frames.RstStreamFrame{}) packer.QueueControlFrameForNextPacket(&frames.RstStreamFrame{})
p, err := packer.PackPacket(swf, 0) packer.QueueControlFrameForNextPacket(swf)
p, err := packer.PackPacket(0)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p.frames[0].(*frames.StopWaitingFrame).PacketNumberLen).To(Equal(protocol.PacketNumberLen4)) Expect(p.frames[0].(*frames.StopWaitingFrame).PacketNumberLen).To(Equal(protocol.PacketNumberLen4))
}) })
It("does not pack a packet containing only a StopWaitingFrame", func() { It("does not pack a packet containing only a StopWaitingFrame", func() {
swf := &frames.StopWaitingFrame{LeastUnacked: 10} swf := &frames.StopWaitingFrame{LeastUnacked: 10}
p, err := packer.PackPacket(swf, 0) packer.QueueControlFrameForNextPacket(swf)
p, err := packer.PackPacket(0)
Expect(p).To(BeNil()) Expect(p).To(BeNil())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
}) })
It("packs a packet if it has queued control frames, but no new control frames", func() { It("packs a packet if it has queued control frames, but no new control frames", func() {
packer.controlFrames = []frames.Frame{&frames.BlockedFrame{StreamID: 0}} packer.controlFrames = []frames.Frame{&frames.BlockedFrame{StreamID: 0}}
p, err := packer.PackPacket(nil, 0) p, err := packer.PackPacket(0)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil()) Expect(p).ToNot(BeNil())
}) })
@ -231,7 +234,7 @@ var _ = Describe("Packet packer", func() {
packer.controlFrames = []frames.Frame{&frames.BlockedFrame{StreamID: 0}} packer.controlFrames = []frames.Frame{&frames.BlockedFrame{StreamID: 0}}
packer.connectionID = 0x1337 packer.connectionID = 0x1337
packer.version = 123 packer.version = 123
p, err := packer.PackPacket(nil, 0) p, err := packer.PackPacket(0)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil()) Expect(p).ToNot(BeNil())
hdr, err := ParsePublicHeader(bytes.NewReader(p.raw), protocol.PerspectiveClient) hdr, err := ParsePublicHeader(bytes.NewReader(p.raw), protocol.PerspectiveClient)
@ -245,7 +248,7 @@ var _ = Describe("Packet packer", func() {
packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionForwardSecure packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionForwardSecure
packer.controlFrames = []frames.Frame{&frames.BlockedFrame{StreamID: 0}} packer.controlFrames = []frames.Frame{&frames.BlockedFrame{StreamID: 0}}
packer.connectionID = 0x1337 packer.connectionID = 0x1337
p, err := packer.PackPacket(nil, 0) p, err := packer.PackPacket(0)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil()) Expect(p).ToNot(BeNil())
hdr, err := ParsePublicHeader(bytes.NewReader(p.raw), protocol.PerspectiveClient) hdr, err := ParsePublicHeader(bytes.NewReader(p.raw), protocol.PerspectiveClient)
@ -263,10 +266,10 @@ var _ = Describe("Packet packer", func() {
controlFrames = append(controlFrames, f) controlFrames = append(controlFrames, f)
} }
packer.controlFrames = controlFrames packer.controlFrames = controlFrames
payloadFrames, err := packer.composeNextPacket(nil, maxFrameSize, false) payloadFrames, err := packer.composeNextPacket(maxFrameSize, false)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(payloadFrames).To(HaveLen(maxFramesPerPacket)) Expect(payloadFrames).To(HaveLen(maxFramesPerPacket))
payloadFrames, err = packer.composeNextPacket(nil, maxFrameSize, false) payloadFrames, err = packer.composeNextPacket(maxFrameSize, false)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(payloadFrames).To(BeEmpty()) Expect(payloadFrames).To(BeEmpty())
}) })
@ -282,17 +285,17 @@ var _ = Describe("Packet packer", func() {
controlFrames = append(controlFrames, blockedFrame) controlFrames = append(controlFrames, blockedFrame)
} }
packer.controlFrames = controlFrames packer.controlFrames = controlFrames
payloadFrames, err := packer.composeNextPacket(nil, maxFrameSize, false) payloadFrames, err := packer.composeNextPacket(maxFrameSize, false)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(payloadFrames).To(HaveLen(maxFramesPerPacket)) Expect(payloadFrames).To(HaveLen(maxFramesPerPacket))
payloadFrames, err = packer.composeNextPacket(nil, maxFrameSize, false) payloadFrames, err = packer.composeNextPacket(maxFrameSize, false)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(payloadFrames).To(HaveLen(10)) Expect(payloadFrames).To(HaveLen(10))
}) })
It("only increases the packet number when there is an actual packet to send", func() { It("only increases the packet number when there is an actual packet to send", func() {
packer.packetNumberGenerator.nextToSkip = 1000 packer.packetNumberGenerator.nextToSkip = 1000
p, err := packer.PackPacket(nil, 0) p, err := packer.PackPacket(0)
Expect(p).To(BeNil()) Expect(p).To(BeNil())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(packer.packetNumberGenerator.Peek()).To(Equal(protocol.PacketNumber(1))) Expect(packer.packetNumberGenerator.Peek()).To(Equal(protocol.PacketNumber(1)))
@ -301,7 +304,7 @@ var _ = Describe("Packet packer", func() {
Data: []byte{0xDE, 0xCA, 0xFB, 0xAD}, Data: []byte{0xDE, 0xCA, 0xFB, 0xAD},
} }
streamFramer.AddFrameForRetransmission(f) streamFramer.AddFrameForRetransmission(f)
p, err = packer.PackPacket(nil, 0) p, err = packer.PackPacket(0)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil()) Expect(p).ToNot(BeNil())
Expect(p.number).To(Equal(protocol.PacketNumber(1))) Expect(p.number).To(Equal(protocol.PacketNumber(1)))
@ -319,11 +322,11 @@ var _ = Describe("Packet packer", func() {
maxStreamFrameDataLen := maxFrameSize - minLength maxStreamFrameDataLen := maxFrameSize - minLength
f.Data = bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)) f.Data = bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen))
streamFramer.AddFrameForRetransmission(f) streamFramer.AddFrameForRetransmission(f)
payloadFrames, err := packer.composeNextPacket(nil, maxFrameSize, true) payloadFrames, err := packer.composeNextPacket(maxFrameSize, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(payloadFrames).To(HaveLen(1)) Expect(payloadFrames).To(HaveLen(1))
Expect(payloadFrames[0].(*frames.StreamFrame).DataLenPresent).To(BeFalse()) Expect(payloadFrames[0].(*frames.StreamFrame).DataLenPresent).To(BeFalse())
payloadFrames, err = packer.composeNextPacket(nil, maxFrameSize, true) payloadFrames, err = packer.composeNextPacket(maxFrameSize, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(payloadFrames).To(BeEmpty()) Expect(payloadFrames).To(BeEmpty())
}) })
@ -342,12 +345,12 @@ var _ = Describe("Packet packer", func() {
} }
streamFramer.AddFrameForRetransmission(f1) streamFramer.AddFrameForRetransmission(f1)
streamFramer.AddFrameForRetransmission(f2) streamFramer.AddFrameForRetransmission(f2)
p, err := packer.PackPacket(nil, 0) p, err := packer.PackPacket(0)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize - 1))) Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize - 1)))
Expect(p.frames).To(HaveLen(1)) Expect(p.frames).To(HaveLen(1))
Expect(p.frames[0].(*frames.StreamFrame).DataLenPresent).To(BeFalse()) Expect(p.frames[0].(*frames.StreamFrame).DataLenPresent).To(BeFalse())
p, err = packer.PackPacket(nil, 0) p, err = packer.PackPacket(0)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(HaveLen(1)) Expect(p.frames).To(HaveLen(1))
Expect(p.frames[0].(*frames.StreamFrame).DataLenPresent).To(BeFalse()) Expect(p.frames[0].(*frames.StreamFrame).DataLenPresent).To(BeFalse())
@ -369,7 +372,7 @@ var _ = Describe("Packet packer", func() {
streamFramer.AddFrameForRetransmission(f1) streamFramer.AddFrameForRetransmission(f1)
streamFramer.AddFrameForRetransmission(f2) streamFramer.AddFrameForRetransmission(f2)
streamFramer.AddFrameForRetransmission(f3) streamFramer.AddFrameForRetransmission(f3)
p, err := packer.PackPacket(nil, 0) p, err := packer.PackPacket(0)
Expect(p).ToNot(BeNil()) Expect(p).ToNot(BeNil())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
b := &bytes.Buffer{} b := &bytes.Buffer{}
@ -394,17 +397,17 @@ var _ = Describe("Packet packer", func() {
maxStreamFrameDataLen := protocol.MaxFrameAndPublicHeaderSize - publicHeaderLen - minLength maxStreamFrameDataLen := protocol.MaxFrameAndPublicHeaderSize - publicHeaderLen - minLength
f.Data = bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)+200) f.Data = bytes.Repeat([]byte{'f'}, int(maxStreamFrameDataLen)+200)
streamFramer.AddFrameForRetransmission(f) streamFramer.AddFrameForRetransmission(f)
payloadFrames, err := packer.composeNextPacket(nil, maxFrameSize, true) payloadFrames, err := packer.composeNextPacket(maxFrameSize, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(payloadFrames).To(HaveLen(1)) Expect(payloadFrames).To(HaveLen(1))
Expect(payloadFrames[0].(*frames.StreamFrame).DataLenPresent).To(BeFalse()) Expect(payloadFrames[0].(*frames.StreamFrame).DataLenPresent).To(BeFalse())
Expect(payloadFrames[0].(*frames.StreamFrame).Data).To(HaveLen(int(maxStreamFrameDataLen))) Expect(payloadFrames[0].(*frames.StreamFrame).Data).To(HaveLen(int(maxStreamFrameDataLen)))
payloadFrames, err = packer.composeNextPacket(nil, maxFrameSize, true) payloadFrames, err = packer.composeNextPacket(maxFrameSize, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(payloadFrames).To(HaveLen(1)) Expect(payloadFrames).To(HaveLen(1))
Expect(payloadFrames[0].(*frames.StreamFrame).Data).To(HaveLen(200)) Expect(payloadFrames[0].(*frames.StreamFrame).Data).To(HaveLen(200))
Expect(payloadFrames[0].(*frames.StreamFrame).DataLenPresent).To(BeFalse()) Expect(payloadFrames[0].(*frames.StreamFrame).DataLenPresent).To(BeFalse())
payloadFrames, err = packer.composeNextPacket(nil, maxFrameSize, true) payloadFrames, err = packer.composeNextPacket(maxFrameSize, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(payloadFrames).To(BeEmpty()) Expect(payloadFrames).To(BeEmpty())
}) })
@ -423,23 +426,23 @@ var _ = Describe("Packet packer", func() {
} }
streamFramer.AddFrameForRetransmission(f1) streamFramer.AddFrameForRetransmission(f1)
streamFramer.AddFrameForRetransmission(f2) streamFramer.AddFrameForRetransmission(f2)
p, err := packer.PackPacket(nil, 0) p, err := packer.PackPacket(0)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(HaveLen(1)) Expect(p.frames).To(HaveLen(1))
Expect(p.frames[0].(*frames.StreamFrame).DataLenPresent).To(BeFalse()) Expect(p.frames[0].(*frames.StreamFrame).DataLenPresent).To(BeFalse())
Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize))) Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize)))
p, err = packer.PackPacket(nil, 0) p, err = packer.PackPacket(0)
Expect(p.frames).To(HaveLen(2)) Expect(p.frames).To(HaveLen(2))
Expect(p.frames[0].(*frames.StreamFrame).DataLenPresent).To(BeTrue()) Expect(p.frames[0].(*frames.StreamFrame).DataLenPresent).To(BeTrue())
Expect(p.frames[1].(*frames.StreamFrame).DataLenPresent).To(BeFalse()) Expect(p.frames[1].(*frames.StreamFrame).DataLenPresent).To(BeFalse())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize))) Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize)))
p, err = packer.PackPacket(nil, 0) p, err = packer.PackPacket(0)
Expect(p.frames).To(HaveLen(1)) Expect(p.frames).To(HaveLen(1))
Expect(p.frames[0].(*frames.StreamFrame).DataLenPresent).To(BeFalse()) Expect(p.frames[0].(*frames.StreamFrame).DataLenPresent).To(BeFalse())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil()) Expect(p).ToNot(BeNil())
p, err = packer.PackPacket(nil, 0) p, err = packer.PackPacket(0)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p).To(BeNil()) Expect(p).To(BeNil())
}) })
@ -452,7 +455,7 @@ var _ = Describe("Packet packer", func() {
minLength, _ := f.MinLength(0) minLength, _ := f.MinLength(0)
f.Data = bytes.Repeat([]byte{'f'}, int(protocol.MaxFrameAndPublicHeaderSize-publicHeaderLen-minLength+1)) // + 1 since MinceLength is 1 bigger than the actual StreamFrame header f.Data = bytes.Repeat([]byte{'f'}, int(protocol.MaxFrameAndPublicHeaderSize-publicHeaderLen-minLength+1)) // + 1 since MinceLength is 1 bigger than the actual StreamFrame header
streamFramer.AddFrameForRetransmission(f) streamFramer.AddFrameForRetransmission(f)
p, err := packer.PackPacket(nil, 0) p, err := packer.PackPacket(0)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil()) Expect(p).ToNot(BeNil())
Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize))) Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize)))
@ -467,10 +470,10 @@ var _ = Describe("Packet packer", func() {
f.Data = bytes.Repeat([]byte{'f'}, int(protocol.MaxFrameAndPublicHeaderSize-publicHeaderLen-minLength+2)) // + 2 since MinceLength is 1 bigger than the actual StreamFrame header f.Data = bytes.Repeat([]byte{'f'}, int(protocol.MaxFrameAndPublicHeaderSize-publicHeaderLen-minLength+2)) // + 2 since MinceLength is 1 bigger than the actual StreamFrame header
streamFramer.AddFrameForRetransmission(f) streamFramer.AddFrameForRetransmission(f)
payloadFrames, err := packer.composeNextPacket(nil, maxFrameSize, true) payloadFrames, err := packer.composeNextPacket(maxFrameSize, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(payloadFrames).To(HaveLen(1)) Expect(payloadFrames).To(HaveLen(1))
payloadFrames, err = packer.composeNextPacket(nil, maxFrameSize, true) payloadFrames, err = packer.composeNextPacket(maxFrameSize, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(payloadFrames).To(HaveLen(1)) Expect(payloadFrames).To(HaveLen(1))
}) })
@ -482,7 +485,7 @@ var _ = Describe("Packet packer", func() {
Data: []byte("foobar"), Data: []byte("foobar"),
} }
streamFramer.AddFrameForRetransmission(f) streamFramer.AddFrameForRetransmission(f)
p, err := packer.PackPacket(nil, 0) p, err := packer.PackPacket(0)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(p).To(BeNil()) Expect(p).To(BeNil())
}) })
@ -495,7 +498,7 @@ var _ = Describe("Packet packer", func() {
Data: []byte("foobar"), Data: []byte("foobar"),
} }
streamFramer.AddFrameForRetransmission(f) streamFramer.AddFrameForRetransmission(f)
p, err := packer.PackPacket(nil, 0) p, err := packer.PackPacket(0)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure)) Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure))
Expect(p.frames[0]).To(Equal(f)) Expect(p.frames[0]).To(Equal(f))
@ -508,7 +511,7 @@ var _ = Describe("Packet packer", func() {
Data: []byte("foobar"), Data: []byte("foobar"),
} }
streamFramer.AddFrameForRetransmission(f) streamFramer.AddFrameForRetransmission(f)
p, err := packer.PackPacket(nil, 0) p, err := packer.PackPacket(0)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p).To(BeNil()) Expect(p).To(BeNil())
}) })
@ -516,7 +519,7 @@ var _ = Describe("Packet packer", func() {
It("sends unencrypted stream data on the crypto stream", func() { It("sends unencrypted stream data on the crypto stream", func() {
packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionUnencrypted packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionUnencrypted
cryptoStream.dataForWriting = []byte("foobar") cryptoStream.dataForWriting = []byte("foobar")
p, err := packer.PackPacket(nil, 0) p, err := packer.PackPacket(0)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p.encryptionLevel).To(Equal(protocol.EncryptionUnencrypted)) Expect(p.encryptionLevel).To(Equal(protocol.EncryptionUnencrypted))
Expect(p.frames).To(HaveLen(1)) Expect(p.frames).To(HaveLen(1))
@ -526,7 +529,7 @@ var _ = Describe("Packet packer", func() {
It("sends encrypted stream data on the crypto stream", func() { It("sends encrypted stream data on the crypto stream", func() {
packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionSecure packer.cryptoSetup.(*mockCryptoSetup).encLevelSealCrypto = protocol.EncryptionSecure
cryptoStream.dataForWriting = []byte("foobar") cryptoStream.dataForWriting = []byte("foobar")
p, err := packer.PackPacket(nil, 0) p, err := packer.PackPacket(0)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure)) Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure))
Expect(p.frames).To(HaveLen(1)) Expect(p.frames).To(HaveLen(1))
@ -537,7 +540,7 @@ var _ = Describe("Packet packer", func() {
packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionUnencrypted packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionUnencrypted
packer.QueueControlFrameForNextPacket(&frames.AckFrame{}) packer.QueueControlFrameForNextPacket(&frames.AckFrame{})
streamFramer.AddFrameForRetransmission(&frames.StreamFrame{StreamID: 3, Data: []byte("foobar")}) streamFramer.AddFrameForRetransmission(&frames.StreamFrame{StreamID: 3, Data: []byte("foobar")})
p, err := packer.PackPacket(nil, 0) p, err := packer.PackPacket(0)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(HaveLen(1)) Expect(p.frames).To(HaveLen(1))
Expect(func() { _ = p.frames[0].(*frames.AckFrame) }).NotTo(Panic()) Expect(func() { _ = p.frames[0].(*frames.AckFrame) }).NotTo(Panic())
@ -553,7 +556,7 @@ var _ = Describe("Packet packer", func() {
Data: bytes.Repeat([]byte{'f'}, length), Data: bytes.Repeat([]byte{'f'}, length),
} }
streamFramer.AddFrameForRetransmission(f) streamFramer.AddFrameForRetransmission(f)
_, err := packer.composeNextPacket(nil, maxFrameSize, true) _, err := packer.composeNextPacket(maxFrameSize, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(packer.controlFrames[0]).To(Equal(&frames.BlockedFrame{StreamID: 5})) Expect(packer.controlFrames[0]).To(Equal(&frames.BlockedFrame{StreamID: 5}))
}) })
@ -566,7 +569,7 @@ var _ = Describe("Packet packer", func() {
Data: bytes.Repeat([]byte{'f'}, length), Data: bytes.Repeat([]byte{'f'}, length),
} }
streamFramer.AddFrameForRetransmission(f) streamFramer.AddFrameForRetransmission(f)
p, err := packer.composeNextPacket(nil, maxFrameSize, true) p, err := packer.composeNextPacket(maxFrameSize, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p).To(HaveLen(1)) Expect(p).To(HaveLen(1))
Expect(p[0].(*frames.StreamFrame).DataLenPresent).To(BeFalse()) Expect(p[0].(*frames.StreamFrame).DataLenPresent).To(BeFalse())
@ -579,14 +582,15 @@ var _ = Describe("Packet packer", func() {
Data: []byte("foobar"), Data: []byte("foobar"),
} }
streamFramer.AddFrameForRetransmission(f) streamFramer.AddFrameForRetransmission(f)
_, err := packer.composeNextPacket(nil, maxFrameSize, true) _, err := packer.composeNextPacket(maxFrameSize, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(packer.controlFrames[0]).To(Equal(&frames.BlockedFrame{StreamID: 0})) Expect(packer.controlFrames[0]).To(Equal(&frames.BlockedFrame{StreamID: 0}))
}) })
}) })
It("returns nil if we only have a single STOP_WAITING", func() { It("returns nil if we only have a single STOP_WAITING", func() {
p, err := packer.PackPacket(&frames.StopWaitingFrame{}, 0) packer.QueueControlFrameForNextPacket(&frames.StopWaitingFrame{})
p, err := packer.PackPacket(0)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(p).To(BeNil()) Expect(p).To(BeNil())
}) })
@ -594,7 +598,7 @@ var _ = Describe("Packet packer", func() {
It("packs a single ACK", func() { It("packs a single ACK", func() {
ack := &frames.AckFrame{LargestAcked: 42} ack := &frames.AckFrame{LargestAcked: 42}
packer.QueueControlFrameForNextPacket(ack) packer.QueueControlFrameForNextPacket(ack)
p, err := packer.PackPacket(nil, 0) p, err := packer.PackPacket(0)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(p).ToNot(BeNil()) Expect(p).ToNot(BeNil())
Expect(p.frames[0]).To(Equal(ack)) Expect(p.frames[0]).To(Equal(ack))
@ -603,7 +607,7 @@ var _ = Describe("Packet packer", func() {
It("does not return nil if we only have a single ACK but request it to be sent", func() { It("does not return nil if we only have a single ACK but request it to be sent", func() {
ack := &frames.AckFrame{} ack := &frames.AckFrame{}
packer.QueueControlFrameForNextPacket(ack) packer.QueueControlFrameForNextPacket(ack)
p, err := packer.PackPacket(nil, 0) p, err := packer.PackPacket(0)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(p).ToNot(BeNil()) Expect(p).ToNot(BeNil())
}) })
@ -611,7 +615,7 @@ var _ = Describe("Packet packer", func() {
It("queues a control frame to be sent in the next packet", func() { It("queues a control frame to be sent in the next packet", func() {
wuf := &frames.WindowUpdateFrame{StreamID: 5} wuf := &frames.WindowUpdateFrame{StreamID: 5}
packer.QueueControlFrameForNextPacket(wuf) packer.QueueControlFrameForNextPacket(wuf)
p, err := packer.PackPacket(nil, 0) p, err := packer.PackPacket(0)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(p.frames).To(HaveLen(1)) Expect(p.frames).To(HaveLen(1))
Expect(p.frames[0]).To(Equal(wuf)) Expect(p.frames[0]).To(Equal(wuf))
@ -624,12 +628,16 @@ var _ = Describe("Packet packer", func() {
Data: []byte("foobar"), Data: []byte("foobar"),
} }
BeforeEach(func() {
packer.QueueControlFrameForNextPacket(swf)
})
It("packs a retransmission for a packet sent with no encryption", func() { It("packs a retransmission for a packet sent with no encryption", func() {
packet := &ackhandler.Packet{ packet := &ackhandler.Packet{
EncryptionLevel: protocol.EncryptionUnencrypted, EncryptionLevel: protocol.EncryptionUnencrypted,
Frames: []frames.Frame{sf}, Frames: []frames.Frame{sf},
} }
p, err := packer.RetransmitNonForwardSecurePacket(swf, packet) p, err := packer.RetransmitNonForwardSecurePacket(packet)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(ContainElement(sf)) Expect(p.frames).To(ContainElement(sf))
Expect(p.frames).To(ContainElement(swf)) Expect(p.frames).To(ContainElement(swf))
@ -643,7 +651,7 @@ var _ = Describe("Packet packer", func() {
EncryptionLevel: protocol.EncryptionSecure, EncryptionLevel: protocol.EncryptionSecure,
Frames: []frames.Frame{sf}, Frames: []frames.Frame{sf},
} }
p, err := packer.RetransmitNonForwardSecurePacket(swf, packet) p, err := packer.RetransmitNonForwardSecurePacket(packet)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(ContainElement(sf)) Expect(p.frames).To(ContainElement(sf))
Expect(p.frames).To(ContainElement(swf)) Expect(p.frames).To(ContainElement(swf))
@ -658,7 +666,7 @@ var _ = Describe("Packet packer", func() {
EncryptionLevel: protocol.EncryptionSecure, EncryptionLevel: protocol.EncryptionSecure,
Frames: []frames.Frame{sf}, Frames: []frames.Frame{sf},
} }
p, err := packer.RetransmitNonForwardSecurePacket(swf, packet) p, err := packer.RetransmitNonForwardSecurePacket(packet)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure)) Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure))
}) })
@ -674,7 +682,7 @@ var _ = Describe("Packet packer", func() {
&frames.AckFrame{}, &frames.AckFrame{},
}, },
} }
p, err := packer.RetransmitNonForwardSecurePacket(swf, packet) p, err := packer.RetransmitNonForwardSecurePacket(packet)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(HaveLen(3)) Expect(p.frames).To(HaveLen(3))
Expect(p.frames).To(ContainElement(sf)) Expect(p.frames).To(ContainElement(sf))
@ -688,7 +696,7 @@ var _ = Describe("Packet packer", func() {
EncryptionLevel: protocol.EncryptionSecure, EncryptionLevel: protocol.EncryptionSecure,
Frames: []frames.Frame{&frames.AckFrame{}, &frames.StopWaitingFrame{}}, Frames: []frames.Frame{&frames.AckFrame{}, &frames.StopWaitingFrame{}},
} }
p, err := packer.RetransmitNonForwardSecurePacket(swf, packet) p, err := packer.RetransmitNonForwardSecurePacket(packet)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p).To(BeNil()) Expect(p).To(BeNil())
}) })
@ -705,7 +713,7 @@ var _ = Describe("Packet packer", func() {
}, },
}, },
} }
_, err := packer.RetransmitNonForwardSecurePacket(swf, packet) _, err := packer.RetransmitNonForwardSecurePacket(packet)
Expect(err).To(MatchError("PacketPacker BUG: packet too large")) Expect(err).To(MatchError("PacketPacker BUG: packet too large"))
}) })
@ -713,15 +721,15 @@ var _ = Describe("Packet packer", func() {
p := &ackhandler.Packet{ p := &ackhandler.Packet{
EncryptionLevel: protocol.EncryptionForwardSecure, EncryptionLevel: protocol.EncryptionForwardSecure,
} }
_, err := packer.RetransmitNonForwardSecurePacket(nil, p) _, err := packer.RetransmitNonForwardSecurePacket(p)
Expect(err).To(MatchError("PacketPacker BUG: forward-secure encrypted handshake packets don't need special treatment")) Expect(err).To(MatchError("PacketPacker BUG: forward-secure encrypted handshake packets don't need special treatment"))
}) })
It("refuses to retransmit packets without a StopWaitingFrame", func() { It("refuses to retransmit packets without a StopWaitingFrame", func() {
p := &ackhandler.Packet{ packer.stopWaiting = nil
_, err := packer.RetransmitNonForwardSecurePacket(&ackhandler.Packet{
EncryptionLevel: protocol.EncryptionSecure, EncryptionLevel: protocol.EncryptionSecure,
} })
_, err := packer.RetransmitNonForwardSecurePacket(nil, p)
Expect(err).To(MatchError("PacketPacker BUG: Handshake retransmissions must contain a StopWaitingFrame")) Expect(err).To(MatchError("PacketPacker BUG: Handshake retransmissions must contain a StopWaitingFrame"))
}) })
}) })

View file

@ -589,9 +589,9 @@ func (s *session) sendPacket() error {
continue continue
} }
utils.Debugf("\tDequeueing handshake retransmission for packet 0x%x", retransmitPacket.PacketNumber) utils.Debugf("\tDequeueing handshake retransmission for packet 0x%x", retransmitPacket.PacketNumber)
stopWaitingFrame := s.sentPacketHandler.GetStopWaitingFrame(true) s.packer.QueueControlFrameForNextPacket(s.sentPacketHandler.GetStopWaitingFrame(true))
var packet *packedPacket var packet *packedPacket
packet, err := s.packer.RetransmitNonForwardSecurePacket(stopWaitingFrame, retransmitPacket) packet, err := s.packer.RetransmitNonForwardSecurePacket(retransmitPacket)
if err != nil { if err != nil {
return err return err
} }
@ -629,11 +629,13 @@ func (s *session) sendPacket() error {
s.packer.QueueControlFrameForNextPacket(ack) s.packer.QueueControlFrameForNextPacket(ack)
} }
hasRetransmission := s.streamFramer.HasFramesForRetransmission() hasRetransmission := s.streamFramer.HasFramesForRetransmission()
var stopWaitingFrame *frames.StopWaitingFrame
if ack != nil || hasRetransmission { if ack != nil || hasRetransmission {
stopWaitingFrame = s.sentPacketHandler.GetStopWaitingFrame(hasRetransmission) swf := s.sentPacketHandler.GetStopWaitingFrame(hasRetransmission)
if swf != nil {
s.packer.QueueControlFrameForNextPacket(swf)
}
} }
packet, err := s.packer.PackPacket(stopWaitingFrame, s.sentPacketHandler.GetLeastUnacked()) packet, err := s.packer.PackPacket(s.sentPacketHandler.GetLeastUnacked())
if err != nil { if err != nil {
return err return err
} }