mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 20:57:36 +03:00
collapse packetHandlerManager.Add and AddIfNotTaken into one function
This commit is contained in:
parent
634169b061
commit
93693349f3
7 changed files with 31 additions and 53 deletions
|
@ -35,9 +35,11 @@ func (m *MockPacketHandlerManager) EXPECT() *MockPacketHandlerManagerMockRecorde
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add mocks base method
|
// Add mocks base method
|
||||||
func (m *MockPacketHandlerManager) Add(arg0 protocol.ConnectionID, arg1 packetHandler) {
|
func (m *MockPacketHandlerManager) Add(arg0 protocol.ConnectionID, arg1 packetHandler) bool {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
m.ctrl.Call(m, "Add", arg0, arg1)
|
ret := m.ctrl.Call(m, "Add", arg0, arg1)
|
||||||
|
ret0, _ := ret[0].(bool)
|
||||||
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add indicates an expected call of Add
|
// Add indicates an expected call of Add
|
||||||
|
@ -46,20 +48,6 @@ func (mr *MockPacketHandlerManagerMockRecorder) Add(arg0, arg1 interface{}) *gom
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockPacketHandlerManager)(nil).Add), arg0, arg1)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockPacketHandlerManager)(nil).Add), arg0, arg1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddIfNotTaken mocks base method
|
|
||||||
func (m *MockPacketHandlerManager) AddIfNotTaken(arg0 protocol.ConnectionID, arg1 packetHandler) bool {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "AddIfNotTaken", arg0, arg1)
|
|
||||||
ret0, _ := ret[0].(bool)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddIfNotTaken indicates an expected call of AddIfNotTaken
|
|
||||||
func (mr *MockPacketHandlerManagerMockRecorder) AddIfNotTaken(arg0, arg1 interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddIfNotTaken", reflect.TypeOf((*MockPacketHandlerManager)(nil).AddIfNotTaken), arg0, arg1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddResetToken mocks base method
|
// AddResetToken mocks base method
|
||||||
func (m *MockPacketHandlerManager) AddResetToken(arg0 [16]byte, arg1 packetHandler) {
|
func (m *MockPacketHandlerManager) AddResetToken(arg0 [16]byte, arg1 packetHandler) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
|
|
@ -35,9 +35,11 @@ func (m *MockSessionRunner) EXPECT() *MockSessionRunnerMockRecorder {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add mocks base method
|
// Add mocks base method
|
||||||
func (m *MockSessionRunner) Add(arg0 protocol.ConnectionID, arg1 packetHandler) {
|
func (m *MockSessionRunner) Add(arg0 protocol.ConnectionID, arg1 packetHandler) bool {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
m.ctrl.Call(m, "Add", arg0, arg1)
|
ret := m.ctrl.Call(m, "Add", arg0, arg1)
|
||||||
|
ret0, _ := ret[0].(bool)
|
||||||
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add indicates an expected call of Add
|
// Add indicates an expected call of Add
|
||||||
|
|
|
@ -95,23 +95,18 @@ func (h *packetHandlerMap) logUsage() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) {
|
func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) bool /* was added */ {
|
||||||
h.mutex.Lock()
|
|
||||||
h.handlers[string(id)] = handler
|
|
||||||
h.mutex.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *packetHandlerMap) AddIfNotTaken(id protocol.ConnectionID, handler packetHandler) bool /* was added */ {
|
|
||||||
sid := string(id)
|
sid := string(id)
|
||||||
|
|
||||||
h.mutex.Lock()
|
h.mutex.Lock()
|
||||||
defer h.mutex.Unlock()
|
defer h.mutex.Unlock()
|
||||||
|
|
||||||
if _, ok := h.handlers[sid]; !ok {
|
if _, ok := h.handlers[sid]; ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
h.handlers[sid] = handler
|
h.handlers[sid] = handler
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *packetHandlerMap) Remove(id protocol.ConnectionID) {
|
func (h *packetHandlerMap) Remove(id protocol.ConnectionID) {
|
||||||
h.mutex.Lock()
|
h.mutex.Lock()
|
||||||
|
|
|
@ -178,8 +178,8 @@ var _ = Describe("Packet Handler Map", func() {
|
||||||
|
|
||||||
It("says if a connection ID is already taken", func() {
|
It("says if a connection ID is already taken", func() {
|
||||||
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||||
Expect(handler.AddIfNotTaken(connID, NewMockPacketHandler(mockCtrl))).To(BeTrue())
|
Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeTrue())
|
||||||
Expect(handler.AddIfNotTaken(connID, NewMockPacketHandler(mockCtrl))).To(BeFalse())
|
Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeFalse())
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -33,10 +33,9 @@ type unknownPacketHandler interface {
|
||||||
|
|
||||||
type packetHandlerManager interface {
|
type packetHandlerManager interface {
|
||||||
Destroy() error
|
Destroy() error
|
||||||
|
sessionRunner
|
||||||
SetServer(unknownPacketHandler)
|
SetServer(unknownPacketHandler)
|
||||||
CloseServer()
|
CloseServer()
|
||||||
sessionRunner
|
|
||||||
AddIfNotTaken(protocol.ConnectionID, packetHandler) bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type quicSession interface {
|
type quicSession interface {
|
||||||
|
@ -476,10 +475,9 @@ func (s *baseServer) createNewSession(
|
||||||
s.logger,
|
s.logger,
|
||||||
version,
|
version,
|
||||||
)
|
)
|
||||||
added := s.sessionHandler.AddIfNotTaken(clientDestConnID, sess)
|
if added := s.sessionHandler.Add(clientDestConnID, sess); !added {
|
||||||
// We're already keeping track of this connection ID.
|
// We're already keeping track of this connection ID.
|
||||||
// This might happen if we receive two copies of the Initial at the same time.
|
// This might happen if we receive two copies of the Initial at the same time.
|
||||||
if !added {
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
s.sessionHandler.Add(srcConnID, sess)
|
s.sessionHandler.Add(srcConnID, sess)
|
||||||
|
|
|
@ -367,9 +367,10 @@ var _ = Describe("Server", func() {
|
||||||
return sess
|
return sess
|
||||||
}
|
}
|
||||||
|
|
||||||
phm.EXPECT().AddIfNotTaken(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, sess).Return(true)
|
phm.EXPECT().Add(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, sess).Return(true)
|
||||||
phm.EXPECT().Add(gomock.Any(), sess).Do(func(c protocol.ConnectionID, _ packetHandler) {
|
phm.EXPECT().Add(gomock.Any(), sess).DoAndReturn(func(c protocol.ConnectionID, _ packetHandler) bool {
|
||||||
Expect(c).To(Equal(newConnID))
|
Expect(c).To(Equal(newConnID))
|
||||||
|
return true
|
||||||
})
|
})
|
||||||
|
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
|
@ -410,7 +411,7 @@ var _ = Describe("Server", func() {
|
||||||
|
|
||||||
p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9})
|
p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9})
|
||||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||||
phm.EXPECT().AddIfNotTaken(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}, sess).Return(false)
|
phm.EXPECT().Add(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}, sess).Return(false)
|
||||||
Expect(serv.handlePacketImpl(p)).To(BeFalse())
|
Expect(serv.handlePacketImpl(p)).To(BeFalse())
|
||||||
Expect(createdSession).To(BeTrue())
|
Expect(createdSession).To(BeTrue())
|
||||||
})
|
})
|
||||||
|
@ -444,8 +445,7 @@ var _ = Describe("Server", func() {
|
||||||
}
|
}
|
||||||
|
|
||||||
phm.EXPECT().GetStatelessResetToken(gomock.Any()).Times(protocol.MaxAcceptQueueSize)
|
phm.EXPECT().GetStatelessResetToken(gomock.Any()).Times(protocol.MaxAcceptQueueSize)
|
||||||
phm.EXPECT().AddIfNotTaken(gomock.Any(), gomock.Any()).Return(true).Times(protocol.MaxAcceptQueueSize)
|
phm.EXPECT().Add(gomock.Any(), gomock.Any()).Return(true).Times(2 * protocol.MaxAcceptQueueSize)
|
||||||
phm.EXPECT().Add(gomock.Any(), gomock.Any()).Times(protocol.MaxAcceptQueueSize)
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(protocol.MaxAcceptQueueSize)
|
wg.Add(protocol.MaxAcceptQueueSize)
|
||||||
|
@ -505,8 +505,7 @@ var _ = Describe("Server", func() {
|
||||||
}
|
}
|
||||||
|
|
||||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||||
phm.EXPECT().AddIfNotTaken(gomock.Any(), gomock.Any()).Return(true)
|
phm.EXPECT().Add(gomock.Any(), gomock.Any()).Return(true).Times(2)
|
||||||
phm.EXPECT().Add(gomock.Any(), gomock.Any())
|
|
||||||
|
|
||||||
serv.handlePacket(p)
|
serv.handlePacket(p)
|
||||||
Consistently(conn.dataWritten).ShouldNot(Receive())
|
Consistently(conn.dataWritten).ShouldNot(Receive())
|
||||||
|
@ -604,8 +603,7 @@ var _ = Describe("Server", func() {
|
||||||
return sess
|
return sess
|
||||||
}
|
}
|
||||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||||
phm.EXPECT().AddIfNotTaken(gomock.Any(), gomock.Any()).Return(true)
|
phm.EXPECT().Add(gomock.Any(), gomock.Any()).Return(true).Times(2)
|
||||||
phm.EXPECT().Add(gomock.Any(), gomock.Any())
|
|
||||||
serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever)
|
serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever)
|
||||||
Consistently(done).ShouldNot(BeClosed())
|
Consistently(done).ShouldNot(BeClosed())
|
||||||
cancel() // complete the handshake
|
cancel() // complete the handshake
|
||||||
|
@ -668,8 +666,7 @@ var _ = Describe("Server", func() {
|
||||||
return sess
|
return sess
|
||||||
}
|
}
|
||||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||||
phm.EXPECT().AddIfNotTaken(gomock.Any(), sess).Return(true)
|
phm.EXPECT().Add(gomock.Any(), sess).Return(true).Times(2)
|
||||||
phm.EXPECT().Add(gomock.Any(), sess)
|
|
||||||
serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever)
|
serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever)
|
||||||
Consistently(done).ShouldNot(BeClosed())
|
Consistently(done).ShouldNot(BeClosed())
|
||||||
close(ready)
|
close(ready)
|
||||||
|
@ -712,8 +709,7 @@ var _ = Describe("Server", func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||||
phm.EXPECT().AddIfNotTaken(gomock.Any(), gomock.Any()).Return(true)
|
phm.EXPECT().Add(gomock.Any(), gomock.Any()).Return(true).Times(2)
|
||||||
phm.EXPECT().Add(gomock.Any(), gomock.Any())
|
|
||||||
serv.handlePacket(getInitialWithRandomDestConnID())
|
serv.handlePacket(getInitialWithRandomDestConnID())
|
||||||
Consistently(conn.dataWritten).ShouldNot(Receive())
|
Consistently(conn.dataWritten).ShouldNot(Receive())
|
||||||
}()
|
}()
|
||||||
|
@ -764,8 +760,7 @@ var _ = Describe("Server", func() {
|
||||||
}
|
}
|
||||||
|
|
||||||
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
phm.EXPECT().GetStatelessResetToken(gomock.Any())
|
||||||
phm.EXPECT().AddIfNotTaken(gomock.Any(), sess).Return(true)
|
phm.EXPECT().Add(gomock.Any(), sess).Return(true).Times(2)
|
||||||
phm.EXPECT().Add(gomock.Any(), sess)
|
|
||||||
serv.handlePacket(p)
|
serv.handlePacket(p)
|
||||||
Consistently(conn.dataWritten).ShouldNot(Receive())
|
Consistently(conn.dataWritten).ShouldNot(Receive())
|
||||||
Eventually(sessionCreated).Should(BeClosed())
|
Eventually(sessionCreated).Should(BeClosed())
|
||||||
|
|
|
@ -74,7 +74,7 @@ func (p *receivedPacket) Clone() *receivedPacket {
|
||||||
}
|
}
|
||||||
|
|
||||||
type sessionRunner interface {
|
type sessionRunner interface {
|
||||||
Add(protocol.ConnectionID, packetHandler)
|
Add(protocol.ConnectionID, packetHandler) bool
|
||||||
GetStatelessResetToken(protocol.ConnectionID) [16]byte
|
GetStatelessResetToken(protocol.ConnectionID) [16]byte
|
||||||
Retire(protocol.ConnectionID)
|
Retire(protocol.ConnectionID)
|
||||||
Remove(protocol.ConnectionID)
|
Remove(protocol.ConnectionID)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue