[dev.boringcrypto] all: merge master into dev.boringcrypto

Change-Id: Ia068dac1677bfc44c41e35d1f46e6499911cfae0
This commit is contained in:
Filippo Valsorda 2018-11-14 15:28:13 -05:00
commit e7b501c673
11 changed files with 978 additions and 119 deletions

View file

@ -43,9 +43,9 @@ func TestBoringServerProtocolVersion(t *testing.T) {
fipstls.Force() fipstls.Force()
defer fipstls.Abandon() defer fipstls.Abandon()
test("VersionSSL30", VersionSSL30, "unsupported, maximum protocol version") test("VersionSSL30", VersionSSL30, "client offered only unsupported versions")
test("VersionTLS10", VersionTLS10, "unsupported, maximum protocol version") test("VersionTLS10", VersionTLS10, "client offered only unsupported versions")
test("VersionTLS11", VersionTLS11, "unsupported, maximum protocol version") test("VersionTLS11", VersionTLS11, "client offered only unsupported versions")
test("VersionTLS12", VersionTLS12, "") test("VersionTLS12", VersionTLS12, "")
} }

105
common.go
View file

@ -40,9 +40,6 @@ const (
recordHeaderLen = 5 // record header length recordHeaderLen = 5 // record header length
maxHandshake = 65536 // maximum handshake we support (protocol max is 16 MB) maxHandshake = 65536 // maximum handshake we support (protocol max is 16 MB)
maxUselessRecords = 5 // maximum number of consecutive non-advancing records maxUselessRecords = 5 // maximum number of consecutive non-advancing records
minVersion = VersionTLS10
maxVersion = VersionTLS12
) )
// TLS record types. // TLS record types.
@ -57,19 +54,23 @@ const (
// TLS handshake message types. // TLS handshake message types.
const ( const (
typeHelloRequest uint8 = 0 typeHelloRequest uint8 = 0
typeClientHello uint8 = 1 typeClientHello uint8 = 1
typeServerHello uint8 = 2 typeServerHello uint8 = 2
typeNewSessionTicket uint8 = 4 typeNewSessionTicket uint8 = 4
typeCertificate uint8 = 11 typeEndOfEarlyData uint8 = 5
typeServerKeyExchange uint8 = 12 typeEncryptedExtensions uint8 = 8
typeCertificateRequest uint8 = 13 typeCertificate uint8 = 11
typeServerHelloDone uint8 = 14 typeServerKeyExchange uint8 = 12
typeCertificateVerify uint8 = 15 typeCertificateRequest uint8 = 13
typeClientKeyExchange uint8 = 16 typeServerHelloDone uint8 = 14
typeFinished uint8 = 20 typeCertificateVerify uint8 = 15
typeCertificateStatus uint8 = 22 typeClientKeyExchange uint8 = 16
typeNextProtocol uint8 = 67 // Not IANA assigned typeFinished uint8 = 20
typeCertificateStatus uint8 = 22
typeKeyUpdate uint8 = 24
typeNextProtocol uint8 = 67 // Not IANA assigned
typeMessageHash uint8 = 254 // synthetic message
) )
// TLS compression types. // TLS compression types.
@ -88,6 +89,7 @@ const (
extensionSCT uint16 = 18 extensionSCT uint16 = 18
extensionSessionTicket uint16 = 35 extensionSessionTicket uint16 = 35
extensionPreSharedKey uint16 = 41 extensionPreSharedKey uint16 = 41
extensionEarlyData uint16 = 42
extensionSupportedVersions uint16 = 43 extensionSupportedVersions uint16 = 43
extensionCookie uint16 = 44 extensionCookie uint16 = 44
extensionPSKModes uint16 = 45 extensionPSKModes uint16 = 45
@ -713,24 +715,46 @@ func (c *Config) cipherSuites() []uint16 {
return s return s
} }
func (c *Config) minVersion() uint16 { var supportedVersions = []uint16{
if needFIPS() { VersionTLS12,
return fipsMinVersion(c) VersionTLS11,
} VersionTLS10,
if c == nil || c.MinVersion == 0 { VersionSSL30,
return minVersion
}
return c.MinVersion
} }
func (c *Config) maxVersion() uint16 { func (c *Config) supportedVersions(isClient bool) []uint16 {
if needFIPS() { versions := make([]uint16, 0, len(supportedVersions))
return fipsMaxVersion(c) for _, v := range supportedVersions {
if needFIPS() && (v < fipsMinVersion(c) || v > fipsMaxVersion(c)) {
continue
}
if c != nil && c.MinVersion != 0 && v < c.MinVersion {
continue
}
if c != nil && c.MaxVersion != 0 && v > c.MaxVersion {
continue
}
// TLS 1.0 is the minimum version supported as a client.
if isClient && v < VersionTLS10 {
continue
}
versions = append(versions, v)
} }
if c == nil || c.MaxVersion == 0 { return versions
return maxVersion }
// supportedVersionsFromMax returns a list of supported versions derived from a
// legacy maximum version value. Note that only versions supported by this
// library are returned. Any newer peer will use supportedVersions anyway.
func supportedVersionsFromMax(maxVersion uint16) []uint16 {
versions := make([]uint16, 0, len(supportedVersions))
for _, v := range supportedVersions {
if v > maxVersion {
continue
}
versions = append(versions, v)
} }
return c.MaxVersion return versions
} }
var defaultCurvePreferences = []CurveID{X25519, CurveP256, CurveP384, CurveP521} var defaultCurvePreferences = []CurveID{X25519, CurveP256, CurveP384, CurveP521}
@ -746,18 +770,17 @@ func (c *Config) curvePreferences() []CurveID {
} }
// mutualVersion returns the protocol version to use given the advertised // mutualVersion returns the protocol version to use given the advertised
// version of the peer. // versions of the peer. Priority is given to the peer preference order.
func (c *Config) mutualVersion(vers uint16) (uint16, bool) { func (c *Config) mutualVersion(isClient bool, peerVersions []uint16) (uint16, bool) {
minVersion := c.minVersion() supportedVersions := c.supportedVersions(isClient)
maxVersion := c.maxVersion() for _, peerVersion := range peerVersions {
for _, v := range supportedVersions {
if vers < minVersion { if v == peerVersion {
return 0, false return v, true
}
}
} }
if vers > maxVersion { return 0, false
vers = maxVersion
}
return vers, true
} }
// getCertificate returns the best certificate for the given ClientHelloInfo, // getCertificate returns the best certificate for the given ClientHelloInfo,

