mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-05 13:17:36 +03:00
implement dial functions that use a context
This commit is contained in:
parent
7356fc05d5
commit
f26a68d45c
3 changed files with 86 additions and 15 deletions
|
@ -5,6 +5,7 @@
|
||||||
- Add support for unidirectional streams (for IETF QUIC).
|
- Add support for unidirectional streams (for IETF QUIC).
|
||||||
- Add a `quic.Config` option for the maximum number of incoming streams.
|
- Add a `quic.Config` option for the maximum number of incoming streams.
|
||||||
- Add support for QUIC 42 and 43.
|
- Add support for QUIC 42 and 43.
|
||||||
|
- Add dial functions that use a context.
|
||||||
|
|
||||||
## v0.7.0 (2018-02-03)
|
## v0.7.0 (2018-02-03)
|
||||||
|
|
||||||
|
|
60
client.go
60
client.go
|
@ -2,6 +2,7 @@ package quic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -52,7 +53,22 @@ var (
|
||||||
|
|
||||||
// DialAddr establishes a new QUIC connection to a server.
|
// DialAddr establishes a new QUIC connection to a server.
|
||||||
// The hostname for SNI is taken from the given address.
|
// The hostname for SNI is taken from the given address.
|
||||||
func DialAddr(addr string, tlsConf *tls.Config, config *Config) (Session, error) {
|
func DialAddr(
|
||||||
|
addr string,
|
||||||
|
tlsConf *tls.Config,
|
||||||
|
config *Config,
|
||||||
|
) (Session, error) {
|
||||||
|
return DialAddrContext(context.Background(), addr, tlsConf, config)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialAddrContext establishes a new QUIC connection to a server using the provided context.
|
||||||
|
// The hostname for SNI is taken from the given address.
|
||||||
|
func DialAddrContext(
|
||||||
|
ctx context.Context,
|
||||||
|
addr string,
|
||||||
|
tlsConf *tls.Config,
|
||||||
|
config *Config,
|
||||||
|
) (Session, error) {
|
||||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -61,7 +77,7 @@ func DialAddr(addr string, tlsConf *tls.Config, config *Config) (Session, error)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return Dial(udpConn, udpAddr, addr, tlsConf, config)
|
return DialContext(ctx, udpConn, udpAddr, addr, tlsConf, config)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
|
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
|
||||||
|
@ -72,6 +88,19 @@ func Dial(
|
||||||
host string,
|
host string,
|
||||||
tlsConf *tls.Config,
|
tlsConf *tls.Config,
|
||||||
config *Config,
|
config *Config,
|
||||||
|
) (Session, error) {
|
||||||
|
return DialContext(context.Background(), pconn, remoteAddr, host, tlsConf, config)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialContext establishes a new QUIC connection to a server using a net.PacketConn using the provided context.
|
||||||
|
// The host parameter is used for SNI.
|
||||||
|
func DialContext(
|
||||||
|
ctx context.Context,
|
||||||
|
pconn net.PacketConn,
|
||||||
|
remoteAddr net.Addr,
|
||||||
|
host string,
|
||||||
|
tlsConf *tls.Config,
|
||||||
|
config *Config,
|
||||||
) (Session, error) {
|
) (Session, error) {
|
||||||
clientConfig := populateClientConfig(config)
|
clientConfig := populateClientConfig(config)
|
||||||
version := clientConfig.Versions[0]
|
version := clientConfig.Versions[0]
|
||||||
|
@ -106,6 +135,7 @@ func Dial(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
c := &client{
|
c := &client{
|
||||||
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
|
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
|
||||||
srcConnID: srcConnID,
|
srcConnID: srcConnID,
|
||||||
|
@ -120,7 +150,7 @@ func Dial(
|
||||||
|
|
||||||
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", hostname, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
|
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", hostname, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
|
||||||
|
|
||||||
if err := c.dial(); err != nil {
|
if err := c.dial(ctx); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return c.session, nil
|
return c.session, nil
|
||||||
|
@ -180,28 +210,28 @@ func populateClientConfig(config *Config) *Config {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) dial() error {
|
func (c *client) dial(ctx context.Context) error {
|
||||||
var err error
|
var err error
|
||||||
if c.version.UsesTLS() {
|
if c.version.UsesTLS() {
|
||||||
err = c.dialTLS()
|
err = c.dialTLS(ctx)
|
||||||
} else {
|
} else {
|
||||||
err = c.dialGQUIC()
|
err = c.dialGQUIC(ctx)
|
||||||
}
|
}
|
||||||
if err == errCloseSessionForNewVersion {
|
if err == errCloseSessionForNewVersion {
|
||||||
return c.dial()
|
return c.dial(ctx)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) dialGQUIC() error {
|
func (c *client) dialGQUIC(ctx context.Context) error {
|
||||||
if err := c.createNewGQUICSession(); err != nil {
|
if err := c.createNewGQUICSession(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
go c.listen()
|
go c.listen()
|
||||||
return c.establishSecureConnection()
|
return c.establishSecureConnection(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) dialTLS() error {
|
func (c *client) dialTLS(ctx context.Context) error {
|
||||||
params := &handshake.TransportParameters{
|
params := &handshake.TransportParameters{
|
||||||
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
|
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
|
||||||
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
|
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
|
||||||
|
@ -224,7 +254,7 @@ func (c *client) dialTLS() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
go c.listen()
|
go c.listen()
|
||||||
if err := c.establishSecureConnection(); err != nil {
|
if err := c.establishSecureConnection(ctx); err != nil {
|
||||||
if err != handshake.ErrCloseSessionForRetry {
|
if err != handshake.ErrCloseSessionForRetry {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -232,7 +262,7 @@ func (c *client) dialTLS() error {
|
||||||
if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil {
|
if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := c.establishSecureConnection(); err != nil {
|
if err := c.establishSecureConnection(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -245,7 +275,7 @@ func (c *client) dialTLS() error {
|
||||||
// - handshake.ErrCloseSessionForRetry when the server performs a stateless retry (for IETF QUIC)
|
// - handshake.ErrCloseSessionForRetry when the server performs a stateless retry (for IETF QUIC)
|
||||||
// - any other error that might occur
|
// - any other error that might occur
|
||||||
// - when the connection is secure (for gQUIC), or forward-secure (for IETF QUIC)
|
// - when the connection is secure (for gQUIC), or forward-secure (for IETF QUIC)
|
||||||
func (c *client) establishSecureConnection() error {
|
func (c *client) establishSecureConnection(ctx context.Context) error {
|
||||||
errorChan := make(chan error, 1)
|
errorChan := make(chan error, 1)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -254,6 +284,10 @@ func (c *client) establishSecureConnection() error {
|
||||||
}()
|
}()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
// The session sending a PeerGoingAway error to the server.
|
||||||
|
c.session.Close(nil)
|
||||||
|
return ctx.Err()
|
||||||
case err := <-errorChan:
|
case err := <-errorChan:
|
||||||
return err
|
return err
|
||||||
case <-c.handshakeChan:
|
case <-c.handshakeChan:
|
||||||
|
|
|
@ -2,6 +2,7 @@ package quic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -198,6 +199,41 @@ var _ = Describe("Client", func() {
|
||||||
Eventually(handledPacket).Should(BeClosed())
|
Eventually(handledPacket).Should(BeClosed())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("closes the session when the context is canceledd", func() {
|
||||||
|
sessionRunning := make(chan struct{})
|
||||||
|
defer close(sessionRunning)
|
||||||
|
sess := NewMockPacketHandler(mockCtrl)
|
||||||
|
sess.EXPECT().run().Do(func() {
|
||||||
|
<-sessionRunning
|
||||||
|
})
|
||||||
|
newClientSession = func(
|
||||||
|
conn connection,
|
||||||
|
_ sessionRunner,
|
||||||
|
_ string,
|
||||||
|
_ protocol.VersionNumber,
|
||||||
|
_ protocol.ConnectionID,
|
||||||
|
_ *tls.Config,
|
||||||
|
_ *Config,
|
||||||
|
_ protocol.VersionNumber,
|
||||||
|
_ []protocol.VersionNumber,
|
||||||
|
_ utils.Logger,
|
||||||
|
) (packetHandler, error) {
|
||||||
|
return sess, nil
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
dialed := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
_, err := DialContext(ctx, packetConn, addr, "quic.clemnte.io:1337", nil, nil)
|
||||||
|
Expect(err).To(MatchError(context.Canceled))
|
||||||
|
close(dialed)
|
||||||
|
}()
|
||||||
|
Consistently(dialed).ShouldNot(BeClosed())
|
||||||
|
sess.EXPECT().Close(nil)
|
||||||
|
cancel()
|
||||||
|
Eventually(dialed).Should(BeClosed())
|
||||||
|
})
|
||||||
|
|
||||||
Context("quic.Config", func() {
|
Context("quic.Config", func() {
|
||||||
It("setups with the right values", func() {
|
It("setups with the right values", func() {
|
||||||
config := &Config{
|
config := &Config{
|
||||||
|
@ -398,7 +434,7 @@ var _ = Describe("Client", func() {
|
||||||
dialed := make(chan struct{})
|
dialed := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
err := cl.dial()
|
err := cl.dial(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
close(dialed)
|
close(dialed)
|
||||||
}()
|
}()
|
||||||
|
@ -442,7 +478,7 @@ var _ = Describe("Client", func() {
|
||||||
dialed := make(chan struct{})
|
dialed := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
err := cl.dial()
|
err := cl.dial(context.Background())
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
close(dialed)
|
close(dialed)
|
||||||
}()
|
}()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue