implement a sophisticated ack algorithm

fixes #354
This commit is contained in:
Marten Seemann 2017-01-09 13:02:01 +07:00
parent 634d2dc6b7
commit edf6ee3dfc
No known key found for this signature in database
GPG key ID: 3603F40B121FCDEA
8 changed files with 370 additions and 256 deletions

View file

@ -28,8 +28,8 @@ type SentPacketHandler interface {
// ReceivedPacketHandler handles ACKs needed to send for incoming packets // ReceivedPacketHandler handles ACKs needed to send for incoming packets
type ReceivedPacketHandler interface { type ReceivedPacketHandler interface {
ReceivedPacket(packetNumber protocol.PacketNumber) error ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error
ReceivedStopWaiting(*frames.StopWaitingFrame) error ReceivedStopWaiting(*frames.StopWaitingFrame) error
GetAckFrame(dequeue bool) (*frames.AckFrame, error) GetAckFrame() *frames.AckFrame
} }

View file

@ -20,22 +20,34 @@ var errInvalidPacketNumber = errors.New("ReceivedPacketHandler: Invalid packet n
type receivedPacketHandler struct { type receivedPacketHandler struct {
largestObserved protocol.PacketNumber largestObserved protocol.PacketNumber
ignorePacketsBelow protocol.PacketNumber ignorePacketsBelow protocol.PacketNumber
currentAckFrame *frames.AckFrame largestObservedReceivedTime time.Time
stateChanged bool // has an ACK for this state already been sent? Will be set to false every time a new packet arrives, and to false every time an ACK is sent
packetHistory *receivedPacketHistory packetHistory *receivedPacketHistory
largestObservedReceivedTime time.Time ackSendDelay time.Duration
packetsReceivedSinceLastAck int
retransmittablePacketsReceivedSinceLastAck int
ackQueued bool
ackAlarm time.Time
ackAlarmResetCallback func(time.Time)
lastAck *frames.AckFrame
} }
// NewReceivedPacketHandler creates a new receivedPacketHandler // NewReceivedPacketHandler creates a new receivedPacketHandler
func NewReceivedPacketHandler() ReceivedPacketHandler { func NewReceivedPacketHandler(ackAlarmResetCallback func(time.Time)) ReceivedPacketHandler {
// create a stopped timer, see https://github.com/golang/go/issues/12721#issuecomment-143010182
timer := time.NewTimer(0)
<-timer.C
return &receivedPacketHandler{ return &receivedPacketHandler{
packetHistory: newReceivedPacketHistory(), packetHistory: newReceivedPacketHistory(),
ackAlarmResetCallback: ackAlarmResetCallback,
ackSendDelay: protocol.AckSendDelay,
} }
} }
func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber) error { func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error {
if packetNumber == 0 { if packetNumber == 0 {
return errInvalidPacketNumber return errInvalidPacketNumber
} }
@ -55,14 +67,12 @@ func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumbe
return err return err
} }
h.stateChanged = true
h.currentAckFrame = nil
if packetNumber > h.largestObserved { if packetNumber > h.largestObserved {
h.largestObserved = packetNumber h.largestObserved = packetNumber
h.largestObservedReceivedTime = time.Now() h.largestObservedReceivedTime = time.Now()
} }
h.maybeQueueAck(packetNumber, shouldInstigateAck)
return nil return nil
} }
@ -78,29 +88,79 @@ func (h *receivedPacketHandler) ReceivedStopWaiting(f *frames.StopWaitingFrame)
return nil return nil
} }
func (h *receivedPacketHandler) GetAckFrame(dequeue bool) (*frames.AckFrame, error) { func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, shouldInstigateAck bool) {
if !h.stateChanged { var ackAlarmSet bool
return nil, nil h.packetsReceivedSinceLastAck++
if shouldInstigateAck {
h.retransmittablePacketsReceivedSinceLastAck++
} }
if dequeue { // always ack the first packet
h.stateChanged = false if h.lastAck == nil {
h.ackQueued = true
} }
if h.currentAckFrame != nil { // Always send an ack every 20 packets in order to allow the peer to discard
return h.currentAckFrame, nil // information from the SentPacketManager and provide an RTT measurement.
if h.packetsReceivedSinceLastAck >= protocol.MaxPacketsReceivedBeforeAckSend {
h.ackQueued = true
}
// if the packet number is smaller than the largest acked packet, it must have been reported missing with the last ACK
// note that it cannot be a duplicate because they're already filtered out by ReceivedPacket()
if h.lastAck != nil && packetNumber < h.lastAck.LargestAcked {
h.ackQueued = true
}
// check if a new missing range above the previously was created
if h.lastAck != nil && h.packetHistory.GetHighestAckRange().FirstPacketNumber > h.lastAck.LargestAcked {
h.ackQueued = true
}
if !h.ackQueued && shouldInstigateAck {
if h.retransmittablePacketsReceivedSinceLastAck >= protocol.RetransmittablePacketsBeforeAck {
h.ackQueued = true
} else {
if h.ackAlarm.IsZero() {
h.ackAlarm = time.Now().Add(h.ackSendDelay)
ackAlarmSet = true
}
}
}
if h.ackQueued {
// cancel the ack alarm
h.ackAlarm = time.Time{}
ackAlarmSet = false
}
if ackAlarmSet {
h.ackAlarmResetCallback(h.ackAlarm)
}
}
func (h *receivedPacketHandler) GetAckFrame() *frames.AckFrame {
if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(time.Now())) {
return nil
} }
ackRanges := h.packetHistory.GetAckRanges() ackRanges := h.packetHistory.GetAckRanges()
h.currentAckFrame = &frames.AckFrame{ ack := &frames.AckFrame{
LargestAcked: h.largestObserved, LargestAcked: h.largestObserved,
LowestAcked: ackRanges[len(ackRanges)-1].FirstPacketNumber, LowestAcked: ackRanges[len(ackRanges)-1].FirstPacketNumber,
PacketReceivedTime: h.largestObservedReceivedTime, PacketReceivedTime: h.largestObservedReceivedTime,
} }
if len(ackRanges) > 1 { if len(ackRanges) > 1 {
h.currentAckFrame.AckRanges = ackRanges ack.AckRanges = ackRanges
} }
return h.currentAckFrame, nil h.lastAck = ack
h.ackAlarm = time.Time{}
h.ackQueued = false
h.packetsReceivedSinceLastAck = 0
h.retransmittablePacketsReceivedSinceLastAck = 0
return ack
} }

View file

