diff --git a/client.go b/client.go index fb05e62b..e8e92ddf 100644 --- a/client.go +++ b/client.go @@ -41,6 +41,7 @@ type client struct { version protocol.VersionNumber handshakeChan chan struct{} + closeCallback func(protocol.ConnectionID) session quicSession @@ -81,7 +82,7 @@ func DialAddrContext( if err != nil { return nil, err } - c, err := newClient(udpConn, udpAddr, config, tlsConf, addr) + c, err := newClient(udpConn, udpAddr, config, tlsConf, addr, nil) if err != nil { return nil, err } @@ -114,18 +115,29 @@ func DialContext( tlsConf *tls.Config, config *Config, ) (Session, error) { - c, err := newClient(pconn, remoteAddr, config, tlsConf, host) + multiplexer := getClientMultiplexer() + manager := multiplexer.AddConn(pconn) + c, err := newClient(pconn, remoteAddr, config, tlsConf, host, manager.Remove) if err != nil { return nil, err } - getClientMultiplexer().Add(pconn, c.srcConnID, c) + if err := multiplexer.AddHandler(pconn, c.srcConnID, c); err != nil { + return nil, err + } if err := c.dial(ctx); err != nil { return nil, err } return c.session, nil } -func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsConf *tls.Config, host string) (*client, error) { +func newClient( + pconn net.PacketConn, + remoteAddr net.Addr, + config *Config, + tlsConf *tls.Config, + host string, + closeCallback func(protocol.ConnectionID), +) (*client, error) { clientConfig := populateClientConfig(config) version := clientConfig.Versions[0] srcConnID, err := generateConnectionID() @@ -159,6 +171,10 @@ func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsCon } } } + onClose := func(protocol.ConnectionID) {} + if closeCallback != nil { + onClose = closeCallback + } return &client{ conn: &conn{pconn: pconn, currentAddr: remoteAddr}, srcConnID: srcConnID, @@ -168,6 +184,7 @@ func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsCon config: clientConfig, version: version, handshakeChan: make(chan struct{}), + closeCallback: onClose, logger: utils.DefaultLogger.WithPrefix("client"), }, nil } @@ -508,7 +525,7 @@ func (c *client) createNewGQUICSession() (err error) { defer c.mutex.Unlock() runner := &runner{ onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) }, - removeConnectionIDImpl: func(protocol.ConnectionID) {}, + removeConnectionIDImpl: c.closeCallback, } c.session, err = newClientSession( c.conn, @@ -533,7 +550,7 @@ func (c *client) createNewTLSSession( defer c.mutex.Unlock() runner := &runner{ onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) }, - removeConnectionIDImpl: func(protocol.ConnectionID) {}, + removeConnectionIDImpl: c.closeCallback, } c.session, err = newTLSClientSession( c.conn, diff --git a/client_multiplexer.go b/client_multiplexer.go index a25f57be..5b541e7e 100644 --- a/client_multiplexer.go +++ b/client_multiplexer.go @@ -2,6 +2,7 @@ package quic import ( "bytes" + "errors" "net" "strings" "sync" @@ -22,7 +23,8 @@ var ( type clientMultiplexer struct { mutex sync.Mutex - conns map[net.PacketConn]packetHandlerManager + conns map[net.PacketConn]packetHandlerManager + newPacketHandlerManager func() packetHandlerManager // so it can be replaced in the tests logger utils.Logger } @@ -30,29 +32,35 @@ type clientMultiplexer struct { func getClientMultiplexer() *clientMultiplexer { clientMuxerOnce.Do(func() { clientMuxer = &clientMultiplexer{ - conns: make(map[net.PacketConn]packetHandlerManager), - logger: utils.DefaultLogger.WithPrefix("client muxer"), + conns: make(map[net.PacketConn]packetHandlerManager), + logger: utils.DefaultLogger.WithPrefix("client muxer"), + newPacketHandlerManager: newPacketHandlerMap, } }) return clientMuxer } -func (m *clientMultiplexer) Add(c net.PacketConn, connID protocol.ConnectionID, handler packetHandler) { +func (m *clientMultiplexer) AddConn(c net.PacketConn) packetHandlerManager { m.mutex.Lock() defer m.mutex.Unlock() sessions, ok := m.conns[c] if !ok { - sessions = newPacketHandlerMap() + sessions = m.newPacketHandlerManager() m.conns[c] = sessions + // If we didn't know this packet conn before, listen for incoming packets + // and dispatch them to the right sessions. + go m.listen(c, sessions) + } + return sessions +} + +func (m *clientMultiplexer) AddHandler(c net.PacketConn, connID protocol.ConnectionID, handler packetHandler) error { + sessions, ok := m.conns[c] + if !ok { + return errors.New("unknown packet conn %s") } sessions.Add(connID, handler) - if ok { - return - } - - // If we didn't know this packet conn before, listen for incoming packets - // and dispatch them to the right sessions. - go m.listen(c, sessions) + return nil } func (m *clientMultiplexer) listen(c net.PacketConn, sessions packetHandlerManager) { @@ -83,6 +91,10 @@ func (m *clientMultiplexer) listen(c net.PacketConn, sessions packetHandlerManag m.logger.Debugf("received a packet with an unexpected connection ID %s", iHdr.DestConnectionID) continue } + if client == nil { + // Late packet for closed session + continue + } hdr, err := iHdr.Parse(r, protocol.PerspectiveServer, client.GetVersion()) if err != nil { m.logger.Debugf("error parsing header from %s: %s", addr, err) diff --git a/client_multiplexer_test.go b/client_multiplexer_test.go index 76493f66..77ada895 100644 --- a/client_multiplexer_test.go +++ b/client_multiplexer_test.go @@ -27,20 +27,28 @@ var _ = Describe("Client Multiplexer", func() { It("adds a new packet conn and handles packets", func() { conn := newMockPacketConn() connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - conn.dataToRead <- getPacket(connID) packetHandler := NewMockQuicSession(mockCtrl) handledPacket := make(chan struct{}) packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(_ *receivedPacket) { close(handledPacket) }) packetHandler.EXPECT().GetVersion() - getClientMultiplexer().Add(conn, connID, packetHandler) + getClientMultiplexer().AddConn(conn) + err := getClientMultiplexer().AddHandler(conn, connID, packetHandler) + Expect(err).ToNot(HaveOccurred()) + conn.dataToRead <- getPacket(connID) Eventually(handledPacket).Should(BeClosed()) // makes the listen go routine return packetHandler.EXPECT().Close(gomock.Any()).AnyTimes() close(conn.dataToRead) }) + It("errors when adding a handler for an unknown conn", func() { + conn := newMockPacketConn() + err := getClientMultiplexer().AddHandler(conn, protocol.ConnectionID{1, 2, 3, 4}, NewMockQuicSession(mockCtrl)) + Expect(err).ToNot(MatchError("unknown packet conn")) + }) + It("handles packets for different packet handlers on the same packet conn", func() { conn := newMockPacketConn() connID1 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} @@ -59,8 +67,9 @@ var _ = Describe("Client Multiplexer", func() { close(handledPacket2) }) packetHandler2.EXPECT().GetVersion() - getClientMultiplexer().Add(conn, connID1, packetHandler1) - getClientMultiplexer().Add(conn, connID2, packetHandler2) + getClientMultiplexer().AddConn(conn) + Expect(getClientMultiplexer().AddHandler(conn, connID1, packetHandler1)).To(Succeed()) + Expect(getClientMultiplexer().AddHandler(conn, connID2, packetHandler2)).To(Succeed()) conn.dataToRead <- getPacket(connID1) conn.dataToRead <- getPacket(connID2) @@ -78,17 +87,39 @@ var _ = Describe("Client Multiplexer", func() { connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} conn.dataToRead <- []byte("invalid header") packetHandler := NewMockQuicSession(mockCtrl) - getClientMultiplexer().Add(conn, connID, packetHandler) + getClientMultiplexer().AddConn(conn) + Expect(getClientMultiplexer().AddHandler(conn, connID, packetHandler)).To(Succeed()) time.Sleep(100 * time.Millisecond) // give the listen go routine some time to process the packet packetHandler.EXPECT().Close(gomock.Any()).AnyTimes() close(conn.dataToRead) }) + It("ignores packets arriving late for closed sessions", func() { + manager := NewMockPacketHandlerManager(mockCtrl) + origNewPacketHandlerManager := getClientMultiplexer().newPacketHandlerManager + defer func() { + getClientMultiplexer().newPacketHandlerManager = origNewPacketHandlerManager + }() + getClientMultiplexer().newPacketHandlerManager = func() packetHandlerManager { return manager } + + conn := newMockPacketConn() + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + done := make(chan struct{}) + manager.EXPECT().Get(connID).Do(func(protocol.ConnectionID) { close(done) }).Return(nil, true) + getClientMultiplexer().AddConn(conn) + conn.dataToRead <- getPacket(connID) + Eventually(done).Should(BeClosed()) + // makes the listen go routine return + manager.EXPECT().Close(gomock.Any()).AnyTimes() + close(conn.dataToRead) + }) + It("drops packets for unknown receivers", func() { conn := newMockPacketConn() conn.dataToRead <- getPacket(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}) packetHandler := NewMockQuicSession(mockCtrl) - getClientMultiplexer().Add(conn, protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, packetHandler) + getClientMultiplexer().AddConn(conn) + Expect(getClientMultiplexer().AddHandler(conn, protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, packetHandler)).To(Succeed()) time.Sleep(100 * time.Millisecond) // give the listen go routine some time to process the packet // makes the listen go routine return packetHandler.EXPECT().Close(gomock.Any()).AnyTimes() @@ -104,7 +135,8 @@ var _ = Describe("Client Multiplexer", func() { packetHandler.EXPECT().Close(testErr).Do(func(error) { close(done) }) - getClientMultiplexer().Add(conn, protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, packetHandler) + getClientMultiplexer().AddConn(conn) + Expect(getClientMultiplexer().AddHandler(conn, protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, packetHandler)).To(Succeed()) Eventually(done).Should(BeClosed()) }) }) diff --git a/client_test.go b/client_test.go index 7b4284ab..e910521c 100644 --- a/client_test.go +++ b/client_test.go @@ -251,7 +251,7 @@ var _ = Describe("Client", func() { It("errors when the Config contains an invalid version", func() { version := protocol.VersionNumber(0x1234) - _, err := Dial(nil, nil, "localhost:1234", &tls.Config{}, &Config{Versions: []protocol.VersionNumber{version}}) + _, err := Dial(packetConn, nil, "localhost:1234", &tls.Config{}, &Config{Versions: []protocol.VersionNumber{version}}) Expect(err).To(MatchError("0x1234 is not a valid QUIC version")) }) diff --git a/session.go b/session.go index c73e9953..347e8170 100644 --- a/session.go +++ b/session.go @@ -536,7 +536,9 @@ runLoop: s.logger.Infof("Handling close error failed: %s", err) } s.logger.Infof("Connection %s closed.", s.srcConnID) - s.sessionRunner.removeConnectionID(s.srcConnID) + if closeErr.err != handshake.ErrCloseSessionForRetry { + s.sessionRunner.removeConnectionID(s.srcConnID) + } return closeErr.err }