mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-06 13:47:35 +03:00
remove closed clients from the multiplexer
This commit is contained in:
parent
7658f10a51
commit
0928e91e4d
5 changed files with 90 additions and 27 deletions
29
client.go
29
client.go
|
@ -41,6 +41,7 @@ type client struct {
|
||||||
version protocol.VersionNumber
|
version protocol.VersionNumber
|
||||||
|
|
||||||
handshakeChan chan struct{}
|
handshakeChan chan struct{}
|
||||||
|
closeCallback func(protocol.ConnectionID)
|
||||||
|
|
||||||
session quicSession
|
session quicSession
|
||||||
|
|
||||||
|
@ -81,7 +82,7 @@ func DialAddrContext(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
c, err := newClient(udpConn, udpAddr, config, tlsConf, addr)
|
c, err := newClient(udpConn, udpAddr, config, tlsConf, addr, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -114,18 +115,29 @@ func DialContext(
|
||||||
tlsConf *tls.Config,
|
tlsConf *tls.Config,
|
||||||
config *Config,
|
config *Config,
|
||||||
) (Session, error) {
|
) (Session, error) {
|
||||||
c, err := newClient(pconn, remoteAddr, config, tlsConf, host)
|
multiplexer := getClientMultiplexer()
|
||||||
|
manager := multiplexer.AddConn(pconn)
|
||||||
|
c, err := newClient(pconn, remoteAddr, config, tlsConf, host, manager.Remove)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
getClientMultiplexer().Add(pconn, c.srcConnID, c)
|
if err := multiplexer.AddHandler(pconn, c.srcConnID, c); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
if err := c.dial(ctx); err != nil {
|
if err := c.dial(ctx); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return c.session, nil
|
return c.session, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsConf *tls.Config, host string) (*client, error) {
|
func newClient(
|
||||||
|
pconn net.PacketConn,
|
||||||
|
remoteAddr net.Addr,
|
||||||
|
config *Config,
|
||||||
|
tlsConf *tls.Config,
|
||||||
|
host string,
|
||||||
|
closeCallback func(protocol.ConnectionID),
|
||||||
|
) (*client, error) {
|
||||||
clientConfig := populateClientConfig(config)
|
clientConfig := populateClientConfig(config)
|
||||||
version := clientConfig.Versions[0]
|
version := clientConfig.Versions[0]
|
||||||
srcConnID, err := generateConnectionID()
|
srcConnID, err := generateConnectionID()
|
||||||
|
@ -159,6 +171,10 @@ func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsCon
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
onClose := func(protocol.ConnectionID) {}
|
||||||
|
if closeCallback != nil {
|
||||||
|
onClose = closeCallback
|
||||||
|
}
|
||||||
return &client{
|
return &client{
|
||||||
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
|
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
|
||||||
srcConnID: srcConnID,
|
srcConnID: srcConnID,
|
||||||
|
@ -168,6 +184,7 @@ func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsCon
|
||||||
config: clientConfig,
|
config: clientConfig,
|
||||||
version: version,
|
version: version,
|
||||||
handshakeChan: make(chan struct{}),
|
handshakeChan: make(chan struct{}),
|
||||||
|
closeCallback: onClose,
|
||||||
logger: utils.DefaultLogger.WithPrefix("client"),
|
logger: utils.DefaultLogger.WithPrefix("client"),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
@ -508,7 +525,7 @@ func (c *client) createNewGQUICSession() (err error) {
|
||||||
defer c.mutex.Unlock()
|
defer c.mutex.Unlock()
|
||||||
runner := &runner{
|
runner := &runner{
|
||||||
onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) },
|
onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) },
|
||||||
removeConnectionIDImpl: func(protocol.ConnectionID) {},
|
removeConnectionIDImpl: c.closeCallback,
|
||||||
}
|
}
|
||||||
c.session, err = newClientSession(
|
c.session, err = newClientSession(
|
||||||
c.conn,
|
c.conn,
|
||||||
|
@ -533,7 +550,7 @@ func (c *client) createNewTLSSession(
|
||||||
defer c.mutex.Unlock()
|
defer c.mutex.Unlock()
|
||||||
runner := &runner{
|
runner := &runner{
|
||||||
onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) },
|
onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) },
|
||||||
removeConnectionIDImpl: func(protocol.ConnectionID) {},
|
removeConnectionIDImpl: c.closeCallback,
|
||||||
}
|
}
|
||||||
c.session, err = newTLSClientSession(
|
c.session, err = newTLSClientSession(
|
||||||
c.conn,
|
c.conn,
|
||||||
|
|
|
@ -2,6 +2,7 @@ package quic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -22,7 +23,8 @@ var (
|
||||||
type clientMultiplexer struct {
|
type clientMultiplexer struct {
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
|
|
||||||
conns map[net.PacketConn]packetHandlerManager
|
conns map[net.PacketConn]packetHandlerManager
|
||||||
|
newPacketHandlerManager func() packetHandlerManager // so it can be replaced in the tests
|
||||||
|
|
||||||
logger utils.Logger
|
logger utils.Logger
|
||||||
}
|
}
|
||||||
|
@ -30,29 +32,35 @@ type clientMultiplexer struct {
|
||||||
func getClientMultiplexer() *clientMultiplexer {
|
func getClientMultiplexer() *clientMultiplexer {
|
||||||
clientMuxerOnce.Do(func() {
|
clientMuxerOnce.Do(func() {
|
||||||
clientMuxer = &clientMultiplexer{
|
clientMuxer = &clientMultiplexer{
|
||||||
conns: make(map[net.PacketConn]packetHandlerManager),
|
conns: make(map[net.PacketConn]packetHandlerManager),
|
||||||
logger: utils.DefaultLogger.WithPrefix("client muxer"),
|
logger: utils.DefaultLogger.WithPrefix("client muxer"),
|
||||||
|
newPacketHandlerManager: newPacketHandlerMap,
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
return clientMuxer
|
return clientMuxer
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *clientMultiplexer) Add(c net.PacketConn, connID protocol.ConnectionID, handler packetHandler) {
|
func (m *clientMultiplexer) AddConn(c net.PacketConn) packetHandlerManager {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
sessions, ok := m.conns[c]
|
sessions, ok := m.conns[c]
|
||||||
if !ok {
|
if !ok {
|
||||||
sessions = newPacketHandlerMap()
|
sessions = m.newPacketHandlerManager()
|
||||||
m.conns[c] = sessions
|
m.conns[c] = sessions
|
||||||
|
// If we didn't know this packet conn before, listen for incoming packets
|
||||||
|
// and dispatch them to the right sessions.
|
||||||
|
go m.listen(c, sessions)
|
||||||
|
}
|
||||||
|
return sessions
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *clientMultiplexer) AddHandler(c net.PacketConn, connID protocol.ConnectionID, handler packetHandler) error {
|
||||||
|
sessions, ok := m.conns[c]
|
||||||
|
if !ok {
|
||||||
|
return errors.New("unknown packet conn %s")
|
||||||
}
|
}
|
||||||
sessions.Add(connID, handler)
|
sessions.Add(connID, handler)
|
||||||
if ok {
|
return nil
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we didn't know this packet conn before, listen for incoming packets
|
|
||||||
// and dispatch them to the right sessions.
|
|
||||||
go m.listen(c, sessions)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *clientMultiplexer) listen(c net.PacketConn, sessions packetHandlerManager) {
|
func (m *clientMultiplexer) listen(c net.PacketConn, sessions packetHandlerManager) {
|
||||||
|
@ -83,6 +91,10 @@ func (m *clientMultiplexer) listen(c net.PacketConn, sessions packetHandlerManag
|
||||||
m.logger.Debugf("received a packet with an unexpected connection ID %s", iHdr.DestConnectionID)
|
m.logger.Debugf("received a packet with an unexpected connection ID %s", iHdr.DestConnectionID)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if client == nil {
|
||||||
|
// Late packet for closed session
|
||||||
|
continue
|
||||||
|
}
|
||||||
hdr, err := iHdr.Parse(r, protocol.PerspectiveServer, client.GetVersion())
|
hdr, err := iHdr.Parse(r, protocol.PerspectiveServer, client.GetVersion())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
m.logger.Debugf("error parsing header from %s: %s", addr, err)
|
m.logger.Debugf("error parsing header from %s: %s", addr, err)
|
||||||
|
|
|
@ -27,20 +27,28 @@ var _ = Describe("Client Multiplexer", func() {
|
||||||
It("adds a new packet conn and handles packets", func() {
|
It("adds a new packet conn and handles packets", func() {
|
||||||
conn := newMockPacketConn()
|
conn := newMockPacketConn()
|
||||||
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||||
conn.dataToRead <- getPacket(connID)
|
|
||||||
packetHandler := NewMockQuicSession(mockCtrl)
|
packetHandler := NewMockQuicSession(mockCtrl)
|
||||||
handledPacket := make(chan struct{})
|
handledPacket := make(chan struct{})
|
||||||
packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(_ *receivedPacket) {
|
packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(_ *receivedPacket) {
|
||||||
close(handledPacket)
|
close(handledPacket)
|
||||||
})
|
})
|
||||||
packetHandler.EXPECT().GetVersion()
|
packetHandler.EXPECT().GetVersion()
|
||||||
getClientMultiplexer().Add(conn, connID, packetHandler)
|
getClientMultiplexer().AddConn(conn)
|
||||||
|
err := getClientMultiplexer().AddHandler(conn, connID, packetHandler)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
conn.dataToRead <- getPacket(connID)
|
||||||
Eventually(handledPacket).Should(BeClosed())
|
Eventually(handledPacket).Should(BeClosed())
|
||||||
// makes the listen go routine return
|
// makes the listen go routine return
|
||||||
packetHandler.EXPECT().Close(gomock.Any()).AnyTimes()
|
packetHandler.EXPECT().Close(gomock.Any()).AnyTimes()
|
||||||
close(conn.dataToRead)
|
close(conn.dataToRead)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("errors when adding a handler for an unknown conn", func() {
|
||||||
|
conn := newMockPacketConn()
|
||||||
|
err := getClientMultiplexer().AddHandler(conn, protocol.ConnectionID{1, 2, 3, 4}, NewMockQuicSession(mockCtrl))
|
||||||
|
Expect(err).ToNot(MatchError("unknown packet conn"))
|
||||||
|
})
|
||||||
|
|
||||||
It("handles packets for different packet handlers on the same packet conn", func() {
|
It("handles packets for different packet handlers on the same packet conn", func() {
|
||||||
conn := newMockPacketConn()
|
conn := newMockPacketConn()
|
||||||
connID1 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
connID1 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||||
|
@ -59,8 +67,9 @@ var _ = Describe("Client Multiplexer", func() {
|
||||||
close(handledPacket2)
|
close(handledPacket2)
|
||||||
})
|
})
|
||||||
packetHandler2.EXPECT().GetVersion()
|
packetHandler2.EXPECT().GetVersion()
|
||||||
getClientMultiplexer().Add(conn, connID1, packetHandler1)
|
getClientMultiplexer().AddConn(conn)
|
||||||
getClientMultiplexer().Add(conn, connID2, packetHandler2)
|
Expect(getClientMultiplexer().AddHandler(conn, connID1, packetHandler1)).To(Succeed())
|
||||||
|
Expect(getClientMultiplexer().AddHandler(conn, connID2, packetHandler2)).To(Succeed())
|
||||||
|
|
||||||
conn.dataToRead <- getPacket(connID1)
|
conn.dataToRead <- getPacket(connID1)
|
||||||
conn.dataToRead <- getPacket(connID2)
|
conn.dataToRead <- getPacket(connID2)
|
||||||
|
@ -78,17 +87,39 @@ var _ = Describe("Client Multiplexer", func() {
|
||||||
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||||
conn.dataToRead <- []byte("invalid header")
|
conn.dataToRead <- []byte("invalid header")
|
||||||
packetHandler := NewMockQuicSession(mockCtrl)
|
packetHandler := NewMockQuicSession(mockCtrl)
|
||||||
getClientMultiplexer().Add(conn, connID, packetHandler)
|
getClientMultiplexer().AddConn(conn)
|
||||||
|
Expect(getClientMultiplexer().AddHandler(conn, connID, packetHandler)).To(Succeed())
|
||||||
time.Sleep(100 * time.Millisecond) // give the listen go routine some time to process the packet
|
time.Sleep(100 * time.Millisecond) // give the listen go routine some time to process the packet
|
||||||
packetHandler.EXPECT().Close(gomock.Any()).AnyTimes()
|
packetHandler.EXPECT().Close(gomock.Any()).AnyTimes()
|
||||||
close(conn.dataToRead)
|
close(conn.dataToRead)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("ignores packets arriving late for closed sessions", func() {
|
||||||
|
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||||
|
origNewPacketHandlerManager := getClientMultiplexer().newPacketHandlerManager
|
||||||
|
defer func() {
|
||||||
|
getClientMultiplexer().newPacketHandlerManager = origNewPacketHandlerManager
|
||||||
|
}()
|
||||||
|
getClientMultiplexer().newPacketHandlerManager = func() packetHandlerManager { return manager }
|
||||||
|
|
||||||
|
conn := newMockPacketConn()
|
||||||
|
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||||
|
done := make(chan struct{})
|
||||||
|
manager.EXPECT().Get(connID).Do(func(protocol.ConnectionID) { close(done) }).Return(nil, true)
|
||||||
|
getClientMultiplexer().AddConn(conn)
|
||||||
|
conn.dataToRead <- getPacket(connID)
|
||||||
|
Eventually(done).Should(BeClosed())
|
||||||
|
// makes the listen go routine return
|
||||||
|
manager.EXPECT().Close(gomock.Any()).AnyTimes()
|
||||||
|
close(conn.dataToRead)
|
||||||
|
})
|
||||||
|
|
||||||
It("drops packets for unknown receivers", func() {
|
It("drops packets for unknown receivers", func() {
|
||||||
conn := newMockPacketConn()
|
conn := newMockPacketConn()
|
||||||
conn.dataToRead <- getPacket(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8})
|
conn.dataToRead <- getPacket(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8})
|
||||||
packetHandler := NewMockQuicSession(mockCtrl)
|
packetHandler := NewMockQuicSession(mockCtrl)
|
||||||
getClientMultiplexer().Add(conn, protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, packetHandler)
|
getClientMultiplexer().AddConn(conn)
|
||||||
|
Expect(getClientMultiplexer().AddHandler(conn, protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, packetHandler)).To(Succeed())
|
||||||
time.Sleep(100 * time.Millisecond) // give the listen go routine some time to process the packet
|
time.Sleep(100 * time.Millisecond) // give the listen go routine some time to process the packet
|
||||||
// makes the listen go routine return
|
// makes the listen go routine return
|
||||||
packetHandler.EXPECT().Close(gomock.Any()).AnyTimes()
|
packetHandler.EXPECT().Close(gomock.Any()).AnyTimes()
|
||||||
|
@ -104,7 +135,8 @@ var _ = Describe("Client Multiplexer", func() {
|
||||||
packetHandler.EXPECT().Close(testErr).Do(func(error) {
|
packetHandler.EXPECT().Close(testErr).Do(func(error) {
|
||||||
close(done)
|
close(done)
|
||||||
})
|
})
|
||||||
getClientMultiplexer().Add(conn, protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, packetHandler)
|
getClientMultiplexer().AddConn(conn)
|
||||||
|
Expect(getClientMultiplexer().AddHandler(conn, protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, packetHandler)).To(Succeed())
|
||||||
Eventually(done).Should(BeClosed())
|
Eventually(done).Should(BeClosed())
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
|
@ -251,7 +251,7 @@ var _ = Describe("Client", func() {
|
||||||
|
|
||||||
It("errors when the Config contains an invalid version", func() {
|
It("errors when the Config contains an invalid version", func() {
|
||||||
version := protocol.VersionNumber(0x1234)
|
version := protocol.VersionNumber(0x1234)
|
||||||
_, err := Dial(nil, nil, "localhost:1234", &tls.Config{}, &Config{Versions: []protocol.VersionNumber{version}})
|
_, err := Dial(packetConn, nil, "localhost:1234", &tls.Config{}, &Config{Versions: []protocol.VersionNumber{version}})
|
||||||
Expect(err).To(MatchError("0x1234 is not a valid QUIC version"))
|
Expect(err).To(MatchError("0x1234 is not a valid QUIC version"))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -536,7 +536,9 @@ runLoop:
|
||||||
s.logger.Infof("Handling close error failed: %s", err)
|
s.logger.Infof("Handling close error failed: %s", err)
|
||||||
}
|
}
|
||||||
s.logger.Infof("Connection %s closed.", s.srcConnID)
|
s.logger.Infof("Connection %s closed.", s.srcConnID)
|
||||||
s.sessionRunner.removeConnectionID(s.srcConnID)
|
if closeErr.err != handshake.ErrCloseSessionForRetry {
|
||||||
|
s.sessionRunner.removeConnectionID(s.srcConnID)
|
||||||
|
}
|
||||||
return closeErr.err
|
return closeErr.err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue