crypto/tls: revert "add HandshakeContext method to Conn"

This reverts CL 246338.

Reason for revert: waiting for 1.17 release cycle

Updates #32406

Change-Id: I074379039041e086c62271d689b4b7f442281663
Reviewed-on: https://go-review.googlesource.com/c/go/+/269697
Run-TryBot: Johan Brandhorst-Satzkorn <johan.brandhorst@gmail.com>
Run-TryBot: Katie Hockman <katie@golang.org>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Katie Hockman <katie@golang.org>
Trust: Katie Hockman <katie@golang.org>
Trust: Roland Shoemaker <roland@golang.org>
This commit is contained in:
Johan Brandhorst 2020-11-12 20:34:51 +00:00 committed by Katie Hockman
parent a2ca1d5330
commit 8649b4ade4
9 changed files with 62 additions and 197 deletions

View file

@ -7,7 +7,6 @@ package tls
import ( import (
"bytes" "bytes"
"container/list" "container/list"
"context"
"crypto" "crypto"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/ed25519" "crypto/ed25519"
@ -444,16 +443,6 @@ type ClientHelloInfo struct {
// config is embedded by the GetCertificate or GetConfigForClient caller, // config is embedded by the GetCertificate or GetConfigForClient caller,
// for use with SupportsCertificate. // for use with SupportsCertificate.
config *Config config *Config
// ctx is the context of the handshake that is in progress.
ctx context.Context
}
// Context returns the context of the handshake that is in progress.
// This context is a child of the context passed to HandshakeContext,
// if any, and is canceled when the handshake concludes.
func (c *ClientHelloInfo) Context() context.Context {
return c.ctx
} }
// CertificateRequestInfo contains information from a server's // CertificateRequestInfo contains information from a server's
@ -472,16 +461,6 @@ type CertificateRequestInfo struct {
// Version is the TLS version that was negotiated for this connection. // Version is the TLS version that was negotiated for this connection.
Version uint16 Version uint16
// ctx is the context of the handshake that is in progress.
ctx context.Context
}
// Context returns the context of the handshake that is in progress.
// This context is a child of the context passed to HandshakeContext,
// if any, and is canceled when the handshake concludes.
func (c *CertificateRequestInfo) Context() context.Context {
return c.ctx
} }
// RenegotiationSupport enumerates the different levels of support for TLS // RenegotiationSupport enumerates the different levels of support for TLS

62
conn.go
View file