26
conn.go
View file

@ -990,12 +990,24 @@ func (c *Conn) readHandshake() (interface{}, error) {
case typeServerHello: case typeServerHello:
m = new(serverHelloMsg) m = new(serverHelloMsg)
case typeNewSessionTicket: case typeNewSessionTicket:
m = new(newSessionTicketMsg) if c.vers == VersionTLS13 {
m = new(newSessionTicketMsgTLS13)
} else {
m = new(newSessionTicketMsg)
}
case typeCertificate: case typeCertificate:
m = new(certificateMsg) if c.vers == VersionTLS13 {
m = new(certificateMsgTLS13)
} else {
m = new(certificateMsg)
}
case typeCertificateRequest: case typeCertificateRequest:
m = &certificateRequestMsg{ if c.vers == VersionTLS13 {
hasSignatureAlgorithm: c.vers >= VersionTLS12, m = new(certificateRequestMsgTLS13)
} else {
m = &certificateRequestMsg{
hasSignatureAlgorithm: c.vers >= VersionTLS12,
}
} }
case typeCertificateStatus: case typeCertificateStatus:
m = new(certificateStatusMsg) m = new(certificateStatusMsg)
@ -1013,6 +1025,12 @@ func (c *Conn) readHandshake() (interface{}, error) {
m = new(nextProtoMsg) m = new(nextProtoMsg)
case typeFinished: case typeFinished:
m = new(finishedMsg) m = new(finishedMsg)
case typeEncryptedExtensions:
m = new(encryptedExtensionsMsg)
case typeEndOfEarlyData:
m = new(endOfEarlyDataMsg)
case typeKeyUpdate:
m = new(keyUpdateMsg)
default: default:
return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
} }

View file

