mirror of
https://github.com/refraction-networking/utls.git
synced 2025-04-03 20:17:36 +03:00
[dev.boringcrypto] all: merge master into dev.boringcrypto
Change-Id: Iba19903f0565b11c648e1fa6effc07b8f97dc322
This commit is contained in:
commit
aad1dae3d1
12 changed files with 311 additions and 148 deletions
|
@ -162,7 +162,7 @@ type cipherSuite struct {
|
|||
// flags is a bitmask of the suite* values, above.
|
||||
flags int
|
||||
cipher func(key, iv []byte, isRead bool) interface{}
|
||||
mac func(version uint16, macKey []byte) macFunction
|
||||
mac func(key []byte) hash.Hash
|
||||
aead func(key, fixedNonce []byte) aead
|
||||
}
|
||||
|
||||
|
@ -249,30 +249,21 @@ func cipherAES(key, iv []byte, isRead bool) interface{} {
|
|||
return cipher.NewCBCEncrypter(block, iv)
|
||||
}
|
||||
|
||||
// macSHA1 returns a macFunction for the given protocol version.
|
||||
func macSHA1(version uint16, key []byte) macFunction {
|
||||
// macSHA1 returns a SHA-1 based constant time MAC.
|
||||
func macSHA1(key []byte) hash.Hash {
|
||||
h := sha1.New
|
||||
// The BoringCrypto SHA1 does not have a constant-time
|
||||
// checksum function, so don't try to use it.
|
||||
if !boring.Enabled {
|
||||
h = newConstantTimeHash(h)
|
||||
}
|
||||
return tls10MAC{h: hmac.New(h, key)}
|
||||
return hmac.New(h, key)
|
||||
}
|
||||
|
||||
// macSHA256 returns a SHA-256 based MAC. These are only supported in TLS 1.2
|
||||
// so the given version is ignored.
|
||||
func macSHA256(version uint16, key []byte) macFunction {
|
||||
return tls10MAC{h: hmac.New(sha256.New, key)}
|
||||
}
|
||||
|
||||
type macFunction interface {
|
||||
// Size returns the length of the MAC.
|
||||
Size() int
|
||||
// MAC appends the MAC of (seq, header, data) to out. The extra data is fed
|
||||
// into the MAC after obtaining the result to normalize timing. The result
|
||||
// is only valid until the next invocation of MAC as the buffer is reused.
|
||||
MAC(seq, header, data, extra []byte) []byte
|
||||
// macSHA256 returns a SHA-256 based MAC. This is only supported in TLS 1.2 and
|
||||
// is currently only used in disabled-by-default cipher suites.
|
||||
func macSHA256(key []byte) hash.Hash {
|
||||
return hmac.New(sha256.New, key)
|
||||
}
|
||||
|
||||
type aead interface {
|
||||
|
@ -431,26 +422,14 @@ func newConstantTimeHash(h func() hash.Hash) func() hash.Hash {
|
|||
}
|
||||
|
||||
// tls10MAC implements the TLS 1.0 MAC function. RFC 2246, Section 6.2.3.
|
||||
type tls10MAC struct {
|
||||
h hash.Hash
|
||||
buf []byte
|
||||
}
|
||||
|
||||
func (s tls10MAC) Size() int {
|
||||
return s.h.Size()
|
||||
}
|
||||
|
||||
// MAC is guaranteed to take constant time, as long as
|
||||
// len(seq)+len(header)+len(data)+len(extra) is constant. extra is not fed into
|
||||
// the MAC, but is only provided to make the timing profile constant.
|
||||
func (s tls10MAC) MAC(seq, header, data, extra []byte) []byte {
|
||||
s.h.Reset()
|
||||
s.h.Write(seq)
|
||||
s.h.Write(header)
|
||||
s.h.Write(data)
|
||||
res := s.h.Sum(s.buf[:0])
|
||||
func tls10MAC(h hash.Hash, out, seq, header, data, extra []byte) []byte {
|
||||
h.Reset()
|
||||
h.Write(seq)
|
||||
h.Write(header)
|
||||
h.Write(data)
|
||||
res := h.Sum(out)
|
||||
if extra != nil {
|
||||
s.h.Write(extra)
|
||||
h.Write(extra)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
|
41
common.go
41
common.go
|
@ -7,6 +7,7 @@ package tls
|
|||
import (
|
||||
"bytes"
|
||||
"container/list"
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
|
@ -294,10 +295,26 @@ func (cs *ConnectionState) ExportKeyingMaterial(label string, context []byte, le
|
|||
type ClientAuthType int
|
||||
|
||||
const (
|
||||
// NoClientCert indicates that no client certificate should be requested
|
||||
// during the handshake, and if any certificates are sent they will not
|
||||
// be verified.
|
||||
NoClientCert ClientAuthType = iota
|
||||
// RequestClientCert indicates that a client certificate should be requested
|
||||
// during the handshake, but does not require that the client send any
|
||||
// certificates.
|
||||
RequestClientCert
|
||||
// RequireAnyClientCert indicates that a client certificate should be requested
|
||||
// during the handshake, and that at least one certificate is required to be
|
||||
// sent by the client, but that certificate is not required to be valid.
|
||||
RequireAnyClientCert
|
||||
// VerifyClientCertIfGiven indicates that a client certificate should be requested
|
||||
// during the handshake, but does not require that the client sends a
|
||||
// certificate. If the client does send a certificate it is required to be
|
||||
// valid.
|
||||
VerifyClientCertIfGiven
|
||||
// RequireAndVerifyClientCert indicates that a client certificate should be requested
|
||||
// during the handshake, and that at least one valid certificate is required
|
||||
// to be sent by the client.
|
||||
RequireAndVerifyClientCert
|
||||
)
|
||||
|
||||
|
@ -428,6 +445,16 @@ type ClientHelloInfo struct {
|
|||
// config is embedded by the GetCertificate or GetConfigForClient caller,
|
||||
// for use with SupportsCertificate.
|
||||
config *Config
|
||||
|
||||
// ctx is the context of the handshake that is in progress.
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// Context returns the context of the handshake that is in progress.
|
||||
// This context is a child of the context passed to HandshakeContext,
|
||||
// if any, and is canceled when the handshake concludes.
|
||||
func (c *ClientHelloInfo) Context() context.Context {
|
||||
return c.ctx
|
||||
}
|
||||
|
||||
// CertificateRequestInfo contains information from a server's
|
||||
|
@ -446,6 +473,16 @@ type CertificateRequestInfo struct {
|
|||
|
||||
// Version is the TLS version that was negotiated for this connection.
|
||||
Version uint16
|
||||
|
||||
// ctx is the context of the handshake that is in progress.
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// Context returns the context of the handshake that is in progress.
|
||||
// This context is a child of the context passed to HandshakeContext,
|
||||
// if any, and is canceled when the handshake concludes.
|
||||
func (c *CertificateRequestInfo) Context() context.Context {
|
||||
return c.ctx
|
||||
}
|
||||
|
||||
// RenegotiationSupport enumerates the different levels of support for TLS
|
||||
|
@ -1256,7 +1293,9 @@ func (c *Config) BuildNameToCertificate() {
|
|||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if len(x509Cert.Subject.CommonName) > 0 {
|
||||
// If SANs are *not* present, some clients will consider the certificate
|
||||
// valid for the name in the Common Name.
|
||||
if x509Cert.Subject.CommonName != "" && len(x509Cert.DNSNames) == 0 {
|
||||
c.NameToCertificate[x509Cert.Subject.CommonName] = cert
|
||||
}
|
||||
for _, san := range x509Cert.DNSNames {
|
||||
|
|
166
conn.go
166
conn.go
|
@ -8,11 +8,13 @@ package tls
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/cipher"
|
||||
"crypto/subtle"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
@ -26,7 +28,7 @@ type Conn struct {
|
|||
// constant
|
||||
conn net.Conn
|
||||
isClient bool
|
||||
handshakeFn func() error // (*Conn).clientHandshake or serverHandshake
|
||||
handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake
|
||||
|
||||
// handshakeStatus is 1 if the connection is currently transferring
|
||||
// application data (i.e. is not currently processing a handshake).
|
||||
|
@ -94,7 +96,6 @@ type Conn struct {
|
|||
rawInput bytes.Buffer // raw input, starting with a record header
|
||||
input bytes.Reader // application data waiting to be read, from rawInput.Next
|
||||
hand bytes.Buffer // handshake data waiting to be read
|
||||
outBuf []byte // scratch buffer used by out.encrypt
|
||||
buffering bool // whether records are buffered in sendBuf
|
||||
sendBuf []byte // a buffer of records waiting to be sent
|
||||
|
||||
|
@ -155,15 +156,16 @@ func (c *Conn) SetWriteDeadline(t time.Time) error {
|
|||
type halfConn struct {
|
||||
sync.Mutex
|
||||
|
||||
err error // first permanent error
|
||||
version uint16 // protocol version
|
||||
cipher interface{} // cipher algorithm
|
||||
mac macFunction
|
||||
seq [8]byte // 64-bit sequence number
|
||||
additionalData [13]byte // to avoid allocs; interface method args escape
|
||||
err error // first permanent error
|
||||
version uint16 // protocol version
|
||||
cipher interface{} // cipher algorithm
|
||||
mac hash.Hash
|
||||
seq [8]byte // 64-bit sequence number
|
||||
|
||||
scratchBuf [13]byte // to avoid allocs; interface method args escape
|
||||
|
||||
nextCipher interface{} // next encryption state
|
||||
nextMac macFunction // next MAC algorithm
|
||||
nextMac hash.Hash // next MAC algorithm
|
||||
|
||||
trafficSecret []byte // current TLS 1.3 traffic secret
|
||||
}
|
||||
|
@ -188,7 +190,7 @@ func (hc *halfConn) setErrorLocked(err error) error {
|
|||
|
||||
// prepareCipherSpec sets the encryption and MAC states
|
||||
// that a subsequent changeCipherSpec will use.
|
||||
func (hc *halfConn) prepareCipherSpec(version uint16, cipher interface{}, mac macFunction) {
|
||||
func (hc *halfConn) prepareCipherSpec(version uint16, cipher interface{}, mac hash.Hash) {
|
||||
hc.version = version
|
||||
hc.nextCipher = cipher
|
||||
hc.nextMac = mac
|
||||
|
@ -350,15 +352,14 @@ func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) {
|
|||
}
|
||||
payload = payload[explicitNonceLen:]
|
||||
|
||||
additionalData := hc.additionalData[:]
|
||||
var additionalData []byte
|
||||
if hc.version == VersionTLS13 {
|
||||
additionalData = record[:recordHeaderLen]
|
||||
} else {
|
||||
copy(additionalData, hc.seq[:])
|
||||
copy(additionalData[8:], record[:3])
|
||||
additionalData = append(hc.scratchBuf[:0], hc.seq[:]...)
|
||||
additionalData = append(additionalData, record[:3]...)
|
||||
n := len(payload) - c.Overhead()
|
||||
additionalData[11] = byte(n >> 8)
|
||||
additionalData[12] = byte(n)
|
||||
additionalData = append(additionalData, byte(n>>8), byte(n))
|
||||
}
|
||||
|
||||
var err error
|
||||
|
@ -424,7 +425,7 @@ func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) {
|
|||
record[3] = byte(n >> 8)
|
||||
record[4] = byte(n)
|
||||
remoteMAC := payload[n : n+macSize]
|
||||
localMAC := hc.mac.MAC(hc.seq[0:], record[:recordHeaderLen], payload[:n], payload[n+macSize:])
|
||||
localMAC := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload[:n], payload[n+macSize:])
|
||||
|
||||
// This is equivalent to checking the MACs and paddingGood
|
||||
// separately, but in constant-time to prevent distinguishing
|
||||
|
@ -460,7 +461,7 @@ func sliceForAppend(in []byte, n int) (head, tail []byte) {
|
|||
}
|
||||
|
||||
// encrypt encrypts payload, adding the appropriate nonce and/or MAC, and
|
||||
// appends it to record, which contains the record header.
|
||||
// appends it to record, which must already contain the record header.
|
||||
func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) {
|
||||
if hc.cipher == nil {
|
||||
return append(record, payload...), nil
|
||||
|
@ -477,7 +478,7 @@ func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, err
|
|||
// an 8 bytes nonce but its nonces must be unpredictable (see RFC
|
||||
// 5246, Appendix F.3), forcing us to use randomness. That's not
|
||||
// 3DES' biggest problem anyway because the birthday bound on block
|
||||
// collision is reached first due to its simlarly small block size
|
||||
// collision is reached first due to its similarly small block size
|
||||
// (see the Sweet32 attack).
|
||||
copy(explicitNonce, hc.seq[:])
|
||||
} else {
|
||||
|
@ -487,14 +488,10 @@ func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, err
|
|||
}
|
||||
}
|
||||
|
||||
var mac []byte
|
||||
if hc.mac != nil {
|
||||
mac = hc.mac.MAC(hc.seq[:], record[:recordHeaderLen], payload, nil)
|
||||
}
|
||||
|
||||
var dst []byte
|
||||
switch c := hc.cipher.(type) {
|
||||
case cipher.Stream:
|
||||
mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil)
|
||||
record, dst = sliceForAppend(record, len(payload)+len(mac))
|
||||
c.XORKeyStream(dst[:len(payload)], payload)
|
||||
c.XORKeyStream(dst[len(payload):], mac)
|
||||
|
@ -518,11 +515,12 @@ func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, err
|
|||
record = c.Seal(record[:recordHeaderLen],
|
||||
nonce, record[recordHeaderLen:], record[:recordHeaderLen])
|
||||
} else {
|
||||
copy(hc.additionalData[:], hc.seq[:])
|
||||
copy(hc.additionalData[8:], record)
|
||||
record = c.Seal(record, nonce, payload, hc.additionalData[:])
|
||||
additionalData := append(hc.scratchBuf[:0], hc.seq[:]...)
|
||||
additionalData = append(additionalData, record[:recordHeaderLen]...)
|
||||
record = c.Seal(record, nonce, payload, additionalData)
|
||||
}
|
||||
case cbcMode:
|
||||
mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil)
|
||||
blockSize := c.BlockSize()
|
||||
plaintextLen := len(payload) + len(mac)
|
||||
paddingLen := blockSize - plaintextLen%blockSize
|
||||
|
@ -928,9 +926,28 @@ func (c *Conn) flush() (int, error) {
|
|||
return n, err
|
||||
}
|
||||
|
||||
// outBufPool pools the record-sized scratch buffers used by writeRecordLocked.
|
||||
var outBufPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return new([]byte)
|
||||
},
|
||||
}
|
||||
|
||||
// writeRecordLocked writes a TLS record with the given type and payload to the
|
||||
// connection and updates the record layer state.
|
||||
func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
|
||||
outBufPtr := outBufPool.Get().(*[]byte)
|
||||
outBuf := *outBufPtr
|
||||
defer func() {
|
||||
// You might be tempted to simplify this by just passing &outBuf to Put,
|
||||
// but that would make the local copy of the outBuf slice header escape
|
||||
// to the heap, causing an allocation. Instead, we keep around the
|
||||
// pointer to the slice header returned by Get, which is already on the
|
||||
// heap, and overwrite and return that.
|
||||
*outBufPtr = outBuf
|
||||
outBufPool.Put(outBufPtr)
|
||||
}()
|
||||
|
||||
var n int
|
||||
for len(data) > 0 {
|
||||
m := len(data)
|
||||
|
@ -938,8 +955,8 @@ func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
|
|||
m = maxPayload
|
||||
}
|
||||
|
||||
_, c.outBuf = sliceForAppend(c.outBuf[:0], recordHeaderLen)
|
||||
c.outBuf[0] = byte(typ)
|
||||
_, outBuf = sliceForAppend(outBuf[:0], recordHeaderLen)
|
||||
outBuf[0] = byte(typ)
|
||||
vers := c.vers
|
||||
if vers == 0 {
|
||||
// Some TLS servers fail if the record version is
|
||||
|
@ -950,17 +967,17 @@ func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
|
|||
// See RFC 8446, Section 5.1.
|
||||
vers = VersionTLS12
|
||||
}
|
||||
c.outBuf[1] = byte(vers >> 8)
|
||||
c.outBuf[2] = byte(vers)
|
||||
c.outBuf[3] = byte(m >> 8)
|
||||
c.outBuf[4] = byte(m)
|
||||
outBuf[1] = byte(vers >> 8)
|
||||
outBuf[2] = byte(vers)
|
||||
outBuf[3] = byte(m >> 8)
|
||||
outBuf[4] = byte(m)
|
||||
|
||||
var err error
|
||||
c.outBuf, err = c.out.encrypt(c.outBuf, data[:m], c.config.rand())
|
||||
outBuf, err = c.out.encrypt(outBuf, data[:m], c.config.rand())
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
if _, err := c.write(c.outBuf); err != nil {
|
||||
if _, err := c.write(outBuf); err != nil {
|
||||
return n, err
|
||||
}
|
||||
n += m
|
||||
|
@ -1074,6 +1091,11 @@ var (
|
|||
)
|
||||
|
||||
// Write writes data to the connection.
|
||||
//
|
||||
// As Write calls Handshake, in order to prevent indefinite blocking a deadline
|
||||
// must be set for both Read and Write before Write is called when the handshake
|
||||
// has not yet completed. See SetDeadline, SetReadDeadline, and
|
||||
// SetWriteDeadline.
|
||||
func (c *Conn) Write(b []byte) (int, error) {
|
||||
// interlock with Close below
|
||||
for {
|
||||
|
@ -1169,7 +1191,7 @@ func (c *Conn) handleRenegotiation() error {
|
|||
defer c.handshakeMutex.Unlock()
|
||||
|
||||
atomic.StoreUint32(&c.handshakeStatus, 0)
|
||||
if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil {
|
||||
if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil {
|
||||
c.handshakes++
|
||||
}
|
||||
return c.handshakeErr
|
||||
|
@ -1232,8 +1254,12 @@ func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// Read can be made to time out and return a net.Error with Timeout() == true
|
||||
// after a fixed time limit; see SetDeadline and SetReadDeadline.
|
||||
// Read reads data from the connection.
|
||||
//
|
||||
// As Read calls Handshake, in order to prevent indefinite blocking a deadline
|
||||
// must be set for both Read and Write before Read is called when the handshake
|
||||
// has not yet completed. See SetDeadline, SetReadDeadline, and
|
||||
// SetWriteDeadline.
|
||||
func (c *Conn) Read(b []byte) (int, error) {
|
||||
if err := c.Handshake(); err != nil {
|
||||
return 0, err
|
||||
|
@ -1301,9 +1327,10 @@ func (c *Conn) Close() error {
|
|||
}
|
||||
|
||||
var alertErr error
|
||||
|
||||
if c.handshakeComplete() {
|
||||
alertErr = c.closeNotify()
|
||||
if err := c.closeNotify(); err != nil {
|
||||
alertErr = fmt.Errorf("tls: failed to send closeNotify alert (but connection was closed anyway): %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.conn.Close(); err != nil {
|
||||
|
@ -1330,8 +1357,12 @@ func (c *Conn) closeNotify() error {
|
|||
defer c.out.Unlock()
|
||||
|
||||
if !c.closeNotifySent {
|
||||
// Set a Write Deadline to prevent possibly blocking forever.
|
||||
c.SetWriteDeadline(time.Now().Add(time.Second * 5))
|
||||
c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify)
|
||||
c.closeNotifySent = true
|
||||
// Any subsequent writes will fail.
|
||||
c.SetWriteDeadline(time.Now())
|
||||
}
|
||||
return c.closeNotifyErr
|
||||
}
|
||||
|
@ -1343,8 +1374,61 @@ func (c *Conn) closeNotify() error {
|
|||
// first Read or Write will call it automatically.
|
||||
//
|
||||
// For control over canceling or setting a timeout on a handshake, use
|
||||
// the Dialer's DialContext method.
|
||||
// HandshakeContext or the Dialer's DialContext method instead.
|
||||
func (c *Conn) Handshake() error {
|
||||
return c.HandshakeContext(context.Background())
|
||||
}
|
||||
|
||||
// HandshakeContext runs the client or server handshake
|
||||
// protocol if it has not yet been run.
|
||||
//
|
||||
// The provided Context must be non-nil. If the context is canceled before
|
||||
// the handshake is complete, the handshake is interrupted and an error is returned.
|
||||
// Once the handshake has completed, cancellation of the context will not affect the
|
||||
// connection.
|
||||
//
|
||||
// Most uses of this package need not call HandshakeContext explicitly: the
|
||||
// first Read or Write will call it automatically.
|
||||
func (c *Conn) HandshakeContext(ctx context.Context) error {
|
||||
// Delegate to unexported method for named return
|
||||
// without confusing documented signature.
|
||||
return c.handshakeContext(ctx)
|
||||
}
|
||||
|
||||
func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
|
||||
handshakeCtx, cancel := context.WithCancel(ctx)
|
||||
// Note: defer this before starting the "interrupter" goroutine
|
||||
// so that we can tell the difference between the input being canceled and
|
||||
// this cancellation. In the former case, we need to close the connection.
|
||||
defer cancel()
|
||||
|
||||
// Start the "interrupter" goroutine, if this context might be canceled.
|
||||
// (The background context cannot).
|
||||
//
|
||||
// The interrupter goroutine waits for the input context to be done and
|
||||
// closes the connection if this happens before the function returns.
|
||||
if ctx.Done() != nil {
|
||||
done := make(chan struct{})
|
||||
interruptRes := make(chan error, 1)
|
||||
defer func() {
|
||||
close(done)
|
||||
if ctxErr := <-interruptRes; ctxErr != nil {
|
||||
// Return context error to user.
|
||||
ret = ctxErr
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
select {
|
||||
case <-handshakeCtx.Done():
|
||||
// Close the connection, discarding the error
|
||||
_ = c.conn.Close()
|
||||
interruptRes <- handshakeCtx.Err()
|
||||
case <-done:
|
||||
interruptRes <- nil
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
c.handshakeMutex.Lock()
|
||||
defer c.handshakeMutex.Unlock()
|
||||
|
||||
|
@ -1358,7 +1442,7 @@ func (c *Conn) Handshake() error {
|
|||
c.in.Lock()
|
||||
defer c.in.Unlock()
|
||||
|
||||
c.handshakeErr = c.handshakeFn()
|
||||
c.handshakeErr = c.handshakeFn(handshakeCtx)
|
||||
if c.handshakeErr == nil {
|
||||
c.handshakes++
|
||||
} else {
|
||||
|
|
|
@ -6,6 +6,7 @@ package tls
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
|
@ -14,6 +15,7 @@ import (
|
|||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
|
@ -23,6 +25,7 @@ import (
|
|||
|
||||
type clientHandshakeState struct {
|
||||
c *Conn
|
||||
ctx context.Context
|
||||
serverHello *serverHelloMsg
|
||||
hello *clientHelloMsg
|
||||
suite *cipherSuite
|
||||
|
@ -136,7 +139,7 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, ecdheParameters, error) {
|
|||
return hello, params, nil
|
||||
}
|
||||
|
||||
func (c *Conn) clientHandshake() (err error) {
|
||||
func (c *Conn) clientHandshake(ctx context.Context) (err error) {
|
||||
if c.config == nil {
|
||||
c.config = defaultConfig()
|
||||
}
|
||||
|
@ -200,6 +203,7 @@ func (c *Conn) clientHandshake() (err error) {
|
|||
if c.vers == VersionTLS13 {
|
||||
hs := &clientHandshakeStateTLS13{
|
||||
c: c,
|
||||
ctx: ctx,
|
||||
serverHello: serverHello,
|
||||
hello: hello,
|
||||
ecdheParams: ecdheParams,
|
||||
|
@ -214,6 +218,7 @@ func (c *Conn) clientHandshake() (err error) {
|
|||
|
||||
hs := &clientHandshakeState{
|
||||
c: c,
|
||||
ctx: ctx,
|
||||
serverHello: serverHello,
|
||||
hello: hello,
|
||||
session: session,
|
||||
|
@ -542,7 +547,7 @@ func (hs *clientHandshakeState) doFullHandshake() error {
|
|||
certRequested = true
|
||||
hs.finishedHash.Write(certReq.marshal())
|
||||
|
||||
cri := certificateRequestInfoFromMsg(c.vers, certReq)
|
||||
cri := certificateRequestInfoFromMsg(hs.ctx, c.vers, certReq)
|
||||
if chainToSend, err = c.getClientCertificate(cri); err != nil {
|
||||
c.sendAlert(alertInternalError)
|
||||
return err
|
||||
|
@ -650,12 +655,12 @@ func (hs *clientHandshakeState) establishKeys() error {
|
|||
clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV :=
|
||||
keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen)
|
||||
var clientCipher, serverCipher interface{}
|
||||
var clientHash, serverHash macFunction
|
||||
var clientHash, serverHash hash.Hash
|
||||
if hs.suite.cipher != nil {
|
||||
clientCipher = hs.suite.cipher(clientKey, clientIV, false /* not for reading */)
|
||||
clientHash = hs.suite.mac(c.vers, clientMAC)
|
||||
clientHash = hs.suite.mac(clientMAC)
|
||||
serverCipher = hs.suite.cipher(serverKey, serverIV, true /* for reading */)
|
||||
serverHash = hs.suite.mac(c.vers, serverMAC)
|
||||
serverHash = hs.suite.mac(serverMAC)
|
||||
} else {
|
||||
clientCipher = hs.suite.aead(clientKey, clientIV)
|
||||
serverCipher = hs.suite.aead(serverKey, serverIV)
|
||||
|
@ -884,10 +889,11 @@ func (c *Conn) verifyServerCertificate(certificates [][]byte) error {
|
|||
|
||||
// certificateRequestInfoFromMsg generates a CertificateRequestInfo from a TLS
|
||||
// <= 1.2 CertificateRequest, making an effort to fill in missing information.
|
||||
func certificateRequestInfoFromMsg(vers uint16, certReq *certificateRequestMsg) *CertificateRequestInfo {
|
||||
func certificateRequestInfoFromMsg(ctx context.Context, vers uint16, certReq *certificateRequestMsg) *CertificateRequestInfo {
|
||||
cri := &CertificateRequestInfo{
|
||||
AcceptableCAs: certReq.certificateAuthorities,
|
||||
Version: vers,
|
||||
ctx: ctx,
|
||||
}
|
||||
|
||||
var rsaAvail, ecAvail bool
|
||||
|
|
|
@ -6,6 +6,7 @@ package tls
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
|
@ -20,6 +21,7 @@ import (
|
|||
"os/exec"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
@ -2511,3 +2513,37 @@ func testResumptionKeepsOCSPAndSCT(t *testing.T, ver uint16) {
|
|||
serverConfig.Certificates[0].SignedCertificateTimestamps, ccs.SignedCertificateTimestamps)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientHandshakeContextCancellation(t *testing.T) {
|
||||
c, s := localPipe(t)
|
||||
serverConfig := testConfig.Clone()
|
||||
serverErr := make(chan error, 1)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
go func() {
|
||||
defer close(serverErr)
|
||||
defer s.Close()
|
||||
conn := Server(s, serverConfig)
|
||||
_, err := conn.readClientHello(ctx)
|
||||
cancel()
|
||||
serverErr <- err
|
||||
}()
|
||||
cli := Client(c, testConfig)
|
||||
err := cli.HandshakeContext(ctx)
|
||||
if err == nil {
|
||||
t.Fatal("Client handshake did not error when the context was canceled")
|
||||
}
|
||||
if err != context.Canceled {
|
||||
t.Errorf("Unexpected client handshake error: %v", err)
|
||||
}
|
||||
if err := <-serverErr; err != nil {
|
||||
t.Errorf("Unexpected server error: %v", err)
|
||||
}
|
||||
if runtime.GOARCH == "wasm" {
|
||||
t.Skip("conn.Close does not error as expected when called multiple times on WASM")
|
||||
}
|
||||
err = cli.Close()
|
||||
if err == nil {
|
||||
t.Error("Client connection was not closed when the context was canceled")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ package tls
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/hmac"
|
||||
"crypto/rsa"
|
||||
|
@ -17,6 +18,7 @@ import (
|
|||
|
||||
type clientHandshakeStateTLS13 struct {
|
||||
c *Conn
|
||||
ctx context.Context
|
||||
serverHello *serverHelloMsg
|
||||
hello *clientHelloMsg
|
||||
ecdheParams ecdheParameters
|
||||
|
@ -553,6 +555,7 @@ func (hs *clientHandshakeStateTLS13) sendClientCertificate() error {
|
|||
AcceptableCAs: hs.certReq.certificateAuthorities,
|
||||
SignatureSchemes: hs.certReq.supportedSignatureAlgorithms,
|
||||
Version: c.vers,
|
||||
ctx: hs.ctx,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
package tls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
|
@ -13,6 +14,7 @@ import (
|
|||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
@ -22,6 +24,7 @@ import (
|
|||
// It's discarded once the handshake has completed.
|
||||
type serverHandshakeState struct {
|
||||
c *Conn
|
||||
ctx context.Context
|
||||
clientHello *clientHelloMsg
|
||||
hello *serverHelloMsg
|
||||
suite *cipherSuite
|
||||
|
@ -36,8 +39,8 @@ type serverHandshakeState struct {
|
|||
}
|
||||
|
||||
// serverHandshake performs a TLS handshake as a server.
|
||||
func (c *Conn) serverHandshake() error {
|
||||
clientHello, err := c.readClientHello()
|
||||
func (c *Conn) serverHandshake(ctx context.Context) error {
|
||||
clientHello, err := c.readClientHello(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -45,6 +48,7 @@ func (c *Conn) serverHandshake() error {
|
|||
if c.vers == VersionTLS13 {
|
||||
hs := serverHandshakeStateTLS13{
|
||||
c: c,
|
||||
ctx: ctx,
|
||||
clientHello: clientHello,
|
||||
}
|
||||
return hs.handshake()
|
||||
|
@ -52,6 +56,7 @@ func (c *Conn) serverHandshake() error {
|
|||
|
||||
hs := serverHandshakeState{
|
||||
c: c,
|
||||
ctx: ctx,
|
||||
clientHello: clientHello,
|
||||
}
|
||||
return hs.handshake()
|
||||
|
@ -123,7 +128,7 @@ func (hs *serverHandshakeState) handshake() error {
|
|||
}
|
||||
|
||||
// readClientHello reads a ClientHello message and selects the protocol version.
|
||||
func (c *Conn) readClientHello() (*clientHelloMsg, error) {
|
||||
func (c *Conn) readClientHello(ctx context.Context) (*clientHelloMsg, error) {
|
||||
msg, err := c.readHandshake()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -137,7 +142,7 @@ func (c *Conn) readClientHello() (*clientHelloMsg, error) {
|
|||
var configForClient *Config
|
||||
originalConfig := c.config
|
||||
if c.config.GetConfigForClient != nil {
|
||||
chi := clientHelloInfo(c, clientHello)
|
||||
chi := clientHelloInfo(ctx, c, clientHello)
|
||||
if configForClient, err = c.config.GetConfigForClient(chi); err != nil {
|
||||
c.sendAlert(alertInternalError)
|
||||
return nil, err
|
||||
|
@ -219,7 +224,7 @@ func (hs *serverHandshakeState) processClientHello() error {
|
|||
}
|
||||
}
|
||||
|
||||
hs.cert, err = c.config.getCertificate(clientHelloInfo(c, hs.clientHello))
|
||||
hs.cert, err = c.config.getCertificate(clientHelloInfo(hs.ctx, c, hs.clientHello))
|
||||
if err != nil {
|
||||
if err == errNoCertificates {
|
||||
c.sendAlert(alertUnrecognizedName)
|
||||
|
@ -641,13 +646,13 @@ func (hs *serverHandshakeState) establishKeys() error {
|
|||
keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen)
|
||||
|
||||
var clientCipher, serverCipher interface{}
|
||||
var clientHash, serverHash macFunction
|
||||
var clientHash, serverHash hash.Hash
|
||||
|
||||
if hs.suite.aead == nil {
|
||||
clientCipher = hs.suite.cipher(clientKey, clientIV, true /* for reading */)
|
||||
clientHash = hs.suite.mac(c.vers, clientMAC)
|
||||
clientHash = hs.suite.mac(clientMAC)
|
||||
serverCipher = hs.suite.cipher(serverKey, serverIV, false /* not for reading */)
|
||||
serverHash = hs.suite.mac(c.vers, serverMAC)
|
||||
serverHash = hs.suite.mac(serverMAC)
|
||||
} else {
|
||||
clientCipher = hs.suite.aead(clientKey, clientIV)
|
||||
serverCipher = hs.suite.aead(serverKey, serverIV)
|
||||
|
@ -815,7 +820,7 @@ func (c *Conn) processCertsFromClient(certificate Certificate) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func clientHelloInfo(c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo {
|
||||
func clientHelloInfo(ctx context.Context, c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo {
|
||||
supportedVersions := clientHello.supportedVersions
|
||||
if len(clientHello.supportedVersions) == 0 {
|
||||
supportedVersions = supportedVersionsFromMax(clientHello.vers)
|
||||
|
@ -831,5 +836,6 @@ func clientHelloInfo(c *Conn, clientHello *clientHelloMsg) *ClientHelloInfo {
|
|||
SupportedVersions: supportedVersions,
|
||||
Conn: c.conn,
|
||||
config: c.config,
|
||||
ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ package tls
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/elliptic"
|
||||
"crypto/x509"
|
||||
|
@ -17,6 +18,7 @@ import (
|
|||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -36,10 +38,12 @@ func testClientHelloFailure(t *testing.T, serverConfig *Config, m handshakeMessa
|
|||
cli.writeRecord(recordTypeHandshake, m.marshal())
|
||||
c.Close()
|
||||
}()
|
||||
ctx := context.Background()
|
||||
conn := Server(s, serverConfig)
|
||||
ch, err := conn.readClientHello()
|
||||
ch, err := conn.readClientHello(ctx)
|
||||
hs := serverHandshakeState{
|
||||
c: conn,
|
||||
ctx: ctx,
|
||||
clientHello: ch,
|
||||
}
|
||||
if err == nil {
|
||||
|
@ -1418,9 +1422,11 @@ func TestSNIGivenOnFailure(t *testing.T) {
|
|||
c.Close()
|
||||
}()
|
||||
conn := Server(s, serverConfig)
|
||||
ch, err := conn.readClientHello()
|
||||
ctx := context.Background()
|
||||
ch, err := conn.readClientHello(ctx)
|
||||
hs := serverHandshakeState{
|
||||
c: conn,
|
||||
ctx: ctx,
|
||||
clientHello: ch,
|
||||
}
|
||||
if err == nil {
|
||||
|
@ -1673,3 +1679,43 @@ func TestMultipleCertificates(t *testing.T) {
|
|||
t.Errorf("expected RSA certificate, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerHandshakeContextCancellation(t *testing.T) {
|
||||
c, s := localPipe(t)
|
||||
clientConfig := testConfig.Clone()
|
||||
clientErr := make(chan error, 1)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
go func() {
|
||||
defer close(clientErr)
|
||||
defer c.Close()
|
||||
clientHello := &clientHelloMsg{
|
||||
vers: VersionTLS10,
|
||||
random: make([]byte, 32),
|
||||
cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA},
|
||||
compressionMethods: []uint8{compressionNone},
|
||||
}
|
||||
cli := Client(c, clientConfig)
|
||||
_, err := cli.writeRecord(recordTypeHandshake, clientHello.marshal())
|
||||
cancel()
|
||||
clientErr <- err
|
||||
}()
|
||||
conn := Server(s, testConfig)
|
||||
err := conn.HandshakeContext(ctx)
|
||||
if err == nil {
|
||||
t.Fatal("Server handshake did not error when the context was canceled")
|
||||
}
|
||||
if err != context.Canceled {
|
||||
t.Errorf("Unexpected server handshake error: %v", err)
|
||||
}
|
||||
if err := <-clientErr; err != nil {
|
||||
t.Errorf("Unexpected client error: %v", err)
|
||||
}
|
||||
if runtime.GOARCH == "wasm" {
|
||||
t.Skip("conn.Close does not error as expected when called multiple times on WASM")
|
||||
}
|
||||
err = conn.Close()
|
||||
if err == nil {
|
||||
t.Error("Server connection was not closed when the context was canceled")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ package tls
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/hmac"
|
||||
"crypto/rsa"
|
||||
|
@ -23,6 +24,7 @@ const maxClientPSKIdentities = 5
|
|||
|
||||
type serverHandshakeStateTLS13 struct {
|
||||
c *Conn
|
||||
ctx context.Context
|
||||
clientHello *clientHelloMsg
|
||||
hello *serverHelloMsg
|
||||
sentDummyCCS bool
|
||||
|
@ -365,7 +367,7 @@ func (hs *serverHandshakeStateTLS13) pickCertificate() error {
|
|||
return c.sendAlert(alertMissingExtension)
|
||||
}
|
||||
|
||||
certificate, err := c.config.getCertificate(clientHelloInfo(c, hs.clientHello))
|
||||
certificate, err := c.config.getCertificate(clientHelloInfo(hs.ctx, c, hs.clientHello))
|
||||
if err != nil {
|
||||
if err == errNoCertificates {
|
||||
c.sendAlert(alertUnrecognizedName)
|
||||
|
|
|
@ -86,7 +86,7 @@ func checkOpenSSLVersion() error {
|
|||
println("to update the test data.")
|
||||
println("")
|
||||
println("Configure it with:")
|
||||
println("./Configure enable-weak-ssl-ciphers")
|
||||
println("./Configure enable-weak-ssl-ciphers no-shared")
|
||||
println("and then add the apps/ directory at the front of your PATH.")
|
||||
println("***********************************************")
|
||||
|
||||
|
@ -403,7 +403,7 @@ func testHandshake(t *testing.T, clientConfig, serverConfig *Config) (serverStat
|
|||
}
|
||||
defer cli.Close()
|
||||
clientState = cli.ConnectionState()
|
||||
buf, err := ioutil.ReadAll(cli)
|
||||
buf, err := io.ReadAll(cli)
|
||||
if err != nil {
|
||||
t.Errorf("failed to call cli.Read: %v", err)
|
||||
}
|
||||
|
|
55
tls.go
55
tls.go
|
@ -25,7 +25,6 @@ import (
|
|||
"io/ioutil"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Server returns a new TLS server side connection
|
||||
|
@ -116,28 +115,16 @@ func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*
|
|||
}
|
||||
|
||||
func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
|
||||
// We want the Timeout and Deadline values from dialer to cover the
|
||||
// whole process: TCP connection and TLS handshake. This means that we
|
||||
// also need to start our own timers now.
|
||||
timeout := netDialer.Timeout
|
||||
if netDialer.Timeout != 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, netDialer.Timeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
if !netDialer.Deadline.IsZero() {
|
||||
deadlineTimeout := time.Until(netDialer.Deadline)
|
||||
if timeout == 0 || deadlineTimeout < timeout {
|
||||
timeout = deadlineTimeout
|
||||
}
|
||||
}
|
||||
|
||||
// hsErrCh is non-nil if we might not wait for Handshake to complete.
|
||||
var hsErrCh chan error
|
||||
if timeout != 0 || ctx.Done() != nil {
|
||||
hsErrCh = make(chan error, 2)
|
||||
}
|
||||
if timeout != 0 {
|
||||
timer := time.AfterFunc(timeout, func() {
|
||||
hsErrCh <- timeoutError{}
|
||||
})
|
||||
defer timer.Stop()
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithDeadline(ctx, netDialer.Deadline)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
rawConn, err := netDialer.DialContext(ctx, network, addr)
|
||||
|
@ -164,34 +151,10 @@ func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, conf
|
|||
}
|
||||
|
||||
conn := Client(rawConn, config)
|
||||
|
||||
if hsErrCh == nil {
|
||||
err = conn.Handshake()
|
||||
} else {
|
||||
go func() {
|
||||
hsErrCh <- conn.Handshake()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
err = ctx.Err()
|
||||
case err = <-hsErrCh:
|
||||
if err != nil {
|
||||
// If the error was due to the context
|
||||
// closing, prefer the context's error, rather
|
||||
// than some random network teardown error.
|
||||
if e := ctx.Err(); e != nil {
|
||||
err = e
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if err := conn.HandshakeContext(ctx); err != nil {
|
||||
rawConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -14,7 +14,6 @@ import (
|
|||
"fmt"
|
||||
"internal/testenv"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"math"
|
||||
"net"
|
||||
"os"
|
||||
|
@ -594,7 +593,7 @@ func TestConnCloseWrite(t *testing.T) {
|
|||
}
|
||||
defer srv.Close()
|
||||
|
||||
data, err := ioutil.ReadAll(srv)
|
||||
data, err := io.ReadAll(srv)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -635,7 +634,7 @@ func TestConnCloseWrite(t *testing.T) {
|
|||
return fmt.Errorf("CloseWrite error = %v; want errShutdown", err)
|
||||
}
|
||||
|
||||
data, err := ioutil.ReadAll(conn)
|
||||
data, err := io.ReadAll(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -698,7 +697,7 @@ func TestWarningAlertFlood(t *testing.T) {
|
|||
}
|
||||
defer srv.Close()
|
||||
|
||||
_, err = ioutil.ReadAll(srv)
|
||||
_, err = io.ReadAll(srv)
|
||||
if err == nil {
|
||||
return errors.New("unexpected lack of error from server")
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue