collapse packetHandlerManager.Add and AddIfNotTaken into one function

This commit is contained in:
Marten Seemann 2020-01-14 18:05:23 +07:00
parent 634169b061
commit 93693349f3
7 changed files with 31 additions and 53 deletions

View file

@ -35,9 +35,11 @@ func (m *MockPacketHandlerManager) EXPECT() *MockPacketHandlerManagerMockRecorde
} }
// Add mocks base method // Add mocks base method
func (m *MockPacketHandlerManager) Add(arg0 protocol.ConnectionID, arg1 packetHandler) { func (m *MockPacketHandlerManager) Add(arg0 protocol.ConnectionID, arg1 packetHandler) bool {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "Add", arg0, arg1) ret := m.ctrl.Call(m, "Add", arg0, arg1)
ret0, _ := ret[0].(bool)
return ret0
} }
// Add indicates an expected call of Add // Add indicates an expected call of Add
@ -46,20 +48,6 @@ func (mr *MockPacketHandlerManagerMockRecorder) Add(arg0, arg1 interface{}) *gom
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockPacketHandlerManager)(nil).Add), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockPacketHandlerManager)(nil).Add), arg0, arg1)
} }
// AddIfNotTaken mocks base method
func (m *MockPacketHandlerManager) AddIfNotTaken(arg0 protocol.ConnectionID, arg1 packetHandler) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddIfNotTaken", arg0, arg1)
ret0, _ := ret[0].(bool)
return ret0
}
// AddIfNotTaken indicates an expected call of AddIfNotTaken
func (mr *MockPacketHandlerManagerMockRecorder) AddIfNotTaken(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddIfNotTaken", reflect.TypeOf((*MockPacketHandlerManager)(nil).AddIfNotTaken), arg0, arg1)
}
// AddResetToken mocks base method // AddResetToken mocks base method
func (m *MockPacketHandlerManager) AddResetToken(arg0 [16]byte, arg1 packetHandler) { func (m *MockPacketHandlerManager) AddResetToken(arg0 [16]byte, arg1 packetHandler) {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View file

@ -35,9 +35,11 @@ func (m *MockSessionRunner) EXPECT() *MockSessionRunnerMockRecorder {
} }
// Add mocks base method // Add mocks base method
func (m *MockSessionRunner) Add(arg0 protocol.ConnectionID, arg1 packetHandler) { func (m *MockSessionRunner) Add(arg0 protocol.ConnectionID, arg1 packetHandler) bool {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "Add", arg0, arg1) ret := m.ctrl.Call(m, "Add", arg0, arg1)
ret0, _ := ret[0].(bool)
return ret0
} }
// Add indicates an expected call of Add // Add indicates an expected call of Add

View file

@ -95,23 +95,18 @@ func (h *packetHandlerMap) logUsage() {
} }
} }
func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) { func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) bool /* was added */ {
h.mutex.Lock()
h.handlers[string(id)] = handler
h.mutex.Unlock()
}
func (h *packetHandlerMap) AddIfNotTaken(id protocol.ConnectionID, handler packetHandler) bool /* was added */ {
sid := string(id) sid := string(id)
h.mutex.Lock() h.mutex.Lock()
defer h.mutex.Unlock() defer h.mutex.Unlock()
if _, ok := h.handlers[sid]; !ok { if _, ok := h.handlers[sid]; ok {
return false
}
h.handlers[sid] = handler h.handlers[sid] = handler
return true return true
} }
return false
}
func (h *packetHandlerMap) Remove(id protocol.ConnectionID) { func (h *packetHandlerMap) Remove(id protocol.ConnectionID) {
h.mutex.Lock() h.mutex.Lock()

View file

@ -178,8 +178,8 @@ var _ = Describe("Packet Handler Map", func() {
It("says if a connection ID is already taken", func() { It("says if a connection ID is already taken", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
Expect(handler.AddIfNotTaken(connID, NewMockPacketHandler(mockCtrl))).To(BeTrue()) Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeTrue())
Expect(handler.AddIfNotTaken(connID, NewMockPacketHandler(mockCtrl))).To(BeFalse()) Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeFalse())
}) })
}) })

View file

