mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-06 21:57:36 +03:00
remove closed clients from the multiplexer
This commit is contained in:
parent
7658f10a51
commit
0928e91e4d
5 changed files with 90 additions and 27 deletions
29
client.go
29
client.go
|
@ -41,6 +41,7 @@ type client struct {
|
|||
version protocol.VersionNumber
|
||||
|
||||
handshakeChan chan struct{}
|
||||
closeCallback func(protocol.ConnectionID)
|
||||
|
||||
session quicSession
|
||||
|
||||
|
@ -81,7 +82,7 @@ func DialAddrContext(
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, err := newClient(udpConn, udpAddr, config, tlsConf, addr)
|
||||
c, err := newClient(udpConn, udpAddr, config, tlsConf, addr, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -114,18 +115,29 @@ func DialContext(
|
|||
tlsConf *tls.Config,
|
||||
config *Config,
|
||||
) (Session, error) {
|
||||
c, err := newClient(pconn, remoteAddr, config, tlsConf, host)
|
||||
multiplexer := getClientMultiplexer()
|
||||
manager := multiplexer.AddConn(pconn)
|
||||
c, err := newClient(pconn, remoteAddr, config, tlsConf, host, manager.Remove)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
getClientMultiplexer().Add(pconn, c.srcConnID, c)
|
||||
if err := multiplexer.AddHandler(pconn, c.srcConnID, c); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := c.dial(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.session, nil
|
||||
}
|
||||
|
||||
func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsConf *tls.Config, host string) (*client, error) {
|
||||
func newClient(
|
||||
pconn net.PacketConn,
|
||||
remoteAddr net.Addr,
|
||||
config *Config,
|
||||
tlsConf *tls.Config,
|
||||
host string,
|
||||
closeCallback func(protocol.ConnectionID),
|
||||
) (*client, error) {
|
||||
clientConfig := populateClientConfig(config)
|
||||
version := clientConfig.Versions[0]
|
||||
srcConnID, err := generateConnectionID()
|
||||
|
@ -159,6 +171,10 @@ func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsCon
|
|||
}
|
||||
}
|
||||
}
|
||||
onClose := func(protocol.ConnectionID) {}
|
||||
if closeCallback != nil {
|
||||
onClose = closeCallback
|
||||
}
|
||||
return &client{
|
||||
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
|
||||
srcConnID: srcConnID,
|
||||
|
@ -168,6 +184,7 @@ func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsCon
|
|||
config: clientConfig,
|
||||
version: version,
|
||||
handshakeChan: make(chan struct{}),
|
||||
closeCallback: onClose,
|
||||
logger: utils.DefaultLogger.WithPrefix("client"),
|
||||
}, nil
|
||||
}
|
||||
|
@ -508,7 +525,7 @@ func (c *client) createNewGQUICSession() (err error) {
|
|||
defer c.mutex.Unlock()
|
||||
runner := &runner{
|
||||
onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) },
|
||||
removeConnectionIDImpl: func(protocol.ConnectionID) {},
|
||||
removeConnectionIDImpl: c.closeCallback,
|
||||
}
|
||||
c.session, err = newClientSession(
|
||||
c.conn,
|
||||
|
@ -533,7 +550,7 @@ func (c *client) createNewTLSSession(
|
|||
defer c.mutex.Unlock()
|
||||
runner := &runner{
|
||||
onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) },
|
||||
removeConnectionIDImpl: func(protocol.ConnectionID) {},
|
||||
removeConnectionIDImpl: c.closeCallback,
|
||||
}
|
||||
c.session, err = newTLSClientSession(
|
||||
c.conn,
|
||||
|
|
|
@ -2,6 +2,7 @@ package quic
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
|
@ -22,7 +23,8 @@ var (
|
|||
type clientMultiplexer struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
conns map[net.PacketConn]packetHandlerManager
|
||||
conns map[net.PacketConn]packetHandlerManager
|
||||
newPacketHandlerManager func() packetHandlerManager // so it can be replaced in the tests
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
@ -30,29 +32,35 @@ type clientMultiplexer struct {
|
|||
func getClientMultiplexer() *clientMultiplexer {
|
||||
clientMuxerOnce.Do(func() {
|
||||
clientMuxer = &clientMultiplexer{
|
||||
conns: make(map[net.PacketConn]packetHandlerManager),
|
||||
logger: utils.DefaultLogger.WithPrefix("client muxer"),
|
||||
conns: make(map[net.PacketConn]packetHandlerManager),
|
||||
logger: utils.DefaultLogger.WithPrefix("client muxer"),
|
||||
newPacketHandlerManager: newPacketHandlerMap,
|
||||
}
|
||||
})
|
||||
return clientMuxer
|
||||
}
|
||||
|
||||
func (m *clientMultiplexer) Add(c net.PacketConn, connID protocol.ConnectionID, handler packetHandler) {
|
||||
func (m *clientMultiplexer) AddConn(c net.PacketConn) packetHandlerManager {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
sessions, ok := m.conns[c]
|
||||
if !ok {
|
||||
sessions = newPacketHandlerMap()
|
||||
sessions = m.newPacketHandlerManager()
|
||||
m.conns[c] = sessions
|
||||
// If we didn't know this packet conn before, listen for incoming packets
|
||||
// and dispatch them to the right sessions.
|
||||
go m.listen(c, sessions)
|
||||
}
|
||||
return sessions
|
||||
}
|
||||
|
||||
func (m *clientMultiplexer) AddHandler(c net.PacketConn, connID protocol.ConnectionID, handler packetHandler) error {
|
||||
sessions, ok := m.conns[c]
|
||||
if !ok {
|
||||
return errors.New("unknown packet conn %s")
|
||||
}
|
||||
sessions.Add(connID, handler)
|
||||
if ok {
|
||||
return
|
||||
}
|
||||
|
||||
// If we didn't know this packet conn before, listen for incoming packets
|
||||
// and dispatch them to the right sessions.
|
||||
go m.listen(c, sessions)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *clientMultiplexer) listen(c net.PacketConn, sessions packetHandlerManager) {
|
||||
|
@ -83,6 +91,10 @@ func (m *clientMultiplexer) listen(c net.PacketConn, sessions packetHandlerManag
|
|||
m.logger.Debugf("received a packet with an unexpected connection ID %s", iHdr.DestConnectionID)
|
||||
continue
|
||||
}
|
||||
if client == nil {
|
||||
// Late packet for closed session
|
||||
continue
|
||||
}
|
||||
hdr, err := iHdr.Parse(r, protocol.PerspectiveServer, client.GetVersion())
|
||||
if err != nil {
|
||||
m.logger.Debugf("error parsing header from %s: %s", addr, err)
|
||||
|
|
|
@ -27,20 +27,28 @@ var _ = Describe("Client Multiplexer", func() {
|
|||
It("adds a new packet conn and handles packets", func() {
|
||||
conn := newMockPacketConn()
|
||||
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
conn.dataToRead <- getPacket(connID)
|
||||
packetHandler := NewMockQuicSession(mockCtrl)
|
||||
handledPacket := make(chan struct{})
|
||||
packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(_ *receivedPacket) {
|
||||
close(handledPacket)
|
||||
})
|
||||
packetHandler.EXPECT().GetVersion()
|
||||
getClientMultiplexer().Add(conn, connID, packetHandler)
|
||||
getClientMultiplexer().AddConn(conn)
|
||||
err := getClientMultiplexer().AddHandler(conn, connID, packetHandler)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
conn.dataToRead <- getPacket(connID)
|
||||
Eventually(handledPacket).Should(BeClosed())
|
||||
// makes the listen go routine return
|
||||
packetHandler.EXPECT().Close(gomock.Any()).AnyTimes()
|
||||
close(conn.dataToRead)
|
||||
})
|
||||
|
||||
It("errors when adding a handler for an unknown conn", func() {
|
||||
conn := newMockPacketConn()
|
||||
err := getClientMultiplexer().AddHandler(conn, protocol.ConnectionID{1, 2, 3, 4}, NewMockQuicSession(mockCtrl))
|
||||
Expect(err).ToNot(MatchError("unknown packet conn"))
|
||||
})
|
||||
|
||||
It("handles packets for different packet handlers on the same packet conn", func() {
|
||||
conn := newMockPacketConn()
|
||||
connID1 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
|
@ -59,8 +67,9 @@ var _ = Describe("Client Multiplexer", func() {
|
|||
close(handledPacket2)
|
||||
})
|
||||
packetHandler2.EXPECT().GetVersion()
|
||||
getClientMultiplexer().Add(conn, connID1, packetHandler1)
|
||||
getClientMultiplexer().Add(conn, connID2, packetHandler2)
|
||||
getClientMultiplexer().AddConn(conn)
|
||||
Expect(getClientMultiplexer().AddHandler(conn, connID1, packetHandler1)).To(Succeed())
|
||||
Expect(getClientMultiplexer().AddHandler(conn, connID2, packetHandler2)).To(Succeed())
|
||||
|
||||
conn.dataToRead <- getPacket(connID1)
|
||||
conn.dataToRead <- getPacket(connID2)
|
||||
|
@ -78,17 +87,39 @@ var _ = Describe("Client Multiplexer", func() {
|
|||
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
conn.dataToRead <- []byte("invalid header")
|
||||
packetHandler := NewMockQuicSession(mockCtrl)
|
||||
getClientMultiplexer().Add(conn, connID, packetHandler)
|
||||
getClientMultiplexer().AddConn(conn)
|
||||
Expect(getClientMultiplexer().AddHandler(conn, connID, packetHandler)).To(Succeed())
|
||||
time.Sleep(100 * time.Millisecond) // give the listen go routine some time to process the packet
|
||||
packetHandler.EXPECT().Close(gomock.Any()).AnyTimes()
|
||||
close(conn.dataToRead)
|
||||
})
|
||||
|
||||
It("ignores packets arriving late for closed sessions", func() {
|
||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||
origNewPacketHandlerManager := getClientMultiplexer().newPacketHandlerManager
|
||||
defer func() {
|
||||
getClientMultiplexer().newPacketHandlerManager = origNewPacketHandlerManager
|
||||
}()
|
||||
getClientMultiplexer().newPacketHandlerManager = func() packetHandlerManager { return manager }
|
||||
|
||||
conn := newMockPacketConn()
|
||||
connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
done := make(chan struct{})
|
||||
manager.EXPECT().Get(connID).Do(func(protocol.ConnectionID) { close(done) }).Return(nil, true)
|
||||
getClientMultiplexer().AddConn(conn)
|
||||
conn.dataToRead <- getPacket(connID)
|
||||
Eventually(done).Should(BeClosed())
|
||||
// makes the listen go routine return
|
||||
manager.EXPECT().Close(gomock.Any()).AnyTimes()
|
||||
close(conn.dataToRead)
|
||||
})
|
||||
|
||||
It("drops packets for unknown receivers", func() {
|
||||
conn := newMockPacketConn()
|
||||
conn.dataToRead <- getPacket(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8})
|
||||
packetHandler := NewMockQuicSession(mockCtrl)
|
||||
getClientMultiplexer().Add(conn, protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, packetHandler)
|
||||
getClientMultiplexer().AddConn(conn)
|
||||
Expect(getClientMultiplexer().AddHandler(conn, protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, packetHandler)).To(Succeed())
|
||||
time.Sleep(100 * time.Millisecond) // give the listen go routine some time to process the packet
|
||||
// makes the listen go routine return
|
||||
packetHandler.EXPECT().Close(gomock.Any()).AnyTimes()
|
||||
|
@ -104,7 +135,8 @@ var _ = Describe("Client Multiplexer", func() {
|
|||
packetHandler.EXPECT().Close(testErr).Do(func(error) {
|
||||
close(done)
|
||||
})
|
||||
getClientMultiplexer().Add(conn, protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, packetHandler)
|
||||
getClientMultiplexer().AddConn(conn)
|
||||
Expect(getClientMultiplexer().AddHandler(conn, protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, packetHandler)).To(Succeed())
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
})
|
||||
|
|
|
@ -251,7 +251,7 @@ var _ = Describe("Client", func() {
|
|||
|
||||
It("errors when the Config contains an invalid version", func() {
|
||||
version := protocol.VersionNumber(0x1234)
|
||||
_, err := Dial(nil, nil, "localhost:1234", &tls.Config{}, &Config{Versions: []protocol.VersionNumber{version}})
|
||||
_, err := Dial(packetConn, nil, "localhost:1234", &tls.Config{}, &Config{Versions: []protocol.VersionNumber{version}})
|
||||
Expect(err).To(MatchError("0x1234 is not a valid QUIC version"))
|
||||
})
|
||||
|
||||
|
|
|
@ -536,7 +536,9 @@ runLoop:
|
|||
s.logger.Infof("Handling close error failed: %s", err)
|
||||
}
|
||||
s.logger.Infof("Connection %s closed.", s.srcConnID)
|
||||
s.sessionRunner.removeConnectionID(s.srcConnID)
|
||||
if closeErr.err != handshake.ErrCloseSessionForRetry {
|
||||
s.sessionRunner.removeConnectionID(s.srcConnID)
|
||||
}
|
||||
return closeErr.err
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue