From 93693349f3706a248c6c11c0e720cd4d9cb35fa8 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 14 Jan 2020 18:05:23 +0700 Subject: [PATCH] collapse packetHandlerManager.Add and AddIfNotTaken into one function --- mock_packet_handler_manager_test.go | 20 ++++---------------- mock_session_runner_test.go | 6 ++++-- packet_handler_map.go | 17 ++++++----------- packet_handler_map_test.go | 4 ++-- server.go | 10 ++++------ server_test.go | 25 ++++++++++--------------- session.go | 2 +- 7 files changed, 31 insertions(+), 53 deletions(-) diff --git a/mock_packet_handler_manager_test.go b/mock_packet_handler_manager_test.go index 13faa199..e6316339 100644 --- a/mock_packet_handler_manager_test.go +++ b/mock_packet_handler_manager_test.go @@ -35,9 +35,11 @@ func (m *MockPacketHandlerManager) EXPECT() *MockPacketHandlerManagerMockRecorde } // Add mocks base method -func (m *MockPacketHandlerManager) Add(arg0 protocol.ConnectionID, arg1 packetHandler) { +func (m *MockPacketHandlerManager) Add(arg0 protocol.ConnectionID, arg1 packetHandler) bool { m.ctrl.T.Helper() - m.ctrl.Call(m, "Add", arg0, arg1) + ret := m.ctrl.Call(m, "Add", arg0, arg1) + ret0, _ := ret[0].(bool) + return ret0 } // Add indicates an expected call of Add @@ -46,20 +48,6 @@ func (mr *MockPacketHandlerManagerMockRecorder) Add(arg0, arg1 interface{}) *gom return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockPacketHandlerManager)(nil).Add), arg0, arg1) } -// AddIfNotTaken mocks base method -func (m *MockPacketHandlerManager) AddIfNotTaken(arg0 protocol.ConnectionID, arg1 packetHandler) bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AddIfNotTaken", arg0, arg1) - ret0, _ := ret[0].(bool) - return ret0 -} - -// AddIfNotTaken indicates an expected call of AddIfNotTaken -func (mr *MockPacketHandlerManagerMockRecorder) AddIfNotTaken(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddIfNotTaken", reflect.TypeOf((*MockPacketHandlerManager)(nil).AddIfNotTaken), arg0, arg1) -} - // AddResetToken mocks base method func (m *MockPacketHandlerManager) AddResetToken(arg0 [16]byte, arg1 packetHandler) { m.ctrl.T.Helper() diff --git a/mock_session_runner_test.go b/mock_session_runner_test.go index 8eaeb750..704272c6 100644 --- a/mock_session_runner_test.go +++ b/mock_session_runner_test.go @@ -35,9 +35,11 @@ func (m *MockSessionRunner) EXPECT() *MockSessionRunnerMockRecorder { } // Add mocks base method -func (m *MockSessionRunner) Add(arg0 protocol.ConnectionID, arg1 packetHandler) { +func (m *MockSessionRunner) Add(arg0 protocol.ConnectionID, arg1 packetHandler) bool { m.ctrl.T.Helper() - m.ctrl.Call(m, "Add", arg0, arg1) + ret := m.ctrl.Call(m, "Add", arg0, arg1) + ret0, _ := ret[0].(bool) + return ret0 } // Add indicates an expected call of Add diff --git a/packet_handler_map.go b/packet_handler_map.go index 179d6583..e40b2776 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -95,22 +95,17 @@ func (h *packetHandlerMap) logUsage() { } } -func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) { - h.mutex.Lock() - h.handlers[string(id)] = handler - h.mutex.Unlock() -} - -func (h *packetHandlerMap) AddIfNotTaken(id protocol.ConnectionID, handler packetHandler) bool /* was added */ { +func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) bool /* was added */ { sid := string(id) + h.mutex.Lock() defer h.mutex.Unlock() - if _, ok := h.handlers[sid]; !ok { - h.handlers[sid] = handler - return true + if _, ok := h.handlers[sid]; ok { + return false } - return false + h.handlers[sid] = handler + return true } func (h *packetHandlerMap) Remove(id protocol.ConnectionID) { diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index e45147b6..05b985fc 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -178,8 +178,8 @@ var _ = Describe("Packet Handler Map", func() { It("says if a connection ID is already taken", func() { connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - Expect(handler.AddIfNotTaken(connID, NewMockPacketHandler(mockCtrl))).To(BeTrue()) - Expect(handler.AddIfNotTaken(connID, NewMockPacketHandler(mockCtrl))).To(BeFalse()) + Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeTrue()) + Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeFalse()) }) }) diff --git a/server.go b/server.go index 333f9d45..a30ed579 100644 --- a/server.go +++ b/server.go @@ -33,10 +33,9 @@ type unknownPacketHandler interface { type packetHandlerManager interface { Destroy() error + sessionRunner SetServer(unknownPacketHandler) CloseServer() - sessionRunner - AddIfNotTaken(protocol.ConnectionID, packetHandler) bool } type quicSession interface { @@ -476,10 +475,9 @@ func (s *baseServer) createNewSession( s.logger, version, ) - added := s.sessionHandler.AddIfNotTaken(clientDestConnID, sess) - // We're already keeping track of this connection ID. - // This might happen if we receive two copies of the Initial at the same time. - if !added { + if added := s.sessionHandler.Add(clientDestConnID, sess); !added { + // We're already keeping track of this connection ID. + // This might happen if we receive two copies of the Initial at the same time. return nil } s.sessionHandler.Add(srcConnID, sess) diff --git a/server_test.go b/server_test.go index f04f52ae..96161637 100644 --- a/server_test.go +++ b/server_test.go @@ -367,9 +367,10 @@ var _ = Describe("Server", func() { return sess } - phm.EXPECT().AddIfNotTaken(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, sess).Return(true) - phm.EXPECT().Add(gomock.Any(), sess).Do(func(c protocol.ConnectionID, _ packetHandler) { + phm.EXPECT().Add(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, sess).Return(true) + phm.EXPECT().Add(gomock.Any(), sess).DoAndReturn(func(c protocol.ConnectionID, _ packetHandler) bool { Expect(c).To(Equal(newConnID)) + return true }) done := make(chan struct{}) @@ -410,7 +411,7 @@ var _ = Describe("Server", func() { p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}) phm.EXPECT().GetStatelessResetToken(gomock.Any()) - phm.EXPECT().AddIfNotTaken(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}, sess).Return(false) + phm.EXPECT().Add(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}, sess).Return(false) Expect(serv.handlePacketImpl(p)).To(BeFalse()) Expect(createdSession).To(BeTrue()) }) @@ -444,8 +445,7 @@ var _ = Describe("Server", func() { } phm.EXPECT().GetStatelessResetToken(gomock.Any()).Times(protocol.MaxAcceptQueueSize) - phm.EXPECT().AddIfNotTaken(gomock.Any(), gomock.Any()).Return(true).Times(protocol.MaxAcceptQueueSize) - phm.EXPECT().Add(gomock.Any(), gomock.Any()).Times(protocol.MaxAcceptQueueSize) + phm.EXPECT().Add(gomock.Any(), gomock.Any()).Return(true).Times(2 * protocol.MaxAcceptQueueSize) var wg sync.WaitGroup wg.Add(protocol.MaxAcceptQueueSize) @@ -505,8 +505,7 @@ var _ = Describe("Server", func() { } phm.EXPECT().GetStatelessResetToken(gomock.Any()) - phm.EXPECT().AddIfNotTaken(gomock.Any(), gomock.Any()).Return(true) - phm.EXPECT().Add(gomock.Any(), gomock.Any()) + phm.EXPECT().Add(gomock.Any(), gomock.Any()).Return(true).Times(2) serv.handlePacket(p) Consistently(conn.dataWritten).ShouldNot(Receive()) @@ -604,8 +603,7 @@ var _ = Describe("Server", func() { return sess } phm.EXPECT().GetStatelessResetToken(gomock.Any()) - phm.EXPECT().AddIfNotTaken(gomock.Any(), gomock.Any()).Return(true) - phm.EXPECT().Add(gomock.Any(), gomock.Any()) + phm.EXPECT().Add(gomock.Any(), gomock.Any()).Return(true).Times(2) serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever) Consistently(done).ShouldNot(BeClosed()) cancel() // complete the handshake @@ -668,8 +666,7 @@ var _ = Describe("Server", func() { return sess } phm.EXPECT().GetStatelessResetToken(gomock.Any()) - phm.EXPECT().AddIfNotTaken(gomock.Any(), sess).Return(true) - phm.EXPECT().Add(gomock.Any(), sess) + phm.EXPECT().Add(gomock.Any(), sess).Return(true).Times(2) serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever) Consistently(done).ShouldNot(BeClosed()) close(ready) @@ -712,8 +709,7 @@ var _ = Describe("Server", func() { defer GinkgoRecover() defer wg.Done() phm.EXPECT().GetStatelessResetToken(gomock.Any()) - phm.EXPECT().AddIfNotTaken(gomock.Any(), gomock.Any()).Return(true) - phm.EXPECT().Add(gomock.Any(), gomock.Any()) + phm.EXPECT().Add(gomock.Any(), gomock.Any()).Return(true).Times(2) serv.handlePacket(getInitialWithRandomDestConnID()) Consistently(conn.dataWritten).ShouldNot(Receive()) }() @@ -764,8 +760,7 @@ var _ = Describe("Server", func() { } phm.EXPECT().GetStatelessResetToken(gomock.Any()) - phm.EXPECT().AddIfNotTaken(gomock.Any(), sess).Return(true) - phm.EXPECT().Add(gomock.Any(), sess) + phm.EXPECT().Add(gomock.Any(), sess).Return(true).Times(2) serv.handlePacket(p) Consistently(conn.dataWritten).ShouldNot(Receive()) Eventually(sessionCreated).Should(BeClosed()) diff --git a/session.go b/session.go index e03aed8b..ebd14ab1 100644 --- a/session.go +++ b/session.go @@ -74,7 +74,7 @@ func (p *receivedPacket) Clone() *receivedPacket { } type sessionRunner interface { - Add(protocol.ConnectionID, packetHandler) + Add(protocol.ConnectionID, packetHandler) bool GetStatelessResetToken(protocol.ConnectionID) [16]byte Retire(protocol.ConnectionID) Remove(protocol.ConnectionID)