mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
fix race condition when closing the server after a Read failed
This commit is contained in:
parent
c135b4f1e3
commit
dc75123836
4 changed files with 21 additions and 26 deletions
|
@ -33,20 +33,6 @@ func (m *MockUnknownPacketHandler) EXPECT() *MockUnknownPacketHandlerMockRecorde
|
|||
return m.recorder
|
||||
}
|
||||
|
||||
// closeWithError mocks base method
|
||||
func (m *MockUnknownPacketHandler) closeWithError(arg0 error) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "closeWithError", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// closeWithError indicates an expected call of closeWithError
|
||||
func (mr *MockUnknownPacketHandlerMockRecorder) closeWithError(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "closeWithError", reflect.TypeOf((*MockUnknownPacketHandler)(nil).closeWithError), arg0)
|
||||
}
|
||||
|
||||
// handlePacket mocks base method
|
||||
func (m *MockUnknownPacketHandler) handlePacket(arg0 *receivedPacket) {
|
||||
m.ctrl.T.Helper()
|
||||
|
@ -58,3 +44,15 @@ func (mr *MockUnknownPacketHandlerMockRecorder) handlePacket(arg0 interface{}) *
|
|||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "handlePacket", reflect.TypeOf((*MockUnknownPacketHandler)(nil).handlePacket), arg0)
|
||||
}
|
||||
|
||||
// setCloseError mocks base method
|
||||
func (m *MockUnknownPacketHandler) setCloseError(arg0 error) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "setCloseError", arg0)
|
||||
}
|
||||
|
||||
// setCloseError indicates an expected call of setCloseError
|
||||
func (mr *MockUnknownPacketHandlerMockRecorder) setCloseError(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "setCloseError", reflect.TypeOf((*MockUnknownPacketHandler)(nil).setCloseError), arg0)
|
||||
}
|
||||
|
|
|
@ -141,7 +141,6 @@ func (h *packetHandlerMap) close(e error) error {
|
|||
h.mutex.Unlock()
|
||||
return nil
|
||||
}
|
||||
h.closed = true
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, handler := range h.handlers {
|
||||
|
@ -153,8 +152,9 @@ func (h *packetHandlerMap) close(e error) error {
|
|||
}
|
||||
|
||||
if h.server != nil {
|
||||
h.server.closeWithError(e)
|
||||
h.server.setCloseError(e)
|
||||
}
|
||||
h.closed = true
|
||||
h.mutex.Unlock()
|
||||
wg.Wait()
|
||||
return getMultiplexer().RemoveConn(h.conn)
|
||||
|
|
13
server.go
13
server.go
|
@ -28,7 +28,7 @@ type packetHandler interface {
|
|||
|
||||
type unknownPacketHandler interface {
|
||||
handlePacket(*receivedPacket)
|
||||
closeWithError(error) error
|
||||
setCloseError(error)
|
||||
}
|
||||
|
||||
type packetHandlerManager interface {
|
||||
|
@ -293,10 +293,6 @@ func (s *server) Close() error {
|
|||
if s.closed {
|
||||
return nil
|
||||
}
|
||||
return s.closeWithMutex()
|
||||
}
|
||||
|
||||
func (s *server) closeWithMutex() error {
|
||||
s.sessionHandler.CloseServer()
|
||||
if s.serverError == nil {
|
||||
s.serverError = errors.New("server closed")
|
||||
|
@ -312,14 +308,15 @@ func (s *server) closeWithMutex() error {
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *server) closeWithError(e error) error {
|
||||
func (s *server) setCloseError(e error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
if s.closed {
|
||||
return nil
|
||||
return
|
||||
}
|
||||
s.closed = true
|
||||
s.serverError = e
|
||||
return s.closeWithMutex()
|
||||
close(s.errorChan)
|
||||
}
|
||||
|
||||
// Addr returns the server's network address
|
||||
|
|
|
@ -451,13 +451,13 @@ var _ = Describe("Server", func() {
|
|||
close(done)
|
||||
}()
|
||||
|
||||
Expect(serv.closeWithError(testErr)).To(Succeed())
|
||||
serv.setCloseError(testErr)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("returns immediately, if an error occurred before", func() {
|
||||
testErr := errors.New("test err")
|
||||
Expect(serv.closeWithError(testErr)).To(Succeed())
|
||||
serv.setCloseError(testErr)
|
||||
for i := 0; i < 3; i++ {
|
||||
_, err := serv.Accept()
|
||||
Expect(err).To(MatchError(testErr))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue