crypto/tls: add HandshakeContext method to Conn

Adds the (*tls.Conn).HandshakeContext method. This allows
us to pass the context provided down the call stack to
eventually reach the tls.ClientHelloInfo and
tls.CertificateRequestInfo structs.
These contexts are exposed to the user as read-only via Context()
methods.

This allows users of (*tls.Config).GetCertificate and
(*tls.Config).GetClientCertificate to use the context for
request scoped parameters and cancellation.

Replace uses of (*tls.Conn).Handshake with (*tls.Conn).HandshakeContext
where appropriate, to propagate existing contexts.

Fixes #32406

Change-Id: I259939c744bdc9b805bf51a845a8bc462c042483
Reviewed-on: https://go-review.googlesource.com/c/go/+/295370
Run-TryBot: Johan Brandhorst-Satzkorn <johan.brandhorst@gmail.com>
TryBot-Result: Go Bot <gobot@golang.org>
Trust: Katie Hockman <katie@golang.org>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
This commit is contained in:
Johan Brandhorst 2020-08-01 12:18:31 +01:00 committed by Johan Brandhorst-Satzkorn
parent 2708f2d5a3
commit 93cad92f83
9 changed files with 266 additions and 62 deletions

55
tls.go
View file

@ -25,7 +25,6 @@ import (
"net"
"os"
"strings"
"time"
)
// Server returns a new TLS server side connection
@ -119,28 +118,16 @@ func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*
}
func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
// We want the Timeout and Deadline values from dialer to cover the
// whole process: TCP connection and TLS handshake. This means that we
// also need to start our own timers now.
timeout := netDialer.Timeout
if netDialer.Timeout != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, netDialer.Timeout)
defer cancel()
}
if !netDialer.Deadline.IsZero() {
deadlineTimeout := time.Until(netDialer.Deadline)
if timeout == 0 || deadlineTimeout < timeout {
timeout = deadlineTimeout
}
}
// hsErrCh is non-nil if we might not wait for Handshake to complete.
var hsErrCh chan error
if timeout != 0 || ctx.Done() != nil {
hsErrCh = make(chan error, 2)
}
if timeout != 0 {
timer := time.AfterFunc(timeout, func() {
hsErrCh <- timeoutError{}
})
defer timer.Stop()
var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(ctx, netDialer.Deadline)
defer cancel()
}
rawConn, err := netDialer.DialContext(ctx, network, addr)
@ -167,34 +154,10 @@ func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, conf
}
conn := Client(rawConn, config)
if hsErrCh == nil {
err = conn.Handshake()
} else {
go func() {
hsErrCh <- conn.Handshake()
}()
select {
case <-ctx.Done():
err = ctx.Err()
case err = <-hsErrCh:
if err != nil {
// If the error was due to the context
// closing, prefer the context's error, rather
// than some random network teardown error.
if e := ctx.Err(); e != nil {
err = e
}
}
}
}
if err != nil {
if err := conn.HandshakeContext(ctx); err != nil {
rawConn.Close()
return nil, err
}
return conn, nil
}