crypto/tls: replace custom equal implementations with reflect.DeepEqual

The equal methods were only there for testing, and I remember regularly
getting them wrong while developing tls-tris. Replace them with simple
reflect.DeepEqual calls.

The only special thing that equal() would do is ignore the difference
between a nil and a zero-length slice. Fixed the Generate methods so
that they create the same value that unmarshal will decode. The
difference is not important: it wasn't tested, all checks are
"len(slice) > 0", and all cases in which presence matters are
accompanied by a boolean.

Change-Id: Iaabf56ea17c2406b5107c808c32f6c85b611aaa8
Reviewed-on: https://go-review.googlesource.com/c/144114
Run-TryBot: Filippo Valsorda <filippo@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Adam Langley <agl@golang.org>
This commit is contained in:
Filippo Valsorda 2018-10-24 13:16:04 -04:00
parent 721ba6fc20
commit e58787c05a
3 changed files with 14 additions and 270 deletions

View file

@ -28,12 +28,6 @@ var tests = []interface{}{
&sessionState{},
}
type testMessage interface {
marshal() []byte
unmarshal([]byte) bool
equal(interface{}) bool
}
func TestMarshalUnmarshal(t *testing.T) {
rand := rand.New(rand.NewSource(0))
@ -51,16 +45,16 @@ func TestMarshalUnmarshal(t *testing.T) {
break
}
m1 := v.Interface().(testMessage)
m1 := v.Interface().(handshakeMessage)
marshaled := m1.marshal()
m2 := iface.(testMessage)
m2 := iface.(handshakeMessage)
if !m2.unmarshal(marshaled) {
t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
break
}
m2.marshal() // to fill any marshal cache in the message
if !m1.equal(m2) {
if !reflect.DeepEqual(m1, m2) {
t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled)
break
}
@ -85,7 +79,7 @@ func TestMarshalUnmarshal(t *testing.T) {
func TestFuzz(t *testing.T) {
rand := rand.New(rand.NewSource(0))
for _, iface := range tests {
m := iface.(testMessage)
m := iface.(handshakeMessage)
for j := 0; j < 1000; j++ {
len := rand.Intn(100)
@ -142,14 +136,15 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m.ticketSupported = true
if rand.Intn(10) > 5 {
m.sessionTicket = randomBytes(rand.Intn(300), rand)
} else {
m.sessionTicket = make([]byte, 0)
}
}
if rand.Intn(10) > 5 {
m.supportedSignatureAlgorithms = supportedSignatureAlgorithms
}
m.alpnProtocols = make([]string, rand.Intn(5))
for i := range m.alpnProtocols {
m.alpnProtocols[i] = randomString(rand.Intn(20)+1, rand)
for i := 0; i < rand.Intn(5); i++ {
m.alpnProtocols = append(m.alpnProtocols, randomString(rand.Intn(20)+1, rand))
}
if rand.Intn(10) > 5 {
m.scts = true
@ -168,11 +163,8 @@ func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
if rand.Intn(10) > 5 {
m.nextProtoNeg = true
n := rand.Intn(10)
m.nextProtos = make([]string, n)
for i := 0; i < n; i++ {
m.nextProtos[i] = randomString(20, rand)
for i := 0; i < rand.Intn(10); i++ {
m.nextProtos = append(m.nextProtos, randomString(20, rand))
}
}
@ -184,12 +176,8 @@ func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
}
m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
if rand.Intn(10) > 5 {
numSCTs := rand.Intn(4)
m.scts = make([][]byte, numSCTs)
for i := range m.scts {
m.scts[i] = randomBytes(rand.Intn(500)+1, rand)
}
for i := 0; i < rand.Intn(4); i++ {
m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand))
}
return reflect.ValueOf(m)
@ -208,10 +196,8 @@ func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &certificateRequestMsg{}
m.certificateTypes = randomBytes(rand.Intn(5)+1, rand)
numCAs := rand.Intn(100)
m.certificateAuthorities = make([][]byte, numCAs)
for i := 0; i < numCAs; i++ {
m.certificateAuthorities[i] = randomBytes(rand.Intn(15)+1, rand)
for i := 0; i < rand.Intn(100); i++ {
m.certificateAuthorities = append(m.certificateAuthorities, randomBytes(rand.Intn(15)+1, rand))
}
return reflect.ValueOf(m)
}