diff --git a/zero_rtt_queue.go b/zero_rtt_queue.go index 0418ddf5..f2d6eb4b 100644 --- a/zero_rtt_queue.go +++ b/zero_rtt_queue.go @@ -30,11 +30,7 @@ func (h *zeroRTTQueue) Enqueue(connID protocol.ConnectionID, p *receivedPacket) if len(h.queue) >= protocol.Max0RTTQueues { return } - h.queue[cid] = &zeroRTTQueueEntry{timer: time.AfterFunc(protocol.Max0RTTQueueingDuration, func() { - h.mutex.Lock() - delete(h.queue, cid) - h.mutex.Unlock() - })} + h.queue[cid] = &zeroRTTQueueEntry{timer: time.AfterFunc(protocol.Max0RTTQueueingDuration, func() { h.deleteQueue(connID) })} } entry := h.queue[cid] if len(entry.packets) >= protocol.Max0RTTQueueLen { @@ -59,3 +55,17 @@ func (h *zeroRTTQueue) Dequeue(connID protocol.ConnectionID) *receivedPacket { } return p } + +func (h *zeroRTTQueue) deleteQueue(connID protocol.ConnectionID) { + h.mutex.Lock() + defer h.mutex.Unlock() + + entry, ok := h.queue[string(connID)] + if !ok { + return + } + for _, p := range entry.packets { + p.buffer.Release() + } + delete(h.queue, string(connID)) +} diff --git a/zero_rtt_queue_test.go b/zero_rtt_queue_test.go index ca9e6907..4b0bb45c 100644 --- a/zero_rtt_queue_test.go +++ b/zero_rtt_queue_test.go @@ -17,6 +17,22 @@ var _ = Describe("0-RTT queue", func() { q = newZeroRTTQueue() }) + AfterEach(func() { + // dequeue all packets to make sure the timers are stopped + q.mutex.Lock() + for connID := range q.queue { + for { + q.mutex.Unlock() + p := q.Dequeue(protocol.ConnectionID(connID)) + q.mutex.Lock() + if p != nil { + break + } + } + } + q.mutex.Unlock() + }) + It("stores a 0-RTT packet", func() { connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} p := &receivedPacket{data: []byte("foobar")} @@ -89,7 +105,7 @@ var _ = Describe("0-RTT queue", func() { It("deletes packets if they aren't dequeued after a short while", func() { connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} - p := &receivedPacket{data: []byte("foobar")} + p := &receivedPacket{data: []byte("foobar"), buffer: getPacketBuffer()} q.Enqueue(connID, p) time.Sleep(protocol.Max0RTTQueueingDuration * 3 / 2) Expect(q.Dequeue(connID)).To(BeNil())