@ -43,13 +43,25 @@ func makeClientHello(config *Config) (*clientHelloMsg, error) {
nextProtosLength += 1 + l nextProtosLength += 1 + l
} }
} }
if nextProtosLength > 0xffff { if nextProtosLength > 0xffff {
return nil, errors.New("tls: NextProtos values too large") return nil, errors.New("tls: NextProtos values too large")
} }
supportedVersions := config.supportedVersions(true)
if len(supportedVersions) == 0 {
return nil, errors.New("tls: no supported versions satisfy MinVersion and MaxVersion")
}
clientHelloVersion := supportedVersions[0]
// The version at the beginning of the ClientHello was capped at TLS 1.2
// for compatibility reasons. The supported_versions extension is used
// to negotiate versions now. See RFC 8446, Section 4.2.1.
if clientHelloVersion > VersionTLS12 {
clientHelloVersion = VersionTLS12
}
hello := &clientHelloMsg{ hello := &clientHelloMsg{
vers: config.maxVersion(), vers: clientHelloVersion,
compressionMethods: []uint8{compressionNone}, compressionMethods: []uint8{compressionNone},
random: make([]byte, 32), random: make([]byte, 32),
ocspStapling: true, ocspStapling: true,
@ -60,6 +72,7 @@ func makeClientHello(config *Config) (*clientHelloMsg, error) {
nextProtoNeg: len(config.NextProtos) > 0, nextProtoNeg: len(config.NextProtos) > 0,
secureRenegotiationSupported: true, secureRenegotiationSupported: true,
alpnProtocols: config.NextProtos, alpnProtocols: config.NextProtos,
supportedVersions: supportedVersions,
} }
possibleCipherSuites := config.cipherSuites() possibleCipherSuites := config.cipherSuites()
hello.cipherSuites = make([]uint16, 0, len(possibleCipherSuites)) hello.cipherSuites = make([]uint16, 0, len(possibleCipherSuites))
@ -143,8 +156,14 @@ func (c *Conn) clientHandshake() error {
} }
} }
versOk := candidateSession.vers >= c.config.minVersion() && versOk := false
candidateSession.vers <= c.config.maxVersion() for _, v := range c.config.supportedVersions(true) {
if v == candidateSession.vers {
versOk = true
break
}
}
if versOk && cipherSuiteOk { if versOk && cipherSuiteOk {
session = candidateSession session = candidateSession
} }
@ -276,11 +295,15 @@ func (hs *clientHandshakeState) handshake() error {
} }
func (hs *clientHandshakeState) pickTLSVersion() error { func (hs *clientHandshakeState) pickTLSVersion() error {
vers, ok := hs.c.config.mutualVersion(hs.serverHello.vers) peerVersion := hs.serverHello.vers
if !ok || vers < VersionTLS10 { if hs.serverHello.supportedVersion != 0 {
// TLS 1.0 is the minimum version supported as a client. peerVersion = hs.serverHello.supportedVersion
}
vers, ok := hs.c.config.mutualVersion(true, []uint16{peerVersion})
if !ok {
hs.c.sendAlert(alertProtocolVersion) hs.c.sendAlert(alertProtocolVersion)
return fmt.Errorf("tls: server selected unsupported protocol version %x", hs.serverHello.vers) return fmt.Errorf("tls: server selected unsupported protocol version %x", peerVersion)
} }
hs.c.vers = vers hs.c.vers = vers
@ -398,9 +421,7 @@ func (hs *clientHandshakeState) doFullHandshake() error {
} }
hs.finishedHash.Write(cs.marshal()) hs.finishedHash.Write(cs.marshal())
if cs.statusType == statusTypeOCSP { c.ocspResponse = cs.response
c.ocspResponse = cs.response
}
msg, err = c.readHandshake() msg, err = c.readHandshake()
if err != nil { if err != nil {

View file

@ -279,6 +279,12 @@ func (test *clientTest) loadData() (flows [][]byte, err error) {
func (test *clientTest) run(t *testing.T, write bool) { func (test *clientTest) run(t *testing.T, write bool) {
checkOpenSSLVersion(t) checkOpenSSLVersion(t)
// TODO(filippo): regenerate client tests all at once after CL 146217,
// RSA-PSS and client-side TLS 1.3 are landed.
if !write {
t.Skip("recorded client tests are out of date")
}
var clientConn, serverConn net.Conn var clientConn, serverConn net.Conn
var recordingConn *recordingConn var recordingConn *recordingConn
var childProcess *exec.Cmd var childProcess *exec.Cmd

View file

@ -71,6 +71,7 @@ type clientHelloMsg struct {
supportedVersions []uint16 supportedVersions []uint16
cookie []byte cookie []byte
keyShares []keyShare keyShares []keyShare
earlyData bool
pskModes []uint8 pskModes []uint8
pskIdentities []pskIdentity pskIdentities []pskIdentity
pskBinders [][]byte pskBinders [][]byte
@ -239,6 +240,11 @@ func (m *clientHelloMsg) marshal() []byte {
}) })
}) })
} }
if m.earlyData {
// RFC 8446, Section 4.2.10
b.AddUint16(extensionEarlyData)
b.AddUint16(0) // empty extension_data
}
if len(m.pskModes) > 0 { if len(m.pskModes) > 0 {
// RFC 8446, Section 4.2.9 // RFC 8446, Section 4.2.9
b.AddUint16(extensionPSKModes) b.AddUint16(extensionPSKModes)
@ -478,6 +484,9 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
} }
m.keyShares = append(m.keyShares, ks) m.keyShares = append(m.keyShares, ks)
} }
case extensionEarlyData:
// RFC 8446, Section 4.2.10
m.earlyData = true
case extensionPSKModes: case extensionPSKModes:
// RFC 8446, Section 4.2.9 // RFC 8446, Section 4.2.9
if !readUint8LengthPrefixed(&extData, &m.pskModes) { if !readUint8LengthPrefixed(&extData, &m.pskModes) {
@ -782,6 +791,342 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool {
return true return true
} }
type encryptedExtensionsMsg struct {
raw []byte
alpnProtocol string
}
func (m *encryptedExtensionsMsg) marshal() []byte {
if m.raw != nil {
return m.raw
}
var b cryptobyte.Builder
b.AddUint8(typeEncryptedExtensions)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
if len(m.alpnProtocol) > 0 {
b.AddUint16(extensionALPN)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes([]byte(m.alpnProtocol))
})
})
})
}
})
})
m.raw = b.BytesOrPanic()
return m.raw
}
func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool {
*m = encryptedExtensionsMsg{raw: data}
s := cryptobyte.String(data)
var extensions cryptobyte.String
if !s.Skip(4) || // message type and uint24 length field
!s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
return false
}
for !extensions.Empty() {
var extension uint16
var extData cryptobyte.String
if !extensions.ReadUint16(&extension) ||
!extensions.ReadUint16LengthPrefixed(&extData) {
return false
}
switch extension {
case extensionALPN:
var protoList cryptobyte.String
if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
return false
}
var proto cryptobyte.String
if !protoList.ReadUint8LengthPrefixed(&proto) ||
proto.Empty() || !protoList.Empty() {
return false
}
m.alpnProtocol = string(proto)
default:
// Ignore unknown extensions.
continue
}
if !extData.Empty() {
return false
}
}
return true
}
type endOfEarlyDataMsg struct{}
func (m *endOfEarlyDataMsg) marshal() []byte {
x := make([]byte, 4)
x[0] = typeEndOfEarlyData
return x
}
func (m *endOfEarlyDataMsg) unmarshal(data []byte) bool {
return len(data) == 4
}
type keyUpdateMsg struct {
raw []byte
updateRequested bool
}
func (m *keyUpdateMsg) marshal() []byte {
if m.raw != nil {
return m.raw
}
var b cryptobyte.Builder
b.AddUint8(typeKeyUpdate)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
if m.updateRequested {
b.AddUint8(1)
} else {
b.AddUint8(0)
}
})
m.raw = b.BytesOrPanic()
return m.raw
}
func (m *keyUpdateMsg) unmarshal(data []byte) bool {
m.raw = data
s := cryptobyte.String(data)
var updateRequested uint8
if !s.Skip(4) || // message type and uint24 length field
!s.ReadUint8(&updateRequested) || !s.Empty() {
return false
}
switch updateRequested {
case 0:
m.updateRequested = false
case 1:
m.updateRequested = true
default:
return false
}
return true
}
type newSessionTicketMsgTLS13 struct {
raw []byte
lifetime uint32
ageAdd uint32
nonce []byte
label []byte
maxEarlyData uint32
}
func (m *newSessionTicketMsgTLS13) marshal() []byte {
if m.raw != nil {
return m.raw
}
var b cryptobyte.Builder
b.AddUint8(typeNewSessionTicket)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint32(m.lifetime)
b.AddUint32(m.ageAdd)
b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(m.nonce)
})
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(m.label)
})
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
if m.maxEarlyData > 0 {
b.AddUint16(extensionEarlyData)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint32(m.maxEarlyData)
})
}
})
})
m.raw = b.BytesOrPanic()
return m.raw
}
func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool {
*m = newSessionTicketMsgTLS13{raw: data}
s := cryptobyte.String(data)
var extensions cryptobyte.String
if !s.Skip(4) || // message type and uint24 length field
!s.ReadUint32(&m.lifetime) ||
!s.ReadUint32(&m.ageAdd) ||
!readUint8LengthPrefixed(&s, &m.nonce) ||
!readUint16LengthPrefixed(&s, &m.label) ||
!s.ReadUint16LengthPrefixed(&extensions) ||
!s.Empty() {
return false
}
for !extensions.Empty() {
var extension uint16
var extData cryptobyte.String
if !extensions.ReadUint16(&extension) ||
!extensions.ReadUint16LengthPrefixed(&extData) {
return false
}
switch extension {
case extensionEarlyData:
if !extData.ReadUint32(&m.maxEarlyData) {
return false
}
default:
// Ignore unknown extensions.
continue
}
if !extData.Empty() {
return false
}
}
return true
}
type certificateRequestMsgTLS13 struct {
raw []byte
ocspStapling bool
scts bool
supportedSignatureAlgorithms []SignatureScheme
supportedSignatureAlgorithmsCert []SignatureScheme
}
func (m *certificateRequestMsgTLS13) marshal() []byte {
if m.raw != nil {
return m.raw
}
var b cryptobyte.Builder
b.AddUint8(typeCertificateRequest)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
// certificate_request_context (SHALL be zero length unless used for
// post-handshake authentication)
b.AddUint8(0)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
if m.ocspStapling {
b.AddUint16(extensionStatusRequest)
b.AddUint16(0) // empty extension_data
}
if m.scts {
// RFC 8446, Section 4.4.2.1 makes no mention of
// signed_certificate_timestamp in CertificateRequest, but
// "Extensions in the Certificate message from the client MUST
// correspond to extensions in the CertificateRequest message
// from the server." and it appears in the table in Section 4.2.
b.AddUint16(extensionSCT)
b.AddUint16(0) // empty extension_data
}
if len(m.supportedSignatureAlgorithms) > 0 {
b.AddUint16(extensionSignatureAlgorithms)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
for _, sigAlgo := range m.supportedSignatureAlgorithms {
b.AddUint16(uint16(sigAlgo))
}
})
})
}
if len(m.supportedSignatureAlgorithmsCert) > 0 {
b.AddUint16(extensionSignatureAlgorithmsCert)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
for _, sigAlgo := range m.supportedSignatureAlgorithmsCert {
b.AddUint16(uint16(sigAlgo))
}
})
})
}
})
})
m.raw = b.BytesOrPanic()
return m.raw
}
func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool {
*m = certificateRequestMsgTLS13{raw: data}
s := cryptobyte.String(data)
var context, extensions cryptobyte.String
if !s.Skip(4) || // message type and uint24 length field
!s.ReadUint8LengthPrefixed(&context) || !context.Empty() ||
!s.ReadUint16LengthPrefixed(&extensions) ||
!s.Empty() {
return false
}
for !extensions.Empty() {
var extension uint16
var extData cryptobyte.String
if !extensions.ReadUint16(&extension) ||
!extensions.ReadUint16LengthPrefixed(&extData) {
return false
}
switch extension {
case extensionStatusRequest:
m.ocspStapling = true
case extensionSCT:
m.scts = true
case extensionSignatureAlgorithms:
var sigAndAlgs cryptobyte.String
if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
return false
}
for !sigAndAlgs.Empty() {
var sigAndAlg uint16
if !sigAndAlgs.ReadUint16(&sigAndAlg) {
return false
}
m.supportedSignatureAlgorithms = append(
m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg))
}
case extensionSignatureAlgorithmsCert:
var sigAndAlgs cryptobyte.String
if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
return false
}
for !sigAndAlgs.Empty() {
var sigAndAlg uint16
if !sigAndAlgs.ReadUint16(&sigAndAlg) {
return false
}
m.supportedSignatureAlgorithmsCert = append(
m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg))
}
default:
// Ignore unknown extensions.
continue
}
if !extData.Empty() {
return false
}
}
return true
}
type certificateMsg struct { type certificateMsg struct {
raw []byte raw []byte
certificates [][]byte certificates [][]byte
@ -859,6 +1204,131 @@ func (m *certificateMsg) unmarshal(data []byte) bool {
return true return true
} }
type certificateMsgTLS13 struct {
raw []byte
certificate Certificate
ocspStapling bool
scts bool
}
func (m *certificateMsgTLS13) marshal() []byte {
if m.raw != nil {
return m.raw
}
var b cryptobyte.Builder
b.AddUint8(typeCertificate)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint8(0) // certificate_request_context
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
for i, cert := range m.certificate.Certificate {
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(cert)
})
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
if i > 0 {
// This library only supports OCSP and SCT for leaf certificates.
return
}
if m.ocspStapling {
b.AddUint16(extensionStatusRequest)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint8(statusTypeOCSP)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(m.certificate.OCSPStaple)
})
})
}
if m.scts {
b.AddUint16(extensionSCT)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
for _, sct := range m.certificate.SignedCertificateTimestamps {
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(sct)
})
}
})
})
}
})
}
})
})
m.raw = b.BytesOrPanic()
return m.raw
}
func (m *certificateMsgTLS13) unmarshal(data []byte) bool {
*m = certificateMsgTLS13{raw: data}
s := cryptobyte.String(data)
var context, certList cryptobyte.String
if !s.Skip(4) || // message type and uint24 length field
!s.ReadUint8LengthPrefixed(&context) || !context.Empty() ||
!s.ReadUint24LengthPrefixed(&certList) ||
!s.Empty() {
return false
}
for !certList.Empty() {
var cert []byte
var extensions cryptobyte.String
if !readUint24LengthPrefixed(&certList, &cert) ||
!certList.ReadUint16LengthPrefixed(&extensions) {
return false
}
m.certificate.Certificate = append(m.certificate.Certificate, cert)
for !extensions.Empty() {
var extension uint16
var extData cryptobyte.String
if !extensions.ReadUint16(&extension) ||
!extensions.ReadUint16LengthPrefixed(&extData) {
return false
}
if len(m.certificate.Certificate) > 1 {
// This library only supports OCSP and SCT for leaf certificates.
continue
}
switch extension {
case extensionStatusRequest:
m.ocspStapling = true
var statusType uint8
if !extData.ReadUint8(&statusType) || statusType != statusTypeOCSP ||
!readUint24LengthPrefixed(&extData, &m.certificate.OCSPStaple) ||
len(m.certificate.OCSPStaple) == 0 {
return false
}
case extensionSCT:
m.scts = true
var sctList cryptobyte.String
if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() {
return false
}
for !sctList.Empty() {
var sct []byte
if !readUint16LengthPrefixed(&sctList, &sct) ||
len(sct) == 0 {
return false
}
m.certificate.SignedCertificateTimestamps = append(
m.certificate.SignedCertificateTimestamps, sct)
}
default:
// Ignore unknown extensions.
continue
}
if !extData.Empty() {
return false
}
}
}
return true
}
type serverKeyExchangeMsg struct { type serverKeyExchangeMsg struct {
raw []byte raw []byte
key []byte key []byte
@ -890,9 +1360,8 @@ func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool {
} }
type certificateStatusMsg struct { type certificateStatusMsg struct {
raw []byte raw []byte
statusType uint8 response []byte
response []byte
} }
func (m *certificateStatusMsg) marshal() []byte { func (m *certificateStatusMsg) marshal() []byte {
@ -900,46 +1369,29 @@ func (m *certificateStatusMsg) marshal() []byte {
return m.raw return m.raw
} }
var x []byte var b cryptobyte.Builder
if m.statusType == statusTypeOCSP { b.AddUint8(typeCertificateStatus)
x = make([]byte, 4+4+len(m.response)) b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
x[0] = typeCertificateStatus b.AddUint8(statusTypeOCSP)
l := len(m.response) + 4 b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
x[1] = byte(l >> 16) b.AddBytes(m.response)
x[2] = byte(l >> 8) })
x[3] = byte(l) })
x[4] = statusTypeOCSP
l -= 4 m.raw = b.BytesOrPanic()
x[5] = byte(l >> 16) return m.raw
x[6] = byte(l >> 8)
x[7] = byte(l)
copy(x[8:], m.response)
} else {
x = []byte{typeCertificateStatus, 0, 0, 1, m.statusType}
}
m.raw = x
return x
} }
func (m *certificateStatusMsg) unmarshal(data []byte) bool { func (m *certificateStatusMsg) unmarshal(data []byte) bool {
m.raw = data m.raw = data
if len(data) < 5 { s := cryptobyte.String(data)
return false
}
m.statusType = data[4]
m.response = nil var statusType uint8
if m.statusType == statusTypeOCSP { if !s.Skip(4) || // message type and uint24 length field
if len(data) < 8 { !s.ReadUint8(&statusType) || statusType != statusTypeOCSP ||
return false !readUint24LengthPrefixed(&s, &m.response) ||
} len(m.response) == 0 || !s.Empty() {
respLen := uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7]) return false
if uint32(len(data)) != 4+4+respLen {
return false
}
m.response = data[8:]
} }
return true return true
} }

