diff --git a/packet_handler_map.go b/packet_handler_map.go index 3b327135..a458ee98 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -238,6 +238,10 @@ func (h *packetHandlerMap) SetServer(s unknownPacketHandler) { func (h *packetHandlerMap) CloseServer() { h.mutex.Lock() + if h.server == nil { + h.mutex.Unlock() + return + } h.server = nil var wg sync.WaitGroup for _, handler := range h.handlers { diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index 95d9737c..06cb9b8f 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -251,6 +251,7 @@ var _ = Describe("Packet Handler Map", func() { }) It("closes all server sessions", func() { + handler.SetServer(NewMockUnknownPacketHandler(mockCtrl)) clientSess := NewMockPacketHandler(mockCtrl) clientSess.EXPECT().getPerspective().Return(protocol.PerspectiveClient) serverSess := NewMockPacketHandler(mockCtrl) diff --git a/server.go b/server.go index 29cbed67..1df61d17 100644 --- a/server.go +++ b/server.go @@ -274,12 +274,12 @@ func (s *baseServer) accept(ctx context.Context) (quicSession, error) { // Close the server func (s *baseServer) Close() error { + s.sessionHandler.CloseServer() s.mutex.Lock() defer s.mutex.Unlock() if s.closed { return nil } - s.sessionHandler.CloseServer() if s.serverError == nil { s.serverError = errors.New("server closed") }