implement dial functions that use a context

This commit is contained in:
Marten Seemann 2018-06-14 16:14:06 +07:00
parent 7356fc05d5
commit f26a68d45c
3 changed files with 86 additions and 15 deletions

View file

@ -5,6 +5,7 @@
- Add support for unidirectional streams (for IETF QUIC). - Add support for unidirectional streams (for IETF QUIC).
- Add a `quic.Config` option for the maximum number of incoming streams. - Add a `quic.Config` option for the maximum number of incoming streams.
- Add support for QUIC 42 and 43. - Add support for QUIC 42 and 43.
- Add dial functions that use a context.
## v0.7.0 (2018-02-03) ## v0.7.0 (2018-02-03)

View file

@ -2,6 +2,7 @@ package quic
import ( import (
"bytes" "bytes"
"context"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
@ -52,7 +53,22 @@ var (
// DialAddr establishes a new QUIC connection to a server. // DialAddr establishes a new QUIC connection to a server.
// The hostname for SNI is taken from the given address. // The hostname for SNI is taken from the given address.
func DialAddr(addr string, tlsConf *tls.Config, config *Config) (Session, error) { func DialAddr(
addr string,
tlsConf *tls.Config,
config *Config,
) (Session, error) {
return DialAddrContext(context.Background(), addr, tlsConf, config)
}
// DialAddrContext establishes a new QUIC connection to a server using the provided context.
// The hostname for SNI is taken from the given address.
func DialAddrContext(
ctx context.Context,
addr string,
tlsConf *tls.Config,
config *Config,
) (Session, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr) udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil { if err != nil {
return nil, err return nil, err
@ -61,7 +77,7 @@ func DialAddr(addr string, tlsConf *tls.Config, config *Config) (Session, error)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return Dial(udpConn, udpAddr, addr, tlsConf, config) return DialContext(ctx, udpConn, udpAddr, addr, tlsConf, config)
} }
// Dial establishes a new QUIC connection to a server using a net.PacketConn. // Dial establishes a new QUIC connection to a server using a net.PacketConn.
@ -72,6 +88,19 @@ func Dial(
host string, host string,
tlsConf *tls.Config, tlsConf *tls.Config,
config *Config, config *Config,
) (Session, error) {
return DialContext(context.Background(), pconn, remoteAddr, host, tlsConf, config)
}
// DialContext establishes a new QUIC connection to a server using a net.PacketConn using the provided context.
// The host parameter is used for SNI.
func DialContext(
ctx context.Context,
pconn net.PacketConn,
remoteAddr net.Addr,
host string,
tlsConf *tls.Config,
config *Config,
) (Session, error) { ) (Session, error) {
clientConfig := populateClientConfig(config) clientConfig := populateClientConfig(config)
version := clientConfig.Versions[0] version := clientConfig.Versions[0]
@ -106,6 +135,7 @@ func Dial(
} }
} }
} }
c := &client{ c := &client{
conn: &conn{pconn: pconn, currentAddr: remoteAddr}, conn: &conn{pconn: pconn, currentAddr: remoteAddr},
srcConnID: srcConnID, srcConnID: srcConnID,
@ -120,7 +150,7 @@ func Dial(
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", hostname, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version) c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", hostname, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
if err := c.dial(); err != nil { if err := c.dial(ctx); err != nil {
return nil, err return nil, err
} }
return c.session, nil return c.session, nil
@ -180,28 +210,28 @@ func populateClientConfig(config *Config) *Config {
} }
} }
func (c *client) dial() error { func (c *client) dial(ctx context.Context) error {
var err error var err error
if c.version.UsesTLS() { if c.version.UsesTLS() {
err = c.dialTLS() err = c.dialTLS(ctx)
} else { } else {
err = c.dialGQUIC() err = c.dialGQUIC(ctx)
} }
if err == errCloseSessionForNewVersion { if err == errCloseSessionForNewVersion {
return c.dial() return c.dial(ctx)
} }
return err return err
} }
func (c *client) dialGQUIC() error { func (c *client) dialGQUIC(ctx context.Context) error {
if err := c.createNewGQUICSession(); err != nil { if err := c.createNewGQUICSession(); err != nil {
return err return err
} }
go c.listen() go c.listen()
return c.establishSecureConnection() return c.establishSecureConnection(ctx)
} }
func (c *client) dialTLS() error { func (c *client) dialTLS(ctx context.Context) error {
params := &handshake.TransportParameters{ params := &handshake.TransportParameters{
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow, StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow, ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
@ -224,7 +254,7 @@ func (c *client) dialTLS() error {
return err return err
} }
go c.listen() go c.listen()
if err := c.establishSecureConnection(); err != nil { if err := c.establishSecureConnection(ctx); err != nil {
if err != handshake.ErrCloseSessionForRetry { if err != handshake.ErrCloseSessionForRetry {
return err return err
} }
@ -232,7 +262,7 @@ func (c *client) dialTLS() error {
if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil { if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil {
return err return err
} }
if err := c.establishSecureConnection(); err != nil { if err := c.establishSecureConnection(ctx); err != nil {
return err return err
} }
} }
@ -245,7 +275,7 @@ func (c *client) dialTLS() error {
// - handshake.ErrCloseSessionForRetry when the server performs a stateless retry (for IETF QUIC) // - handshake.ErrCloseSessionForRetry when the server performs a stateless retry (for IETF QUIC)
// - any other error that might occur // - any other error that might occur
// - when the connection is secure (for gQUIC), or forward-secure (for IETF QUIC) // - when the connection is secure (for gQUIC), or forward-secure (for IETF QUIC)
func (c *client) establishSecureConnection() error { func (c *client) establishSecureConnection(ctx context.Context) error {
errorChan := make(chan error, 1) errorChan := make(chan error, 1)
go func() { go func() {
@ -254,6 +284,10 @@ func (c *client) establishSecureConnection() error {
}() }()
select { select {
case <-ctx.Done():
// The session sending a PeerGoingAway error to the server.
c.session.Close(nil)
return ctx.Err()
case err := <-errorChan: case err := <-errorChan:
return err return err
case <-c.handshakeChan: case <-c.handshakeChan:

View file

@ -2,6 +2,7 @@ package quic
import ( import (
"bytes" "bytes"
"context"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
@ -198,6 +199,41 @@ var _ = Describe("Client", func() {
Eventually(handledPacket).Should(BeClosed()) Eventually(handledPacket).Should(BeClosed())
}) })
It("closes the session when the context is canceledd", func() {
sessionRunning := make(chan struct{})
defer close(sessionRunning)
sess := NewMockPacketHandler(mockCtrl)
sess.EXPECT().run().Do(func() {
<-sessionRunning
})
newClientSession = func(
conn connection,
_ sessionRunner,
_ string,
_ protocol.VersionNumber,
_ protocol.ConnectionID,
_ *tls.Config,
_ *Config,
_ protocol.VersionNumber,
_ []protocol.VersionNumber,
_ utils.Logger,
) (packetHandler, error) {
return sess, nil
}
ctx, cancel := context.WithCancel(context.Background())
dialed := make(chan struct{})
go func() {
defer GinkgoRecover()
_, err := DialContext(ctx, packetConn, addr, "quic.clemnte.io:1337", nil, nil)
Expect(err).To(MatchError(context.Canceled))
close(dialed)
}()
Consistently(dialed).ShouldNot(BeClosed())
sess.EXPECT().Close(nil)
cancel()
Eventually(dialed).Should(BeClosed())
})
Context("quic.Config", func() { Context("quic.Config", func() {
It("setups with the right values", func() { It("setups with the right values", func() {
config := &Config{ config := &Config{
@ -398,7 +434,7 @@ var _ = Describe("Client", func() {
dialed := make(chan struct{}) dialed := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
err := cl.dial() err := cl.dial(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
close(dialed) close(dialed)
}() }()
@ -442,7 +478,7 @@ var _ = Describe("Client", func() {
dialed := make(chan struct{}) dialed := make(chan struct{})
go func() { go func() {
defer GinkgoRecover() defer GinkgoRecover()
err := cl.dial() err := cl.dial(context.Background())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
close(dialed) close(dialed)
}() }()