diff --git a/cipher_suites.go b/cipher_suites.go index 2475906..20e45e5 100644 --- a/cipher_suites.go +++ b/cipher_suites.go @@ -5,6 +5,7 @@ package tls import ( + "crypto" "crypto/aes" "crypto/cipher" "crypto/des" @@ -59,8 +60,7 @@ const ( suiteDefaultOff ) -// A cipherSuite is a specific combination of key agreement, cipher and MAC -// function. All cipher suites currently assume RSA key agreement. +// A cipherSuite is a specific combination of key agreement, cipher and MAC function. type cipherSuite struct { id uint16 // the lengths, in bytes, of the key material needed for each component. @@ -72,7 +72,7 @@ type cipherSuite struct { flags int cipher func(key, iv []byte, isRead bool) interface{} mac func(version uint16, macKey []byte) macFunction - aead func(key, fixedNonce []byte) cipher.AEAD + aead func(key, fixedNonce []byte) aead } var cipherSuites = []*cipherSuite{ @@ -104,6 +104,21 @@ var cipherSuites = []*cipherSuite{ {TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheECDSAKA, suiteECDHE | suiteECDSA | suiteDefaultOff, cipherRC4, macSHA1, nil}, } +// A cipherSuiteTLS13 defines only the pair of the AEAD algorithm and hash +// algorithm to be used with HKDF. See RFC 8446, Appendix B.4. +type cipherSuiteTLS13 struct { + id uint16 + keyLen int + aead func(key, fixedNonce []byte) aead + hash crypto.Hash +} + +var cipherSuitesTLS13 = []*cipherSuiteTLS13{ + {TLS_AES_128_GCM_SHA256, 16, aeadAESGCMTLS13, crypto.SHA256}, + {TLS_CHACHA20_POLY1305_SHA256, 32, aeadChaCha20Poly1305, crypto.SHA256}, + {TLS_AES_256_GCM_SHA384, 32, aeadAESGCMTLS13, crypto.SHA384}, +} + func cipherRC4(key, iv []byte, isRead bool) interface{} { cipher, _ := rc4.NewCipher(key) return cipher @@ -166,36 +181,41 @@ type aead interface { explicitNonceLen() int } -// fixedNonceAEAD wraps an AEAD and prefixes a fixed portion of the nonce to +const ( + aeadNonceLength = 12 + noncePrefixLength = 4 +) + +// prefixNonceAEAD wraps an AEAD and prefixes a fixed portion of the nonce to // each call. -type fixedNonceAEAD struct { +type prefixNonceAEAD struct { // nonce contains the fixed part of the nonce in the first four bytes. - nonce [12]byte + nonce [aeadNonceLength]byte aead cipher.AEAD } -func (f *fixedNonceAEAD) NonceSize() int { return 8 } -func (f *fixedNonceAEAD) Overhead() int { return f.aead.Overhead() } -func (f *fixedNonceAEAD) explicitNonceLen() int { return 8 } +func (f *prefixNonceAEAD) NonceSize() int { return aeadNonceLength - noncePrefixLength } +func (f *prefixNonceAEAD) Overhead() int { return f.aead.Overhead() } +func (f *prefixNonceAEAD) explicitNonceLen() int { return f.NonceSize() } -func (f *fixedNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte { +func (f *prefixNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte { copy(f.nonce[4:], nonce) return f.aead.Seal(out, f.nonce[:], plaintext, additionalData) } -func (f *fixedNonceAEAD) Open(out, nonce, plaintext, additionalData []byte) ([]byte, error) { +func (f *prefixNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) { copy(f.nonce[4:], nonce) - return f.aead.Open(out, f.nonce[:], plaintext, additionalData) + return f.aead.Open(out, f.nonce[:], ciphertext, additionalData) } // xoredNonceAEAD wraps an AEAD by XORing in a fixed pattern to the nonce // before each call. type xorNonceAEAD struct { - nonceMask [12]byte + nonceMask [aeadNonceLength]byte aead cipher.AEAD } -func (f *xorNonceAEAD) NonceSize() int { return 8 } +func (f *xorNonceAEAD) NonceSize() int { return 8 } // 64-bit sequence number func (f *xorNonceAEAD) Overhead() int { return f.aead.Overhead() } func (f *xorNonceAEAD) explicitNonceLen() int { return 0 } @@ -211,11 +231,11 @@ func (f *xorNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte return result } -func (f *xorNonceAEAD) Open(out, nonce, plaintext, additionalData []byte) ([]byte, error) { +func (f *xorNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) { for i, b := range nonce { f.nonceMask[4+i] ^= b } - result, err := f.aead.Open(out, f.nonceMask[:], plaintext, additionalData) + result, err := f.aead.Open(out, f.nonceMask[:], ciphertext, additionalData) for i, b := range nonce { f.nonceMask[4+i] ^= b } @@ -227,7 +247,10 @@ type gcmtls interface { NewGCMTLS() (cipher.AEAD, error) } -func aeadAESGCM(key, fixedNonce []byte) cipher.AEAD { +func aeadAESGCM(key, noncePrefix []byte) aead { + if len(noncePrefix) != noncePrefixLength { + panic("tls: internal error: wrong nonce length") + } aes, err := aes.NewCipher(key) if err != nil { panic(err) @@ -243,19 +266,40 @@ func aeadAESGCM(key, fixedNonce []byte) cipher.AEAD { panic(err) } - ret := &fixedNonceAEAD{aead: aead} - copy(ret.nonce[:], fixedNonce) + ret := &prefixNonceAEAD{aead: aead} + copy(ret.nonce[:], noncePrefix) return ret } -func aeadChaCha20Poly1305(key, fixedNonce []byte) cipher.AEAD { +func aeadAESGCMTLS13(key, nonceMask []byte) aead { + if len(nonceMask) != aeadNonceLength { + panic("tls: internal error: wrong nonce length") + } + aes, err := aes.NewCipher(key) + if err != nil { + panic(err) + } + aead, err := cipher.NewGCM(aes) + if err != nil { + panic(err) + } + + ret := &xorNonceAEAD{aead: aead} + copy(ret.nonceMask[:], nonceMask) + return ret +} + +func aeadChaCha20Poly1305(key, nonceMask []byte) aead { + if len(nonceMask) != aeadNonceLength { + panic("tls: internal error: wrong nonce length") + } aead, err := chacha20poly1305.New(key) if err != nil { panic(err) } ret := &xorNonceAEAD{aead: aead} - copy(ret.nonceMask[:], fixedNonce) + copy(ret.nonceMask[:], nonceMask) return ret } @@ -391,6 +435,7 @@ func mutualCipherSuite(have []uint16, want uint16) *cipherSuite { // // Taken from https://www.iana.org/assignments/tls-parameters/tls-parameters.xml const ( + // TLS 1.0 - 1.2 cipher suites. TLS_RSA_WITH_RC4_128_SHA uint16 = 0x0005 TLS_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x000a TLS_RSA_WITH_AES_128_CBC_SHA uint16 = 0x002f @@ -414,6 +459,11 @@ const ( TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305 uint16 = 0xcca8 TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305 uint16 = 0xcca9 + // TLS 1.3 cipher suites. + TLS_AES_128_GCM_SHA256 uint16 = 0x1301 + TLS_AES_256_GCM_SHA384 uint16 = 0x1302 + TLS_CHACHA20_POLY1305_SHA256 uint16 = 0x1303 + // TLS_FALLBACK_SCSV isn't a standard cipher suite but an indicator // that the client is doing version fallback. See RFC 7507. TLS_FALLBACK_SCSV uint16 = 0x5600 diff --git a/common.go b/common.go index 98d13b0..a3cfe05 100644 --- a/common.go +++ b/common.go @@ -27,14 +27,19 @@ const ( VersionTLS10 = 0x0301 VersionTLS11 = 0x0302 VersionTLS12 = 0x0303 + + // VersionTLS13 is under development in this library and can't be selected + // nor negotiated yet on either side. + VersionTLS13 = 0x0304 ) const ( - maxPlaintext = 16384 // maximum plaintext payload length - maxCiphertext = 16384 + 2048 // maximum ciphertext payload length - recordHeaderLen = 5 // record header length - maxHandshake = 65536 // maximum handshake we support (protocol max is 16 MB) - maxWarnAlertCount = 5 // maximum number of consecutive warning alerts + maxPlaintext = 16384 // maximum plaintext payload length + maxCiphertext = 16384 + 2048 // maximum ciphertext payload length + maxCiphertextTLS13 = 16384 + 256 // maximum ciphertext length in TLS 1.3 + recordHeaderLen = 5 // record header length + maxHandshake = 65536 // maximum handshake we support (protocol max is 16 MB) + maxUselessRecords = 5 // maximum number of consecutive non-advancing records minVersion = VersionTLS10 maxVersion = VersionTLS12 @@ -74,16 +79,22 @@ const ( // TLS extension numbers const ( - extensionServerName uint16 = 0 - extensionStatusRequest uint16 = 5 - extensionSupportedCurves uint16 = 10 - extensionSupportedPoints uint16 = 11 - extensionSignatureAlgorithms uint16 = 13 - extensionALPN uint16 = 16 - extensionSCT uint16 = 18 // RFC 6962, Section 6 - extensionSessionTicket uint16 = 35 - extensionNextProtoNeg uint16 = 13172 // not IANA assigned - extensionRenegotiationInfo uint16 = 0xff01 + extensionServerName uint16 = 0 + extensionStatusRequest uint16 = 5 + extensionSupportedCurves uint16 = 10 // supported_groups in TLS 1.3, see RFC 8446, Section 4.2.7 + extensionSupportedPoints uint16 = 11 + extensionSignatureAlgorithms uint16 = 13 + extensionALPN uint16 = 16 + extensionSCT uint16 = 18 + extensionSessionTicket uint16 = 35 + extensionPreSharedKey uint16 = 41 + extensionSupportedVersions uint16 = 43 + extensionCookie uint16 = 44 + extensionPSKModes uint16 = 45 + extensionSignatureAlgorithmsCert uint16 = 50 + extensionKeyShare uint16 = 51 + extensionNextProtoNeg uint16 = 13172 // not IANA assigned + extensionRenegotiationInfo uint16 = 0xff01 ) // TLS signaling cipher suite values @@ -92,7 +103,10 @@ const ( ) // CurveID is the type of a TLS identifier for an elliptic curve. See -// https://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-8 +// https://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-8. +// +// In TLS 1.3, this type is called NamedGroup, but at this time this library +// only supports Elliptic Curve based groups. See RFC 8446, Section 4.2.7. type CurveID uint16 const ( @@ -102,6 +116,25 @@ const ( X25519 CurveID = 29 ) +// TLS 1.3 Key Share. See RFC 8446, Section 4.2.8. +type keyShare struct { + group CurveID + data []byte +} + +// TLS 1.3 PSK Key Exchange Modes. See RFC 8446, Section 4.2.9. +const ( + pskModePlain uint8 = 0 + pskModeDHE uint8 = 1 +) + +// TLS 1.3 PSK Identity. Can be a Session Ticket, or a reference to a saved +// session. See RFC 8446, Section 4.2.11. +type pskIdentity struct { + label []byte + obfuscatedTicketAge uint32 +} + // TLS Elliptic Curve Point Formats // https://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-9 const ( @@ -425,7 +458,8 @@ type Config struct { // If RootCAs is nil, TLS uses the host's root CA set. RootCAs *x509.CertPool - // NextProtos is a list of supported, application level protocols. + // NextProtos is a list of supported application level protocols, in + // order of preference. NextProtos []string // ServerName is used to verify the hostname on the returned @@ -777,10 +811,14 @@ func (c *Config) BuildNameToCertificate() { c.NameToCertificate = make(map[string]*Certificate) for i := range c.Certificates { cert := &c.Certificates[i] - x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) - if err != nil { - continue + if cert.Leaf == nil { + x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + continue + } + cert.Leaf = x509Cert } + x509Cert := cert.Leaf if len(x509Cert.Subject.CommonName) > 0 { c.NameToCertificate[x509Cert.Subject.CommonName] = cert } @@ -922,8 +960,9 @@ func defaultConfig() *Config { } var ( - once sync.Once - varDefaultCipherSuites []uint16 + once sync.Once + varDefaultCipherSuites []uint16 + varDefaultCipherSuitesTLS13 []uint16 ) func defaultCipherSuites() []uint16 { @@ -931,19 +970,24 @@ func defaultCipherSuites() []uint16 { return varDefaultCipherSuites } +func defaultCipherSuitesTLS13() []uint16 { + once.Do(initDefaultCipherSuites) + return varDefaultCipherSuitesTLS13 +} + func initDefaultCipherSuites() { var topCipherSuites []uint16 // Check the cpu flags for each platform that has optimized GCM implementations. - // Worst case, these variables will just all be false - hasGCMAsmAMD64 := cpu.X86.HasAES && cpu.X86.HasPCLMULQDQ + // Worst case, these variables will just all be false. + var ( + hasGCMAsmAMD64 = cpu.X86.HasAES && cpu.X86.HasPCLMULQDQ + hasGCMAsmARM64 = cpu.ARM64.HasAES && cpu.ARM64.HasPMULL + // Keep in sync with crypto/aes/cipher_s390x.go. + hasGCMAsmS390X = cpu.S390X.HasAES && cpu.S390X.HasAESCBC && cpu.S390X.HasAESCTR && (cpu.S390X.HasGHASH || cpu.S390X.HasAESGCM) - hasGCMAsmARM64 := cpu.ARM64.HasAES && cpu.ARM64.HasPMULL - - // Keep in sync with crypto/aes/cipher_s390x.go. - hasGCMAsmS390X := cpu.S390X.HasAES && cpu.S390X.HasAESCBC && cpu.S390X.HasAESCTR && (cpu.S390X.HasGHASH || cpu.S390X.HasAESGCM) - - hasGCMAsm := hasGCMAsmAMD64 || hasGCMAsmARM64 || hasGCMAsmS390X + hasGCMAsm = hasGCMAsmAMD64 || hasGCMAsmARM64 || hasGCMAsmS390X + ) if hasGCMAsm || boring.Enabled { // If BoringCrypto is enabled, always prioritize AES-GCM. @@ -957,6 +1001,11 @@ func initDefaultCipherSuites() { TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, } + varDefaultCipherSuitesTLS13 = []uint16{ + TLS_AES_128_GCM_SHA256, + TLS_CHACHA20_POLY1305_SHA256, + TLS_AES_256_GCM_SHA384, + } } else { // Without AES-GCM hardware, we put the ChaCha20-Poly1305 // cipher suites first. @@ -968,6 +1017,11 @@ func initDefaultCipherSuites() { TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, } + varDefaultCipherSuitesTLS13 = []uint16{ + TLS_CHACHA20_POLY1305_SHA256, + TLS_AES_128_GCM_SHA256, + TLS_AES_256_GCM_SHA384, + } } varDefaultCipherSuites = make([]uint16, 0, len(cipherSuites)) diff --git a/conn.go b/conn.go index dae5fd1..5af1413 100644 --- a/conn.go +++ b/conn.go @@ -94,9 +94,9 @@ type Conn struct { bytesSent int64 packetsSent int64 - // warnCount counts the number of consecutive warning alerts received + // retryCount counts the number of consecutive warning alerts received // by Conn.readRecord. Protected by in.Mutex. - warnCount int + retryCount int // activeCall is an atomic int32; the low bit is whether Close has // been called. the rest of the bits are the number of goroutines @@ -291,9 +291,17 @@ type cbcMode interface { // decrypt authenticates and decrypts the record if protection is active at // this stage. The returned plaintext might overlap with the input. -func (hc *halfConn) decrypt(record []byte) (plaintext []byte, err error) { +func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) { + var plaintext []byte + typ := recordType(record[0]) payload := record[recordHeaderLen:] + // In TLS 1.3, change_cipher_spec messages are to be ignored without being + // decrypted. See RFC 8446, Appendix D.4. + if hc.version == VersionTLS13 && typ == recordTypeChangeCipherSpec { + return payload, typ, nil + } + paddingGood := byte(255) paddingLen := 0 @@ -305,7 +313,7 @@ func (hc *halfConn) decrypt(record []byte) (plaintext []byte, err error) { c.XORKeyStream(payload, payload) case aead: if len(payload) < explicitNonceLen { - return nil, alertBadRecordMAC + return nil, 0, alertBadRecordMAC } nonce := payload[:explicitNonceLen] if len(nonce) == 0 { @@ -313,22 +321,27 @@ func (hc *halfConn) decrypt(record []byte) (plaintext []byte, err error) { } payload = payload[explicitNonceLen:] - copy(hc.additionalData[:], hc.seq[:]) - copy(hc.additionalData[8:], record[:3]) - n := len(payload) - c.Overhead() - hc.additionalData[11] = byte(n >> 8) - hc.additionalData[12] = byte(n) + additionalData := hc.additionalData[:] + if hc.version == VersionTLS13 { + additionalData = record[:recordHeaderLen] + } else { + copy(additionalData, hc.seq[:]) + copy(additionalData[8:], record[:3]) + n := len(payload) - c.Overhead() + additionalData[11] = byte(n >> 8) + additionalData[12] = byte(n) + } var err error - plaintext, err = c.Open(payload[:0], nonce, payload, hc.additionalData[:]) + plaintext, err = c.Open(payload[:0], nonce, payload, additionalData) if err != nil { - return nil, alertBadRecordMAC + return nil, 0, alertBadRecordMAC } case cbcMode: blockSize := c.BlockSize() minPayload := explicitNonceLen + roundUp(hc.mac.Size()+1, blockSize) // TODO: vuln? if len(payload)%blockSize != 0 || len(payload) < minPayload { - return nil, alertBadRecordMAC + return nil, 0, alertBadRecordMAC } if explicitNonceLen > 0 { @@ -351,6 +364,26 @@ func (hc *halfConn) decrypt(record []byte) (plaintext []byte, err error) { default: panic("unknown cipher type") } + + if hc.version == VersionTLS13 { + if typ != recordTypeApplicationData { + return nil, 0, alertUnexpectedMessage + } + if len(plaintext) > maxPlaintext+1 { + return nil, 0, alertRecordOverflow + } + // Remove padding and find the ContentType scanning from the end. + for i := len(plaintext) - 1; i >= 0; i-- { + if plaintext[i] != 0 { + typ = recordType(plaintext[i]) + plaintext = plaintext[:i] + break + } + if i == 0 { + return nil, 0, alertUnexpectedMessage + } + } + } } else { plaintext = payload } @@ -358,7 +391,7 @@ func (hc *halfConn) decrypt(record []byte) (plaintext []byte, err error) { if hc.mac != nil { macSize := hc.mac.Size() if len(payload) < macSize { - return nil, alertBadRecordMAC + return nil, 0, alertBadRecordMAC } n := len(payload) - macSize - paddingLen @@ -369,14 +402,14 @@ func (hc *halfConn) decrypt(record []byte) (plaintext []byte, err error) { localMAC := hc.mac.MAC(hc.seq[0:], record[:recordHeaderLen], payload[:n], payload[n+macSize:]) if subtle.ConstantTimeCompare(localMAC, remoteMAC) != 1 || paddingGood != 255 { - return nil, alertBadRecordMAC + return nil, 0, alertBadRecordMAC } plaintext = payload[:n] } hc.incSeq() - return plaintext, nil + return plaintext, typ, nil } // sliceForAppend extends the input slice by n bytes. head is the full extended @@ -438,12 +471,24 @@ func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, err nonce = hc.seq[:] } - copy(hc.additionalData[:], hc.seq[:]) - copy(hc.additionalData[8:], record[:3]) - hc.additionalData[11] = byte(len(payload) >> 8) - hc.additionalData[12] = byte(len(payload)) + if hc.version == VersionTLS13 { + record = append(record, payload...) - record = c.Seal(record, nonce, payload, hc.additionalData[:]) + // Encrypt the actual ContentType and replace the plaintext one. + record = append(record, record[0]) + record[0] = byte(recordTypeApplicationData) + + n := len(payload) + 1 + c.Overhead() + record[3] = byte(n >> 8) + record[4] = byte(n) + + record = c.Seal(record[:recordHeaderLen], + nonce, record[recordHeaderLen:], record[:recordHeaderLen]) + } else { + copy(hc.additionalData[:], hc.seq[:]) + copy(hc.additionalData[8:], record) + record = c.Seal(record, nonce, payload, hc.additionalData[:]) + } case cbcMode: blockSize := c.BlockSize() plaintextLen := len(payload) + len(mac) @@ -546,7 +591,7 @@ func (c *Conn) readRecord(want recordType) error { vers := uint16(hdr[1])<<8 | uint16(hdr[2]) n := int(hdr[3])<<8 | int(hdr[4]) - if c.haveVers && vers != c.vers { + if c.haveVers && c.vers != VersionTLS13 && vers != c.vers { c.sendAlert(alertProtocolVersion) msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, c.vers) return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg)) @@ -560,7 +605,7 @@ func (c *Conn) readRecord(want recordType) error { return c.in.setErrorLocked(c.newRecordHeaderError(c.conn, "first record does not look like a TLS handshake")) } } - if n > maxCiphertext { + if c.vers == VersionTLS13 && n > maxCiphertextTLS13 || n > maxCiphertext { c.sendAlert(alertRecordOverflow) msg := fmt.Sprintf("oversized record received with length %d", n) return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg)) @@ -574,7 +619,7 @@ func (c *Conn) readRecord(want recordType) error { // Process message. record := c.rawInput.Next(recordHeaderLen + n) - data, err := c.in.decrypt(record) + data, typ, err := c.in.decrypt(record) if err != nil { return c.in.setErrorLocked(c.sendAlert(err.(alert))) } @@ -582,9 +627,31 @@ func (c *Conn) readRecord(want recordType) error { return c.in.setErrorLocked(c.sendAlert(alertRecordOverflow)) } - if typ != recordTypeAlert && len(data) > 0 { - // this is a valid non-alert message: reset the count of alerts - c.warnCount = 0 + // Application Data messages are always protected. + if c.in.cipher == nil && typ == recordTypeApplicationData { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + + if typ != recordTypeAlert && typ != recordTypeChangeCipherSpec && len(data) > 0 { + // This is a state-advancing message: reset the retry count. + c.retryCount = 0 + } + + // Handshake messages MUST NOT be interleaved with other record types in TLS 1.3. + if c.vers == VersionTLS13 && typ != recordTypeHandshake && c.hand.Len() > 0 { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + + // In TLS 1.3, change_cipher_spec records are ignored until the Finished. + // See RFC 8446, Appendix D.4. Note that according to Section 5, a server + // can send a ChangeCipherSpec before its ServerHello, when c.vers is still + // unset. That's not useful though and suspicious if the server then selects + // a lower protocol version, so don't allow that. + if c.vers == VersionTLS13 && typ == recordTypeChangeCipherSpec { + if len(data) != 1 || data[0] != 1 || c.handshakeComplete() { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + return c.retryReadRecord(want) } switch typ { @@ -598,14 +665,12 @@ func (c *Conn) readRecord(want recordType) error { if alert(data[1]) == alertCloseNotify { return c.in.setErrorLocked(io.EOF) } + if c.vers == VersionTLS13 { + return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])}) + } switch data[0] { case alertLevelWarning: - c.warnCount++ - if c.warnCount > maxWarnAlertCount { - c.sendAlert(alertUnexpectedMessage) - return c.in.setErrorLocked(errors.New("tls: too many warn alerts")) - } - return c.readRecord(want) // Drop the record on the floor and retry. + return c.retryReadRecord(want) // Drop the record on the floor and retry. case alertLevelError: return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])}) default: @@ -629,6 +694,9 @@ func (c *Conn) readRecord(want recordType) error { if typ != want { return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) } + if len(data) == 0 { + return c.retryReadRecord(want) + } // Note that data is owned by c.rawInput, following the Next call above, // to avoid copying the plaintext. This is safe because c.rawInput is // not read from or written to until c.input is drained. @@ -636,14 +704,35 @@ func (c *Conn) readRecord(want recordType) error { return nil case recordTypeHandshake: - if typ != want && !(c.isClient && c.config.Renegotiation != RenegotiateNever) { + if typ != want && !c.isRenegotiationAcceptable() { return c.in.setErrorLocked(c.sendAlert(alertNoRenegotiation)) } + if len(data) == 0 { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } c.hand.Write(data) return nil } } +// retryReadRecord recurses into readRecord to drop a non-advancing record, like +// a warning alert, empty application_data, or a change_cipher_spec in TLS 1.3. +func (c *Conn) retryReadRecord(want recordType) error { + c.retryCount++ + if c.retryCount > maxUselessRecords { + c.sendAlert(alertUnexpectedMessage) + return c.in.setErrorLocked(errors.New("tls: too many ignored records")) + } + return c.readRecord(want) +} + +func (c *Conn) isRenegotiationAcceptable() bool { + return c.isClient && + c.vers != VersionTLS13 && + c.handshakeComplete() && + c.config.Renegotiation != RenegotiateNever +} + // atLeastReader reads from R, stopping with EOF once at least N bytes have been // read. It is different from an io.LimitedReader in that it doesn't cut short // the last Read call, and in that it considers an early EOF an error. @@ -767,6 +856,9 @@ func (c *Conn) maxPayloadSizeForWrite(typ recordType) int { panic("unknown cipher type") } } + if c.vers == VersionTLS13 { + payloadBytes-- // encrypted ContentType + } // Allow packet growth in arithmetic progression up to max. pkt := c.packetsSent @@ -822,6 +914,10 @@ func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) { // Some TLS servers fail if the record version is // greater than TLS 1.0 for the initial ClientHello. vers = VersionTLS10 + } else if vers == VersionTLS13 { + // TLS 1.3 froze the record layer version to 1.2. + // See RFC 8446, Section 5.1. + vers = VersionTLS12 } c.outBuf[1] = byte(vers >> 8) c.outBuf[2] = byte(vers) diff --git a/handshake_messages.go b/handshake_messages.go index d678555..d04efc9 100644 --- a/handshake_messages.go +++ b/handshake_messages.go @@ -49,24 +49,31 @@ func readUint24LengthPrefixed(s *cryptobyte.String, out *[]byte) bool { } type clientHelloMsg struct { - raw []byte - vers uint16 - random []byte - sessionId []byte - cipherSuites []uint16 - compressionMethods []uint8 - nextProtoNeg bool - serverName string - ocspStapling bool - scts bool - supportedCurves []CurveID - supportedPoints []uint8 - ticketSupported bool - sessionTicket []uint8 - supportedSignatureAlgorithms []SignatureScheme - secureRenegotiation []byte - secureRenegotiationSupported bool - alpnProtocols []string + raw []byte + vers uint16 + random []byte + sessionId []byte + cipherSuites []uint16 + compressionMethods []uint8 + nextProtoNeg bool + serverName string + ocspStapling bool + supportedCurves []CurveID + supportedPoints []uint8 + ticketSupported bool + sessionTicket []uint8 + supportedSignatureAlgorithms []SignatureScheme + supportedSignatureAlgorithmsCert []SignatureScheme + secureRenegotiationSupported bool + secureRenegotiation []byte + alpnProtocols []string + scts bool + supportedVersions []uint16 + cookie []byte + keyShares []keyShare + pskModes []uint8 + pskIdentities []pskIdentity + pskBinders [][]byte } func (m *clientHelloMsg) marshal() []byte { @@ -123,7 +130,7 @@ func (m *clientHelloMsg) marshal() []byte { }) } if len(m.supportedCurves) > 0 { - // RFC 4492, Section 5.1.1 + // RFC 4492, Section 5.1.1 and RFC 8446, Section 4.2.7 b.AddUint16(extensionSupportedCurves) b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { @@ -160,6 +167,17 @@ func (m *clientHelloMsg) marshal() []byte { }) }) } + if len(m.supportedSignatureAlgorithmsCert) > 0 { + // RFC 8446, Section 4.2.3 + b.AddUint16(extensionSignatureAlgorithmsCert) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, sigAlgo := range m.supportedSignatureAlgorithmsCert { + b.AddUint16(uint16(sigAlgo)) + } + }) + }) + } if m.secureRenegotiationSupported { // RFC 5746, Section 3.2 b.AddUint16(extensionRenegotiationInfo) @@ -187,6 +205,70 @@ func (m *clientHelloMsg) marshal() []byte { b.AddUint16(extensionSCT) b.AddUint16(0) // empty extension_data } + if len(m.supportedVersions) > 0 { + // RFC 8446, Section 4.2.1 + b.AddUint16(extensionSupportedVersions) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + for _, vers := range m.supportedVersions { + b.AddUint16(vers) + } + }) + }) + } + if len(m.cookie) > 0 { + // RFC 8446, Section 4.2.2 + b.AddUint16(extensionCookie) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.cookie) + }) + }) + } + if len(m.keyShares) > 0 { + // RFC 8446, Section 4.2.8 + b.AddUint16(extensionKeyShare) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, ks := range m.keyShares { + b.AddUint16(uint16(ks.group)) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(ks.data) + }) + } + }) + }) + } + if len(m.pskModes) > 0 { + // RFC 8446, Section 4.2.9 + b.AddUint16(extensionPSKModes) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.pskModes) + }) + }) + } + if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension + // RFC 8446, Section 4.2.11 + b.AddUint16(extensionPreSharedKey) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, psk := range m.pskIdentities { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(psk.label) + }) + b.AddUint32(psk.obfuscatedTicketAge) + } + }) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, binder := range m.pskBinders { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(binder) + }) + } + }) + }) + } extensionsPresent = len(b.BytesOrPanic()) > 2 }) @@ -291,7 +373,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { } m.ocspStapling = statusType == statusTypeOCSP case extensionSupportedCurves: - // RFC 4492, Section 5.1.1 + // RFC 4492, Section 5.1.1 and RFC 8446, Section 4.2.7 var curves cryptobyte.String if !extData.ReadUint16LengthPrefixed(&curves) || curves.Empty() { return false @@ -327,6 +409,20 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { m.supportedSignatureAlgorithms = append( m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg)) } + case extensionSignatureAlgorithmsCert: + // RFC 8446, Section 4.2.3 + var sigAndAlgs cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { + return false + } + for !sigAndAlgs.Empty() { + var sigAndAlg uint16 + if !sigAndAlgs.ReadUint16(&sigAndAlg) { + return false + } + m.supportedSignatureAlgorithmsCert = append( + m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg)) + } case extensionRenegotiationInfo: // RFC 5746, Section 3.2 if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) { @@ -349,6 +445,74 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { case extensionSCT: // RFC 6962, Section 3.3.1 m.scts = true + case extensionSupportedVersions: + // RFC 8446, Section 4.2.1 + var versList cryptobyte.String + if !extData.ReadUint8LengthPrefixed(&versList) || versList.Empty() { + return false + } + for !versList.Empty() { + var vers uint16 + if !versList.ReadUint16(&vers) { + return false + } + m.supportedVersions = append(m.supportedVersions, vers) + } + case extensionCookie: + // RFC 8446, Section 4.2.2 + if !readUint16LengthPrefixed(&extData, &m.cookie) { + return false + } + case extensionKeyShare: + // RFC 8446, Section 4.2.8 + var clientShares cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&clientShares) { + return false + } + for !clientShares.Empty() { + var ks keyShare + if !clientShares.ReadUint16((*uint16)(&ks.group)) || + !readUint16LengthPrefixed(&clientShares, &ks.data) || + len(ks.data) == 0 { + return false + } + m.keyShares = append(m.keyShares, ks) + } + case extensionPSKModes: + // RFC 8446, Section 4.2.9 + if !readUint8LengthPrefixed(&extData, &m.pskModes) { + return false + } + case extensionPreSharedKey: + // RFC 8446, Section 4.2.11 + if !extensions.Empty() { + return false // pre_shared_key must be the last extension + } + var identities cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&identities) || identities.Empty() { + return false + } + for !identities.Empty() { + var psk pskIdentity + if !readUint16LengthPrefixed(&identities, &psk.label) || + !identities.ReadUint32(&psk.obfuscatedTicketAge) || + len(psk.label) == 0 { + return false + } + m.pskIdentities = append(m.pskIdentities, psk) + } + var binders cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&binders) || binders.Empty() { + return false + } + for !binders.Empty() { + var binder []byte + if !readUint8LengthPrefixed(&binders, &binder) || + len(binder) == 0 { + return false + } + m.pskBinders = append(m.pskBinders, binder) + } default: // Ignore unknown extensions. continue @@ -372,11 +536,19 @@ type serverHelloMsg struct { nextProtoNeg bool nextProtos []string ocspStapling bool - scts [][]byte ticketSupported bool - secureRenegotiation []byte secureRenegotiationSupported bool + secureRenegotiation []byte alpnProtocol string + scts [][]byte + supportedVersion uint16 + serverShare keyShare + selectedIdentityPresent bool + selectedIdentity uint16 + + // HelloRetryRequest extensions + cookie []byte + selectedGroup CurveID } func (m *serverHelloMsg) marshal() []byte { @@ -448,6 +620,42 @@ func (m *serverHelloMsg) marshal() []byte { }) }) } + if m.supportedVersion != 0 { + b.AddUint16(extensionSupportedVersions) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(m.supportedVersion) + }) + } + if m.serverShare.group != 0 { + b.AddUint16(extensionKeyShare) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(uint16(m.serverShare.group)) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.serverShare.data) + }) + }) + } + if m.selectedIdentityPresent { + b.AddUint16(extensionPreSharedKey) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(m.selectedIdentity) + }) + } + + if len(m.cookie) > 0 { + b.AddUint16(extensionCookie) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.cookie) + }) + }) + } + if m.selectedGroup != 0 { + b.AddUint16(extensionKeyShare) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(uint16(m.selectedGroup)) + }) + } extensionsPresent = len(b.BytesOrPanic()) > 2 }) @@ -535,6 +743,32 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool { } m.scts = append(m.scts, sct) } + case extensionSupportedVersions: + if !extData.ReadUint16(&m.supportedVersion) { + return false + } + case extensionCookie: + if !readUint16LengthPrefixed(&extData, &m.cookie) { + return false + } + case extensionKeyShare: + // This extension has different formats in SH and HRR, accept either + // and let the handshake logic decide. See RFC 8446, Section 4.2.8. + if len(extData) == 2 { + if !extData.ReadUint16((*uint16)(&m.selectedGroup)) { + return false + } + } else { + if !extData.ReadUint16((*uint16)(&m.serverShare.group)) || + !readUint16LengthPrefixed(&extData, &m.serverShare.data) { + return false + } + } + case extensionPreSharedKey: + m.selectedIdentityPresent = true + if !extData.ReadUint16(&m.selectedIdentity) { + return false + } default: // Ignore unknown extensions. continue diff --git a/handshake_messages_test.go b/handshake_messages_test.go index 08622ea..fdf096b 100644 --- a/handshake_messages_test.go +++ b/handshake_messages_test.go @@ -11,6 +11,7 @@ import ( "strings" "testing" "testing/quick" + "time" ) var tests = []interface{}{ @@ -31,7 +32,7 @@ var tests = []interface{}{ } func TestMarshalUnmarshal(t *testing.T) { - rand := rand.New(rand.NewSource(0)) + rand := rand.New(rand.NewSource(time.Now().UnixNano())) for i, iface := range tests { ty := reflect.ValueOf(iface).Type() @@ -132,7 +133,7 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { m.supportedPoints = randomBytes(rand.Intn(5)+1, rand) m.supportedCurves = make([]CurveID, rand.Intn(5)+1) for i := range m.supportedCurves { - m.supportedCurves[i] = CurveID(rand.Intn(30000)) + m.supportedCurves[i] = CurveID(rand.Intn(30000) + 1) } if rand.Intn(10) > 5 { m.ticketSupported = true @@ -145,6 +146,9 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { if rand.Intn(10) > 5 { m.supportedSignatureAlgorithms = supportedSignatureAlgorithms() } + if rand.Intn(10) > 5 { + m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms() + } for i := 0; i < rand.Intn(5); i++ { m.alpnProtocols = append(m.alpnProtocols, randomString(rand.Intn(20)+1, rand)) } @@ -155,6 +159,31 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { m.secureRenegotiationSupported = true m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand) } + for i := 0; i < rand.Intn(5); i++ { + m.supportedVersions = append(m.supportedVersions, uint16(rand.Intn(0xffff)+1)) + } + if rand.Intn(10) > 5 { + m.cookie = randomBytes(rand.Intn(500)+1, rand) + } + for i := 0; i < rand.Intn(5); i++ { + var ks keyShare + ks.group = CurveID(rand.Intn(30000) + 1) + ks.data = randomBytes(rand.Intn(200)+1, rand) + m.keyShares = append(m.keyShares, ks) + } + switch rand.Intn(3) { + case 1: + m.pskModes = []uint8{pskModeDHE} + case 2: + m.pskModes = []uint8{pskModeDHE, pskModePlain} + } + for i := 0; i < rand.Intn(5); i++ { + var psk pskIdentity + psk.obfuscatedTicketAge = uint32(rand.Intn(500000)) + psk.label = randomBytes(rand.Intn(500)+1, rand) + m.pskIdentities = append(m.pskIdentities, psk) + m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand)) + } return reflect.ValueOf(m) } @@ -190,6 +219,24 @@ func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { m.secureRenegotiationSupported = true m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand) } + if rand.Intn(10) > 5 { + m.supportedVersion = uint16(rand.Intn(0xffff) + 1) + } + if rand.Intn(10) > 5 { + m.cookie = randomBytes(rand.Intn(500)+1, rand) + } + if rand.Intn(10) > 5 { + for i := 0; i < rand.Intn(5); i++ { + m.serverShare.group = CurveID(rand.Intn(30000) + 1) + m.serverShare.data = randomBytes(rand.Intn(200)+1, rand) + } + } else if rand.Intn(10) > 5 { + m.selectedGroup = CurveID(rand.Intn(30000) + 1) + } + if rand.Intn(10) > 5 { + m.selectedIdentityPresent = true + m.selectedIdentity = uint16(rand.Intn(0xffff)) + } return reflect.ValueOf(m) } diff --git a/tls_test.go b/tls_test.go index 7542699..e9abe01 100644 --- a/tls_test.go +++ b/tls_test.go @@ -589,7 +589,7 @@ func TestWarningAlertFlood(t *testing.T) { if err == nil { return errors.New("unexpected lack of error from server") } - const expected = "too many warn" + const expected = "too many ignored" if str := err.Error(); !strings.Contains(str, expected) { return fmt.Errorf("expected error containing %q, but saw: %s", expected, str) } @@ -610,7 +610,7 @@ func TestWarningAlertFlood(t *testing.T) { t.Fatal(err) } - for i := 0; i < maxWarnAlertCount+1; i++ { + for i := 0; i < maxUselessRecords+1; i++ { conn.sendAlert(alertNoRenegotiation) }