mirror of
https://github.com/refraction-networking/utls.git
synced 2025-04-04 12:37:35 +03:00
crypto/tls: revert "add HandshakeContext method to Conn"
This reverts CL 246338. Reason for revert: waiting for 1.17 release cycle Updates #32406 Change-Id: I074379039041e086c62271d689b4b7f442281663 Reviewed-on: https://go-review.googlesource.com/c/go/+/269697 Run-TryBot: Johan Brandhorst-Satzkorn <johan.brandhorst@gmail.com> Run-TryBot: Katie Hockman <katie@golang.org> TryBot-Result: Go Bot <gobot@golang.org> Reviewed-by: Katie Hockman <katie@golang.org> Trust: Katie Hockman <katie@golang.org> Trust: Roland Shoemaker <roland@golang.org>
This commit is contained in:
parent
a2ca1d5330
commit
8649b4ade4
9 changed files with 62 additions and 197 deletions
21
common.go
21
common.go
|
@ -7,7 +7,6 @@ package tls
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"container/list"
|
"container/list"
|
||||||
"context"
|
|
||||||
"crypto"
|
"crypto"
|
||||||
"crypto/ecdsa"
|
"crypto/ecdsa"
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
|
@ -444,16 +443,6 @@ type ClientHelloInfo struct {
|
||||||
// config is embedded by the GetCertificate or GetConfigForClient caller,
|
// config is embedded by the GetCertificate or GetConfigForClient caller,
|
||||||
// for use with SupportsCertificate.
|
// for use with SupportsCertificate.
|
||||||
config *Config
|
config *Config
|
||||||
|
|
||||||
// ctx is the context of the handshake that is in progress.
|
|
||||||
ctx context.Context
|
|
||||||
}
|
|
||||||
|
|
||||||
// Context returns the context of the handshake that is in progress.
|
|
||||||
// This context is a child of the context passed to HandshakeContext,
|
|
||||||
// if any, and is canceled when the handshake concludes.
|
|
||||||
func (c *ClientHelloInfo) Context() context.Context {
|
|
||||||
return c.ctx
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// CertificateRequestInfo contains information from a server's
|
// CertificateRequestInfo contains information from a server's
|
||||||
|
@ -472,16 +461,6 @@ type CertificateRequestInfo struct {
|
||||||
|
|
||||||
// Version is the TLS version that was negotiated for this connection.
|
// Version is the TLS version that was negotiated for this connection.
|
||||||
Version uint16
|
Version uint16
|
||||||
|
|
||||||
// ctx is the context of the handshake that is in progress.
|
|
||||||
ctx context.Context
|
|
||||||
}
|
|
||||||
|
|
||||||
// Context returns the context of the handshake that is in progress.
|
|
||||||
// This context is a child of the context passed to HandshakeContext,
|
|
||||||
// if any, and is canceled when the handshake concludes.
|
|
||||||
func (c *CertificateRequestInfo) Context() context.Context {
|
|
||||||
return c.ctx
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RenegotiationSupport enumerates the different levels of support for TLS
|
// RenegotiationSupport enumerates the different levels of support for TLS
|
||||||
|
|
62
conn.go
62
conn.go
|
@ -8,7 +8,6 @@ package tls
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
|
||||||
"crypto/cipher"
|
"crypto/cipher"
|
||||||
"crypto/subtle"
|
"crypto/subtle"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
|
@ -28,7 +27,7 @@ type Conn struct {
|
||||||
// constant
|
// constant
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
isClient bool
|
isClient bool
|
||||||
handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake
|
handshakeFn func() error // (*Conn).clientHandshake or serverHandshake
|
||||||
|
|
||||||
// handshakeStatus is 1 if the connection is currently transferring
|
// handshakeStatus is 1 if the connection is currently transferring
|
||||||
// application data (i.e. is not currently processing a handshake).
|
// application data (i.e. is not currently processing a handshake).
|
||||||
|
@ -1191,7 +1190,7 @@ func (c *Conn) handleRenegotiation() error {
|
||||||
defer c.handshakeMutex.Unlock()
|
defer c.handshakeMutex.Unlock()
|
||||||
|
|
||||||
atomic.StoreUint32(&c.handshakeStatus, 0)
|
atomic.StoreUint32(&c.handshakeStatus, 0)
|
||||||
if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil {
|
if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil {
|
||||||
c.handshakes++
|
c.handshakes++
|
||||||
}
|
}
|
||||||
return c.handshakeErr
|
return c.handshakeErr
|
||||||
|
@ -1374,61 +1373,8 @@ func (c *Conn) closeNotify() error {
|
||||||
// first Read or Write will call it automatically.
|
// first Read or Write will call it automatically.
|
||||||
//
|
//
|
||||||
// For control over canceling or setting a timeout on a handshake, use
|
// For control over canceling or setting a timeout on a handshake, use
|
||||||
// HandshakeContext or the Dialer's DialContext method instead.
|
// the Dialer's DialContext method.
|
||||||
func (c *Conn) Handshake() error {
|
func (c *Conn) Handshake() error {
|
||||||
return c.HandshakeContext(context.Background())
|
|
||||||
}
|
|
||||||
|
|
||||||
// HandshakeContext runs the client or server handshake
|
|
||||||
// protocol if it has not yet been run.
|
|
||||||
//
|
|
||||||
// The provided Context must be non-nil. If the context is canceled before
|
|
||||||
// the handshake is complete, the handshake is interrupted and an error is returned.
|
|
||||||
// Once the handshake has completed, cancellation of the context will not affect the
|
|
||||||
// connection.
|
|
||||||
//
|
|
||||||
// Most uses of this package need not call HandshakeContext explicitly: the
|
|
||||||
// first Read or Write will call it automatically.
|
|
||||||
func (c *Conn) HandshakeContext(ctx context.Context) error {
|
|
||||||
// Delegate to unexported method for named return
|
|
||||||
// without confusing documented signature.
|
|
||||||
return c.handshakeContext(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
|
|
||||||
handshakeCtx, cancel := context.WithCancel(ctx)
|
|
||||||
// Note: defer this before starting the "interrupter" goroutine
|
|
||||||
// so that we can tell the difference between the input being canceled and
|
|
||||||
// this cancellation. In the former case, we need to close the connection.
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
// Start the "interrupter" goroutine, if this context might be canceled.
|
|
||||||
// (The background context cannot).
|
|
||||||
//
|
|
||||||
// The interrupter goroutine waits for the input context to be done and
|
|
||||||
// closes the connection if this happens before the function returns.
|
|
||||||
if ctx.Done() != nil {
|
|
||||||
done := make(chan struct{})
|
|
||||||
interruptRes := make(chan error, 1)
|
|
||||||
defer func() {
|
|
||||||
close(done)
|
|
||||||
if ctxErr := <-interruptRes; ctxErr != nil {
|
|
||||||
// Return context error to user.
|
|
||||||
ret = ctxErr
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
select {
|
|
||||||
case <-handshakeCtx.Done():
|
|
||||||
// Close the connection, discarding the error
|
|
||||||
_ = c.conn.Close()
|
|
||||||
interruptRes <- handshakeCtx.Err()
|
|
||||||
case <-done:
|
|
||||||
interruptRes <- nil
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
c.handshakeMutex.Lock()
|
c.handshakeMutex.Lock()
|
||||||
defer c.handshakeMutex.Unlock()
|
defer c.handshakeMutex.Unlock()
|
||||||
|
|
||||||
|
@ -1442,7 +1388,7 @@ func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
|
||||||
c.in.Lock()
|
c.in.Lock()
|
||||||
defer c.in.Unlock()
|
defer c.in.Unlock()
|
||||||
|
|
||||||
c.handshakeErr = c.handshakeFn(handshakeCtx)
|
c.handshakeErr = c.handshakeFn()
|
||||||
if c.handshakeErr == nil {
|
if c.handshakeErr == nil {
|
||||||
c.handshakes++
|
c.handshakes++
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -6,7 +6,6 @@ package tls
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
|
||||||
"crypto"
|
"crypto"
|
||||||
"crypto/ecdsa"
|
"crypto/ecdsa"
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
|
@ -25,7 +24,6 @@ import (
|
||||||
|
|
||||||
type clientHandshakeState struct {
|
type clientHandshakeState struct {
|
||||||
c *Conn
|
c *Conn
|
||||||
ctx context.Context
|
|
||||||
serverHello *serverHelloMsg
|
serverHello *serverHelloMsg
|
||||||
hello *clientHelloMsg
|
hello *clientHelloMsg
|
||||||
suite *cipherSuite
|
suite *cipherSuite
|
||||||
|
@ -136,7 +134,7 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, ecdheParameters, error) {
|
||||||
return hello, params, nil
|
return hello, params, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) clientHandshake(ctx context.Context) (err error) {
|
func (c *Conn) clientHandshake() (err error) {
|
||||||
if c.config == nil {
|
if c.config == nil {
|
||||||
c.config = defaultConfig()
|
c.config = defaultConfig()
|
||||||
}
|
}
|
||||||
|
@ -200,7 +198,6 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) {
|
||||||
if c.vers == VersionTLS13 {
|
if c.vers == VersionTLS13 {
|
||||||
hs := &clientHandshakeStateTLS13{
|
hs := &clientHandshakeStateTLS13{
|
||||||
c: c,
|
c: c,
|
||||||
ctx: ctx,
|
|
||||||
serverHello: serverHello,
|
serverHello: serverHello,
|
||||||
hello: hello,
|
hello: hello,
|
||||||
ecdheParams: ecdheParams,
|
ecdheParams: ecdheParams,
|
||||||
|
@ -215,7 +212,6 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) {
|
||||||
|
|
||||||
hs := &clientHandshakeState{
|
hs := &clientHandshakeState{
|
||||||
c: c,
|
c: c,
|
||||||
ctx: ctx,
|
|
||||||
serverHello: serverHello,
|
serverHello: serverHello,
|
||||||
hello: hello,
|
hello: hello,
|
||||||
session: session,
|
session: session,
|
||||||
|
@ -544,7 +540,7 @@ func (hs *clientHandshakeState) doFullHandshake() error {
|
||||||
certRequested = true
|
certRequested = true
|
||||||
hs.finishedHash.Write(certReq.marshal())
|
hs.finishedHash.Write(certReq.marshal())
|
||||||
|
|
||||||
cri := certificateRequestInfoFromMsg(hs.ctx, c.vers, certReq)
|
cri := certificateRequestInfoFromMsg(c.vers, certReq)
|
||||||
if chainToSend, err = c.getClientCertificate(cri); err != nil {
|
if chainToSend, err = c.getClientCertificate(cri); err != nil {
|
||||||
c.sendAlert(alertInternalError)
|
c.sendAlert(alertInternalError)
|
||||||
return err
|
return err
|
||||||
|
@ -884,11 +880,10 @@ func (c *Conn) verifyServerCertificate(certificates [][]byte) error {
|
||||||
|
|
||||||
// certificateRequestInfoFromMsg generates a CertificateRequestInfo from a TLS
|
// certificateRequestInfoFromMsg generates a CertificateRequestInfo from a TLS
|
||||||
// <= 1.2 CertificateRequest, making an effort to fill in missing information.
|
// <= 1.2 CertificateRequest, making an effort to fill in missing information.
|
||||||
func certificateRequestInfoFromMsg(ctx context.Context, vers uint16, certReq *certificateRequestMsg) *CertificateRequestInfo {
|
func certificateRequestInfoFromMsg(vers uint16, certReq *certificateRequestMsg) *CertificateRequestInfo {
|
||||||
cri := &CertificateRequestInfo{
|
cri := &CertificateRequestInfo{
|
||||||
AcceptableCAs: certReq.certificateAuthorities,
|
AcceptableCAs: certReq.certificateAuthorities,
|
||||||
Version: vers,
|
Version: vers,
|
||||||
ctx: ctx,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var rsaAvail, ecAvail bool
|
var rsaAvail, ecAvail bool
|
||||||
|
|
|
@ -6,7 +6,6 @@ package tls
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
@ -21,7 +20,6 @@ import (
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"reflect"
|
"reflect"
|
||||||
"runtime"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -2513,37 +2511,3 @@ func testResumptionKeepsOCSPAndSCT(t *testing.T, ver uint16) {
|
||||||
serverConfig.Certificates[0].SignedCertificateTimestamps, ccs.SignedCertificateTimestamps)
|
serverConfig.Certificates[0].SignedCertificateTimestamps, ccs.SignedCertificateTimestamps)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClientHandshakeContextCancellation(t *testing.T) {
|
|
||||||
c, s := localPipe(t)
|
|
||||||
serverConfig := testConfig.Clone()
|
|
||||||
serverErr := make(chan error, 1)
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
go func() {
|
|
||||||
defer close(serverErr)
|
|
||||||
defer s.Close()
|
|
||||||
conn := Server(s, serverConfig)
|
|
||||||
_, err := conn.readClientHello(ctx)
|
|
||||||
cancel()
|
|
||||||
serverErr <- err
|
|
||||||
}()
|
|
||||||
cli := Client(c, testConfig)
|
|
||||||
err := cli.HandshakeContext(ctx)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("Client handshake did not error when the context was canceled")
|
|
||||||
}
|
|
||||||
if err != context.Canceled {
|
|
||||||
t.Errorf("Unexpected client handshake error: %v", err)
|
|
||||||
}
|
|
||||||
if err := <-serverErr; err != nil {
|
|
||||||
t.Errorf("Unexpected server error: %v", err)
|
|
||||||
}
|
|
||||||
if runtime.GOARCH == "wasm" {
|
|
||||||
t.Skip("conn.Close does not error as expected when called multiple times on WASM")
|
|
||||||
}
|
|
||||||
err = cli.Close()
|
|
||||||
if err == nil {
|
|
||||||
t.Error("Client connection was not closed when the context was canceled")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -6,7 +6,6 @@ package tls
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
|
||||||
"crypto"
|
"crypto"
|
||||||
"crypto/hmac"
|
"crypto/hmac"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
|
@ -18,7 +17,6 @@ import (
|
||||||
|
|
||||||
type clientHandshakeStateTLS13 struct {
|
type clientHandshakeStateTLS13 struct {
|
||||||
c *Conn
|
c *Conn
|
||||||
ctx context.Context
|
|
||||||
serverHello *serverHelloMsg
|
serverHello *serverHelloMsg
|
||||||
hello *clientHelloMsg
|
hello *clientHelloMsg
|
||||||
ecdheParams ecdheParameters
|
ecdheParams ecdheParameters
|
||||||
|
@ -557,7 +555,6 @@ func (hs *clientHandshakeStateTLS13) sendClientCertificate() error {
|
||||||
AcceptableCAs: hs.certReq.certificateAuthorities,
|
AcceptableCAs: hs.certReq.certificateAuthorities,
|
||||||
SignatureSchemes: hs.certReq.supportedSignatureAlgorithms,
|
SignatureSchemes: hs.certReq.supportedSignatureAlgorithms,
|
||||||
Version: c.vers,
|
Version: c.vers,
|
||||||
ctx: hs.ctx,
|
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
package tls
|
package tls
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"crypto"
|
"crypto"
|
||||||
"crypto/ecdsa"
|
"crypto/ecdsa"
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
|
@ -24,7 +23,6 @@ import (
|
||||||
// It's discarded once the handshake has completed.
|
// It's discarded once the handshake has completed.
|
||||||
type serverHandshakeState struct {
|
type serverHandshakeState struct {
|
||||||
c *Conn
|
c *Conn
|
||||||
ctx context.Context
|
|
||||||
clientHello *clientHelloMsg
|
clientHello *clientHelloMsg
|
||||||
hello *serverHelloMsg
|
hello *serverHelloMsg
|
||||||
suite *cipherSuite
|
suite *cipherSuite
|
||||||
|
@ -39,8 +37,8 @@ type serverHandshakeState struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// serverHandshake performs a TLS handshake as a server.
|
// serverHandshake performs a TLS handshake as a server.
|
||||||
func (c *Conn) serverHandshake(ctx context.Context) error {
|
func (c *Conn) serverHandshake() error {
|
||||||
clientHello, err := c.readClientHello(ctx)
|
clientHello, err := c.readClientHello()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -48,7 +46,6 @@ func (c *Conn) serverHandshake(ctx context.Context) error {
|
||||||
if c.vers == VersionTLS13 {
|
if c.vers == VersionTLS13 {
|
||||||
hs := serverHandshakeStateTLS13{
|
hs := serverHandshakeStateTLS13{
|
||||||
c: c,
|
c: c,
|
||||||
ctx: ctx,
|
|
||||||
clientHello: clientHello,
|
clientHello: clientHello,
|
||||||
}
|
}
|
||||||
return hs.handshake()
|
return hs.handshake()
|
||||||
|
@ -56,7 +53,6 @@ func (c *Conn) serverHandshake(ctx context.Context) error {
|
||||||
|
|
||||||
hs := serverHandshakeState{
|
hs := serverHandshakeState{
|
||||||
c: c,
|
c: c,
|
||||||
ctx: ctx,
|
|
||||||
clientHello: clientHello,
|
clientHello: clientHello,
|
||||||
}
|
}
|
||||||
return hs.handshake()
|
return hs.handshake()
|
||||||
|
@ -128,7 +124,7 @@ func (hs *serverHandshakeState) handshake() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// readClientHello reads a ClientHello message and selects the protocol version.
|
// readClientHello reads a ClientHello message and selects the protocol version.
|
||||||
func (c *Conn) readClientHello(ctx context.Context) (*clientHelloMsg, error) {
|
func (c *Conn) readClientHello() (*clientHelloMsg, error) {
|
||||||
msg, err := c.readHandshake()
|
msg, err := c.readHandshake()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -142,7 +138,7 @@ func (c *Conn) readClientHello(ctx context.Context) (*clientHelloMsg, error) {
|
||||||
var configForClient *Config
|
var configForClient *Config
|
||||||
originalConfig := c.config
|
originalConfig := c.config
|
||||||
if c.config.GetConfigForClient != nil {
|
if c.config.GetConfigForClient != nil {
|
||||||
chi := clientHelloInfo(ctx, c, clientHello)
|
chi := clientHelloInfo(c, clientHello)
|
||||||
if configForClient, err = c.config.GetConfigForClient(chi); err != nil {
|
if configForClient, err = c.config.GetConfigForClient(chi); err != nil {
|
||||||
c.sendAlert(alertInternalError)
|
c.sendAlert(alertInternalError)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -224,7 +220,7 @@ func (hs *serverHandshakeState) processClientHello() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
hs.cert, err = c.config.getCertificate(clientHelloInfo(hs.ctx, c, hs.clientHello))
|
hs.cert, err = c.config.getCertificate(clientHelloInfo(c, hs.clientHello))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == errNoCertificates {
|
if err == errNoCertificates {
|
||||||
c.sendAlert(alertUnrecognizedName)
|
c.sendAlert(alertUnrecognizedName)
|
||||||
|
@ -832,7 +828,7 @@ func (c *Conn) processCertsFromClient(certificate Certificate) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func clientHelloInfo(ctx context.Context, c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo {
|
func clientHelloInfo(c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo {
|
||||||
supportedVersions := clientHello.supportedVersions
|
supportedVersions := clientHello.supportedVersions
|
||||||
if len(clientHello.supportedVersions) == 0 {
|
if len(clientHello.supportedVersions) == 0 {
|
||||||
supportedVersions = supportedVersionsFromMax(clientHello.vers)
|
supportedVersions = supportedVersionsFromMax(clientHello.vers)
|
||||||
|
@ -848,6 +844,5 @@ func clientHelloInfo(ctx context.Context, c *Conn, clientHello *clientHelloMsg)
|
||||||
SupportedVersions: supportedVersions,
|
SupportedVersions: supportedVersions,
|
||||||
Conn: c.conn,
|
Conn: c.conn,
|
||||||
config: c.config,
|
config: c.config,
|
||||||
ctx: ctx,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,7 +6,6 @@ package tls
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
|
||||||
"crypto"
|
"crypto"
|
||||||
"crypto/elliptic"
|
"crypto/elliptic"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
|
@ -18,7 +17,6 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
@ -40,12 +38,10 @@ func testClientHelloFailure(t *testing.T, serverConfig *Config, m handshakeMessa
|
||||||
cli.writeRecord(recordTypeHandshake, m.marshal())
|
cli.writeRecord(recordTypeHandshake, m.marshal())
|
||||||
c.Close()
|
c.Close()
|
||||||
}()
|
}()
|
||||||
ctx := context.Background()
|
|
||||||
conn := Server(s, serverConfig)
|
conn := Server(s, serverConfig)
|
||||||
ch, err := conn.readClientHello(ctx)
|
ch, err := conn.readClientHello()
|
||||||
hs := serverHandshakeState{
|
hs := serverHandshakeState{
|
||||||
c: conn,
|
c: conn,
|
||||||
ctx: ctx,
|
|
||||||
clientHello: ch,
|
clientHello: ch,
|
||||||
}
|
}
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
@ -1425,11 +1421,9 @@ func TestSNIGivenOnFailure(t *testing.T) {
|
||||||
c.Close()
|
c.Close()
|
||||||
}()
|
}()
|
||||||
conn := Server(s, serverConfig)
|
conn := Server(s, serverConfig)
|
||||||
ctx := context.Background()
|
ch, err := conn.readClientHello()
|
||||||
ch, err := conn.readClientHello(ctx)
|
|
||||||
hs := serverHandshakeState{
|
hs := serverHandshakeState{
|
||||||
c: conn,
|
c: conn,
|
||||||
ctx: ctx,
|
|
||||||
clientHello: ch,
|
clientHello: ch,
|
||||||
}
|
}
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
@ -1683,46 +1677,6 @@ func TestMultipleCertificates(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServerHandshakeContextCancellation(t *testing.T) {
|
|
||||||
c, s := localPipe(t)
|
|
||||||
clientConfig := testConfig.Clone()
|
|
||||||
clientErr := make(chan error, 1)
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
go func() {
|
|
||||||
defer close(clientErr)
|
|
||||||
defer c.Close()
|
|
||||||
clientHello := &clientHelloMsg{
|
|
||||||
vers: VersionTLS10,
|
|
||||||
random: make([]byte, 32),
|
|
||||||
cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA},
|
|
||||||
compressionMethods: []uint8{compressionNone},
|
|
||||||
}
|
|
||||||
cli := Client(c, clientConfig)
|
|
||||||
_, err := cli.writeRecord(recordTypeHandshake, clientHello.marshal())
|
|
||||||
cancel()
|
|
||||||
clientErr <- err
|
|
||||||
}()
|
|
||||||
conn := Server(s, testConfig)
|
|
||||||
err := conn.HandshakeContext(ctx)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("Server handshake did not error when the context was canceled")
|
|
||||||
}
|
|
||||||
if err != context.Canceled {
|
|
||||||
t.Errorf("Unexpected server handshake error: %v", err)
|
|
||||||
}
|
|
||||||
if err := <-clientErr; err != nil {
|
|
||||||
t.Errorf("Unexpected client error: %v", err)
|
|
||||||
}
|
|
||||||
if runtime.GOARCH == "wasm" {
|
|
||||||
t.Skip("conn.Close does not error as expected when called multiple times on WASM")
|
|
||||||
}
|
|
||||||
err = conn.Close()
|
|
||||||
if err == nil {
|
|
||||||
t.Error("Server connection was not closed when the context was canceled")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAESCipherReordering(t *testing.T) {
|
func TestAESCipherReordering(t *testing.T) {
|
||||||
currentAESSupport := hasAESGCMHardwareSupport
|
currentAESSupport := hasAESGCMHardwareSupport
|
||||||
defer func() { hasAESGCMHardwareSupport = currentAESSupport; initDefaultCipherSuites() }()
|
defer func() { hasAESGCMHardwareSupport = currentAESSupport; initDefaultCipherSuites() }()
|
||||||
|
|
|
@ -6,7 +6,6 @@ package tls
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
|
||||||
"crypto"
|
"crypto"
|
||||||
"crypto/hmac"
|
"crypto/hmac"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
|
@ -24,7 +23,6 @@ const maxClientPSKIdentities = 5
|
||||||
|
|
||||||
type serverHandshakeStateTLS13 struct {
|
type serverHandshakeStateTLS13 struct {
|
||||||
c *Conn
|
c *Conn
|
||||||
ctx context.Context
|
|
||||||
clientHello *clientHelloMsg
|
clientHello *clientHelloMsg
|
||||||
hello *serverHelloMsg
|
hello *serverHelloMsg
|
||||||
sentDummyCCS bool
|
sentDummyCCS bool
|
||||||
|
@ -376,7 +374,7 @@ func (hs *serverHandshakeStateTLS13) pickCertificate() error {
|
||||||
return c.sendAlert(alertMissingExtension)
|
return c.sendAlert(alertMissingExtension)
|
||||||
}
|
}
|
||||||
|
|
||||||
certificate, err := c.config.getCertificate(clientHelloInfo(hs.ctx, c, hs.clientHello))
|
certificate, err := c.config.getCertificate(clientHelloInfo(c, hs.clientHello))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == errNoCertificates {
|
if err == errNoCertificates {
|
||||||
c.sendAlert(alertUnrecognizedName)
|
c.sendAlert(alertUnrecognizedName)
|
||||||
|
|
55
tls.go
55
tls.go
|
@ -25,6 +25,7 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Server returns a new TLS server side connection
|
// Server returns a new TLS server side connection
|
||||||
|
@ -115,16 +116,28 @@ func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*
|
||||||
}
|
}
|
||||||
|
|
||||||
func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
|
func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
|
||||||
if netDialer.Timeout != 0 {
|
// We want the Timeout and Deadline values from dialer to cover the
|
||||||
var cancel context.CancelFunc
|
// whole process: TCP connection and TLS handshake. This means that we
|
||||||
ctx, cancel = context.WithTimeout(ctx, netDialer.Timeout)
|
// also need to start our own timers now.
|
||||||
defer cancel()
|
timeout := netDialer.Timeout
|
||||||
}
|
|
||||||
|
|
||||||
if !netDialer.Deadline.IsZero() {
|
if !netDialer.Deadline.IsZero() {
|
||||||
var cancel context.CancelFunc
|
deadlineTimeout := time.Until(netDialer.Deadline)
|
||||||
ctx, cancel = context.WithDeadline(ctx, netDialer.Deadline)
|
if timeout == 0 || deadlineTimeout < timeout {
|
||||||
defer cancel()
|
timeout = deadlineTimeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// hsErrCh is non-nil if we might not wait for Handshake to complete.
|
||||||
|
var hsErrCh chan error
|
||||||
|
if timeout != 0 || ctx.Done() != nil {
|
||||||
|
hsErrCh = make(chan error, 2)
|
||||||
|
}
|
||||||
|
if timeout != 0 {
|
||||||
|
timer := time.AfterFunc(timeout, func() {
|
||||||
|
hsErrCh <- timeoutError{}
|
||||||
|
})
|
||||||
|
defer timer.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
rawConn, err := netDialer.DialContext(ctx, network, addr)
|
rawConn, err := netDialer.DialContext(ctx, network, addr)
|
||||||
|
@ -151,10 +164,34 @@ func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, conf
|
||||||
}
|
}
|
||||||
|
|
||||||
conn := Client(rawConn, config)
|
conn := Client(rawConn, config)
|
||||||
if err := conn.HandshakeContext(ctx); err != nil {
|
|
||||||
|
if hsErrCh == nil {
|
||||||
|
err = conn.Handshake()
|
||||||
|
} else {
|
||||||
|
go func() {
|
||||||
|
hsErrCh <- conn.Handshake()
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
err = ctx.Err()
|
||||||
|
case err = <-hsErrCh:
|
||||||
|
if err != nil {
|
||||||
|
// If the error was due to the context
|
||||||
|
// closing, prefer the context's error, rather
|
||||||
|
// than some random network teardown error.
|
||||||
|
if e := ctx.Err(); e != nil {
|
||||||
|
err = e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
rawConn.Close()
|
rawConn.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue