diff --git a/mock_packet_handler_manager_test.go b/mock_packet_handler_manager_test.go new file mode 100644 index 00000000..120bd8a4 --- /dev/null +++ b/mock_packet_handler_manager_test.go @@ -0,0 +1,78 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go (interfaces: PacketHandlerManager) + +// Package quic is a generated GoMock package. +package quic + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// MockPacketHandlerManager is a mock of PacketHandlerManager interface +type MockPacketHandlerManager struct { + ctrl *gomock.Controller + recorder *MockPacketHandlerManagerMockRecorder +} + +// MockPacketHandlerManagerMockRecorder is the mock recorder for MockPacketHandlerManager +type MockPacketHandlerManagerMockRecorder struct { + mock *MockPacketHandlerManager +} + +// NewMockPacketHandlerManager creates a new mock instance +func NewMockPacketHandlerManager(ctrl *gomock.Controller) *MockPacketHandlerManager { + mock := &MockPacketHandlerManager{ctrl: ctrl} + mock.recorder = &MockPacketHandlerManagerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockPacketHandlerManager) EXPECT() *MockPacketHandlerManagerMockRecorder { + return m.recorder +} + +// Add mocks base method +func (m *MockPacketHandlerManager) Add(arg0 protocol.ConnectionID, arg1 packetHandler) { + m.ctrl.Call(m, "Add", arg0, arg1) +} + +// Add indicates an expected call of Add +func (mr *MockPacketHandlerManagerMockRecorder) Add(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockPacketHandlerManager)(nil).Add), arg0, arg1) +} + +// Close mocks base method +func (m *MockPacketHandlerManager) Close() { + m.ctrl.Call(m, "Close") +} + +// Close indicates an expected call of Close +func (mr *MockPacketHandlerManagerMockRecorder) Close() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPacketHandlerManager)(nil).Close)) +} + +// Get mocks base method +func (m *MockPacketHandlerManager) Get(arg0 protocol.ConnectionID) (packetHandler, bool) { + ret := m.ctrl.Call(m, "Get", arg0) + ret0, _ := ret[0].(packetHandler) + ret1, _ := ret[1].(bool) + return ret0, ret1 +} + +// Get indicates an expected call of Get +func (mr *MockPacketHandlerManagerMockRecorder) Get(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockPacketHandlerManager)(nil).Get), arg0) +} + +// Remove mocks base method +func (m *MockPacketHandlerManager) Remove(arg0 protocol.ConnectionID) { + m.ctrl.Call(m, "Remove", arg0) +} + +// Remove indicates an expected call of Remove +func (mr *MockPacketHandlerManagerMockRecorder) Remove(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockPacketHandlerManager)(nil).Remove), arg0) +} diff --git a/mock_session_handler_test.go b/mock_session_handler_test.go deleted file mode 100644 index 522cee2a..00000000 --- a/mock_session_handler_test.go +++ /dev/null @@ -1,78 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/lucas-clemente/quic-go (interfaces: SessionHandler) - -// Package quic is a generated GoMock package. -package quic - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/lucas-clemente/quic-go/internal/protocol" -) - -// MockSessionHandler is a mock of SessionHandler interface -type MockSessionHandler struct { - ctrl *gomock.Controller - recorder *MockSessionHandlerMockRecorder -} - -// MockSessionHandlerMockRecorder is the mock recorder for MockSessionHandler -type MockSessionHandlerMockRecorder struct { - mock *MockSessionHandler -} - -// NewMockSessionHandler creates a new mock instance -func NewMockSessionHandler(ctrl *gomock.Controller) *MockSessionHandler { - mock := &MockSessionHandler{ctrl: ctrl} - mock.recorder = &MockSessionHandlerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use -func (m *MockSessionHandler) EXPECT() *MockSessionHandlerMockRecorder { - return m.recorder -} - -// Add mocks base method -func (m *MockSessionHandler) Add(arg0 protocol.ConnectionID, arg1 packetHandler) { - m.ctrl.Call(m, "Add", arg0, arg1) -} - -// Add indicates an expected call of Add -func (mr *MockSessionHandlerMockRecorder) Add(arg0, arg1 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockSessionHandler)(nil).Add), arg0, arg1) -} - -// Close mocks base method -func (m *MockSessionHandler) Close() { - m.ctrl.Call(m, "Close") -} - -// Close indicates an expected call of Close -func (mr *MockSessionHandlerMockRecorder) Close() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockSessionHandler)(nil).Close)) -} - -// Get mocks base method -func (m *MockSessionHandler) Get(arg0 protocol.ConnectionID) (packetHandler, bool) { - ret := m.ctrl.Call(m, "Get", arg0) - ret0, _ := ret[0].(packetHandler) - ret1, _ := ret[1].(bool) - return ret0, ret1 -} - -// Get indicates an expected call of Get -func (mr *MockSessionHandlerMockRecorder) Get(arg0 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockSessionHandler)(nil).Get), arg0) -} - -// Remove mocks base method -func (m *MockSessionHandler) Remove(arg0 protocol.ConnectionID) { - m.ctrl.Call(m, "Remove", arg0) -} - -// Remove indicates an expected call of Remove -func (mr *MockSessionHandlerMockRecorder) Remove(arg0 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockSessionHandler)(nil).Remove), arg0) -} diff --git a/mockgen.go b/mockgen.go index cf0470c7..feaaf595 100644 --- a/mockgen.go +++ b/mockgen.go @@ -13,6 +13,6 @@ package quic //go:generate sh -c "./mockgen_private.sh quic mock_gquic_aead_test.go github.com/lucas-clemente/quic-go gQUICAEAD GQUICAEAD" //go:generate sh -c "./mockgen_private.sh quic mock_session_runner_test.go github.com/lucas-clemente/quic-go sessionRunner SessionRunner" //go:generate sh -c "./mockgen_private.sh quic mock_quic_session_test.go github.com/lucas-clemente/quic-go quicSession QuicSession" -//go:generate sh -c "./mockgen_private.sh quic mock_session_handler_test.go github.com/lucas-clemente/quic-go sessionHandler SessionHandler" +//go:generate sh -c "./mockgen_private.sh quic mock_packet_handler_manager_test.go github.com/lucas-clemente/quic-go packetHandlerManager PacketHandlerManager" //go:generate sh -c "find . -type f -name 'mock_*_test.go' | xargs sed -i '' 's/quic_go.//g'" //go:generate sh -c "goimports -w mock*_test.go" diff --git a/packet_handler_map.go b/packet_handler_map.go new file mode 100644 index 00000000..ea8334e6 --- /dev/null +++ b/packet_handler_map.go @@ -0,0 +1,74 @@ +package quic + +import ( + "sync" + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +type packetHandlerMap struct { + mutex sync.RWMutex + + handlers map[string] /* string(ConnectionID)*/ packetHandler + closed bool + + deleteClosedSessionsAfter time.Duration +} + +var _ packetHandlerManager = &packetHandlerMap{} + +func newPacketHandlerMap() packetHandlerManager { + return &packetHandlerMap{ + handlers: make(map[string]packetHandler), + deleteClosedSessionsAfter: protocol.ClosedSessionDeleteTimeout, + } +} + +func (h *packetHandlerMap) Get(id protocol.ConnectionID) (packetHandler, bool) { + h.mutex.RLock() + sess, ok := h.handlers[string(id)] + h.mutex.RUnlock() + return sess, ok +} + +func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) { + h.mutex.Lock() + h.handlers[string(id)] = handler + h.mutex.Unlock() +} + +func (h *packetHandlerMap) Remove(id protocol.ConnectionID) { + h.mutex.Lock() + h.handlers[string(id)] = nil + h.mutex.Unlock() + + time.AfterFunc(h.deleteClosedSessionsAfter, func() { + h.mutex.Lock() + delete(h.handlers, string(id)) + h.mutex.Unlock() + }) +} + +func (h *packetHandlerMap) Close() { + h.mutex.Lock() + if h.closed { + h.mutex.Unlock() + return + } + h.closed = true + + var wg sync.WaitGroup + for _, handler := range h.handlers { + if handler != nil { + wg.Add(1) + go func(handler packetHandler) { + // session.Close() blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped + _ = handler.Close(nil) + wg.Done() + }(handler) + } + } + h.mutex.Unlock() + wg.Wait() +} diff --git a/session_map_test.go b/packet_handler_map_test.go similarity index 90% rename from session_map_test.go rename to packet_handler_map_test.go index ad02725d..a08c8e2b 100644 --- a/session_map_test.go +++ b/packet_handler_map_test.go @@ -8,11 +8,11 @@ import ( . "github.com/onsi/gomega" ) -var _ = Describe("Session Handler", func() { - var handler *sessionMap +var _ = Describe("Packet Handler Map", func() { + var handler *packetHandlerMap BeforeEach(func() { - handler = newSessionMap().(*sessionMap) + handler = newPacketHandlerMap().(*packetHandlerMap) }) It("adds and gets", func() { diff --git a/server.go b/server.go index 9c4af1f0..ba6970a4 100644 --- a/server.go +++ b/server.go @@ -22,6 +22,13 @@ type packetHandler interface { Close(error) error } +type packetHandlerManager interface { + Add(protocol.ConnectionID, packetHandler) + Get(protocol.ConnectionID) (packetHandler, bool) + Remove(protocol.ConnectionID) + Close() +} + type quicSession interface { Session handlePacket(*receivedPacket) @@ -46,13 +53,6 @@ func (r *runner) removeConnectionID(c protocol.ConnectionID) { r.removeConnectio var _ sessionRunner = &runner{} -type sessionHandler interface { - Add(protocol.ConnectionID, packetHandler) - Get(protocol.ConnectionID) (packetHandler, bool) - Remove(protocol.ConnectionID) - Close() -} - // A Listener of QUIC type server struct { tlsConf *tls.Config @@ -66,7 +66,7 @@ type server struct { certChain crypto.CertChain scfg *handshake.ServerConfig - sessionHandler sessionHandler + sessionHandler packetHandlerManager serverError error @@ -129,7 +129,7 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, certChain: certChain, scfg: scfg, newSession: newSession, - sessionHandler: newSessionMap(), + sessionHandler: newPacketHandlerMap(), sessionQueue: make(chan Session, 5), errorChan: make(chan struct{}), supportsTLS: supportsTLS, diff --git a/server_test.go b/server_test.go index 6e82d92c..fd2d8500 100644 --- a/server_test.go +++ b/server_test.go @@ -84,11 +84,11 @@ var _ = Describe("Server", func() { firstPacket []byte // a valid first packet for a new connection with connectionID 0x4cfa9f9b668619f6 (= connID) connID = protocol.ConnectionID{0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6} sessions = make([]*MockQuicSession, 0) - sessionHandler *MockSessionHandler + sessionHandler *MockPacketHandlerManager ) BeforeEach(func() { - sessionHandler = NewMockSessionHandler(mockCtrl) + sessionHandler = NewMockPacketHandlerManager(mockCtrl) newMockSession := func( _ connection, runner sessionRunner, diff --git a/session_map.go b/session_map.go deleted file mode 100644 index 561630b1..00000000 --- a/session_map.go +++ /dev/null @@ -1,74 +0,0 @@ -package quic - -import ( - "sync" - "time" - - "github.com/lucas-clemente/quic-go/internal/protocol" -) - -type sessionMap struct { - mutex sync.RWMutex - - sessions map[string] /* string(ConnectionID)*/ packetHandler - closed bool - - deleteClosedSessionsAfter time.Duration -} - -var _ sessionHandler = &sessionMap{} - -func newSessionMap() sessionHandler { - return &sessionMap{ - sessions: make(map[string]packetHandler), - deleteClosedSessionsAfter: protocol.ClosedSessionDeleteTimeout, - } -} - -func (h *sessionMap) Get(id protocol.ConnectionID) (packetHandler, bool) { - h.mutex.RLock() - sess, ok := h.sessions[string(id)] - h.mutex.RUnlock() - return sess, ok -} - -func (h *sessionMap) Add(id protocol.ConnectionID, sess packetHandler) { - h.mutex.Lock() - h.sessions[string(id)] = sess - h.mutex.Unlock() -} - -func (h *sessionMap) Remove(id protocol.ConnectionID) { - h.mutex.Lock() - h.sessions[string(id)] = nil - h.mutex.Unlock() - - time.AfterFunc(h.deleteClosedSessionsAfter, func() { - h.mutex.Lock() - delete(h.sessions, string(id)) - h.mutex.Unlock() - }) -} - -func (h *sessionMap) Close() { - h.mutex.Lock() - if h.closed { - h.mutex.Unlock() - return - } - h.closed = true - - var wg sync.WaitGroup - for _, session := range h.sessions { - if session != nil { - wg.Add(1) - go func(sess packetHandler) { - // session.Close() blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped - _ = sess.Close(nil) - wg.Done() - }(session) - } - } - h.mutex.Unlock() - wg.Wait() -}