remove closed clients from the multiplexer

This commit is contained in:
Marten Seemann 2018-07-02 14:56:12 +07:00
parent 7658f10a51
commit 0928e91e4d
5 changed files with 90 additions and 27 deletions

View file

@ -41,6 +41,7 @@ type client struct {
version protocol.VersionNumber
handshakeChan chan struct{}
closeCallback func(protocol.ConnectionID)
session quicSession
@ -81,7 +82,7 @@ func DialAddrContext(
if err != nil {
return nil, err
}
c, err := newClient(udpConn, udpAddr, config, tlsConf, addr)
c, err := newClient(udpConn, udpAddr, config, tlsConf, addr, nil)
if err != nil {
return nil, err
}
@ -114,18 +115,29 @@ func DialContext(
tlsConf *tls.Config,
config *Config,
) (Session, error) {
c, err := newClient(pconn, remoteAddr, config, tlsConf, host)
multiplexer := getClientMultiplexer()
manager := multiplexer.AddConn(pconn)
c, err := newClient(pconn, remoteAddr, config, tlsConf, host, manager.Remove)
if err != nil {
return nil, err
}
getClientMultiplexer().Add(pconn, c.srcConnID, c)
if err := multiplexer.AddHandler(pconn, c.srcConnID, c); err != nil {
return nil, err
}
if err := c.dial(ctx); err != nil {
return nil, err
}
return c.session, nil
}
func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsConf *tls.Config, host string) (*client, error) {
func newClient(
pconn net.PacketConn,
remoteAddr net.Addr,
config *Config,
tlsConf *tls.Config,
host string,
closeCallback func(protocol.ConnectionID),
) (*client, error) {
clientConfig := populateClientConfig(config)
version := clientConfig.Versions[0]
srcConnID, err := generateConnectionID()
@ -159,6 +171,10 @@ func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsCon
}
}
}
onClose := func(protocol.ConnectionID) {}
if closeCallback != nil {
onClose = closeCallback
}
return &client{
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
srcConnID: srcConnID,
@ -168,6 +184,7 @@ func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsCon
config: clientConfig,
version: version,
handshakeChan: make(chan struct{}),
closeCallback: onClose,
logger: utils.DefaultLogger.WithPrefix("client"),
}, nil
}
@ -508,7 +525,7 @@ func (c *client) createNewGQUICSession() (err error) {
defer c.mutex.Unlock()
runner := &runner{
onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) },
removeConnectionIDImpl: func(protocol.ConnectionID) {},
removeConnectionIDImpl: c.closeCallback,
}
c.session, err = newClientSession(
c.conn,
@ -533,7 +550,7 @@ func (c *client) createNewTLSSession(
defer c.mutex.Unlock()
runner := &runner{
onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) },
removeConnectionIDImpl: func(protocol.ConnectionID) {},
removeConnectionIDImpl: c.closeCallback,
}
c.session, err = newTLSClientSession(
c.conn,

View file

@ -2,6 +2,7 @@ package quic
import (
"bytes"
"errors"
"net"
"strings"
"sync"
@ -22,7 +23,8 @@ var (
type clientMultiplexer struct {
mutex sync.Mutex
conns map[net.PacketConn]packetHandlerManager
conns map[net.PacketConn]packetHandlerManager
newPacketHandlerManager func() packetHandlerManager // so it can be replaced in the tests
logger utils.Logger
}
@ -30,29 +32,35 @@ type clientMultiplexer struct {
func getClientMultiplexer() *clientMultiplexer {
clientMuxerOnce.Do(func() {
clientMuxer = &clientMultiplexer{
conns: make(map[net.PacketConn]packetHandlerManager),
logger: utils.DefaultLogger.WithPrefix("client muxer"),
conns: make(map[net.PacketConn]packetHandlerManager),
logger: utils.DefaultLogger.WithPrefix("client muxer"),
newPacketHandlerManager: newPacketHandlerMap,
}
})
return clientMuxer
}
func (m *clientMultiplexer) Add(c net.PacketConn, connID protocol.ConnectionID, handler packetHandler) {
func (m *clientMultiplexer) AddConn(c net.PacketConn) packetHandlerManager {
m.mutex.Lock()
defer m.mutex.Unlock()
sessions, ok := m.conns[c]
if !ok {
sessions = newPacketHandlerMap()
sessions = m.newPacketHandlerManager()
m.conns[c] = sessions
// If we didn't know this packet conn before, listen for incoming packets
// and dispatch them to the right sessions.
go m.listen(c, sessions)
}
return sessions
}
func (m *clientMultiplexer) AddHandler(c net.PacketConn, connID protocol.ConnectionID, handler packetHandler) error {
sessions, ok := m.conns[c]
if !ok {
return errors.New("unknown packet conn %s")
}
sessions.Add(connID, handler)
if ok {
return
}
// If we didn't know this packet conn before, listen for incoming packets
// and dispatch them to the right sessions.
go m.listen(c, sessions)
return nil
}
func (m *clientMultiplexer) listen(c net.PacketConn, sessions packetHandlerManager) {
@ -83,6 +91,10 @@ func (m *clientMultiplexer) listen(c net.PacketConn, sessions packetHandlerManag
m.logger.Debugf("received a packet with an unexpected connection ID %s", iHdr.DestConnectionID)
continue
}
if client == nil {
// Late packet for closed session
continue
}
hdr, err := iHdr.Parse(r, protocol.PerspectiveServer, client.GetVersion())
if err != nil {
m.logger.Debugf("error parsing header from %s: %s", addr, err)

View file

@ -27,20 +27,28 @@ var _ = Describe("Client Multiplexer", func() {
It("adds a new packet conn and handles packets", func() {
conn := newMockPacketConn()
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
conn.dataToRead <- getPacket(connID)
packetHandler := NewMockQuicSession(mockCtrl)
handledPacket := make(chan struct{})
packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(_ *receivedPacket) {
close(handledPacket)
})
packetHandler.EXPECT().GetVersion()
getClientMultiplexer().Add(conn, connID, packetHandler)
getClientMultiplexer().AddConn(conn)
err := getClientMultiplexer().AddHandler(conn, connID, packetHandler)
Expect(err).ToNot(HaveOccurred())
conn.dataToRead <- getPacket(connID)
Eventually(handledPacket).Should(BeClosed())
// makes the listen go routine return
packetHandler.EXPECT().Close(gomock.Any()).AnyTimes()
close(conn.dataToRead)
})
It("errors when adding a handler for an unknown conn", func() {
conn := newMockPacketConn()
err := getClientMultiplexer().AddHandler(conn, protocol.ConnectionID{1, 2, 3, 4}, NewMockQuicSession(mockCtrl))
Expect(err).ToNot(MatchError("unknown packet conn"))
})
It("handles packets for different packet handlers on the same packet conn", func() {
conn := newMockPacketConn()
connID1 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
@ -59,8 +67,9 @@ var _ = Describe("Client Multiplexer", func() {
close(handledPacket2)
})
packetHandler2.EXPECT().GetVersion()
getClientMultiplexer().Add(conn, connID1, packetHandler1)
getClientMultiplexer().Add(conn, connID2, packetHandler2)
getClientMultiplexer().AddConn(conn)
Expect(getClientMultiplexer().AddHandler(conn, connID1, packetHandler1)).To(Succeed())
Expect(getClientMultiplexer().AddHandler(conn, connID2, packetHandler2)).To(Succeed())
conn.dataToRead <- getPacket(connID1)
conn.dataToRead <- getPacket(connID2)
@ -78,17 +87,39 @@ var _ = Describe("Client Multiplexer", func() {
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
conn.dataToRead <- []byte("invalid header")
packetHandler := NewMockQuicSession(mockCtrl)
getClientMultiplexer().Add(conn, connID, packetHandler)
getClientMultiplexer().AddConn(conn)
Expect(getClientMultiplexer().AddHandler(conn, connID, packetHandler)).To(Succeed())
time.Sleep(100 * time.Millisecond) // give the listen go routine some time to process the packet
packetHandler.EXPECT().Close(gomock.Any()).AnyTimes()
close(conn.dataToRead)
})
It("ignores packets arriving late for closed sessions", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
origNewPacketHandlerManager := getClientMultiplexer().newPacketHandlerManager
defer func() {
getClientMultiplexer().newPacketHandlerManager = origNewPacketHandlerManager
}()
getClientMultiplexer().newPacketHandlerManager = func() packetHandlerManager { return manager }
conn := newMockPacketConn()
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
done := make(chan struct{})
manager.EXPECT().Get(connID).Do(func(protocol.ConnectionID) { close(done) }).Return(nil, true)
getClientMultiplexer().AddConn(conn)
conn.dataToRead <- getPacket(connID)
Eventually(done).Should(BeClosed())
// makes the listen go routine return
manager.EXPECT().Close(gomock.Any()).AnyTimes()
close(conn.dataToRead)
})
It("drops packets for unknown receivers", func() {
conn := newMockPacketConn()
conn.dataToRead <- getPacket(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8})
packetHandler := NewMockQuicSession(mockCtrl)
getClientMultiplexer().Add(conn, protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, packetHandler)
getClientMultiplexer().AddConn(conn)
Expect(getClientMultiplexer().AddHandler(conn, protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, packetHandler)).To(Succeed())
time.Sleep(100 * time.Millisecond) // give the listen go routine some time to process the packet
// makes the listen go routine return
packetHandler.EXPECT().Close(gomock.Any()).AnyTimes()
@ -104,7 +135,8 @@ var _ = Describe("Client Multiplexer", func() {
packetHandler.EXPECT().Close(testErr).Do(func(error) {
close(done)
})
getClientMultiplexer().Add(conn, protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, packetHandler)
getClientMultiplexer().AddConn(conn)
Expect(getClientMultiplexer().AddHandler(conn, protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, packetHandler)).To(Succeed())
Eventually(done).Should(BeClosed())
})
})

View file

@ -251,7 +251,7 @@ var _ = Describe("Client", func() {
It("errors when the Config contains an invalid version", func() {
version := protocol.VersionNumber(0x1234)
_, err := Dial(nil, nil, "localhost:1234", &tls.Config{}, &Config{Versions: []protocol.VersionNumber{version}})
_, err := Dial(packetConn, nil, "localhost:1234", &tls.Config{}, &Config{Versions: []protocol.VersionNumber{version}})
Expect(err).To(MatchError("0x1234 is not a valid QUIC version"))
})

View file

@ -536,7 +536,9 @@ runLoop:
s.logger.Infof("Handling close error failed: %s", err)
}
s.logger.Infof("Connection %s closed.", s.srcConnID)
s.sessionRunner.removeConnectionID(s.srcConnID)
if closeErr.err != handshake.ErrCloseSessionForRetry {
s.sessionRunner.removeConnectionID(s.srcConnID)
}
return closeErr.err
}