View file

@ -29,6 +29,12 @@ var tests = []interface{}{
&nextProtoMsg{}, &nextProtoMsg{},
&newSessionTicketMsg{}, &newSessionTicketMsg{},
&sessionState{}, &sessionState{},
&encryptedExtensionsMsg{},
&endOfEarlyDataMsg{},
&keyUpdateMsg{},
&newSessionTicketMsgTLS13{},
&certificateRequestMsgTLS13{},
&certificateMsgTLS13{},
} }
func TestMarshalUnmarshal(t *testing.T) { func TestMarshalUnmarshal(t *testing.T) {
@ -184,6 +190,9 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m.pskIdentities = append(m.pskIdentities, psk) m.pskIdentities = append(m.pskIdentities, psk)
m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand)) m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand))
} }
if rand.Intn(10) > 5 {
m.earlyData = true
}
return reflect.ValueOf(m) return reflect.ValueOf(m)
} }
@ -209,7 +218,9 @@ func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
if rand.Intn(10) > 5 { if rand.Intn(10) > 5 {
m.ticketSupported = true m.ticketSupported = true
} }
m.alpnProtocol = randomString(rand.Intn(32)+1, rand) if rand.Intn(10) > 5 {
m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
}
for i := 0; i < rand.Intn(4); i++ { for i := 0; i < rand.Intn(4); i++ {
m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand)) m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand))
@ -241,6 +252,16 @@ func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
return reflect.ValueOf(m) return reflect.ValueOf(m)
} }
func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &encryptedExtensionsMsg{}
if rand.Intn(10) > 5 {
m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
}
return reflect.ValueOf(m)
}
func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value { func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &certificateMsg{} m := &certificateMsg{}
numCerts := rand.Intn(20) numCerts := rand.Intn(20)
@ -270,12 +291,7 @@ func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value {
func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value { func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &certificateStatusMsg{} m := &certificateStatusMsg{}
if rand.Intn(10) > 5 { m.response = randomBytes(rand.Intn(10)+1, rand)
m.statusType = statusTypeOCSP
m.response = randomBytes(rand.Intn(10)+1, rand)
} else {
m.statusType = 42
}
return reflect.ValueOf(m) return reflect.ValueOf(m)
} }
@ -316,6 +332,66 @@ func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value {
return reflect.ValueOf(s) return reflect.ValueOf(s)
} }
func (*endOfEarlyDataMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &endOfEarlyDataMsg{}
return reflect.ValueOf(m)
}
func (*keyUpdateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &keyUpdateMsg{}
m.updateRequested = rand.Intn(10) > 5
return reflect.ValueOf(m)
}
func (*newSessionTicketMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
m := &newSessionTicketMsgTLS13{}
m.lifetime = uint32(rand.Intn(500000))
m.ageAdd = uint32(rand.Intn(500000))
m.nonce = randomBytes(rand.Intn(100), rand)
m.label = randomBytes(rand.Intn(1000), rand)
if rand.Intn(10) > 5 {
m.maxEarlyData = uint32(rand.Intn(500000))
}
return reflect.ValueOf(m)
}
func (*certificateRequestMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
m := &certificateRequestMsgTLS13{}
if rand.Intn(10) > 5 {
m.ocspStapling = true
}
if rand.Intn(10) > 5 {
m.scts = true
}
if rand.Intn(10) > 5 {
m.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
}
if rand.Intn(10) > 5 {
m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms()
}
return reflect.ValueOf(m)
}
func (*certificateMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
m := &certificateMsgTLS13{}
for i := 0; i < rand.Intn(2)+1; i++ {
m.certificate.Certificate = append(
m.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand))
}
if rand.Intn(10) > 5 {
m.ocspStapling = true
m.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand)
}
if rand.Intn(10) > 5 {
m.scts = true
for i := 0; i < rand.Intn(2)+1; i++ {
m.certificate.SignedCertificateTimestamps = append(
m.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand))
}
}
return reflect.ValueOf(m)
}
func TestRejectEmptySCTList(t *testing.T) { func TestRejectEmptySCTList(t *testing.T) {
// RFC 6962, Section 3.3.1 specifies that empty SCT lists are invalid. // RFC 6962, Section 3.3.1 specifies that empty SCT lists are invalid.

View file

@ -135,14 +135,19 @@ func (hs *serverHandshakeState) readClientHello() (isResume bool, err error) {
} }
} }
c.vers, ok = c.config.mutualVersion(hs.clientHello.vers) clientVersions := hs.clientHello.supportedVersions
if len(hs.clientHello.supportedVersions) == 0 {
clientVersions = supportedVersionsFromMax(hs.clientHello.vers)
}
c.vers, ok = c.config.mutualVersion(false, clientVersions)
if !ok { if !ok {
c.sendAlert(alertProtocolVersion) c.sendAlert(alertProtocolVersion)
return false, fmt.Errorf("tls: client offered an unsupported, maximum protocol version of %x", hs.clientHello.vers) return false, fmt.Errorf("tls: client offered only unsupported versions: %x", clientVersions)
} }
c.haveVers = true c.haveVers = true
hs.hello = new(serverHelloMsg) hs.hello = new(serverHelloMsg)
hs.hello.vers = c.vers
supportedCurve := false supportedCurve := false
preferredCurves := c.config.curvePreferences() preferredCurves := c.config.curvePreferences()
@ -179,7 +184,6 @@ Curves:
return false, errors.New("tls: client does not support uncompressed connections") return false, errors.New("tls: client does not support uncompressed connections")
} }
hs.hello.vers = c.vers
hs.hello.random = make([]byte, 32) hs.hello.random = make([]byte, 32)
_, err = io.ReadFull(c.config.rand(), hs.hello.random) _, err = io.ReadFull(c.config.rand(), hs.hello.random)
if err != nil { if err != nil {
@ -272,7 +276,7 @@ Curves:
for _, id := range hs.clientHello.cipherSuites { for _, id := range hs.clientHello.cipherSuites {
if id == TLS_FALLBACK_SCSV { if id == TLS_FALLBACK_SCSV {
// The client is doing a fallback connection. // The client is doing a fallback connection.
if hs.clientHello.vers < c.config.maxVersion() { if hs.clientHello.vers < c.config.supportedVersions(false)[0] {
c.sendAlert(alertInappropriateFallback) c.sendAlert(alertInappropriateFallback)
return false, errors.New("tls: client using inappropriate protocol fallback") return false, errors.New("tls: client using inappropriate protocol fallback")
} }
@ -389,7 +393,6 @@ func (hs *serverHandshakeState) doFullHandshake() error {
if hs.hello.ocspStapling { if hs.hello.ocspStapling {
certStatus := new(certificateStatusMsg) certStatus := new(certificateStatusMsg)
certStatus.statusType = statusTypeOCSP
certStatus.response = hs.cert.OCSPStaple certStatus.response = hs.cert.OCSPStaple
hs.finishedHash.Write(certStatus.marshal()) hs.finishedHash.Write(certStatus.marshal())
if _, err := c.writeRecord(recordTypeHandshake, certStatus.marshal()); err != nil { if _, err := c.writeRecord(recordTypeHandshake, certStatus.marshal()); err != nil {
@ -765,19 +768,14 @@ func (hs *serverHandshakeState) setCipherSuite(id uint16, supportedCipherSuites
return false return false
} }
// suppVersArray is the backing array of ClientHelloInfo.SupportedVersions
var suppVersArray = [...]uint16{VersionTLS12, VersionTLS11, VersionTLS10, VersionSSL30}
func (hs *serverHandshakeState) clientHelloInfo() *ClientHelloInfo { func (hs *serverHandshakeState) clientHelloInfo() *ClientHelloInfo {
if hs.cachedClientHelloInfo != nil { if hs.cachedClientHelloInfo != nil {
return hs.cachedClientHelloInfo return hs.cachedClientHelloInfo
} }
var supportedVersions []uint16 supportedVersions := hs.clientHello.supportedVersions
if hs.clientHello.vers > VersionTLS12 { if len(hs.clientHello.supportedVersions) == 0 {
supportedVersions = suppVersArray[:] supportedVersions = supportedVersionsFromMax(hs.clientHello.vers)
} else if hs.clientHello.vers >= VersionSSL30 {
supportedVersions = suppVersArray[VersionTLS12-hs.clientHello.vers:]
} }
hs.cachedClientHelloInfo = &ClientHelloInfo{ hs.cachedClientHelloInfo = &ClientHelloInfo{

View file

@ -104,8 +104,13 @@ func TestRejectBadProtocolVersion(t *testing.T) {
testClientHelloFailure(t, testConfig, &clientHelloMsg{ testClientHelloFailure(t, testConfig, &clientHelloMsg{
vers: v, vers: v,
random: make([]byte, 32), random: make([]byte, 32),
}, "unsupported, maximum protocol version") }, "unsupported versions")
} }
testClientHelloFailure(t, testConfig, &clientHelloMsg{
vers: VersionTLS12,
supportedVersions: badProtocolVersions,
random: make([]byte, 32),
}, "unsupported versions")
} }
func TestNoSuiteOverlap(t *testing.T) { func TestNoSuiteOverlap(t *testing.T) {
@ -1289,11 +1294,11 @@ var getConfigForClientTests = []struct {
func(clientHello *ClientHelloInfo) (*Config, error) { func(clientHello *ClientHelloInfo) (*Config, error) {
config := testConfig.Clone() config := testConfig.Clone()
// Setting a maximum version of TLS 1.1 should cause // Setting a maximum version of TLS 1.1 should cause
// the handshake to fail. // the handshake to fail, as the client MinVersion is TLS 1.2.
config.MaxVersion = VersionTLS11 config.MaxVersion = VersionTLS11
return config, nil return config, nil
}, },
"version 301 when expecting version 302", "client offered only unsupported versions",
nil, nil,
}, },
{ {

85
key_schedule.go Normal file
View file

@ -0,0 +1,85 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"golang_org/x/crypto/cryptobyte"
"golang_org/x/crypto/hkdf"
"hash"
)
// This file contains the functions necessary to compute the TLS 1.3 key
// schedule. See RFC 8446, Section 7.
const (
resumptionBinderLabel = "res binder"
clientHandshakeTrafficLabel = "c hs traffic"
serverHandshakeTrafficLabel = "s hs traffic"
clientApplicationTrafficLabel = "c ap traffic"
serverApplicationTrafficLabel = "s ap traffic"
exporterLabel = "exp master"
resumptionLabel = "res master"
trafficUpdateLabel = "traffic upd"
)
// expandLabel implements HKDF-Expand-Label from RFC 8446, Section 7.1.
func (c *cipherSuiteTLS13) expandLabel(secret []byte, label string, context []byte, length int) []byte {
var hkdfLabel cryptobyte.Builder
hkdfLabel.AddUint16(uint16(length))
hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes([]byte("tls13 "))
b.AddBytes([]byte(label))
})
hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(context)
})
out := make([]byte, length)
n, err := hkdf.Expand(c.hash.New, secret, hkdfLabel.BytesOrPanic()).Read(out)
if err != nil || n != length {
panic("tls: HKDF-Expand-Label invocation failed unexpectedly")
}
return out
}
// deriveSecret implements Derive-Secret from RFC 8446, Section 7.1.
func (c *cipherSuiteTLS13) deriveSecret(secret []byte, label string, transcript hash.Hash) []byte {
if transcript == nil {
transcript = c.hash.New()
}
return c.expandLabel(secret, label, transcript.Sum(nil), c.hash.Size())
}
// extract implements HKDF-Extract with the cipher suite hash.
func (c *cipherSuiteTLS13) extract(newSecret, currentSecret []byte) []byte {
if newSecret == nil {
newSecret = make([]byte, c.hash.Size())
}
return hkdf.Extract(c.hash.New, newSecret, currentSecret)
}
// nextTrafficSecret generates the next traffic secret, given the current one,
// according to RFC 8446, Section 7.2.
func (c *cipherSuiteTLS13) nextTrafficSecret(trafficSecret []byte) []byte {
return c.expandLabel(trafficSecret, trafficUpdateLabel, nil, c.hash.Size())
}
// trafficKey generates traffic keys according to RFC 8446, Section 7.3.
func (c *cipherSuiteTLS13) trafficKey(trafficSecret []byte) (key, iv []byte) {
key = c.expandLabel(trafficSecret, "key", nil, c.keyLen)
iv = c.expandLabel(trafficSecret, "iv", nil, aeadNonceLength)
return
}
// exportKeyingMaterial implements RFC5705 exporters for TLS 1.3 according to
// RFC 8446, Section 7.5.
func (c *cipherSuiteTLS13) exportKeyingMaterial(masterSecret []byte, transcript hash.Hash) func(string, []byte, int) ([]byte, error) {
expMasterSecret := c.deriveSecret(masterSecret, exporterLabel, transcript)
return func(label string, context []byte, length int) ([]byte, error) {
secret := c.deriveSecret(expMasterSecret, label, nil)
h := c.hash.New()
h.Write(context)
return c.expandLabel(secret, "exporter", h.Sum(nil), length), nil
}
}

175
key_schedule_test.go Normal file
View file

@ -0,0 +1,175 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"bytes"
"encoding/hex"
"hash"
"strings"
"testing"
"unicode"
)
// This file contains tests derived from draft-ietf-tls-tls13-vectors-07.
func parseVector(v string) []byte {
v = strings.Map(func(c rune) rune {
if unicode.IsSpace(c) {
return -1
}
return c
}, v)
parts := strings.Split(v, ":")
v = parts[len(parts)-1]
res, err := hex.DecodeString(v)
if err != nil {
panic(err)
}
return res
}
func TestDeriveSecret(t *testing.T) {
chTranscript := cipherSuitesTLS13[0].hash.New()
chTranscript.Write(parseVector(`
payload (512 octets): 01 00 01 fc 03 03 1b c3 ce b6 bb e3 9c ff
93 83 55 b5 a5 0a db 6d b2 1b 7a 6a f6 49 d7 b4 bc 41 9d 78 76
48 7d 95 00 00 06 13 01 13 03 13 02 01 00 01 cd 00 00 00 0b 00
09 00 00 06 73 65 72 76 65 72 ff 01 00 01 00 00 0a 00 14 00 12
00 1d 00 17 00 18 00 19 01 00 01 01 01 02 01 03 01 04 00 33 00
26 00 24 00 1d 00 20 e4 ff b6 8a c0 5f 8d 96 c9 9d a2 66 98 34
6c 6b e1 64 82 ba dd da fe 05 1a 66 b4 f1 8d 66 8f 0b 00 2a 00
00 00 2b 00 03 02 03 04 00 0d 00 20 00 1e 04 03 05 03 06 03 02
03 08 04 08 05 08 06 04 01 05 01 06 01 02 01 04 02 05 02 06 02
02 02 00 2d 00 02 01 01 00 1c 00 02 40 01 00 15 00 57 00 00 00
00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00
00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00
00 29 00 dd 00 b8 00 b2 2c 03 5d 82 93 59 ee 5f f7 af 4e c9 00
00 00 00 26 2a 64 94 dc 48 6d 2c 8a 34 cb 33 fa 90 bf 1b 00 70
ad 3c 49 88 83 c9 36 7c 09 a2 be 78 5a bc 55 cd 22 60 97 a3 a9
82 11 72 83 f8 2a 03 a1 43 ef d3 ff 5d d3 6d 64 e8 61 be 7f d6
1d 28 27 db 27 9c ce 14 50 77 d4 54 a3 66 4d 4e 6d a4 d2 9e e0
37 25 a6 a4 da fc d0 fc 67 d2 ae a7 05 29 51 3e 3d a2 67 7f a5
90 6c 5b 3f 7d 8f 92 f2 28 bd a4 0d da 72 14 70 f9 fb f2 97 b5
ae a6 17 64 6f ac 5c 03 27 2e 97 07 27 c6 21 a7 91 41 ef 5f 7d
e6 50 5e 5b fb c3 88 e9 33 43 69 40 93 93 4a e4 d3 57 fa d6 aa
cb 00 21 20 3a dd 4f b2 d8 fd f8 22 a0 ca 3c f7 67 8e f5 e8 8d
ae 99 01 41 c5 92 4d 57 bb 6f a3 1b 9e 5f 9d`))
type args struct {
secret []byte
label string
transcript hash.Hash
}
tests := []struct {
name string
args args
want []byte
}{
{
`derive secret for handshake "tls13 derived"`,
args{
parseVector(`PRK (32 octets): 33 ad 0a 1c 60 7e c0 3b 09 e6 cd 98 93 68 0c e2
10 ad f3 00 aa 1f 26 60 e1 b2 2e 10 f1 70 f9 2a`),
"derived",
nil,
},
parseVector(`expanded (32 octets): 6f 26 15 a1 08 c7 02 c5 67 8f 54 fc 9d ba
b6 97 16 c0 76 18 9c 48 25 0c eb ea c3 57 6c 36 11 ba`),
},
{
`derive secret "tls13 c e traffic"`,
args{
parseVector(`PRK (32 octets): 9b 21 88 e9 b2 fc 6d 64 d7 1d c3 29 90 0e 20 bb
41 91 50 00 f6 78 aa 83 9c bb 79 7c b7 d8 33 2c`),
"c e traffic",
chTranscript,
},
parseVector(`expanded (32 octets): 3f bb e6 a6 0d eb 66 c3 0a 32 79 5a ba 0e
ff 7e aa 10 10 55 86 e7 be 5c 09 67 8d 63 b6 ca ab 62`),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := cipherSuitesTLS13[0]
if got := c.deriveSecret(tt.args.secret, tt.args.label, tt.args.transcript); !bytes.Equal(got, tt.want) {
t.Errorf("cipherSuiteTLS13.deriveSecret() = % x, want % x", got, tt.want)
}
})
}
}
func TestTrafficKey(t *testing.T) {
trafficSecret := parseVector(
`PRK (32 octets): b6 7b 7d 69 0c c1 6c 4e 75 e5 42 13 cb 2d 37 b4
e9 c9 12 bc de d9 10 5d 42 be fd 59 d3 91 ad 38`)
wantKey := parseVector(
`key expanded (16 octets): 3f ce 51 60 09 c2 17 27 d0 f2 e4 e8 6e
e4 03 bc`)
wantIV := parseVector(
`iv expanded (12 octets): 5d 31 3e b2 67 12 76 ee 13 00 0b 30`)
c := cipherSuitesTLS13[0]
gotKey, gotIV := c.trafficKey(trafficSecret)
if !bytes.Equal(gotKey, wantKey) {
t.Errorf("cipherSuiteTLS13.trafficKey() gotKey = % x, want % x", gotKey, wantKey)
}
if !bytes.Equal(gotIV, wantIV) {
t.Errorf("cipherSuiteTLS13.trafficKey() gotIV = % x, want % x", gotIV, wantIV)
}
}
func TestExtract(t *testing.T) {
type args struct {
newSecret []byte
currentSecret []byte
}
tests := []struct {
name string
args args
want []byte
}{
{
`extract secret "early"`,
args{
nil,
nil,
},
parseVector(`secret (32 octets): 33 ad 0a 1c 60 7e c0 3b 09 e6 cd 98 93 68 0c
e2 10 ad f3 00 aa 1f 26 60 e1 b2 2e 10 f1 70 f9 2a`),
},
{
`extract secret "master"`,
args{
nil,
parseVector(`salt (32 octets): 43 de 77 e0 c7 77 13 85 9a 94 4d b9 db 25 90 b5
31 90 a6 5b 3e e2 e4 f1 2d d7 a0 bb 7c e2 54 b4`),
},
parseVector(`secret (32 octets): 18 df 06 84 3d 13 a0 8b f2 a4 49 84 4c 5f 8a
47 80 01 bc 4d 4c 62 79 84 d5 a4 1d a8 d0 40 29 19`),
},
{
`extract secret "handshake"`,
args{
parseVector(`IKM (32 octets): 8b d4 05 4f b5 5b 9d 63 fd fb ac f9 f0 4b 9f 0d
35 e6 d6 3f 53 75 63 ef d4 62 72 90 0f 89 49 2d`),
parseVector(`salt (32 octets): 6f 26 15 a1 08 c7 02 c5 67 8f 54 fc 9d ba b6 97
16 c0 76 18 9c 48 25 0c eb ea c3 57 6c 36 11 ba`),
},
parseVector(`secret (32 octets): 1d c8 26 e9 36 06 aa 6f dc 0a ad c1 2f 74 1b
01 04 6a a6 b9 9f 69 1e d2 21 a9 f0 ca 04 3f be ac`),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := cipherSuitesTLS13[0]
if got := c.extract(tt.args.newSecret, tt.args.currentSecret); !bytes.Equal(got, tt.want) {
t.Errorf("cipherSuiteTLS13.extract() = % x, want % x", got, tt.want)
}
})
}
}