[dev.boringcrypto] all: merge master into dev.boringcrypto

Change-Id: Ia661c871e14445672b7d36a443455302e47cc2a1
This commit is contained in:
Filippo Valsorda 2018-10-25 19:30:18 -04:00
commit a1f6c3f0cb
15 changed files with 443 additions and 495 deletions

View file

@ -23,10 +23,9 @@ import (
func pickSignatureAlgorithm(pubkey crypto.PublicKey, peerSigAlgs, ourSigAlgs []SignatureScheme, tlsVersion uint16) (sigAlg SignatureScheme, sigType uint8, hashFunc crypto.Hash, err error) { func pickSignatureAlgorithm(pubkey crypto.PublicKey, peerSigAlgs, ourSigAlgs []SignatureScheme, tlsVersion uint16) (sigAlg SignatureScheme, sigType uint8, hashFunc crypto.Hash, err error) {
if tlsVersion < VersionTLS12 || len(peerSigAlgs) == 0 { if tlsVersion < VersionTLS12 || len(peerSigAlgs) == 0 {
// For TLS 1.1 and before, the signature algorithm could not be // For TLS 1.1 and before, the signature algorithm could not be
// negotiated and the hash is fixed based on the signature type. // negotiated and the hash is fixed based on the signature type. For TLS
// For TLS 1.2, if the client didn't send signature_algorithms // 1.2, if the client didn't send signature_algorithms extension then we
// extension then we can assume that it supports SHA1. See // can assume that it supports SHA1. See RFC 5246, Section 7.4.1.4.1.
// https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
switch pubkey.(type) { switch pubkey.(type) {
case *rsa.PublicKey: case *rsa.PublicKey:
if tlsVersion < VersionTLS12 { if tlsVersion < VersionTLS12 {

View file

@ -14,9 +14,8 @@ import (
"crypto/sha1" "crypto/sha1"
"crypto/sha256" "crypto/sha256"
"crypto/x509" "crypto/x509"
"hash"
"golang_org/x/crypto/chacha20poly1305" "golang_org/x/crypto/chacha20poly1305"
"hash"
) )
// a keyAgreement implements the client and server side of a TLS key agreement // a keyAgreement implements the client and server side of a TLS key agreement
@ -140,25 +139,29 @@ func macSHA1(version uint16, key []byte) macFunction {
if !boring.Enabled { if !boring.Enabled {
h = newConstantTimeHash(h) h = newConstantTimeHash(h)
} }
return tls10MAC{hmac.New(h, key)} return tls10MAC{h: hmac.New(h, key)}
} }
// macSHA256 returns a SHA-256 based MAC. These are only supported in TLS 1.2 // macSHA256 returns a SHA-256 based MAC. These are only supported in TLS 1.2
// so the given version is ignored. // so the given version is ignored.
func macSHA256(version uint16, key []byte) macFunction { func macSHA256(version uint16, key []byte) macFunction {
return tls10MAC{hmac.New(sha256.New, key)} return tls10MAC{h: hmac.New(sha256.New, key)}
} }
type macFunction interface { type macFunction interface {
// Size returns the length of the MAC.
Size() int Size() int
MAC(digestBuf, seq, header, data, extra []byte) []byte // MAC appends the MAC of (seq, header, data) to out. The extra data is fed
// into the MAC after obtaining the result to normalize timing. The result
// is only valid until the next invocation of MAC as the buffer is reused.
MAC(seq, header, data, extra []byte) []byte
} }
type aead interface { type aead interface {
cipher.AEAD cipher.AEAD
// explicitIVLen returns the number of bytes used by the explicit nonce // explicitNonceLen returns the number of bytes of explicit nonce
// that is included in the record. This is eight for older AEADs and // included in each record. This is eight for older AEADs and
// zero for modern ones. // zero for modern ones.
explicitNonceLen() int explicitNonceLen() int
} }
@ -261,6 +264,7 @@ func aeadChaCha20Poly1305(key, fixedNonce []byte) cipher.AEAD {
type ssl30MAC struct { type ssl30MAC struct {
h hash.Hash h hash.Hash
key []byte key []byte
buf []byte
} }
func (s ssl30MAC) Size() int { func (s ssl30MAC) Size() int {
@ -273,7 +277,7 @@ var ssl30Pad2 = [48]byte{0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0
// MAC does not offer constant timing guarantees for SSL v3.0, since it's deemed // MAC does not offer constant timing guarantees for SSL v3.0, since it's deemed
// useless considering the similar, protocol-level POODLE vulnerability. // useless considering the similar, protocol-level POODLE vulnerability.
func (s ssl30MAC) MAC(digestBuf, seq, header, data, extra []byte) []byte { func (s ssl30MAC) MAC(seq, header, data, extra []byte) []byte {
padLength := 48 padLength := 48
if s.h.Size() == 20 { if s.h.Size() == 20 {
padLength = 40 padLength = 40
@ -286,13 +290,13 @@ func (s ssl30MAC) MAC(digestBuf, seq, header, data, extra []byte) []byte {
s.h.Write(header[:1]) s.h.Write(header[:1])
s.h.Write(header[3:5]) s.h.Write(header[3:5])
s.h.Write(data) s.h.Write(data)
digestBuf = s.h.Sum(digestBuf[:0]) s.buf = s.h.Sum(s.buf[:0])
s.h.Reset() s.h.Reset()
s.h.Write(s.key) s.h.Write(s.key)
s.h.Write(ssl30Pad2[:padLength]) s.h.Write(ssl30Pad2[:padLength])
s.h.Write(digestBuf) s.h.Write(s.buf)
return s.h.Sum(digestBuf[:0]) return s.h.Sum(s.buf[:0])
} }
type constantTimeHash interface { type constantTimeHash interface {
@ -323,9 +327,10 @@ func newConstantTimeHash(h func() hash.Hash) func() hash.Hash {
} }
} }
// tls10MAC implements the TLS 1.0 MAC function. RFC 2246, section 6.2.3. // tls10MAC implements the TLS 1.0 MAC function. RFC 2246, Section 6.2.3.
type tls10MAC struct { type tls10MAC struct {
h hash.Hash h hash.Hash
buf []byte
} }
func (s tls10MAC) Size() int { func (s tls10MAC) Size() int {
@ -335,12 +340,12 @@ func (s tls10MAC) Size() int {
// MAC is guaranteed to take constant time, as long as // MAC is guaranteed to take constant time, as long as
// len(seq)+len(header)+len(data)+len(extra) is constant. extra is not fed into // len(seq)+len(header)+len(data)+len(extra) is constant. extra is not fed into
// the MAC, but is only provided to make the timing profile constant. // the MAC, but is only provided to make the timing profile constant.
func (s tls10MAC) MAC(digestBuf, seq, header, data, extra []byte) []byte { func (s tls10MAC) MAC(seq, header, data, extra []byte) []byte {
s.h.Reset() s.h.Reset()
s.h.Write(seq) s.h.Write(seq)
s.h.Write(header) s.h.Write(header)
s.h.Write(data) s.h.Write(data)
res := s.h.Sum(digestBuf[:0]) res := s.h.Sum(s.buf[:0])
if extra != nil { if extra != nil {
s.h.Write(extra) s.h.Write(extra)
} }
@ -410,7 +415,6 @@ const (
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305 uint16 = 0xcca9 TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305 uint16 = 0xcca9
// TLS_FALLBACK_SCSV isn't a standard cipher suite but an indicator // TLS_FALLBACK_SCSV isn't a standard cipher suite but an indicator
// that the client is doing version fallback. See // that the client is doing version fallback. See RFC 7507.
// https://tools.ietf.org/html/rfc7507.
TLS_FALLBACK_SCSV uint16 = 0x5600 TLS_FALLBACK_SCSV uint16 = 0x5600
) )

View file

@ -80,7 +80,7 @@ const (
extensionSupportedPoints uint16 = 11 extensionSupportedPoints uint16 = 11
extensionSignatureAlgorithms uint16 = 13 extensionSignatureAlgorithms uint16 = 13
extensionALPN uint16 = 16 extensionALPN uint16 = 16
extensionSCT uint16 = 18 // https://tools.ietf.org/html/rfc6962#section-6 extensionSCT uint16 = 18 // RFC 6962, Section 6
extensionSessionTicket uint16 = 35 extensionSessionTicket uint16 = 35
extensionNextProtoNeg uint16 = 13172 // not IANA assigned extensionNextProtoNeg uint16 = 13172 // not IANA assigned
extensionRenegotiationInfo uint16 = 0xff01 extensionRenegotiationInfo uint16 = 0xff01
@ -129,7 +129,7 @@ const (
) )
// Signature algorithms (for internal signaling use). Starting at 16 to avoid overlap with // Signature algorithms (for internal signaling use). Starting at 16 to avoid overlap with
// TLS 1.2 codepoints (RFC 5246, section A.4.1), with which these have nothing to do. // TLS 1.2 codepoints (RFC 5246, Appendix A.4.1), with which these have nothing to do.
const ( const (
signaturePKCS1v15 uint8 = iota + 16 signaturePKCS1v15 uint8 = iota + 16
signatureECDSA signatureECDSA
@ -178,9 +178,9 @@ type ConnectionState struct {
} }
// ExportKeyingMaterial returns length bytes of exported key material in a new // ExportKeyingMaterial returns length bytes of exported key material in a new
// slice as defined in https://tools.ietf.org/html/rfc5705. If context is nil, // slice as defined in RFC 5705. If context is nil, it is not used as part of
// it is not used as part of the seed. If the connection was set to allow // the seed. If the connection was set to allow renegotiation via
// renegotiation via Config.Renegotiation, this function will return an error. // Config.Renegotiation, this function will return an error.
func (cs *ConnectionState) ExportKeyingMaterial(label string, context []byte, length int) ([]byte, error) { func (cs *ConnectionState) ExportKeyingMaterial(label string, context []byte, length int) ([]byte, error) {
return cs.ekm(label, context, length) return cs.ekm(label, context, length)
} }
@ -223,7 +223,7 @@ type ClientSessionCache interface {
} }
// SignatureScheme identifies a signature algorithm supported by TLS. See // SignatureScheme identifies a signature algorithm supported by TLS. See
// https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.3. // RFC 8446, Section 4.2.3.
type SignatureScheme uint16 type SignatureScheme uint16
const ( const (
@ -253,32 +253,27 @@ type ClientHelloInfo struct {
// ServerName indicates the name of the server requested by the client // ServerName indicates the name of the server requested by the client
// in order to support virtual hosting. ServerName is only set if the // in order to support virtual hosting. ServerName is only set if the
// client is using SNI (see // client is using SNI (see RFC 4366, Section 3.1).
// https://tools.ietf.org/html/rfc4366#section-3.1).
ServerName string ServerName string
// SupportedCurves lists the elliptic curves supported by the client. // SupportedCurves lists the elliptic curves supported by the client.
// SupportedCurves is set only if the Supported Elliptic Curves // SupportedCurves is set only if the Supported Elliptic Curves
// Extension is being used (see // Extension is being used (see RFC 4492, Section 5.1.1).
// https://tools.ietf.org/html/rfc4492#section-5.1.1).
SupportedCurves []CurveID SupportedCurves []CurveID
// SupportedPoints lists the point formats supported by the client. // SupportedPoints lists the point formats supported by the client.
// SupportedPoints is set only if the Supported Point Formats Extension // SupportedPoints is set only if the Supported Point Formats Extension
// is being used (see // is being used (see RFC 4492, Section 5.1.2).
// https://tools.ietf.org/html/rfc4492#section-5.1.2).
SupportedPoints []uint8 SupportedPoints []uint8
// SignatureSchemes lists the signature and hash schemes that the client // SignatureSchemes lists the signature and hash schemes that the client
// is willing to verify. SignatureSchemes is set only if the Signature // is willing to verify. SignatureSchemes is set only if the Signature
// Algorithms Extension is being used (see // Algorithms Extension is being used (see RFC 5246, Section 7.4.1.4.1).
// https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1).
SignatureSchemes []SignatureScheme SignatureSchemes []SignatureScheme
// SupportedProtos lists the application protocols supported by the client. // SupportedProtos lists the application protocols supported by the client.
// SupportedProtos is set only if the Application-Layer Protocol // SupportedProtos is set only if the Application-Layer Protocol
// Negotiation Extension is being used (see // Negotiation Extension is being used (see RFC 7301, Section 3.1).
// https://tools.ietf.org/html/rfc7301#section-3.1).
// //
// Servers can select a protocol by setting Config.NextProtos in a // Servers can select a protocol by setting Config.NextProtos in a
// GetConfigForClient return value. // GetConfigForClient return value.

598
conn.go
View file

@ -82,9 +82,10 @@ type Conn struct {
// input/output // input/output
in, out halfConn in, out halfConn
rawInput *block // raw input, right off the wire rawInput bytes.Buffer // raw input, starting with a record header
input *block // application data waiting to be read input bytes.Reader // application data waiting to be read, from rawInput.Next
hand bytes.Buffer // handshake data waiting to be read hand bytes.Buffer // handshake data waiting to be read
outBuf []byte // scratch buffer used by out.encrypt
buffering bool // whether records are buffered in sendBuf buffering bool // whether records are buffered in sendBuf
sendBuf []byte // a buffer of records waiting to be sent sendBuf []byte // a buffer of records waiting to be sent
@ -149,14 +150,10 @@ type halfConn struct {
cipher interface{} // cipher algorithm cipher interface{} // cipher algorithm
mac macFunction mac macFunction
seq [8]byte // 64-bit sequence number seq [8]byte // 64-bit sequence number
bfree *block // list of free blocks
additionalData [13]byte // to avoid allocs; interface method args escape additionalData [13]byte // to avoid allocs; interface method args escape
nextCipher interface{} // next encryption state nextCipher interface{} // next encryption state
nextMac macFunction // next MAC algorithm nextMac macFunction // next MAC algorithm
// used to save allocating a new buffer for each MAC.
inDigestBuf, outDigestBuf []byte
} }
func (hc *halfConn) setErrorLocked(err error) error { func (hc *halfConn) setErrorLocked(err error) error {
@ -203,9 +200,33 @@ func (hc *halfConn) incSeq() {
panic("TLS: sequence number wraparound") panic("TLS: sequence number wraparound")
} }
// explicitNonceLen returns the number of bytes of explicit nonce or IV included
// in each record. Explicit nonces are present only in CBC modes after TLS 1.0
// and in certain AEAD modes in TLS 1.2.
func (hc *halfConn) explicitNonceLen() int {
if hc.cipher == nil {
return 0
}
switch c := hc.cipher.(type) {
case cipher.Stream:
return 0
case aead:
return c.explicitNonceLen()
case cbcMode:
// TLS 1.1 introduced a per-record explicit IV to fix the BEAST attack.
if hc.version >= VersionTLS11 {
return c.BlockSize()
}
return 0
default:
panic("unknown cipher type")
}
}
// extractPadding returns, in constant time, the length of the padding to remove // extractPadding returns, in constant time, the length of the padding to remove
// from the end of payload. It also returns a byte which is equal to 255 if the // from the end of payload. It also returns a byte which is equal to 255 if the
// padding was valid and 0 otherwise. See RFC 2246, section 6.2.3.2 // padding was valid and 0 otherwise. See RFC 2246, Section 6.2.3.2.
func extractPadding(payload []byte) (toRemove int, good byte) { func extractPadding(payload []byte) (toRemove int, good byte) {
if len(payload) < 1 { if len(payload) < 1 {
return 0, 0 return 0, 0
@ -268,283 +289,189 @@ type cbcMode interface {
SetIV([]byte) SetIV([]byte)
} }
// decrypt checks and strips the mac and decrypts the data in b. Returns a // decrypt authenticates and decrypts the record if protection is active at
// success boolean, the number of bytes to skip from the start of the record in // this stage. The returned plaintext might overlap with the input.
// order to get the application payload, and an optional alert value. func (hc *halfConn) decrypt(record []byte) (plaintext []byte, err error) {
func (hc *halfConn) decrypt(b *block) (ok bool, prefixLen int, alertValue alert) { payload := record[recordHeaderLen:]
// pull out payload
payload := b.data[recordHeaderLen:]
macSize := 0
if hc.mac != nil {
macSize = hc.mac.Size()
}
paddingGood := byte(255) paddingGood := byte(255)
paddingLen := 0 paddingLen := 0
explicitIVLen := 0
// decrypt explicitNonceLen := hc.explicitNonceLen()
if hc.cipher != nil { if hc.cipher != nil {
switch c := hc.cipher.(type) { switch c := hc.cipher.(type) {
case cipher.Stream: case cipher.Stream:
c.XORKeyStream(payload, payload) c.XORKeyStream(payload, payload)
case aead: case aead:
explicitIVLen = c.explicitNonceLen() if len(payload) < explicitNonceLen {
if len(payload) < explicitIVLen { return nil, alertBadRecordMAC
return false, 0, alertBadRecordMAC
} }
nonce := payload[:explicitIVLen] nonce := payload[:explicitNonceLen]
payload = payload[explicitIVLen:]
if len(nonce) == 0 { if len(nonce) == 0 {
nonce = hc.seq[:] nonce = hc.seq[:]
} }
payload = payload[explicitNonceLen:]
copy(hc.additionalData[:], hc.seq[:]) copy(hc.additionalData[:], hc.seq[:])
copy(hc.additionalData[8:], b.data[:3]) copy(hc.additionalData[8:], record[:3])
n := len(payload) - c.Overhead() n := len(payload) - c.Overhead()
hc.additionalData[11] = byte(n >> 8) hc.additionalData[11] = byte(n >> 8)
hc.additionalData[12] = byte(n) hc.additionalData[12] = byte(n)
var err error var err error
payload, err = c.Open(payload[:0], nonce, payload, hc.additionalData[:]) plaintext, err = c.Open(payload[:0], nonce, payload, hc.additionalData[:])
if err != nil { if err != nil {
return false, 0, alertBadRecordMAC return nil, alertBadRecordMAC
} }
b.resize(recordHeaderLen + explicitIVLen + len(payload))
case cbcMode: case cbcMode:
blockSize := c.BlockSize() blockSize := c.BlockSize()
if hc.version >= VersionTLS11 { minPayload := explicitNonceLen + roundUp(hc.mac.Size()+1, blockSize) // TODO: vuln?
explicitIVLen = blockSize if len(payload)%blockSize != 0 || len(payload) < minPayload {
return nil, alertBadRecordMAC
} }
if len(payload)%blockSize != 0 || len(payload) < roundUp(explicitIVLen+macSize+1, blockSize) { if explicitNonceLen > 0 {
return false, 0, alertBadRecordMAC c.SetIV(payload[:explicitNonceLen])
} payload = payload[explicitNonceLen:]
if explicitIVLen > 0 {
c.SetIV(payload[:explicitIVLen])
payload = payload[explicitIVLen:]
} }
c.CryptBlocks(payload, payload) c.CryptBlocks(payload, payload)
// In a limited attempt to protect against CBC padding oracles like
// Lucky13, the data past paddingLen (which is secret) is passed to
// the MAC function as extra data, to be fed into the HMAC after
// computing the digest. This makes the MAC roughly constant time as
// long as the digest computation is constant time and does not
// affect the subsequent write, modulo cache effects.
if hc.version == VersionSSL30 { if hc.version == VersionSSL30 {
paddingLen, paddingGood = extractPaddingSSL30(payload) paddingLen, paddingGood = extractPaddingSSL30(payload)
} else { } else {
paddingLen, paddingGood = extractPadding(payload) paddingLen, paddingGood = extractPadding(payload)
// To protect against CBC padding oracles like Lucky13, the data
// past paddingLen (which is secret) is passed to the MAC
// function as extra data, to be fed into the HMAC after
// computing the digest. This makes the MAC constant time as
// long as the digest computation is constant time and does not
// affect the subsequent write.
} }
default: default:
panic("unknown cipher type") panic("unknown cipher type")
} }
} else {
plaintext = payload
} }
// check, strip mac
if hc.mac != nil { if hc.mac != nil {
macSize := hc.mac.Size()
if len(payload) < macSize { if len(payload) < macSize {
return false, 0, alertBadRecordMAC return nil, alertBadRecordMAC
} }
// strip mac off payload, b.data
n := len(payload) - macSize - paddingLen n := len(payload) - macSize - paddingLen
n = subtle.ConstantTimeSelect(int(uint32(n)>>31), 0, n) // if n < 0 { n = 0 } n = subtle.ConstantTimeSelect(int(uint32(n)>>31), 0, n) // if n < 0 { n = 0 }
b.data[3] = byte(n >> 8) record[3] = byte(n >> 8)
b.data[4] = byte(n) record[4] = byte(n)
remoteMAC := payload[n : n+macSize] remoteMAC := payload[n : n+macSize]
localMAC := hc.mac.MAC(hc.inDigestBuf, hc.seq[0:], b.data[:recordHeaderLen], payload[:n], payload[n+macSize:]) localMAC := hc.mac.MAC(hc.seq[0:], record[:recordHeaderLen], payload[:n], payload[n+macSize:])
if subtle.ConstantTimeCompare(localMAC, remoteMAC) != 1 || paddingGood != 255 { if subtle.ConstantTimeCompare(localMAC, remoteMAC) != 1 || paddingGood != 255 {
return false, 0, alertBadRecordMAC return nil, alertBadRecordMAC
} }
hc.inDigestBuf = localMAC
b.resize(recordHeaderLen + explicitIVLen + n) plaintext = payload[:n]
} }
hc.incSeq() hc.incSeq()
return plaintext, nil
return true, recordHeaderLen + explicitIVLen, 0
} }
// padToBlockSize calculates the needed padding block, if any, for a payload. // sliceForAppend extends the input slice by n bytes. head is the full extended
// On exit, prefix aliases payload and extends to the end of the last full // slice, while tail is the appended part. If the original slice has sufficient
// block of payload. finalBlock is a fresh slice which contains the contents of // capacity no allocation is performed.
// any suffix of payload as well as the needed padding to make finalBlock a func sliceForAppend(in []byte, n int) (head, tail []byte) {
// full block. if total := len(in) + n; cap(in) >= total {
func padToBlockSize(payload []byte, blockSize int) (prefix, finalBlock []byte) { head = in[:total]
overrun := len(payload) % blockSize } else {
paddingLen := blockSize - overrun head = make([]byte, total)
prefix = payload[:len(payload)-overrun] copy(head, in)
finalBlock = make([]byte, blockSize)
copy(finalBlock, payload[len(payload)-overrun:])
for i := overrun; i < blockSize; i++ {
finalBlock[i] = byte(paddingLen - 1)
} }
tail = head[len(in):]
return return
} }
// encrypt encrypts and macs the data in b. // encrypt encrypts payload, adding the appropriate nonce and/or MAC, and
func (hc *halfConn) encrypt(b *block, explicitIVLen int) (bool, alert) { // appends it to record, which contains the record header.
// mac func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) {
if hc.mac != nil { if hc.cipher == nil {
mac := hc.mac.MAC(hc.outDigestBuf, hc.seq[0:], b.data[:recordHeaderLen], b.data[recordHeaderLen+explicitIVLen:], nil) return append(record, payload...), nil
n := len(b.data)
b.resize(n + len(mac))
copy(b.data[n:], mac)
hc.outDigestBuf = mac
} }
payload := b.data[recordHeaderLen:] var explicitNonce []byte
if explicitNonceLen := hc.explicitNonceLen(); explicitNonceLen > 0 {
record, explicitNonce = sliceForAppend(record, explicitNonceLen)
if _, isCBC := hc.cipher.(cbcMode); !isCBC && explicitNonceLen < 16 {
// The AES-GCM construction in TLS has an explicit nonce so that the
// nonce can be random. However, the nonce is only 8 bytes which is
// too small for a secure, random nonce. Therefore we use the
// sequence number as the nonce. The 3DES-CBC construction also has
// an 8 bytes nonce but its nonces must be unpredictable (see RFC
// 5246, Appendix F.3), forcing us to use randomness. That's not
// 3DES' biggest problem anyway because the birthday bound on block
// collision is reached first due to its simlarly small block size
// (see the Sweet32 attack).
copy(explicitNonce, hc.seq[:])
} else {
if _, err := io.ReadFull(rand, explicitNonce); err != nil {
return nil, err
}
}
}
// encrypt var mac []byte
if hc.cipher != nil { if hc.mac != nil {
mac = hc.mac.MAC(hc.seq[:], record[:recordHeaderLen], payload, nil)
}
var dst []byte
switch c := hc.cipher.(type) { switch c := hc.cipher.(type) {
case cipher.Stream: case cipher.Stream:
c.XORKeyStream(payload, payload) record, dst = sliceForAppend(record, len(payload)+len(mac))
c.XORKeyStream(dst[:len(payload)], payload)
c.XORKeyStream(dst[len(payload):], mac)
case aead: case aead:
payloadLen := len(b.data) - recordHeaderLen - explicitIVLen nonce := explicitNonce
b.resize(len(b.data) + c.Overhead())
nonce := b.data[recordHeaderLen : recordHeaderLen+explicitIVLen]
if len(nonce) == 0 { if len(nonce) == 0 {
nonce = hc.seq[:] nonce = hc.seq[:]
} }
payload := b.data[recordHeaderLen+explicitIVLen:]
payload = payload[:payloadLen]
copy(hc.additionalData[:], hc.seq[:]) copy(hc.additionalData[:], hc.seq[:])
copy(hc.additionalData[8:], b.data[:3]) copy(hc.additionalData[8:], record[:3])
hc.additionalData[11] = byte(payloadLen >> 8) hc.additionalData[11] = byte(len(payload) >> 8)
hc.additionalData[12] = byte(payloadLen) hc.additionalData[12] = byte(len(payload))
c.Seal(payload[:0], nonce, payload, hc.additionalData[:]) record = c.Seal(record, nonce, payload, hc.additionalData[:])
case cbcMode: case cbcMode:
blockSize := c.BlockSize() blockSize := c.BlockSize()
if explicitIVLen > 0 { plaintextLen := len(payload) + len(mac)
c.SetIV(payload[:explicitIVLen]) paddingLen := blockSize - plaintextLen%blockSize
payload = payload[explicitIVLen:] record, dst = sliceForAppend(record, plaintextLen+paddingLen)
copy(dst, payload)
copy(dst[len(payload):], mac)
for i := plaintextLen; i < len(dst); i++ {
dst[i] = byte(paddingLen - 1)
} }
prefix, finalBlock := padToBlockSize(payload, blockSize) if len(explicitNonce) > 0 {
b.resize(recordHeaderLen + explicitIVLen + len(prefix) + len(finalBlock)) c.SetIV(explicitNonce)
c.CryptBlocks(b.data[recordHeaderLen+explicitIVLen:], prefix) }
c.CryptBlocks(b.data[recordHeaderLen+explicitIVLen+len(prefix):], finalBlock) c.CryptBlocks(dst, dst)
default: default:
panic("unknown cipher type") panic("unknown cipher type")
} }
}
// update length to include MAC and any block padding needed. // Update length to include nonce, MAC and any block padding needed.
n := len(b.data) - recordHeaderLen n := len(record) - recordHeaderLen
b.data[3] = byte(n >> 8) record[3] = byte(n >> 8)
b.data[4] = byte(n) record[4] = byte(n)
hc.incSeq() hc.incSeq()
return true, 0 return record, nil
} }
// A block is a simple data buffer. // RecordHeaderError is returned when a TLS record header is invalid.
type block struct {
data []byte
off int // index for Read
link *block
}
// resize resizes block to be n bytes, growing if necessary.
func (b *block) resize(n int) {
if n > cap(b.data) {
b.reserve(n)
}
b.data = b.data[0:n]
}
// reserve makes sure that block contains a capacity of at least n bytes.
func (b *block) reserve(n int) {
if cap(b.data) >= n {
return
}
m := cap(b.data)
if m == 0 {
m = 1024
}
for m < n {
m *= 2
}
data := make([]byte, len(b.data), m)
copy(data, b.data)
b.data = data
}
// readFromUntil reads from r into b until b contains at least n bytes
// or else returns an error.
func (b *block) readFromUntil(r io.Reader, n int) error {
// quick case
if len(b.data) >= n {
return nil
}
// read until have enough.
b.reserve(n)
for {
m, err := r.Read(b.data[len(b.data):cap(b.data)])
b.data = b.data[0 : len(b.data)+m]
if len(b.data) >= n {
// TODO(bradfitz,agl): slightly suspicious
// that we're throwing away r.Read's err here.
break
}
if err != nil {
return err
}
}
return nil
}
func (b *block) Read(p []byte) (n int, err error) {
n = copy(p, b.data[b.off:])
b.off += n
return
}
// newBlock allocates a new block, from hc's free list if possible.
func (hc *halfConn) newBlock() *block {
b := hc.bfree
if b == nil {
return new(block)
}
hc.bfree = b.link
b.link = nil
b.resize(0)
return b
}
// freeBlock returns a block to hc's free list.
// The protocol is such that each side only has a block or two on
// its free list at a time, so there's no need to worry about
// trimming the list, etc.
func (hc *halfConn) freeBlock(b *block) {
b.link = hc.bfree
hc.bfree = b
}
// splitBlock splits a block after the first n bytes,
// returning a block with those n bytes and a
// block with the remainder. the latter may be nil.
func (hc *halfConn) splitBlock(b *block, n int) (*block, *block) {
if len(b.data) <= n {
return b, nil
}
bb := hc.newBlock()
bb.resize(len(b.data) - n)
copy(bb.data, b.data[n:])
b.data = b.data[0:n]
return b, bb
}
// RecordHeaderError results when a TLS record header is invalid.
type RecordHeaderError struct { type RecordHeaderError struct {
// Msg contains a human readable string that describes the error. // Msg contains a human readable string that describes the error.
Msg string Msg string
@ -557,7 +484,7 @@ func (e RecordHeaderError) Error() string { return "tls: " + e.Msg }
func (c *Conn) newRecordHeaderError(msg string) (err RecordHeaderError) { func (c *Conn) newRecordHeaderError(msg string) (err RecordHeaderError) {
err.Msg = msg err.Msg = msg
copy(err.RecordHeader[:], c.rawInput.data) copy(err.RecordHeader[:], c.rawInput.Bytes())
return err return err
} }
@ -569,40 +496,38 @@ func (c *Conn) readRecord(want recordType) error {
// else application data. // else application data.
switch want { switch want {
default: default:
c.sendAlert(alertInternalError) panic("tls: unknown record type requested")
return c.in.setErrorLocked(errors.New("tls: unknown record type requested"))
case recordTypeHandshake, recordTypeChangeCipherSpec: case recordTypeHandshake, recordTypeChangeCipherSpec:
if c.handshakeComplete() { if c.handshakeComplete() {
c.sendAlert(alertInternalError) panic("tls: handshake or ChangeCipherSpec requested while not in handshake")
return c.in.setErrorLocked(errors.New("tls: handshake or ChangeCipherSpec requested while not in handshake"))
} }
case recordTypeApplicationData: case recordTypeApplicationData:
if !c.handshakeComplete() { if !c.handshakeComplete() {
c.sendAlert(alertInternalError) panic("tls: application data record requested while in handshake")
return c.in.setErrorLocked(errors.New("tls: application data record requested while in handshake"))
} }
} }
Again: // This function modifies c.rawInput, which owns the c.input memory.
if c.rawInput == nil { if c.input.Len() != 0 {
c.rawInput = c.in.newBlock() panic("tls: attempted to read record with pending application data")
} }
b := c.rawInput c.input.Reset(nil)
// Read header, payload. // Read header, payload.
if err := b.readFromUntil(c.conn, recordHeaderLen); err != nil { if err := c.readFromUntil(c.conn, recordHeaderLen); err != nil {
// RFC suggests that EOF without an alertCloseNotify is // RFC 8446, Section 6.1 suggests that EOF without an alertCloseNotify
// an error, but popular web sites seem to do this, // is an error, but popular web sites seem to do this, so we accept it
// so we can't make it an error. // if and only if at the record boundary.
// if err == io.EOF { if err == io.ErrUnexpectedEOF && c.rawInput.Len() == 0 {
// err = io.ErrUnexpectedEOF err = io.EOF
// } }
if e, ok := err.(net.Error); !ok || !e.Temporary() { if e, ok := err.(net.Error); !ok || !e.Temporary() {
c.in.setErrorLocked(err) c.in.setErrorLocked(err)
} }
return err return err
} }
typ := recordType(b.data[0]) hdr := c.rawInput.Bytes()[:recordHeaderLen]
typ := recordType(hdr[0])
// No valid TLS record has a type of 0x80, however SSLv2 handshakes // No valid TLS record has a type of 0x80, however SSLv2 handshakes
// start with a uint16 length where the MSB is set and the first record // start with a uint16 length where the MSB is set and the first record
@ -613,8 +538,8 @@ Again:
return c.in.setErrorLocked(c.newRecordHeaderError("unsupported SSLv2 handshake received")) return c.in.setErrorLocked(c.newRecordHeaderError("unsupported SSLv2 handshake received"))
} }
vers := uint16(b.data[1])<<8 | uint16(b.data[2]) vers := uint16(hdr[1])<<8 | uint16(hdr[2])
n := int(b.data[3])<<8 | int(b.data[4]) n := int(hdr[3])<<8 | int(hdr[4])
if c.haveVers && vers != c.vers { if c.haveVers && vers != c.vers {
c.sendAlert(alertProtocolVersion) c.sendAlert(alertProtocolVersion)
msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, c.vers) msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, c.vers)
@ -635,10 +560,7 @@ Again:
return c.in.setErrorLocked(c.newRecordHeaderError("first record does not look like a TLS handshake")) return c.in.setErrorLocked(c.newRecordHeaderError("first record does not look like a TLS handshake"))
} }
} }
if err := b.readFromUntil(c.conn, recordHeaderLen+n); err != nil { if err := c.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
if e, ok := err.(net.Error); !ok || !e.Temporary() { if e, ok := err.(net.Error); !ok || !e.Temporary() {
c.in.setErrorLocked(err) c.in.setErrorLocked(err)
} }
@ -646,18 +568,13 @@ Again:
} }
// Process message. // Process message.
b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n) record := c.rawInput.Next(recordHeaderLen + n)
ok, off, alertValue := c.in.decrypt(b) data, err := c.in.decrypt(record)
if !ok { if err != nil {
c.in.freeBlock(b) return c.in.setErrorLocked(c.sendAlert(err.(alert)))
return c.in.setErrorLocked(c.sendAlert(alertValue))
} }
b.off = off
data := b.data[b.off:]
if len(data) > maxPlaintext { if len(data) > maxPlaintext {
err := c.sendAlert(alertRecordOverflow) return c.in.setErrorLocked(c.sendAlert(alertRecordOverflow))
c.in.freeBlock(b)
return c.in.setErrorLocked(err)
} }
if typ != recordTypeAlert && len(data) > 0 { if typ != recordTypeAlert && len(data) > 0 {
@ -667,70 +584,97 @@ Again:
switch typ { switch typ {
default: default:
c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
case recordTypeAlert: case recordTypeAlert:
if len(data) != 2 { if len(data) != 2 {
c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
break
} }
if alert(data[1]) == alertCloseNotify { if alert(data[1]) == alertCloseNotify {
c.in.setErrorLocked(io.EOF) return c.in.setErrorLocked(io.EOF)
break
} }
switch data[0] { switch data[0] {
case alertLevelWarning: case alertLevelWarning:
// drop on the floor
c.in.freeBlock(b)
c.warnCount++ c.warnCount++
if c.warnCount > maxWarnAlertCount { if c.warnCount > maxWarnAlertCount {
c.sendAlert(alertUnexpectedMessage) c.sendAlert(alertUnexpectedMessage)
return c.in.setErrorLocked(errors.New("tls: too many warn alerts")) return c.in.setErrorLocked(errors.New("tls: too many warn alerts"))
} }
return c.readRecord(want) // Drop the record on the floor and retry.
goto Again
case alertLevelError: case alertLevelError:
c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])}) return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
default: default:
c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
} }
case recordTypeChangeCipherSpec: case recordTypeChangeCipherSpec:
if typ != want || len(data) != 1 || data[0] != 1 { if typ != want || len(data) != 1 || data[0] != 1 {
c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
break
} }
// Handshake messages are not allowed to fragment across the CCS // Handshake messages are not allowed to fragment across the CCS.
if c.hand.Len() > 0 { if c.hand.Len() > 0 {
c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
break
} }
err := c.in.changeCipherSpec() if err := c.in.changeCipherSpec(); err != nil {
if err != nil { return c.in.setErrorLocked(c.sendAlert(err.(alert)))
c.in.setErrorLocked(c.sendAlert(err.(alert)))
} }
return nil
case recordTypeApplicationData: case recordTypeApplicationData:
if typ != want { if typ != want {
c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
break
} }
c.input = b // Note that data is owned by c.rawInput, following the Next call above,
b = nil // to avoid copying the plaintext. This is safe because c.rawInput is
// not read from or written to until c.input is drained.
c.input.Reset(data)
return nil
case recordTypeHandshake: case recordTypeHandshake:
// TODO(rsc): Should at least pick off connection close.
if typ != want && !(c.isClient && c.config.Renegotiation != RenegotiateNever) { if typ != want && !(c.isClient && c.config.Renegotiation != RenegotiateNever) {
return c.in.setErrorLocked(c.sendAlert(alertNoRenegotiation)) return c.in.setErrorLocked(c.sendAlert(alertNoRenegotiation))
} }
c.hand.Write(data) c.hand.Write(data)
return nil
} }
}
if b != nil { // atLeastReader reads from R, stopping with EOF once at least N bytes have been
c.in.freeBlock(b) // read. It is different from an io.LimitedReader in that it doesn't cut short
// the last Read call, and in that it considers an early EOF an error.
type atLeastReader struct {
R io.Reader
N int64
}
func (r *atLeastReader) Read(p []byte) (int, error) {
if r.N <= 0 {
return 0, io.EOF
} }
return c.in.err n, err := r.R.Read(p)
r.N -= int64(n) // won't underflow unless len(p) >= n > 9223372036854775809
if r.N > 0 && err == io.EOF {
return n, io.ErrUnexpectedEOF
}
if r.N <= 0 && err == nil {
return n, io.EOF
}
return n, err
}
// readFromUntil reads from r into c.rawInput until c.rawInput contains
// at least n bytes or else returns an error.
func (c *Conn) readFromUntil(r io.Reader, n int) error {
if c.rawInput.Len() >= n {
return nil
}
needs := n - c.rawInput.Len()
// There might be extra input waiting on the wire. Make a best effort
// attempt to fetch it so that it can be used in (*Conn).Read to
// "predict" closeNotify alerts.
c.rawInput.Grow(needs + bytes.MinRead)
_, err := c.rawInput.ReadFrom(&atLeastReader{r, int64(needs)})
return err
} }
// sendAlert sends a TLS alert message. // sendAlert sends a TLS alert message.
@ -789,7 +733,7 @@ const (
// //
// In the interests of simplicity and determinism, this code does not attempt // In the interests of simplicity and determinism, this code does not attempt
// to reset the record size once the connection is idle, however. // to reset the record size once the connection is idle, however.
func (c *Conn) maxPayloadSizeForWrite(typ recordType, explicitIVLen int) int { func (c *Conn) maxPayloadSizeForWrite(typ recordType) int {
if c.config.DynamicRecordSizingDisabled || typ != recordTypeApplicationData { if c.config.DynamicRecordSizingDisabled || typ != recordTypeApplicationData {
return maxPlaintext return maxPlaintext
} }
@ -799,16 +743,11 @@ func (c *Conn) maxPayloadSizeForWrite(typ recordType, explicitIVLen int) int {
} }
// Subtract TLS overheads to get the maximum payload size. // Subtract TLS overheads to get the maximum payload size.
macSize := 0 payloadBytes := tcpMSSEstimate - recordHeaderLen - c.out.explicitNonceLen()
if c.out.mac != nil {
macSize = c.out.mac.Size()
}
payloadBytes := tcpMSSEstimate - recordHeaderLen - explicitIVLen
if c.out.cipher != nil { if c.out.cipher != nil {
switch ciph := c.out.cipher.(type) { switch ciph := c.out.cipher.(type) {
case cipher.Stream: case cipher.Stream:
payloadBytes -= macSize payloadBytes -= c.out.mac.Size()
case cipher.AEAD: case cipher.AEAD:
payloadBytes -= ciph.Overhead() payloadBytes -= ciph.Overhead()
case cbcMode: case cbcMode:
@ -818,7 +757,7 @@ func (c *Conn) maxPayloadSizeForWrite(typ recordType, explicitIVLen int) int {
payloadBytes = (payloadBytes & ^(blockSize - 1)) - 1 payloadBytes = (payloadBytes & ^(blockSize - 1)) - 1
// The MAC is appended before padding so affects the // The MAC is appended before padding so affects the
// payload size directly. // payload size directly.
payloadBytes -= macSize payloadBytes -= c.out.mac.Size()
default: default:
panic("unknown cipher type") panic("unknown cipher type")
} }
@ -864,63 +803,32 @@ func (c *Conn) flush() (int, error) {
// writeRecordLocked writes a TLS record with the given type and payload to the // writeRecordLocked writes a TLS record with the given type and payload to the
// connection and updates the record layer state. // connection and updates the record layer state.
func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) { func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
b := c.out.newBlock()
defer c.out.freeBlock(b)
var n int var n int
for len(data) > 0 { for len(data) > 0 {
explicitIVLen := 0
explicitIVIsSeq := false
var cbc cbcMode
if c.out.version >= VersionTLS11 {
var ok bool
if cbc, ok = c.out.cipher.(cbcMode); ok {
explicitIVLen = cbc.BlockSize()
}
}
if explicitIVLen == 0 {
if c, ok := c.out.cipher.(aead); ok {
explicitIVLen = c.explicitNonceLen()
// The AES-GCM construction in TLS has an
// explicit nonce so that the nonce can be
// random. However, the nonce is only 8 bytes
// which is too small for a secure, random
// nonce. Therefore we use the sequence number
// as the nonce.
explicitIVIsSeq = explicitIVLen > 0
}
}
m := len(data) m := len(data)
if maxPayload := c.maxPayloadSizeForWrite(typ, explicitIVLen); m > maxPayload { if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload {
m = maxPayload m = maxPayload
} }
b.resize(recordHeaderLen + explicitIVLen + m)
b.data[0] = byte(typ) _, c.outBuf = sliceForAppend(c.outBuf[:0], recordHeaderLen)
c.outBuf[0] = byte(typ)
vers := c.vers vers := c.vers
if vers == 0 { if vers == 0 {
// Some TLS servers fail if the record version is // Some TLS servers fail if the record version is
// greater than TLS 1.0 for the initial ClientHello. // greater than TLS 1.0 for the initial ClientHello.
vers = VersionTLS10 vers = VersionTLS10
} }
b.data[1] = byte(vers >> 8) c.outBuf[1] = byte(vers >> 8)
b.data[2] = byte(vers) c.outBuf[2] = byte(vers)
b.data[3] = byte(m >> 8) c.outBuf[3] = byte(m >> 8)
b.data[4] = byte(m) c.outBuf[4] = byte(m)
if explicitIVLen > 0 {
explicitIV := b.data[recordHeaderLen : recordHeaderLen+explicitIVLen] var err error
if explicitIVIsSeq { c.outBuf, err = c.out.encrypt(c.outBuf, data[:m], c.config.rand())
copy(explicitIV, c.out.seq[:]) if err != nil {
} else {
if _, err := io.ReadFull(c.config.rand(), explicitIV); err != nil {
return n, err return n, err
} }
} if _, err := c.write(c.outBuf); err != nil {
}
copy(b.data[recordHeaderLen+explicitIVLen:], data)
c.out.encrypt(b, explicitIVLen)
if _, err := c.write(b.data); err != nil {
return n, err return n, err
} }
n += m n += m
@ -1124,14 +1032,14 @@ func (c *Conn) handleRenegotiation() error {
// Read can be made to time out and return a net.Error with Timeout() == true // Read can be made to time out and return a net.Error with Timeout() == true
// after a fixed time limit; see SetDeadline and SetReadDeadline. // after a fixed time limit; see SetDeadline and SetReadDeadline.
func (c *Conn) Read(b []byte) (n int, err error) { func (c *Conn) Read(b []byte) (int, error) {
if err = c.Handshake(); err != nil { if err := c.Handshake(); err != nil {
return return 0, err
} }
if len(b) == 0 { if len(b) == 0 {
// Put this after Handshake, in case people were calling // Put this after Handshake, in case people were calling
// Read(nil) for the side effect of the Handshake. // Read(nil) for the side effect of the Handshake.
return return 0, nil
} }
c.in.Lock() c.in.Lock()
@ -1141,9 +1049,8 @@ func (c *Conn) Read(b []byte) (n int, err error) {
// CBC IV. So this loop ignores a limited number of empty records. // CBC IV. So this loop ignores a limited number of empty records.
const maxConsecutiveEmptyRecords = 100 const maxConsecutiveEmptyRecords = 100
for emptyRecordCount := 0; emptyRecordCount <= maxConsecutiveEmptyRecords; emptyRecordCount++ { for emptyRecordCount := 0; emptyRecordCount <= maxConsecutiveEmptyRecords; emptyRecordCount++ {
for c.input == nil && c.in.err == nil { for c.input.Len() == 0 && c.in.err == nil {
if err := c.readRecord(recordTypeApplicationData); err != nil { if err := c.readRecord(recordTypeApplicationData); err != nil {
// Soft error, like EAGAIN
return 0, err return 0, err
} }
if c.hand.Len() > 0 { if c.hand.Len() > 0 {
@ -1158,33 +1065,24 @@ func (c *Conn) Read(b []byte) (n int, err error) {
return 0, err return 0, err
} }
n, err = c.input.Read(b) n, _ := c.input.Read(b)
if c.input.off >= len(c.input.data) {
c.in.freeBlock(c.input)
c.input = nil
}
// If a close-notify alert is waiting, read it so that // If a close-notify alert is waiting, read it so that we can return (n,
// we can return (n, EOF) instead of (n, nil), to signal // EOF) instead of (n, nil), to signal to the HTTP response reading
// to the HTTP response reading goroutine that the // goroutine that the connection is now closed. This eliminates a race
// connection is now closed. This eliminates a race // where the HTTP response reading goroutine would otherwise not observe
// where the HTTP response reading goroutine would // the EOF until its next read, by which time a client goroutine might
// otherwise not observe the EOF until its next read, // have already tried to reuse the HTTP connection for a new request.
// by which time a client goroutine might have already // See https://golang.org/cl/76400046 and https://golang.org/issue/3514
// tried to reuse the HTTP connection for a new if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 &&
// request. recordType(c.rawInput.Bytes()[0]) == recordTypeAlert {
// See https://codereview.appspot.com/76400046 if err := c.readRecord(recordTypeApplicationData); err != nil {
// and https://golang.org/issue/3514 return n, err // will be io.EOF on closeNotify
if ri := c.rawInput; ri != nil &&
n != 0 && err == nil &&
c.input == nil && len(ri.data) > 0 && recordType(ri.data[0]) == recordTypeAlert {
if recErr := c.readRecord(recordTypeApplicationData); recErr != nil {
err = recErr // will be io.EOF on closeNotify
} }
} }
if n != 0 || err != nil { if n != 0 {
return n, err return n, nil
} }
} }

View file

@ -134,12 +134,13 @@ func TestCertificateSelection(t *testing.T) {
// Run with multiple crypto configs to test the logic for computing TLS record overheads. // Run with multiple crypto configs to test the logic for computing TLS record overheads.
func runDynamicRecordSizingTest(t *testing.T, config *Config) { func runDynamicRecordSizingTest(t *testing.T, config *Config) {
clientConn, serverConn := net.Pipe() clientConn, serverConn := localPipe(t)
serverConfig := config.Clone() serverConfig := config.Clone()
serverConfig.DynamicRecordSizingDisabled = false serverConfig.DynamicRecordSizingDisabled = false
tlsConn := Server(serverConn, serverConfig) tlsConn := Server(serverConn, serverConfig)
handshakeDone := make(chan struct{})
recordSizesChan := make(chan []int, 1) recordSizesChan := make(chan []int, 1)
go func() { go func() {
// This goroutine performs a TLS handshake over clientConn and // This goroutine performs a TLS handshake over clientConn and
@ -153,6 +154,7 @@ func runDynamicRecordSizingTest(t *testing.T, config *Config) {
t.Errorf("Error from client handshake: %v", err) t.Errorf("Error from client handshake: %v", err)
return return
} }
close(handshakeDone)
var recordHeader [recordHeaderLen]byte var recordHeader [recordHeaderLen]byte
var record []byte var record []byte
@ -192,6 +194,7 @@ func runDynamicRecordSizingTest(t *testing.T, config *Config) {
if err := tlsConn.Handshake(); err != nil { if err := tlsConn.Handshake(); err != nil {
t.Fatalf("Error from server handshake: %s", err) t.Fatalf("Error from server handshake: %s", err)
} }
<-handshakeDone
// The server writes these plaintexts in order. // The server writes these plaintexts in order.
plaintext := bytes.Join([][]byte{ plaintext := bytes.Join([][]byte{
@ -269,7 +272,7 @@ func (conn *hairpinConn) Close() error {
func TestHairpinInClose(t *testing.T) { func TestHairpinInClose(t *testing.T) {
// This tests that the underlying net.Conn can call back into the // This tests that the underlying net.Conn can call back into the
// tls.Conn when being closed without deadlocking. // tls.Conn when being closed without deadlocking.
client, server := net.Pipe() client, server := localPipe(t)
defer server.Close() defer server.Close()
defer client.Close() defer client.Close()

View file

@ -850,7 +850,7 @@ func mutualProtocol(protos, preferenceProtos []string) (string, bool) {
// hostnameInSNI converts name into an approriate hostname for SNI. // hostnameInSNI converts name into an approriate hostname for SNI.
// Literal IP addresses and absolute FQDNs are not permitted as SNI values. // Literal IP addresses and absolute FQDNs are not permitted as SNI values.
// See https://tools.ietf.org/html/rfc6066#section-3. // See RFC 6066, Section 3.
func hostnameInSNI(name string) string { func hostnameInSNI(name string) string {
host := name host := name
if len(host) > 0 && host[0] == '[' && host[len(host)-1] == ']' { if len(host) > 0 && host[0] == '[' && host[len(host)-1] == ']' {

View file

@ -179,7 +179,7 @@ func (test *clientTest) connFromCommand() (conn *recordingConn, child *exec.Cmd,
var pemOut bytes.Buffer var pemOut bytes.Buffer
pem.Encode(&pemOut, &pem.Block{Type: pemType + " PRIVATE KEY", Bytes: derBytes}) pem.Encode(&pemOut, &pem.Block{Type: pemType + " PRIVATE KEY", Bytes: derBytes})
keyPath := tempFile(string(pemOut.Bytes())) keyPath := tempFile(pemOut.String())
defer os.Remove(keyPath) defer os.Remove(keyPath)
var command []string var command []string
@ -293,7 +293,7 @@ func (test *clientTest) run(t *testing.T, write bool) {
} }
clientConn = recordingConn clientConn = recordingConn
} else { } else {
clientConn, serverConn = net.Pipe() clientConn, serverConn = localPipe(t)
} }
config := test.config config := test.config
@ -682,7 +682,7 @@ func TestClientResumption(t *testing.T) {
} }
testResumeState := func(test string, didResume bool) { testResumeState := func(test string, didResume bool) {
_, hs, err := testHandshake(clientConfig, serverConfig) _, hs, err := testHandshake(t, clientConfig, serverConfig)
if err != nil { if err != nil {
t.Fatalf("%s: handshake failed: %s", test, err) t.Fatalf("%s: handshake failed: %s", test, err)
} }
@ -800,7 +800,7 @@ func TestKeyLog(t *testing.T) {
serverConfig := testConfig.Clone() serverConfig := testConfig.Clone()
serverConfig.KeyLogWriter = &serverBuf serverConfig.KeyLogWriter = &serverBuf
c, s := net.Pipe() c, s := localPipe(t)
done := make(chan bool) done := make(chan bool)
go func() { go func() {
@ -838,8 +838,8 @@ func TestKeyLog(t *testing.T) {
} }
} }
checkKeylogLine("client", string(clientBuf.Bytes())) checkKeylogLine("client", clientBuf.String())
checkKeylogLine("server", string(serverBuf.Bytes())) checkKeylogLine("server", serverBuf.String())
} }
func TestHandshakeClientALPNMatch(t *testing.T) { func TestHandshakeClientALPNMatch(t *testing.T) {
@ -1021,7 +1021,7 @@ var hostnameInSNITests = []struct {
func TestHostnameInSNI(t *testing.T) { func TestHostnameInSNI(t *testing.T) {
for _, tt := range hostnameInSNITests { for _, tt := range hostnameInSNITests {
c, s := net.Pipe() c, s := localPipe(t)
go func(host string) { go func(host string) {
Client(c, &Config{ServerName: host, InsecureSkipVerify: true}).Handshake() Client(c, &Config{ServerName: host, InsecureSkipVerify: true}).Handshake()
@ -1059,7 +1059,7 @@ func TestServerSelectingUnconfiguredCipherSuite(t *testing.T) {
// This checks that the server can't select a cipher suite that the // This checks that the server can't select a cipher suite that the
// client didn't offer. See #13174. // client didn't offer. See #13174.
c, s := net.Pipe() c, s := localPipe(t)
errChan := make(chan error, 1) errChan := make(chan error, 1)
go func() { go func() {
@ -1228,7 +1228,7 @@ func TestVerifyPeerCertificate(t *testing.T) {
} }
for i, test := range tests { for i, test := range tests {
c, s := net.Pipe() c, s := localPipe(t)
done := make(chan error) done := make(chan error)
var clientCalled, serverCalled bool var clientCalled, serverCalled bool
@ -1287,7 +1287,7 @@ func (b *brokenConn) Write(data []byte) (int, error) {
func TestFailedWrite(t *testing.T) { func TestFailedWrite(t *testing.T) {
// Test that a write error during the handshake is returned. // Test that a write error during the handshake is returned.
for _, breakAfter := range []int{0, 1} { for _, breakAfter := range []int{0, 1} {
c, s := net.Pipe() c, s := localPipe(t)
done := make(chan bool) done := make(chan bool)
go func() { go func() {
@ -1321,7 +1321,7 @@ func (wcc *writeCountingConn) Write(data []byte) (int, error) {
} }
func TestBuffering(t *testing.T) { func TestBuffering(t *testing.T) {
c, s := net.Pipe() c, s := localPipe(t)
done := make(chan bool) done := make(chan bool)
clientWCC := &writeCountingConn{Conn: c} clientWCC := &writeCountingConn{Conn: c}
@ -1350,7 +1350,7 @@ func TestBuffering(t *testing.T) {
} }
func TestAlertFlushing(t *testing.T) { func TestAlertFlushing(t *testing.T) {
c, s := net.Pipe() c, s := localPipe(t)
done := make(chan bool) done := make(chan bool)
clientWCC := &writeCountingConn{Conn: c} clientWCC := &writeCountingConn{Conn: c}
@ -1399,7 +1399,7 @@ func TestHandshakeRace(t *testing.T) {
// order to provide some evidence that there are no races or deadlocks // order to provide some evidence that there are no races or deadlocks
// in the handshake locking. // in the handshake locking.
for i := 0; i < 32; i++ { for i := 0; i < 32; i++ {
c, s := net.Pipe() c, s := localPipe(t)
go func() { go func() {
server := Server(s, testConfig) server := Server(s, testConfig)
@ -1430,7 +1430,7 @@ func TestHandshakeRace(t *testing.T) {
go func() { go func() {
<-startRead <-startRead
var reply [1]byte var reply [1]byte
if n, err := client.Read(reply[:]); err != nil || n != 1 { if _, err := io.ReadFull(client, reply[:]); err != nil {
panic(err) panic(err)
} }
c.Close() c.Close()
@ -1559,7 +1559,7 @@ func TestGetClientCertificate(t *testing.T) {
err error err error
} }
c, s := net.Pipe() c, s := localPipe(t)
done := make(chan serverResult) done := make(chan serverResult)
go func() { go func() {
@ -1637,7 +1637,7 @@ RwBA9Xk1KBNF
} }
func TestCloseClientConnectionOnIdleServer(t *testing.T) { func TestCloseClientConnectionOnIdleServer(t *testing.T) {
clientConn, serverConn := net.Pipe() clientConn, serverConn := localPipe(t)
client := Client(clientConn, testConfig.Clone()) client := Client(clientConn, testConfig.Clone())
go func() { go func() {
var b [1]byte var b [1]byte
@ -1647,8 +1647,8 @@ func TestCloseClientConnectionOnIdleServer(t *testing.T) {
client.SetWriteDeadline(time.Now().Add(time.Second)) client.SetWriteDeadline(time.Now().Add(time.Second))
err := client.Handshake() err := client.Handshake()
if err != nil { if err != nil {
if !strings.Contains(err.Error(), "read/write on closed pipe") { if err, ok := err.(net.Error); ok && err.Timeout() {
t.Errorf("Error expected containing 'read/write on closed pipe' but got '%s'", err.Error()) t.Errorf("Expected a closed network connection error but got '%s'", err.Error())
} }
} else { } else {
t.Errorf("Error expected, but no error returned") t.Errorf("Error expected, but no error returned")

View file

@ -155,7 +155,7 @@ func (m *clientHelloMsg) marshal() []byte {
z[3] = byte(l) z[3] = byte(l)
z = z[4:] z = z[4:]
// RFC 3546, section 3.1 // RFC 3546, Section 3.1
// //
// struct { // struct {
// NameType name_type; // NameType name_type;
@ -182,7 +182,7 @@ func (m *clientHelloMsg) marshal() []byte {
z = z[l:] z = z[l:]
} }
if m.ocspStapling { if m.ocspStapling {
// RFC 4366, section 3.6 // RFC 4366, Section 3.6
z[0] = byte(extensionStatusRequest >> 8) z[0] = byte(extensionStatusRequest >> 8)
z[1] = byte(extensionStatusRequest) z[1] = byte(extensionStatusRequest)
z[2] = 0 z[2] = 0
@ -192,7 +192,7 @@ func (m *clientHelloMsg) marshal() []byte {
z = z[9:] z = z[9:]
} }
if len(m.supportedCurves) > 0 { if len(m.supportedCurves) > 0 {
// https://tools.ietf.org/html/rfc4492#section-5.5.1 // RFC 4492, Section 5.5.1
z[0] = byte(extensionSupportedCurves >> 8) z[0] = byte(extensionSupportedCurves >> 8)
z[1] = byte(extensionSupportedCurves) z[1] = byte(extensionSupportedCurves)
l := 2 + 2*len(m.supportedCurves) l := 2 + 2*len(m.supportedCurves)
@ -209,7 +209,7 @@ func (m *clientHelloMsg) marshal() []byte {
} }
} }
if len(m.supportedPoints) > 0 { if len(m.supportedPoints) > 0 {
// https://tools.ietf.org/html/rfc4492#section-5.5.2 // RFC 4492, Section 5.5.2
z[0] = byte(extensionSupportedPoints >> 8) z[0] = byte(extensionSupportedPoints >> 8)
z[1] = byte(extensionSupportedPoints) z[1] = byte(extensionSupportedPoints)
l := 1 + len(m.supportedPoints) l := 1 + len(m.supportedPoints)
@ -224,7 +224,7 @@ func (m *clientHelloMsg) marshal() []byte {
} }
} }
if m.ticketSupported { if m.ticketSupported {
// https://tools.ietf.org/html/rfc5077#section-3.2 // RFC 5077, Section 3.2
z[0] = byte(extensionSessionTicket >> 8) z[0] = byte(extensionSessionTicket >> 8)
z[1] = byte(extensionSessionTicket) z[1] = byte(extensionSessionTicket)
l := len(m.sessionTicket) l := len(m.sessionTicket)
@ -235,7 +235,7 @@ func (m *clientHelloMsg) marshal() []byte {
z = z[len(m.sessionTicket):] z = z[len(m.sessionTicket):]
} }
if len(m.supportedSignatureAlgorithms) > 0 { if len(m.supportedSignatureAlgorithms) > 0 {
// https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1 // RFC 5246, Section 7.4.1.4.1
z[0] = byte(extensionSignatureAlgorithms >> 8) z[0] = byte(extensionSignatureAlgorithms >> 8)
z[1] = byte(extensionSignatureAlgorithms) z[1] = byte(extensionSignatureAlgorithms)
l := 2 + 2*len(m.supportedSignatureAlgorithms) l := 2 + 2*len(m.supportedSignatureAlgorithms)
@ -285,7 +285,7 @@ func (m *clientHelloMsg) marshal() []byte {
lengths[1] = byte(stringsLength) lengths[1] = byte(stringsLength)
} }
if m.scts { if m.scts {
// https://tools.ietf.org/html/rfc6962#section-3.3.1 // RFC 6962, Section 3.3.1
z[0] = byte(extensionSCT >> 8) z[0] = byte(extensionSCT >> 8)
z[1] = byte(extensionSCT) z[1] = byte(extensionSCT)
// zero uint16 for the zero-length extension_data // zero uint16 for the zero-length extension_data
@ -396,9 +396,8 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
} }
if nameType == 0 { if nameType == 0 {
m.serverName = string(d[:nameLen]) m.serverName = string(d[:nameLen])
// An SNI value may not include a // An SNI value may not include a trailing dot.
// trailing dot. See // See RFC 6066, Section 3.
// https://tools.ietf.org/html/rfc6066#section-3.
if strings.HasSuffix(m.serverName, ".") { if strings.HasSuffix(m.serverName, ".") {
return false return false
} }
@ -414,7 +413,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
case extensionStatusRequest: case extensionStatusRequest:
m.ocspStapling = length > 0 && data[0] == statusTypeOCSP m.ocspStapling = length > 0 && data[0] == statusTypeOCSP
case extensionSupportedCurves: case extensionSupportedCurves:
// https://tools.ietf.org/html/rfc4492#section-5.5.1 // RFC 4492, Section 5.5.1
if length < 2 { if length < 2 {
return false return false
} }
@ -430,7 +429,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
d = d[2:] d = d[2:]
} }
case extensionSupportedPoints: case extensionSupportedPoints:
// https://tools.ietf.org/html/rfc4492#section-5.5.2 // RFC 4492, Section 5.5.2
if length < 1 { if length < 1 {
return false return false
} }
@ -441,11 +440,11 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
m.supportedPoints = make([]uint8, l) m.supportedPoints = make([]uint8, l)
copy(m.supportedPoints, data[1:]) copy(m.supportedPoints, data[1:])
case extensionSessionTicket: case extensionSessionTicket:
// https://tools.ietf.org/html/rfc5077#section-3.2 // RFC 5077, Section 3.2
m.ticketSupported = true m.ticketSupported = true
m.sessionTicket = data[:length] m.sessionTicket = data[:length]
case extensionSignatureAlgorithms: case extensionSignatureAlgorithms:
// https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1 // RFC 5246, Section 7.4.1.4.1
if length < 2 || length&1 != 0 { if length < 2 || length&1 != 0 {
return false return false
} }
@ -1224,7 +1223,7 @@ func (m *certificateRequestMsg) marshal() (x []byte) {
return m.raw return m.raw
} }
// See https://tools.ietf.org/html/rfc4346#section-7.4.4 // See RFC 4346, Section 7.4.4.
length := 1 + len(m.certificateTypes) + 2 length := 1 + len(m.certificateTypes) + 2
casLength := 0 casLength := 0
for _, ca := range m.certificateAuthorities { for _, ca := range m.certificateAuthorities {
@ -1374,7 +1373,7 @@ func (m *certificateVerifyMsg) marshal() (x []byte) {
return m.raw return m.raw
} }
// See https://tools.ietf.org/html/rfc4346#section-7.4.8 // See RFC 4346, Section 7.4.8.
siglength := len(m.signature) siglength := len(m.signature)
length := 2 + siglength length := 2 + siglength
if m.hasSignatureAndHash { if m.hasSignatureAndHash {
@ -1452,7 +1451,7 @@ func (m *newSessionTicketMsg) marshal() (x []byte) {
return m.raw return m.raw
} }
// See https://tools.ietf.org/html/rfc5077#section-3.3 // See RFC 5077, Section 3.3.
ticketLen := len(m.ticket) ticketLen := len(m.ticket)
length := 2 + 4 + ticketLen length := 2 + 4 + ticketLen
x = make([]byte, 4+length) x = make([]byte, 4+length)

View file

@ -188,7 +188,7 @@ func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
numSCTs := rand.Intn(4) numSCTs := rand.Intn(4)
m.scts = make([][]byte, numSCTs) m.scts = make([][]byte, numSCTs)
for i := range m.scts { for i := range m.scts {
m.scts[i] = randomBytes(rand.Intn(500), rand) m.scts[i] = randomBytes(rand.Intn(500)+1, rand)
} }
} }
@ -271,8 +271,7 @@ func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value {
} }
func TestRejectEmptySCTList(t *testing.T) { func TestRejectEmptySCTList(t *testing.T) {
// https://tools.ietf.org/html/rfc6962#section-3.3.1 specifies that // RFC 6962, Section 3.3.1 specifies that empty SCT lists are invalid.
// empty SCT lists are invalid.
var random [32]byte var random [32]byte
sct := []byte{0x42, 0x42, 0x42, 0x42} sct := []byte{0x42, 0x42, 0x42, 0x42}

View file

@ -49,7 +49,7 @@ func (c *Conn) serverHandshake() error {
return err return err
} }
// For an overview of TLS handshaking, see https://tools.ietf.org/html/rfc5246#section-7.3 // For an overview of TLS handshaking, see RFC 5246, Section 7.3.
c.buffering = true c.buffering = true
if isResume { if isResume {
// The client has included a session ticket and so we do an abbreviated handshake. // The client has included a session ticket and so we do an abbreviated handshake.
@ -268,7 +268,7 @@ Curves:
return false, errors.New("tls: no cipher suite supported by both client and server") return false, errors.New("tls: no cipher suite supported by both client and server")
} }
// See https://tools.ietf.org/html/rfc7507. // See RFC 7507.
for _, id := range hs.clientHello.cipherSuites { for _, id := range hs.clientHello.cipherSuites {
if id == TLS_FALLBACK_SCSV { if id == TLS_FALLBACK_SCSV {
// The client is doing a fallback connection. // The client is doing a fallback connection.

View file

@ -70,10 +70,7 @@ func testClientHello(t *testing.T, serverConfig *Config, m handshakeMessage) {
} }
func testClientHelloFailure(t *testing.T, serverConfig *Config, m handshakeMessage, expectedSubStr string) { func testClientHelloFailure(t *testing.T, serverConfig *Config, m handshakeMessage, expectedSubStr string) {
// Create in-memory network connection, c, s := localPipe(t)
// send message to server. Should return
// expected error.
c, s := net.Pipe()
go func() { go func() {
cli := Client(c, testConfig) cli := Client(c, testConfig)
if ch, ok := m.(*clientHelloMsg); ok { if ch, ok := m.(*clientHelloMsg); ok {
@ -201,25 +198,26 @@ func TestRenegotiationExtension(t *testing.T) {
cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA},
} }
var buf []byte bufChan := make(chan []byte)
c, s := net.Pipe() c, s := localPipe(t)
go func() { go func() {
cli := Client(c, testConfig) cli := Client(c, testConfig)
cli.vers = clientHello.vers cli.vers = clientHello.vers
cli.writeRecord(recordTypeHandshake, clientHello.marshal()) cli.writeRecord(recordTypeHandshake, clientHello.marshal())
buf = make([]byte, 1024) buf := make([]byte, 1024)
n, err := c.Read(buf) n, err := c.Read(buf)
if err != nil { if err != nil {
t.Errorf("Server read returned error: %s", err) t.Errorf("Server read returned error: %s", err)
return return
} }
buf = buf[:n]
c.Close() c.Close()
bufChan <- buf[:n]
}() }()
Server(s, testConfig).Handshake() Server(s, testConfig).Handshake()
buf := <-bufChan
if len(buf) < 5+4 { if len(buf) < 5+4 {
t.Fatalf("Server returned short message of length %d", len(buf)) t.Fatalf("Server returned short message of length %d", len(buf))
@ -262,22 +260,27 @@ func TestTLS12OnlyCipherSuites(t *testing.T) {
supportedPoints: []uint8{pointFormatUncompressed}, supportedPoints: []uint8{pointFormatUncompressed},
} }
c, s := net.Pipe() c, s := localPipe(t)
var reply interface{} replyChan := make(chan interface{})
var clientErr error
go func() { go func() {
cli := Client(c, testConfig) cli := Client(c, testConfig)
cli.vers = clientHello.vers cli.vers = clientHello.vers
cli.writeRecord(recordTypeHandshake, clientHello.marshal()) cli.writeRecord(recordTypeHandshake, clientHello.marshal())
reply, clientErr = cli.readHandshake() reply, err := cli.readHandshake()
c.Close() c.Close()
if err != nil {
replyChan <- err
} else {
replyChan <- reply
}
}() }()
config := testConfig.Clone() config := testConfig.Clone()
config.CipherSuites = clientHello.cipherSuites config.CipherSuites = clientHello.cipherSuites
Server(s, config).Handshake() Server(s, config).Handshake()
s.Close() s.Close()
if clientErr != nil { reply := <-replyChan
t.Fatal(clientErr) if err, ok := reply.(error); ok {
t.Fatal(err)
} }
serverHello, ok := reply.(*serverHelloMsg) serverHello, ok := reply.(*serverHelloMsg)
if !ok { if !ok {
@ -289,7 +292,7 @@ func TestTLS12OnlyCipherSuites(t *testing.T) {
} }
func TestAlertForwarding(t *testing.T) { func TestAlertForwarding(t *testing.T) {
c, s := net.Pipe() c, s := localPipe(t)
go func() { go func() {
Client(c, testConfig).sendAlert(alertUnknownCA) Client(c, testConfig).sendAlert(alertUnknownCA)
c.Close() c.Close()
@ -303,7 +306,7 @@ func TestAlertForwarding(t *testing.T) {
} }
func TestClose(t *testing.T) { func TestClose(t *testing.T) {
c, s := net.Pipe() c, s := localPipe(t)
go c.Close() go c.Close()
err := Server(s, testConfig).Handshake() err := Server(s, testConfig).Handshake()
@ -313,8 +316,8 @@ func TestClose(t *testing.T) {
} }
} }
func testHandshake(clientConfig, serverConfig *Config) (serverState, clientState ConnectionState, err error) { func testHandshake(t *testing.T, clientConfig, serverConfig *Config) (serverState, clientState ConnectionState, err error) {
c, s := net.Pipe() c, s := localPipe(t)
done := make(chan bool) done := make(chan bool)
go func() { go func() {
cli := Client(c, clientConfig) cli := Client(c, clientConfig)
@ -341,7 +344,7 @@ func TestVersion(t *testing.T) {
clientConfig := &Config{ clientConfig := &Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
} }
state, _, err := testHandshake(clientConfig, serverConfig) state, _, err := testHandshake(t, clientConfig, serverConfig)
if err != nil { if err != nil {
t.Fatalf("handshake failed: %s", err) t.Fatalf("handshake failed: %s", err)
} }
@ -360,7 +363,7 @@ func TestCipherSuitePreference(t *testing.T) {
CipherSuites: []uint16{TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_RC4_128_SHA}, CipherSuites: []uint16{TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_RC4_128_SHA},
InsecureSkipVerify: true, InsecureSkipVerify: true,
} }
state, _, err := testHandshake(clientConfig, serverConfig) state, _, err := testHandshake(t, clientConfig, serverConfig)
if err != nil { if err != nil {
t.Fatalf("handshake failed: %s", err) t.Fatalf("handshake failed: %s", err)
} }
@ -370,7 +373,7 @@ func TestCipherSuitePreference(t *testing.T) {
} }
serverConfig.PreferServerCipherSuites = true serverConfig.PreferServerCipherSuites = true
state, _, err = testHandshake(clientConfig, serverConfig) state, _, err = testHandshake(t, clientConfig, serverConfig)
if err != nil { if err != nil {
t.Fatalf("handshake failed: %s", err) t.Fatalf("handshake failed: %s", err)
} }
@ -391,7 +394,7 @@ func TestSCTHandshake(t *testing.T) {
clientConfig := &Config{ clientConfig := &Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
} }
_, state, err := testHandshake(clientConfig, serverConfig) _, state, err := testHandshake(t, clientConfig, serverConfig)
if err != nil { if err != nil {
t.Fatalf("handshake failed: %s", err) t.Fatalf("handshake failed: %s", err)
} }
@ -420,13 +423,13 @@ func TestCrossVersionResume(t *testing.T) {
// Establish a session at TLS 1.1. // Establish a session at TLS 1.1.
clientConfig.MaxVersion = VersionTLS11 clientConfig.MaxVersion = VersionTLS11
_, _, err := testHandshake(clientConfig, serverConfig) _, _, err := testHandshake(t, clientConfig, serverConfig)
if err != nil { if err != nil {
t.Fatalf("handshake failed: %s", err) t.Fatalf("handshake failed: %s", err)
} }
// The client session cache now contains a TLS 1.1 session. // The client session cache now contains a TLS 1.1 session.
state, _, err := testHandshake(clientConfig, serverConfig) state, _, err := testHandshake(t, clientConfig, serverConfig)
if err != nil { if err != nil {
t.Fatalf("handshake failed: %s", err) t.Fatalf("handshake failed: %s", err)
} }
@ -436,7 +439,7 @@ func TestCrossVersionResume(t *testing.T) {
// Test that the server will decline to resume at a lower version. // Test that the server will decline to resume at a lower version.
clientConfig.MaxVersion = VersionTLS10 clientConfig.MaxVersion = VersionTLS10
state, _, err = testHandshake(clientConfig, serverConfig) state, _, err = testHandshake(t, clientConfig, serverConfig)
if err != nil { if err != nil {
t.Fatalf("handshake failed: %s", err) t.Fatalf("handshake failed: %s", err)
} }
@ -445,7 +448,7 @@ func TestCrossVersionResume(t *testing.T) {
} }
// The client session cache now contains a TLS 1.0 session. // The client session cache now contains a TLS 1.0 session.
state, _, err = testHandshake(clientConfig, serverConfig) state, _, err = testHandshake(t, clientConfig, serverConfig)
if err != nil { if err != nil {
t.Fatalf("handshake failed: %s", err) t.Fatalf("handshake failed: %s", err)
} }
@ -455,7 +458,7 @@ func TestCrossVersionResume(t *testing.T) {
// Test that the server will decline to resume at a higher version. // Test that the server will decline to resume at a higher version.
clientConfig.MaxVersion = VersionTLS11 clientConfig.MaxVersion = VersionTLS11
state, _, err = testHandshake(clientConfig, serverConfig) state, _, err = testHandshake(t, clientConfig, serverConfig)
if err != nil { if err != nil {
t.Fatalf("handshake failed: %s", err) t.Fatalf("handshake failed: %s", err)
} }
@ -579,7 +582,7 @@ func (test *serverTest) run(t *testing.T, write bool) {
} }
serverConn = recordingConn serverConn = recordingConn
} else { } else {
clientConn, serverConn = net.Pipe() clientConn, serverConn = localPipe(t)
} }
config := test.config config := test.config
if config == nil { if config == nil {
@ -832,7 +835,7 @@ func TestHandshakeServerSNIGetCertificate(t *testing.T) {
nameToCert := config.NameToCertificate nameToCert := config.NameToCertificate
config.NameToCertificate = nil config.NameToCertificate = nil
config.GetCertificate = func(clientHello *ClientHelloInfo) (*Certificate, error) { config.GetCertificate = func(clientHello *ClientHelloInfo) (*Certificate, error) {
cert, _ := nameToCert[clientHello.ServerName] cert := nameToCert[clientHello.ServerName]
return cert, nil return cert, nil
} }
test := &serverTest{ test := &serverTest{
@ -1025,7 +1028,7 @@ func benchmarkHandshakeServer(b *testing.B, cipherSuite uint16, curve CurveID, c
config.Certificates[0].PrivateKey = key config.Certificates[0].PrivateKey = key
config.BuildNameToCertificate() config.BuildNameToCertificate()
clientConn, serverConn := net.Pipe() clientConn, serverConn := localPipe(b)
serverConn = &recordingConn{Conn: serverConn} serverConn = &recordingConn{Conn: serverConn}
go func() { go func() {
client := Client(clientConn, testConfig) client := Client(clientConn, testConfig)
@ -1039,7 +1042,7 @@ func benchmarkHandshakeServer(b *testing.B, cipherSuite uint16, curve CurveID, c
flows := serverConn.(*recordingConn).flows flows := serverConn.(*recordingConn).flows
feeder := make(chan struct{}) feeder := make(chan struct{})
clientConn, serverConn = net.Pipe() clientConn, serverConn = localPipe(b)
go func() { go func() {
for range feeder { for range feeder {
@ -1051,10 +1054,10 @@ func benchmarkHandshakeServer(b *testing.B, cipherSuite uint16, curve CurveID, c
ff := make([]byte, len(f)) ff := make([]byte, len(f))
n, err := io.ReadFull(clientConn, ff) n, err := io.ReadFull(clientConn, ff)
if err != nil { if err != nil {
b.Fatalf("#%d: %s\nRead %d, wanted %d, got %x, wanted %x\n", i+1, err, n, len(ff), ff[:n], f) b.Errorf("#%d: %s\nRead %d, wanted %d, got %x, wanted %x\n", i+1, err, n, len(ff), ff[:n], f)
} }
if !bytes.Equal(f, ff) { if !bytes.Equal(f, ff) {
b.Fatalf("#%d: mismatch on read: got:%x want:%x", i+1, ff, f) b.Errorf("#%d: mismatch on read: got:%x want:%x", i+1, ff, f)
} }
} }
} }
@ -1216,7 +1219,7 @@ func TestSNIGivenOnFailure(t *testing.T) {
// Erase the server's cipher suites to ensure the handshake fails. // Erase the server's cipher suites to ensure the handshake fails.
serverConfig.CipherSuites = nil serverConfig.CipherSuites = nil
c, s := net.Pipe() c, s := localPipe(t)
go func() { go func() {
cli := Client(c, testConfig) cli := Client(c, testConfig)
cli.vers = clientHello.vers cli.vers = clientHello.vers
@ -1346,7 +1349,7 @@ func TestGetConfigForClient(t *testing.T) {
configReturned = config configReturned = config
return config, err return config, err
} }
c, s := net.Pipe() c, s := localPipe(t)
done := make(chan error) done := make(chan error)
go func() { go func() {
@ -1423,7 +1426,7 @@ var testECDSAPrivateKey = &ecdsa.PrivateKey{
var testP256PrivateKey, _ = x509.ParseECPrivateKey(fromHex("30770201010420012f3b52bc54c36ba3577ad45034e2e8efe1e6999851284cb848725cfe029991a00a06082a8648ce3d030107a14403420004c02c61c9b16283bbcc14956d886d79b358aa614596975f78cece787146abf74c2d5dc578c0992b4f3c631373479ebf3892efe53d21c4f4f1cc9a11c3536b7f75")) var testP256PrivateKey, _ = x509.ParseECPrivateKey(fromHex("30770201010420012f3b52bc54c36ba3577ad45034e2e8efe1e6999851284cb848725cfe029991a00a06082a8648ce3d030107a14403420004c02c61c9b16283bbcc14956d886d79b358aa614596975f78cece787146abf74c2d5dc578c0992b4f3c631373479ebf3892efe53d21c4f4f1cc9a11c3536b7f75"))
func TestCloseServerConnectionOnIdleClient(t *testing.T) { func TestCloseServerConnectionOnIdleClient(t *testing.T) {
clientConn, serverConn := net.Pipe() clientConn, serverConn := localPipe(t)
server := Server(serverConn, testConfig.Clone()) server := Server(serverConn, testConfig.Clone())
go func() { go func() {
clientConn.Write([]byte{'0'}) clientConn.Write([]byte{'0'})
@ -1432,8 +1435,8 @@ func TestCloseServerConnectionOnIdleClient(t *testing.T) {
server.SetReadDeadline(time.Now().Add(time.Second)) server.SetReadDeadline(time.Now().Add(time.Second))
err := server.Handshake() err := server.Handshake()
if err != nil { if err != nil {
if !strings.Contains(err.Error(), "read/write on closed pipe") { if err, ok := err.(net.Error); ok && err.Timeout() {
t.Errorf("Error expected containing 'read/write on closed pipe' but got '%s'", err.Error()) t.Errorf("Expected a closed network connection error but got '%s'", err.Error())
} }
} else { } else {
t.Errorf("Error expected, but no error returned") t.Errorf("Error expected, but no error returned")

View file

@ -13,6 +13,7 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"net" "net"
"os"
"os/exec" "os/exec"
"strconv" "strconv"
"strings" "strings"
@ -224,3 +225,45 @@ func tempFile(contents string) string {
file.Close() file.Close()
return path return path
} }
// localListener is set up by TestMain and used by localPipe to create Conn
// pairs like net.Pipe, but connected by an actual buffered TCP connection.
var localListener struct {
sync.Mutex
net.Listener
}
func localPipe(t testing.TB) (net.Conn, net.Conn) {
localListener.Lock()
defer localListener.Unlock()
c := make(chan net.Conn)
go func() {
conn, err := localListener.Accept()
if err != nil {
t.Errorf("Failed to accept local connection: %v", err)
}
c <- conn
}()
addr := localListener.Addr()
c1, err := net.Dial(addr.Network(), addr.String())
if err != nil {
t.Fatalf("Failed to dial local connection: %v", err)
}
c2 := <-c
return c1, c2
}
func TestMain(m *testing.M) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
l, err = net.Listen("tcp6", "[::1]:0")
}
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to open local listener: %v", err)
os.Exit(1)
}
localListener.Listener = l
exitCode := m.Run()
localListener.Close()
os.Exit(exitCode)
}

View file

@ -12,10 +12,9 @@ import (
"crypto/sha1" "crypto/sha1"
"crypto/x509" "crypto/x509"
"errors" "errors"
"golang_org/x/crypto/curve25519"
"io" "io"
"math/big" "math/big"
"golang_org/x/crypto/curve25519"
) )
var errClientKeyExchange = errors.New("tls: invalid ClientKeyExchange message") var errClientKeyExchange = errors.New("tls: invalid ClientKeyExchange message")
@ -200,7 +199,7 @@ NextCandidate:
ecdhePublic = elliptic.Marshal(curve, x, y) ecdhePublic = elliptic.Marshal(curve, x, y)
} }
// https://tools.ietf.org/html/rfc4492#section-5.4 // See RFC 4492, Section 5.4.
serverECDHParams := make([]byte, 1+2+1+len(ecdhePublic)) serverECDHParams := make([]byte, 1+2+1+len(ecdhePublic))
serverECDHParams[0] = 3 // named curve serverECDHParams[0] = 3 // named curve
serverECDHParams[1] = byte(ka.curveid >> 8) serverECDHParams[1] = byte(ka.curveid >> 8)

15
prf.go
View file

@ -16,14 +16,14 @@ import (
"hash" "hash"
) )
// Split a premaster secret in two as specified in RFC 4346, section 5. // Split a premaster secret in two as specified in RFC 4346, Section 5.
func splitPreMasterSecret(secret []byte) (s1, s2 []byte) { func splitPreMasterSecret(secret []byte) (s1, s2 []byte) {
s1 = secret[0 : (len(secret)+1)/2] s1 = secret[0 : (len(secret)+1)/2]
s2 = secret[len(secret)/2:] s2 = secret[len(secret)/2:]
return return
} }
// pHash implements the P_hash function, as defined in RFC 4346, section 5. // pHash implements the P_hash function, as defined in RFC 4346, Section 5.
func pHash(result, secret, seed []byte, hash func() hash.Hash) { func pHash(result, secret, seed []byte, hash func() hash.Hash) {
h := hmac.New(hash, secret) h := hmac.New(hash, secret)
h.Write(seed) h.Write(seed)
@ -44,7 +44,7 @@ func pHash(result, secret, seed []byte, hash func() hash.Hash) {
} }
} }
// prf10 implements the TLS 1.0 pseudo-random function, as defined in RFC 2246, section 5. // prf10 implements the TLS 1.0 pseudo-random function, as defined in RFC 2246, Section 5.
func prf10(result, secret, label, seed []byte) { func prf10(result, secret, label, seed []byte) {
hashSHA1 := sha1.New hashSHA1 := sha1.New
hashMD5 := md5.New hashMD5 := md5.New
@ -63,7 +63,7 @@ func prf10(result, secret, label, seed []byte) {
} }
} }
// prf12 implements the TLS 1.2 pseudo-random function, as defined in RFC 5246, section 5. // prf12 implements the TLS 1.2 pseudo-random function, as defined in RFC 5246, Section 5.
func prf12(hashFunc func() hash.Hash) func(result, secret, label, seed []byte) { func prf12(hashFunc func() hash.Hash) func(result, secret, label, seed []byte) {
return func(result, secret, label, seed []byte) { return func(result, secret, label, seed []byte) {
labelAndSeed := make([]byte, len(label)+len(seed)) labelAndSeed := make([]byte, len(label)+len(seed))
@ -140,7 +140,7 @@ func prfForVersion(version uint16, suite *cipherSuite) func(result, secret, labe
} }
// masterFromPreMasterSecret generates the master secret from the pre-master // masterFromPreMasterSecret generates the master secret from the pre-master
// secret. See https://tools.ietf.org/html/rfc5246#section-8.1 // secret. See RFC 5246, Section 8.1.
func masterFromPreMasterSecret(version uint16, suite *cipherSuite, preMasterSecret, clientRandom, serverRandom []byte) []byte { func masterFromPreMasterSecret(version uint16, suite *cipherSuite, preMasterSecret, clientRandom, serverRandom []byte) []byte {
seed := make([]byte, 0, len(clientRandom)+len(serverRandom)) seed := make([]byte, 0, len(clientRandom)+len(serverRandom))
seed = append(seed, clientRandom...) seed = append(seed, clientRandom...)
@ -153,7 +153,7 @@ func masterFromPreMasterSecret(version uint16, suite *cipherSuite, preMasterSecr
// keysFromMasterSecret generates the connection keys from the master // keysFromMasterSecret generates the connection keys from the master
// secret, given the lengths of the MAC key, cipher key and IV, as defined in // secret, given the lengths of the MAC key, cipher key and IV, as defined in
// RFC 2246, section 6.3. // RFC 2246, Section 6.3.
func keysFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte, macLen, keyLen, ivLen int) (clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV []byte) { func keysFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte, macLen, keyLen, ivLen int) (clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV []byte) {
seed := make([]byte, 0, len(serverRandom)+len(clientRandom)) seed := make([]byte, 0, len(serverRandom)+len(clientRandom))
seed = append(seed, serverRandom...) seed = append(seed, serverRandom...)
@ -353,8 +353,7 @@ func noExportedKeyingMaterial(label string, context []byte, length int) ([]byte,
return nil, errors.New("crypto/tls: ExportKeyingMaterial is unavailable when renegotiation is enabled") return nil, errors.New("crypto/tls: ExportKeyingMaterial is unavailable when renegotiation is enabled")
} }
// ekmFromMasterSecret generates exported keying material as defined in // ekmFromMasterSecret generates exported keying material as defined in RFC 5705.
// https://tools.ietf.org/html/rfc5705.
func ekmFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte) func(string, []byte, int) ([]byte, error) { func ekmFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte) func(string, []byte, int) ([]byte, error) {
return func(label string, context []byte, length int) ([]byte, error) { return func(label string, context []byte, length int) ([]byte, error) {
switch label { switch label {

11
tls.go
View file

@ -11,6 +11,7 @@ package tls
// https://www.imperialviolet.org/2013/02/04/luckythirteen.html. // https://www.imperialviolet.org/2013/02/04/luckythirteen.html.
import ( import (
"bytes"
"crypto" "crypto"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/rsa" "crypto/rsa"
@ -29,7 +30,10 @@ import (
// The configuration config must be non-nil and must include // The configuration config must be non-nil and must include
// at least one certificate or else set GetCertificate. // at least one certificate or else set GetCertificate.
func Server(conn net.Conn, config *Config) *Conn { func Server(conn net.Conn, config *Config) *Conn {
return &Conn{conn: conn, config: config} return &Conn{
conn: conn, config: config,
input: *bytes.NewReader(nil), // Issue 28269
}
} }
// Client returns a new TLS client side connection // Client returns a new TLS client side connection
@ -37,7 +41,10 @@ func Server(conn net.Conn, config *Config) *Conn {
// The config cannot be nil: users must set either ServerName or // The config cannot be nil: users must set either ServerName or
// InsecureSkipVerify in the config. // InsecureSkipVerify in the config.
func Client(conn net.Conn, config *Config) *Conn { func Client(conn net.Conn, config *Config) *Conn {
return &Conn{conn: conn, config: config, isClient: true} return &Conn{
conn: conn, config: config, isClient: true,
input: *bytes.NewReader(nil), // Issue 28269
}
} }
// A listener implements a network listener (net.Listener) for TLS connections. // A listener implements a network listener (net.Listener) for TLS connections.