From e58787c05a05c47a975a6f2f41a19a35e8aa5544 Mon Sep 17 00:00:00 2001 From: Filippo Valsorda Date: Wed, 24 Oct 2018 13:16:04 -0400 Subject: [PATCH] crypto/tls: replace custom equal implementations with reflect.DeepEqual The equal methods were only there for testing, and I remember regularly getting them wrong while developing tls-tris. Replace them with simple reflect.DeepEqual calls. The only special thing that equal() would do is ignore the difference between a nil and a zero-length slice. Fixed the Generate methods so that they create the same value that unmarshal will decode. The difference is not important: it wasn't tested, all checks are "len(slice) > 0", and all cases in which presence matters are accompanied by a boolean. Change-Id: Iaabf56ea17c2406b5107c808c32f6c85b611aaa8 Reviewed-on: https://go-review.googlesource.com/c/144114 Run-TryBot: Filippo Valsorda TryBot-Result: Gobot Gobot Reviewed-by: Adam Langley --- handshake_messages.go | 217 ------------------------------------- handshake_messages_test.go | 42 +++---- ticket.go | 25 ----- 3 files changed, 14 insertions(+), 270 deletions(-) diff --git a/handshake_messages.go b/handshake_messages.go index 27004b2..c5d9950 100644 --- a/handshake_messages.go +++ b/handshake_messages.go @@ -5,7 +5,6 @@ package tls import ( - "bytes" "strings" ) @@ -30,32 +29,6 @@ type clientHelloMsg struct { alpnProtocols []string } -func (m *clientHelloMsg) equal(i interface{}) bool { - m1, ok := i.(*clientHelloMsg) - if !ok { - return false - } - - return bytes.Equal(m.raw, m1.raw) && - m.vers == m1.vers && - bytes.Equal(m.random, m1.random) && - bytes.Equal(m.sessionId, m1.sessionId) && - eqUint16s(m.cipherSuites, m1.cipherSuites) && - bytes.Equal(m.compressionMethods, m1.compressionMethods) && - m.nextProtoNeg == m1.nextProtoNeg && - m.serverName == m1.serverName && - m.ocspStapling == m1.ocspStapling && - m.scts == m1.scts && - eqCurveIDs(m.supportedCurves, m1.supportedCurves) && - bytes.Equal(m.supportedPoints, m1.supportedPoints) && - m.ticketSupported == m1.ticketSupported && - bytes.Equal(m.sessionTicket, m1.sessionTicket) && - eqSignatureAlgorithms(m.supportedSignatureAlgorithms, m1.supportedSignatureAlgorithms) && - m.secureRenegotiationSupported == m1.secureRenegotiationSupported && - bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) && - eqStrings(m.alpnProtocols, m1.alpnProtocols) -} - func (m *clientHelloMsg) marshal() []byte { if m.raw != nil { return m.raw @@ -519,36 +492,6 @@ type serverHelloMsg struct { alpnProtocol string } -func (m *serverHelloMsg) equal(i interface{}) bool { - m1, ok := i.(*serverHelloMsg) - if !ok { - return false - } - - if len(m.scts) != len(m1.scts) { - return false - } - for i, sct := range m.scts { - if !bytes.Equal(sct, m1.scts[i]) { - return false - } - } - - return bytes.Equal(m.raw, m1.raw) && - m.vers == m1.vers && - bytes.Equal(m.random, m1.random) && - bytes.Equal(m.sessionId, m1.sessionId) && - m.cipherSuite == m1.cipherSuite && - m.compressionMethod == m1.compressionMethod && - m.nextProtoNeg == m1.nextProtoNeg && - eqStrings(m.nextProtos, m1.nextProtos) && - m.ocspStapling == m1.ocspStapling && - m.ticketSupported == m1.ticketSupported && - m.secureRenegotiationSupported == m1.secureRenegotiationSupported && - bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) && - m.alpnProtocol == m1.alpnProtocol -} - func (m *serverHelloMsg) marshal() []byte { if m.raw != nil { return m.raw @@ -838,16 +781,6 @@ type certificateMsg struct { certificates [][]byte } -func (m *certificateMsg) equal(i interface{}) bool { - m1, ok := i.(*certificateMsg) - if !ok { - return false - } - - return bytes.Equal(m.raw, m1.raw) && - eqByteSlices(m.certificates, m1.certificates) -} - func (m *certificateMsg) marshal() (x []byte) { if m.raw != nil { return m.raw @@ -925,16 +858,6 @@ type serverKeyExchangeMsg struct { key []byte } -func (m *serverKeyExchangeMsg) equal(i interface{}) bool { - m1, ok := i.(*serverKeyExchangeMsg) - if !ok { - return false - } - - return bytes.Equal(m.raw, m1.raw) && - bytes.Equal(m.key, m1.key) -} - func (m *serverKeyExchangeMsg) marshal() []byte { if m.raw != nil { return m.raw @@ -966,17 +889,6 @@ type certificateStatusMsg struct { response []byte } -func (m *certificateStatusMsg) equal(i interface{}) bool { - m1, ok := i.(*certificateStatusMsg) - if !ok { - return false - } - - return bytes.Equal(m.raw, m1.raw) && - m.statusType == m1.statusType && - bytes.Equal(m.response, m1.response) -} - func (m *certificateStatusMsg) marshal() []byte { if m.raw != nil { return m.raw @@ -1028,11 +940,6 @@ func (m *certificateStatusMsg) unmarshal(data []byte) bool { type serverHelloDoneMsg struct{} -func (m *serverHelloDoneMsg) equal(i interface{}) bool { - _, ok := i.(*serverHelloDoneMsg) - return ok -} - func (m *serverHelloDoneMsg) marshal() []byte { x := make([]byte, 4) x[0] = typeServerHelloDone @@ -1048,16 +955,6 @@ type clientKeyExchangeMsg struct { ciphertext []byte } -func (m *clientKeyExchangeMsg) equal(i interface{}) bool { - m1, ok := i.(*clientKeyExchangeMsg) - if !ok { - return false - } - - return bytes.Equal(m.raw, m1.raw) && - bytes.Equal(m.ciphertext, m1.ciphertext) -} - func (m *clientKeyExchangeMsg) marshal() []byte { if m.raw != nil { return m.raw @@ -1092,16 +989,6 @@ type finishedMsg struct { verifyData []byte } -func (m *finishedMsg) equal(i interface{}) bool { - m1, ok := i.(*finishedMsg) - if !ok { - return false - } - - return bytes.Equal(m.raw, m1.raw) && - bytes.Equal(m.verifyData, m1.verifyData) -} - func (m *finishedMsg) marshal() (x []byte) { if m.raw != nil { return m.raw @@ -1129,16 +1016,6 @@ type nextProtoMsg struct { proto string } -func (m *nextProtoMsg) equal(i interface{}) bool { - m1, ok := i.(*nextProtoMsg) - if !ok { - return false - } - - return bytes.Equal(m.raw, m1.raw) && - m.proto == m1.proto -} - func (m *nextProtoMsg) marshal() []byte { if m.raw != nil { return m.raw @@ -1206,18 +1083,6 @@ type certificateRequestMsg struct { certificateAuthorities [][]byte } -func (m *certificateRequestMsg) equal(i interface{}) bool { - m1, ok := i.(*certificateRequestMsg) - if !ok { - return false - } - - return bytes.Equal(m.raw, m1.raw) && - bytes.Equal(m.certificateTypes, m1.certificateTypes) && - eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities) && - eqSignatureAlgorithms(m.supportedSignatureAlgorithms, m1.supportedSignatureAlgorithms) -} - func (m *certificateRequestMsg) marshal() (x []byte) { if m.raw != nil { return m.raw @@ -1356,18 +1221,6 @@ type certificateVerifyMsg struct { signature []byte } -func (m *certificateVerifyMsg) equal(i interface{}) bool { - m1, ok := i.(*certificateVerifyMsg) - if !ok { - return false - } - - return bytes.Equal(m.raw, m1.raw) && - m.hasSignatureAndHash == m1.hasSignatureAndHash && - m.signatureAlgorithm == m1.signatureAlgorithm && - bytes.Equal(m.signature, m1.signature) -} - func (m *certificateVerifyMsg) marshal() (x []byte) { if m.raw != nil { return m.raw @@ -1436,16 +1289,6 @@ type newSessionTicketMsg struct { ticket []byte } -func (m *newSessionTicketMsg) equal(i interface{}) bool { - m1, ok := i.(*newSessionTicketMsg) - if !ok { - return false - } - - return bytes.Equal(m.raw, m1.raw) && - bytes.Equal(m.ticket, m1.ticket) -} - func (m *newSessionTicketMsg) marshal() (x []byte) { if m.raw != nil { return m.raw @@ -1500,63 +1343,3 @@ func (*helloRequestMsg) marshal() []byte { func (*helloRequestMsg) unmarshal(data []byte) bool { return len(data) == 4 } - -func eqUint16s(x, y []uint16) bool { - if len(x) != len(y) { - return false - } - for i, v := range x { - if y[i] != v { - return false - } - } - return true -} - -func eqCurveIDs(x, y []CurveID) bool { - if len(x) != len(y) { - return false - } - for i, v := range x { - if y[i] != v { - return false - } - } - return true -} - -func eqStrings(x, y []string) bool { - if len(x) != len(y) { - return false - } - for i, v := range x { - if y[i] != v { - return false - } - } - return true -} - -func eqByteSlices(x, y [][]byte) bool { - if len(x) != len(y) { - return false - } - for i, v := range x { - if !bytes.Equal(v, y[i]) { - return false - } - } - return true -} - -func eqSignatureAlgorithms(x, y []SignatureScheme) bool { - if len(x) != len(y) { - return false - } - for i, v := range x { - if v != y[i] { - return false - } - } - return true -} diff --git a/handshake_messages_test.go b/handshake_messages_test.go index 52c5d30..c8cc0d6 100644 --- a/handshake_messages_test.go +++ b/handshake_messages_test.go @@ -28,12 +28,6 @@ var tests = []interface{}{ &sessionState{}, } -type testMessage interface { - marshal() []byte - unmarshal([]byte) bool - equal(interface{}) bool -} - func TestMarshalUnmarshal(t *testing.T) { rand := rand.New(rand.NewSource(0)) @@ -51,16 +45,16 @@ func TestMarshalUnmarshal(t *testing.T) { break } - m1 := v.Interface().(testMessage) + m1 := v.Interface().(handshakeMessage) marshaled := m1.marshal() - m2 := iface.(testMessage) + m2 := iface.(handshakeMessage) if !m2.unmarshal(marshaled) { t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled) break } m2.marshal() // to fill any marshal cache in the message - if !m1.equal(m2) { + if !reflect.DeepEqual(m1, m2) { t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled) break } @@ -85,7 +79,7 @@ func TestMarshalUnmarshal(t *testing.T) { func TestFuzz(t *testing.T) { rand := rand.New(rand.NewSource(0)) for _, iface := range tests { - m := iface.(testMessage) + m := iface.(handshakeMessage) for j := 0; j < 1000; j++ { len := rand.Intn(100) @@ -142,14 +136,15 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { m.ticketSupported = true if rand.Intn(10) > 5 { m.sessionTicket = randomBytes(rand.Intn(300), rand) + } else { + m.sessionTicket = make([]byte, 0) } } if rand.Intn(10) > 5 { m.supportedSignatureAlgorithms = supportedSignatureAlgorithms } - m.alpnProtocols = make([]string, rand.Intn(5)) - for i := range m.alpnProtocols { - m.alpnProtocols[i] = randomString(rand.Intn(20)+1, rand) + for i := 0; i < rand.Intn(5); i++ { + m.alpnProtocols = append(m.alpnProtocols, randomString(rand.Intn(20)+1, rand)) } if rand.Intn(10) > 5 { m.scts = true @@ -168,11 +163,8 @@ func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { if rand.Intn(10) > 5 { m.nextProtoNeg = true - - n := rand.Intn(10) - m.nextProtos = make([]string, n) - for i := 0; i < n; i++ { - m.nextProtos[i] = randomString(20, rand) + for i := 0; i < rand.Intn(10); i++ { + m.nextProtos = append(m.nextProtos, randomString(20, rand)) } } @@ -184,12 +176,8 @@ func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { } m.alpnProtocol = randomString(rand.Intn(32)+1, rand) - if rand.Intn(10) > 5 { - numSCTs := rand.Intn(4) - m.scts = make([][]byte, numSCTs) - for i := range m.scts { - m.scts[i] = randomBytes(rand.Intn(500)+1, rand) - } + for i := 0; i < rand.Intn(4); i++ { + m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand)) } return reflect.ValueOf(m) @@ -208,10 +196,8 @@ func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value { func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value { m := &certificateRequestMsg{} m.certificateTypes = randomBytes(rand.Intn(5)+1, rand) - numCAs := rand.Intn(100) - m.certificateAuthorities = make([][]byte, numCAs) - for i := 0; i < numCAs; i++ { - m.certificateAuthorities[i] = randomBytes(rand.Intn(15)+1, rand) + for i := 0; i < rand.Intn(100); i++ { + m.certificateAuthorities = append(m.certificateAuthorities, randomBytes(rand.Intn(15)+1, rand)) } return reflect.ValueOf(m) } diff --git a/ticket.go b/ticket.go index 3e7aa93..c1077e5 100644 --- a/ticket.go +++ b/ticket.go @@ -27,31 +27,6 @@ type sessionState struct { usedOldKey bool } -func (s *sessionState) equal(i interface{}) bool { - s1, ok := i.(*sessionState) - if !ok { - return false - } - - if s.vers != s1.vers || - s.cipherSuite != s1.cipherSuite || - !bytes.Equal(s.masterSecret, s1.masterSecret) { - return false - } - - if len(s.certificates) != len(s1.certificates) { - return false - } - - for i := range s.certificates { - if !bytes.Equal(s.certificates[i], s1.certificates[i]) { - return false - } - } - - return true -} - func (s *sessionState) marshal() []byte { length := 2 + 2 + 2 + len(s.masterSecret) + 2 for _, cert := range s.certificates {