@ -13,56 +13,62 @@ import (
var _ = Describe("receivedPacketHandler", func() { var _ = Describe("receivedPacketHandler", func() {
var ( var (
handler *receivedPacketHandler handler *receivedPacketHandler
ackAlarmCallbackCalled bool
) )
ackAlarmCallback := func(time.Time) {
ackAlarmCallbackCalled = true
}
BeforeEach(func() { BeforeEach(func() {
handler = NewReceivedPacketHandler().(*receivedPacketHandler) ackAlarmCallbackCalled = false
handler = NewReceivedPacketHandler(ackAlarmCallback).(*receivedPacketHandler)
}) })
Context("accepting packets", func() { Context("accepting packets", func() {
It("handles a packet that arrives late", func() { It("handles a packet that arrives late", func() {
err := handler.ReceivedPacket(protocol.PacketNumber(1)) err := handler.ReceivedPacket(protocol.PacketNumber(1), true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(protocol.PacketNumber(3)) err = handler.ReceivedPacket(protocol.PacketNumber(3), true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(protocol.PacketNumber(2)) err = handler.ReceivedPacket(protocol.PacketNumber(2), true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
}) })
It("rejects packets with packet number 0", func() { It("rejects packets with packet number 0", func() {
err := handler.ReceivedPacket(protocol.PacketNumber(0)) err := handler.ReceivedPacket(protocol.PacketNumber(0), true)
Expect(err).To(MatchError(errInvalidPacketNumber)) Expect(err).To(MatchError(errInvalidPacketNumber))
}) })
It("rejects a duplicate package", func() { It("rejects a duplicate package", func() {
for i := 1; i < 5; i++ { for i := 1; i < 5; i++ {
err := handler.ReceivedPacket(protocol.PacketNumber(i)) err := handler.ReceivedPacket(protocol.PacketNumber(i), true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
} }
err := handler.ReceivedPacket(4) err := handler.ReceivedPacket(4, true)
Expect(err).To(MatchError(ErrDuplicatePacket)) Expect(err).To(MatchError(ErrDuplicatePacket))
}) })
It("ignores a packet with PacketNumber less than the LeastUnacked of a previously received StopWaiting", func() { It("ignores a packet with PacketNumber less than the LeastUnacked of a previously received StopWaiting", func() {
err := handler.ReceivedPacket(5) err := handler.ReceivedPacket(5, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedStopWaiting(&frames.StopWaitingFrame{LeastUnacked: 10}) err = handler.ReceivedStopWaiting(&frames.StopWaitingFrame{LeastUnacked: 10})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(9) err = handler.ReceivedPacket(9, true)
Expect(err).To(MatchError(ErrPacketSmallerThanLastStopWaiting)) Expect(err).To(MatchError(ErrPacketSmallerThanLastStopWaiting))
}) })
It("does not ignore a packet with PacketNumber equal to LeastUnacked of a previously received StopWaiting", func() { It("does not ignore a packet with PacketNumber equal to LeastUnacked of a previously received StopWaiting", func() {
err := handler.ReceivedPacket(5) err := handler.ReceivedPacket(5, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedStopWaiting(&frames.StopWaitingFrame{LeastUnacked: 10}) err = handler.ReceivedStopWaiting(&frames.StopWaitingFrame{LeastUnacked: 10})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(10) err = handler.ReceivedPacket(10, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
}) })
It("saves the time when each packet arrived", func() { It("saves the time when each packet arrived", func() {
err := handler.ReceivedPacket(protocol.PacketNumber(3)) err := handler.ReceivedPacket(protocol.PacketNumber(3), true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(handler.largestObservedReceivedTime).To(BeTemporally("~", time.Now(), 10*time.Millisecond)) Expect(handler.largestObservedReceivedTime).To(BeTemporally("~", time.Now(), 10*time.Millisecond))
}) })
@ -70,7 +76,7 @@ var _ = Describe("receivedPacketHandler", func() {
It("updates the largestObserved and the largestObservedReceivedTime", func() { It("updates the largestObserved and the largestObservedReceivedTime", func() {
handler.largestObserved = 3 handler.largestObserved = 3
handler.largestObservedReceivedTime = time.Now().Add(-1 * time.Second) handler.largestObservedReceivedTime = time.Now().Add(-1 * time.Second)
err := handler.ReceivedPacket(5) err := handler.ReceivedPacket(5, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(handler.largestObserved).To(Equal(protocol.PacketNumber(5))) Expect(handler.largestObserved).To(Equal(protocol.PacketNumber(5)))
Expect(handler.largestObservedReceivedTime).To(BeTemporally("~", time.Now(), 10*time.Millisecond)) Expect(handler.largestObservedReceivedTime).To(BeTemporally("~", time.Now(), 10*time.Millisecond))
@ -80,27 +86,27 @@ var _ = Describe("receivedPacketHandler", func() {
timestamp := time.Now().Add(-1 * time.Second) timestamp := time.Now().Add(-1 * time.Second)
handler.largestObserved = 5 handler.largestObserved = 5
handler.largestObservedReceivedTime = timestamp handler.largestObservedReceivedTime = timestamp
err := handler.ReceivedPacket(4) err := handler.ReceivedPacket(4, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(handler.largestObserved).To(Equal(protocol.PacketNumber(5))) Expect(handler.largestObserved).To(Equal(protocol.PacketNumber(5)))
Expect(handler.largestObservedReceivedTime).To(Equal(timestamp)) Expect(handler.largestObservedReceivedTime).To(Equal(timestamp))
}) })
It("doesn't store more than MaxTrackedReceivedPackets packets", func() { It("doesn't store more than MaxTrackedReceivedPackets packets", func() {
err := handler.ReceivedPacket(1) err := handler.ReceivedPacket(1, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
for i := protocol.PacketNumber(3); i < 3+protocol.MaxTrackedReceivedPackets-1; i++ { for i := protocol.PacketNumber(3); i < 3+protocol.MaxTrackedReceivedPackets-1; i++ {
err := handler.ReceivedPacket(protocol.PacketNumber(i)) err := handler.ReceivedPacket(protocol.PacketNumber(i), true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
} }
err = handler.ReceivedPacket(protocol.PacketNumber(protocol.MaxTrackedReceivedPackets) + 10) err = handler.ReceivedPacket(protocol.PacketNumber(protocol.MaxTrackedReceivedPackets)+10, true)
Expect(err).To(MatchError(errTooManyOutstandingReceivedPackets)) Expect(err).To(MatchError(errTooManyOutstandingReceivedPackets))
}) })
It("passes on errors from receivedPacketHistory", func() { It("passes on errors from receivedPacketHistory", func() {
var err error var err error
for i := protocol.PacketNumber(0); i < 5*protocol.MaxTrackedReceivedAckRanges; i++ { for i := protocol.PacketNumber(0); i < 5*protocol.MaxTrackedReceivedAckRanges; i++ {
err = handler.ReceivedPacket(2*i + 1) err = handler.ReceivedPacket(2*i+1, true)
// this will eventually return an error // this will eventually return an error
// details about when exactly the receivedPacketHistory errors are tested there // details about when exactly the receivedPacketHistory errors are tested there
if err != nil { if err != nil {
@ -120,7 +126,7 @@ var _ = Describe("receivedPacketHandler", func() {
It("increase the ignorePacketsBelow number, even if all packets below the LeastUnacked were already acked", func() { It("increase the ignorePacketsBelow number, even if all packets below the LeastUnacked were already acked", func() {
for i := 1; i < 20; i++ { for i := 1; i < 20; i++ {
err := handler.ReceivedPacket(protocol.PacketNumber(i)) err := handler.ReceivedPacket(protocol.PacketNumber(i), true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
} }
err := handler.ReceivedStopWaiting(&frames.StopWaitingFrame{LeastUnacked: protocol.PacketNumber(12)}) err := handler.ReceivedStopWaiting(&frames.StopWaitingFrame{LeastUnacked: protocol.PacketNumber(12)})
@ -138,26 +144,129 @@ var _ = Describe("receivedPacketHandler", func() {
}) })
}) })
Context("ACK package generation", func() { Context("ACKs", func() {
Context("queueing ACKs", func() {
receiveAndAck10Packets := func() {
for i := 1; i <= 10; i++ {
err := handler.ReceivedPacket(protocol.PacketNumber(i), true)
Expect(err).ToNot(HaveOccurred())
}
Expect(handler.GetAckFrame()).ToNot(BeNil())
Expect(handler.ackQueued).To(BeFalse())
ackAlarmCallbackCalled = false
}
It("always queues an ACK for the first packet", func() {
err := handler.ReceivedPacket(1, false)
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeTrue())
Expect(ackAlarmCallbackCalled).To(BeFalse())
})
It("only queues one ACK for many non-retransmittable packets", func() {
receiveAndAck10Packets()
for i := 11; i < 10+protocol.MaxPacketsReceivedBeforeAckSend; i++ {
err := handler.ReceivedPacket(protocol.PacketNumber(i), false)
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeFalse())
}
err := handler.ReceivedPacket(10+protocol.MaxPacketsReceivedBeforeAckSend, false)
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeTrue())
Expect(ackAlarmCallbackCalled).To(BeFalse())
})
It("queues an ACK for every second retransmittable packet, if they are arriving fast", func() {
receiveAndAck10Packets()
err := handler.ReceivedPacket(11, true)
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeFalse())
Expect(ackAlarmCallbackCalled).To(BeTrue())
ackAlarmCallbackCalled = false
err = handler.ReceivedPacket(12, true)
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeTrue())
Expect(ackAlarmCallbackCalled).To(BeFalse())
})
It("only sets the timer when receiving a retransmittable packets", func() {
receiveAndAck10Packets()
err := handler.ReceivedPacket(11, false)
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeFalse())
Expect(handler.ackAlarm).To(BeZero())
err = handler.ReceivedPacket(12, true)
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeFalse())
Expect(handler.ackAlarm).ToNot(BeZero())
Expect(ackAlarmCallbackCalled).To(BeTrue())
})
It("queues an ACK if it was reported missing before", func() {
receiveAndAck10Packets()
err := handler.ReceivedPacket(11, true)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(13, true)
Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame() // ACK: 1 and 3, missing: 2
Expect(ack).ToNot(BeNil())
Expect(ack.HasMissingRanges()).To(BeTrue())
Expect(handler.ackQueued).To(BeFalse())
err = handler.ReceivedPacket(12, false)
Expect(err).ToNot(HaveOccurred())
Expect(handler.ackQueued).To(BeTrue())
})
It("queues an ACK if it creates a new missing range", func() {
receiveAndAck10Packets()
for i := 11; i < 16; i++ {
err := handler.ReceivedPacket(protocol.PacketNumber(i), true)
Expect(err).ToNot(HaveOccurred())
}
Expect(handler.GetAckFrame()).ToNot(BeNil())
handler.ReceivedPacket(20, true) // we now know that packets 16 to 19 are missing
Expect(handler.ackQueued).To(BeTrue())
})
})
Context("ACK generation", func() {
BeforeEach(func() {
handler.ackQueued = true
})
It("generates a simple ACK frame", func() { It("generates a simple ACK frame", func() {
err := handler.ReceivedPacket(protocol.PacketNumber(1)) err := handler.ReceivedPacket(1, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(protocol.PacketNumber(2)) err = handler.ReceivedPacket(2, true)
Expect(err).ToNot(HaveOccurred())
ack, err := handler.GetAckFrame(true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame()
Expect(ack).ToNot(BeNil())
Expect(ack.LargestAcked).To(Equal(protocol.PacketNumber(2))) Expect(ack.LargestAcked).To(Equal(protocol.PacketNumber(2)))
Expect(ack.LowestAcked).To(Equal(protocol.PacketNumber(1))) Expect(ack.LowestAcked).To(Equal(protocol.PacketNumber(1)))
Expect(ack.AckRanges).To(BeEmpty()) Expect(ack.AckRanges).To(BeEmpty())
}) })
It("saves the last sent ACK", func() {
err := handler.ReceivedPacket(1, true)
Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame()
Expect(ack).ToNot(BeNil())
Expect(handler.lastAck).To(Equal(ack))
err = handler.ReceivedPacket(2, true)
Expect(err).ToNot(HaveOccurred())
handler.ackQueued = true
ack = handler.GetAckFrame()
Expect(ack).ToNot(BeNil())
Expect(handler.lastAck).To(Equal(ack))
})
It("generates an ACK frame with missing packets", func() { It("generates an ACK frame with missing packets", func() {
err := handler.ReceivedPacket(protocol.PacketNumber(1)) err := handler.ReceivedPacket(1, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(protocol.PacketNumber(4)) err = handler.ReceivedPacket(4, true)
Expect(err).ToNot(HaveOccurred())
ack, err := handler.GetAckFrame(true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
ack := handler.GetAckFrame()
Expect(ack).ToNot(BeNil())
Expect(ack.LargestAcked).To(Equal(protocol.PacketNumber(4))) Expect(ack.LargestAcked).To(Equal(protocol.PacketNumber(4)))
Expect(ack.LowestAcked).To(Equal(protocol.PacketNumber(1))) Expect(ack.LowestAcked).To(Equal(protocol.PacketNumber(1)))
Expect(ack.AckRanges).To(HaveLen(2)) Expect(ack.AckRanges).To(HaveLen(2))
@ -165,106 +274,55 @@ var _ = Describe("receivedPacketHandler", func() {
Expect(ack.AckRanges[1]).To(Equal(frames.AckRange{FirstPacketNumber: 1, LastPacketNumber: 1})) Expect(ack.AckRanges[1]).To(Equal(frames.AckRange{FirstPacketNumber: 1, LastPacketNumber: 1}))
}) })
It("does not generate an ACK if an ACK has already been sent for the largest Packet", func() {
err := handler.ReceivedPacket(protocol.PacketNumber(1))
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(protocol.PacketNumber(2))
Expect(err).ToNot(HaveOccurred())
ack, err := handler.GetAckFrame(true)
Expect(err).ToNot(HaveOccurred())
Expect(ack).ToNot(BeNil())
ack, err = handler.GetAckFrame(true)
Expect(err).ToNot(HaveOccurred())
Expect(ack).To(BeNil())
})
It("does not dequeue an ACK frame if told so", func() {
err := handler.ReceivedPacket(protocol.PacketNumber(2))
Expect(err).ToNot(HaveOccurred())
ack, err := handler.GetAckFrame(false)
Expect(err).ToNot(HaveOccurred())
Expect(ack).ToNot(BeNil())
ack, err = handler.GetAckFrame(false)
Expect(err).ToNot(HaveOccurred())
Expect(ack).ToNot(BeNil())
ack, err = handler.GetAckFrame(false)
Expect(err).ToNot(HaveOccurred())
Expect(ack).ToNot(BeNil())
})
It("returns a cached ACK frame if the ACK was not dequeued", func() {
err := handler.ReceivedPacket(protocol.PacketNumber(2))
Expect(err).ToNot(HaveOccurred())
ack, err := handler.GetAckFrame(false)
Expect(err).ToNot(HaveOccurred())
Expect(ack).ToNot(BeNil())
ack2, err := handler.GetAckFrame(false)
Expect(err).ToNot(HaveOccurred())
Expect(ack2).ToNot(BeNil())
Expect(&ack).To(Equal(&ack2))
})
It("generates a new ACK (and deletes the cached one) when a new packet arrives", func() {
err := handler.ReceivedPacket(protocol.PacketNumber(1))
Expect(err).ToNot(HaveOccurred())
ack, _ := handler.GetAckFrame(true)
Expect(ack).ToNot(BeNil())
Expect(ack.LargestAcked).To(Equal(protocol.PacketNumber(1)))
err = handler.ReceivedPacket(protocol.PacketNumber(3))
Expect(err).ToNot(HaveOccurred())
ack, _ = handler.GetAckFrame(true)
Expect(ack).ToNot(BeNil())
Expect(ack.LargestAcked).To(Equal(protocol.PacketNumber(3)))
})
It("generates a new ACK when an out-of-order packet arrives", func() {
err := handler.ReceivedPacket(protocol.PacketNumber(1))
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(protocol.PacketNumber(3))
Expect(err).ToNot(HaveOccurred())
ack, _ := handler.GetAckFrame(true)
Expect(ack).ToNot(BeNil())
Expect(ack.AckRanges).To(HaveLen(2))
err = handler.ReceivedPacket(protocol.PacketNumber(2))
Expect(err).ToNot(HaveOccurred())
ack, _ = handler.GetAckFrame(true)
Expect(ack).ToNot(BeNil())
Expect(ack.AckRanges).To(BeEmpty())
})
It("doesn't send old ACK ranges after receiving a StopWaiting", func() {
err := handler.ReceivedPacket(5)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(10)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(11)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedPacket(12)
Expect(err).ToNot(HaveOccurred())
err = handler.ReceivedStopWaiting(&frames.StopWaitingFrame{LeastUnacked: protocol.PacketNumber(11)})
Expect(err).ToNot(HaveOccurred())
ack, err := handler.GetAckFrame(true)
Expect(err).ToNot(HaveOccurred())
Expect(ack).ToNot(BeNil())
Expect(ack.LargestAcked).To(Equal(protocol.PacketNumber(12)))
Expect(ack.LowestAcked).To(Equal(protocol.PacketNumber(11)))
Expect(ack.HasMissingRanges()).To(BeFalse())
})
It("deletes packets from the packetHistory after receiving a StopWaiting, after continuously received packets", func() { It("deletes packets from the packetHistory after receiving a StopWaiting, after continuously received packets", func() {
for i := 1; i <= 12; i++ { for i := 1; i <= 12; i++ {
err := handler.ReceivedPacket(protocol.PacketNumber(i)) err := handler.ReceivedPacket(protocol.PacketNumber(i), true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
} }
err := handler.ReceivedStopWaiting(&frames.StopWaitingFrame{LeastUnacked: protocol.PacketNumber(6)}) err := handler.ReceivedStopWaiting(&frames.StopWaitingFrame{LeastUnacked: protocol.PacketNumber(6)})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
// check that the packets were deleted from the receivedPacketHistory by checking the values in an ACK frame // check that the packets were deleted from the receivedPacketHistory by checking the values in an ACK frame
ack, err := handler.GetAckFrame(true) ack := handler.GetAckFrame()
Expect(err).ToNot(HaveOccurred())
Expect(ack).ToNot(BeNil()) Expect(ack).ToNot(BeNil())
Expect(ack.LargestAcked).To(Equal(protocol.PacketNumber(12))) Expect(ack.LargestAcked).To(Equal(protocol.PacketNumber(12)))
Expect(ack.LowestAcked).To(Equal(protocol.PacketNumber(6))) Expect(ack.LowestAcked).To(Equal(protocol.PacketNumber(6)))
Expect(ack.HasMissingRanges()).To(BeFalse()) Expect(ack.HasMissingRanges()).To(BeFalse())
}) })
It("resets all counters needed for the ACK queueing decision when sending an ACK", func() {
err := handler.ReceivedPacket(1, true)
Expect(err).ToNot(HaveOccurred())
handler.ackAlarm = time.Now().Add(-time.Minute)
Expect(handler.GetAckFrame()).ToNot(BeNil())
Expect(handler.packetsReceivedSinceLastAck).To(BeZero())
Expect(handler.ackAlarm).To(BeZero())
Expect(handler.retransmittablePacketsReceivedSinceLastAck).To(BeZero())
Expect(handler.ackQueued).To(BeFalse())
})
It("doesn't generate an ACK when none is queued and the timer is not set", func() {
err := handler.ReceivedPacket(1, true)
Expect(err).ToNot(HaveOccurred())
handler.ackQueued = false
handler.ackAlarm = time.Time{}
Expect(handler.GetAckFrame()).To(BeNil())
})
It("doesn't generate an ACK when none is queued and the timer has not yet expired", func() {
err := handler.ReceivedPacket(1, true)
Expect(err).ToNot(HaveOccurred())
handler.ackQueued = false
handler.ackAlarm = time.Now().Add(time.Minute)
Expect(handler.GetAckFrame()).To(BeNil())
})
It("generates an ACK when the timer has expired", func() {
err := handler.ReceivedPacket(1, true)
Expect(err).ToNot(HaveOccurred())
handler.ackQueued = false
handler.ackAlarm = time.Now().Add(-time.Minute)
Expect(handler.GetAckFrame()).ToNot(BeNil())
})
})
}) })
}) })

View file

@ -41,14 +41,14 @@ func newPacketPacker(connectionID protocol.ConnectionID, cryptoSetup *handshake.
} }
func (p *packetPacker) PackConnectionClose(frame *frames.ConnectionCloseFrame, leastUnacked protocol.PacketNumber) (*packedPacket, error) { func (p *packetPacker) PackConnectionClose(frame *frames.ConnectionCloseFrame, leastUnacked protocol.PacketNumber) (*packedPacket, error) {
return p.packPacket(nil, []frames.Frame{frame}, leastUnacked, true, false) return p.packPacket(nil, []frames.Frame{frame}, leastUnacked, true)
} }
func (p *packetPacker) PackPacket(stopWaitingFrame *frames.StopWaitingFrame, controlFrames []frames.Frame, leastUnacked protocol.PacketNumber, maySendOnlyAck bool) (*packedPacket, error) { func (p *packetPacker) PackPacket(stopWaitingFrame *frames.StopWaitingFrame, controlFrames []frames.Frame, leastUnacked protocol.PacketNumber) (*packedPacket, error) {
return p.packPacket(stopWaitingFrame, controlFrames, leastUnacked, false, maySendOnlyAck) return p.packPacket(stopWaitingFrame, controlFrames, leastUnacked, false)
} }
func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, controlFrames []frames.Frame, leastUnacked protocol.PacketNumber, onlySendOneControlFrame, maySendOnlyAck bool) (*packedPacket, error) { func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, controlFrames []frames.Frame, leastUnacked protocol.PacketNumber, onlySendOneControlFrame bool) (*packedPacket, error) {
if len(controlFrames) > 0 { if len(controlFrames) > 0 {
p.controlFrames = append(p.controlFrames, controlFrames...) p.controlFrames = append(p.controlFrames, controlFrames...)
} }
@ -97,18 +97,6 @@ func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, con
if !onlySendOneControlFrame && len(payloadFrames) == 1 && stopWaitingFrame != nil { if !onlySendOneControlFrame && len(payloadFrames) == 1 && stopWaitingFrame != nil {
return nil, nil return nil, nil
} }
// Don't send out packets that only contain an ACK (plus optional STOP_WAITING), if requested
if !maySendOnlyAck {
if len(payloadFrames) == 1 {
if _, ok := payloadFrames[0].(*frames.AckFrame); ok {
return nil, nil
}
} else if len(payloadFrames) == 2 && stopWaitingFrame != nil {
if _, ok := payloadFrames[1].(*frames.AckFrame); ok {
return nil, nil
}
}
}
raw := getPacketBuffer() raw := getPacketBuffer()
buffer := bytes.NewBuffer(raw) buffer := bytes.NewBuffer(raw)

View file

@ -37,7 +37,7 @@ var _ = Describe("Packet packer", func() {
}) })
It("returns nil when no packet is queued", func() { It("returns nil when no packet is queued", func() {
p, err := packer.PackPacket(nil, []frames.Frame{}, 0, true) p, err := packer.PackPacket(nil, []frames.Frame{}, 0)
Expect(p).To(BeNil()) Expect(p).To(BeNil())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
}) })
@ -48,7 +48,7 @@ var _ = Describe("Packet packer", func() {
Data: []byte{0xDE, 0xCA, 0xFB, 0xAD}, Data: []byte{0xDE, 0xCA, 0xFB, 0xAD},
} }
streamFramer.AddFrameForRetransmission(f) streamFramer.AddFrameForRetransmission(f)
p, err := packer.PackPacket(nil, []frames.Frame{}, 0, true) p, err := packer.PackPacket(nil, []frames.Frame{}, 0)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil()) Expect(p).ToNot(BeNil())
b := &bytes.Buffer{} b := &bytes.Buffer{}
@ -73,14 +73,14 @@ var _ = Describe("Packet packer", func() {
ErrorCode: 0x1337, ErrorCode: 0x1337,
ReasonPhrase: "foobar", ReasonPhrase: "foobar",
} }
p, err := packer.packPacket(&frames.StopWaitingFrame{LeastUnacked: 13}, []frames.Frame{&ccf, &frames.WindowUpdateFrame{StreamID: 37}}, 0, true, true) p, err := packer.packPacket(&frames.StopWaitingFrame{LeastUnacked: 13}, []frames.Frame{&ccf, &frames.WindowUpdateFrame{StreamID: 37}}, 0, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(HaveLen(1)) Expect(p.frames).To(HaveLen(1))
Expect(p.frames[0]).To(Equal(&ccf)) Expect(p.frames[0]).To(Equal(&ccf))
}) })
It("packs only control frames", func() { It("packs only control frames", func() {
p, err := packer.PackPacket(nil, []frames.Frame{&frames.ConnectionCloseFrame{}}, 0, true) p, err := packer.PackPacket(nil, []frames.Frame{&frames.RstStreamFrame{}}, 0)
Expect(p).ToNot(BeNil()) Expect(p).ToNot(BeNil())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(HaveLen(1)) Expect(p.frames).To(HaveLen(1))
@ -88,10 +88,10 @@ var _ = Describe("Packet packer", func() {
}) })
It("increases the packet number", func() { It("increases the packet number", func() {
p1, err := packer.PackPacket(nil, []frames.Frame{&frames.ConnectionCloseFrame{}}, 0, true) p1, err := packer.PackPacket(nil, []frames.Frame{&frames.RstStreamFrame{}}, 0)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p1).ToNot(BeNil()) Expect(p1).ToNot(BeNil())
p2, err := packer.PackPacket(nil, []frames.Frame{&frames.ConnectionCloseFrame{}}, 0, true) p2, err := packer.PackPacket(nil, []frames.Frame{&frames.RstStreamFrame{}}, 0)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p2).ToNot(BeNil()) Expect(p2).ToNot(BeNil())
Expect(p2.number).To(BeNumerically(">", p1.number)) Expect(p2.number).To(BeNumerically(">", p1.number))
@ -100,7 +100,7 @@ var _ = Describe("Packet packer", func() {
It("packs a StopWaitingFrame first", func() { It("packs a StopWaitingFrame first", func() {
packer.packetNumberGenerator.next = 15 packer.packetNumberGenerator.next = 15
swf := &frames.StopWaitingFrame{LeastUnacked: 10} swf := &frames.StopWaitingFrame{LeastUnacked: 10}
p, err := packer.PackPacket(swf, []frames.Frame{&frames.ConnectionCloseFrame{}}, 0, true) p, err := packer.PackPacket(swf, []frames.Frame{&frames.RstStreamFrame{}}, 0)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil()) Expect(p).ToNot(BeNil())
Expect(p.frames).To(HaveLen(2)) Expect(p.frames).To(HaveLen(2))
@ -111,21 +111,21 @@ var _ = Describe("Packet packer", func() {
packetNumber := protocol.PacketNumber(0xDECAFB) // will result in a 4 byte packet number packetNumber := protocol.PacketNumber(0xDECAFB) // will result in a 4 byte packet number
packer.packetNumberGenerator.next = packetNumber packer.packetNumberGenerator.next = packetNumber
swf := &frames.StopWaitingFrame{LeastUnacked: packetNumber - 0x100} swf := &frames.StopWaitingFrame{LeastUnacked: packetNumber - 0x100}
p, err := packer.PackPacket(swf, []frames.Frame{&frames.ConnectionCloseFrame{}}, 0, true) p, err := packer.PackPacket(swf, []frames.Frame{&frames.RstStreamFrame{}}, 0)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p.frames[0].(*frames.StopWaitingFrame).PacketNumberLen).To(Equal(protocol.PacketNumberLen4)) Expect(p.frames[0].(*frames.StopWaitingFrame).PacketNumberLen).To(Equal(protocol.PacketNumberLen4))
}) })
It("does not pack a packet containing only a StopWaitingFrame", func() { It("does not pack a packet containing only a StopWaitingFrame", func() {
swf := &frames.StopWaitingFrame{LeastUnacked: 10} swf := &frames.StopWaitingFrame{LeastUnacked: 10}
p, err := packer.PackPacket(swf, []frames.Frame{}, 0, true) p, err := packer.PackPacket(swf, []frames.Frame{}, 0)
Expect(p).To(BeNil()) Expect(p).To(BeNil())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
}) })
It("packs a packet if it has queued control frames, but no new control frames", func() { It("packs a packet if it has queued control frames, but no new control frames", func() {
packer.controlFrames = []frames.Frame{&frames.BlockedFrame{StreamID: 0}} packer.controlFrames = []frames.Frame{&frames.BlockedFrame{StreamID: 0}}
p, err := packer.PackPacket(nil, []frames.Frame{}, 0, true) p, err := packer.PackPacket(nil, []frames.Frame{}, 0)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil()) Expect(p).ToNot(BeNil())
}) })
@ -169,7 +169,7 @@ var _ = Describe("Packet packer", func() {
It("only increases the packet number when there is an actual packet to send", func() { It("only increases the packet number when there is an actual packet to send", func() {
packer.packetNumberGenerator.nextToSkip = 1000 packer.packetNumberGenerator.nextToSkip = 1000
p, err := packer.PackPacket(nil, []frames.Frame{}, 0, true) p, err := packer.PackPacket(nil, []frames.Frame{}, 0)
Expect(p).To(BeNil()) Expect(p).To(BeNil())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(packer.packetNumberGenerator.Peek()).To(Equal(protocol.PacketNumber(1))) Expect(packer.packetNumberGenerator.Peek()).To(Equal(protocol.PacketNumber(1)))
@ -178,7 +178,7 @@ var _ = Describe("Packet packer", func() {
Data: []byte{0xDE, 0xCA, 0xFB, 0xAD}, Data: []byte{0xDE, 0xCA, 0xFB, 0xAD},
} }
streamFramer.AddFrameForRetransmission(f) streamFramer.AddFrameForRetransmission(f)
p, err = packer.PackPacket(nil, []frames.Frame{}, 0, true) p, err = packer.PackPacket(nil, []frames.Frame{}, 0)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil()) Expect(p).ToNot(BeNil())
Expect(p.number).To(Equal(protocol.PacketNumber(1))) Expect(p.number).To(Equal(protocol.PacketNumber(1)))
@ -219,12 +219,12 @@ var _ = Describe("Packet packer", func() {
} }
streamFramer.AddFrameForRetransmission(f1) streamFramer.AddFrameForRetransmission(f1)
streamFramer.AddFrameForRetransmission(f2) streamFramer.AddFrameForRetransmission(f2)
p, err := packer.PackPacket(nil, []frames.Frame{}, 0, true) p, err := packer.PackPacket(nil, []frames.Frame{}, 0)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize - 1))) Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize - 1)))
Expect(p.frames).To(HaveLen(1)) Expect(p.frames).To(HaveLen(1))
Expect(p.frames[0].(*frames.StreamFrame).DataLenPresent).To(BeFalse()) Expect(p.frames[0].(*frames.StreamFrame).DataLenPresent).To(BeFalse())
p, err = packer.PackPacket(nil, []frames.Frame{}, 0, true) p, err = packer.PackPacket(nil, []frames.Frame{}, 0)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(HaveLen(1)) Expect(p.frames).To(HaveLen(1))
Expect(p.frames[0].(*frames.StreamFrame).DataLenPresent).To(BeFalse()) Expect(p.frames[0].(*frames.StreamFrame).DataLenPresent).To(BeFalse())
@ -246,7 +246,7 @@ var _ = Describe("Packet packer", func() {
streamFramer.AddFrameForRetransmission(f1) streamFramer.AddFrameForRetransmission(f1)
streamFramer.AddFrameForRetransmission(f2) streamFramer.AddFrameForRetransmission(f2)
streamFramer.AddFrameForRetransmission(f3) streamFramer.AddFrameForRetransmission(f3)
p, err := packer.PackPacket(nil, []frames.Frame{}, 0, true) p, err := packer.PackPacket(nil, []frames.Frame{}, 0)
Expect(p).ToNot(BeNil()) Expect(p).ToNot(BeNil())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
b := &bytes.Buffer{} b := &bytes.Buffer{}
@ -300,23 +300,23 @@ var _ = Describe("Packet packer", func() {
} }
streamFramer.AddFrameForRetransmission(f1) streamFramer.AddFrameForRetransmission(f1)
streamFramer.AddFrameForRetransmission(f2) streamFramer.AddFrameForRetransmission(f2)
p, err := packer.PackPacket(nil, []frames.Frame{}, 0, true) p, err := packer.PackPacket(nil, []frames.Frame{}, 0)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p.frames).To(HaveLen(1)) Expect(p.frames).To(HaveLen(1))
Expect(p.frames[0].(*frames.StreamFrame).DataLenPresent).To(BeFalse()) Expect(p.frames[0].(*frames.StreamFrame).DataLenPresent).To(BeFalse())
Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize))) Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize)))
p, err = packer.PackPacket(nil, []frames.Frame{}, 0, true) p, err = packer.PackPacket(nil, []frames.Frame{}, 0)
Expect(p.frames).To(HaveLen(2)) Expect(p.frames).To(HaveLen(2))
Expect(p.frames[0].(*frames.StreamFrame).DataLenPresent).To(BeTrue()) Expect(p.frames[0].(*frames.StreamFrame).DataLenPresent).To(BeTrue())
Expect(p.frames[1].(*frames.StreamFrame).DataLenPresent).To(BeFalse()) Expect(p.frames[1].(*frames.StreamFrame).DataLenPresent).To(BeFalse())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize))) Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize)))
p, err = packer.PackPacket(nil, []frames.Frame{}, 0, true) p, err = packer.PackPacket(nil, []frames.Frame{}, 0)
Expect(p.frames).To(HaveLen(1)) Expect(p.frames).To(HaveLen(1))
Expect(p.frames[0].(*frames.StreamFrame).DataLenPresent).To(BeFalse()) Expect(p.frames[0].(*frames.StreamFrame).DataLenPresent).To(BeFalse())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil()) Expect(p).ToNot(BeNil())
p, err = packer.PackPacket(nil, []frames.Frame{}, 0, true) p, err = packer.PackPacket(nil, []frames.Frame{}, 0)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p).To(BeNil()) Expect(p).To(BeNil())
}) })
@ -329,7 +329,7 @@ var _ = Describe("Packet packer", func() {
minLength, _ := f.MinLength(0) minLength, _ := f.MinLength(0)
f.Data = bytes.Repeat([]byte{'f'}, int(protocol.MaxFrameAndPublicHeaderSize-publicHeaderLen-minLength+1)) // + 1 since MinceLength is 1 bigger than the actual StreamFrame header f.Data = bytes.Repeat([]byte{'f'}, int(protocol.MaxFrameAndPublicHeaderSize-publicHeaderLen-minLength+1)) // + 1 since MinceLength is 1 bigger than the actual StreamFrame header
streamFramer.AddFrameForRetransmission(f) streamFramer.AddFrameForRetransmission(f)
p, err := packer.PackPacket(nil, []frames.Frame{}, 0, true) p, err := packer.PackPacket(nil, []frames.Frame{}, 0)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(p).ToNot(BeNil()) Expect(p).ToNot(BeNil())
Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize))) Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize)))
@ -395,25 +395,21 @@ var _ = Describe("Packet packer", func() {
}) })
It("returns nil if we only have a single STOP_WAITING", func() { It("returns nil if we only have a single STOP_WAITING", func() {
p, err := packer.PackPacket(&frames.StopWaitingFrame{}, nil, 0, false) p, err := packer.PackPacket(&frames.StopWaitingFrame{}, nil, 0)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(p).To(BeNil()) Expect(p).To(BeNil())
}) })
It("returns nil if we only have a single STOP_WAITING and an ACK", func() { It("packs a single ACK", func() {
p, err := packer.PackPacket(&frames.StopWaitingFrame{}, []frames.Frame{&frames.AckFrame{}}, 0, false) ack := &frames.AckFrame{LargestAcked: 42}
p, err := packer.PackPacket(nil, []frames.Frame{ack}, 0)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(p).To(BeNil()) Expect(p).ToNot(BeNil())
}) Expect(p.frames[0]).To(Equal(ack))
It("returns nil if we only have a single ACK", func() {
p, err := packer.PackPacket(nil, []frames.Frame{&frames.AckFrame{}}, 0, false)
Expect(err).NotTo(HaveOccurred())
Expect(p).To(BeNil())
}) })
It("does not return nil if we only have a single ACK but request it to be sent", func() { It("does not return nil if we only have a single ACK but request it to be sent", func() {
p, err := packer.PackPacket(nil, []frames.Frame{&frames.AckFrame{}}, 0, true) p, err := packer.PackPacket(nil, []frames.Frame{&frames.AckFrame{}}, 0)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(p).ToNot(BeNil()) Expect(p).ToNot(BeNil())
}) })

View file

@ -71,6 +71,12 @@ const MaxTrackedReceivedPackets = 2 * DefaultMaxCongestionWindow
// MaxTrackedReceivedAckRanges is the maximum number of ACK ranges tracked // MaxTrackedReceivedAckRanges is the maximum number of ACK ranges tracked
const MaxTrackedReceivedAckRanges = DefaultMaxCongestionWindow const MaxTrackedReceivedAckRanges = DefaultMaxCongestionWindow
// MaxPacketsReceivedBeforeAckSend is the number of packets that can be received before an ACK frame is sent
const MaxPacketsReceivedBeforeAckSend = 20
// RetransmittablePacketsBeforeAck is the number of retransmittable that an ACK is sent for
const RetransmittablePacketsBeforeAck = 2
// MaxStreamFrameSorterGaps is the maximum number of gaps between received StreamFrames // MaxStreamFrameSorterGaps is the maximum number of gaps between received StreamFrames
// prevents DoS attacks against the streamFrameSorter // prevents DoS attacks against the streamFrameSorter
const MaxStreamFrameSorterGaps = 1000 const MaxStreamFrameSorterGaps = 1000

View file

@ -4,7 +4,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"runtime"
"sync/atomic" "sync/atomic"
"time" "time"
@ -77,7 +76,7 @@ type Session struct {
undecryptablePackets []*receivedPacket undecryptablePackets []*receivedPacket
aeadChanged chan struct{} aeadChanged chan struct{}
delayedAckOriginTime time.Time nextAckScheduledTime time.Time
connectionParameters handshake.ConnectionParametersManager connectionParameters handshake.ConnectionParametersManager
@ -99,12 +98,9 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
connectionParameters := handshake.NewConnectionParamatersManager(v) connectionParameters := handshake.NewConnectionParamatersManager(v)
var sentPacketHandler ackhandler.SentPacketHandler var sentPacketHandler ackhandler.SentPacketHandler
var receivedPacketHandler ackhandler.ReceivedPacketHandler
rttStats := &congestion.RTTStats{} rttStats := &congestion.RTTStats{}
sentPacketHandler = ackhandler.NewSentPacketHandler(rttStats) sentPacketHandler = ackhandler.NewSentPacketHandler(rttStats)
receivedPacketHandler = ackhandler.NewReceivedPacketHandler()
flowControlManager := flowcontrol.NewFlowControlManager(connectionParameters, rttStats) flowControlManager := flowcontrol.NewFlowControlManager(connectionParameters, rttStats)
now := time.Now() now := time.Now()
@ -118,7 +114,6 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
connectionParameters: connectionParameters, connectionParameters: connectionParameters,
sentPacketHandler: sentPacketHandler, sentPacketHandler: sentPacketHandler,
receivedPacketHandler: receivedPacketHandler,
flowControlManager: flowControlManager, flowControlManager: flowControlManager,
receivedPackets: make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets), receivedPackets: make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets),
@ -133,6 +128,7 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
sessionCreationTime: now, sessionCreationTime: now,
} }
session.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(session.ackAlarmChanged)
session.streamsMap = newStreamsMap(session.newStream, session.connectionParameters) session.streamsMap = newStreamsMap(session.newStream, session.connectionParameters)
cryptoStream, _ := session.GetOrOpenStream(1) cryptoStream, _ := session.GetOrOpenStream(1)
@ -195,9 +191,6 @@ runLoop:
// This is a bit unclean, but works properly, since the packet always // This is a bit unclean, but works properly, since the packet always
// begins with the public header and we never copy it. // begins with the public header and we never copy it.
putPacketBuffer(p.publicHeader.Raw) putPacketBuffer(p.publicHeader.Raw)
if s.delayedAckOriginTime.IsZero() {
s.delayedAckOriginTime = p.rcvTime
}
case <-s.aeadChanged: case <-s.aeadChanged:
s.tryDecryptingQueuedPackets() s.tryDecryptingQueuedPackets()
} }
@ -225,8 +218,8 @@ runLoop:
func (s *Session) maybeResetTimer() { func (s *Session) maybeResetTimer() {
nextDeadline := s.lastNetworkActivityTime.Add(s.idleTimeout()) nextDeadline := s.lastNetworkActivityTime.Add(s.idleTimeout())
if !s.delayedAckOriginTime.IsZero() { if !s.nextAckScheduledTime.IsZero() {
nextDeadline = utils.MinTime(nextDeadline, s.delayedAckOriginTime.Add(protocol.AckSendDelay)) nextDeadline = utils.MinTime(nextDeadline, s.nextAckScheduledTime)
} }
if rtoTime := s.sentPacketHandler.TimeOfFirstRTO(); !rtoTime.IsZero() { if rtoTime := s.sentPacketHandler.TimeOfFirstRTO(); !rtoTime.IsZero() {
nextDeadline = utils.MinTime(nextDeadline, rtoTime) nextDeadline = utils.MinTime(nextDeadline, rtoTime)
@ -291,7 +284,7 @@ func (s *Session) handlePacketImpl(p *receivedPacket) error {
// Only do this after decrypting, so we are sure the packet is not attacker-controlled // Only do this after decrypting, so we are sure the packet is not attacker-controlled
s.largestRcvdPacketNumber = utils.MaxPacketNumber(s.largestRcvdPacketNumber, hdr.PacketNumber) s.largestRcvdPacketNumber = utils.MaxPacketNumber(s.largestRcvdPacketNumber, hdr.PacketNumber)
err = s.receivedPacketHandler.ReceivedPacket(hdr.PacketNumber) err = s.receivedPacketHandler.ReceivedPacket(hdr.PacketNumber, packet.IsRetransmittable())
// ignore duplicate packets // ignore duplicate packets
if err == ackhandler.ErrDuplicatePacket { if err == ackhandler.ErrDuplicatePacket {
utils.Infof("Ignoring packet 0x%x due to ErrDuplicatePacket", hdr.PacketNumber) utils.Infof("Ignoring packet 0x%x due to ErrDuplicatePacket", hdr.PacketNumber)
@ -514,45 +507,26 @@ func (s *Session) sendPacket() error {
if err != nil { if err != nil {
return err return err
} }
for _, wuf := range windowUpdateFrames { for _, wuf := range windowUpdateFrames {
controlFrames = append(controlFrames, wuf) controlFrames = append(controlFrames, wuf)
} }
ack := s.receivedPacketHandler.GetAckFrame()
ack, err := s.receivedPacketHandler.GetAckFrame(false)
if err != nil {
return err
}
if ack != nil { if ack != nil {
controlFrames = append(controlFrames, ack) controlFrames = append(controlFrames, ack)
} }
// Check whether we are allowed to send a packet containing only an ACK
maySendOnlyAck := time.Now().Sub(s.delayedAckOriginTime) > protocol.AckSendDelay
if runtime.GOOS == "windows" {
maySendOnlyAck = true
}
hasRetransmission := s.streamFramer.HasFramesForRetransmission() hasRetransmission := s.streamFramer.HasFramesForRetransmission()
var stopWaitingFrame *frames.StopWaitingFrame var stopWaitingFrame *frames.StopWaitingFrame
if ack != nil || hasRetransmission { if ack != nil || hasRetransmission {
stopWaitingFrame = s.sentPacketHandler.GetStopWaitingFrame(hasRetransmission) stopWaitingFrame = s.sentPacketHandler.GetStopWaitingFrame(hasRetransmission)
} }
packet, err := s.packer.PackPacket(stopWaitingFrame, controlFrames, s.sentPacketHandler.GetLeastUnacked(), maySendOnlyAck) packet, err := s.packer.PackPacket(stopWaitingFrame, controlFrames, s.sentPacketHandler.GetLeastUnacked())
if err != nil { if err != nil {
return err return err
} }
if packet == nil { if packet == nil {
return nil return nil
} }
// send every window update twice
// Pop the ACK frame now that we are sure we're gonna send it
_, err = s.receivedPacketHandler.GetAckFrame(true)
if err != nil {
return err
}
for _, f := range windowUpdateFrames { for _, f := range windowUpdateFrames {
s.packer.QueueControlFrameForNextPacket(f) s.packer.QueueControlFrameForNextPacket(f)
} }
@ -567,13 +541,13 @@ func (s *Session) sendPacket() error {
} }
s.logPacket(packet) s.logPacket(packet)
s.delayedAckOriginTime = time.Time{}
err = s.conn.write(packet.raw) err = s.conn.write(packet.raw)
putPacketBuffer(packet.raw) putPacketBuffer(packet.raw)
if err != nil { if err != nil {
return err return err
} }
s.nextAckScheduledTime = time.Time{}
} }
} }
@ -695,6 +669,11 @@ func (s *Session) getWindowUpdateFrames() ([]*frames.WindowUpdateFrame, error) {
return res, nil return res, nil
} }
func (s *Session) ackAlarmChanged(t time.Time) {
s.nextAckScheduledTime = t
s.maybeResetTimer()
}
// RemoteAddr returns the net.UDPAddr of the client // RemoteAddr returns the net.UDPAddr of the client
func (s *Session) RemoteAddr() *net.UDPAddr { func (s *Session) RemoteAddr() *net.UDPAddr {
return s.conn.RemoteAddr() return s.conn.RemoteAddr()

View file

@ -88,6 +88,22 @@ func newMockSentPacketHandler() ackhandler.SentPacketHandler {
return &mockSentPacketHandler{} return &mockSentPacketHandler{}
} }
var _ ackhandler.SentPacketHandler = &mockSentPacketHandler{}
type mockReceivedPacketHandler struct {
nextAckFrame *frames.AckFrame
}
func (m *mockReceivedPacketHandler) GetAckFrame() *frames.AckFrame { return m.nextAckFrame }
func (m *mockReceivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error {
panic("not implemented")
}
func (m *mockReceivedPacketHandler) ReceivedStopWaiting(*frames.StopWaitingFrame) error {
panic("not implemented")
}
var _ ackhandler.ReceivedPacketHandler = &mockReceivedPacketHandler{}
var _ = Describe("Session", func() { var _ = Describe("Session", func() {
var ( var (
session *Session session *Session
@ -602,7 +618,7 @@ var _ = Describe("Session", func() {
Context("sending packets", func() { Context("sending packets", func() {
It("sends ack frames", func() { It("sends ack frames", func() {
packetNumber := protocol.PacketNumber(0x035E) packetNumber := protocol.PacketNumber(0x035E)
session.receivedPacketHandler.ReceivedPacket(packetNumber) session.receivedPacketHandler.ReceivedPacket(packetNumber, true)
err := session.sendPacket() err := session.sendPacket()
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(conn.written).To(HaveLen(1)) Expect(conn.written).To(HaveLen(1))
@ -734,6 +750,17 @@ var _ = Describe("Session", func() {
s.(*stream).getDataForWriting(1000) // unblock s.(*stream).getDataForWriting(1000) // unblock
}) })
It("sets the timer to the ack timer", func() {
rph := &mockReceivedPacketHandler{}
rph.nextAckFrame = &frames.AckFrame{LargestAcked: 0x1337}
session.receivedPacketHandler = rph
go session.run()
session.ackAlarmChanged(time.Now().Add(10 * time.Millisecond))
time.Sleep(10 * time.Millisecond)
Eventually(func() int { return len(conn.written) }).ShouldNot(BeZero())
Expect(conn.written[0]).To(ContainSubstring(string([]byte{0x37, 0x13})))
})
Context("bundling of small packets", func() { Context("bundling of small packets", func() {
It("bundles two small frames of different streams into one packet", func() { It("bundles two small frames of different streams into one packet", func() {
s1, err := session.GetOrOpenStream(5) s1, err := session.GetOrOpenStream(5)
@ -783,7 +810,7 @@ var _ = Describe("Session", func() {
It("sends a queued ACK frame only once", func() { It("sends a queued ACK frame only once", func() {
packetNumber := protocol.PacketNumber(0x1337) packetNumber := protocol.PacketNumber(0x1337)
session.receivedPacketHandler.ReceivedPacket(packetNumber) session.receivedPacketHandler.ReceivedPacket(packetNumber, true)
s, err := session.GetOrOpenStream(5) s, err := session.GetOrOpenStream(5)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())