From 9986f579382fcc44ef468570cc0f90ea6529aa67 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Thu, 15 Mar 2018 08:21:44 +0300 Subject: [PATCH 1/5] crypto/tls, net/http: reject HTTP requests to HTTPS server This adds a crypto/tls.RecordHeaderError.Conn field containing the TLS underlying net.Conn for non-TLS handshake errors, and then uses it in the net/http Server to return plaintext HTTP 400 errors when a client mistakenly sends a plaintext HTTP request to an HTTPS server. This is the same behavior as Apache. Also in crypto/tls: swap two error paths to not use a value before it's valid, and don't send a alert record when a handshake contains a bogus TLS record (a TLS record in response won't help a non-TLS client). Fixes #23689 Change-Id: Ife774b1e3886beb66f25ae4587c62123ccefe847 Reviewed-on: https://go-review.googlesource.com/c/143177 Reviewed-by: Filippo Valsorda --- conn.go | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/conn.go b/conn.go index 13cebc9..8e23643 100644 --- a/conn.go +++ b/conn.go @@ -478,12 +478,18 @@ type RecordHeaderError struct { // RecordHeader contains the five bytes of TLS record header that // triggered the error. RecordHeader [5]byte + // Conn provides the underlying net.Conn in the case that a client + // sent an initial handshake that didn't look like TLS. + // It is nil if there's already been a handshake or a TLS alert has + // been written to the connection. + Conn net.Conn } func (e RecordHeaderError) Error() string { return "tls: " + e.Msg } -func (c *Conn) newRecordHeaderError(msg string) (err RecordHeaderError) { +func (c *Conn) newRecordHeaderError(conn net.Conn, msg string) (err RecordHeaderError) { err.Msg = msg + err.Conn = conn copy(err.RecordHeader[:], c.rawInput.Bytes()) return err } @@ -535,7 +541,7 @@ func (c *Conn) readRecord(want recordType) error { // an SSLv2 client. if want == recordTypeHandshake && typ == 0x80 { c.sendAlert(alertProtocolVersion) - return c.in.setErrorLocked(c.newRecordHeaderError("unsupported SSLv2 handshake received")) + return c.in.setErrorLocked(c.newRecordHeaderError(nil, "unsupported SSLv2 handshake received")) } vers := uint16(hdr[1])<<8 | uint16(hdr[2]) @@ -543,12 +549,7 @@ func (c *Conn) readRecord(want recordType) error { if c.haveVers && vers != c.vers { c.sendAlert(alertProtocolVersion) msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, c.vers) - return c.in.setErrorLocked(c.newRecordHeaderError(msg)) - } - if n > maxCiphertext { - c.sendAlert(alertRecordOverflow) - msg := fmt.Sprintf("oversized record received with length %d", n) - return c.in.setErrorLocked(c.newRecordHeaderError(msg)) + return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg)) } if !c.haveVers { // First message, be extra suspicious: this might not be a TLS @@ -556,10 +557,14 @@ func (c *Conn) readRecord(want recordType) error { // The current max version is 3.3 so if the version is >= 16.0, // it's probably not real. if (typ != recordTypeAlert && typ != want) || vers >= 0x1000 { - c.sendAlert(alertUnexpectedMessage) - return c.in.setErrorLocked(c.newRecordHeaderError("first record does not look like a TLS handshake")) + return c.in.setErrorLocked(c.newRecordHeaderError(c.conn, "first record does not look like a TLS handshake")) } } + if n > maxCiphertext { + c.sendAlert(alertRecordOverflow) + msg := fmt.Sprintf("oversized record received with length %d", n) + return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg)) + } if err := c.readFromUntil(c.conn, recordHeaderLen+n); err != nil { if e, ok := err.(net.Error); !ok || !e.Temporary() { c.in.setErrorLocked(err) From 721ba6fc20c3fcd3d997a4abef8136ce88b4bc1e Mon Sep 17 00:00:00 2001 From: Filippo Valsorda Date: Wed, 24 Oct 2018 21:31:18 -0400 Subject: [PATCH 2/5] crypto/tls: add timeouts to recorded tests If something causes the recorded tests to deviate from the expected flows, they might wait forever for data that is not coming. Add a short timeout, after which a useful error message is shown. Change-Id: Ib11ccc0e17dcb8b2180493556017275678abbb08 Reviewed-on: https://go-review.googlesource.com/c/144116 Run-TryBot: Filippo Valsorda TryBot-Result: Gobot Gobot Reviewed-by: Adam Langley --- handshake_client_test.go | 2 ++ handshake_server_test.go | 2 ++ 2 files changed, 4 insertions(+) diff --git a/handshake_client_test.go b/handshake_client_test.go index dcd6914..5a1e608 100644 --- a/handshake_client_test.go +++ b/handshake_client_test.go @@ -384,10 +384,12 @@ func (test *clientTest) run(t *testing.T, write bool) { } for i, b := range flows { if i%2 == 1 { + serverConn.SetWriteDeadline(time.Now().Add(1 * time.Second)) serverConn.Write(b) continue } bb := make([]byte, len(b)) + serverConn.SetReadDeadline(time.Now().Add(1 * time.Second)) _, err := io.ReadFull(serverConn, bb) if err != nil { t.Fatalf("%s #%d: %s", test.name, i, err) diff --git a/handshake_server_test.go b/handshake_server_test.go index 44c67ed..2a77584 100644 --- a/handshake_server_test.go +++ b/handshake_server_test.go @@ -615,10 +615,12 @@ func (test *serverTest) run(t *testing.T, write bool) { } for i, b := range flows { if i%2 == 0 { + clientConn.SetWriteDeadline(time.Now().Add(1 * time.Second)) clientConn.Write(b) continue } bb := make([]byte, len(b)) + clientConn.SetReadDeadline(time.Now().Add(1 * time.Second)) n, err := io.ReadFull(clientConn, bb) if err != nil { t.Fatalf("%s #%d: %s\nRead %d, wanted %d, got %x, wanted %x\n", test.name, i+1, err, n, len(bb), bb[:n], b) From e58787c05a05c47a975a6f2f41a19a35e8aa5544 Mon Sep 17 00:00:00 2001 From: Filippo Valsorda Date: Wed, 24 Oct 2018 13:16:04 -0400 Subject: [PATCH 3/5] crypto/tls: replace custom equal implementations with reflect.DeepEqual The equal methods were only there for testing, and I remember regularly getting them wrong while developing tls-tris. Replace them with simple reflect.DeepEqual calls. The only special thing that equal() would do is ignore the difference between a nil and a zero-length slice. Fixed the Generate methods so that they create the same value that unmarshal will decode. The difference is not important: it wasn't tested, all checks are "len(slice) > 0", and all cases in which presence matters are accompanied by a boolean. Change-Id: Iaabf56ea17c2406b5107c808c32f6c85b611aaa8 Reviewed-on: https://go-review.googlesource.com/c/144114 Run-TryBot: Filippo Valsorda TryBot-Result: Gobot Gobot Reviewed-by: Adam Langley --- handshake_messages.go | 217 ------------------------------------- handshake_messages_test.go | 42 +++---- ticket.go | 25 ----- 3 files changed, 14 insertions(+), 270 deletions(-) diff --git a/handshake_messages.go b/handshake_messages.go index 27004b2..c5d9950 100644 --- a/handshake_messages.go +++ b/handshake_messages.go @@ -5,7 +5,6 @@ package tls import ( - "bytes" "strings" ) @@ -30,32 +29,6 @@ type clientHelloMsg struct { alpnProtocols []string } -func (m *clientHelloMsg) equal(i interface{}) bool { - m1, ok := i.(*clientHelloMsg) - if !ok { - return false - } - - return bytes.Equal(m.raw, m1.raw) && - m.vers == m1.vers && - bytes.Equal(m.random, m1.random) && - bytes.Equal(m.sessionId, m1.sessionId) && - eqUint16s(m.cipherSuites, m1.cipherSuites) && - bytes.Equal(m.compressionMethods, m1.compressionMethods) && - m.nextProtoNeg == m1.nextProtoNeg && - m.serverName == m1.serverName && - m.ocspStapling == m1.ocspStapling && - m.scts == m1.scts && - eqCurveIDs(m.supportedCurves, m1.supportedCurves) && - bytes.Equal(m.supportedPoints, m1.supportedPoints) && - m.ticketSupported == m1.ticketSupported && - bytes.Equal(m.sessionTicket, m1.sessionTicket) && - eqSignatureAlgorithms(m.supportedSignatureAlgorithms, m1.supportedSignatureAlgorithms) && - m.secureRenegotiationSupported == m1.secureRenegotiationSupported && - bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) && - eqStrings(m.alpnProtocols, m1.alpnProtocols) -} - func (m *clientHelloMsg) marshal() []byte { if m.raw != nil { return m.raw @@ -519,36 +492,6 @@ type serverHelloMsg struct { alpnProtocol string } -func (m *serverHelloMsg) equal(i interface{}) bool { - m1, ok := i.(*serverHelloMsg) - if !ok { - return false - } - - if len(m.scts) != len(m1.scts) { - return false - } - for i, sct := range m.scts { - if !bytes.Equal(sct, m1.scts[i]) { - return false - } - } - - return bytes.Equal(m.raw, m1.raw) && - m.vers == m1.vers && - bytes.Equal(m.random, m1.random) && - bytes.Equal(m.sessionId, m1.sessionId) && - m.cipherSuite == m1.cipherSuite && - m.compressionMethod == m1.compressionMethod && - m.nextProtoNeg == m1.nextProtoNeg && - eqStrings(m.nextProtos, m1.nextProtos) && - m.ocspStapling == m1.ocspStapling && - m.ticketSupported == m1.ticketSupported && - m.secureRenegotiationSupported == m1.secureRenegotiationSupported && - bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) && - m.alpnProtocol == m1.alpnProtocol -} - func (m *serverHelloMsg) marshal() []byte { if m.raw != nil { return m.raw @@ -838,16 +781,6 @@ type certificateMsg struct { certificates [][]byte } -func (m *certificateMsg) equal(i interface{}) bool { - m1, ok := i.(*certificateMsg) - if !ok { - return false - } - - return bytes.Equal(m.raw, m1.raw) && - eqByteSlices(m.certificates, m1.certificates) -} - func (m *certificateMsg) marshal() (x []byte) { if m.raw != nil { return m.raw @@ -925,16 +858,6 @@ type serverKeyExchangeMsg struct { key []byte } -func (m *serverKeyExchangeMsg) equal(i interface{}) bool { - m1, ok := i.(*serverKeyExchangeMsg) - if !ok { - return false - } - - return bytes.Equal(m.raw, m1.raw) && - bytes.Equal(m.key, m1.key) -} - func (m *serverKeyExchangeMsg) marshal() []byte { if m.raw != nil { return m.raw @@ -966,17 +889,6 @@ type certificateStatusMsg struct { response []byte } -func (m *certificateStatusMsg) equal(i interface{}) bool { - m1, ok := i.(*certificateStatusMsg) - if !ok { - return false - } - - return bytes.Equal(m.raw, m1.raw) && - m.statusType == m1.statusType && - bytes.Equal(m.response, m1.response) -} - func (m *certificateStatusMsg) marshal() []byte { if m.raw != nil { return m.raw @@ -1028,11 +940,6 @@ func (m *certificateStatusMsg) unmarshal(data []byte) bool { type serverHelloDoneMsg struct{} -func (m *serverHelloDoneMsg) equal(i interface{}) bool { - _, ok := i.(*serverHelloDoneMsg) - return ok -} - func (m *serverHelloDoneMsg) marshal() []byte { x := make([]byte, 4) x[0] = typeServerHelloDone @@ -1048,16 +955,6 @@ type clientKeyExchangeMsg struct { ciphertext []byte } -func (m *clientKeyExchangeMsg) equal(i interface{}) bool { - m1, ok := i.(*clientKeyExchangeMsg) - if !ok { - return false - } - - return bytes.Equal(m.raw, m1.raw) && - bytes.Equal(m.ciphertext, m1.ciphertext) -} - func (m *clientKeyExchangeMsg) marshal() []byte { if m.raw != nil { return m.raw @@ -1092,16 +989,6 @@ type finishedMsg struct { verifyData []byte } -func (m *finishedMsg) equal(i interface{}) bool { - m1, ok := i.(*finishedMsg) - if !ok { - return false - } - - return bytes.Equal(m.raw, m1.raw) && - bytes.Equal(m.verifyData, m1.verifyData) -} - func (m *finishedMsg) marshal() (x []byte) { if m.raw != nil { return m.raw @@ -1129,16 +1016,6 @@ type nextProtoMsg struct { proto string } -func (m *nextProtoMsg) equal(i interface{}) bool { - m1, ok := i.(*nextProtoMsg) - if !ok { - return false - } - - return bytes.Equal(m.raw, m1.raw) && - m.proto == m1.proto -} - func (m *nextProtoMsg) marshal() []byte { if m.raw != nil { return m.raw @@ -1206,18 +1083,6 @@ type certificateRequestMsg struct { certificateAuthorities [][]byte } -func (m *certificateRequestMsg) equal(i interface{}) bool { - m1, ok := i.(*certificateRequestMsg) - if !ok { - return false - } - - return bytes.Equal(m.raw, m1.raw) && - bytes.Equal(m.certificateTypes, m1.certificateTypes) && - eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities) && - eqSignatureAlgorithms(m.supportedSignatureAlgorithms, m1.supportedSignatureAlgorithms) -} - func (m *certificateRequestMsg) marshal() (x []byte) { if m.raw != nil { return m.raw @@ -1356,18 +1221,6 @@ type certificateVerifyMsg struct { signature []byte } -func (m *certificateVerifyMsg) equal(i interface{}) bool { - m1, ok := i.(*certificateVerifyMsg) - if !ok { - return false - } - - return bytes.Equal(m.raw, m1.raw) && - m.hasSignatureAndHash == m1.hasSignatureAndHash && - m.signatureAlgorithm == m1.signatureAlgorithm && - bytes.Equal(m.signature, m1.signature) -} - func (m *certificateVerifyMsg) marshal() (x []byte) { if m.raw != nil { return m.raw @@ -1436,16 +1289,6 @@ type newSessionTicketMsg struct { ticket []byte } -func (m *newSessionTicketMsg) equal(i interface{}) bool { - m1, ok := i.(*newSessionTicketMsg) - if !ok { - return false - } - - return bytes.Equal(m.raw, m1.raw) && - bytes.Equal(m.ticket, m1.ticket) -} - func (m *newSessionTicketMsg) marshal() (x []byte) { if m.raw != nil { return m.raw @@ -1500,63 +1343,3 @@ func (*helloRequestMsg) marshal() []byte { func (*helloRequestMsg) unmarshal(data []byte) bool { return len(data) == 4 } - -func eqUint16s(x, y []uint16) bool { - if len(x) != len(y) { - return false - } - for i, v := range x { - if y[i] != v { - return false - } - } - return true -} - -func eqCurveIDs(x, y []CurveID) bool { - if len(x) != len(y) { - return false - } - for i, v := range x { - if y[i] != v { - return false - } - } - return true -} - -func eqStrings(x, y []string) bool { - if len(x) != len(y) { - return false - } - for i, v := range x { - if y[i] != v { - return false - } - } - return true -} - -func eqByteSlices(x, y [][]byte) bool { - if len(x) != len(y) { - return false - } - for i, v := range x { - if !bytes.Equal(v, y[i]) { - return false - } - } - return true -} - -func eqSignatureAlgorithms(x, y []SignatureScheme) bool { - if len(x) != len(y) { - return false - } - for i, v := range x { - if v != y[i] { - return false - } - } - return true -} diff --git a/handshake_messages_test.go b/handshake_messages_test.go index 52c5d30..c8cc0d6 100644 --- a/handshake_messages_test.go +++ b/handshake_messages_test.go @@ -28,12 +28,6 @@ var tests = []interface{}{ &sessionState{}, } -type testMessage interface { - marshal() []byte - unmarshal([]byte) bool - equal(interface{}) bool -} - func TestMarshalUnmarshal(t *testing.T) { rand := rand.New(rand.NewSource(0)) @@ -51,16 +45,16 @@ func TestMarshalUnmarshal(t *testing.T) { break } - m1 := v.Interface().(testMessage) + m1 := v.Interface().(handshakeMessage) marshaled := m1.marshal() - m2 := iface.(testMessage) + m2 := iface.(handshakeMessage) if !m2.unmarshal(marshaled) { t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled) break } m2.marshal() // to fill any marshal cache in the message - if !m1.equal(m2) { + if !reflect.DeepEqual(m1, m2) { t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled) break } @@ -85,7 +79,7 @@ func TestMarshalUnmarshal(t *testing.T) { func TestFuzz(t *testing.T) { rand := rand.New(rand.NewSource(0)) for _, iface := range tests { - m := iface.(testMessage) + m := iface.(handshakeMessage) for j := 0; j < 1000; j++ { len := rand.Intn(100) @@ -142,14 +136,15 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { m.ticketSupported = true if rand.Intn(10) > 5 { m.sessionTicket = randomBytes(rand.Intn(300), rand) + } else { + m.sessionTicket = make([]byte, 0) } } if rand.Intn(10) > 5 { m.supportedSignatureAlgorithms = supportedSignatureAlgorithms } - m.alpnProtocols = make([]string, rand.Intn(5)) - for i := range m.alpnProtocols { - m.alpnProtocols[i] = randomString(rand.Intn(20)+1, rand) + for i := 0; i < rand.Intn(5); i++ { + m.alpnProtocols = append(m.alpnProtocols, randomString(rand.Intn(20)+1, rand)) } if rand.Intn(10) > 5 { m.scts = true @@ -168,11 +163,8 @@ func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { if rand.Intn(10) > 5 { m.nextProtoNeg = true - - n := rand.Intn(10) - m.nextProtos = make([]string, n) - for i := 0; i < n; i++ { - m.nextProtos[i] = randomString(20, rand) + for i := 0; i < rand.Intn(10); i++ { + m.nextProtos = append(m.nextProtos, randomString(20, rand)) } } @@ -184,12 +176,8 @@ func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { } m.alpnProtocol = randomString(rand.Intn(32)+1, rand) - if rand.Intn(10) > 5 { - numSCTs := rand.Intn(4) - m.scts = make([][]byte, numSCTs) - for i := range m.scts { - m.scts[i] = randomBytes(rand.Intn(500)+1, rand) - } + for i := 0; i < rand.Intn(4); i++ { + m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand)) } return reflect.ValueOf(m) @@ -208,10 +196,8 @@ func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value { func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value { m := &certificateRequestMsg{} m.certificateTypes = randomBytes(rand.Intn(5)+1, rand) - numCAs := rand.Intn(100) - m.certificateAuthorities = make([][]byte, numCAs) - for i := 0; i < numCAs; i++ { - m.certificateAuthorities[i] = randomBytes(rand.Intn(15)+1, rand) + for i := 0; i < rand.Intn(100); i++ { + m.certificateAuthorities = append(m.certificateAuthorities, randomBytes(rand.Intn(15)+1, rand)) } return reflect.ValueOf(m) } diff --git a/ticket.go b/ticket.go index 3e7aa93..c1077e5 100644 --- a/ticket.go +++ b/ticket.go @@ -27,31 +27,6 @@ type sessionState struct { usedOldKey bool } -func (s *sessionState) equal(i interface{}) bool { - s1, ok := i.(*sessionState) - if !ok { - return false - } - - if s.vers != s1.vers || - s.cipherSuite != s1.cipherSuite || - !bytes.Equal(s.masterSecret, s1.masterSecret) { - return false - } - - if len(s.certificates) != len(s1.certificates) { - return false - } - - for i := range s.certificates { - if !bytes.Equal(s.certificates[i], s1.certificates[i]) { - return false - } - } - - return true -} - func (s *sessionState) marshal() []byte { length := 2 + 2 + 2 + len(s.masterSecret) + 2 for _, cert := range s.certificates { From d2f8d6286156f6fb9ddaf01ea7d3a1aee3b5b8d2 Mon Sep 17 00:00:00 2001 From: Filippo Valsorda Date: Fri, 26 Oct 2018 11:41:02 -0400 Subject: [PATCH 4/5] crypto/tls: bump test timeouts from 1s to 1m for slow builders The arm5 and mips builders are can't-send-a-packet-to-localhost-in-1s slow apparently. 1m is less useful, but still better than an obscure test timeout panic. Fixes #28405 Change-Id: I2feeae6ea1b095114caccaab4f6709f405faebad Reviewed-on: https://go-review.googlesource.com/c/145037 Run-TryBot: Filippo Valsorda TryBot-Result: Gobot Gobot Reviewed-by: Brad Fitzpatrick --- handshake_client_test.go | 6 +++--- handshake_server_test.go | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/handshake_client_test.go b/handshake_client_test.go index 5a1e608..437aaed 100644 --- a/handshake_client_test.go +++ b/handshake_client_test.go @@ -384,12 +384,12 @@ func (test *clientTest) run(t *testing.T, write bool) { } for i, b := range flows { if i%2 == 1 { - serverConn.SetWriteDeadline(time.Now().Add(1 * time.Second)) + serverConn.SetWriteDeadline(time.Now().Add(1 * time.Minute)) serverConn.Write(b) continue } bb := make([]byte, len(b)) - serverConn.SetReadDeadline(time.Now().Add(1 * time.Second)) + serverConn.SetReadDeadline(time.Now().Add(1 * time.Minute)) _, err := io.ReadFull(serverConn, bb) if err != nil { t.Fatalf("%s #%d: %s", test.name, i, err) @@ -1646,7 +1646,7 @@ func TestCloseClientConnectionOnIdleServer(t *testing.T) { serverConn.Read(b[:]) client.Close() }() - client.SetWriteDeadline(time.Now().Add(time.Second)) + client.SetWriteDeadline(time.Now().Add(time.Minute)) err := client.Handshake() if err != nil { if err, ok := err.(net.Error); ok && err.Timeout() { diff --git a/handshake_server_test.go b/handshake_server_test.go index 2a77584..e14adbd 100644 --- a/handshake_server_test.go +++ b/handshake_server_test.go @@ -615,12 +615,12 @@ func (test *serverTest) run(t *testing.T, write bool) { } for i, b := range flows { if i%2 == 0 { - clientConn.SetWriteDeadline(time.Now().Add(1 * time.Second)) + clientConn.SetWriteDeadline(time.Now().Add(1 * time.Minute)) clientConn.Write(b) continue } bb := make([]byte, len(b)) - clientConn.SetReadDeadline(time.Now().Add(1 * time.Second)) + clientConn.SetReadDeadline(time.Now().Add(1 * time.Minute)) n, err := io.ReadFull(clientConn, bb) if err != nil { t.Fatalf("%s #%d: %s\nRead %d, wanted %d, got %x, wanted %x\n", test.name, i+1, err, n, len(bb), bb[:n], b) @@ -1434,7 +1434,7 @@ func TestCloseServerConnectionOnIdleClient(t *testing.T) { clientConn.Write([]byte{'0'}) server.Close() }() - server.SetReadDeadline(time.Now().Add(time.Second)) + server.SetReadDeadline(time.Now().Add(time.Minute)) err := server.Handshake() if err != nil { if err, ok := err.(net.Error); ok && err.Timeout() { From 8d011ce74c8762f1bea0ccfc2f25fa241d54a5e7 Mon Sep 17 00:00:00 2001 From: Filippo Valsorda Date: Wed, 24 Oct 2018 21:22:00 -0400 Subject: [PATCH 5/5] crypto/tls: rewrite some messages with golang.org/x/crypto/cryptobyte As a first round, rewrite those handshake message types which can be reused in TLS 1.3 with golang.org/x/crypto/cryptobyte. All other types changed significantly in TLS 1.3 and will require separate implementations. They will be ported to cryptobyte in a later CL. The only semantic changes should be enforcing the random length on the marshaling side, enforcing a couple more "must not be empty" on the unmarshaling side, and checking the rest of the SNI list even if we only take the first. Change-Id: Idd2ced60c558fafcf02ee489195b6f3b4735fe22 Reviewed-on: https://go-review.googlesource.com/c/144115 Run-TryBot: Filippo Valsorda TryBot-Result: Gobot Gobot Reviewed-by: Adam Langley --- conn.go | 4 +- handshake_client.go | 6 +- handshake_messages.go | 1098 ++++++++++++++---------------------- handshake_messages_test.go | 15 +- handshake_server.go | 2 +- handshake_server_test.go | 24 +- 6 files changed, 460 insertions(+), 689 deletions(-) diff --git a/conn.go b/conn.go index 8e23643..dae5fd1 100644 --- a/conn.go +++ b/conn.go @@ -899,7 +899,7 @@ func (c *Conn) readHandshake() (interface{}, error) { m = new(certificateMsg) case typeCertificateRequest: m = &certificateRequestMsg{ - hasSignatureAndHash: c.vers >= VersionTLS12, + hasSignatureAlgorithm: c.vers >= VersionTLS12, } case typeCertificateStatus: m = new(certificateStatusMsg) @@ -911,7 +911,7 @@ func (c *Conn) readHandshake() (interface{}, error) { m = new(clientKeyExchangeMsg) case typeCertificateVerify: m = &certificateVerifyMsg{ - hasSignatureAndHash: c.vers >= VersionTLS12, + hasSignatureAlgorithm: c.vers >= VersionTLS12, } case typeNextProtocol: m = new(nextProtoMsg) diff --git a/handshake_client.go b/handshake_client.go index af290e3..fb74f79 100644 --- a/handshake_client.go +++ b/handshake_client.go @@ -471,7 +471,7 @@ func (hs *clientHandshakeState) doFullHandshake() error { if chainToSend != nil && len(chainToSend.Certificate) > 0 { certVerify := &certificateVerifyMsg{ - hasSignatureAndHash: c.vers >= VersionTLS12, + hasSignatureAlgorithm: c.vers >= VersionTLS12, } key, ok := chainToSend.PrivateKey.(crypto.Signer) @@ -486,7 +486,7 @@ func (hs *clientHandshakeState) doFullHandshake() error { return err } // SignatureAndHashAlgorithm was introduced in TLS 1.2. - if certVerify.hasSignatureAndHash { + if certVerify.hasSignatureAlgorithm { certVerify.signatureAlgorithm = signatureAlgorithm } digest, err := hs.finishedHash.hashForClientCertificate(sigType, hashFunc, hs.masterSecret) @@ -739,7 +739,7 @@ func (hs *clientHandshakeState) getCertificate(certReq *certificateRequestMsg) ( if c.config.GetClientCertificate != nil { var signatureSchemes []SignatureScheme - if !certReq.hasSignatureAndHash { + if !certReq.hasSignatureAlgorithm { // Prior to TLS 1.2, the signature schemes were not // included in the certificate request message. In this // case we use a plausible list based on the acceptable diff --git a/handshake_messages.go b/handshake_messages.go index c5d9950..d678555 100644 --- a/handshake_messages.go +++ b/handshake_messages.go @@ -5,9 +5,49 @@ package tls import ( + "fmt" + "golang_org/x/crypto/cryptobyte" "strings" ) +// The marshalingFunction type is an adapter to allow the use of ordinary +// functions as cryptobyte.MarshalingValue. +type marshalingFunction func(b *cryptobyte.Builder) error + +func (f marshalingFunction) Marshal(b *cryptobyte.Builder) error { + return f(b) +} + +// addBytesWithLength appends a sequence of bytes to the cryptobyte.Builder. If +// the length of the sequence is not the value specified, it produces an error. +func addBytesWithLength(b *cryptobyte.Builder, v []byte, n int) { + b.AddValue(marshalingFunction(func(b *cryptobyte.Builder) error { + if len(v) != n { + return fmt.Errorf("invalid value length: expected %d, got %d", n, len(v)) + } + b.AddBytes(v) + return nil + })) +} + +// readUint8LengthPrefixed acts like s.ReadUint8LengthPrefixed, but targets a +// []byte instead of a cryptobyte.String. +func readUint8LengthPrefixed(s *cryptobyte.String, out *[]byte) bool { + return s.ReadUint8LengthPrefixed((*cryptobyte.String)(out)) +} + +// readUint16LengthPrefixed acts like s.ReadUint16LengthPrefixed, but targets a +// []byte instead of a cryptobyte.String. +func readUint16LengthPrefixed(s *cryptobyte.String, out *[]byte) bool { + return s.ReadUint16LengthPrefixed((*cryptobyte.String)(out)) +} + +// readUint24LengthPrefixed acts like s.ReadUint24LengthPrefixed, but targets a +// []byte instead of a cryptobyte.String. +func readUint24LengthPrefixed(s *cryptobyte.String, out *[]byte) bool { + return s.ReadUint24LengthPrefixed((*cryptobyte.String)(out)) +} + type clientHelloMsg struct { raw []byte vers uint16 @@ -34,442 +74,289 @@ func (m *clientHelloMsg) marshal() []byte { return m.raw } - length := 2 + 32 + 1 + len(m.sessionId) + 2 + len(m.cipherSuites)*2 + 1 + len(m.compressionMethods) - numExtensions := 0 - extensionsLength := 0 - if m.nextProtoNeg { - numExtensions++ - } - if m.ocspStapling { - extensionsLength += 1 + 2 + 2 - numExtensions++ - } - if len(m.serverName) > 0 { - extensionsLength += 5 + len(m.serverName) - numExtensions++ - } - if len(m.supportedCurves) > 0 { - extensionsLength += 2 + 2*len(m.supportedCurves) - numExtensions++ - } - if len(m.supportedPoints) > 0 { - extensionsLength += 1 + len(m.supportedPoints) - numExtensions++ - } - if m.ticketSupported { - extensionsLength += len(m.sessionTicket) - numExtensions++ - } - if len(m.supportedSignatureAlgorithms) > 0 { - extensionsLength += 2 + 2*len(m.supportedSignatureAlgorithms) - numExtensions++ - } - if m.secureRenegotiationSupported { - extensionsLength += 1 + len(m.secureRenegotiation) - numExtensions++ - } - if len(m.alpnProtocols) > 0 { - extensionsLength += 2 - for _, s := range m.alpnProtocols { - if l := len(s); l == 0 || l > 255 { - panic("invalid ALPN protocol") + var b cryptobyte.Builder + b.AddUint8(typeClientHello) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(m.vers) + addBytesWithLength(b, m.random, 32) + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.sessionId) + }) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, suite := range m.cipherSuites { + b.AddUint16(suite) } - extensionsLength++ - extensionsLength += len(s) + }) + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.compressionMethods) + }) + + // If extensions aren't present, omit them. + var extensionsPresent bool + bWithoutExtensions := *b + + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if m.nextProtoNeg { + // draft-agl-tls-nextprotoneg-04 + b.AddUint16(extensionNextProtoNeg) + b.AddUint16(0) // empty extension_data + } + if len(m.serverName) > 0 { + // RFC 6066, Section 3 + b.AddUint16(extensionServerName) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8(0) // name_type = host_name + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte(m.serverName)) + }) + }) + }) + } + if m.ocspStapling { + // RFC 4366, Section 3.6 + b.AddUint16(extensionStatusRequest) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8(1) // status_type = ocsp + b.AddUint16(0) // empty responder_id_list + b.AddUint16(0) // empty request_extensions + }) + } + if len(m.supportedCurves) > 0 { + // RFC 4492, Section 5.1.1 + b.AddUint16(extensionSupportedCurves) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, curve := range m.supportedCurves { + b.AddUint16(uint16(curve)) + } + }) + }) + } + if len(m.supportedPoints) > 0 { + // RFC 4492, Section 5.1.2 + b.AddUint16(extensionSupportedPoints) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.supportedPoints) + }) + }) + } + if m.ticketSupported { + // RFC 5077, Section 3.2 + b.AddUint16(extensionSessionTicket) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.sessionTicket) + }) + } + if len(m.supportedSignatureAlgorithms) > 0 { + // RFC 5246, Section 7.4.1.4.1 + b.AddUint16(extensionSignatureAlgorithms) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, sigAlgo := range m.supportedSignatureAlgorithms { + b.AddUint16(uint16(sigAlgo)) + } + }) + }) + } + if m.secureRenegotiationSupported { + // RFC 5746, Section 3.2 + b.AddUint16(extensionRenegotiationInfo) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.secureRenegotiation) + }) + }) + } + if len(m.alpnProtocols) > 0 { + // RFC 7301, Section 3.1 + b.AddUint16(extensionALPN) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, proto := range m.alpnProtocols { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte(proto)) + }) + } + }) + }) + } + if m.scts { + // RFC 6962, Section 3.3.1 + b.AddUint16(extensionSCT) + b.AddUint16(0) // empty extension_data + } + + extensionsPresent = len(b.BytesOrPanic()) > 2 + }) + + if !extensionsPresent { + *b = bWithoutExtensions } - numExtensions++ - } - if m.scts { - numExtensions++ - } - if numExtensions > 0 { - extensionsLength += 4 * numExtensions - length += 2 + extensionsLength - } + }) - x := make([]byte, 4+length) - x[0] = typeClientHello - x[1] = uint8(length >> 16) - x[2] = uint8(length >> 8) - x[3] = uint8(length) - x[4] = uint8(m.vers >> 8) - x[5] = uint8(m.vers) - copy(x[6:38], m.random) - x[38] = uint8(len(m.sessionId)) - copy(x[39:39+len(m.sessionId)], m.sessionId) - y := x[39+len(m.sessionId):] - y[0] = uint8(len(m.cipherSuites) >> 7) - y[1] = uint8(len(m.cipherSuites) << 1) - for i, suite := range m.cipherSuites { - y[2+i*2] = uint8(suite >> 8) - y[3+i*2] = uint8(suite) - } - z := y[2+len(m.cipherSuites)*2:] - z[0] = uint8(len(m.compressionMethods)) - copy(z[1:], m.compressionMethods) - - z = z[1+len(m.compressionMethods):] - if numExtensions > 0 { - z[0] = byte(extensionsLength >> 8) - z[1] = byte(extensionsLength) - z = z[2:] - } - if m.nextProtoNeg { - z[0] = byte(extensionNextProtoNeg >> 8) - z[1] = byte(extensionNextProtoNeg & 0xff) - // The length is always 0 - z = z[4:] - } - if len(m.serverName) > 0 { - z[0] = byte(extensionServerName >> 8) - z[1] = byte(extensionServerName & 0xff) - l := len(m.serverName) + 5 - z[2] = byte(l >> 8) - z[3] = byte(l) - z = z[4:] - - // RFC 3546, Section 3.1 - // - // struct { - // NameType name_type; - // select (name_type) { - // case host_name: HostName; - // } name; - // } ServerName; - // - // enum { - // host_name(0), (255) - // } NameType; - // - // opaque HostName<1..2^16-1>; - // - // struct { - // ServerName server_name_list<1..2^16-1> - // } ServerNameList; - - z[0] = byte((len(m.serverName) + 3) >> 8) - z[1] = byte(len(m.serverName) + 3) - z[3] = byte(len(m.serverName) >> 8) - z[4] = byte(len(m.serverName)) - copy(z[5:], []byte(m.serverName)) - z = z[l:] - } - if m.ocspStapling { - // RFC 4366, Section 3.6 - z[0] = byte(extensionStatusRequest >> 8) - z[1] = byte(extensionStatusRequest) - z[2] = 0 - z[3] = 5 - z[4] = 1 // OCSP type - // Two zero valued uint16s for the two lengths. - z = z[9:] - } - if len(m.supportedCurves) > 0 { - // RFC 4492, Section 5.5.1 - z[0] = byte(extensionSupportedCurves >> 8) - z[1] = byte(extensionSupportedCurves) - l := 2 + 2*len(m.supportedCurves) - z[2] = byte(l >> 8) - z[3] = byte(l) - l -= 2 - z[4] = byte(l >> 8) - z[5] = byte(l) - z = z[6:] - for _, curve := range m.supportedCurves { - z[0] = byte(curve >> 8) - z[1] = byte(curve) - z = z[2:] - } - } - if len(m.supportedPoints) > 0 { - // RFC 4492, Section 5.5.2 - z[0] = byte(extensionSupportedPoints >> 8) - z[1] = byte(extensionSupportedPoints) - l := 1 + len(m.supportedPoints) - z[2] = byte(l >> 8) - z[3] = byte(l) - l-- - z[4] = byte(l) - z = z[5:] - for _, pointFormat := range m.supportedPoints { - z[0] = pointFormat - z = z[1:] - } - } - if m.ticketSupported { - // RFC 5077, Section 3.2 - z[0] = byte(extensionSessionTicket >> 8) - z[1] = byte(extensionSessionTicket) - l := len(m.sessionTicket) - z[2] = byte(l >> 8) - z[3] = byte(l) - z = z[4:] - copy(z, m.sessionTicket) - z = z[len(m.sessionTicket):] - } - if len(m.supportedSignatureAlgorithms) > 0 { - // RFC 5246, Section 7.4.1.4.1 - z[0] = byte(extensionSignatureAlgorithms >> 8) - z[1] = byte(extensionSignatureAlgorithms) - l := 2 + 2*len(m.supportedSignatureAlgorithms) - z[2] = byte(l >> 8) - z[3] = byte(l) - z = z[4:] - - l -= 2 - z[0] = byte(l >> 8) - z[1] = byte(l) - z = z[2:] - for _, sigAlgo := range m.supportedSignatureAlgorithms { - z[0] = byte(sigAlgo >> 8) - z[1] = byte(sigAlgo) - z = z[2:] - } - } - if m.secureRenegotiationSupported { - z[0] = byte(extensionRenegotiationInfo >> 8) - z[1] = byte(extensionRenegotiationInfo & 0xff) - z[2] = 0 - z[3] = byte(len(m.secureRenegotiation) + 1) - z[4] = byte(len(m.secureRenegotiation)) - z = z[5:] - copy(z, m.secureRenegotiation) - z = z[len(m.secureRenegotiation):] - } - if len(m.alpnProtocols) > 0 { - z[0] = byte(extensionALPN >> 8) - z[1] = byte(extensionALPN & 0xff) - lengths := z[2:] - z = z[6:] - - stringsLength := 0 - for _, s := range m.alpnProtocols { - l := len(s) - z[0] = byte(l) - copy(z[1:], s) - z = z[1+l:] - stringsLength += 1 + l - } - - lengths[2] = byte(stringsLength >> 8) - lengths[3] = byte(stringsLength) - stringsLength += 2 - lengths[0] = byte(stringsLength >> 8) - lengths[1] = byte(stringsLength) - } - if m.scts { - // RFC 6962, Section 3.3.1 - z[0] = byte(extensionSCT >> 8) - z[1] = byte(extensionSCT) - // zero uint16 for the zero-length extension_data - z = z[4:] - } - - m.raw = x - - return x + m.raw = b.BytesOrPanic() + return m.raw } func (m *clientHelloMsg) unmarshal(data []byte) bool { - if len(data) < 42 { + *m = clientHelloMsg{raw: data} + s := cryptobyte.String(data) + + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) || + !readUint8LengthPrefixed(&s, &m.sessionId) { return false } - m.raw = data - m.vers = uint16(data[4])<<8 | uint16(data[5]) - m.random = data[6:38] - sessionIdLen := int(data[38]) - if sessionIdLen > 32 || len(data) < 39+sessionIdLen { + + var cipherSuites cryptobyte.String + if !s.ReadUint16LengthPrefixed(&cipherSuites) { return false } - m.sessionId = data[39 : 39+sessionIdLen] - data = data[39+sessionIdLen:] - if len(data) < 2 { - return false - } - // cipherSuiteLen is the number of bytes of cipher suite numbers. Since - // they are uint16s, the number must be even. - cipherSuiteLen := int(data[0])<<8 | int(data[1]) - if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen { - return false - } - numCipherSuites := cipherSuiteLen / 2 - m.cipherSuites = make([]uint16, numCipherSuites) - for i := 0; i < numCipherSuites; i++ { - m.cipherSuites[i] = uint16(data[2+2*i])<<8 | uint16(data[3+2*i]) - if m.cipherSuites[i] == scsvRenegotiation { + m.cipherSuites = []uint16{} + m.secureRenegotiationSupported = false + for !cipherSuites.Empty() { + var suite uint16 + if !cipherSuites.ReadUint16(&suite) { + return false + } + if suite == scsvRenegotiation { m.secureRenegotiationSupported = true } + m.cipherSuites = append(m.cipherSuites, suite) } - data = data[2+cipherSuiteLen:] - if len(data) < 1 { + + if !readUint8LengthPrefixed(&s, &m.compressionMethods) { return false } - compressionMethodsLen := int(data[0]) - if len(data) < 1+compressionMethodsLen { - return false - } - m.compressionMethods = data[1 : 1+compressionMethodsLen] - data = data[1+compressionMethodsLen:] - - m.nextProtoNeg = false - m.serverName = "" - m.ocspStapling = false - m.ticketSupported = false - m.sessionTicket = nil - m.supportedSignatureAlgorithms = nil - m.alpnProtocols = nil - m.scts = false - - if len(data) == 0 { + if s.Empty() { // ClientHello is optionally followed by extension data return true } - if len(data) < 2 { + + var extensions cryptobyte.String + if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() { return false } - extensionsLength := int(data[0])<<8 | int(data[1]) - data = data[2:] - if extensionsLength != len(data) { - return false - } - - for len(data) != 0 { - if len(data) < 4 { - return false - } - extension := uint16(data[0])<<8 | uint16(data[1]) - length := int(data[2])<<8 | int(data[3]) - data = data[4:] - if len(data) < length { + for !extensions.Empty() { + var extension uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&extData) { return false } switch extension { case extensionServerName: - d := data[:length] - if len(d) < 2 { + // RFC 6066, Section 3 + var nameList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&nameList) || nameList.Empty() { return false } - namesLen := int(d[0])<<8 | int(d[1]) - d = d[2:] - if len(d) != namesLen { - return false - } - for len(d) > 0 { - if len(d) < 3 { + for !nameList.Empty() { + var nameType uint8 + var serverName cryptobyte.String + if !nameList.ReadUint8(&nameType) || + !nameList.ReadUint16LengthPrefixed(&serverName) || + serverName.Empty() { return false } - nameType := d[0] - nameLen := int(d[1])<<8 | int(d[2]) - d = d[3:] - if len(d) < nameLen { + if nameType != 0 { + continue + } + if len(m.serverName) != 0 { + // Multiple names of the same name_type are prohibited. return false } - if nameType == 0 { - m.serverName = string(d[:nameLen]) - // An SNI value may not include a trailing dot. - // See RFC 6066, Section 3. - if strings.HasSuffix(m.serverName, ".") { - return false - } - break + m.serverName = string(serverName) + // An SNI value may not include a trailing dot. + if strings.HasSuffix(m.serverName, ".") { + return false } - d = d[nameLen:] } case extensionNextProtoNeg: - if length > 0 { - return false - } + // draft-agl-tls-nextprotoneg-04 m.nextProtoNeg = true case extensionStatusRequest: - m.ocspStapling = length > 0 && data[0] == statusTypeOCSP + // RFC 4366, Section 3.6 + var statusType uint8 + var ignored cryptobyte.String + if !extData.ReadUint8(&statusType) || + !extData.ReadUint16LengthPrefixed(&ignored) || + !extData.ReadUint16LengthPrefixed(&ignored) { + return false + } + m.ocspStapling = statusType == statusTypeOCSP case extensionSupportedCurves: - // RFC 4492, Section 5.5.1 - if length < 2 { + // RFC 4492, Section 5.1.1 + var curves cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&curves) || curves.Empty() { return false } - l := int(data[0])<<8 | int(data[1]) - if l%2 == 1 || length != l+2 { - return false - } - numCurves := l / 2 - m.supportedCurves = make([]CurveID, numCurves) - d := data[2:] - for i := 0; i < numCurves; i++ { - m.supportedCurves[i] = CurveID(d[0])<<8 | CurveID(d[1]) - d = d[2:] + for !curves.Empty() { + var curve uint16 + if !curves.ReadUint16(&curve) { + return false + } + m.supportedCurves = append(m.supportedCurves, CurveID(curve)) } case extensionSupportedPoints: - // RFC 4492, Section 5.5.2 - if length < 1 { + // RFC 4492, Section 5.1.2 + if !readUint8LengthPrefixed(&extData, &m.supportedPoints) || + len(m.supportedPoints) == 0 { return false } - l := int(data[0]) - if length != l+1 { - return false - } - m.supportedPoints = make([]uint8, l) - copy(m.supportedPoints, data[1:]) case extensionSessionTicket: // RFC 5077, Section 3.2 m.ticketSupported = true - m.sessionTicket = data[:length] + extData.ReadBytes(&m.sessionTicket, len(extData)) case extensionSignatureAlgorithms: // RFC 5246, Section 7.4.1.4.1 - if length < 2 || length&1 != 0 { + var sigAndAlgs cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { return false } - l := int(data[0])<<8 | int(data[1]) - if l != length-2 { - return false - } - n := l / 2 - d := data[2:] - m.supportedSignatureAlgorithms = make([]SignatureScheme, n) - for i := range m.supportedSignatureAlgorithms { - m.supportedSignatureAlgorithms[i] = SignatureScheme(d[0])<<8 | SignatureScheme(d[1]) - d = d[2:] - } - case extensionRenegotiationInfo: - if length == 0 { - return false - } - d := data[:length] - l := int(d[0]) - d = d[1:] - if l != len(d) { - return false - } - - m.secureRenegotiation = d - m.secureRenegotiationSupported = true - case extensionALPN: - if length < 2 { - return false - } - l := int(data[0])<<8 | int(data[1]) - if l != length-2 { - return false - } - d := data[2:length] - for len(d) != 0 { - stringLen := int(d[0]) - d = d[1:] - if stringLen == 0 || stringLen > len(d) { + for !sigAndAlgs.Empty() { + var sigAndAlg uint16 + if !sigAndAlgs.ReadUint16(&sigAndAlg) { return false } - m.alpnProtocols = append(m.alpnProtocols, string(d[:stringLen])) - d = d[stringLen:] + m.supportedSignatureAlgorithms = append( + m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg)) } - case extensionSCT: - m.scts = true - if length != 0 { + case extensionRenegotiationInfo: + // RFC 5746, Section 3.2 + if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) { return false } + m.secureRenegotiationSupported = true + case extensionALPN: + // RFC 7301, Section 3.1 + var protoList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() { + return false + } + for !protoList.Empty() { + var proto cryptobyte.String + if !protoList.ReadUint8LengthPrefixed(&proto) || proto.Empty() { + return false + } + m.alpnProtocols = append(m.alpnProtocols, string(proto)) + } + case extensionSCT: + // RFC 6962, Section 3.3.1 + m.scts = true + default: + // Ignore unknown extensions. + continue + } + + if !extData.Empty() { + return false } - data = data[length:] } return true @@ -497,280 +384,165 @@ func (m *serverHelloMsg) marshal() []byte { return m.raw } - length := 38 + len(m.sessionId) - numExtensions := 0 - extensionsLength := 0 + var b cryptobyte.Builder + b.AddUint8(typeServerHello) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16(m.vers) + addBytesWithLength(b, m.random, 32) + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.sessionId) + }) + b.AddUint16(m.cipherSuite) + b.AddUint8(m.compressionMethod) - nextProtoLen := 0 - if m.nextProtoNeg { - numExtensions++ - for _, v := range m.nextProtos { - nextProtoLen += len(v) - } - nextProtoLen += len(m.nextProtos) - extensionsLength += nextProtoLen - } - if m.ocspStapling { - numExtensions++ - } - if m.ticketSupported { - numExtensions++ - } - if m.secureRenegotiationSupported { - extensionsLength += 1 + len(m.secureRenegotiation) - numExtensions++ - } - if alpnLen := len(m.alpnProtocol); alpnLen > 0 { - if alpnLen >= 256 { - panic("invalid ALPN protocol") - } - extensionsLength += 2 + 1 + alpnLen - numExtensions++ - } - sctLen := 0 - if len(m.scts) > 0 { - for _, sct := range m.scts { - sctLen += len(sct) + 2 - } - extensionsLength += 2 + sctLen - numExtensions++ - } + // If extensions aren't present, omit them. + var extensionsPresent bool + bWithoutExtensions := *b - if numExtensions > 0 { - extensionsLength += 4 * numExtensions - length += 2 + extensionsLength - } - - x := make([]byte, 4+length) - x[0] = typeServerHello - x[1] = uint8(length >> 16) - x[2] = uint8(length >> 8) - x[3] = uint8(length) - x[4] = uint8(m.vers >> 8) - x[5] = uint8(m.vers) - copy(x[6:38], m.random) - x[38] = uint8(len(m.sessionId)) - copy(x[39:39+len(m.sessionId)], m.sessionId) - z := x[39+len(m.sessionId):] - z[0] = uint8(m.cipherSuite >> 8) - z[1] = uint8(m.cipherSuite) - z[2] = m.compressionMethod - - z = z[3:] - if numExtensions > 0 { - z[0] = byte(extensionsLength >> 8) - z[1] = byte(extensionsLength) - z = z[2:] - } - if m.nextProtoNeg { - z[0] = byte(extensionNextProtoNeg >> 8) - z[1] = byte(extensionNextProtoNeg & 0xff) - z[2] = byte(nextProtoLen >> 8) - z[3] = byte(nextProtoLen) - z = z[4:] - - for _, v := range m.nextProtos { - l := len(v) - if l > 255 { - l = 255 + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if m.nextProtoNeg { + b.AddUint16(extensionNextProtoNeg) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, proto := range m.nextProtos { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte(proto)) + }) + } + }) } - z[0] = byte(l) - copy(z[1:], []byte(v[0:l])) - z = z[1+l:] + if m.ocspStapling { + b.AddUint16(extensionStatusRequest) + b.AddUint16(0) // empty extension_data + } + if m.ticketSupported { + b.AddUint16(extensionSessionTicket) + b.AddUint16(0) // empty extension_data + } + if m.secureRenegotiationSupported { + b.AddUint16(extensionRenegotiationInfo) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.secureRenegotiation) + }) + }) + } + if len(m.alpnProtocol) > 0 { + b.AddUint16(extensionALPN) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte(m.alpnProtocol)) + }) + }) + }) + } + if len(m.scts) > 0 { + b.AddUint16(extensionSCT) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, sct := range m.scts { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(sct) + }) + } + }) + }) + } + + extensionsPresent = len(b.BytesOrPanic()) > 2 + }) + + if !extensionsPresent { + *b = bWithoutExtensions } - } - if m.ocspStapling { - z[0] = byte(extensionStatusRequest >> 8) - z[1] = byte(extensionStatusRequest) - z = z[4:] - } - if m.ticketSupported { - z[0] = byte(extensionSessionTicket >> 8) - z[1] = byte(extensionSessionTicket) - z = z[4:] - } - if m.secureRenegotiationSupported { - z[0] = byte(extensionRenegotiationInfo >> 8) - z[1] = byte(extensionRenegotiationInfo & 0xff) - z[2] = 0 - z[3] = byte(len(m.secureRenegotiation) + 1) - z[4] = byte(len(m.secureRenegotiation)) - z = z[5:] - copy(z, m.secureRenegotiation) - z = z[len(m.secureRenegotiation):] - } - if alpnLen := len(m.alpnProtocol); alpnLen > 0 { - z[0] = byte(extensionALPN >> 8) - z[1] = byte(extensionALPN & 0xff) - l := 2 + 1 + alpnLen - z[2] = byte(l >> 8) - z[3] = byte(l) - l -= 2 - z[4] = byte(l >> 8) - z[5] = byte(l) - l -= 1 - z[6] = byte(l) - copy(z[7:], []byte(m.alpnProtocol)) - z = z[7+alpnLen:] - } - if sctLen > 0 { - z[0] = byte(extensionSCT >> 8) - z[1] = byte(extensionSCT) - l := sctLen + 2 - z[2] = byte(l >> 8) - z[3] = byte(l) - z[4] = byte(sctLen >> 8) - z[5] = byte(sctLen) + }) - z = z[6:] - for _, sct := range m.scts { - z[0] = byte(len(sct) >> 8) - z[1] = byte(len(sct)) - copy(z[2:], sct) - z = z[len(sct)+2:] - } - } - - m.raw = x - - return x + m.raw = b.BytesOrPanic() + return m.raw } func (m *serverHelloMsg) unmarshal(data []byte) bool { - if len(data) < 42 { - return false - } - m.raw = data - m.vers = uint16(data[4])<<8 | uint16(data[5]) - m.random = data[6:38] - sessionIdLen := int(data[38]) - if sessionIdLen > 32 || len(data) < 39+sessionIdLen { - return false - } - m.sessionId = data[39 : 39+sessionIdLen] - data = data[39+sessionIdLen:] - if len(data) < 3 { - return false - } - m.cipherSuite = uint16(data[0])<<8 | uint16(data[1]) - m.compressionMethod = data[2] - data = data[3:] + *m = serverHelloMsg{raw: data} + s := cryptobyte.String(data) - m.nextProtoNeg = false - m.nextProtos = nil - m.ocspStapling = false - m.scts = nil - m.ticketSupported = false - m.alpnProtocol = "" + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) || + !readUint8LengthPrefixed(&s, &m.sessionId) || + !s.ReadUint16(&m.cipherSuite) || + !s.ReadUint8(&m.compressionMethod) { + return false + } - if len(data) == 0 { + if s.Empty() { // ServerHello is optionally followed by extension data return true } - if len(data) < 2 { + + var extensions cryptobyte.String + if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() { return false } - extensionsLength := int(data[0])<<8 | int(data[1]) - data = data[2:] - if len(data) != extensionsLength { - return false - } - - for len(data) != 0 { - if len(data) < 4 { - return false - } - extension := uint16(data[0])<<8 | uint16(data[1]) - length := int(data[2])<<8 | int(data[3]) - data = data[4:] - if len(data) < length { + for !extensions.Empty() { + var extension uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&extData) { return false } switch extension { case extensionNextProtoNeg: m.nextProtoNeg = true - d := data[:length] - for len(d) > 0 { - l := int(d[0]) - d = d[1:] - if l == 0 || l > len(d) { + for !extData.Empty() { + var proto cryptobyte.String + if !extData.ReadUint8LengthPrefixed(&proto) || + proto.Empty() { return false } - m.nextProtos = append(m.nextProtos, string(d[:l])) - d = d[l:] + m.nextProtos = append(m.nextProtos, string(proto)) } case extensionStatusRequest: - if length > 0 { - return false - } m.ocspStapling = true case extensionSessionTicket: - if length > 0 { - return false - } m.ticketSupported = true case extensionRenegotiationInfo: - if length == 0 { + if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) { return false } - d := data[:length] - l := int(d[0]) - d = d[1:] - if l != len(d) { - return false - } - - m.secureRenegotiation = d m.secureRenegotiationSupported = true case extensionALPN: - d := data[:length] - if len(d) < 3 { + var protoList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() { return false } - l := int(d[0])<<8 | int(d[1]) - if l != len(d)-2 { + var proto cryptobyte.String + if !protoList.ReadUint8LengthPrefixed(&proto) || + proto.Empty() || !protoList.Empty() { return false } - d = d[2:] - l = int(d[0]) - if l != len(d)-1 { - return false - } - d = d[1:] - if len(d) == 0 { - // ALPN protocols must not be empty. - return false - } - m.alpnProtocol = string(d) + m.alpnProtocol = string(proto) case extensionSCT: - d := data[:length] - - if len(d) < 2 { + var sctList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() { return false } - l := int(d[0])<<8 | int(d[1]) - d = d[2:] - if len(d) != l || l == 0 { - return false - } - - m.scts = make([][]byte, 0, 3) - for len(d) != 0 { - if len(d) < 2 { + for !sctList.Empty() { + var sct []byte + if !readUint16LengthPrefixed(&sctList, &sct) || + len(sct) == 0 { return false } - sctLen := int(d[0])<<8 | int(d[1]) - d = d[2:] - if sctLen == 0 || len(d) < sctLen { - return false - } - m.scts = append(m.scts, d[:sctLen]) - d = d[sctLen:] + m.scts = append(m.scts, sct) } + default: + // Ignore unknown extensions. + continue + } + + if !extData.Empty() { + return false } - data = data[length:] } return true @@ -989,26 +761,27 @@ type finishedMsg struct { verifyData []byte } -func (m *finishedMsg) marshal() (x []byte) { +func (m *finishedMsg) marshal() []byte { if m.raw != nil { return m.raw } - x = make([]byte, 4+len(m.verifyData)) - x[0] = typeFinished - x[3] = byte(len(m.verifyData)) - copy(x[4:], m.verifyData) - m.raw = x - return + var b cryptobyte.Builder + b.AddUint8(typeFinished) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.verifyData) + }) + + m.raw = b.BytesOrPanic() + return m.raw } func (m *finishedMsg) unmarshal(data []byte) bool { m.raw = data - if len(data) < 4 { - return false - } - m.verifyData = data[4:] - return true + s := cryptobyte.String(data) + return s.Skip(1) && + readUint24LengthPrefixed(&s, &m.verifyData) && + s.Empty() } type nextProtoMsg struct { @@ -1073,10 +846,9 @@ func (m *nextProtoMsg) unmarshal(data []byte) bool { type certificateRequestMsg struct { raw []byte - // hasSignatureAndHash indicates whether this message includes a list - // of signature and hash functions. This change was introduced with TLS - // 1.2. - hasSignatureAndHash bool + // hasSignatureAlgorithm indicates whether this message includes a list of + // supported signature algorithms. This change was introduced with TLS 1.2. + hasSignatureAlgorithm bool certificateTypes []byte supportedSignatureAlgorithms []SignatureScheme @@ -1096,7 +868,7 @@ func (m *certificateRequestMsg) marshal() (x []byte) { } length += casLength - if m.hasSignatureAndHash { + if m.hasSignatureAlgorithm { length += 2 + 2*len(m.supportedSignatureAlgorithms) } @@ -1111,7 +883,7 @@ func (m *certificateRequestMsg) marshal() (x []byte) { copy(x[5:], m.certificateTypes) y := x[5+len(m.certificateTypes):] - if m.hasSignatureAndHash { + if m.hasSignatureAlgorithm { n := len(m.supportedSignatureAlgorithms) * 2 y[0] = uint8(n >> 8) y[1] = uint8(n) @@ -1163,7 +935,7 @@ func (m *certificateRequestMsg) unmarshal(data []byte) bool { data = data[numCertTypes:] - if m.hasSignatureAndHash { + if m.hasSignatureAlgorithm { if len(data) < 2 { return false } @@ -1215,10 +987,10 @@ func (m *certificateRequestMsg) unmarshal(data []byte) bool { } type certificateVerifyMsg struct { - raw []byte - hasSignatureAndHash bool - signatureAlgorithm SignatureScheme - signature []byte + raw []byte + hasSignatureAlgorithm bool // format change introduced in TLS 1.2 + signatureAlgorithm SignatureScheme + signature []byte } func (m *certificateVerifyMsg) marshal() (x []byte) { @@ -1226,62 +998,34 @@ func (m *certificateVerifyMsg) marshal() (x []byte) { return m.raw } - // See RFC 4346, Section 7.4.8. - siglength := len(m.signature) - length := 2 + siglength - if m.hasSignatureAndHash { - length += 2 - } - x = make([]byte, 4+length) - x[0] = typeCertificateVerify - x[1] = uint8(length >> 16) - x[2] = uint8(length >> 8) - x[3] = uint8(length) - y := x[4:] - if m.hasSignatureAndHash { - y[0] = uint8(m.signatureAlgorithm >> 8) - y[1] = uint8(m.signatureAlgorithm) - y = y[2:] - } - y[0] = uint8(siglength >> 8) - y[1] = uint8(siglength) - copy(y[2:], m.signature) + var b cryptobyte.Builder + b.AddUint8(typeCertificateVerify) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + if m.hasSignatureAlgorithm { + b.AddUint16(uint16(m.signatureAlgorithm)) + } + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.signature) + }) + }) - m.raw = x - - return + m.raw = b.BytesOrPanic() + return m.raw } func (m *certificateVerifyMsg) unmarshal(data []byte) bool { m.raw = data + s := cryptobyte.String(data) - if len(data) < 6 { + if !s.Skip(4) { // message type and uint24 length field return false } - - length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3]) - if uint32(len(data))-4 != length { - return false + if m.hasSignatureAlgorithm { + if !s.ReadUint16((*uint16)(&m.signatureAlgorithm)) { + return false + } } - - data = data[4:] - if m.hasSignatureAndHash { - m.signatureAlgorithm = SignatureScheme(data[0])<<8 | SignatureScheme(data[1]) - data = data[2:] - } - - if len(data) < 2 { - return false - } - siglength := int(data[0])<<8 + int(data[1]) - data = data[2:] - if len(data) != siglength { - return false - } - - m.signature = data - - return true + return readUint16LengthPrefixed(&s, &m.signature) && s.Empty() } type newSessionTicketMsg struct { diff --git a/handshake_messages_test.go b/handshake_messages_test.go index c8cc0d6..fbc294b 100644 --- a/handshake_messages_test.go +++ b/handshake_messages_test.go @@ -20,7 +20,9 @@ var tests = []interface{}{ &certificateMsg{}, &certificateRequestMsg{}, - &certificateVerifyMsg{}, + &certificateVerifyMsg{ + hasSignatureAlgorithm: true, + }, &certificateStatusMsg{}, &clientKeyExchangeMsg{}, &nextProtoMsg{}, @@ -149,6 +151,10 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { if rand.Intn(10) > 5 { m.scts = true } + if rand.Intn(10) > 5 { + m.secureRenegotiationSupported = true + m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand) + } return reflect.ValueOf(m) } @@ -180,6 +186,11 @@ func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand)) } + if rand.Intn(10) > 5 { + m.secureRenegotiationSupported = true + m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand) + } + return reflect.ValueOf(m) } @@ -204,6 +215,8 @@ func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value { m := &certificateVerifyMsg{} + m.hasSignatureAlgorithm = true + m.signatureAlgorithm = SignatureScheme(rand.Intn(30000)) m.signature = randomBytes(rand.Intn(15)+1, rand) return reflect.ValueOf(m) } diff --git a/handshake_server.go b/handshake_server.go index b077c90..bec128f 100644 --- a/handshake_server.go +++ b/handshake_server.go @@ -418,7 +418,7 @@ func (hs *serverHandshakeState) doFullHandshake() error { byte(certTypeECDSASign), } if c.vers >= VersionTLS12 { - certReq.hasSignatureAndHash = true + certReq.hasSignatureAlgorithm = true certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms } diff --git a/handshake_server_test.go b/handshake_server_test.go index e14adbd..01de92d 100644 --- a/handshake_server_test.go +++ b/handshake_server_test.go @@ -101,13 +101,17 @@ var badProtocolVersions = []uint16{0x0000, 0x0005, 0x0100, 0x0105, 0x0200, 0x020 func TestRejectBadProtocolVersion(t *testing.T) { for _, v := range badProtocolVersions { - testClientHelloFailure(t, testConfig, &clientHelloMsg{vers: v}, "unsupported, maximum protocol version") + testClientHelloFailure(t, testConfig, &clientHelloMsg{ + vers: v, + random: make([]byte, 32), + }, "unsupported, maximum protocol version") } } func TestNoSuiteOverlap(t *testing.T) { clientHello := &clientHelloMsg{ vers: VersionTLS10, + random: make([]byte, 32), cipherSuites: []uint16{0xff00}, compressionMethods: []uint8{compressionNone}, } @@ -117,6 +121,7 @@ func TestNoSuiteOverlap(t *testing.T) { func TestNoCompressionOverlap(t *testing.T) { clientHello := &clientHelloMsg{ vers: VersionTLS10, + random: make([]byte, 32), cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, compressionMethods: []uint8{0xff}, } @@ -126,6 +131,7 @@ func TestNoCompressionOverlap(t *testing.T) { func TestNoRC4ByDefault(t *testing.T) { clientHello := &clientHelloMsg{ vers: VersionTLS10, + random: make([]byte, 32), cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, compressionMethods: []uint8{compressionNone}, } @@ -137,7 +143,11 @@ func TestNoRC4ByDefault(t *testing.T) { } func TestRejectSNIWithTrailingDot(t *testing.T) { - testClientHelloFailure(t, testConfig, &clientHelloMsg{vers: VersionTLS12, serverName: "foo.com."}, "unexpected message") + testClientHelloFailure(t, testConfig, &clientHelloMsg{ + vers: VersionTLS12, + random: make([]byte, 32), + serverName: "foo.com.", + }, "unexpected message") } func TestDontSelectECDSAWithRSAKey(t *testing.T) { @@ -145,6 +155,7 @@ func TestDontSelectECDSAWithRSAKey(t *testing.T) { // won't be selected if the server's private key doesn't support it. clientHello := &clientHelloMsg{ vers: VersionTLS10, + random: make([]byte, 32), cipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA}, compressionMethods: []uint8{compressionNone}, supportedCurves: []CurveID{CurveP256}, @@ -170,6 +181,7 @@ func TestDontSelectRSAWithECDSAKey(t *testing.T) { // won't be selected if the server's private key doesn't support it. clientHello := &clientHelloMsg{ vers: VersionTLS10, + random: make([]byte, 32), cipherSuites: []uint16{TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA}, compressionMethods: []uint8{compressionNone}, supportedCurves: []CurveID{CurveP256}, @@ -242,11 +254,9 @@ func TestRenegotiationExtension(t *testing.T) { func TestTLS12OnlyCipherSuites(t *testing.T) { // Test that a Server doesn't select a TLS 1.2-only cipher suite when // the client negotiates TLS 1.1. - var zeros [32]byte - clientHello := &clientHelloMsg{ vers: VersionTLS11, - random: zeros[:], + random: make([]byte, 32), cipherSuites: []uint16{ // The Server, by default, will use the client's // preference order. So the GCM cipher suite @@ -878,6 +888,7 @@ func TestHandshakeServerSNIGetCertificateError(t *testing.T) { clientHello := &clientHelloMsg{ vers: VersionTLS10, + random: make([]byte, 32), cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, compressionMethods: []uint8{compressionNone}, serverName: "test", @@ -898,6 +909,7 @@ func TestHandshakeServerEmptyCertificates(t *testing.T) { clientHello := &clientHelloMsg{ vers: VersionTLS10, + random: make([]byte, 32), cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, compressionMethods: []uint8{compressionNone}, } @@ -909,6 +921,7 @@ func TestHandshakeServerEmptyCertificates(t *testing.T) { clientHello = &clientHelloMsg{ vers: VersionTLS10, + random: make([]byte, 32), cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, compressionMethods: []uint8{compressionNone}, } @@ -1212,6 +1225,7 @@ func TestSNIGivenOnFailure(t *testing.T) { clientHello := &clientHelloMsg{ vers: VersionTLS10, + random: make([]byte, 32), cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, compressionMethods: []uint8{compressionNone}, serverName: expectedServerName,