@ -33,10 +33,9 @@ type unknownPacketHandler interface {
type packetHandlerManager interface { type packetHandlerManager interface {
Destroy() error Destroy() error
sessionRunner
SetServer(unknownPacketHandler) SetServer(unknownPacketHandler)
CloseServer() CloseServer()
sessionRunner
AddIfNotTaken(protocol.ConnectionID, packetHandler) bool
} }
type quicSession interface { type quicSession interface {
@ -476,10 +475,9 @@ func (s *baseServer) createNewSession(
s.logger, s.logger,
version, version,
) )
added := s.sessionHandler.AddIfNotTaken(clientDestConnID, sess) if added := s.sessionHandler.Add(clientDestConnID, sess); !added {
// We're already keeping track of this connection ID. // We're already keeping track of this connection ID.
// This might happen if we receive two copies of the Initial at the same time. // This might happen if we receive two copies of the Initial at the same time.
if !added {
return nil return nil
} }
s.sessionHandler.Add(srcConnID, sess) s.sessionHandler.Add(srcConnID, sess)

View file

@ -367,9 +367,10 @@ var _ = Describe("Server", func() {
return sess return sess
} }
phm.EXPECT().AddIfNotTaken(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, sess).Return(true) phm.EXPECT().Add(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, sess).Return(true)
phm.EXPECT().Add(gomock.Any(), sess).Do(func(c protocol.ConnectionID, _ packetHandler) { phm.EXPECT().Add(gomock.Any(), sess).DoAndReturn(func(c protocol.ConnectionID, _ packetHandler) bool {
Expect(c).To(Equal(newConnID)) Expect(c).To(Equal(newConnID))
return true
}) })
done := make(chan struct{}) done := make(chan struct{})
@ -410,7 +411,7 @@ var _ = Describe("Server", func() {
p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}) p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9})
phm.EXPECT().GetStatelessResetToken(gomock.Any()) phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().AddIfNotTaken(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}, sess).Return(false) phm.EXPECT().Add(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}, sess).Return(false)
Expect(serv.handlePacketImpl(p)).To(BeFalse()) Expect(serv.handlePacketImpl(p)).To(BeFalse())
Expect(createdSession).To(BeTrue()) Expect(createdSession).To(BeTrue())
}) })
@ -444,8 +445,7 @@ var _ = Describe("Server", func() {
} }
phm.EXPECT().GetStatelessResetToken(gomock.Any()).Times(protocol.MaxAcceptQueueSize) phm.EXPECT().GetStatelessResetToken(gomock.Any()).Times(protocol.MaxAcceptQueueSize)
phm.EXPECT().AddIfNotTaken(gomock.Any(), gomock.Any()).Return(true).Times(protocol.MaxAcceptQueueSize) phm.EXPECT().Add(gomock.Any(), gomock.Any()).Return(true).Times(2 * protocol.MaxAcceptQueueSize)
phm.EXPECT().Add(gomock.Any(), gomock.Any()).Times(protocol.MaxAcceptQueueSize)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(protocol.MaxAcceptQueueSize) wg.Add(protocol.MaxAcceptQueueSize)
@ -505,8 +505,7 @@ var _ = Describe("Server", func() {
} }
phm.EXPECT().GetStatelessResetToken(gomock.Any()) phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().AddIfNotTaken(gomock.Any(), gomock.Any()).Return(true) phm.EXPECT().Add(gomock.Any(), gomock.Any()).Return(true).Times(2)
phm.EXPECT().Add(gomock.Any(), gomock.Any())
serv.handlePacket(p) serv.handlePacket(p)
Consistently(conn.dataWritten).ShouldNot(Receive()) Consistently(conn.dataWritten).ShouldNot(Receive())
@ -604,8 +603,7 @@ var _ = Describe("Server", func() {
return sess return sess
} }
phm.EXPECT().GetStatelessResetToken(gomock.Any()) phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().AddIfNotTaken(gomock.Any(), gomock.Any()).Return(true) phm.EXPECT().Add(gomock.Any(), gomock.Any()).Return(true).Times(2)
phm.EXPECT().Add(gomock.Any(), gomock.Any())
serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever) serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever)
Consistently(done).ShouldNot(BeClosed()) Consistently(done).ShouldNot(BeClosed())
cancel() // complete the handshake cancel() // complete the handshake
@ -668,8 +666,7 @@ var _ = Describe("Server", func() {
return sess return sess
} }
phm.EXPECT().GetStatelessResetToken(gomock.Any()) phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().AddIfNotTaken(gomock.Any(), sess).Return(true) phm.EXPECT().Add(gomock.Any(), sess).Return(true).Times(2)
phm.EXPECT().Add(gomock.Any(), sess)
serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever) serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever)
Consistently(done).ShouldNot(BeClosed()) Consistently(done).ShouldNot(BeClosed())
close(ready) close(ready)
@ -712,8 +709,7 @@ var _ = Describe("Server", func() {
defer GinkgoRecover() defer GinkgoRecover()
defer wg.Done() defer wg.Done()
phm.EXPECT().GetStatelessResetToken(gomock.Any()) phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().AddIfNotTaken(gomock.Any(), gomock.Any()).Return(true) phm.EXPECT().Add(gomock.Any(), gomock.Any()).Return(true).Times(2)
phm.EXPECT().Add(gomock.Any(), gomock.Any())
serv.handlePacket(getInitialWithRandomDestConnID()) serv.handlePacket(getInitialWithRandomDestConnID())
Consistently(conn.dataWritten).ShouldNot(Receive()) Consistently(conn.dataWritten).ShouldNot(Receive())
}() }()
@ -764,8 +760,7 @@ var _ = Describe("Server", func() {
} }
phm.EXPECT().GetStatelessResetToken(gomock.Any()) phm.EXPECT().GetStatelessResetToken(gomock.Any())
phm.EXPECT().AddIfNotTaken(gomock.Any(), sess).Return(true) phm.EXPECT().Add(gomock.Any(), sess).Return(true).Times(2)
phm.EXPECT().Add(gomock.Any(), sess)
serv.handlePacket(p) serv.handlePacket(p)
Consistently(conn.dataWritten).ShouldNot(Receive()) Consistently(conn.dataWritten).ShouldNot(Receive())
Eventually(sessionCreated).Should(BeClosed()) Eventually(sessionCreated).Should(BeClosed())

View file

@ -74,7 +74,7 @@ func (p *receivedPacket) Clone() *receivedPacket {
} }
type sessionRunner interface { type sessionRunner interface {
Add(protocol.ConnectionID, packetHandler) Add(protocol.ConnectionID, packetHandler) bool
GetStatelessResetToken(protocol.ConnectionID) [16]byte GetStatelessResetToken(protocol.ConnectionID) [16]byte
Retire(protocol.ConnectionID) Retire(protocol.ConnectionID)
Remove(protocol.ConnectionID) Remove(protocol.ConnectionID)