[dev.boringcrypto] all: merge master into dev.boringcrypto

Change-Id: Ice4172e2058a45b1a24da561fd420244ab2a97bd
This commit is contained in:
Filippo Valsorda 2018-11-13 13:58:50 -05:00
commit ca4966e4f0
9 changed files with 497 additions and 970 deletions

View file

@ -28,6 +28,7 @@ func TestBoringServerProtocolVersion(t *testing.T) {
serverConfig.MinVersion = VersionSSL30 serverConfig.MinVersion = VersionSSL30
clientHello := &clientHelloMsg{ clientHello := &clientHelloMsg{
vers: v, vers: v,
random: make([]byte, 32),
cipherSuites: allCipherSuites(), cipherSuites: allCipherSuites(),
compressionMethods: []uint8{compressionNone}, compressionMethods: []uint8{compressionNone},
} }
@ -110,6 +111,7 @@ func TestBoringServerCipherSuites(t *testing.T) {
t.Run(fmt.Sprintf("suite=%#x", id), func(t *testing.T) { t.Run(fmt.Sprintf("suite=%#x", id), func(t *testing.T) {
clientHello := &clientHelloMsg{ clientHello := &clientHelloMsg{
vers: VersionTLS12, vers: VersionTLS12,
random: make([]byte, 32),
cipherSuites: []uint16{id}, cipherSuites: []uint16{id},
compressionMethods: []uint8{compressionNone}, compressionMethods: []uint8{compressionNone},
supportedCurves: defaultCurvePreferences, supportedCurves: defaultCurvePreferences,
@ -141,6 +143,7 @@ func TestBoringServerCurves(t *testing.T) {
t.Run(fmt.Sprintf("curve=%d", curveid), func(t *testing.T) { t.Run(fmt.Sprintf("curve=%d", curveid), func(t *testing.T) {
clientHello := &clientHelloMsg{ clientHello := &clientHelloMsg{
vers: VersionTLS12, vers: VersionTLS12,
random: make([]byte, 32),
cipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, cipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
compressionMethods: []uint8{compressionNone}, compressionMethods: []uint8{compressionNone},
supportedCurves: []CurveID{curveid}, supportedCurves: []CurveID{curveid},

29
conn.go
View file

@ -478,12 +478,18 @@ type RecordHeaderError struct {
// RecordHeader contains the five bytes of TLS record header that // RecordHeader contains the five bytes of TLS record header that
// triggered the error. // triggered the error.
RecordHeader [5]byte RecordHeader [5]byte
// Conn provides the underlying net.Conn in the case that a client
// sent an initial handshake that didn't look like TLS.
// It is nil if there's already been a handshake or a TLS alert has
// been written to the connection.
Conn net.Conn
} }
func (e RecordHeaderError) Error() string { return "tls: " + e.Msg } func (e RecordHeaderError) Error() string { return "tls: " + e.Msg }
func (c *Conn) newRecordHeaderError(msg string) (err RecordHeaderError) { func (c *Conn) newRecordHeaderError(conn net.Conn, msg string) (err RecordHeaderError) {
err.Msg = msg err.Msg = msg
err.Conn = conn
copy(err.RecordHeader[:], c.rawInput.Bytes()) copy(err.RecordHeader[:], c.rawInput.Bytes())
return err return err
} }
@ -535,7 +541,7 @@ func (c *Conn) readRecord(want recordType) error {
// an SSLv2 client. // an SSLv2 client.
if want == recordTypeHandshake && typ == 0x80 { if want == recordTypeHandshake && typ == 0x80 {
c.sendAlert(alertProtocolVersion) c.sendAlert(alertProtocolVersion)
return c.in.setErrorLocked(c.newRecordHeaderError("unsupported SSLv2 handshake received")) return c.in.setErrorLocked(c.newRecordHeaderError(nil, "unsupported SSLv2 handshake received"))
} }
vers := uint16(hdr[1])<<8 | uint16(hdr[2]) vers := uint16(hdr[1])<<8 | uint16(hdr[2])
@ -543,12 +549,7 @@ func (c *Conn) readRecord(want recordType) error {
if c.haveVers && vers != c.vers { if c.haveVers && vers != c.vers {
c.sendAlert(alertProtocolVersion) c.sendAlert(alertProtocolVersion)
msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, c.vers) msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, c.vers)
return c.in.setErrorLocked(c.newRecordHeaderError(msg)) return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg))
}
if n > maxCiphertext {
c.sendAlert(alertRecordOverflow)
msg := fmt.Sprintf("oversized record received with length %d", n)
return c.in.setErrorLocked(c.newRecordHeaderError(msg))
} }
if !c.haveVers { if !c.haveVers {
// First message, be extra suspicious: this might not be a TLS // First message, be extra suspicious: this might not be a TLS
@ -556,10 +557,14 @@ func (c *Conn) readRecord(want recordType) error {
// The current max version is 3.3 so if the version is >= 16.0, // The current max version is 3.3 so if the version is >= 16.0,
// it's probably not real. // it's probably not real.
if (typ != recordTypeAlert && typ != want) || vers >= 0x1000 { if (typ != recordTypeAlert && typ != want) || vers >= 0x1000 {
c.sendAlert(alertUnexpectedMessage) return c.in.setErrorLocked(c.newRecordHeaderError(c.conn, "first record does not look like a TLS handshake"))
return c.in.setErrorLocked(c.newRecordHeaderError("first record does not look like a TLS handshake"))
} }
} }
if n > maxCiphertext {
c.sendAlert(alertRecordOverflow)
msg := fmt.Sprintf("oversized record received with length %d", n)
return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg))
}
if err := c.readFromUntil(c.conn, recordHeaderLen+n); err != nil { if err := c.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
if e, ok := err.(net.Error); !ok || !e.Temporary() { if e, ok := err.(net.Error); !ok || !e.Temporary() {
c.in.setErrorLocked(err) c.in.setErrorLocked(err)
@ -894,7 +899,7 @@ func (c *Conn) readHandshake() (interface{}, error) {
m = new(certificateMsg) m = new(certificateMsg)
case typeCertificateRequest: case typeCertificateRequest:
m = &certificateRequestMsg{ m = &certificateRequestMsg{
hasSignatureAndHash: c.vers >= VersionTLS12, hasSignatureAlgorithm: c.vers >= VersionTLS12,
} }
case typeCertificateStatus: case typeCertificateStatus:
m = new(certificateStatusMsg) m = new(certificateStatusMsg)
@ -906,7 +911,7 @@ func (c *Conn) readHandshake() (interface{}, error) {
m = new(clientKeyExchangeMsg) m = new(clientKeyExchangeMsg)
case typeCertificateVerify: case typeCertificateVerify:
m = &certificateVerifyMsg{ m = &certificateVerifyMsg{
hasSignatureAndHash: c.vers >= VersionTLS12, hasSignatureAlgorithm: c.vers >= VersionTLS12,
} }
case typeNextProtocol: case typeNextProtocol:
m = new(nextProtoMsg) m = new(nextProtoMsg)

View file

@ -476,7 +476,7 @@ func (hs *clientHandshakeState) doFullHandshake() error {
if chainToSend != nil && len(chainToSend.Certificate) > 0 { if chainToSend != nil && len(chainToSend.Certificate) > 0 {
certVerify := &certificateVerifyMsg{ certVerify := &certificateVerifyMsg{
hasSignatureAndHash: c.vers >= VersionTLS12, hasSignatureAlgorithm: c.vers >= VersionTLS12,
} }
key, ok := chainToSend.PrivateKey.(crypto.Signer) key, ok := chainToSend.PrivateKey.(crypto.Signer)
@ -491,7 +491,7 @@ func (hs *clientHandshakeState) doFullHandshake() error {
return err return err
} }
// SignatureAndHashAlgorithm was introduced in TLS 1.2. // SignatureAndHashAlgorithm was introduced in TLS 1.2.
if certVerify.hasSignatureAndHash { if certVerify.hasSignatureAlgorithm {
certVerify.signatureAlgorithm = signatureAlgorithm certVerify.signatureAlgorithm = signatureAlgorithm
} }
digest, err := hs.finishedHash.hashForClientCertificate(sigType, hashFunc, hs.masterSecret) digest, err := hs.finishedHash.hashForClientCertificate(sigType, hashFunc, hs.masterSecret)
@ -744,7 +744,7 @@ func (hs *clientHandshakeState) getCertificate(certReq *certificateRequestMsg) (
if c.config.GetClientCertificate != nil { if c.config.GetClientCertificate != nil {
var signatureSchemes []SignatureScheme var signatureSchemes []SignatureScheme
if !certReq.hasSignatureAndHash { if !certReq.hasSignatureAlgorithm {
// Prior to TLS 1.2, the signature schemes were not // Prior to TLS 1.2, the signature schemes were not
// included in the certificate request message. In this // included in the certificate request message. In this
// case we use a plausible list based on the acceptable // case we use a plausible list based on the acceptable

View file

@ -384,10 +384,12 @@ func (test *clientTest) run(t *testing.T, write bool) {
} }
for i, b := range flows { for i, b := range flows {
if i%2 == 1 { if i%2 == 1 {
serverConn.SetWriteDeadline(time.Now().Add(1 * time.Minute))
serverConn.Write(b) serverConn.Write(b)
continue continue
} }
bb := make([]byte, len(b)) bb := make([]byte, len(b))
serverConn.SetReadDeadline(time.Now().Add(1 * time.Minute))
_, err := io.ReadFull(serverConn, bb) _, err := io.ReadFull(serverConn, bb)
if err != nil { if err != nil {
t.Fatalf("%s #%d: %s", test.name, i, err) t.Fatalf("%s #%d: %s", test.name, i, err)
@ -1644,7 +1646,7 @@ func TestCloseClientConnectionOnIdleServer(t *testing.T) {
serverConn.Read(b[:]) serverConn.Read(b[:])
client.Close() client.Close()
}() }()
client.SetWriteDeadline(time.Now().Add(time.Second)) client.SetWriteDeadline(time.Now().Add(time.Minute))
err := client.Handshake() err := client.Handshake()
if err != nil { if err != nil {
if err, ok := err.(net.Error); ok && err.Timeout() { if err, ok := err.(net.Error); ok && err.Timeout() {

File diff suppressed because it is too large Load diff

View file

@ -20,7 +20,9 @@ var tests = []interface{}{
&certificateMsg{}, &certificateMsg{},
&certificateRequestMsg{}, &certificateRequestMsg{},
&certificateVerifyMsg{}, &certificateVerifyMsg{
hasSignatureAlgorithm: true,
},
&certificateStatusMsg{}, &certificateStatusMsg{},
&clientKeyExchangeMsg{}, &clientKeyExchangeMsg{},
&nextProtoMsg{}, &nextProtoMsg{},
@ -28,12 +30,6 @@ var tests = []interface{}{
&sessionState{}, &sessionState{},
} }
type testMessage interface {
marshal() []byte
unmarshal([]byte) bool
equal(interface{}) bool
}
func TestMarshalUnmarshal(t *testing.T) { func TestMarshalUnmarshal(t *testing.T) {
rand := rand.New(rand.NewSource(0)) rand := rand.New(rand.NewSource(0))
@ -51,16 +47,16 @@ func TestMarshalUnmarshal(t *testing.T) {
break break
} }
m1 := v.Interface().(testMessage) m1 := v.Interface().(handshakeMessage)
marshaled := m1.marshal() marshaled := m1.marshal()
m2 := iface.(testMessage) m2 := iface.(handshakeMessage)
if !m2.unmarshal(marshaled) { if !m2.unmarshal(marshaled) {
t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled) t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
break break
} }
m2.marshal() // to fill any marshal cache in the message m2.marshal() // to fill any marshal cache in the message
if !m1.equal(m2) { if !reflect.DeepEqual(m1, m2) {
t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled) t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled)
break break
} }
@ -85,7 +81,7 @@ func TestMarshalUnmarshal(t *testing.T) {
func TestFuzz(t *testing.T) { func TestFuzz(t *testing.T) {
rand := rand.New(rand.NewSource(0)) rand := rand.New(rand.NewSource(0))
for _, iface := range tests { for _, iface := range tests {
m := iface.(testMessage) m := iface.(handshakeMessage)
for j := 0; j < 1000; j++ { for j := 0; j < 1000; j++ {
len := rand.Intn(100) len := rand.Intn(100)
@ -142,18 +138,23 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m.ticketSupported = true m.ticketSupported = true
if rand.Intn(10) > 5 { if rand.Intn(10) > 5 {
m.sessionTicket = randomBytes(rand.Intn(300), rand) m.sessionTicket = randomBytes(rand.Intn(300), rand)
} else {
m.sessionTicket = make([]byte, 0)
} }
} }
if rand.Intn(10) > 5 { if rand.Intn(10) > 5 {
m.supportedSignatureAlgorithms = supportedSignatureAlgorithms() m.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
} }
m.alpnProtocols = make([]string, rand.Intn(5)) for i := 0; i < rand.Intn(5); i++ {
for i := range m.alpnProtocols { m.alpnProtocols = append(m.alpnProtocols, randomString(rand.Intn(20)+1, rand))
m.alpnProtocols[i] = randomString(rand.Intn(20)+1, rand)
} }
if rand.Intn(10) > 5 { if rand.Intn(10) > 5 {
m.scts = true m.scts = true
} }
if rand.Intn(10) > 5 {
m.secureRenegotiationSupported = true
m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
}
return reflect.ValueOf(m) return reflect.ValueOf(m)
} }
@ -168,11 +169,8 @@ func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
if rand.Intn(10) > 5 { if rand.Intn(10) > 5 {
m.nextProtoNeg = true m.nextProtoNeg = true
for i := 0; i < rand.Intn(10); i++ {
n := rand.Intn(10) m.nextProtos = append(m.nextProtos, randomString(20, rand))
m.nextProtos = make([]string, n)
for i := 0; i < n; i++ {
m.nextProtos[i] = randomString(20, rand)
} }
} }
@ -184,12 +182,13 @@ func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
} }
m.alpnProtocol = randomString(rand.Intn(32)+1, rand) m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
for i := 0; i < rand.Intn(4); i++ {
m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand))
}
if rand.Intn(10) > 5 { if rand.Intn(10) > 5 {
numSCTs := rand.Intn(4) m.secureRenegotiationSupported = true
m.scts = make([][]byte, numSCTs) m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
for i := range m.scts {
m.scts[i] = randomBytes(rand.Intn(500)+1, rand)
}
} }
return reflect.ValueOf(m) return reflect.ValueOf(m)
@ -208,16 +207,16 @@ func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value { func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &certificateRequestMsg{} m := &certificateRequestMsg{}
m.certificateTypes = randomBytes(rand.Intn(5)+1, rand) m.certificateTypes = randomBytes(rand.Intn(5)+1, rand)
numCAs := rand.Intn(100) for i := 0; i < rand.Intn(100); i++ {
m.certificateAuthorities = make([][]byte, numCAs) m.certificateAuthorities = append(m.certificateAuthorities, randomBytes(rand.Intn(15)+1, rand))
for i := 0; i < numCAs; i++ {
m.certificateAuthorities[i] = randomBytes(rand.Intn(15)+1, rand)
} }
return reflect.ValueOf(m) return reflect.ValueOf(m)
} }
func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value { func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &certificateVerifyMsg{} m := &certificateVerifyMsg{}
m.hasSignatureAlgorithm = true
m.signatureAlgorithm = SignatureScheme(rand.Intn(30000))
m.signature = randomBytes(rand.Intn(15)+1, rand) m.signature = randomBytes(rand.Intn(15)+1, rand)
return reflect.ValueOf(m) return reflect.ValueOf(m)
} }

View file

@ -418,7 +418,7 @@ func (hs *serverHandshakeState) doFullHandshake() error {
byte(certTypeECDSASign), byte(certTypeECDSASign),
} }
if c.vers >= VersionTLS12 { if c.vers >= VersionTLS12 {
certReq.hasSignatureAndHash = true certReq.hasSignatureAlgorithm = true
certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms() certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
} }

View file

@ -101,13 +101,17 @@ var badProtocolVersions = []uint16{0x0000, 0x0005, 0x0100, 0x0105, 0x0200, 0x020
func TestRejectBadProtocolVersion(t *testing.T) { func TestRejectBadProtocolVersion(t *testing.T) {
for _, v := range badProtocolVersions { for _, v := range badProtocolVersions {
testClientHelloFailure(t, testConfig, &clientHelloMsg{vers: v}, "unsupported, maximum protocol version") testClientHelloFailure(t, testConfig, &clientHelloMsg{
vers: v,
random: make([]byte, 32),
}, "unsupported, maximum protocol version")
} }
} }
func TestNoSuiteOverlap(t *testing.T) { func TestNoSuiteOverlap(t *testing.T) {
clientHello := &clientHelloMsg{ clientHello := &clientHelloMsg{
vers: VersionTLS10, vers: VersionTLS10,
random: make([]byte, 32),
cipherSuites: []uint16{0xff00}, cipherSuites: []uint16{0xff00},
compressionMethods: []uint8{compressionNone}, compressionMethods: []uint8{compressionNone},
} }
@ -117,6 +121,7 @@ func TestNoSuiteOverlap(t *testing.T) {
func TestNoCompressionOverlap(t *testing.T) { func TestNoCompressionOverlap(t *testing.T) {
clientHello := &clientHelloMsg{ clientHello := &clientHelloMsg{
vers: VersionTLS10, vers: VersionTLS10,
random: make([]byte, 32),
cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA},
compressionMethods: []uint8{0xff}, compressionMethods: []uint8{0xff},
} }
@ -126,6 +131,7 @@ func TestNoCompressionOverlap(t *testing.T) {
func TestNoRC4ByDefault(t *testing.T) { func TestNoRC4ByDefault(t *testing.T) {
clientHello := &clientHelloMsg{ clientHello := &clientHelloMsg{
vers: VersionTLS10, vers: VersionTLS10,
random: make([]byte, 32),
cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA},
compressionMethods: []uint8{compressionNone}, compressionMethods: []uint8{compressionNone},
} }
@ -137,7 +143,11 @@ func TestNoRC4ByDefault(t *testing.T) {
} }
func TestRejectSNIWithTrailingDot(t *testing.T) { func TestRejectSNIWithTrailingDot(t *testing.T) {
testClientHelloFailure(t, testConfig, &clientHelloMsg{vers: VersionTLS12, serverName: "foo.com."}, "unexpected message") testClientHelloFailure(t, testConfig, &clientHelloMsg{
vers: VersionTLS12,
random: make([]byte, 32),
serverName: "foo.com.",
}, "unexpected message")
} }
func TestDontSelectECDSAWithRSAKey(t *testing.T) { func TestDontSelectECDSAWithRSAKey(t *testing.T) {
@ -145,6 +155,7 @@ func TestDontSelectECDSAWithRSAKey(t *testing.T) {
// won't be selected if the server's private key doesn't support it. // won't be selected if the server's private key doesn't support it.
clientHello := &clientHelloMsg{ clientHello := &clientHelloMsg{
vers: VersionTLS10, vers: VersionTLS10,
random: make([]byte, 32),
cipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA}, cipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA},
compressionMethods: []uint8{compressionNone}, compressionMethods: []uint8{compressionNone},
supportedCurves: []CurveID{CurveP256}, supportedCurves: []CurveID{CurveP256},
@ -170,6 +181,7 @@ func TestDontSelectRSAWithECDSAKey(t *testing.T) {
// won't be selected if the server's private key doesn't support it. // won't be selected if the server's private key doesn't support it.
clientHello := &clientHelloMsg{ clientHello := &clientHelloMsg{
vers: VersionTLS10, vers: VersionTLS10,
random: make([]byte, 32),
cipherSuites: []uint16{TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA}, cipherSuites: []uint16{TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA},
compressionMethods: []uint8{compressionNone}, compressionMethods: []uint8{compressionNone},
supportedCurves: []CurveID{CurveP256}, supportedCurves: []CurveID{CurveP256},
@ -242,11 +254,9 @@ func TestRenegotiationExtension(t *testing.T) {
func TestTLS12OnlyCipherSuites(t *testing.T) { func TestTLS12OnlyCipherSuites(t *testing.T) {
// Test that a Server doesn't select a TLS 1.2-only cipher suite when // Test that a Server doesn't select a TLS 1.2-only cipher suite when
// the client negotiates TLS 1.1. // the client negotiates TLS 1.1.
var zeros [32]byte
clientHello := &clientHelloMsg{ clientHello := &clientHelloMsg{
vers: VersionTLS11, vers: VersionTLS11,
random: zeros[:], random: make([]byte, 32),
cipherSuites: []uint16{ cipherSuites: []uint16{
// The Server, by default, will use the client's // The Server, by default, will use the client's
// preference order. So the GCM cipher suite // preference order. So the GCM cipher suite
@ -615,10 +625,12 @@ func (test *serverTest) run(t *testing.T, write bool) {
} }
for i, b := range flows { for i, b := range flows {
if i%2 == 0 { if i%2 == 0 {
clientConn.SetWriteDeadline(time.Now().Add(1 * time.Minute))
clientConn.Write(b) clientConn.Write(b)
continue continue
} }
bb := make([]byte, len(b)) bb := make([]byte, len(b))
clientConn.SetReadDeadline(time.Now().Add(1 * time.Minute))
n, err := io.ReadFull(clientConn, bb) n, err := io.ReadFull(clientConn, bb)
if err != nil { if err != nil {
t.Fatalf("%s #%d: %s\nRead %d, wanted %d, got %x, wanted %x\n", test.name, i+1, err, n, len(bb), bb[:n], b) t.Fatalf("%s #%d: %s\nRead %d, wanted %d, got %x, wanted %x\n", test.name, i+1, err, n, len(bb), bb[:n], b)
@ -876,6 +888,7 @@ func TestHandshakeServerSNIGetCertificateError(t *testing.T) {
clientHello := &clientHelloMsg{ clientHello := &clientHelloMsg{
vers: VersionTLS10, vers: VersionTLS10,
random: make([]byte, 32),
cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA},
compressionMethods: []uint8{compressionNone}, compressionMethods: []uint8{compressionNone},
serverName: "test", serverName: "test",
@ -896,6 +909,7 @@ func TestHandshakeServerEmptyCertificates(t *testing.T) {
clientHello := &clientHelloMsg{ clientHello := &clientHelloMsg{
vers: VersionTLS10, vers: VersionTLS10,
random: make([]byte, 32),
cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA},
compressionMethods: []uint8{compressionNone}, compressionMethods: []uint8{compressionNone},
} }
@ -907,6 +921,7 @@ func TestHandshakeServerEmptyCertificates(t *testing.T) {
clientHello = &clientHelloMsg{ clientHello = &clientHelloMsg{
vers: VersionTLS10, vers: VersionTLS10,
random: make([]byte, 32),
cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA},
compressionMethods: []uint8{compressionNone}, compressionMethods: []uint8{compressionNone},
} }
@ -1210,6 +1225,7 @@ func TestSNIGivenOnFailure(t *testing.T) {
clientHello := &clientHelloMsg{ clientHello := &clientHelloMsg{
vers: VersionTLS10, vers: VersionTLS10,
random: make([]byte, 32),
cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA},
compressionMethods: []uint8{compressionNone}, compressionMethods: []uint8{compressionNone},
serverName: expectedServerName, serverName: expectedServerName,
@ -1432,7 +1448,7 @@ func TestCloseServerConnectionOnIdleClient(t *testing.T) {
clientConn.Write([]byte{'0'}) clientConn.Write([]byte{'0'})
server.Close() server.Close()
}() }()
server.SetReadDeadline(time.Now().Add(time.Second)) server.SetReadDeadline(time.Now().Add(time.Minute))
err := server.Handshake() err := server.Handshake()
if err != nil { if err != nil {
if err, ok := err.(net.Error); ok && err.Timeout() { if err, ok := err.(net.Error); ok && err.Timeout() {

View file

@ -27,31 +27,6 @@ type sessionState struct {
usedOldKey bool usedOldKey bool
} }
func (s *sessionState) equal(i interface{}) bool {
s1, ok := i.(*sessionState)
if !ok {
return false
}
if s.vers != s1.vers ||
s.cipherSuite != s1.cipherSuite ||
!bytes.Equal(s.masterSecret, s1.masterSecret) {
return false
}
if len(s.certificates) != len(s1.certificates) {
return false
}
for i := range s.certificates {
if !bytes.Equal(s.certificates[i], s1.certificates[i]) {
return false
}
}
return true
}
func (s *sessionState) marshal() []byte { func (s *sessionState) marshal() []byte {
length := 2 + 2 + 2 + len(s.masterSecret) + 2 length := 2 + 2 + 2 + len(s.masterSecret) + 2
for _, cert := range s.certificates { for _, cert := range s.certificates {