@ -8,7 +8,6 @@ package tls
import ( import (
"bytes" "bytes"
"context"
"crypto/cipher" "crypto/cipher"
"crypto/subtle" "crypto/subtle"
"crypto/x509" "crypto/x509"
@ -28,7 +27,7 @@ type Conn struct {
// constant // constant
conn net.Conn conn net.Conn
isClient bool isClient bool
handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake handshakeFn func() error // (*Conn).clientHandshake or serverHandshake
// handshakeStatus is 1 if the connection is currently transferring // handshakeStatus is 1 if the connection is currently transferring
// application data (i.e. is not currently processing a handshake). // application data (i.e. is not currently processing a handshake).
@ -1191,7 +1190,7 @@ func (c *Conn) handleRenegotiation() error {
defer c.handshakeMutex.Unlock() defer c.handshakeMutex.Unlock()
atomic.StoreUint32(&c.handshakeStatus, 0) atomic.StoreUint32(&c.handshakeStatus, 0)
if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil { if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil {
c.handshakes++ c.handshakes++
} }
return c.handshakeErr return c.handshakeErr
@ -1374,61 +1373,8 @@ func (c *Conn) closeNotify() error {
// first Read or Write will call it automatically. // first Read or Write will call it automatically.
// //
// For control over canceling or setting a timeout on a handshake, use // For control over canceling or setting a timeout on a handshake, use
// HandshakeContext or the Dialer's DialContext method instead. // the Dialer's DialContext method.
func (c *Conn) Handshake() error { func (c *Conn) Handshake() error {
return c.HandshakeContext(context.Background())
}
// HandshakeContext runs the client or server handshake
// protocol if it has not yet been run.
//
// The provided Context must be non-nil. If the context is canceled before
// the handshake is complete, the handshake is interrupted and an error is returned.
// Once the handshake has completed, cancellation of the context will not affect the
// connection.
//
// Most uses of this package need not call HandshakeContext explicitly: the
// first Read or Write will call it automatically.
func (c *Conn) HandshakeContext(ctx context.Context) error {
// Delegate to unexported method for named return
// without confusing documented signature.
return c.handshakeContext(ctx)
}
func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
handshakeCtx, cancel := context.WithCancel(ctx)
// Note: defer this before starting the "interrupter" goroutine
// so that we can tell the difference between the input being canceled and
// this cancellation. In the former case, we need to close the connection.
defer cancel()
// Start the "interrupter" goroutine, if this context might be canceled.
// (The background context cannot).
//
// The interrupter goroutine waits for the input context to be done and
// closes the connection if this happens before the function returns.
if ctx.Done() != nil {
done := make(chan struct{})
interruptRes := make(chan error, 1)
defer func() {
close(done)
if ctxErr := <-interruptRes; ctxErr != nil {
// Return context error to user.
ret = ctxErr
}
}()
go func() {
select {
case <-handshakeCtx.Done():
// Close the connection, discarding the error
_ = c.conn.Close()
interruptRes <- handshakeCtx.Err()
case <-done:
interruptRes <- nil
}
}()
}
c.handshakeMutex.Lock() c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock() defer c.handshakeMutex.Unlock()
@ -1442,7 +1388,7 @@ func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
c.in.Lock() c.in.Lock()
defer c.in.Unlock() defer c.in.Unlock()
c.handshakeErr = c.handshakeFn(handshakeCtx) c.handshakeErr = c.handshakeFn()
if c.handshakeErr == nil { if c.handshakeErr == nil {
c.handshakes++ c.handshakes++
} else { } else {

View file

@ -6,7 +6,6 @@ package tls
import ( import (
"bytes" "bytes"
"context"
"crypto" "crypto"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/ed25519" "crypto/ed25519"
@ -25,7 +24,6 @@ import (
type clientHandshakeState struct { type clientHandshakeState struct {
c *Conn c *Conn
ctx context.Context
serverHello *serverHelloMsg serverHello *serverHelloMsg
hello *clientHelloMsg hello *clientHelloMsg
suite *cipherSuite suite *cipherSuite
@ -136,7 +134,7 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, ecdheParameters, error) {
return hello, params, nil return hello, params, nil
} }
func (c *Conn) clientHandshake(ctx context.Context) (err error) { func (c *Conn) clientHandshake() (err error) {
if c.config == nil { if c.config == nil {
c.config = defaultConfig() c.config = defaultConfig()
} }
@ -200,7 +198,6 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) {
if c.vers == VersionTLS13 { if c.vers == VersionTLS13 {
hs := &clientHandshakeStateTLS13{ hs := &clientHandshakeStateTLS13{
c: c, c: c,
ctx: ctx,
serverHello: serverHello, serverHello: serverHello,
hello: hello, hello: hello,
ecdheParams: ecdheParams, ecdheParams: ecdheParams,
@ -215,7 +212,6 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) {
hs := &clientHandshakeState{ hs := &clientHandshakeState{
c: c, c: c,
ctx: ctx,
serverHello: serverHello, serverHello: serverHello,
hello: hello, hello: hello,
session: session, session: session,
@ -544,7 +540,7 @@ func (hs *clientHandshakeState) doFullHandshake() error {
certRequested = true certRequested = true
hs.finishedHash.Write(certReq.marshal()) hs.finishedHash.Write(certReq.marshal())
cri := certificateRequestInfoFromMsg(hs.ctx, c.vers, certReq) cri := certificateRequestInfoFromMsg(c.vers, certReq)
if chainToSend, err = c.getClientCertificate(cri); err != nil { if chainToSend, err = c.getClientCertificate(cri); err != nil {
c.sendAlert(alertInternalError) c.sendAlert(alertInternalError)
return err return err
@ -884,11 +880,10 @@ func (c *Conn) verifyServerCertificate(certificates [][]byte) error {
// certificateRequestInfoFromMsg generates a CertificateRequestInfo from a TLS // certificateRequestInfoFromMsg generates a CertificateRequestInfo from a TLS
// <= 1.2 CertificateRequest, making an effort to fill in missing information. // <= 1.2 CertificateRequest, making an effort to fill in missing information.
func certificateRequestInfoFromMsg(ctx context.Context, vers uint16, certReq *certificateRequestMsg) *CertificateRequestInfo { func certificateRequestInfoFromMsg(vers uint16, certReq *certificateRequestMsg) *CertificateRequestInfo {
cri := &CertificateRequestInfo{ cri := &CertificateRequestInfo{
AcceptableCAs: certReq.certificateAuthorities, AcceptableCAs: certReq.certificateAuthorities,
Version: vers, Version: vers,
ctx: ctx,
} }
var rsaAvail, ecAvail bool var rsaAvail, ecAvail bool

View file

@ -6,7 +6,6 @@ package tls
import ( import (
"bytes" "bytes"
"context"
"crypto/rsa" "crypto/rsa"
"crypto/x509" "crypto/x509"
"encoding/base64" "encoding/base64"
@ -21,7 +20,6 @@ import (
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"reflect" "reflect"
"runtime"
"strconv" "strconv"
"strings" "strings"
"testing" "testing"
@ -2513,37 +2511,3 @@ func testResumptionKeepsOCSPAndSCT(t *testing.T, ver uint16) {
serverConfig.Certificates[0].SignedCertificateTimestamps, ccs.SignedCertificateTimestamps) serverConfig.Certificates[0].SignedCertificateTimestamps, ccs.SignedCertificateTimestamps)
} }
} }
func TestClientHandshakeContextCancellation(t *testing.T) {
c, s := localPipe(t)
serverConfig := testConfig.Clone()
serverErr := make(chan error, 1)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
defer close(serverErr)
defer s.Close()
conn := Server(s, serverConfig)
_, err := conn.readClientHello(ctx)
cancel()
serverErr <- err
}()
cli := Client(c, testConfig)
err := cli.HandshakeContext(ctx)
if err == nil {
t.Fatal("Client handshake did not error when the context was canceled")
}
if err != context.Canceled {
t.Errorf("Unexpected client handshake error: %v", err)
}
if err := <-serverErr; err != nil {
t.Errorf("Unexpected server error: %v", err)
}
if runtime.GOARCH == "wasm" {
t.Skip("conn.Close does not error as expected when called multiple times on WASM")
}
err = cli.Close()
if err == nil {
t.Error("Client connection was not closed when the context was canceled")
}
}

View file

@ -6,7 +6,6 @@ package tls
import ( import (
"bytes" "bytes"
"context"
"crypto" "crypto"
"crypto/hmac" "crypto/hmac"
"crypto/rsa" "crypto/rsa"
@ -18,7 +17,6 @@ import (
type clientHandshakeStateTLS13 struct { type clientHandshakeStateTLS13 struct {
c *Conn c *Conn
ctx context.Context
serverHello *serverHelloMsg serverHello *serverHelloMsg
hello *clientHelloMsg hello *clientHelloMsg
ecdheParams ecdheParameters ecdheParams ecdheParameters
@ -557,7 +555,6 @@ func (hs *clientHandshakeStateTLS13) sendClientCertificate() error {
AcceptableCAs: hs.certReq.certificateAuthorities, AcceptableCAs: hs.certReq.certificateAuthorities,
SignatureSchemes: hs.certReq.supportedSignatureAlgorithms, SignatureSchemes: hs.certReq.supportedSignatureAlgorithms,
Version: c.vers, Version: c.vers,
ctx: hs.ctx,
}) })
if err != nil { if err != nil {
return err return err

View file

@ -5,7 +5,6 @@
package tls package tls
import ( import (
"context"
"crypto" "crypto"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/ed25519" "crypto/ed25519"
@ -24,7 +23,6 @@ import (
// It's discarded once the handshake has completed. // It's discarded once the handshake has completed.
type serverHandshakeState struct { type serverHandshakeState struct {
c *Conn c *Conn
ctx context.Context
clientHello *clientHelloMsg clientHello *clientHelloMsg
hello *serverHelloMsg hello *serverHelloMsg
suite *cipherSuite suite *cipherSuite
@ -39,8 +37,8 @@ type serverHandshakeState struct {
} }
// serverHandshake performs a TLS handshake as a server. // serverHandshake performs a TLS handshake as a server.
func (c *Conn) serverHandshake(ctx context.Context) error { func (c *Conn) serverHandshake() error {
clientHello, err := c.readClientHello(ctx) clientHello, err := c.readClientHello()
if err != nil { if err != nil {
return err return err
} }
@ -48,7 +46,6 @@ func (c *Conn) serverHandshake(ctx context.Context) error {
if c.vers == VersionTLS13 { if c.vers == VersionTLS13 {
hs := serverHandshakeStateTLS13{ hs := serverHandshakeStateTLS13{
c: c, c: c,
ctx: ctx,
clientHello: clientHello, clientHello: clientHello,
} }
return hs.handshake() return hs.handshake()
@ -56,7 +53,6 @@ func (c *Conn) serverHandshake(ctx context.Context) error {
hs := serverHandshakeState{ hs := serverHandshakeState{
c: c, c: c,
ctx: ctx,
clientHello: clientHello, clientHello: clientHello,
} }
return hs.handshake() return hs.handshake()
@ -128,7 +124,7 @@ func (hs *serverHandshakeState) handshake() error {
} }
// readClientHello reads a ClientHello message and selects the protocol version. // readClientHello reads a ClientHello message and selects the protocol version.
func (c *Conn) readClientHello(ctx context.Context) (*clientHelloMsg, error) { func (c *Conn) readClientHello() (*clientHelloMsg, error) {
msg, err := c.readHandshake() msg, err := c.readHandshake()
if err != nil { if err != nil {
return nil, err return nil, err
@ -142,7 +138,7 @@ func (c *Conn) readClientHello(ctx context.Context) (*clientHelloMsg, error) {
var configForClient *Config var configForClient *Config
originalConfig := c.config originalConfig := c.config
if c.config.GetConfigForClient != nil { if c.config.GetConfigForClient != nil {
chi := clientHelloInfo(ctx, c, clientHello) chi := clientHelloInfo(c, clientHello)
if configForClient, err = c.config.GetConfigForClient(chi); err != nil { if configForClient, err = c.config.GetConfigForClient(chi); err != nil {
c.sendAlert(alertInternalError) c.sendAlert(alertInternalError)
return nil, err return nil, err
@ -224,7 +220,7 @@ func (hs *serverHandshakeState) processClientHello() error {
} }
} }
hs.cert, err = c.config.getCertificate(clientHelloInfo(hs.ctx, c, hs.clientHello)) hs.cert, err = c.config.getCertificate(clientHelloInfo(c, hs.clientHello))
if err != nil { if err != nil {
if err == errNoCertificates { if err == errNoCertificates {
c.sendAlert(alertUnrecognizedName) c.sendAlert(alertUnrecognizedName)
@ -832,7 +828,7 @@ func (c *Conn) processCertsFromClient(certificate Certificate) error {
return nil return nil
} }
func clientHelloInfo(ctx context.Context, c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo { func clientHelloInfo(c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo {
supportedVersions := clientHello.supportedVersions supportedVersions := clientHello.supportedVersions
if len(clientHello.supportedVersions) == 0 { if len(clientHello.supportedVersions) == 0 {
supportedVersions = supportedVersionsFromMax(clientHello.vers) supportedVersions = supportedVersionsFromMax(clientHello.vers)
@ -848,6 +844,5 @@ func clientHelloInfo(ctx context.Context, c *Conn, clientHello *clientHelloMsg)
SupportedVersions: supportedVersions, SupportedVersions: supportedVersions,
Conn: c.conn, Conn: c.conn,
config: c.config, config: c.config,
ctx: ctx,
} }
} }

View file

@ -6,7 +6,6 @@ package tls
import ( import (
"bytes" "bytes"
"context"
"crypto" "crypto"
"crypto/elliptic" "crypto/elliptic"
"crypto/x509" "crypto/x509"
@ -18,7 +17,6 @@ import (
"os" "os"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"runtime"
"strings" "strings"
"testing" "testing"
"time" "time"
@ -40,12 +38,10 @@ func testClientHelloFailure(t *testing.T, serverConfig *Config, m handshakeMessa
cli.writeRecord(recordTypeHandshake, m.marshal()) cli.writeRecord(recordTypeHandshake, m.marshal())
c.Close() c.Close()
}() }()
ctx := context.Background()
conn := Server(s, serverConfig) conn := Server(s, serverConfig)
ch, err := conn.readClientHello(ctx) ch, err := conn.readClientHello()
hs := serverHandshakeState{ hs := serverHandshakeState{
c: conn, c: conn,
ctx: ctx,
clientHello: ch, clientHello: ch,
} }
if err == nil { if err == nil {
@ -1425,11 +1421,9 @@ func TestSNIGivenOnFailure(t *testing.T) {
c.Close() c.Close()
}() }()
conn := Server(s, serverConfig) conn := Server(s, serverConfig)
ctx := context.Background() ch, err := conn.readClientHello()
ch, err := conn.readClientHello(ctx)
hs := serverHandshakeState{ hs := serverHandshakeState{
c: conn, c: conn,
ctx: ctx,
clientHello: ch, clientHello: ch,
} }
if err == nil { if err == nil {
@ -1683,46 +1677,6 @@ func TestMultipleCertificates(t *testing.T) {
} }
} }
func TestServerHandshakeContextCancellation(t *testing.T) {
c, s := localPipe(t)
clientConfig := testConfig.Clone()
clientErr := make(chan error, 1)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
defer close(clientErr)
defer c.Close()
clientHello := &clientHelloMsg{
vers: VersionTLS10,
random: make([]byte, 32),
cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA},
compressionMethods: []uint8{compressionNone},
}
cli := Client(c, clientConfig)
_, err := cli.writeRecord(recordTypeHandshake, clientHello.marshal())
cancel()
clientErr <- err
}()
conn := Server(s, testConfig)
err := conn.HandshakeContext(ctx)
if err == nil {
t.Fatal("Server handshake did not error when the context was canceled")
}
if err != context.Canceled {
t.Errorf("Unexpected server handshake error: %v", err)
}
if err := <-clientErr; err != nil {
t.Errorf("Unexpected client error: %v", err)
}
if runtime.GOARCH == "wasm" {
t.Skip("conn.Close does not error as expected when called multiple times on WASM")
}
err = conn.Close()
if err == nil {
t.Error("Server connection was not closed when the context was canceled")
}
}
func TestAESCipherReordering(t *testing.T) { func TestAESCipherReordering(t *testing.T) {
currentAESSupport := hasAESGCMHardwareSupport currentAESSupport := hasAESGCMHardwareSupport
defer func() { hasAESGCMHardwareSupport = currentAESSupport; initDefaultCipherSuites() }() defer func() { hasAESGCMHardwareSupport = currentAESSupport; initDefaultCipherSuites() }()

View file

@ -6,7 +6,6 @@ package tls
import ( import (
"bytes" "bytes"
"context"
"crypto" "crypto"
"crypto/hmac" "crypto/hmac"
"crypto/rsa" "crypto/rsa"
@ -24,7 +23,6 @@ const maxClientPSKIdentities = 5
type serverHandshakeStateTLS13 struct { type serverHandshakeStateTLS13 struct {
c *Conn c *Conn
ctx context.Context
clientHello *clientHelloMsg clientHello *clientHelloMsg
hello *serverHelloMsg hello *serverHelloMsg
sentDummyCCS bool sentDummyCCS bool
@ -376,7 +374,7 @@ func (hs *serverHandshakeStateTLS13) pickCertificate() error {
return c.sendAlert(alertMissingExtension) return c.sendAlert(alertMissingExtension)
} }
certificate, err := c.config.getCertificate(clientHelloInfo(hs.ctx, c, hs.clientHello)) certificate, err := c.config.getCertificate(clientHelloInfo(c, hs.clientHello))
if err != nil { if err != nil {
if err == errNoCertificates { if err == errNoCertificates {
c.sendAlert(alertUnrecognizedName) c.sendAlert(alertUnrecognizedName)

55
tls.go
View file

@ -25,6 +25,7 @@ import (
"net" "net"
"os" "os"
"strings" "strings"
"time"
) )
// Server returns a new TLS server side connection // Server returns a new TLS server side connection
@ -115,16 +116,28 @@ func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*
} }
func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config) (*Conn, error) { func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
if netDialer.Timeout != 0 { // We want the Timeout and Deadline values from dialer to cover the
var cancel context.CancelFunc // whole process: TCP connection and TLS handshake. This means that we
ctx, cancel = context.WithTimeout(ctx, netDialer.Timeout) // also need to start our own timers now.
defer cancel() timeout := netDialer.Timeout
}
if !netDialer.Deadline.IsZero() { if !netDialer.Deadline.IsZero() {
var cancel context.CancelFunc deadlineTimeout := time.Until(netDialer.Deadline)
ctx, cancel = context.WithDeadline(ctx, netDialer.Deadline) if timeout == 0 || deadlineTimeout < timeout {
defer cancel() timeout = deadlineTimeout
}
}
// hsErrCh is non-nil if we might not wait for Handshake to complete.
var hsErrCh chan error
if timeout != 0 || ctx.Done() != nil {
hsErrCh = make(chan error, 2)
}
if timeout != 0 {
timer := time.AfterFunc(timeout, func() {
hsErrCh <- timeoutError{}
})
defer timer.Stop()
} }
rawConn, err := netDialer.DialContext(ctx, network, addr) rawConn, err := netDialer.DialContext(ctx, network, addr)
@ -151,10 +164,34 @@ func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, conf
} }
conn := Client(rawConn, config) conn := Client(rawConn, config)
if err := conn.HandshakeContext(ctx); err != nil {
if hsErrCh == nil {
err = conn.Handshake()
} else {
go func() {
hsErrCh <- conn.Handshake()
}()
select {
case <-ctx.Done():
err = ctx.Err()
case err = <-hsErrCh:
if err != nil {
// If the error was due to the context
// closing, prefer the context's error, rather
// than some random network teardown error.
if e := ctx.Err(); e != nil {
err = e
}
}
}
}
if err != nil {
rawConn.Close() rawConn.Close()
return nil, err return nil, err
} }
return conn, nil return conn, nil
} }