add a function to close the packet handler map

Close will close the underlying connection and wait until listen has
returned. While not strictly necessary in production use, this will fix
a few race conditions in our tests.
This commit is contained in:
Marten Seemann 2019-01-24 18:09:46 +07:00
parent 6dc4be9f4e
commit bb185a3ad2
5 changed files with 32 additions and 7 deletions

View file

@ -271,7 +271,7 @@ func (c *client) establishSecureConnection(ctx context.Context) error {
go func() {
err := c.session.run() // returns as soon as the session is closed
if err != errCloseForRecreating && c.createdPacketConn {
c.conn.Close()
c.packetHandlers.Close()
}
errorChan <- err
}()

View file

@ -131,6 +131,7 @@ var _ = Describe("Client", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(gomock.Any(), gomock.Any())
manager.EXPECT().Close()
mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any()).Return(manager, nil)
remoteAddrChan := make(chan string, 1)
@ -162,6 +163,7 @@ var _ = Describe("Client", func() {
It("uses the tls.Config.ServerName as the hostname, if present", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(gomock.Any(), gomock.Any())
manager.EXPECT().Close()
mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any()).Return(manager, nil)
hostnameChan := make(chan string, 1)
@ -403,12 +405,9 @@ var _ = Describe("Client", func() {
// check that the connection is not closed
Expect(conn.Write([]byte("foobar"))).To(Succeed())
manager.EXPECT().Close()
close(run)
time.Sleep(50 * time.Millisecond)
// check that the connection is closed
err := conn.Write([]byte("foobar"))
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("use of closed network connection"))
Eventually(done).Should(BeClosed())
})

View file

@ -44,6 +44,18 @@ func (mr *MockPacketHandlerManagerMockRecorder) Add(arg0, arg1 interface{}) *gom
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockPacketHandlerManager)(nil).Add), arg0, arg1)
}
// Close mocks base method
func (m *MockPacketHandlerManager) Close() error {
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close
func (mr *MockPacketHandlerManagerMockRecorder) Close() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPacketHandlerManager)(nil).Close))
}
// CloseServer mocks base method
func (m *MockPacketHandlerManager) CloseServer() {
m.ctrl.Call(m, "CloseServer")

View file

@ -31,7 +31,9 @@ type packetHandlerMap struct {
handlers map[string] /* string(ConnectionID)*/ packetHandlerEntry
resetTokens map[[16]byte] /* stateless reset token */ packetHandler
server unknownPacketHandler
closed bool
listening chan struct{} // is closed when listen returns
closed bool
deleteRetiredSessionsAfter time.Duration
@ -44,6 +46,7 @@ func newPacketHandlerMap(conn net.PacketConn, connIDLen int, logger utils.Logger
m := &packetHandlerMap{
conn: conn,
connIDLen: connIDLen,
listening: make(chan struct{}),
handlers: make(map[string]packetHandlerEntry),
resetTokens: make(map[[16]byte]packetHandler),
deleteRetiredSessionsAfter: protocol.RetiredConnectionIDDeleteTimeout,
@ -117,6 +120,15 @@ func (h *packetHandlerMap) CloseServer() {
wg.Wait()
}
// Close the underlying connection and wait until listen() has returned.
func (h *packetHandlerMap) Close() error {
if err := h.conn.Close(); err != nil {
return err
}
<-h.listening // wait until listening returns
return nil
}
func (h *packetHandlerMap) close(e error) error {
h.mutex.Lock()
if h.closed {
@ -143,6 +155,7 @@ func (h *packetHandlerMap) close(e error) error {
}
func (h *packetHandlerMap) listen() {
defer close(h.listening)
for {
buffer := getPacketBuffer()
data := buffer.Slice

View file

@ -32,6 +32,7 @@ type unknownPacketHandler interface {
}
type packetHandlerManager interface {
io.Closer
Add(protocol.ConnectionID, packetHandler)
Retire(protocol.ConnectionID)
Remove(protocol.ConnectionID)
@ -300,7 +301,7 @@ func (s *server) closeWithMutex() error {
// If the server was started with ListenAddr, we created the packet conn.
// We need to close it in order to make the go routine reading from that conn return.
if s.createdPacketConn {
err = s.conn.Close()
err = s.sessionHandler.Close()
}
s.closed = true
close(s.errorChan)