From dad30e7c338a5691f495fd05a4a4f3e60fc3ee57 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 29 May 2020 15:54:25 +0700 Subject: [PATCH] only create a single session for duplicate Initials --- client.go | 2 - mock_packet_handler_manager_test.go | 14 +++++ packet_handler_map.go | 15 ++++++ packet_handler_map_test.go | 8 +++ server.go | 61 +++++++++++----------- server_test.go | 81 ++++++++++++++++++++--------- 6 files changed, 123 insertions(+), 58 deletions(-) diff --git a/client.go b/client.go index ee46c454..66a9b713 100644 --- a/client.go +++ b/client.go @@ -271,8 +271,6 @@ func (c *client) dial(ctx context.Context) error { c.version, ) c.mutex.Unlock() - // It's not possible to use the stateless reset token for the client's (first) connection ID, - // since there's no way to securely communicate it to the server. c.packetHandlers.Add(c.srcConnID, c) errorChan := make(chan error, 1) diff --git a/mock_packet_handler_manager_test.go b/mock_packet_handler_manager_test.go index e6316339..77515635 100644 --- a/mock_packet_handler_manager_test.go +++ b/mock_packet_handler_manager_test.go @@ -60,6 +60,20 @@ func (mr *MockPacketHandlerManagerMockRecorder) AddResetToken(arg0, arg1 interfa return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).AddResetToken), arg0, arg1) } +// AddWithConnID mocks base method +func (m *MockPacketHandlerManager) AddWithConnID(arg0, arg1 protocol.ConnectionID, arg2 func() packetHandler) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddWithConnID", arg0, arg1, arg2) + ret0, _ := ret[0].(bool) + return ret0 +} + +// AddWithConnID indicates an expected call of AddWithConnID +func (mr *MockPacketHandlerManagerMockRecorder) AddWithConnID(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddWithConnID", reflect.TypeOf((*MockPacketHandlerManager)(nil).AddWithConnID), arg0, arg1, arg2) +} + // CloseServer mocks base method func (m *MockPacketHandlerManager) CloseServer() { m.ctrl.T.Helper() diff --git a/packet_handler_map.go b/packet_handler_map.go index 1daec42b..acce56c0 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -118,6 +118,21 @@ func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) return true } +func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.ConnectionID, fn func() packetHandler) bool { + sid := string(clientDestConnID) + h.mutex.Lock() + defer h.mutex.Unlock() + + if _, ok := h.handlers[sid]; ok { + return false + } + + sess := fn() + h.handlers[sid] = sess + h.handlers[string(newConnID)] = sess + return true +} + func (h *packetHandlerMap) Remove(id protocol.ConnectionID) { h.mutex.Lock() delete(h.handlers, string(id)) diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index 4580747e..a110b8d9 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -181,6 +181,14 @@ var _ = Describe("Packet Handler Map", func() { Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeTrue()) Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeFalse()) }) + + It("says if a connection ID is already taken, for AddWithConnID", func() { + clientDestConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + newConnID1 := protocol.ConnectionID{1, 2, 3, 4} + newConnID2 := protocol.ConnectionID{4, 3, 2, 1} + Expect(handler.AddWithConnID(clientDestConnID, newConnID1, func() packetHandler { return NewMockPacketHandler(mockCtrl) })).To(BeTrue()) + Expect(handler.AddWithConnID(clientDestConnID, newConnID2, func() packetHandler { return NewMockPacketHandler(mockCtrl) })).To(BeFalse()) + }) }) Context("running a server", func() { diff --git a/server.go b/server.go index b20f0b13..a88737f0 100644 --- a/server.go +++ b/server.go @@ -34,6 +34,7 @@ type unknownPacketHandler interface { } type packetHandlerManager interface { + AddWithConnID(protocol.ConnectionID, protocol.ConnectionID, func() packetHandler) bool Destroy() error sessionRunner SetServer(unknownPacketHandler) @@ -421,39 +422,39 @@ func (s *baseServer) createNewSession( srcConnID protocol.ConnectionID, version protocol.VersionNumber, ) quicSession { - var qlogger qlog.Tracer - if s.config.GetLogWriter != nil { - // Use the same connection ID that is passed to the client's GetLogWriter callback. - connID := clientDestConnID - if origDestConnID.Len() > 0 { - connID = origDestConnID + var sess quicSession + if added := s.sessionHandler.AddWithConnID(clientDestConnID, srcConnID, func() packetHandler { + var qlogger qlog.Tracer + if s.config.GetLogWriter != nil { + // Use the same connection ID that is passed to the client's GetLogWriter callback. + connID := clientDestConnID + if origDestConnID.Len() > 0 { + connID = origDestConnID + } + if w := s.config.GetLogWriter(connID); w != nil { + qlogger = qlog.NewTracer(w, protocol.PerspectiveServer, connID) + } } - if w := s.config.GetLogWriter(connID); w != nil { - qlogger = qlog.NewTracer(w, protocol.PerspectiveServer, connID) - } - } - sess := s.newSession( - &conn{pconn: s.conn, currentAddr: remoteAddr}, - s.sessionHandler, - origDestConnID, - clientDestConnID, - destConnID, - srcConnID, - s.sessionHandler.GetStatelessResetToken(srcConnID), - s.config, - s.tlsConf, - s.tokenGenerator, - s.acceptEarlySessions, - qlogger, - s.logger, - version, - ) - if added := s.sessionHandler.Add(clientDestConnID, sess); !added { - // We're already keeping track of this connection ID. - // This might happen if we receive two copies of the Initial at the same time. + sess = s.newSession( + &conn{pconn: s.conn, currentAddr: remoteAddr}, + s.sessionHandler, + origDestConnID, + clientDestConnID, + destConnID, + srcConnID, + s.sessionHandler.GetStatelessResetToken(srcConnID), + s.config, + s.tlsConf, + s.tokenGenerator, + s.acceptEarlySessions, + qlogger, + s.logger, + version, + ) + return sess + }); !added { return nil } - s.sessionHandler.Add(srcConnID, sess) go sess.run() go s.handleNewSession(sess) return sess diff --git a/server_test.go b/server_test.go index 9803a638..5ab73b22 100644 --- a/server_test.go +++ b/server_test.go @@ -403,10 +403,7 @@ var _ = Describe("Server", func() { var token [16]byte rand.Read(token[:]) var newConnID protocol.ConnectionID - phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) [16]byte { - newConnID = c - return token - }) + sess := NewMockQuicSession(mockCtrl) serv.newSession = func( _ connection, @@ -439,9 +436,13 @@ var _ = Describe("Server", func() { return sess } - phm.EXPECT().Add(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, sess).Return(true) - phm.EXPECT().Add(gomock.Any(), sess).DoAndReturn(func(c protocol.ConnectionID, _ packetHandler) bool { - Expect(c).To(Equal(newConnID)) + phm.EXPECT().AddWithConnID(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool { + newConnID = c + phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) [16]byte { + newConnID = c + return token + }) + fn() return true }) @@ -502,14 +503,20 @@ var _ = Describe("Server", func() { Expect(serv.handlePacketImpl(zeroRTTPacket)).To(BeTrue()) // Then receive the Initial packet. phm.EXPECT().GetStatelessResetToken(gomock.Any()) - phm.EXPECT().Add(gomock.Any(), sess).Return(true).Times(2) + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { + fn() + return true + }) Expect(serv.handlePacketImpl(initialPacket)).To(BeTrue()) Expect(createdSession).To(BeTrue()) }) It("drops packets if the receive queue is full", func() { - phm.EXPECT().GetStatelessResetToken(gomock.Any()).AnyTimes() - phm.EXPECT().Add(gomock.Any(), gomock.Any()).AnyTimes() + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { + phm.EXPECT().GetStatelessResetToken(gomock.Any()) + fn() + return true + }).AnyTimes() serv.config.AcceptToken = func(net.Addr, *Token) bool { return true } acceptSession := make(chan struct{}) @@ -532,7 +539,12 @@ var _ = Describe("Server", func() { ) quicSession { <-acceptSession atomic.AddUint32(&counter, 1) - return nil + sess := NewMockQuicSession(mockCtrl) + sess.EXPECT().handlePacket(gomock.Any()) + sess.EXPECT().run() + sess.EXPECT().Context().Return(context.Background()) + sess.EXPECT().HandshakeComplete().Return(context.Background()) + return sess } serv.handlePacket(getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8})) @@ -577,10 +589,9 @@ var _ = Describe("Server", func() { } p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}) - phm.EXPECT().GetStatelessResetToken(gomock.Any()) - phm.EXPECT().Add(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}, sess).Return(false) + phm.EXPECT().AddWithConnID(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}, gomock.Any(), gomock.Any()).Return(false) Expect(serv.handlePacketImpl(p)).To(BeTrue()) - Expect(createdSession).To(BeTrue()) + Expect(createdSession).To(BeFalse()) }) It("rejects new connection attempts if the accept queue is full", func() { @@ -612,8 +623,11 @@ var _ = Describe("Server", func() { return sess } - phm.EXPECT().GetStatelessResetToken(gomock.Any()).Times(protocol.MaxAcceptQueueSize) - phm.EXPECT().Add(gomock.Any(), gomock.Any()).Return(true).Times(2 * protocol.MaxAcceptQueueSize) + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { + phm.EXPECT().GetStatelessResetToken(gomock.Any()) + fn() + return true + }).Times(protocol.MaxAcceptQueueSize) var wg sync.WaitGroup wg.Add(protocol.MaxAcceptQueueSize) @@ -673,8 +687,11 @@ var _ = Describe("Server", func() { return sess } - phm.EXPECT().GetStatelessResetToken(gomock.Any()) - phm.EXPECT().Add(gomock.Any(), gomock.Any()).Return(true).Times(2) + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { + phm.EXPECT().GetStatelessResetToken(gomock.Any()) + fn() + return true + }) serv.handlePacket(p) Consistently(conn.dataWritten).ShouldNot(Receive()) @@ -772,8 +789,11 @@ var _ = Describe("Server", func() { sess.EXPECT().Context().Return(context.Background()) return sess } - phm.EXPECT().GetStatelessResetToken(gomock.Any()) - phm.EXPECT().Add(gomock.Any(), gomock.Any()).Return(true).Times(2) + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { + phm.EXPECT().GetStatelessResetToken(gomock.Any()) + fn() + return true + }) serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever) Consistently(done).ShouldNot(BeClosed()) cancel() // complete the handshake @@ -836,8 +856,11 @@ var _ = Describe("Server", func() { sess.EXPECT().Context().Return(context.Background()) return sess } - phm.EXPECT().GetStatelessResetToken(gomock.Any()) - phm.EXPECT().Add(gomock.Any(), sess).Return(true).Times(2) + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { + phm.EXPECT().GetStatelessResetToken(gomock.Any()) + fn() + return true + }) serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever) Consistently(done).ShouldNot(BeClosed()) close(ready) @@ -874,8 +897,11 @@ var _ = Describe("Server", func() { return sess } - phm.EXPECT().GetStatelessResetToken(gomock.Any()).Times(protocol.MaxAcceptQueueSize) - phm.EXPECT().Add(gomock.Any(), gomock.Any()).Return(true).Times(2 * protocol.MaxAcceptQueueSize) + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { + phm.EXPECT().GetStatelessResetToken(gomock.Any()) + fn() + return true + }).Times(protocol.MaxAcceptQueueSize) for i := 0; i < protocol.MaxAcceptQueueSize; i++ { serv.handlePacket(getInitialWithRandomDestConnID()) } @@ -927,8 +953,11 @@ var _ = Describe("Server", func() { return sess } - phm.EXPECT().GetStatelessResetToken(gomock.Any()) - phm.EXPECT().Add(gomock.Any(), sess).Return(true).Times(2) + phm.EXPECT().AddWithConnID(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ protocol.ConnectionID, fn func() packetHandler) bool { + phm.EXPECT().GetStatelessResetToken(gomock.Any()) + fn() + return true + }) serv.handlePacket(p) Consistently(conn.dataWritten).ShouldNot(Receive()) Eventually(sessionCreated).Should(BeClosed())