mirror of
https://github.com/refraction-networking/utls.git
synced 2025-04-03 20:17:36 +03:00
crypto/tls: replace custom equal implementations with reflect.DeepEqual
The equal methods were only there for testing, and I remember regularly getting them wrong while developing tls-tris. Replace them with simple reflect.DeepEqual calls. The only special thing that equal() would do is ignore the difference between a nil and a zero-length slice. Fixed the Generate methods so that they create the same value that unmarshal will decode. The difference is not important: it wasn't tested, all checks are "len(slice) > 0", and all cases in which presence matters are accompanied by a boolean. Change-Id: Iaabf56ea17c2406b5107c808c32f6c85b611aaa8 Reviewed-on: https://go-review.googlesource.com/c/144114 Run-TryBot: Filippo Valsorda <filippo@golang.org> TryBot-Result: Gobot Gobot <gobot@golang.org> Reviewed-by: Adam Langley <agl@golang.org>
This commit is contained in:
parent
721ba6fc20
commit
e58787c05a
3 changed files with 14 additions and 270 deletions
|
@ -5,7 +5,6 @@
|
|||
package tls
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
)
|
||||
|
||||
|
@ -30,32 +29,6 @@ type clientHelloMsg struct {
|
|||
alpnProtocols []string
|
||||
}
|
||||
|
||||
func (m *clientHelloMsg) equal(i interface{}) bool {
|
||||
m1, ok := i.(*clientHelloMsg)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
return bytes.Equal(m.raw, m1.raw) &&
|
||||
m.vers == m1.vers &&
|
||||
bytes.Equal(m.random, m1.random) &&
|
||||
bytes.Equal(m.sessionId, m1.sessionId) &&
|
||||
eqUint16s(m.cipherSuites, m1.cipherSuites) &&
|
||||
bytes.Equal(m.compressionMethods, m1.compressionMethods) &&
|
||||
m.nextProtoNeg == m1.nextProtoNeg &&
|
||||
m.serverName == m1.serverName &&
|
||||
m.ocspStapling == m1.ocspStapling &&
|
||||
m.scts == m1.scts &&
|
||||
eqCurveIDs(m.supportedCurves, m1.supportedCurves) &&
|
||||
bytes.Equal(m.supportedPoints, m1.supportedPoints) &&
|
||||
m.ticketSupported == m1.ticketSupported &&
|
||||
bytes.Equal(m.sessionTicket, m1.sessionTicket) &&
|
||||
eqSignatureAlgorithms(m.supportedSignatureAlgorithms, m1.supportedSignatureAlgorithms) &&
|
||||
m.secureRenegotiationSupported == m1.secureRenegotiationSupported &&
|
||||
bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) &&
|
||||
eqStrings(m.alpnProtocols, m1.alpnProtocols)
|
||||
}
|
||||
|
||||
func (m *clientHelloMsg) marshal() []byte {
|
||||
if m.raw != nil {
|
||||
return m.raw
|
||||
|
@ -519,36 +492,6 @@ type serverHelloMsg struct {
|
|||
alpnProtocol string
|
||||
}
|
||||
|
||||
func (m *serverHelloMsg) equal(i interface{}) bool {
|
||||
m1, ok := i.(*serverHelloMsg)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(m.scts) != len(m1.scts) {
|
||||
return false
|
||||
}
|
||||
for i, sct := range m.scts {
|
||||
if !bytes.Equal(sct, m1.scts[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return bytes.Equal(m.raw, m1.raw) &&
|
||||
m.vers == m1.vers &&
|
||||
bytes.Equal(m.random, m1.random) &&
|
||||
bytes.Equal(m.sessionId, m1.sessionId) &&
|
||||
m.cipherSuite == m1.cipherSuite &&
|
||||
m.compressionMethod == m1.compressionMethod &&
|
||||
m.nextProtoNeg == m1.nextProtoNeg &&
|
||||
eqStrings(m.nextProtos, m1.nextProtos) &&
|
||||
m.ocspStapling == m1.ocspStapling &&
|
||||
m.ticketSupported == m1.ticketSupported &&
|
||||
m.secureRenegotiationSupported == m1.secureRenegotiationSupported &&
|
||||
bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) &&
|
||||
m.alpnProtocol == m1.alpnProtocol
|
||||
}
|
||||
|
||||
func (m *serverHelloMsg) marshal() []byte {
|
||||
if m.raw != nil {
|
||||
return m.raw
|
||||
|
@ -838,16 +781,6 @@ type certificateMsg struct {
|
|||
certificates [][]byte
|
||||
}
|
||||
|
||||
func (m *certificateMsg) equal(i interface{}) bool {
|
||||
m1, ok := i.(*certificateMsg)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
return bytes.Equal(m.raw, m1.raw) &&
|
||||
eqByteSlices(m.certificates, m1.certificates)
|
||||
}
|
||||
|
||||
func (m *certificateMsg) marshal() (x []byte) {
|
||||
if m.raw != nil {
|
||||
return m.raw
|
||||
|
@ -925,16 +858,6 @@ type serverKeyExchangeMsg struct {
|
|||
key []byte
|
||||
}
|
||||
|
||||
func (m *serverKeyExchangeMsg) equal(i interface{}) bool {
|
||||
m1, ok := i.(*serverKeyExchangeMsg)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
return bytes.Equal(m.raw, m1.raw) &&
|
||||
bytes.Equal(m.key, m1.key)
|
||||
}
|
||||
|
||||
func (m *serverKeyExchangeMsg) marshal() []byte {
|
||||
if m.raw != nil {
|
||||
return m.raw
|
||||
|
@ -966,17 +889,6 @@ type certificateStatusMsg struct {
|
|||
response []byte
|
||||
}
|
||||
|
||||
func (m *certificateStatusMsg) equal(i interface{}) bool {
|
||||
m1, ok := i.(*certificateStatusMsg)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
return bytes.Equal(m.raw, m1.raw) &&
|
||||
m.statusType == m1.statusType &&
|
||||
bytes.Equal(m.response, m1.response)
|
||||
}
|
||||
|
||||
func (m *certificateStatusMsg) marshal() []byte {
|
||||
if m.raw != nil {
|
||||
return m.raw
|
||||
|
@ -1028,11 +940,6 @@ func (m *certificateStatusMsg) unmarshal(data []byte) bool {
|
|||
|
||||
type serverHelloDoneMsg struct{}
|
||||
|
||||
func (m *serverHelloDoneMsg) equal(i interface{}) bool {
|
||||
_, ok := i.(*serverHelloDoneMsg)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (m *serverHelloDoneMsg) marshal() []byte {
|
||||
x := make([]byte, 4)
|
||||
x[0] = typeServerHelloDone
|
||||
|
@ -1048,16 +955,6 @@ type clientKeyExchangeMsg struct {
|
|||
ciphertext []byte
|
||||
}
|
||||
|
||||
func (m *clientKeyExchangeMsg) equal(i interface{}) bool {
|
||||
m1, ok := i.(*clientKeyExchangeMsg)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
return bytes.Equal(m.raw, m1.raw) &&
|
||||
bytes.Equal(m.ciphertext, m1.ciphertext)
|
||||
}
|
||||
|
||||
func (m *clientKeyExchangeMsg) marshal() []byte {
|
||||
if m.raw != nil {
|
||||
return m.raw
|
||||
|
@ -1092,16 +989,6 @@ type finishedMsg struct {
|
|||
verifyData []byte
|
||||
}
|
||||
|
||||
func (m *finishedMsg) equal(i interface{}) bool {
|
||||
m1, ok := i.(*finishedMsg)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
return bytes.Equal(m.raw, m1.raw) &&
|
||||
bytes.Equal(m.verifyData, m1.verifyData)
|
||||
}
|
||||
|
||||
func (m *finishedMsg) marshal() (x []byte) {
|
||||
if m.raw != nil {
|
||||
return m.raw
|
||||
|
@ -1129,16 +1016,6 @@ type nextProtoMsg struct {
|
|||
proto string
|
||||
}
|
||||
|
||||
func (m *nextProtoMsg) equal(i interface{}) bool {
|
||||
m1, ok := i.(*nextProtoMsg)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
return bytes.Equal(m.raw, m1.raw) &&
|
||||
m.proto == m1.proto
|
||||
}
|
||||
|
||||
func (m *nextProtoMsg) marshal() []byte {
|
||||
if m.raw != nil {
|
||||
return m.raw
|
||||
|
@ -1206,18 +1083,6 @@ type certificateRequestMsg struct {
|
|||
certificateAuthorities [][]byte
|
||||
}
|
||||
|
||||
func (m *certificateRequestMsg) equal(i interface{}) bool {
|
||||
m1, ok := i.(*certificateRequestMsg)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
return bytes.Equal(m.raw, m1.raw) &&
|
||||
bytes.Equal(m.certificateTypes, m1.certificateTypes) &&
|
||||
eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities) &&
|
||||
eqSignatureAlgorithms(m.supportedSignatureAlgorithms, m1.supportedSignatureAlgorithms)
|
||||
}
|
||||
|
||||
func (m *certificateRequestMsg) marshal() (x []byte) {
|
||||
if m.raw != nil {
|
||||
return m.raw
|
||||
|
@ -1356,18 +1221,6 @@ type certificateVerifyMsg struct {
|
|||
signature []byte
|
||||
}
|
||||
|
||||
func (m *certificateVerifyMsg) equal(i interface{}) bool {
|
||||
m1, ok := i.(*certificateVerifyMsg)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
return bytes.Equal(m.raw, m1.raw) &&
|
||||
m.hasSignatureAndHash == m1.hasSignatureAndHash &&
|
||||
m.signatureAlgorithm == m1.signatureAlgorithm &&
|
||||
bytes.Equal(m.signature, m1.signature)
|
||||
}
|
||||
|
||||
func (m *certificateVerifyMsg) marshal() (x []byte) {
|
||||
if m.raw != nil {
|
||||
return m.raw
|
||||
|
@ -1436,16 +1289,6 @@ type newSessionTicketMsg struct {
|
|||
ticket []byte
|
||||
}
|
||||
|
||||
func (m *newSessionTicketMsg) equal(i interface{}) bool {
|
||||
m1, ok := i.(*newSessionTicketMsg)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
return bytes.Equal(m.raw, m1.raw) &&
|
||||
bytes.Equal(m.ticket, m1.ticket)
|
||||
}
|
||||
|
||||
func (m *newSessionTicketMsg) marshal() (x []byte) {
|
||||
if m.raw != nil {
|
||||
return m.raw
|
||||
|
@ -1500,63 +1343,3 @@ func (*helloRequestMsg) marshal() []byte {
|
|||
func (*helloRequestMsg) unmarshal(data []byte) bool {
|
||||
return len(data) == 4
|
||||
}
|
||||
|
||||
func eqUint16s(x, y []uint16) bool {
|
||||
if len(x) != len(y) {
|
||||
return false
|
||||
}
|
||||
for i, v := range x {
|
||||
if y[i] != v {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func eqCurveIDs(x, y []CurveID) bool {
|
||||
if len(x) != len(y) {
|
||||
return false
|
||||
}
|
||||
for i, v := range x {
|
||||
if y[i] != v {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func eqStrings(x, y []string) bool {
|
||||
if len(x) != len(y) {
|
||||
return false
|
||||
}
|
||||
for i, v := range x {
|
||||
if y[i] != v {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func eqByteSlices(x, y [][]byte) bool {
|
||||
if len(x) != len(y) {
|
||||
return false
|
||||
}
|
||||
for i, v := range x {
|
||||
if !bytes.Equal(v, y[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func eqSignatureAlgorithms(x, y []SignatureScheme) bool {
|
||||
if len(x) != len(y) {
|
||||
return false
|
||||
}
|
||||
for i, v := range x {
|
||||
if v != y[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
|
|
@ -28,12 +28,6 @@ var tests = []interface{}{
|
|||
&sessionState{},
|
||||
}
|
||||
|
||||
type testMessage interface {
|
||||
marshal() []byte
|
||||
unmarshal([]byte) bool
|
||||
equal(interface{}) bool
|
||||
}
|
||||
|
||||
func TestMarshalUnmarshal(t *testing.T) {
|
||||
rand := rand.New(rand.NewSource(0))
|
||||
|
||||
|
@ -51,16 +45,16 @@ func TestMarshalUnmarshal(t *testing.T) {
|
|||
break
|
||||
}
|
||||
|
||||
m1 := v.Interface().(testMessage)
|
||||
m1 := v.Interface().(handshakeMessage)
|
||||
marshaled := m1.marshal()
|
||||
m2 := iface.(testMessage)
|
||||
m2 := iface.(handshakeMessage)
|
||||
if !m2.unmarshal(marshaled) {
|
||||
t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
|
||||
break
|
||||
}
|
||||
m2.marshal() // to fill any marshal cache in the message
|
||||
|
||||
if !m1.equal(m2) {
|
||||
if !reflect.DeepEqual(m1, m2) {
|
||||
t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled)
|
||||
break
|
||||
}
|
||||
|
@ -85,7 +79,7 @@ func TestMarshalUnmarshal(t *testing.T) {
|
|||
func TestFuzz(t *testing.T) {
|
||||
rand := rand.New(rand.NewSource(0))
|
||||
for _, iface := range tests {
|
||||
m := iface.(testMessage)
|
||||
m := iface.(handshakeMessage)
|
||||
|
||||
for j := 0; j < 1000; j++ {
|
||||
len := rand.Intn(100)
|
||||
|
@ -142,14 +136,15 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
|
|||
m.ticketSupported = true
|
||||
if rand.Intn(10) > 5 {
|
||||
m.sessionTicket = randomBytes(rand.Intn(300), rand)
|
||||
} else {
|
||||
m.sessionTicket = make([]byte, 0)
|
||||
}
|
||||
}
|
||||
if rand.Intn(10) > 5 {
|
||||
m.supportedSignatureAlgorithms = supportedSignatureAlgorithms
|
||||
}
|
||||
m.alpnProtocols = make([]string, rand.Intn(5))
|
||||
for i := range m.alpnProtocols {
|
||||
m.alpnProtocols[i] = randomString(rand.Intn(20)+1, rand)
|
||||
for i := 0; i < rand.Intn(5); i++ {
|
||||
m.alpnProtocols = append(m.alpnProtocols, randomString(rand.Intn(20)+1, rand))
|
||||
}
|
||||
if rand.Intn(10) > 5 {
|
||||
m.scts = true
|
||||
|
@ -168,11 +163,8 @@ func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
|
|||
|
||||
if rand.Intn(10) > 5 {
|
||||
m.nextProtoNeg = true
|
||||
|
||||
n := rand.Intn(10)
|
||||
m.nextProtos = make([]string, n)
|
||||
for i := 0; i < n; i++ {
|
||||
m.nextProtos[i] = randomString(20, rand)
|
||||
for i := 0; i < rand.Intn(10); i++ {
|
||||
m.nextProtos = append(m.nextProtos, randomString(20, rand))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -184,12 +176,8 @@ func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
|
|||
}
|
||||
m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
|
||||
|
||||
if rand.Intn(10) > 5 {
|
||||
numSCTs := rand.Intn(4)
|
||||
m.scts = make([][]byte, numSCTs)
|
||||
for i := range m.scts {
|
||||
m.scts[i] = randomBytes(rand.Intn(500)+1, rand)
|
||||
}
|
||||
for i := 0; i < rand.Intn(4); i++ {
|
||||
m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand))
|
||||
}
|
||||
|
||||
return reflect.ValueOf(m)
|
||||
|
@ -208,10 +196,8 @@ func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
|
|||
func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value {
|
||||
m := &certificateRequestMsg{}
|
||||
m.certificateTypes = randomBytes(rand.Intn(5)+1, rand)
|
||||
numCAs := rand.Intn(100)
|
||||
m.certificateAuthorities = make([][]byte, numCAs)
|
||||
for i := 0; i < numCAs; i++ {
|
||||
m.certificateAuthorities[i] = randomBytes(rand.Intn(15)+1, rand)
|
||||
for i := 0; i < rand.Intn(100); i++ {
|
||||
m.certificateAuthorities = append(m.certificateAuthorities, randomBytes(rand.Intn(15)+1, rand))
|
||||
}
|
||||
return reflect.ValueOf(m)
|
||||
}
|
||||
|
|
25
ticket.go
25
ticket.go
|
@ -27,31 +27,6 @@ type sessionState struct {
|
|||
usedOldKey bool
|
||||
}
|
||||
|
||||
func (s *sessionState) equal(i interface{}) bool {
|
||||
s1, ok := i.(*sessionState)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
if s.vers != s1.vers ||
|
||||
s.cipherSuite != s1.cipherSuite ||
|
||||
!bytes.Equal(s.masterSecret, s1.masterSecret) {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(s.certificates) != len(s1.certificates) {
|
||||
return false
|
||||
}
|
||||
|
||||
for i := range s.certificates {
|
||||
if !bytes.Equal(s.certificates[i], s1.certificates[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *sessionState) marshal() []byte {
|
||||
length := 2 + 2 + 2 + len(s.masterSecret) + 2
|
||||
for _, cert := range s.certificates {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue