From f26a68d45c6198bf09b24b87ecba72c86db9c285 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 14 Jun 2018 16:14:06 +0700 Subject: [PATCH] implement dial functions that use a context --- Changelog.md | 1 + client.go | 60 +++++++++++++++++++++++++++++++++++++++----------- client_test.go | 40 +++++++++++++++++++++++++++++++-- 3 files changed, 86 insertions(+), 15 deletions(-) diff --git a/Changelog.md b/Changelog.md index 33eddd7f..d537032a 100644 --- a/Changelog.md +++ b/Changelog.md @@ -5,6 +5,7 @@ - Add support for unidirectional streams (for IETF QUIC). - Add a `quic.Config` option for the maximum number of incoming streams. - Add support for QUIC 42 and 43. +- Add dial functions that use a context. ## v0.7.0 (2018-02-03) diff --git a/client.go b/client.go index 5fbeb856..a343ffea 100644 --- a/client.go +++ b/client.go @@ -2,6 +2,7 @@ package quic import ( "bytes" + "context" "crypto/tls" "errors" "fmt" @@ -52,7 +53,22 @@ var ( // DialAddr establishes a new QUIC connection to a server. // The hostname for SNI is taken from the given address. -func DialAddr(addr string, tlsConf *tls.Config, config *Config) (Session, error) { +func DialAddr( + addr string, + tlsConf *tls.Config, + config *Config, +) (Session, error) { + return DialAddrContext(context.Background(), addr, tlsConf, config) +} + +// DialAddrContext establishes a new QUIC connection to a server using the provided context. +// The hostname for SNI is taken from the given address. +func DialAddrContext( + ctx context.Context, + addr string, + tlsConf *tls.Config, + config *Config, +) (Session, error) { udpAddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { return nil, err @@ -61,7 +77,7 @@ func DialAddr(addr string, tlsConf *tls.Config, config *Config) (Session, error) if err != nil { return nil, err } - return Dial(udpConn, udpAddr, addr, tlsConf, config) + return DialContext(ctx, udpConn, udpAddr, addr, tlsConf, config) } // Dial establishes a new QUIC connection to a server using a net.PacketConn. @@ -72,6 +88,19 @@ func Dial( host string, tlsConf *tls.Config, config *Config, +) (Session, error) { + return DialContext(context.Background(), pconn, remoteAddr, host, tlsConf, config) +} + +// DialContext establishes a new QUIC connection to a server using a net.PacketConn using the provided context. +// The host parameter is used for SNI. +func DialContext( + ctx context.Context, + pconn net.PacketConn, + remoteAddr net.Addr, + host string, + tlsConf *tls.Config, + config *Config, ) (Session, error) { clientConfig := populateClientConfig(config) version := clientConfig.Versions[0] @@ -106,6 +135,7 @@ func Dial( } } } + c := &client{ conn: &conn{pconn: pconn, currentAddr: remoteAddr}, srcConnID: srcConnID, @@ -120,7 +150,7 @@ func Dial( c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", hostname, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version) - if err := c.dial(); err != nil { + if err := c.dial(ctx); err != nil { return nil, err } return c.session, nil @@ -180,28 +210,28 @@ func populateClientConfig(config *Config) *Config { } } -func (c *client) dial() error { +func (c *client) dial(ctx context.Context) error { var err error if c.version.UsesTLS() { - err = c.dialTLS() + err = c.dialTLS(ctx) } else { - err = c.dialGQUIC() + err = c.dialGQUIC(ctx) } if err == errCloseSessionForNewVersion { - return c.dial() + return c.dial(ctx) } return err } -func (c *client) dialGQUIC() error { +func (c *client) dialGQUIC(ctx context.Context) error { if err := c.createNewGQUICSession(); err != nil { return err } go c.listen() - return c.establishSecureConnection() + return c.establishSecureConnection(ctx) } -func (c *client) dialTLS() error { +func (c *client) dialTLS(ctx context.Context) error { params := &handshake.TransportParameters{ StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow, ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow, @@ -224,7 +254,7 @@ func (c *client) dialTLS() error { return err } go c.listen() - if err := c.establishSecureConnection(); err != nil { + if err := c.establishSecureConnection(ctx); err != nil { if err != handshake.ErrCloseSessionForRetry { return err } @@ -232,7 +262,7 @@ func (c *client) dialTLS() error { if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil { return err } - if err := c.establishSecureConnection(); err != nil { + if err := c.establishSecureConnection(ctx); err != nil { return err } } @@ -245,7 +275,7 @@ func (c *client) dialTLS() error { // - handshake.ErrCloseSessionForRetry when the server performs a stateless retry (for IETF QUIC) // - any other error that might occur // - when the connection is secure (for gQUIC), or forward-secure (for IETF QUIC) -func (c *client) establishSecureConnection() error { +func (c *client) establishSecureConnection(ctx context.Context) error { errorChan := make(chan error, 1) go func() { @@ -254,6 +284,10 @@ func (c *client) establishSecureConnection() error { }() select { + case <-ctx.Done(): + // The session sending a PeerGoingAway error to the server. + c.session.Close(nil) + return ctx.Err() case err := <-errorChan: return err case <-c.handshakeChan: diff --git a/client_test.go b/client_test.go index b4e12ce3..f402bdfe 100644 --- a/client_test.go +++ b/client_test.go @@ -2,6 +2,7 @@ package quic import ( "bytes" + "context" "crypto/tls" "errors" "fmt" @@ -198,6 +199,41 @@ var _ = Describe("Client", func() { Eventually(handledPacket).Should(BeClosed()) }) + It("closes the session when the context is canceledd", func() { + sessionRunning := make(chan struct{}) + defer close(sessionRunning) + sess := NewMockPacketHandler(mockCtrl) + sess.EXPECT().run().Do(func() { + <-sessionRunning + }) + newClientSession = func( + conn connection, + _ sessionRunner, + _ string, + _ protocol.VersionNumber, + _ protocol.ConnectionID, + _ *tls.Config, + _ *Config, + _ protocol.VersionNumber, + _ []protocol.VersionNumber, + _ utils.Logger, + ) (packetHandler, error) { + return sess, nil + } + ctx, cancel := context.WithCancel(context.Background()) + dialed := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := DialContext(ctx, packetConn, addr, "quic.clemnte.io:1337", nil, nil) + Expect(err).To(MatchError(context.Canceled)) + close(dialed) + }() + Consistently(dialed).ShouldNot(BeClosed()) + sess.EXPECT().Close(nil) + cancel() + Eventually(dialed).Should(BeClosed()) + }) + Context("quic.Config", func() { It("setups with the right values", func() { config := &Config{ @@ -398,7 +434,7 @@ var _ = Describe("Client", func() { dialed := make(chan struct{}) go func() { defer GinkgoRecover() - err := cl.dial() + err := cl.dial(context.Background()) Expect(err).ToNot(HaveOccurred()) close(dialed) }() @@ -442,7 +478,7 @@ var _ = Describe("Client", func() { dialed := make(chan struct{}) go func() { defer GinkgoRecover() - err := cl.dial() + err := cl.dial(context.Background()) Expect(err).ToNot(HaveOccurred()) close(dialed) }()