implement a sophisticated ack algorithm

fixes #354
This commit is contained in:
Marten Seemann 2017-01-09 13:02:01 +07:00
parent 634d2dc6b7
commit edf6ee3dfc
No known key found for this signature in database
GPG key ID: 3603F40B121FCDEA
8 changed files with 370 additions and 256 deletions

View file

@ -28,8 +28,8 @@ type SentPacketHandler interface {
// ReceivedPacketHandler handles ACKs needed to send for incoming packets
type ReceivedPacketHandler interface {
ReceivedPacket(packetNumber protocol.PacketNumber) error
ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error
ReceivedStopWaiting(*frames.StopWaitingFrame) error
GetAckFrame(dequeue bool) (*frames.AckFrame, error)
GetAckFrame() *frames.AckFrame
}

View file

@ -20,22 +20,34 @@ var errInvalidPacketNumber = errors.New("ReceivedPacketHandler: Invalid packet n
type receivedPacketHandler struct {
largestObserved protocol.PacketNumber
ignorePacketsBelow protocol.PacketNumber
currentAckFrame *frames.AckFrame
stateChanged bool // has an ACK for this state already been sent? Will be set to false every time a new packet arrives, and to false every time an ACK is sent
largestObservedReceivedTime time.Time
packetHistory *receivedPacketHistory
largestObservedReceivedTime time.Time
ackSendDelay time.Duration
packetsReceivedSinceLastAck int
retransmittablePacketsReceivedSinceLastAck int
ackQueued bool
ackAlarm time.Time
ackAlarmResetCallback func(time.Time)
lastAck *frames.AckFrame
}
// NewReceivedPacketHandler creates a new receivedPacketHandler
func NewReceivedPacketHandler() ReceivedPacketHandler {
func NewReceivedPacketHandler(ackAlarmResetCallback func(time.Time)) ReceivedPacketHandler {
// create a stopped timer, see https://github.com/golang/go/issues/12721#issuecomment-143010182
timer := time.NewTimer(0)
<-timer.C
return &receivedPacketHandler{
packetHistory: newReceivedPacketHistory(),
ackAlarmResetCallback: ackAlarmResetCallback,
ackSendDelay: protocol.AckSendDelay,
}
}
func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber) error {
func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error {
if packetNumber == 0 {
return errInvalidPacketNumber
}
@ -55,14 +67,12 @@ func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumbe
return err
}
h.stateChanged = true
h.currentAckFrame = nil
if packetNumber > h.largestObserved {
h.largestObserved = packetNumber
h.largestObservedReceivedTime = time.Now()
}
h.maybeQueueAck(packetNumber, shouldInstigateAck)
return nil
}
@ -78,29 +88,79 @@ func (h *receivedPacketHandler) ReceivedStopWaiting(f *frames.StopWaitingFrame)
return nil
}
func (h *receivedPacketHandler) GetAckFrame(dequeue bool) (*frames.AckFrame, error) {
if !h.stateChanged {
return nil, nil
func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, shouldInstigateAck bool) {
var ackAlarmSet bool
h.packetsReceivedSinceLastAck++
if shouldInstigateAck {
h.retransmittablePacketsReceivedSinceLastAck++
}
if dequeue {
h.stateChanged = false
// always ack the first packet
if h.lastAck == nil {
h.ackQueued = true
}
if h.currentAckFrame != nil {
return h.currentAckFrame, nil
// Always send an ack every 20 packets in order to allow the peer to discard
// information from the SentPacketManager and provide an RTT measurement.
if h.packetsReceivedSinceLastAck >= protocol.MaxPacketsReceivedBeforeAckSend {
h.ackQueued = true
}
// if the packet number is smaller than the largest acked packet, it must have been reported missing with the last ACK
// note that it cannot be a duplicate because they're already filtered out by ReceivedPacket()
if h.lastAck != nil && packetNumber < h.lastAck.LargestAcked {
h.ackQueued = true
}
// check if a new missing range above the previously was created
if h.lastAck != nil && h.packetHistory.GetHighestAckRange().FirstPacketNumber > h.lastAck.LargestAcked {
h.ackQueued = true
}
if !h.ackQueued && shouldInstigateAck {
if h.retransmittablePacketsReceivedSinceLastAck >= protocol.RetransmittablePacketsBeforeAck {
h.ackQueued = true
} else {
if h.ackAlarm.IsZero() {
h.ackAlarm = time.Now().Add(h.ackSendDelay)
ackAlarmSet = true
}
}
}
if h.ackQueued {
// cancel the ack alarm
h.ackAlarm = time.Time{}
ackAlarmSet = false
}
if ackAlarmSet {
h.ackAlarmResetCallback(h.ackAlarm)
}
}
func (h *receivedPacketHandler) GetAckFrame() *frames.AckFrame {
if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(time.Now())) {
return nil
}
ackRanges := h.packetHistory.GetAckRanges()
h.currentAckFrame = &frames.AckFrame{
ack := &frames.AckFrame{
LargestAcked: h.largestObserved,
LowestAcked: ackRanges[len(ackRanges)-1].FirstPacketNumber,
PacketReceivedTime: h.largestObservedReceivedTime,
}
if len(ackRanges) > 1 {
h.currentAckFrame.AckRanges = ackRanges
ack.AckRanges = ackRanges
}
return h.currentAckFrame, nil
h.lastAck = ack
h.ackAlarm = time.Time{}
h.ackQueued = false
h.packetsReceivedSinceLastAck = 0
h.retransmittablePacketsReceivedSinceLastAck = 0
return ack
}

View file

@ -13,56 +13,62 @@ import (
var _ = Describe("receivedPacketHandler", func() {
var (
handler *receivedPacketHandler
ackAlarmCallbackCalled bool
)
ackAlarmCallback := func(time.Time) {
ackAlarmCallbackCalled = true
}
BeforeEach(func() {
handler = NewReceivedPacketHandler().(*receivedPacketHandler)
ackAlarmCallbackCalled = false
handler = NewReceivedPacketHandler(ackAlarmCallback).(*receivedPacketHandler)
})
Context("accepting packets", func() {
It("handles a packet that arrives late", func() {
err := handler.ReceivedPacket(protocol.PacketNumber(1))
err := handler.ReceivedPacket(protocol.PacketNumber(1), true)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(protocol.PacketNumber(3))
err = handler.ReceivedPacket(protocol.PacketNumber(3), true)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(protocol.PacketNumber(2))
err = handler.ReceivedPacket(protocol.PacketNumber(2), true)
Expect(err).ToNot(HaveOccurred())
})
It("rejects packets with packet number 0", func() {
err := handler.ReceivedPacket(protocol.PacketNumber(0))
err := handler.ReceivedPacket(protocol.PacketNumber(0), true)
Expect(err).To(MatchError(errInvalidPacketNumber))
})
It("rejects a duplicate package", func() {
for i := 1; i < 5; i++ {
err := handler.ReceivedPacket(protocol.PacketNumber(i))
err := handler.ReceivedPacket(protocol.PacketNumber(i), true)
Expect(err).ToNot(HaveOccurred())
}
err := handler.ReceivedPacket(4)
err := handler.ReceivedPacket(4, true)
Expect(err).To(MatchError(ErrDuplicatePacket))
})
It("ignores a packet with PacketNumber less than the LeastUnacked of a previously received StopWaiting", func() {
err := handler.ReceivedPacket(5)
err := handler.ReceivedPacket(5, true)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedStopWaiting(&frames.StopWaitingFrame{LeastUnacked: 10})
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(9)
err = handler.ReceivedPacket(9, true)
Expect(err).To(MatchError(ErrPacketSmallerThanLastStopWaiting))
})
It("does not ignore a packet with PacketNumber equal to LeastUnacked of a previously received StopWaiting", func() {
err := handler.ReceivedPacket(5)
err := handler.ReceivedPacket(5, true)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedStopWaiting(&frames.StopWaitingFrame{LeastUnacked: 10})
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(10)
err = handler.ReceivedPacket(10, true)
Expect(err).ToNot(HaveOccurred())
})
It("saves the time when each packet arrived", func() {
err := handler.ReceivedPacket(protocol.PacketNumber(3))
err := handler.ReceivedPacket(protocol.PacketNumber(3), true)
Expect(err).ToNot(HaveOccurred())
Expect(handler.largestObservedReceivedTime).To(BeTemporally("~", time.Now(), 10*time.Millisecond))
})
@ -70,7 +76,7 @@ var _ = Describe("receivedPacketHandler", func() {
It("updates the largestObserved and the largestObservedReceivedTime", func() {
handler.largestObserved = 3
handler.largestObservedReceivedTime = time.Now().Add(-1 * time.Second)
err := handler.ReceivedPacket(5)
err := handler.ReceivedPacket(5, true)
Expect(err).ToNot(HaveOccurred())
Expect(handler.largestObserved).To(Equal(protocol.PacketNumber(5)))
Expect(handler.largestObservedReceivedTime).To(BeTemporally("~", time.Now(), 10*time.Millisecond))
@ -80,27 +86,27 @@ var _ = Describe("receivedPacketHandler", func() {
timestamp := time.Now().Add(-1 * time.Second)
handler.largestObserved = 5
handler.largestObservedReceivedTime = timestamp
err := handler.ReceivedPacket(4)
err := handler.ReceivedPacket(4, true)
Expect(err).ToNot(HaveOccurred())
Expect(handler.largestObserved).To(Equal(protocol.PacketNumber(5)))
Expect(handler.largestObservedReceivedTime).To(Equal(timestamp))
})
It("doesn't store more than MaxTrackedReceivedPackets packets", func() {
err := handler.ReceivedPacket(1)
err := handler.ReceivedPacket(1, true)
Expect(err).ToNot(HaveOccurred())
for i := protocol.PacketNumber(3); i < 3+protocol.MaxTrackedReceivedPackets-1; i++ {
err := handler.ReceivedPacket(protocol.PacketNumber(i))
err := handler.ReceivedPacket(protocol.PacketNumber(i), true)
Expect(err).ToNot(HaveOccurred())
}
err = handler.ReceivedPacket(protocol.PacketNumber(protocol.MaxTrackedReceivedPackets) + 10)
err = handler.ReceivedPacket(protocol.PacketNumber(protocol.MaxTrackedReceivedPackets)+10, true)
Expect(err).To(MatchError(errTooManyOutstandingReceivedPackets))
})
It("passes on errors from receivedPacketHistory", func() {
var err error
for i := protocol.PacketNumber(0); i < 5*protocol.MaxTrackedReceivedAckRanges; i++ {
err = handler.ReceivedPacket(2*i + 1)
err = handler.ReceivedPacket(2*i+1, true)
// this will eventually return an error
// details about when exactly the receivedPacketHistory errors are tested there
if err != nil {
@ -120,7 +126,7 @@ var _ = Describe("receivedPacketHandler", func() {
It("increase the ignorePacketsBelow number, even if all packets below the LeastUnacked were already acked", func() {
for i := 1; i < 20; i++ {
err := handler.ReceivedPacket(protocol.PacketNumber(i))
err := handler.ReceivedPacket(protocol.PacketNumber(i), true)
Expect(err).ToNot(HaveOccurred())
}
err := handler.ReceivedStopWaiting(&frames.StopWaitingFrame{LeastUnacked: protocol.PacketNumber(12)})
@ -138,26 +144,129 @@ var _ = Describe("receivedPacketHandler", func() {
})
})
Context("ACK package generation", func() {
Context("ACKs", func() {
Context("queueing ACKs", func() {
receiveAndAck10Packets := func() {
for i := 1; i <= 10; i++ {
err := handler.ReceivedPacket(protocol.PacketNumber(i), true)
Expect(err).ToNot(HaveOccurred())
}
Expect(handler.GetAckFrame()).ToNot(BeNil())
Expect(handler.ackQueued).To(BeFalse())
ackAlarmCallbackCalled = false
}
It("always queues an ACK for the first packet", func() {
err := handler.ReceivedPacket(1, false)
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeTrue())
Expect(ackAlarmCallbackCalled).To(BeFalse())
})
It("only queues one ACK for many non-retransmittable packets", func() {
receiveAndAck10Packets()
for i := 11; i < 10+protocol.MaxPacketsReceivedBeforeAckSend; i++ {
err := handler.ReceivedPacket(protocol.PacketNumber(i), false)
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeFalse())
}
err := handler.ReceivedPacket(10+protocol.MaxPacketsReceivedBeforeAckSend, false)
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeTrue())
Expect(ackAlarmCallbackCalled).To(BeFalse())
})
It("queues an ACK for every second retransmittable packet, if they are arriving fast", func() {
receiveAndAck10Packets()
err := handler.ReceivedPacket(11, true)
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeFalse())
Expect(ackAlarmCallbackCalled).To(BeTrue())
ackAlarmCallbackCalled = false
err = handler.ReceivedPacket(12, true)
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeTrue())
Expect(ackAlarmCallbackCalled).To(BeFalse())
})
It("only sets the timer when receiving a retransmittable packets", func() {
receiveAndAck10Packets()
err := handler.ReceivedPacket(11, false)
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeFalse())
Expect(handler.ackAlarm).To(BeZero())
err = handler.ReceivedPacket(12, true)
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeFalse())
Expect(handler.ackAlarm).ToNot(BeZero())
Expect(ackAlarmCallbackCalled).To(BeTrue())
})
It("queues an ACK if it was reported missing before", func() {
receiveAndAck10Packets()
err := handler.ReceivedPacket(11, true)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(13, true)
Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame() // ACK: 1 and 3, missing: 2
Expect(ack).ToNot(BeNil())
Expect(ack.HasMissingRanges()).To(BeTrue())
Expect(handler.ackQueued).To(BeFalse())
err = handler.ReceivedPacket(12, false)
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeTrue())
})
It("queues an ACK if it creates a new missing range", func() {
receiveAndAck10Packets()
for i := 11; i < 16; i++ {
err := handler.ReceivedPacket(protocol.PacketNumber(i), true)
Expect(err).ToNot(HaveOccurred())
}
Expect(handler.GetAckFrame()).ToNot(BeNil())
handler.ReceivedPacket(20, true) // we now know that packets 16 to 19 are missing
Expect(handler.ackQueued).To(BeTrue())
})
})
Context("ACK generation", func() {
BeforeEach(func() {
handler.ackQueued = true
})
It("generates a simple ACK frame", func() {
err := handler.ReceivedPacket(protocol.PacketNumber(1))
err := handler.ReceivedPacket(1, true)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(protocol.PacketNumber(2))
Expect(err).ToNot(HaveOccurred())
ack, err := handler.GetAckFrame(true)
err = handler.ReceivedPacket(2, true)
Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame()
Expect(ack).ToNot(BeNil())
Expect(ack.LargestAcked).To(Equal(protocol.PacketNumber(2)))
Expect(ack.LowestAcked).To(Equal(protocol.PacketNumber(1)))
Expect(ack.AckRanges).To(BeEmpty())
})
It("saves the last sent ACK", func() {
err := handler.ReceivedPacket(1, true)
Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame()
Expect(ack).ToNot(BeNil())
Expect(handler.lastAck).To(Equal(ack))
err = handler.ReceivedPacket(2, true)
Expect(err).ToNot(HaveOccurred())
handler.ackQueued = true
ack = handler.GetAckFrame()
Expect(ack).ToNot(BeNil())
Expect(handler.lastAck).To(Equal(ack))
})
It("generates an ACK frame with missing packets", func() {
err := handler.ReceivedPacket(protocol.PacketNumber(1))
err := handler.ReceivedPacket(1, true)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(protocol.PacketNumber(4))
Expect(err).ToNot(HaveOccurred())
ack, err := handler.GetAckFrame(true)
err = handler.ReceivedPacket(4, true)
Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame()
Expect(ack).ToNot(BeNil())
Expect(ack.LargestAcked).To(Equal(protocol.PacketNumber(4)))
Expect(ack.LowestAcked).To(Equal(protocol.PacketNumber(1)))
Expect(ack.AckRanges).To(HaveLen(2))
@ -165,106 +274,55 @@ var _ = Describe("receivedPacketHandler", func() {
Expect(ack.AckRanges[1]).To(Equal(frames.AckRange{FirstPacketNumber: 1, LastPacketNumber: 1}))
})
It("does not generate an ACK if an ACK has already been sent for the largest Packet", func() {
err := handler.ReceivedPacket(protocol.PacketNumber(1))
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(protocol.PacketNumber(2))
Expect(err).ToNot(HaveOccurred())
ack, err := handler.GetAckFrame(true)
Expect(err).ToNot(HaveOccurred())
Expect(ack).ToNot(BeNil())
ack, err = handler.GetAckFrame(true)
Expect(err).ToNot(HaveOccurred())
Expect(ack).To(BeNil())
})
It("does not dequeue an ACK frame if told so", func() {
err := handler.ReceivedPacket(protocol.PacketNumber(2))
Expect(err).ToNot(HaveOccurred())
ack, err := handler.GetAckFrame(false)
Expect(err).ToNot(HaveOccurred())
Expect(ack).ToNot(BeNil())
ack, err = handler.GetAckFrame(false)
Expect(err).ToNot(HaveOccurred())
Expect(ack).ToNot(BeNil())
ack, err = handler.GetAckFrame(false)
Expect(err).ToNot(HaveOccurred())
Expect(ack).ToNot(BeNil())
})
It("returns a cached ACK frame if the ACK was not dequeued", func() {
err := handler.ReceivedPacket(protocol.PacketNumber(2))
Expect(err).ToNot(HaveOccurred())
ack, err := handler.GetAckFrame(false)
Expect(err).ToNot(HaveOccurred())
Expect(ack).ToNot(BeNil())
ack2, err := handler.GetAckFrame(false)
Expect(err).ToNot(HaveOccurred())
Expect(ack2).ToNot(BeNil())
Expect(&ack).To(Equal(&ack2))
})
It("generates a new ACK (and deletes the cached one) when a new packet arrives", func() {
err := handler.ReceivedPacket(protocol.PacketNumber(1))
Expect(err).ToNot(HaveOccurred())
ack, _ := handler.GetAckFrame(true)
Expect(ack).ToNot(BeNil())
Expect(ack.LargestAcked).To(Equal(protocol.PacketNumber(1)))
err = handler.ReceivedPacket(protocol.PacketNumber(3))
Expect(err).ToNot(HaveOccurred())
ack, _ = handler.GetAckFrame(true)
Expect(ack).ToNot(BeNil())
Expect(ack.LargestAcked).To(Equal(protocol.PacketNumber(3)))
})
It("generates a new ACK when an out-of-order packet arrives", func() {
err := handler.ReceivedPacket(protocol.PacketNumber(1))
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(protocol.PacketNumber(3))
Expect(err).ToNot(HaveOccurred())
ack, _ := handler.GetAckFrame(true)
Expect(ack).ToNot(BeNil())
Expect(ack.AckRanges).To(HaveLen(2))
err = handler.ReceivedPacket(protocol.PacketNumber(2))
Expect(err).ToNot(HaveOccurred())
ack, _ = handler.GetAckFrame(true)
Expect(ack).ToNot(BeNil())
Expect(ack.AckRanges).To(BeEmpty())
})
It("doesn't send old ACK ranges after receiving a StopWaiting", func() {
err := handler.ReceivedPacket(5)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(10)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(11)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(12)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedStopWaiting(&frames.StopWaitingFrame{LeastUnacked: protocol.PacketNumber(11)})
Expect(err).ToNot(HaveOccurred())
ack, err := handler.GetAckFrame(true)
Expect(err).ToNot(HaveOccurred())
Expect(ack).ToNot(BeNil())
Expect(ack.LargestAcked).To(Equal(protocol.PacketNumber(12)))
Expect(ack.LowestAcked).To(Equal(protocol.PacketNumber(11)))
Expect(ack.HasMissingRanges()).To(BeFalse())
})
It("deletes packets from the packetHistory after receiving a StopWaiting, after continuously received packets", func() {
for i := 1; i <= 12; i++ {
err := handler.ReceivedPacket(protocol.PacketNumber(i))
err := handler.ReceivedPacket(protocol.PacketNumber(i), true)
Expect(err).ToNot(HaveOccurred())
}
err := handler.ReceivedStopWaiting(&frames.StopWaitingFrame{LeastUnacked: protocol.PacketNumber(6)})
Expect(err).ToNot(HaveOccurred())
// check that the packets were deleted from the receivedPacketHistory by checking the values in an ACK frame
ack, err := handler.GetAckFrame(true)
Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame()
Expect(ack).ToNot(BeNil())
Expect(ack.LargestAcked).To(Equal(protocol.PacketNumber(12)))
Expect(ack.LowestAcked).To(Equal(protocol.PacketNumber(6)))
Expect(ack.HasMissingRanges()).To(BeFalse())
})
It("resets all counters needed for the ACK queueing decision when sending an ACK", func() {
err := handler.ReceivedPacket(1, true)
Expect(err).ToNot(HaveOccurred())
handler.ackAlarm = time.Now().Add(-time.Minute)
Expect(handler.GetAckFrame()).ToNot(BeNil())
Expect(handler.packetsReceivedSinceLastAck).To(BeZero())
Expect(handler.ackAlarm).To(BeZero())
Expect(handler.retransmittablePacketsReceivedSinceLastAck).To(BeZero())
Expect(handler.ackQueued).To(BeFalse())
})
It("doesn't generate an ACK when none is queued and the timer is not set", func() {
err := handler.ReceivedPacket(1, true)
Expect(err).ToNot(HaveOccurred())
handler.ackQueued = false
handler.ackAlarm = time.Time{}
Expect(handler.GetAckFrame()).To(BeNil())
})
It("doesn't generate an ACK when none is queued and the timer has not yet expired", func() {
err := handler.ReceivedPacket(1, true)
Expect(err).ToNot(HaveOccurred())
handler.ackQueued = false
handler.ackAlarm = time.Now().Add(time.Minute)
Expect(handler.GetAckFrame()).To(BeNil())
})
It("generates an ACK when the timer has expired", func() {
err := handler.ReceivedPacket(1, true)
Expect(err).ToNot(HaveOccurred())
handler.ackQueued = false
handler.ackAlarm = time.Now().Add(-time.Minute)
Expect(handler.GetAckFrame()).ToNot(BeNil())
})
})
})
})

View file

@ -41,14 +41,14 @@ func newPacketPacker(connectionID protocol.ConnectionID, cryptoSetup *handshake.
}
func (p *packetPacker) PackConnectionClose(frame *frames.ConnectionCloseFrame, leastUnacked protocol.PacketNumber) (*packedPacket, error) {
return p.packPacket(nil, []frames.Frame{frame}, leastUnacked, true, false)
return p.packPacket(nil, []frames.Frame{frame}, leastUnacked, true)
}
func (p *packetPacker) PackPacket(stopWaitingFrame *frames.StopWaitingFrame, controlFrames []frames.Frame, leastUnacked protocol.PacketNumber, maySendOnlyAck bool) (*packedPacket, error) {
return p.packPacket(stopWaitingFrame, controlFrames, leastUnacked, false, maySendOnlyAck)
func (p *packetPacker) PackPacket(stopWaitingFrame *frames.StopWaitingFrame, controlFrames []frames.Frame, leastUnacked protocol.PacketNumber) (*packedPacket, error) {
return p.packPacket(stopWaitingFrame, controlFrames, leastUnacked, false)
}
func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, controlFrames []frames.Frame, leastUnacked protocol.PacketNumber, onlySendOneControlFrame, maySendOnlyAck bool) (*packedPacket, error) {
func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, controlFrames []frames.Frame, leastUnacked protocol.PacketNumber, onlySendOneControlFrame bool) (*packedPacket, error) {
if len(controlFrames) > 0 {
p.controlFrames = append(p.controlFrames, controlFrames...)
}
@ -97,18 +97,6 @@ func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, con
if !onlySendOneControlFrame && len(payloadFrames) == 1 && stopWaitingFrame != nil {
return nil, nil
}
// Don't send out packets that only contain an ACK (plus optional STOP_WAITING), if requested
if !maySendOnlyAck {
if len(payloadFrames) == 1 {
if _, ok := payloadFrames[0].(*frames.AckFrame); ok {
return nil, nil
}
} else if len(payloadFrames) == 2 && stopWaitingFrame != nil {
if _, ok := payloadFrames[1].(*frames.AckFrame); ok {
return nil, nil
}
}
}
raw := getPacketBuffer()
buffer := bytes.NewBuffer(raw)

View file

@ -37,7 +37,7 @@ var _ = Describe("Packet packer", func() {
})
It("returns nil when no packet is queued", func() {
p, err := packer.PackPacket(nil, []frames.Frame{}, 0, true)
p, err := packer.PackPacket(nil, []frames.Frame{}, 0)
Expect(p).To(BeNil())
Expect(err).ToNot(HaveOccurred())
})
@ -48,7 +48,7 @@ var _ = Describe("Packet packer", func() {
Data: []byte{0xDE, 0xCA, 0xFB, 0xAD},
}
streamFramer.AddFrameForRetransmission(f)
p, err := packer.PackPacket(nil, []frames.Frame{}, 0, true)
p, err := packer.PackPacket(nil, []frames.Frame{}, 0)
Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil())
b := &bytes.Buffer{}
@ -73,14 +73,14 @@ var _ = Describe("Packet packer", func() {
ErrorCode: 0x1337,
ReasonPhrase: "foobar",
}
p, err := packer.packPacket(&frames.StopWaitingFrame{LeastUnacked: 13}, []frames.Frame{&ccf, &frames.WindowUpdateFrame{StreamID: 37}}, 0, true, true)
p, err := packer.packPacket(&frames.StopWaitingFrame{LeastUnacked: 13}, []frames.Frame{&ccf, &frames.WindowUpdateFrame{StreamID: 37}}, 0, true)
Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(HaveLen(1))
Expect(p.frames[0]).To(Equal(&ccf))
})
It("packs only control frames", func() {
p, err := packer.PackPacket(nil, []frames.Frame{&frames.ConnectionCloseFrame{}}, 0, true)
p, err := packer.PackPacket(nil, []frames.Frame{&frames.RstStreamFrame{}}, 0)
Expect(p).ToNot(BeNil())
Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(HaveLen(1))
@ -88,10 +88,10 @@ var _ = Describe("Packet packer", func() {
})
It("increases the packet number", func() {
p1, err := packer.PackPacket(nil, []frames.Frame{&frames.ConnectionCloseFrame{}}, 0, true)
p1, err := packer.PackPacket(nil, []frames.Frame{&frames.RstStreamFrame{}}, 0)
Expect(err).ToNot(HaveOccurred())
Expect(p1).ToNot(BeNil())
p2, err := packer.PackPacket(nil, []frames.Frame{&frames.ConnectionCloseFrame{}}, 0, true)
p2, err := packer.PackPacket(nil, []frames.Frame{&frames.RstStreamFrame{}}, 0)
Expect(err).ToNot(HaveOccurred())
Expect(p2).ToNot(BeNil())
Expect(p2.number).To(BeNumerically(">", p1.number))
@ -100,7 +100,7 @@ var _ = Describe("Packet packer", func() {
It("packs a StopWaitingFrame first", func() {
packer.packetNumberGenerator.next = 15
swf := &frames.StopWaitingFrame{LeastUnacked: 10}
p, err := packer.PackPacket(swf, []frames.Frame{&frames.ConnectionCloseFrame{}}, 0, true)
p, err := packer.PackPacket(swf, []frames.Frame{&frames.RstStreamFrame{}}, 0)
Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil())
Expect(p.frames).To(HaveLen(2))
@ -111,21 +111,21 @@ var _ = Describe("Packet packer", func() {
packetNumber := protocol.PacketNumber(0xDECAFB) // will result in a 4 byte packet number
packer.packetNumberGenerator.next = packetNumber
swf := &frames.StopWaitingFrame{LeastUnacked: packetNumber - 0x100}
p, err := packer.PackPacket(swf, []frames.Frame{&frames.ConnectionCloseFrame{}}, 0, true)
p, err := packer.PackPacket(swf, []frames.Frame{&frames.RstStreamFrame{}}, 0)
Expect(err).ToNot(HaveOccurred())
Expect(p.frames[0].(*frames.StopWaitingFrame).PacketNumberLen).To(Equal(protocol.PacketNumberLen4))
})
It("does not pack a packet containing only a StopWaitingFrame", func() {
swf := &frames.StopWaitingFrame{LeastUnacked: 10}
p, err := packer.PackPacket(swf, []frames.Frame{}, 0, true)
p, err := packer.PackPacket(swf, []frames.Frame{}, 0)
Expect(p).To(BeNil())
Expect(err).ToNot(HaveOccurred())
})
It("packs a packet if it has queued control frames, but no new control frames", func() {
packer.controlFrames = []frames.Frame{&frames.BlockedFrame{StreamID: 0}}
p, err := packer.PackPacket(nil, []frames.Frame{}, 0, true)
p, err := packer.PackPacket(nil, []frames.Frame{}, 0)
Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil())
})
@ -169,7 +169,7 @@ var _ = Describe("Packet packer", func() {
It("only increases the packet number when there is an actual packet to send", func() {
packer.packetNumberGenerator.nextToSkip = 1000
p, err := packer.PackPacket(nil, []frames.Frame{}, 0, true)
p, err := packer.PackPacket(nil, []frames.Frame{}, 0)
Expect(p).To(BeNil())
Expect(err).ToNot(HaveOccurred())
Expect(packer.packetNumberGenerator.Peek()).To(Equal(protocol.PacketNumber(1)))
@ -178,7 +178,7 @@ var _ = Describe("Packet packer", func() {
Data: []byte{0xDE, 0xCA, 0xFB, 0xAD},
}
streamFramer.AddFrameForRetransmission(f)
p, err = packer.PackPacket(nil, []frames.Frame{}, 0, true)
p, err = packer.PackPacket(nil, []frames.Frame{}, 0)
Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil())
Expect(p.number).To(Equal(protocol.PacketNumber(1)))
@ -219,12 +219,12 @@ var _ = Describe("Packet packer", func() {
}
streamFramer.AddFrameForRetransmission(f1)
streamFramer.AddFrameForRetransmission(f2)
p, err := packer.PackPacket(nil, []frames.Frame{}, 0, true)
p, err := packer.PackPacket(nil, []frames.Frame{}, 0)
Expect(err).ToNot(HaveOccurred())
Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize - 1)))
Expect(p.frames).To(HaveLen(1))
Expect(p.frames[0].(*frames.StreamFrame).DataLenPresent).To(BeFalse())
p, err = packer.PackPacket(nil, []frames.Frame{}, 0, true)
p, err = packer.PackPacket(nil, []frames.Frame{}, 0)
Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(HaveLen(1))
Expect(p.frames[0].(*frames.StreamFrame).DataLenPresent).To(BeFalse())
@ -246,7 +246,7 @@ var _ = Describe("Packet packer", func() {
streamFramer.AddFrameForRetransmission(f1)
streamFramer.AddFrameForRetransmission(f2)
streamFramer.AddFrameForRetransmission(f3)
p, err := packer.PackPacket(nil, []frames.Frame{}, 0, true)
p, err := packer.PackPacket(nil, []frames.Frame{}, 0)
Expect(p).ToNot(BeNil())
Expect(err).ToNot(HaveOccurred())
b := &bytes.Buffer{}
@ -300,23 +300,23 @@ var _ = Describe("Packet packer", func() {
}
streamFramer.AddFrameForRetransmission(f1)
streamFramer.AddFrameForRetransmission(f2)
p, err := packer.PackPacket(nil, []frames.Frame{}, 0, true)
p, err := packer.PackPacket(nil, []frames.Frame{}, 0)
Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(HaveLen(1))
Expect(p.frames[0].(*frames.StreamFrame).DataLenPresent).To(BeFalse())
Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize)))
p, err = packer.PackPacket(nil, []frames.Frame{}, 0, true)
p, err = packer.PackPacket(nil, []frames.Frame{}, 0)
Expect(p.frames).To(HaveLen(2))
Expect(p.frames[0].(*frames.StreamFrame).DataLenPresent).To(BeTrue())
Expect(p.frames[1].(*frames.StreamFrame).DataLenPresent).To(BeFalse())
Expect(err).ToNot(HaveOccurred())
Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize)))
p, err = packer.PackPacket(nil, []frames.Frame{}, 0, true)
p, err = packer.PackPacket(nil, []frames.Frame{}, 0)
Expect(p.frames).To(HaveLen(1))
Expect(p.frames[0].(*frames.StreamFrame).DataLenPresent).To(BeFalse())
Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil())
p, err = packer.PackPacket(nil, []frames.Frame{}, 0, true)
p, err = packer.PackPacket(nil, []frames.Frame{}, 0)
Expect(err).ToNot(HaveOccurred())
Expect(p).To(BeNil())
})
@ -329,7 +329,7 @@ var _ = Describe("Packet packer", func() {
minLength, _ := f.MinLength(0)
f.Data = bytes.Repeat([]byte{'f'}, int(protocol.MaxFrameAndPublicHeaderSize-publicHeaderLen-minLength+1)) // + 1 since MinceLength is 1 bigger than the actual StreamFrame header
streamFramer.AddFrameForRetransmission(f)
p, err := packer.PackPacket(nil, []frames.Frame{}, 0, true)
p, err := packer.PackPacket(nil, []frames.Frame{}, 0)
Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil())
Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize)))
@ -395,25 +395,21 @@ var _ = Describe("Packet packer", func() {
})
It("returns nil if we only have a single STOP_WAITING", func() {
p, err := packer.PackPacket(&frames.StopWaitingFrame{}, nil, 0, false)
p, err := packer.PackPacket(&frames.StopWaitingFrame{}, nil, 0)
Expect(err).NotTo(HaveOccurred())
Expect(p).To(BeNil())
})
It("returns nil if we only have a single STOP_WAITING and an ACK", func() {
p, err := packer.PackPacket(&frames.StopWaitingFrame{}, []frames.Frame{&frames.AckFrame{}}, 0, false)
It("packs a single ACK", func() {
ack := &frames.AckFrame{LargestAcked: 42}
p, err := packer.PackPacket(nil, []frames.Frame{ack}, 0)
Expect(err).NotTo(HaveOccurred())
Expect(p).To(BeNil())
})
It("returns nil if we only have a single ACK", func() {
p, err := packer.PackPacket(nil, []frames.Frame{&frames.AckFrame{}}, 0, false)
Expect(err).NotTo(HaveOccurred())
Expect(p).To(BeNil())
Expect(p).ToNot(BeNil())
Expect(p.frames[0]).To(Equal(ack))
})
It("does not return nil if we only have a single ACK but request it to be sent", func() {
p, err := packer.PackPacket(nil, []frames.Frame{&frames.AckFrame{}}, 0, true)
p, err := packer.PackPacket(nil, []frames.Frame{&frames.AckFrame{}}, 0)
Expect(err).NotTo(HaveOccurred())
Expect(p).ToNot(BeNil())
})

View file

@ -71,6 +71,12 @@ const MaxTrackedReceivedPackets = 2 * DefaultMaxCongestionWindow
// MaxTrackedReceivedAckRanges is the maximum number of ACK ranges tracked
const MaxTrackedReceivedAckRanges = DefaultMaxCongestionWindow
// MaxPacketsReceivedBeforeAckSend is the number of packets that can be received before an ACK frame is sent
const MaxPacketsReceivedBeforeAckSend = 20
// RetransmittablePacketsBeforeAck is the number of retransmittable that an ACK is sent for
const RetransmittablePacketsBeforeAck = 2
// MaxStreamFrameSorterGaps is the maximum number of gaps between received StreamFrames
// prevents DoS attacks against the streamFrameSorter
const MaxStreamFrameSorterGaps = 1000

View file

@ -4,7 +4,6 @@ import (
"errors"
"fmt"
"net"
"runtime"
"sync/atomic"
"time"
@ -77,7 +76,7 @@ type Session struct {
undecryptablePackets []*receivedPacket
aeadChanged chan struct{}
delayedAckOriginTime time.Time
nextAckScheduledTime time.Time
connectionParameters handshake.ConnectionParametersManager
@ -99,12 +98,9 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
connectionParameters := handshake.NewConnectionParamatersManager(v)
var sentPacketHandler ackhandler.SentPacketHandler
var receivedPacketHandler ackhandler.ReceivedPacketHandler
rttStats := &congestion.RTTStats{}
sentPacketHandler = ackhandler.NewSentPacketHandler(rttStats)
receivedPacketHandler = ackhandler.NewReceivedPacketHandler()
flowControlManager := flowcontrol.NewFlowControlManager(connectionParameters, rttStats)
now := time.Now()
@ -118,7 +114,6 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
connectionParameters: connectionParameters,
sentPacketHandler: sentPacketHandler,
receivedPacketHandler: receivedPacketHandler,
flowControlManager: flowControlManager,
receivedPackets: make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets),
@ -133,6 +128,7 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
sessionCreationTime: now,
}
session.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(session.ackAlarmChanged)
session.streamsMap = newStreamsMap(session.newStream, session.connectionParameters)
cryptoStream, _ := session.GetOrOpenStream(1)
@ -195,9 +191,6 @@ runLoop:
// This is a bit unclean, but works properly, since the packet always
// begins with the public header and we never copy it.
putPacketBuffer(p.publicHeader.Raw)
if s.delayedAckOriginTime.IsZero() {
s.delayedAckOriginTime = p.rcvTime
}
case <-s.aeadChanged:
s.tryDecryptingQueuedPackets()
}
@ -225,8 +218,8 @@ runLoop:
func (s *Session) maybeResetTimer() {
nextDeadline := s.lastNetworkActivityTime.Add(s.idleTimeout())
if !s.delayedAckOriginTime.IsZero() {
nextDeadline = utils.MinTime(nextDeadline, s.delayedAckOriginTime.Add(protocol.AckSendDelay))
if !s.nextAckScheduledTime.IsZero() {
nextDeadline = utils.MinTime(nextDeadline, s.nextAckScheduledTime)
}
if rtoTime := s.sentPacketHandler.TimeOfFirstRTO(); !rtoTime.IsZero() {
nextDeadline = utils.MinTime(nextDeadline, rtoTime)
@ -291,7 +284,7 @@ func (s *Session) handlePacketImpl(p *receivedPacket) error {
// Only do this after decrypting, so we are sure the packet is not attacker-controlled
s.largestRcvdPacketNumber = utils.MaxPacketNumber(s.largestRcvdPacketNumber, hdr.PacketNumber)
err = s.receivedPacketHandler.ReceivedPacket(hdr.PacketNumber)
err = s.receivedPacketHandler.ReceivedPacket(hdr.PacketNumber, packet.IsRetransmittable())
// ignore duplicate packets
if err == ackhandler.ErrDuplicatePacket {
utils.Infof("Ignoring packet 0x%x due to ErrDuplicatePacket", hdr.PacketNumber)
@ -514,45 +507,26 @@ func (s *Session) sendPacket() error {
if err != nil {
return err
}
for _, wuf := range windowUpdateFrames {
controlFrames = append(controlFrames, wuf)
}
ack, err := s.receivedPacketHandler.GetAckFrame(false)
if err != nil {
return err
}
ack := s.receivedPacketHandler.GetAckFrame()
if ack != nil {
controlFrames = append(controlFrames, ack)
}
// Check whether we are allowed to send a packet containing only an ACK
maySendOnlyAck := time.Now().Sub(s.delayedAckOriginTime) > protocol.AckSendDelay
if runtime.GOOS == "windows" {
maySendOnlyAck = true
}
hasRetransmission := s.streamFramer.HasFramesForRetransmission()
var stopWaitingFrame *frames.StopWaitingFrame
if ack != nil || hasRetransmission {
stopWaitingFrame = s.sentPacketHandler.GetStopWaitingFrame(hasRetransmission)
}
packet, err := s.packer.PackPacket(stopWaitingFrame, controlFrames, s.sentPacketHandler.GetLeastUnacked(), maySendOnlyAck)
packet, err := s.packer.PackPacket(stopWaitingFrame, controlFrames, s.sentPacketHandler.GetLeastUnacked())
if err != nil {
return err
}
if packet == nil {
return nil
}
// Pop the ACK frame now that we are sure we're gonna send it
_, err = s.receivedPacketHandler.GetAckFrame(true)
if err != nil {
return err
}
// send every window update twice
for _, f := range windowUpdateFrames {
s.packer.QueueControlFrameForNextPacket(f)
}
@ -567,13 +541,13 @@ func (s *Session) sendPacket() error {
}
s.logPacket(packet)
s.delayedAckOriginTime = time.Time{}
err = s.conn.write(packet.raw)
putPacketBuffer(packet.raw)
if err != nil {
return err
}
s.nextAckScheduledTime = time.Time{}
}
}
@ -695,6 +669,11 @@ func (s *Session) getWindowUpdateFrames() ([]*frames.WindowUpdateFrame, error) {
return res, nil
}
func (s *Session) ackAlarmChanged(t time.Time) {
s.nextAckScheduledTime = t
s.maybeResetTimer()
}
// RemoteAddr returns the net.UDPAddr of the client
func (s *Session) RemoteAddr() *net.UDPAddr {
return s.conn.RemoteAddr()

View file

@ -88,6 +88,22 @@ func newMockSentPacketHandler() ackhandler.SentPacketHandler {
return &mockSentPacketHandler{}
}
var _ ackhandler.SentPacketHandler = &mockSentPacketHandler{}
type mockReceivedPacketHandler struct {
nextAckFrame *frames.AckFrame
}
func (m *mockReceivedPacketHandler) GetAckFrame() *frames.AckFrame { return m.nextAckFrame }
func (m *mockReceivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error {
panic("not implemented")
}
func (m *mockReceivedPacketHandler) ReceivedStopWaiting(*frames.StopWaitingFrame) error {
panic("not implemented")
}
var _ ackhandler.ReceivedPacketHandler = &mockReceivedPacketHandler{}
var _ = Describe("Session", func() {
var (
session *Session
@ -602,7 +618,7 @@ var _ = Describe("Session", func() {
Context("sending packets", func() {
It("sends ack frames", func() {
packetNumber := protocol.PacketNumber(0x035E)
session.receivedPacketHandler.ReceivedPacket(packetNumber)
session.receivedPacketHandler.ReceivedPacket(packetNumber, true)
err := session.sendPacket()
Expect(err).NotTo(HaveOccurred())
Expect(conn.written).To(HaveLen(1))
@ -734,6 +750,17 @@ var _ = Describe("Session", func() {
s.(*stream).getDataForWriting(1000) // unblock
})
It("sets the timer to the ack timer", func() {
rph := &mockReceivedPacketHandler{}
rph.nextAckFrame = &frames.AckFrame{LargestAcked: 0x1337}
session.receivedPacketHandler = rph
go session.run()
session.ackAlarmChanged(time.Now().Add(10 * time.Millisecond))
time.Sleep(10 * time.Millisecond)
Eventually(func() int { return len(conn.written) }).ShouldNot(BeZero())
Expect(conn.written[0]).To(ContainSubstring(string([]byte{0x37, 0x13})))
})
Context("bundling of small packets", func() {
It("bundles two small frames of different streams into one packet", func() {
s1, err := session.GetOrOpenStream(5)
@ -783,7 +810,7 @@ var _ = Describe("Session", func() {
It("sends a queued ACK frame only once", func() {
packetNumber := protocol.PacketNumber(0x1337)
session.receivedPacketHandler.ReceivedPacket(packetNumber)
session.receivedPacketHandler.ReceivedPacket(packetNumber, true)
s, err := session.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred())