diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index 5aee0361..70b1199e 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -7,7 +7,6 @@ import ( "io" mrand "math/rand" "net" - "sort" "sync" "sync/atomic" "time" @@ -214,7 +213,7 @@ var _ = Describe("0-RTT", func() { ln, proxy.LocalPort(), clientTLSConf, - &quic.Config{ConnectionIDLength: connIDLen}, + getQuicConfig(&quic.Config{ConnectionIDLength: connIDLen}), PRData, ) @@ -237,8 +236,7 @@ var _ = Describe("0-RTT", func() { Expect(num0RTT).ToNot(BeZero()) zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10)) - sort.Slice(zeroRTTPackets, func(i, j int) bool { return zeroRTTPackets[i] < zeroRTTPackets[j] }) - Expect(zeroRTTPackets[0]).To(Equal(protocol.PacketNumber(0))) + Expect(zeroRTTPackets).To(ContainElement(protocol.PacketNumber(0))) }) } diff --git a/mock_packet_handler_manager_test.go b/mock_packet_handler_manager_test.go index 3b71f7f2..bcd16038 100644 --- a/mock_packet_handler_manager_test.go +++ b/mock_packet_handler_manager_test.go @@ -100,6 +100,21 @@ func (mr *MockPacketHandlerManagerMockRecorder) Destroy() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Destroy", reflect.TypeOf((*MockPacketHandlerManager)(nil).Destroy)) } +// Get mocks base method. +func (m *MockPacketHandlerManager) Get(arg0 protocol.ConnectionID) (packetHandler, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", arg0) + ret0, _ := ret[0].(packetHandler) + ret1, _ := ret[1].(bool) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockPacketHandlerManagerMockRecorder) Get(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockPacketHandlerManager)(nil).Get), arg0) +} + // GetStatelessResetToken mocks base method. func (m *MockPacketHandlerManager) GetStatelessResetToken(arg0 protocol.ConnectionID) protocol.StatelessResetToken { m.ctrl.T.Helper() diff --git a/packet_handler_map.go b/packet_handler_map.go index 1f643412..99e2bfb1 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -48,16 +48,14 @@ type packetHandlerMap struct { closeQueue chan closePacket - handlers map[protocol.ConnectionID]packetHandler - resetTokens map[protocol.StatelessResetToken] /* stateless reset token */ packetHandler - server unknownPacketHandler - numZeroRTTEntries int + handlers map[protocol.ConnectionID]packetHandler + resetTokens map[protocol.StatelessResetToken] /* stateless reset token */ packetHandler + server unknownPacketHandler listening chan struct{} // is closed when listen returns closed bool deleteRetiredConnsAfter time.Duration - zeroRTTQueueDuration time.Duration statelessResetEnabled bool statelessResetMutex sync.Mutex @@ -130,7 +128,6 @@ func newPacketHandlerMap( handlers: make(map[protocol.ConnectionID]packetHandler), resetTokens: make(map[protocol.StatelessResetToken]packetHandler), deleteRetiredConnsAfter: protocol.RetiredConnectionIDDeleteTimeout, - zeroRTTQueueDuration: protocol.Max0RTTQueueingDuration, closeQueue: make(chan closePacket, 4), statelessResetEnabled: statelessResetKey != nil, tracer: tracer, @@ -174,6 +171,14 @@ func (h *packetHandlerMap) logUsage() { } } +func (h *packetHandlerMap) Get(id protocol.ConnectionID) (packetHandler, bool) { + h.mutex.Lock() + defer h.mutex.Unlock() + + handler, ok := h.handlers[id] + return handler, ok +} + func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) bool /* was added */ { h.mutex.Lock() defer h.mutex.Unlock() @@ -191,23 +196,11 @@ func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.Co h.mutex.Lock() defer h.mutex.Unlock() - var q *zeroRTTQueue - if handler, ok := h.handlers[clientDestConnID]; ok { - q, ok = handler.(*zeroRTTQueue) - if !ok { - h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID) - return false - } - q.retireTimer.Stop() - h.numZeroRTTEntries-- - if h.numZeroRTTEntries < 0 { - panic("number of 0-RTT queues < 0") - } + if _, ok := h.handlers[clientDestConnID]; ok { + h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID) + return false } conn := fn() - if q != nil { - q.EnqueueAll(conn) - } h.handlers[clientDestConnID] = conn h.handlers[newConnID] = conn h.logger.Debugf("Adding connection IDs %s and %s for a new connection.", clientDestConnID, newConnID) @@ -395,17 +388,9 @@ func (h *packetHandlerMap) handlePacket(p *receivedPacket) { if isStatelessReset := h.maybeHandleStatelessReset(p.data); isStatelessReset { return } - if handler, ok := h.handlers[connID]; ok { - if ha, ok := handler.(*zeroRTTQueue); ok { // only enqueue 0-RTT packets in the 0-RTT queue - if wire.Is0RTTPacket(p.data) { - ha.handlePacket(p) - return - } - } else { // existing connection - handler.handlePacket(p) - return - } + handler.handlePacket(p) + return } if !wire.IsLongHeaderPacket(p.data[0]) { go h.maybeSendStatelessReset(p, connID) @@ -415,38 +400,6 @@ func (h *packetHandlerMap) handlePacket(p *receivedPacket) { h.logger.Debugf("received a packet with an unexpected connection ID %s", connID) return } - if wire.Is0RTTPacket(p.data) { - if h.numZeroRTTEntries >= protocol.Max0RTTQueues { - if h.tracer != nil { - h.tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) - } - return - } - h.numZeroRTTEntries++ - queue := &zeroRTTQueue{queue: make([]*receivedPacket, 0, 8)} - h.handlers[connID] = queue - queue.retireTimer = time.AfterFunc(h.zeroRTTQueueDuration, func() { - h.mutex.Lock() - defer h.mutex.Unlock() - // The entry might have been replaced by an actual connection. - // Only delete it if it's still a 0-RTT queue. - if handler, ok := h.handlers[connID]; ok { - if q, ok := handler.(*zeroRTTQueue); ok { - delete(h.handlers, connID) - h.numZeroRTTEntries-- - if h.numZeroRTTEntries < 0 { - panic("number of 0-RTT queues < 0") - } - q.Clear() - if h.logger.Debug() { - h.logger.Debugf("Removing 0-RTT queue for %s.", connID) - } - } - } - }) - queue.handlePacket(p) - return - } h.server.handlePacket(p) } diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index f893c839..0397e3f0 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -271,93 +271,6 @@ var _ = Describe("Packet Handler Map", func() { }) }) - Context("0-RTT", func() { - JustBeforeEach(func() { - handler.zeroRTTQueueDuration = time.Hour - server := NewMockUnknownPacketHandler(mockCtrl) - // we don't expect any calls to server.handlePacket - handler.SetServer(server) - }) - - It("queues 0-RTT packets", func() { - server := NewMockUnknownPacketHandler(mockCtrl) - // don't EXPECT any calls to server.handlePacket - handler.SetServer(server) - connID := protocol.ParseConnectionID([]byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}) - p1 := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1)} - p2 := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 2)} - p3 := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 3)} - handler.handlePacket(p1) - handler.handlePacket(p2) - handler.handlePacket(p3) - conn := NewMockPacketHandler(mockCtrl) - done := make(chan struct{}) - gomock.InOrder( - conn.EXPECT().handlePacket(p1), - conn.EXPECT().handlePacket(p2), - conn.EXPECT().handlePacket(p3).Do(func(packet *receivedPacket) { close(done) }), - ) - handler.AddWithConnID(connID, protocol.ParseConnectionID([]byte{1, 2, 3, 4}), func() packetHandler { return conn }) - Eventually(done).Should(BeClosed()) - }) - - It("directs 0-RTT packets to existing connections", func() { - connID := protocol.ParseConnectionID([]byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}) - conn := NewMockPacketHandler(mockCtrl) - handler.AddWithConnID(connID, protocol.ParseConnectionID([]byte{1, 2, 3, 4}), func() packetHandler { return conn }) - p1 := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1)} - conn.EXPECT().handlePacket(p1) - handler.handlePacket(p1) - }) - - It("limits the number of 0-RTT queues", func() { - for i := 0; i < protocol.Max0RTTQueues; i++ { - b := make([]byte, 8) - rand.Read(b) - p := &receivedPacket{data: getPacketWithPacketType( - protocol.ParseConnectionID(b), - protocol.PacketType0RTT, - 1, - )} - handler.handlePacket(p) - } - // We're already storing the maximum number of queues. This packet will be dropped. - connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9}) - tracer.EXPECT().DroppedPacket(gomock.Any(), logging.PacketType0RTT, gomock.Any(), logging.PacketDropDOSPrevention) - handler.handlePacket(&receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1)}) - // Don't EXPECT any handlePacket() calls. - conn := NewMockPacketHandler(mockCtrl) - handler.AddWithConnID(connID, protocol.ParseConnectionID([]byte{1, 2, 3, 4}), func() packetHandler { return conn }) - time.Sleep(20 * time.Millisecond) - }) - - It("deletes queues if no connection is created for this connection ID", func() { - queueDuration := scaleDuration(10 * time.Millisecond) - handler.zeroRTTQueueDuration = queueDuration - - server := NewMockUnknownPacketHandler(mockCtrl) - // don't EXPECT any calls to server.handlePacket - handler.SetServer(server) - connID := protocol.ParseConnectionID([]byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}) - p1 := &receivedPacket{ - data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1), - buffer: getPacketBuffer(), - } - p2 := &receivedPacket{ - data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 2), - buffer: getPacketBuffer(), - } - handler.handlePacket(p1) - handler.handlePacket(p2) - // wait a bit. The queue should now already be deleted. - time.Sleep(queueDuration * 3) - // Don't EXPECT any handlePacket() calls. - conn := NewMockPacketHandler(mockCtrl) - handler.AddWithConnID(connID, protocol.ParseConnectionID([]byte{1, 2, 3, 4}), func() packetHandler { return conn }) - time.Sleep(20 * time.Millisecond) - }) - }) - Context("stateless resets", func() { BeforeEach(func() { connIDLen = 5 diff --git a/server.go b/server.go index edacdd85..d984beea 100644 --- a/server.go +++ b/server.go @@ -36,6 +36,7 @@ type unknownPacketHandler interface { } type packetHandlerManager interface { + Get(protocol.ConnectionID) (packetHandler, bool) AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() packetHandler) bool Destroy() error connRunner @@ -54,6 +55,11 @@ type quicConn interface { shutdown() } +type zeroRTTQueue struct { + packets []*receivedPacket + expiration time.Time +} + // A Listener of QUIC type baseServer struct { mutex sync.Mutex @@ -74,6 +80,9 @@ type baseServer struct { receivedPackets chan *receivedPacket + nextZeroRTTCleanup time.Time + zeroRTTQueues map[protocol.ConnectionID]*zeroRTTQueue // only initialized if acceptEarlyConns == true + // set as a member, so they can be set in the tests newConn func( sendConn, @@ -214,6 +223,9 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarl logger: utils.DefaultLogger.WithPrefix("server"), acceptEarlyConns: acceptEarly, } + if acceptEarly { + s.zeroRTTQueues = map[protocol.ConnectionID]*zeroRTTQueue{} + } go s.run() connHandler.SetServer(s) s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String()) @@ -310,6 +322,10 @@ func (s *baseServer) handlePacket(p *receivedPacket) { } func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer still in use? */ { + if !s.nextZeroRTTCleanup.IsZero() && p.rcvTime.After(s.nextZeroRTTCleanup) { + defer s.cleanupZeroRTTQueues(p.rcvTime) + } + if wire.IsVersionNegotiationPacket(p.data) { s.logger.Debugf("Dropping Version Negotiation packet.") if s.config.Tracer != nil { @@ -344,6 +360,17 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s } return false } + + if wire.Is0RTTPacket(p.data) { + if !s.acceptEarlyConns { + if s.config.Tracer != nil { + s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropUnexpectedPacket) + } + return false + } + return s.handle0RTTPacket(p) + } + // If we're creating a new connection, the packet will be passed to the connection. // The header will then be parsed again. hdr, _, _, err := wire.ParsePacket(p.data) @@ -383,6 +410,74 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s return true } +func (s *baseServer) handle0RTTPacket(p *receivedPacket) bool { + connID, err := wire.ParseConnectionID(p.data, 0) + if err != nil { + if s.config.Tracer != nil { + s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropHeaderParseError) + } + return false + } + + // check again if we might have a connection now + if handler, ok := s.connHandler.Get(connID); ok { + handler.handlePacket(p) + return true + } + + if q, ok := s.zeroRTTQueues[connID]; ok { + if len(q.packets) >= protocol.Max0RTTQueueLen { + if s.config.Tracer != nil { + s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) + } + return false + } + q.packets = append(q.packets, p) + return true + } + + if len(s.zeroRTTQueues) >= protocol.Max0RTTQueues { + if s.config.Tracer != nil { + s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) + } + return false + } + queue := &zeroRTTQueue{packets: make([]*receivedPacket, 1, 8)} + queue.packets[0] = p + expiration := p.rcvTime.Add(protocol.Max0RTTQueueingDuration) + queue.expiration = expiration + if s.nextZeroRTTCleanup.IsZero() || s.nextZeroRTTCleanup.After(expiration) { + s.nextZeroRTTCleanup = expiration + } + s.zeroRTTQueues[connID] = queue + return true +} + +func (s *baseServer) cleanupZeroRTTQueues(now time.Time) { + // Iterate over all queues to find those that are expired. + // This is ok since we're placing a pretty low limit on the number of queues. + var nextCleanup time.Time + for connID, q := range s.zeroRTTQueues { + if q.expiration.After(now) { + if nextCleanup.IsZero() || nextCleanup.After(q.expiration) { + nextCleanup = q.expiration + } + continue + } + for _, p := range q.packets { + if s.config.Tracer != nil { + s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) + } + p.buffer.Release() + } + delete(s.zeroRTTQueues, connID) + if s.logger.Debug() { + s.logger.Debugf("Removing 0-RTT queue for %s.", connID) + } + } + s.nextZeroRTTCleanup = nextCleanup +} + // validateToken returns false if: // - address is invalid // - token is expired @@ -412,6 +507,14 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro return errors.New("too short connection ID") } + // The server queues packets for a while, and we might already have established a connection by now. + // This results in a second check in the connection map. + // That's ok since it's not the hot path (it's only taken by some Initial and 0-RTT packets). + if handler, ok := s.connHandler.Get(hdr.DestConnectionID); ok { + handler.handlePacket(p) + return nil + } + var ( token *handshake.Token retrySrcConnID *protocol.ConnectionID @@ -429,7 +532,6 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro } clientAddrIsValid := s.validateToken(token, p.remoteAddr) - if token != nil && !clientAddrIsValid { // For invalid and expired non-retry tokens, we don't send an INVALID_TOKEN error. // We just ignore them, and act as if there was no token on this packet at all. @@ -450,6 +552,8 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro } } if token == nil && s.config.RequireAddressValidation(p.remoteAddr) { + // Retry invalidates all 0-RTT packets sent. + delete(s.zeroRTTQueues, hdr.DestConnectionID) go func() { defer p.buffer.Release() if err := s.sendRetry(p.remoteAddr, hdr, p.info); err != nil { @@ -510,8 +614,18 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro hdr.Version, ) conn.handlePacket(p) + + if q, ok := s.zeroRTTQueues[hdr.DestConnectionID]; ok { + for _, p := range q.packets { + conn.handlePacket(p) + } + delete(s.zeroRTTQueues, hdr.DestConnectionID) + } + return conn }); !added { + // TODO: don't just drop the packet + // Properly reject the connection attempt. return nil } go conn.run() diff --git a/server_test.go b/server_test.go index e2c801e9..a349efc1 100644 --- a/server_test.go +++ b/server_test.go @@ -215,8 +215,7 @@ var _ = Describe("Server", func() { Type: protocol.PacketTypeInitial, DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), Version: serv.config.Versions[0], - }, make([]byte, protocol.MinInitialPacketSize-100), - ) + }, make([]byte, protocol.MinInitialPacketSize-100)) tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) serv.handlePacket(p) // make sure there are no Write calls on the packet conn @@ -234,6 +233,21 @@ var _ = Describe("Server", func() { time.Sleep(50 * time.Millisecond) }) + It("passes packets to existing connections", func() { + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) + p := getPacket(&wire.Header{ + Type: protocol.PacketTypeInitial, + DestConnectionID: connID, + Version: serv.config.Versions[0], + }, make([]byte, protocol.MinInitialPacketSize)) + conn := NewMockPacketHandler(mockCtrl) + phm.EXPECT().Get(connID).Return(conn, true) + handled := make(chan struct{}) + conn.EXPECT().handlePacket(p).Do(func(*receivedPacket) { close(handled) }) + serv.handlePacket(p) + Eventually(handled).Should(BeClosed()) + }) + It("creates a connection when the token is accepted", func() { serv.config.RequireAddressValidation = func(net.Addr) bool { return true } raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} @@ -243,10 +257,11 @@ var _ = Describe("Server", func() { protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), ) Expect(err).ToNot(HaveOccurred()) + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) hdr := &wire.Header{ Type: protocol.PacketTypeInitial, SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), - DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), + DestConnectionID: connID, Version: protocol.Version1, Token: retryToken, } @@ -257,7 +272,9 @@ var _ = Describe("Server", func() { rand.Read(token[:]) var newConnID protocol.ConnectionID - phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool { + + phm.EXPECT().Get(connID) + phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool { newConnID = c phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken { newConnID = c @@ -408,11 +425,12 @@ var _ = Describe("Server", func() { }) It("replies with a Retry packet, if a token is required", func() { + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) serv.config.RequireAddressValidation = func(net.Addr) bool { return true } hdr := &wire.Header{ Type: protocol.PacketTypeInitial, SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), - DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), + DestConnectionID: connID, Version: protocol.Version1, } packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) @@ -435,15 +453,17 @@ var _ = Describe("Server", func() { Expect(b[len(b)-16:]).To(Equal(handshake.GetRetryIntegrityTag(b[:len(b)-16], hdr.DestConnectionID, hdr.Version)[:])) return len(b), nil }) + phm.EXPECT().Get(connID) serv.handlePacket(packet) Eventually(done).Should(BeClosed()) }) It("creates a connection, if no token is required", func() { + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) hdr := &wire.Header{ Type: protocol.PacketTypeInitial, SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), - DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), + DestConnectionID: connID, Version: protocol.Version1, } p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) @@ -452,15 +472,18 @@ var _ = Describe("Server", func() { rand.Read(token[:]) var newConnID protocol.ConnectionID - phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool { - newConnID = c - phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken { + gomock.InOrder( + phm.EXPECT().Get(connID), + phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool { newConnID = c - return token - }) - fn() - return true - }) + phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken { + newConnID = c + return token + }) + fn() + return true + }), + ) tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) conn := NewMockQUICConn(mockCtrl) @@ -513,6 +536,7 @@ var _ = Describe("Server", func() { }) It("drops packets if the receive queue is full", func() { + phm.EXPECT().Get(gomock.Any()).AnyTimes() phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { phm.EXPECT().GetStatelessResetToken(gomock.Any()) fn() @@ -597,8 +621,10 @@ var _ = Describe("Server", func() { return conn } - p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9})) - phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9}), gomock.Any(), gomock.Any()).Return(false) + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9}) + p := getInitial(connID) + phm.EXPECT().Get(connID) + phm.EXPECT().AddWithConnID(connID, gomock.Any(), gomock.Any()).Return(false) Expect(serv.handlePacketImpl(p)).To(BeTrue()) Expect(createdConn).To(BeFalse()) }) @@ -632,6 +658,7 @@ var _ = Describe("Server", func() { return conn } + phm.EXPECT().Get(gomock.Any()).Times(protocol.MaxAcceptQueueSize + 1) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { phm.EXPECT().GetStatelessResetToken(gomock.Any()) fn() @@ -702,6 +729,7 @@ var _ = Describe("Server", func() { return conn } + phm.EXPECT().Get(gomock.Any()) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { phm.EXPECT().GetStatelessResetToken(gomock.Any()) fn() @@ -765,6 +793,7 @@ var _ = Describe("Server", func() { tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) done := make(chan struct{}) + phm.EXPECT().Get(gomock.Any()) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_, _ protocol.ConnectionID, _ func() packetHandler) { close(done) }) serv.handlePacket(packet) Eventually(done).Should(BeClosed()) @@ -801,6 +830,7 @@ var _ = Describe("Server", func() { checkInvalidToken(b, hdr) return len(b), nil }) + phm.EXPECT().Get(gomock.Any()) serv.handlePacket(packet) Eventually(done).Should(BeClosed()) }) @@ -837,6 +867,7 @@ var _ = Describe("Server", func() { checkInvalidToken(b, hdr) return len(b), nil }) + phm.EXPECT().Get(gomock.Any()) serv.handlePacket(packet) Eventually(done).Should(BeClosed()) }) @@ -864,6 +895,7 @@ var _ = Describe("Server", func() { Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) return len(b), nil }) + phm.EXPECT().Get(gomock.Any()) serv.handlePacket(packet) // make sure there are no Write calls on the packet conn Eventually(done).Should(BeClosed()) @@ -893,6 +925,7 @@ var _ = Describe("Server", func() { defer close(done) return len(b), nil }) + phm.EXPECT().Get(gomock.Any()) serv.handlePacket(packet) Eventually(done).Should(BeClosed()) }) @@ -913,6 +946,7 @@ var _ = Describe("Server", func() { packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} done := make(chan struct{}) tracer.EXPECT().DroppedPacket(packet.remoteAddr, logging.PacketTypeInitial, packet.Size(), logging.PacketDropPayloadDecryptError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(done) }) + phm.EXPECT().Get(gomock.Any()) serv.handlePacket(packet) // make sure there are no Write calls on the packet conn time.Sleep(50 * time.Millisecond) @@ -997,6 +1031,7 @@ var _ = Describe("Server", func() { conn.EXPECT().Context().Return(context.Background()) return conn } + phm.EXPECT().Get(gomock.Any()) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { phm.EXPECT().GetStatelessResetToken(gomock.Any()) fn() @@ -1070,6 +1105,7 @@ var _ = Describe("Server", func() { conn.EXPECT().Context().Return(context.Background()) return conn } + phm.EXPECT().Get(gomock.Any()) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { phm.EXPECT().GetStatelessResetToken(gomock.Any()) fn() @@ -1115,6 +1151,7 @@ var _ = Describe("Server", func() { return conn } + phm.EXPECT().Get(gomock.Any()).AnyTimes() phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { phm.EXPECT().GetStatelessResetToken(gomock.Any()) fn() @@ -1175,6 +1212,7 @@ var _ = Describe("Server", func() { return conn } + phm.EXPECT().Get(gomock.Any()) phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { phm.EXPECT().GetStatelessResetToken(gomock.Any()) fn() @@ -1202,4 +1240,215 @@ var _ = Describe("Server", func() { Eventually(done).Should(BeClosed()) }) }) + + Context("0-RTT", func() { + var ( + serv *earlyServer + phm *MockPacketHandlerManager + tracer *mocklogging.MockTracer + ) + + BeforeEach(func() { + tracer = mocklogging.NewMockTracer(mockCtrl) + ln, err := ListenEarly(conn, tlsConf, &Config{Tracer: tracer}) + Expect(err).ToNot(HaveOccurred()) + serv = ln.(*earlyServer) + phm = NewMockPacketHandlerManager(mockCtrl) + serv.connHandler = phm + }) + + AfterEach(func() { + phm.EXPECT().CloseServer().MaxTimes(1) + serv.Close() + }) + + It("passes packets to existing connections", func() { + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) + p := getPacket(&wire.Header{ + Type: protocol.PacketType0RTT, + DestConnectionID: connID, + Version: serv.config.Versions[0], + }, make([]byte, 100)) + conn := NewMockPacketHandler(mockCtrl) + phm.EXPECT().Get(connID).Return(conn, true) + handled := make(chan struct{}) + conn.EXPECT().handlePacket(p).Do(func(*receivedPacket) { close(handled) }) + serv.handlePacket(p) + Eventually(handled).Should(BeClosed()) + }) + + It("queues 0-RTT packets, up to Max0RTTQueueSize", func() { + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) + + var zeroRTTPackets []*receivedPacket + + for i := 0; i < protocol.Max0RTTQueueLen; i++ { + p := getPacket(&wire.Header{ + Type: protocol.PacketType0RTT, + DestConnectionID: connID, + Version: serv.config.Versions[0], + }, make([]byte, 100+i)) + phm.EXPECT().Get(connID) + serv.handlePacket(p) + zeroRTTPackets = append(zeroRTTPackets, p) + } + + // send one more packet, this one should be dropped + p := getPacket(&wire.Header{ + Type: protocol.PacketType0RTT, + DestConnectionID: connID, + Version: serv.config.Versions[0], + }, make([]byte, 200)) + phm.EXPECT().Get(connID) + tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention) + serv.handlePacket(p) + + initial := getPacket(&wire.Header{ + Type: protocol.PacketTypeInitial, + DestConnectionID: connID, + Version: serv.config.Versions[0], + }, make([]byte, protocol.MinInitialPacketSize)) + called := make(chan struct{}) + serv.newConn = func( + _ sendConn, + _ connRunner, + _ protocol.ConnectionID, + _ *protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ protocol.StatelessResetToken, + _ *Config, + _ *tls.Config, + _ *handshake.TokenGenerator, + _ bool, + _ logging.ConnectionTracer, + _ uint64, + _ utils.Logger, + _ protocol.VersionNumber, + ) quicConn { + conn := NewMockQUICConn(mockCtrl) + var calls []*gomock.Call + calls = append(calls, conn.EXPECT().handlePacket(initial)) + for _, p := range zeroRTTPackets { + calls = append(calls, conn.EXPECT().handlePacket(p)) + } + gomock.InOrder(calls...) + conn.EXPECT().run() + conn.EXPECT().earlyConnReady() + conn.EXPECT().Context().Return(context.Background()) + close(called) + return conn + } + + tracer.EXPECT().TracerForConnection(gomock.Any(), gomock.Any(), gomock.Any()) + phm.EXPECT().Get(connID) + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { + phm.EXPECT().GetStatelessResetToken(gomock.Any()) + fn() + return true + }) + serv.handlePacket(initial) + Eventually(called).Should(BeClosed()) + }) + + It("limits the number of queues", func() { + for i := 0; i < protocol.Max0RTTQueues; i++ { + b := make([]byte, 16) + rand.Read(b) + connID := protocol.ParseConnectionID(b) + p := getPacket(&wire.Header{ + Type: protocol.PacketType0RTT, + DestConnectionID: connID, + Version: serv.config.Versions[0], + }, make([]byte, 100+i)) + phm.EXPECT().Get(connID) + serv.handlePacket(p) + } + + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) + p := getPacket(&wire.Header{ + Type: protocol.PacketType0RTT, + DestConnectionID: connID, + Version: serv.config.Versions[0], + }, make([]byte, 200)) + phm.EXPECT().Get(connID) + dropped := make(chan struct{}) + tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { + close(dropped) + }) + serv.handlePacket(p) + Eventually(dropped).Should(BeClosed()) + }) + + It("drops queues after a while", func() { + now := time.Now() + + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) + p := getPacket(&wire.Header{ + Type: protocol.PacketType0RTT, + DestConnectionID: connID, + Version: serv.config.Versions[0], + }, make([]byte, 200)) + p.rcvTime = now + + connID2 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 9}) + p2Time := now.Add(protocol.Max0RTTQueueingDuration / 2) + p2 := getPacket(&wire.Header{ + Type: protocol.PacketType0RTT, + DestConnectionID: connID2, + Version: serv.config.Versions[0], + }, make([]byte, 300)) + p2.rcvTime = p2Time // doesn't trigger the cleanup of the first packet + + dropped1 := make(chan struct{}) + dropped2 := make(chan struct{}) + // need to register the call before handling the packet to avoid race condition + gomock.InOrder( + tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketType0RTT, p.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { + close(dropped1) + }), + tracer.EXPECT().DroppedPacket(p2.remoteAddr, logging.PacketType0RTT, p2.Size(), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { + close(dropped2) + }), + ) + + phm.EXPECT().Get(connID) + serv.handlePacket(p) + + // There's no cleanup Go routine. + // Cleanup is triggered when new packets are received. + + phm.EXPECT().Get(connID2) + serv.handlePacket(p2) + // make sure no cleanup is executed + Consistently(dropped1, 50*time.Millisecond).ShouldNot(BeClosed()) + + // There's no cleanup Go routine. + // Cleanup is triggered when new packets are received. + connID3 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 0}) + p3 := getPacket(&wire.Header{ + Type: protocol.PacketType0RTT, + DestConnectionID: connID3, + Version: serv.config.Versions[0], + }, make([]byte, 200)) + p3.rcvTime = now.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond) // now triggers the cleanup + phm.EXPECT().Get(connID3) + serv.handlePacket(p3) + Eventually(dropped1).Should(BeClosed()) + Consistently(dropped2, 50*time.Millisecond).ShouldNot(BeClosed()) + + // make sure the second packet is also cleaned up + connID4 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 1}) + p4 := getPacket(&wire.Header{ + Type: protocol.PacketType0RTT, + DestConnectionID: connID4, + Version: serv.config.Versions[0], + }, make([]byte, 200)) + p4.rcvTime = p2Time.Add(protocol.Max0RTTQueueingDuration + time.Nanosecond) // now triggers the cleanup + phm.EXPECT().Get(connID4) + serv.handlePacket(p4) + Eventually(dropped2).Should(BeClosed()) + }) + }) }) diff --git a/zero_rtt_queue.go b/zero_rtt_queue.go deleted file mode 100644 index b81a936e..00000000 --- a/zero_rtt_queue.go +++ /dev/null @@ -1,34 +0,0 @@ -package quic - -import ( - "time" - - "github.com/quic-go/quic-go/internal/protocol" -) - -type zeroRTTQueue struct { - queue []*receivedPacket - retireTimer *time.Timer -} - -var _ packetHandler = &zeroRTTQueue{} - -func (h *zeroRTTQueue) handlePacket(p *receivedPacket) { - if len(h.queue) < protocol.Max0RTTQueueLen { - h.queue = append(h.queue, p) - } -} -func (h *zeroRTTQueue) shutdown() {} -func (h *zeroRTTQueue) destroy(error) {} -func (h *zeroRTTQueue) getPerspective() protocol.Perspective { return protocol.PerspectiveClient } -func (h *zeroRTTQueue) EnqueueAll(sess packetHandler) { - for _, p := range h.queue { - sess.handlePacket(p) - } -} - -func (h *zeroRTTQueue) Clear() { - for _, p := range h.queue { - p.buffer.Release() - } -}