diff --git a/Changelog.md b/Changelog.md index d537032a..0e3e1d5b 100644 --- a/Changelog.md +++ b/Changelog.md @@ -6,6 +6,7 @@ - Add a `quic.Config` option for the maximum number of incoming streams. - Add support for QUIC 42 and 43. - Add dial functions that use a context. +- Multiplex clients on a net.PacketConn, when using Dial(conn). ## v0.7.0 (2018-02-03) diff --git a/client.go b/client.go index b3b596d0..65da2541 100644 --- a/client.go +++ b/client.go @@ -47,6 +47,8 @@ type client struct { logger utils.Logger } +var _ packetHandler = &client{} + var ( // make it possible to mock connection ID generation in the tests generateConnectionID = protocol.GenerateConnectionID @@ -79,7 +81,15 @@ func DialAddrContext( if err != nil { return nil, err } - return DialContext(ctx, udpConn, udpAddr, addr, tlsConf, config) + c, err := newClient(udpConn, udpAddr, config, tlsConf, addr) + if err != nil { + return nil, err + } + go c.listen() + if err := c.dial(ctx); err != nil { + return nil, err + } + return c.session, nil } // Dial establishes a new QUIC connection to a server using a net.PacketConn. @@ -104,6 +114,18 @@ func DialContext( tlsConf *tls.Config, config *Config, ) (Session, error) { + c, err := newClient(pconn, remoteAddr, config, tlsConf, host) + if err != nil { + return nil, err + } + getClientMultiplexer().Add(pconn, c.srcConnID, c) + if err := c.dial(ctx); err != nil { + return nil, err + } + return c.session, nil +} + +func newClient(pconn net.PacketConn, remoteAddr net.Addr, config *Config, tlsConf *tls.Config, host string) (*client, error) { clientConfig := populateClientConfig(config) version := clientConfig.Versions[0] srcConnID, err := generateConnectionID() @@ -137,8 +159,7 @@ func DialContext( } } } - - c := &client{ + return &client{ conn: &conn{pconn: pconn, currentAddr: remoteAddr}, srcConnID: srcConnID, destConnID: destConnID, @@ -148,14 +169,7 @@ func DialContext( version: version, handshakeChan: make(chan struct{}), logger: utils.DefaultLogger.WithPrefix("client"), - } - - c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", hostname, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version) - - if err := c.dial(ctx); err != nil { - return nil, err - } - return c.session, nil + }, nil } // populateClientConfig populates fields in the quic.Config with their default values, if none are set @@ -213,6 +227,8 @@ func populateClientConfig(config *Config) *Config { } func (c *client) dial(ctx context.Context) error { + c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.hostname, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version) + var err error if c.version.UsesTLS() { err = c.dialTLS(ctx) @@ -229,7 +245,6 @@ func (c *client) dialGQUIC(ctx context.Context) error { if err := c.createNewGQUICSession(); err != nil { return err } - go c.listen() return c.establishSecureConnection(ctx) } @@ -255,7 +270,6 @@ func (c *client) dialTLS(ctx context.Context) error { if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil { return err } - go c.listen() if err := c.establishSecureConnection(ctx); err != nil { if err != handshake.ErrCloseSessionForRetry { return err @@ -530,3 +544,12 @@ func (c *client) createNewTLSSession( ) return err } + +func (c *client) Close(err error) error { + c.mutex.Lock() + defer c.mutex.Unlock() + if c.session == nil { + return nil + } + return c.session.Close(err) +} diff --git a/client_multiplexer.go b/client_multiplexer.go new file mode 100644 index 00000000..2cc1197a --- /dev/null +++ b/client_multiplexer.go @@ -0,0 +1,96 @@ +package quic + +import ( + "bytes" + "net" + "strings" + "sync" + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/internal/wire" +) + +var ( + clientMuxerOnce sync.Once + clientMuxer *clientMultiplexer +) + +// The clientMultiplexer listens on multiple net.PacketConns and dispatches +// incoming packets to the session handler. +type clientMultiplexer struct { + mutex sync.Mutex + + conns map[net.PacketConn]packetHandlerManager + + logger utils.Logger +} + +func getClientMultiplexer() *clientMultiplexer { + clientMuxerOnce.Do(func() { + clientMuxer = &clientMultiplexer{ + conns: make(map[net.PacketConn]packetHandlerManager), + logger: utils.DefaultLogger.WithPrefix("client muxer"), + } + }) + return clientMuxer +} + +func (m *clientMultiplexer) Add(c net.PacketConn, connID protocol.ConnectionID, handler packetHandler) { + m.mutex.Lock() + defer m.mutex.Unlock() + sessions, ok := m.conns[c] + if !ok { + sessions = newPacketHandlerMap() + m.conns[c] = sessions + } + sessions.Add(connID, handler) + if ok { + return + } + + // If we didn't know this packet conn before, listen for incoming packets + // and dispatch them to the right sessions. + go m.listen(c, sessions) +} + +func (m *clientMultiplexer) listen(c net.PacketConn, sessions packetHandlerManager) { + for { + data := *getPacketBuffer() + data = data[:protocol.MaxReceivePacketSize] + // The packet size should not exceed protocol.MaxReceivePacketSize bytes + // If it does, we only read a truncated packet, which will then end up undecryptable + n, addr, err := c.ReadFrom(data) + if err != nil { + if !strings.HasSuffix(err.Error(), "use of closed network connection") { + sessions.Close(err) + } + return + } + data = data[:n] + rcvTime := time.Now() + + r := bytes.NewReader(data) + hdr, err := wire.ParseHeaderSentByServer(r) + // drop the packet if we can't parse the header + if err != nil { + m.logger.Debugf("error parsing packet from %s: %s", addr, err) + continue + } + hdr.Raw = data[:len(data)-r.Len()] + packetData := data[len(data)-r.Len():] + + client, ok := sessions.Get(hdr.DestConnectionID) + if !ok { + m.logger.Debugf("received a packet with an unexpected connection ID %s", hdr.DestConnectionID) + continue + } + client.handlePacket(&receivedPacket{ + remoteAddr: addr, + header: hdr, + data: packetData, + rcvTime: rcvTime, + }) + } +} diff --git a/client_multiplexer_test.go b/client_multiplexer_test.go new file mode 100644 index 00000000..b848c7d4 --- /dev/null +++ b/client_multiplexer_test.go @@ -0,0 +1,106 @@ +package quic + +import ( + "bytes" + "errors" + "time" + + "github.com/golang/mock/gomock" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Client Multiplexer", func() { + getPacket := func(connID protocol.ConnectionID) []byte { + buf := &bytes.Buffer{} + err := (&wire.Header{ + SrcConnectionID: connID, + DestConnectionID: connID, + PacketNumberLen: protocol.PacketNumberLen1, + }).Write(buf, protocol.PerspectiveServer, protocol.VersionWhatever) + Expect(err).ToNot(HaveOccurred()) + return buf.Bytes() + } + + It("adds a new packet conn and handles packets", func() { + conn := newMockPacketConn() + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + conn.dataToRead <- getPacket(connID) + packetHandler := NewMockQuicSession(mockCtrl) + handledPacket := make(chan struct{}) + packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(_ *receivedPacket) { + close(handledPacket) + }) + getClientMultiplexer().Add(conn, connID, packetHandler) + Eventually(handledPacket).Should(BeClosed()) + // makes the listen go routine return + packetHandler.EXPECT().Close(gomock.Any()).AnyTimes() + close(conn.dataToRead) + }) + + It("handles packets for different packet handlers on the same packet conn", func() { + conn := newMockPacketConn() + connID1 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + connID2 := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} + conn.dataToRead <- getPacket(connID1) + conn.dataToRead <- getPacket(connID2) + packetHandler1 := NewMockQuicSession(mockCtrl) + packetHandler2 := NewMockQuicSession(mockCtrl) + handledPacket1 := make(chan struct{}) + handledPacket2 := make(chan struct{}) + packetHandler1.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { + Expect(p.header.DestConnectionID).To(Equal(connID1)) + close(handledPacket1) + }) + packetHandler2.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { + Expect(p.header.DestConnectionID).To(Equal(connID2)) + close(handledPacket2) + }) + getClientMultiplexer().Add(conn, connID1, packetHandler1) + getClientMultiplexer().Add(conn, connID2, packetHandler2) + Eventually(handledPacket1).Should(BeClosed()) + Eventually(handledPacket2).Should(BeClosed()) + // makes the listen go routine return + packetHandler1.EXPECT().Close(gomock.Any()).AnyTimes() + packetHandler2.EXPECT().Close(gomock.Any()).AnyTimes() + close(conn.dataToRead) + }) + + It("drops unparseable packets", func() { + conn := newMockPacketConn() + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + conn.dataToRead <- []byte("invalid header") + packetHandler := NewMockQuicSession(mockCtrl) + getClientMultiplexer().Add(conn, connID, packetHandler) + time.Sleep(100 * time.Millisecond) // give the listen go routine some time to process the packet + packetHandler.EXPECT().Close(gomock.Any()).AnyTimes() + close(conn.dataToRead) + }) + + It("drops packets for unknown receivers", func() { + conn := newMockPacketConn() + conn.dataToRead <- getPacket(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}) + packetHandler := NewMockQuicSession(mockCtrl) + getClientMultiplexer().Add(conn, protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, packetHandler) + time.Sleep(100 * time.Millisecond) // give the listen go routine some time to process the packet + // makes the listen go routine return + packetHandler.EXPECT().Close(gomock.Any()).AnyTimes() + close(conn.dataToRead) + }) + + It("closes the packet handlers when reading from the conn fails", func() { + conn := newMockPacketConn() + testErr := errors.New("test error") + conn.readErr = testErr + done := make(chan struct{}) + packetHandler := NewMockQuicSession(mockCtrl) + packetHandler.EXPECT().Close(testErr).Do(func(error) { + close(done) + }) + getClientMultiplexer().Add(conn, protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, packetHandler) + Eventually(done).Should(BeClosed()) + }) +}) diff --git a/mock_packet_handler_manager_test.go b/mock_packet_handler_manager_test.go index 120bd8a4..6bb1b7b4 100644 --- a/mock_packet_handler_manager_test.go +++ b/mock_packet_handler_manager_test.go @@ -45,13 +45,13 @@ func (mr *MockPacketHandlerManagerMockRecorder) Add(arg0, arg1 interface{}) *gom } // Close mocks base method -func (m *MockPacketHandlerManager) Close() { - m.ctrl.Call(m, "Close") +func (m *MockPacketHandlerManager) Close(arg0 error) { + m.ctrl.Call(m, "Close", arg0) } // Close indicates an expected call of Close -func (mr *MockPacketHandlerManagerMockRecorder) Close() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPacketHandlerManager)(nil).Close)) +func (mr *MockPacketHandlerManagerMockRecorder) Close(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPacketHandlerManager)(nil).Close), arg0) } // Get mocks base method diff --git a/packet_handler_map.go b/packet_handler_map.go index ea8334e6..b8e5038f 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -7,6 +7,10 @@ import ( "github.com/lucas-clemente/quic-go/internal/protocol" ) +// The packetHandlerMap stores packetHandlers, identified by connection ID. +// It is used: +// * by the server to store sessions +// * when multiplexing outgoing connections to store clients type packetHandlerMap struct { mutex sync.RWMutex @@ -50,7 +54,7 @@ func (h *packetHandlerMap) Remove(id protocol.ConnectionID) { }) } -func (h *packetHandlerMap) Close() { +func (h *packetHandlerMap) Close(err error) { h.mutex.Lock() if h.closed { h.mutex.Unlock() @@ -64,7 +68,7 @@ func (h *packetHandlerMap) Close() { wg.Add(1) go func(handler packetHandler) { // session.Close() blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped - _ = handler.Close(nil) + _ = handler.Close(err) wg.Done() }(handler) } diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index a08c8e2b..380132ea 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -1,6 +1,7 @@ package quic import ( + "errors" "time" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -45,12 +46,13 @@ var _ = Describe("Packet Handler Map", func() { }) It("closes", func() { + testErr := errors.New("test error") sess1 := NewMockQuicSession(mockCtrl) - sess1.EXPECT().Close(nil) + sess1.EXPECT().Close(testErr) sess2 := NewMockQuicSession(mockCtrl) - sess2.EXPECT().Close(nil) + sess2.EXPECT().Close(testErr) handler.Add(protocol.ConnectionID{1, 1, 1, 1}, sess1) handler.Add(protocol.ConnectionID{2, 2, 2, 2}, sess2) - handler.Close() + handler.Close(testErr) }) }) diff --git a/server.go b/server.go index ba6970a4..426a7df8 100644 --- a/server.go +++ b/server.go @@ -26,7 +26,7 @@ type packetHandlerManager interface { Add(protocol.ConnectionID, packetHandler) Get(protocol.ConnectionID) (packetHandler, bool) Remove(protocol.ConnectionID) - Close() + Close(error) } type quicSession interface { @@ -288,7 +288,7 @@ func (s *server) Accept() (Session, error) { // Close the server func (s *server) Close() error { - s.sessionHandler.Close() + s.sessionHandler.Close(nil) err := s.conn.Close() <-s.errorChan // wait for serve() to return return err diff --git a/server_test.go b/server_test.go index fd2d8500..78f4f226 100644 --- a/server_test.go +++ b/server_test.go @@ -212,7 +212,7 @@ var _ = Describe("Server", func() { Expect(err).ToNot(HaveOccurred()) Consistently(done).ShouldNot(BeClosed()) // make the go routine return - sessionHandler.EXPECT().Close() + sessionHandler.EXPECT().Close(nil) close(serv.errorChan) serv.Close() Eventually(done).Should(BeClosed()) @@ -233,7 +233,7 @@ var _ = Describe("Server", func() { serv.serve() }() // close the server - sessionHandler.EXPECT().Close().AnyTimes() + sessionHandler.EXPECT().Close(nil).AnyTimes() Expect(serv.Close()).To(Succeed()) Expect(conn.closed).To(BeTrue()) }) @@ -270,7 +270,7 @@ var _ = Describe("Server", func() { It("errors when encountering a connection error", func() { testErr := errors.New("connection error") conn.readErr = testErr - sessionHandler.EXPECT().Close() + sessionHandler.EXPECT().Close(nil) done := make(chan struct{}) go func() { defer GinkgoRecover()