diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index 398c7e76..a76eb72e 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -7,6 +7,7 @@ import ( "io/ioutil" mrand "math/rand" "net" + "sort" "sync" "sync/atomic" "time" @@ -204,8 +205,9 @@ var _ = Describe("0-RTT", func() { "localhost:0", tlsConf, getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Versions: []protocol.VersionNumber{version}, + AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }), ) Expect(err).ToNot(HaveOccurred()) @@ -219,8 +221,10 @@ var _ = Describe("0-RTT", func() { num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).ToNot(BeZero()) - // TODO(#2629): ensure that this is a contiguous block of packets, starting at packet 0 - Expect(get0RTTPackets(tracer.getRcvdPackets())).ToNot(BeEmpty()) + zeroRTTPackets := get0RTTPackets(tracer.getRcvdPackets()) + Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10)) + sort.Slice(zeroRTTPackets, func(i, j int) bool { return zeroRTTPackets[i] < zeroRTTPackets[j] }) + Expect(zeroRTTPackets[0]).To(Equal(protocol.PacketNumber(0))) }) // Test that data intended to be sent with 1-RTT protection is not sent in 0-RTT packets. @@ -294,8 +298,7 @@ var _ = Describe("0-RTT", func() { num0RTT := atomic.LoadUint32(num0RTTPackets) fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) Expect(num0RTT).To(Or(BeEquivalentTo(2), BeEquivalentTo(3))) // the FIN might be sent in a separate packet - // TODO(#2629): check that packets are sent - // Expect(get0RTTPackets(tracer.getRcvdPackets())).ToNot(BeEmpty()) + Expect(get0RTTPackets(tracer.getRcvdPackets())).To(HaveLen(int(num0RTT))) }) It("transfers 0-RTT data, when 0-RTT packets are lost", func() { @@ -426,9 +429,8 @@ var _ = Describe("0-RTT", func() { Expect(firstCounter).To(BeNumerically("~", 5000+100 /* framing overhead */, 100)) // the FIN bit might be sent extra Expect(secondCounter).To(BeNumerically("~", firstCounter, 20)) zeroRTTPackets := get0RTTPackets(tracer.getRcvdPackets()) - // TODO(#2629): We should receive 5 packets here. - Expect(len(zeroRTTPackets)).To(BeNumerically(">=", 1)) - Expect(zeroRTTPackets[0]).To(BeNumerically(">", protocol.PacketNumber(1))) + Expect(len(zeroRTTPackets)).To(BeNumerically(">=", 5)) + Expect(zeroRTTPackets[0]).To(BeNumerically(">=", protocol.PacketNumber(5))) }) It("rejects 0-RTT when the server's transport parameters changed", func() { @@ -565,6 +567,41 @@ var _ = Describe("0-RTT", func() { Expect(num0RTT).ToNot(BeZero()) Expect(get0RTTPackets(tracer.getRcvdPackets())).To(BeEmpty()) }) + + It("queues 0-RTT packets, if the Initial is delayed", func() { + tlsConf, clientConf := dialAndReceiveSessionTicket(nil) + + tracer := newRcvdPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + Versions: []protocol.VersionNumber{version}, + AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ + RemoteAddr: ln.Addr().String(), + DelayPacket: func(dir quicproxy.Direction, data []byte) time.Duration { + if dir == quicproxy.DirectionIncoming && data[0]&0x80 > 0 && data[0]&0x30>>4 == 0 { // Initial packet from client + return rtt/2 + rtt + } + return rtt / 2 + }, + }) + Expect(err).ToNot(HaveOccurred()) + defer proxy.Close() + + transfer0RTTData(ln, proxy.LocalPort(), clientConf, PRData, true) + + Expect(tracer.rcvdPackets[0].Type).To(Equal(protocol.PacketTypeInitial)) + zeroRTTPackets := get0RTTPackets(tracer.getRcvdPackets()) + Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10)) + Expect(zeroRTTPackets[0]).To(Equal(protocol.PacketNumber(0))) + }) }) } }) diff --git a/packet_handler_map.go b/packet_handler_map.go index 65947481..edae8af2 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -26,6 +26,38 @@ func (e statelessResetErr) Error() string { return fmt.Sprintf("received a stateless reset with token %x", e.token) } +type zeroRTTQueue struct { + queue []*receivedPacket + retireTimer *time.Timer +} + +var _ packetHandler = &zeroRTTQueue{} + +func (h *zeroRTTQueue) handlePacket(p *receivedPacket) { + if len(h.queue) < protocol.Max0RTTQueueLen { + h.queue = append(h.queue, p) + } +} +func (h *zeroRTTQueue) shutdown() {} +func (h *zeroRTTQueue) destroy(error) {} +func (h *zeroRTTQueue) getPerspective() protocol.Perspective { return protocol.PerspectiveClient } +func (h *zeroRTTQueue) EnqueueAll(sess packetHandler) { + for _, p := range h.queue { + sess.handlePacket(p) + } +} + +func (h *zeroRTTQueue) Clear() { + for _, p := range h.queue { + p.buffer.Release() + } +} + +type packetHandlerMapEntry struct { + packetHandler packetHandler + is0RTTQueue bool +} + // The packetHandlerMap stores packetHandlers, identified by connection ID. // It is used: // * by the server to store sessions @@ -36,14 +68,16 @@ type packetHandlerMap struct { conn connection connIDLen int - handlers map[string] /* string(ConnectionID)*/ packetHandler - resetTokens map[protocol.StatelessResetToken] /* stateless reset token */ packetHandler - server unknownPacketHandler + handlers map[string] /* string(ConnectionID)*/ packetHandlerMapEntry + resetTokens map[protocol.StatelessResetToken] /* stateless reset token */ packetHandler + server unknownPacketHandler + numZeroRTTEntries int listening chan struct{} // is closed when listen returns closed bool deleteRetiredSessionsAfter time.Duration + zeroRTTQueueDuration time.Duration statelessResetEnabled bool statelessResetMutex sync.Mutex @@ -107,9 +141,10 @@ func newPacketHandlerMap( conn: conn, connIDLen: connIDLen, listening: make(chan struct{}), - handlers: make(map[string]packetHandler), + handlers: make(map[string]packetHandlerMapEntry), resetTokens: make(map[protocol.StatelessResetToken]packetHandler), deleteRetiredSessionsAfter: protocol.RetiredConnectionIDDeleteTimeout, + zeroRTTQueueDuration: protocol.Max0RTTQueueingDuration, statelessResetEnabled: len(statelessResetKey) > 0, statelessResetHasher: hmac.New(sha256.New, statelessResetKey), tracer: tracer, @@ -157,7 +192,7 @@ func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) h.logger.Debugf("Not adding connection ID %s, as it already exists.", id) return false } - h.handlers[string(id)] = handler + h.handlers[string(id)] = packetHandlerMapEntry{packetHandler: handler} h.logger.Debugf("Adding connection ID %s.", id) return true } @@ -166,14 +201,25 @@ func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.Co h.mutex.Lock() defer h.mutex.Unlock() - if _, ok := h.handlers[string(clientDestConnID)]; ok { - h.logger.Debugf("Not adding connection ID %s for a new session, as it already exists.", clientDestConnID) - return false + var q *zeroRTTQueue + if entry, ok := h.handlers[string(clientDestConnID)]; ok { + if !entry.is0RTTQueue { + h.logger.Debugf("Not adding connection ID %s for a new session, as it already exists.", clientDestConnID) + return false + } + q = entry.packetHandler.(*zeroRTTQueue) + q.retireTimer.Stop() + h.numZeroRTTEntries-- + if h.numZeroRTTEntries < 0 { + panic("number of 0-RTT queues < 0") + } } - sess := fn() - h.handlers[string(clientDestConnID)] = sess - h.handlers[string(newConnID)] = sess + if q != nil { + q.EnqueueAll(sess) + } + h.handlers[string(clientDestConnID)] = packetHandlerMapEntry{packetHandler: sess} + h.handlers[string(newConnID)] = packetHandlerMapEntry{packetHandler: sess} h.logger.Debugf("Adding connection IDs %s and %s for a new session.", clientDestConnID, newConnID) return true } @@ -197,7 +243,7 @@ func (h *packetHandlerMap) Retire(id protocol.ConnectionID) { func (h *packetHandlerMap) ReplaceWithClosed(id protocol.ConnectionID, handler packetHandler) { h.mutex.Lock() - h.handlers[string(id)] = handler + h.handlers[string(id)] = packetHandlerMapEntry{packetHandler: handler} h.mutex.Unlock() h.logger.Debugf("Replacing session for connection ID %s with a closed session.", id) @@ -236,14 +282,14 @@ func (h *packetHandlerMap) CloseServer() { } h.server = nil var wg sync.WaitGroup - for _, handler := range h.handlers { - if handler.getPerspective() == protocol.PerspectiveServer { + for _, entry := range h.handlers { + if entry.packetHandler.getPerspective() == protocol.PerspectiveServer { wg.Add(1) go func(handler packetHandler) { // blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped handler.shutdown() wg.Done() - }(handler) + }(entry.packetHandler) } } h.mutex.Unlock() @@ -268,12 +314,12 @@ func (h *packetHandlerMap) close(e error) error { } var wg sync.WaitGroup - for _, handler := range h.handlers { + for _, entry := range h.handlers { wg.Add(1) go func(handler packetHandler) { handler.destroy(e) wg.Done() - }(handler) + }(entry.packetHandler) } if h.server != nil { @@ -319,9 +365,16 @@ func (h *packetHandlerMap) handlePacket(p *receivedPacket) { return } - if handler, ok := h.handlers[string(connID)]; ok { // existing session - handler.handlePacket(p) - return + if entry, ok := h.handlers[string(connID)]; ok { + if entry.is0RTTQueue { // only enqueue 0-RTT packets in the 0-RTT queue + if wire.Is0RTTPacket(p.data) { + entry.packetHandler.handlePacket(p) + return + } + } else { // existing session + entry.packetHandler.handlePacket(p) + return + } } if p.data[0]&0x80 == 0 { go h.maybeSendStatelessReset(p, connID) @@ -331,6 +384,36 @@ func (h *packetHandlerMap) handlePacket(p *receivedPacket) { h.logger.Debugf("received a packet with an unexpected connection ID %s", connID) return } + if wire.Is0RTTPacket(p.data) { + if h.numZeroRTTEntries >= protocol.Max0RTTQueues { + return + } + h.numZeroRTTEntries++ + queue := &zeroRTTQueue{queue: make([]*receivedPacket, 0, 8)} + h.handlers[string(connID)] = packetHandlerMapEntry{ + packetHandler: queue, + is0RTTQueue: true, + } + queue.retireTimer = time.AfterFunc(h.zeroRTTQueueDuration, func() { + h.mutex.Lock() + defer h.mutex.Unlock() + // The entry might have been replaced by an actual session. + // Only delete it if it's still a 0-RTT queue. + if entry, ok := h.handlers[string(connID)]; ok && entry.is0RTTQueue { + delete(h.handlers, string(connID)) + h.numZeroRTTEntries-- + if h.numZeroRTTEntries < 0 { + panic("number of 0-RTT queues < 0") + } + entry.packetHandler.(*zeroRTTQueue).Clear() + if h.logger.Debug() { + h.logger.Debugf("Removing 0-RTT queue for %s.", connID) + } + } + }) + queue.handlePacket(p) + return + } h.server.handlePacket(p) } diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index d22ba78a..76de8c16 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -36,12 +36,12 @@ var _ = Describe("Packet Handler Map", func() { statelessResetKey []byte ) - getPacketWithLength := func(connID protocol.ConnectionID, length protocol.ByteCount) []byte { + getPacketWithPacketType := func(connID protocol.ConnectionID, t protocol.PacketType, length protocol.ByteCount) []byte { buf := &bytes.Buffer{} Expect((&wire.ExtendedHeader{ Header: wire.Header{ IsLongHeader: true, - Type: protocol.PacketTypeHandshake, + Type: t, DestConnectionID: connID, Length: length, Version: protocol.VersionTLS, @@ -52,7 +52,7 @@ var _ = Describe("Packet Handler Map", func() { } getPacket := func(connID protocol.ConnectionID) []byte { - return getPacketWithLength(connID, 2) + return getPacketWithPacketType(connID, protocol.PacketTypeHandshake, 2) } BeforeEach(func() { @@ -274,6 +274,88 @@ var _ = Describe("Packet Handler Map", func() { }) }) + Context("0-RTT", func() { + JustBeforeEach(func() { + handler.zeroRTTQueueDuration = time.Hour + server := NewMockUnknownPacketHandler(mockCtrl) + // we don't expect any calls to server.handlePacket + handler.SetServer(server) + }) + + It("queues 0-RTT packets", func() { + server := NewMockUnknownPacketHandler(mockCtrl) + // don't EXPECT any calls to server.handlePacket + handler.SetServer(server) + connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} + p1 := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1)} + p2 := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 2)} + p3 := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 3)} + handler.handlePacket(p1) + handler.handlePacket(p2) + handler.handlePacket(p3) + sess := NewMockPacketHandler(mockCtrl) + done := make(chan struct{}) + gomock.InOrder( + sess.EXPECT().handlePacket(p1), + sess.EXPECT().handlePacket(p2), + sess.EXPECT().handlePacket(p3).Do(func(packet *receivedPacket) { close(done) }), + ) + handler.AddWithConnID(connID, protocol.ConnectionID{1, 2, 3, 4}, func() packetHandler { return sess }) + Eventually(done).Should(BeClosed()) + }) + + It("directs 0-RTT packets to existing sessions", func() { + connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} + sess := NewMockPacketHandler(mockCtrl) + handler.AddWithConnID(connID, protocol.ConnectionID{1, 2, 3, 4}, func() packetHandler { return sess }) + p1 := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1)} + sess.EXPECT().handlePacket(p1) + handler.handlePacket(p1) + }) + + It("limits the number of 0-RTT queues", func() { + for i := 0; i < protocol.Max0RTTQueues; i++ { + connID := make(protocol.ConnectionID, 8) + rand.Read(connID) + p := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1)} + handler.handlePacket(p) + } + // We're already storing the maximum number of queues. This packet will be dropped. + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9} + handler.handlePacket(&receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1)}) + // Don't EXPECT any handlePacket() calls. + sess := NewMockPacketHandler(mockCtrl) + handler.AddWithConnID(connID, protocol.ConnectionID{1, 2, 3, 4}, func() packetHandler { return sess }) + time.Sleep(20 * time.Millisecond) + }) + + It("deletes queues if no session is created for this connection ID", func() { + queueDuration := scaleDuration(10 * time.Millisecond) + handler.zeroRTTQueueDuration = queueDuration + + server := NewMockUnknownPacketHandler(mockCtrl) + // don't EXPECT any calls to server.handlePacket + handler.SetServer(server) + connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} + p1 := &receivedPacket{ + data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1), + buffer: getPacketBuffer(), + } + p2 := &receivedPacket{ + data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 2), + buffer: getPacketBuffer(), + } + handler.handlePacket(p1) + handler.handlePacket(p2) + // wait a bit. The queue should now already be deleted. + time.Sleep(queueDuration * 3) + // Don't EXPECT any handlePacket() calls. + sess := NewMockPacketHandler(mockCtrl) + handler.AddWithConnID(connID, protocol.ConnectionID{1, 2, 3, 4}, func() packetHandler { return sess }) + time.Sleep(20 * time.Millisecond) + }) + }) + Context("stateless resets", func() { BeforeEach(func() { connIDLen = 5 diff --git a/server.go b/server.go index 475cb3fd..d8a125be 100644 --- a/server.go +++ b/server.go @@ -68,7 +68,6 @@ type baseServer struct { tokenGenerator *handshake.TokenGenerator - zeroRTTQueue *zeroRTTQueue sessionHandler packetHandlerManager receivedPackets chan *receivedPacket @@ -200,7 +199,6 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarl config: config, tokenGenerator: tokenGenerator, sessionHandler: sessionHandler, - zeroRTTQueue: newZeroRTTQueue(), sessionQueue: make(chan quicSession), errorChan: make(chan struct{}), running: make(chan struct{}), @@ -365,20 +363,15 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s go s.sendVersionNegotiationPacket(p, hdr) return false } - if hdr.IsLongHeader { - if hdr.Type == protocol.PacketType0RTT { - s.zeroRTTQueue.Enqueue(hdr.DestConnectionID, p) - return true - } else if hdr.Type != protocol.PacketTypeInitial { - // Drop long header packets. - // There's little point in sending a Stateless Reset, since the client - // might not have received the token yet. - s.logger.Debugf("Dropping long header packet of type %s (%d bytes)", hdr.Type, len(p.data)) - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropUnexpectedPacket) - } - return false + if hdr.IsLongHeader && hdr.Type != protocol.PacketTypeInitial { + // Drop long header packets. + // There's little point in sending a Stateless Reset, since the client + // might not have received the token yet. + s.logger.Debugf("Dropping long header packet of type %s (%d bytes)", hdr.Type, len(p.data)) + if s.config.Tracer != nil { + s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropUnexpectedPacket) } + return false } s.logger.Debugf("<- Received Initial packet.") @@ -490,7 +483,6 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro p.buffer.Release() return nil } - s.zeroRTTQueue.DequeueToSession(hdr.DestConnectionID, sess) return nil } diff --git a/server_test.go b/server_test.go index 6de2ee71..b97c8ccc 100644 --- a/server_test.go +++ b/server_test.go @@ -629,60 +629,6 @@ var _ = Describe("Server", func() { Eventually(done).Should(BeClosed()) }) - It("passes queued 0-RTT packets to the session", func() { - serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } - var createdSession bool - sess := NewMockQuicSession(mockCtrl) - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9} - initialPacket := getInitial(connID) - zeroRTTPacket := getPacket(&wire.Header{ - IsLongHeader: true, - Type: protocol.PacketType0RTT, - SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, - DestConnectionID: connID, - Version: protocol.VersionTLS, - }, []byte("foobar")) - sess.EXPECT().Context().Return(context.Background()).MaxTimes(1) - sess.EXPECT().HandshakeComplete().Return(context.Background()).MaxTimes(1) - sess.EXPECT().run().MaxTimes(1) - gomock.InOrder( - sess.EXPECT().handlePacket(initialPacket), - sess.EXPECT().handlePacket(zeroRTTPacket), - ) - serv.newSession = func( - _ sendConn, - runner sessionRunner, - _ protocol.ConnectionID, - _ *protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.ConnectionID, - _ protocol.StatelessResetToken, - _ *Config, - _ *tls.Config, - _ *handshake.TokenGenerator, - _ bool, - _ logging.ConnectionTracer, - _ utils.Logger, - _ protocol.VersionNumber, - ) quicSession { - createdSession = true - return sess - } - - // Receive the 0-RTT packet first. - Expect(serv.handlePacketImpl(zeroRTTPacket)).To(BeTrue()) - // Then receive the Initial packet. - phm.EXPECT().GetStatelessResetToken(gomock.Any()) - phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { - fn() - return true - }) - tracer.EXPECT().TracerForConnection(protocol.PerspectiveServer, gomock.Any()) - Expect(serv.handlePacketImpl(initialPacket)).To(BeTrue()) - Expect(createdSession).To(BeTrue()) - }) - It("drops packets if the receive queue is full", func() { phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { phm.EXPECT().GetStatelessResetToken(gomock.Any()) diff --git a/zero_rtt_queue.go b/zero_rtt_queue.go deleted file mode 100644 index b0173f9d..00000000 --- a/zero_rtt_queue.go +++ /dev/null @@ -1,79 +0,0 @@ -package quic - -import ( - "sync" - "time" - - "github.com/lucas-clemente/quic-go/internal/protocol" -) - -type zeroRTTQueueEntry struct { - timer *time.Timer - packets []*receivedPacket -} - -type zeroRTTQueue struct { - mutex sync.Mutex - queue map[string]*zeroRTTQueueEntry - queueDuration time.Duration // so we can set it in tests -} - -func newZeroRTTQueue() *zeroRTTQueue { - return &zeroRTTQueue{ - queue: make(map[string]*zeroRTTQueueEntry), - queueDuration: protocol.Max0RTTQueueingDuration, - } -} - -func (h *zeroRTTQueue) Enqueue(connID protocol.ConnectionID, p *receivedPacket) { - h.mutex.Lock() - defer h.mutex.Unlock() - - cid := string(connID) - if _, ok := h.queue[cid]; !ok { - if len(h.queue) >= protocol.Max0RTTQueues { - return - } - h.queue[cid] = &zeroRTTQueueEntry{timer: time.AfterFunc(h.queueDuration, func() { - h.deleteQueue(connID) - })} - } - entry := h.queue[cid] - if len(entry.packets) >= protocol.Max0RTTQueueLen { - return - } - entry.packets = append(entry.packets, p) -} - -func (h *zeroRTTQueue) DequeueToSession(connID protocol.ConnectionID, sess packetHandler) { - h.mutex.Lock() - defer h.mutex.Unlock() - - h.dequeueToSession(connID, sess) -} - -func (h *zeroRTTQueue) dequeueToSession(connID protocol.ConnectionID, sess packetHandler) { - entry, ok := h.queue[string(connID)] - if !ok { - return - } - entry.timer.Stop() - for _, p := range entry.packets { - sess.handlePacket(p) - } - delete(h.queue, string(connID)) -} - -func (h *zeroRTTQueue) deleteQueue(connID protocol.ConnectionID) { - h.mutex.Lock() - defer h.mutex.Unlock() - - entry, ok := h.queue[string(connID)] - if !ok { - return - } - for _, p := range entry.packets { - p.buffer.Release() - } - delete(h.queue, string(connID)) -} diff --git a/zero_rtt_queue_test.go b/zero_rtt_queue_test.go deleted file mode 100644 index b4d37996..00000000 --- a/zero_rtt_queue_test.go +++ /dev/null @@ -1,121 +0,0 @@ -package quic - -import ( - "encoding/binary" - "time" - - "github.com/golang/mock/gomock" - "github.com/lucas-clemente/quic-go/internal/protocol" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("0-RTT queue", func() { - var q *zeroRTTQueue - queueDuration := scaleDuration(50 * time.Millisecond) - - getQueuedPackets := func(connID protocol.ConnectionID) []*receivedPacket { - var packets []*receivedPacket - sess := NewMockPacketHandler(mockCtrl) - sess.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { - packets = append(packets, p) - }).AnyTimes() - q.DequeueToSession(connID, sess) - return packets - } - - BeforeEach(func() { - q = newZeroRTTQueue() - q.queueDuration = queueDuration - }) - - AfterEach(func() { - // dequeue all packets to make sure the timers are stopped - q.mutex.Lock() - for connID := range q.queue { - sess := NewMockPacketHandler(mockCtrl) - sess.EXPECT().handlePacket(gomock.Any()).AnyTimes() - q.dequeueToSession(protocol.ConnectionID(connID), sess) - } - q.mutex.Unlock() - }) - - It("stores a 0-RTT packet", func() { - connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} - p := &receivedPacket{data: []byte("foobar")} - q.Enqueue(connID, p) - queuedPackets := getQueuedPackets(connID) - Expect(queuedPackets).To(Equal([]*receivedPacket{p})) - }) - - It("doesn't dequeue for unknown connection IDs", func() { - Expect(getQueuedPackets(protocol.ConnectionID{0x42})).To(BeEmpty()) - }) - - It("only stores packets for Max0RTTQueues connection", func() { - getConnID := func(i int) protocol.ConnectionID { - connID := make([]byte, 4) - binary.BigEndian.PutUint32(connID, uint32(i)) - return connID - } - - // fill up the queues - for i := 0; i < protocol.Max0RTTQueues; i++ { - connID := getConnID(i) - q.Enqueue(connID, &receivedPacket{data: []byte(connID)}) - } - // now try to enqueue a packet for another connection ID - connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} - q.Enqueue(connID, &receivedPacket{data: []byte("foobar")}) - // check that the other queues were all saved - for i := 0; i < protocol.Max0RTTQueues; i++ { - queuedPackets := getQueuedPackets(getConnID(i)) - Expect(queuedPackets).To(HaveLen(1)) - Expect(binary.BigEndian.Uint32(queuedPackets[0].data)).To(BeEquivalentTo(i)) - } - }) - - It("removes queues when packets are dequeued", func() { - // fill up the queues - for i := 0; i < protocol.Max0RTTQueues; i++ { - data := make([]byte, 4) - binary.BigEndian.PutUint32(data, uint32(i)) - q.Enqueue(protocol.ConnectionID(data), &receivedPacket{data: data}) - } - // now try to enqueue a packet for another connection ID - connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} - q.Enqueue(connID, &receivedPacket{data: []byte("foobar")}) - Expect(getQueuedPackets(connID)).To(BeEmpty()) - // dequeue the packet from the first queue - Expect(getQueuedPackets(protocol.ConnectionID{0, 0, 0, 0})).ToNot(BeNil()) - // now it should be possible to queue another packet - q.Enqueue(connID, &receivedPacket{data: []byte("foobar")}) - Expect(getQueuedPackets(connID)).ToNot(BeNil()) - }) - - It("limits the number of packets it stores for one connection", func() { - connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} - // fill up the queue - for i := 0; i < protocol.Max0RTTQueueLen; i++ { - data := make([]byte, 4) - binary.BigEndian.PutUint32(data, uint32(i)) - q.Enqueue(connID, &receivedPacket{data: data}) - } - // The queue is full now. This packet will be dropped. - q.Enqueue(connID, &receivedPacket{data: []byte("foobar")}) - queuedPackets := getQueuedPackets(connID) - Expect(queuedPackets).To(HaveLen(protocol.Max0RTTQueueLen)) - for i, p := range queuedPackets { - Expect(binary.BigEndian.Uint32(p.data)).To(BeEquivalentTo(i)) - } - }) - - It("deletes packets if they aren't dequeued after a short while", func() { - connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} - p := &receivedPacket{data: []byte("foobar"), buffer: getPacketBuffer()} - q.Enqueue(connID, p) - time.Sleep(queueDuration * 3 / 2) - Expect(getQueuedPackets(connID)).To(BeNil()) - }) -})