crypto/tls: add Dialer

Fixes #18482

Change-Id: I99d65dc5d824c00093ea61e7445fc121314af87f
Reviewed-on: https://go-review.googlesource.com/c/go/+/214977
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
This commit is contained in:
Brad Fitzpatrick 2020-01-15 19:27:32 +00:00
parent 2fcb91d134
commit 2f2a543ff4
3 changed files with 126 additions and 13 deletions

89
tls.go
View file

@ -13,6 +13,7 @@ package tls
import (
"bytes"
"context"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
@ -111,29 +112,35 @@ func (timeoutError) Temporary() bool { return true }
// DialWithDialer interprets a nil configuration as equivalent to the zero
// configuration; see the documentation of Config for the defaults.
func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
return dial(context.Background(), dialer, network, addr, config)
}
func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
// We want the Timeout and Deadline values from dialer to cover the
// whole process: TCP connection and TLS handshake. This means that we
// also need to start our own timers now.
timeout := dialer.Timeout
timeout := netDialer.Timeout
if !dialer.Deadline.IsZero() {
deadlineTimeout := time.Until(dialer.Deadline)
if !netDialer.Deadline.IsZero() {
deadlineTimeout := time.Until(netDialer.Deadline)
if timeout == 0 || deadlineTimeout < timeout {
timeout = deadlineTimeout
}
}
var errChannel chan error
// hsErrCh is non-nil if we might not wait for Handshake to complete.
var hsErrCh chan error
if timeout != 0 || ctx.Done() != nil {
hsErrCh = make(chan error, 2)
}
if timeout != 0 {
errChannel = make(chan error, 2)
timer := time.AfterFunc(timeout, func() {
errChannel <- timeoutError{}
hsErrCh <- timeoutError{}
})
defer timer.Stop()
}
rawConn, err := dialer.Dial(network, addr)
rawConn, err := netDialer.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
@ -158,14 +165,26 @@ func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*
conn := Client(rawConn, config)
if timeout == 0 {
if hsErrCh == nil {
err = conn.Handshake()
} else {
go func() {
errChannel <- conn.Handshake()
hsErrCh <- conn.Handshake()
}()
err = <-errChannel
select {
case <-ctx.Done():
err = ctx.Err()
case err = <-hsErrCh:
if err != nil {
// If the error was due to the context
// closing, prefer the context's error, rather
// than some random network teardown error.
if e := ctx.Err(); e != nil {
err = e
}
}
}
}
if err != nil {
@ -186,6 +205,54 @@ func Dial(network, addr string, config *Config) (*Conn, error) {
return DialWithDialer(new(net.Dialer), network, addr, config)
}
// Dialer dials TLS connections given a configuration and a Dialer for the
// underlying connection.
type Dialer struct {
// NetDialer is the optional dialer to use for the TLS connections'
// underlying TCP connections.
// A nil NetDialer is equivalent to the net.Dialer zero value.
NetDialer *net.Dialer
// Config is the TLS configuration to use for new connections.
// A nil configuration is equivalent to the zero
// configuration; see the documentation of Config for the
// defaults.
Config *Config
}
// Dial connects to the given network address and initiates a TLS
// handshake, returning the resulting TLS connection.
//
// The returned Conn, if any, will always be of type *Conn.
func (d *Dialer) Dial(network, addr string) (net.Conn, error) {
return d.DialContext(context.Background(), network, addr)
}
func (d *Dialer) netDialer() *net.Dialer {
if d.NetDialer != nil {
return d.NetDialer
}
return new(net.Dialer)
}
// Dial connects to the given network address and initiates a TLS
// handshake, returning the resulting TLS connection.
//
// The provided Context must be non-nil. If the context expires before
// the connection is complete, an error is returned. Once successfully
// connected, any expiration of the context will not affect the
// connection.
//
// The returned Conn, if any, will always be of type *Conn.
func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
c, err := dial(ctx, d.netDialer(), network, addr, d.Config)
if err != nil {
// Don't return c (a typed nil) in an interface.
return nil, err
}
return c, nil
}
// LoadX509KeyPair reads and parses a public/private key pair from a pair
// of files. The files must contain PEM encoded data. The certificate file
// may contain intermediate certificates following the leaf certificate to