dequeue all 0-RTT packets to the session in one go

This commit is contained in:
Marten Seemann 2021-01-09 13:34:05 +08:00
parent d1c5297c0b
commit 7a301aae3d
3 changed files with 47 additions and 45 deletions

View file

@ -463,13 +463,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
return nil
}
sess.handlePacket(p)
for {
p := s.zeroRTTQueue.Dequeue(hdr.DestConnectionID)
if p == nil {
break
}
sess.handlePacket(p)
}
s.zeroRTTQueue.DequeueToSession(hdr.DestConnectionID, sess)
return nil
}

View file

@ -45,21 +45,23 @@ func (h *zeroRTTQueue) Enqueue(connID protocol.ConnectionID, p *receivedPacket)
entry.packets = append(entry.packets, p)
}
func (h *zeroRTTQueue) Dequeue(connID protocol.ConnectionID) *receivedPacket {
func (h *zeroRTTQueue) DequeueToSession(connID protocol.ConnectionID, sess packetHandler) {
h.mutex.Lock()
defer h.mutex.Unlock()
h.dequeueToSession(connID, sess)
}
func (h *zeroRTTQueue) dequeueToSession(connID protocol.ConnectionID, sess packetHandler) {
entry, ok := h.queue[string(connID)]
if !ok {
return nil
return
}
p := entry.packets[0]
entry.packets = entry.packets[1:]
if len(entry.packets) == 0 {
entry.timer.Stop()
delete(h.queue, string(connID))
entry.timer.Stop()
for _, p := range entry.packets {
sess.handlePacket(p)
}
return p
delete(h.queue, string(connID))
}
func (h *zeroRTTQueue) deleteQueue(connID protocol.ConnectionID) {

View file

@ -4,6 +4,7 @@ import (
"encoding/binary"
"time"
"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/protocol"
. "github.com/onsi/ginkgo"
@ -14,6 +15,16 @@ var _ = Describe("0-RTT queue", func() {
var q *zeroRTTQueue
queueDuration := scaleDuration(50 * time.Millisecond)
getQueuedPackets := func(connID protocol.ConnectionID) []*receivedPacket {
var packets []*receivedPacket
sess := NewMockPacketHandler(mockCtrl)
sess.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) {
packets = append(packets, p)
}).AnyTimes()
q.DequeueToSession(connID, sess)
return packets
}
BeforeEach(func() {
q = newZeroRTTQueue()
q.queueDuration = queueDuration
@ -23,14 +34,9 @@ var _ = Describe("0-RTT queue", func() {
// dequeue all packets to make sure the timers are stopped
q.mutex.Lock()
for connID := range q.queue {
for {
q.mutex.Unlock()
p := q.Dequeue(protocol.ConnectionID(connID))
q.mutex.Lock()
if p != nil {
break
}
}
sess := NewMockPacketHandler(mockCtrl)
sess.EXPECT().handlePacket(gomock.Any()).AnyTimes()
q.dequeueToSession(protocol.ConnectionID(connID), sess)
}
q.mutex.Unlock()
})
@ -39,32 +45,34 @@ var _ = Describe("0-RTT queue", func() {
connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}
p := &receivedPacket{data: []byte("foobar")}
q.Enqueue(connID, p)
Expect(q.Dequeue(connID)).To(Equal(p))
Expect(q.Dequeue(connID)).To(BeNil())
queuedPackets := getQueuedPackets(connID)
Expect(queuedPackets).To(Equal([]*receivedPacket{p}))
})
It("returns a nil packet for unknown connection IDs", func() {
Expect(q.Dequeue(protocol.ConnectionID{0x42})).To(BeNil())
It("doesn't dequeue for unknown connection IDs", func() {
Expect(getQueuedPackets(protocol.ConnectionID{0x42})).To(BeEmpty())
})
It("only stores packets for Max0RTTQueues connection", func() {
getConnID := func(i int) protocol.ConnectionID {
connID := make([]byte, 4)
binary.BigEndian.PutUint32(connID, uint32(i))
return connID
}
// fill up the queues
for i := 0; i < protocol.Max0RTTQueues; i++ {
data := make([]byte, 4)
binary.BigEndian.PutUint32(data, uint32(i))
q.Enqueue(protocol.ConnectionID(data), &receivedPacket{data: data})
connID := getConnID(i)
q.Enqueue(connID, &receivedPacket{data: []byte(connID)})
}
// now try to enqueue a packet for another connection ID
connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}
q.Enqueue(connID, &receivedPacket{data: []byte("foobar")})
Expect(q.Dequeue(connID)).To(BeNil())
// check that the other queues were all saved
for i := 0; i < protocol.Max0RTTQueues; i++ {
connID := make([]byte, 4)
binary.BigEndian.PutUint32(connID, uint32(i))
p := q.Dequeue(connID)
Expect(p).ToNot(BeNil())
Expect(binary.BigEndian.Uint32(p.data)).To(BeEquivalentTo(i))
queuedPackets := getQueuedPackets(getConnID(i))
Expect(queuedPackets).To(HaveLen(1))
Expect(binary.BigEndian.Uint32(queuedPackets[0].data)).To(BeEquivalentTo(i))
}
})
@ -78,12 +86,12 @@ var _ = Describe("0-RTT queue", func() {
// now try to enqueue a packet for another connection ID
connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}
q.Enqueue(connID, &receivedPacket{data: []byte("foobar")})
Expect(q.Dequeue(connID)).To(BeNil())
Expect(getQueuedPackets(connID)).To(BeEmpty())
// dequeue the packet from the first queue
Expect(q.Dequeue(protocol.ConnectionID{0, 0, 0, 0})).ToNot(BeNil())
Expect(getQueuedPackets(protocol.ConnectionID{0, 0, 0, 0})).ToNot(BeNil())
// now it should be possible to queue another packet
q.Enqueue(connID, &receivedPacket{data: []byte("foobar")})
Expect(q.Dequeue(connID)).ToNot(BeNil())
Expect(getQueuedPackets(connID)).ToNot(BeNil())
})
It("limits the number of packets it stores for one connection", func() {
@ -96,13 +104,11 @@ var _ = Describe("0-RTT queue", func() {
}
// The queue is full now. This packet will be dropped.
q.Enqueue(connID, &receivedPacket{data: []byte("foobar")})
for i := 0; i < protocol.Max0RTTQueueLen; i++ {
p := q.Dequeue(connID)
Expect(p).ToNot(BeNil())
queuedPackets := getQueuedPackets(connID)
Expect(queuedPackets).To(HaveLen(protocol.Max0RTTQueueLen))
for i, p := range queuedPackets {
Expect(binary.BigEndian.Uint32(p.data)).To(BeEquivalentTo(i))
}
// The queue should now be empty.
Expect(q.Dequeue(connID)).To(BeNil())
})
It("deletes packets if they aren't dequeued after a short while", func() {
@ -110,6 +116,6 @@ var _ = Describe("0-RTT queue", func() {
p := &receivedPacket{data: []byte("foobar"), buffer: getPacketBuffer()}
q.Enqueue(connID, p)
time.Sleep(queueDuration * 3 / 2)
Expect(q.Dequeue(connID)).To(BeNil())
Expect(getQueuedPackets(connID)).To(BeNil())
})
})