crypto/tls: don't cache marshal'd bytes

Only cache the wire representation for clientHelloMsg and serverHelloMsg
during unmarshal, which are the only places we actually need to hold
onto them. For everything else, remove the raw field.

This appears to have zero performance impact:

name                                               old time/op   new time/op   delta
CertCache/0-10                                       177µs ± 2%    189µs ±11%   ~     (p=0.700 n=3+3)
CertCache/1-10                                       184µs ± 3%    182µs ± 6%   ~     (p=1.000 n=3+3)
CertCache/2-10                                       187µs ±12%    187µs ± 2%   ~     (p=1.000 n=3+3)
CertCache/3-10                                       204µs ±21%    187µs ± 1%   ~     (p=0.700 n=3+3)
HandshakeServer/RSA-10                               410µs ± 2%    410µs ± 3%   ~     (p=1.000 n=3+3)
HandshakeServer/ECDHE-P256-RSA/TLSv13-10             473µs ± 3%    460µs ± 2%   ~     (p=0.200 n=3+3)
HandshakeServer/ECDHE-P256-RSA/TLSv12-10             498µs ± 3%    489µs ± 2%   ~     (p=0.700 n=3+3)
HandshakeServer/ECDHE-P256-ECDSA-P256/TLSv13-10      140µs ± 5%    138µs ± 5%   ~     (p=1.000 n=3+3)
HandshakeServer/ECDHE-P256-ECDSA-P256/TLSv12-10      132µs ± 1%    133µs ± 2%   ~     (p=0.400 n=3+3)
HandshakeServer/ECDHE-X25519-ECDSA-P256/TLSv13-10    168µs ± 1%    171µs ± 4%   ~     (p=1.000 n=3+3)
HandshakeServer/ECDHE-X25519-ECDSA-P256/TLSv12-10    166µs ± 3%    163µs ± 0%   ~     (p=0.700 n=3+3)
HandshakeServer/ECDHE-P521-ECDSA-P521/TLSv13-10     1.87ms ± 2%   1.81ms ± 0%   ~     (p=0.100 n=3+3)
HandshakeServer/ECDHE-P521-ECDSA-P521/TLSv12-10     1.86ms ± 0%   1.86ms ± 1%   ~     (p=1.000 n=3+3)
Throughput/MaxPacket/1MB/TLSv12-10                  6.79ms ± 3%   6.73ms ± 0%   ~     (p=1.000 n=3+3)
Throughput/MaxPacket/1MB/TLSv13-10                  6.73ms ± 1%   6.75ms ± 0%   ~     (p=0.700 n=3+3)
Throughput/MaxPacket/2MB/TLSv12-10                  12.8ms ± 2%   12.7ms ± 0%   ~     (p=0.700 n=3+3)
Throughput/MaxPacket/2MB/TLSv13-10                  13.1ms ± 3%   12.8ms ± 1%   ~     (p=0.400 n=3+3)
Throughput/MaxPacket/4MB/TLSv12-10                  24.9ms ± 2%   24.7ms ± 1%   ~     (p=1.000 n=3+3)
Throughput/MaxPacket/4MB/TLSv13-10                  26.0ms ± 4%   24.9ms ± 1%   ~     (p=0.100 n=3+3)
Throughput/MaxPacket/8MB/TLSv12-10                  50.0ms ± 3%   48.9ms ± 0%   ~     (p=0.200 n=3+3)
Throughput/MaxPacket/8MB/TLSv13-10                  49.8ms ± 2%   49.3ms ± 1%   ~     (p=0.400 n=3+3)
Throughput/MaxPacket/16MB/TLSv12-10                 97.3ms ± 1%   97.4ms ± 0%   ~     (p=0.700 n=3+3)
Throughput/MaxPacket/16MB/TLSv13-10                 97.9ms ± 0%   97.9ms ± 1%   ~     (p=1.000 n=3+3)
Throughput/MaxPacket/32MB/TLSv12-10                  195ms ± 0%    194ms ± 1%   ~     (p=0.400 n=3+3)
Throughput/MaxPacket/32MB/TLSv13-10                  196ms ± 0%    196ms ± 1%   ~     (p=0.700 n=3+3)
Throughput/MaxPacket/64MB/TLSv12-10                  405ms ± 3%    385ms ± 0%   ~     (p=0.100 n=3+3)
Throughput/MaxPacket/64MB/TLSv13-10                  391ms ± 1%    388ms ± 1%   ~     (p=0.200 n=3+3)
Throughput/DynamicPacket/1MB/TLSv12-10              6.75ms ± 0%   6.75ms ± 1%   ~     (p=0.700 n=3+3)
Throughput/DynamicPacket/1MB/TLSv13-10              6.84ms ± 1%   6.77ms ± 0%   ~     (p=0.100 n=3+3)
Throughput/DynamicPacket/2MB/TLSv12-10              12.8ms ± 1%   12.8ms ± 1%   ~     (p=0.400 n=3+3)
Throughput/DynamicPacket/2MB/TLSv13-10              12.8ms ± 1%   13.0ms ± 1%   ~     (p=0.200 n=3+3)
Throughput/DynamicPacket/4MB/TLSv12-10              24.8ms ± 1%   24.8ms ± 0%   ~     (p=1.000 n=3+3)
Throughput/DynamicPacket/4MB/TLSv13-10              25.1ms ± 2%   25.1ms ± 1%   ~     (p=1.000 n=3+3)
Throughput/DynamicPacket/8MB/TLSv12-10              49.2ms ± 2%   48.9ms ± 0%   ~     (p=0.700 n=3+3)
Throughput/DynamicPacket/8MB/TLSv13-10              49.3ms ± 1%   49.4ms ± 1%   ~     (p=0.700 n=3+3)
Throughput/DynamicPacket/16MB/TLSv12-10             97.1ms ± 0%   98.0ms ± 1%   ~     (p=0.200 n=3+3)
Throughput/DynamicPacket/16MB/TLSv13-10             98.8ms ± 1%   98.4ms ± 1%   ~     (p=0.700 n=3+3)
Throughput/DynamicPacket/32MB/TLSv12-10              192ms ± 0%    198ms ± 5%   ~     (p=0.100 n=3+3)
Throughput/DynamicPacket/32MB/TLSv13-10              194ms ± 0%    196ms ± 1%   ~     (p=0.400 n=3+3)
Throughput/DynamicPacket/64MB/TLSv12-10              385ms ± 1%    384ms ± 0%   ~     (p=0.700 n=3+3)
Throughput/DynamicPacket/64MB/TLSv13-10              387ms ± 0%    388ms ± 0%   ~     (p=0.400 n=3+3)
Latency/MaxPacket/200kbps/TLSv12-10                  694ms ± 0%    694ms ± 0%   ~     (p=0.700 n=3+3)
Latency/MaxPacket/200kbps/TLSv13-10                  699ms ± 0%    699ms ± 0%   ~     (p=0.700 n=3+3)
Latency/MaxPacket/500kbps/TLSv12-10                  278ms ± 0%    278ms ± 0%   ~     (p=0.400 n=3+3)
Latency/MaxPacket/500kbps/TLSv13-10                  280ms ± 0%    280ms ± 0%   ~     (p=1.000 n=3+3)
Latency/MaxPacket/1000kbps/TLSv12-10                 140ms ± 1%    140ms ± 0%   ~     (p=0.700 n=3+3)
Latency/MaxPacket/1000kbps/TLSv13-10                 141ms ± 0%    141ms ± 0%   ~     (p=1.000 n=3+3)
Latency/MaxPacket/2000kbps/TLSv12-10                70.5ms ± 0%   70.4ms ± 0%   ~     (p=0.700 n=3+3)
Latency/MaxPacket/2000kbps/TLSv13-10                70.7ms ± 0%   70.7ms ± 0%   ~     (p=0.700 n=3+3)
Latency/MaxPacket/5000kbps/TLSv12-10                28.8ms ± 0%   28.8ms ± 0%   ~     (p=0.700 n=3+3)
Latency/MaxPacket/5000kbps/TLSv13-10                28.9ms ± 0%   28.9ms ± 0%   ~     (p=0.700 n=3+3)
Latency/DynamicPacket/200kbps/TLSv12-10              134ms ± 0%    134ms ± 0%   ~     (p=0.700 n=3+3)
Latency/DynamicPacket/200kbps/TLSv13-10              138ms ± 0%    138ms ± 0%   ~     (p=1.000 n=3+3)
Latency/DynamicPacket/500kbps/TLSv12-10             54.1ms ± 0%   54.1ms ± 0%   ~     (p=1.000 n=3+3)
Latency/DynamicPacket/500kbps/TLSv13-10             55.7ms ± 0%   55.7ms ± 0%   ~     (p=0.100 n=3+3)
Latency/DynamicPacket/1000kbps/TLSv12-10            27.6ms ± 0%   27.6ms ± 0%   ~     (p=0.200 n=3+3)
Latency/DynamicPacket/1000kbps/TLSv13-10            28.4ms ± 0%   28.4ms ± 0%   ~     (p=0.200 n=3+3)
Latency/DynamicPacket/2000kbps/TLSv12-10            14.4ms ± 0%   14.4ms ± 0%   ~     (p=1.000 n=3+3)
Latency/DynamicPacket/2000kbps/TLSv13-10            14.6ms ± 0%   14.6ms ± 0%   ~     (p=1.000 n=3+3)
Latency/DynamicPacket/5000kbps/TLSv12-10            6.44ms ± 0%   6.45ms ± 0%   ~     (p=0.100 n=3+3)
Latency/DynamicPacket/5000kbps/TLSv13-10            6.49ms ± 0%   6.49ms ± 0%   ~     (p=0.700 n=3+3)

name                                               old speed     new speed     delta
Throughput/MaxPacket/1MB/TLSv12-10                 155MB/s ± 3%  156MB/s ± 0%   ~     (p=1.000 n=3+3)
Throughput/MaxPacket/1MB/TLSv13-10                 156MB/s ± 1%  155MB/s ± 0%   ~     (p=0.700 n=3+3)
Throughput/MaxPacket/2MB/TLSv12-10                 163MB/s ± 2%  165MB/s ± 0%   ~     (p=0.700 n=3+3)
Throughput/MaxPacket/2MB/TLSv13-10                 160MB/s ± 3%  164MB/s ± 1%   ~     (p=0.400 n=3+3)
Throughput/MaxPacket/4MB/TLSv12-10                 168MB/s ± 2%  170MB/s ± 1%   ~     (p=1.000 n=3+3)
Throughput/MaxPacket/4MB/TLSv13-10                 162MB/s ± 4%  168MB/s ± 1%   ~     (p=0.100 n=3+3)
Throughput/MaxPacket/8MB/TLSv12-10                 168MB/s ± 3%  172MB/s ± 0%   ~     (p=0.200 n=3+3)
Throughput/MaxPacket/8MB/TLSv13-10                 168MB/s ± 2%  170MB/s ± 1%   ~     (p=0.400 n=3+3)
Throughput/MaxPacket/16MB/TLSv12-10                172MB/s ± 1%  172MB/s ± 0%   ~     (p=0.700 n=3+3)
Throughput/MaxPacket/16MB/TLSv13-10                171MB/s ± 0%  171MB/s ± 1%   ~     (p=1.000 n=3+3)
Throughput/MaxPacket/32MB/TLSv12-10                172MB/s ± 0%  173MB/s ± 1%   ~     (p=0.400 n=3+3)
Throughput/MaxPacket/32MB/TLSv13-10                171MB/s ± 0%  172MB/s ± 1%   ~     (p=0.700 n=3+3)
Throughput/MaxPacket/64MB/TLSv12-10                166MB/s ± 3%  174MB/s ± 0%   ~     (p=0.100 n=3+3)
Throughput/MaxPacket/64MB/TLSv13-10                171MB/s ± 1%  173MB/s ± 1%   ~     (p=0.200 n=3+3)
Throughput/DynamicPacket/1MB/TLSv12-10             155MB/s ± 0%  155MB/s ± 1%   ~     (p=0.700 n=3+3)
Throughput/DynamicPacket/1MB/TLSv13-10             153MB/s ± 1%  155MB/s ± 0%   ~     (p=0.100 n=3+3)
Throughput/DynamicPacket/2MB/TLSv12-10             164MB/s ± 1%  164MB/s ± 1%   ~     (p=0.400 n=3+3)
Throughput/DynamicPacket/2MB/TLSv13-10             163MB/s ± 1%  162MB/s ± 1%   ~     (p=0.200 n=3+3)
Throughput/DynamicPacket/4MB/TLSv12-10             169MB/s ± 1%  169MB/s ± 0%   ~     (p=1.000 n=3+3)
Throughput/DynamicPacket/4MB/TLSv13-10             167MB/s ± 1%  167MB/s ± 1%   ~     (p=1.000 n=3+3)
Throughput/DynamicPacket/8MB/TLSv12-10             170MB/s ± 2%  171MB/s ± 0%   ~     (p=0.700 n=3+3)
Throughput/DynamicPacket/8MB/TLSv13-10             170MB/s ± 1%  170MB/s ± 1%   ~     (p=0.700 n=3+3)
Throughput/DynamicPacket/16MB/TLSv12-10            173MB/s ± 0%  171MB/s ± 1%   ~     (p=0.200 n=3+3)
Throughput/DynamicPacket/16MB/TLSv13-10            170MB/s ± 1%  170MB/s ± 1%   ~     (p=0.700 n=3+3)
Throughput/DynamicPacket/32MB/TLSv12-10            175MB/s ± 0%  170MB/s ± 5%   ~     (p=0.100 n=3+3)
Throughput/DynamicPacket/32MB/TLSv13-10            173MB/s ± 0%  171MB/s ± 1%   ~     (p=0.300 n=3+3)
Throughput/DynamicPacket/64MB/TLSv12-10            174MB/s ± 1%  175MB/s ± 0%   ~     (p=0.700 n=3+3)
Throughput/DynamicPacket/64MB/TLSv13-10            174MB/s ± 0%  173MB/s ± 0%   ~     (p=0.400 n=3+3)

Change-Id: Ifa79cce002011850ed8b2835edd34f60e014eea8
Cq-Include-Trybots: luci.golang.try:gotip-linux-amd64-longtest,gotip-linux-arm64-longtest
Reviewed-on: https://go-review.googlesource.com/c/go/+/580215
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Damien Neil <dneil@google.com>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
Auto-Submit: Roland Shoemaker <roland@golang.org>
This commit is contained in:
Roland Shoemaker 2024-04-18 10:51:25 -07:00 committed by Gopher Robot
parent c16fc29ae8
commit 990f526003
4 changed files with 114 additions and 189 deletions

View file

@ -1445,6 +1445,15 @@ type handshakeMessage interface {
unmarshal([]byte) bool unmarshal([]byte) bool
} }
type handshakeMessageWithOriginalBytes interface {
handshakeMessage
// originalBytes should return the original bytes that were passed to
// unmarshal to create the message. If the message was not produced by
// unmarshal, it should return nil.
originalBytes() []byte
}
// lruSessionCache is a ClientSessionCache implementation that uses an LRU // lruSessionCache is a ClientSessionCache implementation that uses an LRU
// caching strategy. // caching strategy.
type lruSessionCache struct { type lruSessionCache struct {

View file

@ -249,7 +249,6 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
hs.hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}} hs.hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}}
} }
hs.hello.raw = nil
if len(hs.hello.pskIdentities) > 0 { if len(hs.hello.pskIdentities) > 0 {
pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite) pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite)
if pskSuite == nil { if pskSuite == nil {

View file

@ -68,7 +68,7 @@ func readUint24LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
} }
type clientHelloMsg struct { type clientHelloMsg struct {
raw []byte original []byte
vers uint16 vers uint16
random []byte random []byte
sessionId []byte sessionId []byte
@ -98,10 +98,6 @@ type clientHelloMsg struct {
} }
func (m *clientHelloMsg) marshal() ([]byte, error) { func (m *clientHelloMsg) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
var exts cryptobyte.Builder var exts cryptobyte.Builder
if len(m.serverName) > 0 { if len(m.serverName) > 0 {
// RFC 6066, Section 3 // RFC 6066, Section 3
@ -310,8 +306,7 @@ func (m *clientHelloMsg) marshal() ([]byte, error) {
} }
}) })
m.raw, err = b.Bytes() return b.Bytes()
return m.raw, err
} }
// marshalWithoutBinders returns the ClientHello through the // marshalWithoutBinders returns the ClientHello through the
@ -324,16 +319,21 @@ func (m *clientHelloMsg) marshalWithoutBinders() ([]byte, error) {
bindersLen += len(binder) bindersLen += len(binder)
} }
fullMessage, err := m.marshal() var fullMessage []byte
if m.original != nil {
fullMessage = m.original
} else {
var err error
fullMessage, err = m.marshal()
if err != nil { if err != nil {
return nil, err return nil, err
} }
}
return fullMessage[:len(fullMessage)-bindersLen], nil return fullMessage[:len(fullMessage)-bindersLen], nil
} }
// updateBinders updates the m.pskBinders field, if necessary updating the // updateBinders updates the m.pskBinders field. The supplied binders must have
// cached marshaled representation. The supplied binders must have the same // the same length as the current m.pskBinders.
// length as the current m.pskBinders.
func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) error { func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) error {
if len(pskBinders) != len(m.pskBinders) { if len(pskBinders) != len(m.pskBinders) {
return errors.New("tls: internal error: pskBinders length mismatch") return errors.New("tls: internal error: pskBinders length mismatch")
@ -344,30 +344,12 @@ func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) error {
} }
} }
m.pskBinders = pskBinders m.pskBinders = pskBinders
if m.raw != nil {
helloBytes, err := m.marshalWithoutBinders()
if err != nil {
return err
}
lenWithoutBinders := len(helloBytes)
b := cryptobyte.NewFixedBuilder(m.raw[:lenWithoutBinders])
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
for _, binder := range m.pskBinders {
b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(binder)
})
}
})
if out, err := b.Bytes(); err != nil || len(out) != len(m.raw) {
return errors.New("tls: internal error: failed to update binders")
}
}
return nil return nil
} }
func (m *clientHelloMsg) unmarshal(data []byte) bool { func (m *clientHelloMsg) unmarshal(data []byte) bool {
*m = clientHelloMsg{raw: data} *m = clientHelloMsg{original: data}
s := cryptobyte.String(data) s := cryptobyte.String(data)
if !s.Skip(4) || // message type and uint24 length field if !s.Skip(4) || // message type and uint24 length field
@ -625,8 +607,12 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
return true return true
} }
func (m *clientHelloMsg) originalBytes() []byte {
return m.original
}
type serverHelloMsg struct { type serverHelloMsg struct {
raw []byte original []byte
vers uint16 vers uint16
random []byte random []byte
sessionId []byte sessionId []byte
@ -651,10 +637,6 @@ type serverHelloMsg struct {
} }
func (m *serverHelloMsg) marshal() ([]byte, error) { func (m *serverHelloMsg) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
var exts cryptobyte.Builder var exts cryptobyte.Builder
if m.ocspStapling { if m.ocspStapling {
exts.AddUint16(extensionStatusRequest) exts.AddUint16(extensionStatusRequest)
@ -766,12 +748,11 @@ func (m *serverHelloMsg) marshal() ([]byte, error) {
} }
}) })
m.raw, err = b.Bytes() return b.Bytes()
return m.raw, err
} }
func (m *serverHelloMsg) unmarshal(data []byte) bool { func (m *serverHelloMsg) unmarshal(data []byte) bool {
*m = serverHelloMsg{raw: data} *m = serverHelloMsg{original: data}
s := cryptobyte.String(data) s := cryptobyte.String(data)
if !s.Skip(4) || // message type and uint24 length field if !s.Skip(4) || // message type and uint24 length field
@ -888,18 +869,17 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool {
return true return true
} }
func (m *serverHelloMsg) originalBytes() []byte {
return m.original
}
type encryptedExtensionsMsg struct { type encryptedExtensionsMsg struct {
raw []byte
alpnProtocol string alpnProtocol string
quicTransportParameters []byte quicTransportParameters []byte
earlyData bool earlyData bool
} }
func (m *encryptedExtensionsMsg) marshal() ([]byte, error) { func (m *encryptedExtensionsMsg) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
var b cryptobyte.Builder var b cryptobyte.Builder
b.AddUint8(typeEncryptedExtensions) b.AddUint8(typeEncryptedExtensions)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
@ -929,13 +909,11 @@ func (m *encryptedExtensionsMsg) marshal() ([]byte, error) {
}) })
}) })
var err error return b.Bytes()
m.raw, err = b.Bytes()
return m.raw, err
} }
func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool { func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool {
*m = encryptedExtensionsMsg{raw: data} *m = encryptedExtensionsMsg{}
s := cryptobyte.String(data) s := cryptobyte.String(data)
var extensions cryptobyte.String var extensions cryptobyte.String
@ -998,15 +976,10 @@ func (m *endOfEarlyDataMsg) unmarshal(data []byte) bool {
} }
type keyUpdateMsg struct { type keyUpdateMsg struct {
raw []byte
updateRequested bool updateRequested bool
} }
func (m *keyUpdateMsg) marshal() ([]byte, error) { func (m *keyUpdateMsg) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
var b cryptobyte.Builder var b cryptobyte.Builder
b.AddUint8(typeKeyUpdate) b.AddUint8(typeKeyUpdate)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
@ -1017,13 +990,10 @@ func (m *keyUpdateMsg) marshal() ([]byte, error) {
} }
}) })
var err error return b.Bytes()
m.raw, err = b.Bytes()
return m.raw, err
} }
func (m *keyUpdateMsg) unmarshal(data []byte) bool { func (m *keyUpdateMsg) unmarshal(data []byte) bool {
m.raw = data
s := cryptobyte.String(data) s := cryptobyte.String(data)
var updateRequested uint8 var updateRequested uint8
@ -1043,7 +1013,6 @@ func (m *keyUpdateMsg) unmarshal(data []byte) bool {
} }
type newSessionTicketMsgTLS13 struct { type newSessionTicketMsgTLS13 struct {
raw []byte
lifetime uint32 lifetime uint32
ageAdd uint32 ageAdd uint32
nonce []byte nonce []byte
@ -1052,10 +1021,6 @@ type newSessionTicketMsgTLS13 struct {
} }
func (m *newSessionTicketMsgTLS13) marshal() ([]byte, error) { func (m *newSessionTicketMsgTLS13) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
var b cryptobyte.Builder var b cryptobyte.Builder
b.AddUint8(typeNewSessionTicket) b.AddUint8(typeNewSessionTicket)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
@ -1078,13 +1043,11 @@ func (m *newSessionTicketMsgTLS13) marshal() ([]byte, error) {
}) })
}) })
var err error return b.Bytes()
m.raw, err = b.Bytes()
return m.raw, err
} }
func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool { func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool {
*m = newSessionTicketMsgTLS13{raw: data} *m = newSessionTicketMsgTLS13{}
s := cryptobyte.String(data) s := cryptobyte.String(data)
var extensions cryptobyte.String var extensions cryptobyte.String
@ -1125,7 +1088,6 @@ func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool {
} }
type certificateRequestMsgTLS13 struct { type certificateRequestMsgTLS13 struct {
raw []byte
ocspStapling bool ocspStapling bool
scts bool scts bool
supportedSignatureAlgorithms []SignatureScheme supportedSignatureAlgorithms []SignatureScheme
@ -1134,10 +1096,6 @@ type certificateRequestMsgTLS13 struct {
} }
func (m *certificateRequestMsgTLS13) marshal() ([]byte, error) { func (m *certificateRequestMsgTLS13) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
var b cryptobyte.Builder var b cryptobyte.Builder
b.AddUint8(typeCertificateRequest) b.AddUint8(typeCertificateRequest)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
@ -1194,13 +1152,11 @@ func (m *certificateRequestMsgTLS13) marshal() ([]byte, error) {
}) })
}) })
var err error return b.Bytes()
m.raw, err = b.Bytes()
return m.raw, err
} }
func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool { func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool {
*m = certificateRequestMsgTLS13{raw: data} *m = certificateRequestMsgTLS13{}
s := cryptobyte.String(data) s := cryptobyte.String(data)
var context, extensions cryptobyte.String var context, extensions cryptobyte.String
@ -1276,15 +1232,10 @@ func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool {
} }
type certificateMsg struct { type certificateMsg struct {
raw []byte
certificates [][]byte certificates [][]byte
} }
func (m *certificateMsg) marshal() ([]byte, error) { func (m *certificateMsg) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
var i int var i int
for _, slice := range m.certificates { for _, slice := range m.certificates {
i += len(slice) i += len(slice)
@ -1311,8 +1262,7 @@ func (m *certificateMsg) marshal() ([]byte, error) {
y = y[3+len(slice):] y = y[3+len(slice):]
} }
m.raw = x return x, nil
return m.raw, nil
} }
func (m *certificateMsg) unmarshal(data []byte) bool { func (m *certificateMsg) unmarshal(data []byte) bool {
@ -1320,7 +1270,6 @@ func (m *certificateMsg) unmarshal(data []byte) bool {
return false return false
} }
m.raw = data
certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6]) certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6])
if uint32(len(data)) != certsLen+7 { if uint32(len(data)) != certsLen+7 {
return false return false
@ -1353,17 +1302,12 @@ func (m *certificateMsg) unmarshal(data []byte) bool {
} }
type certificateMsgTLS13 struct { type certificateMsgTLS13 struct {
raw []byte
certificate Certificate certificate Certificate
ocspStapling bool ocspStapling bool
scts bool scts bool
} }
func (m *certificateMsgTLS13) marshal() ([]byte, error) { func (m *certificateMsgTLS13) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
var b cryptobyte.Builder var b cryptobyte.Builder
b.AddUint8(typeCertificate) b.AddUint8(typeCertificate)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
@ -1379,9 +1323,7 @@ func (m *certificateMsgTLS13) marshal() ([]byte, error) {
marshalCertificate(b, certificate) marshalCertificate(b, certificate)
}) })
var err error return b.Bytes()
m.raw, err = b.Bytes()
return m.raw, err
} }
func marshalCertificate(b *cryptobyte.Builder, certificate Certificate) { func marshalCertificate(b *cryptobyte.Builder, certificate Certificate) {
@ -1422,7 +1364,7 @@ func marshalCertificate(b *cryptobyte.Builder, certificate Certificate) {
} }
func (m *certificateMsgTLS13) unmarshal(data []byte) bool { func (m *certificateMsgTLS13) unmarshal(data []byte) bool {
*m = certificateMsgTLS13{raw: data} *m = certificateMsgTLS13{}
s := cryptobyte.String(data) s := cryptobyte.String(data)
var context cryptobyte.String var context cryptobyte.String
@ -1500,14 +1442,10 @@ func unmarshalCertificate(s *cryptobyte.String, certificate *Certificate) bool {
} }
type serverKeyExchangeMsg struct { type serverKeyExchangeMsg struct {
raw []byte
key []byte key []byte
} }
func (m *serverKeyExchangeMsg) marshal() ([]byte, error) { func (m *serverKeyExchangeMsg) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
length := len(m.key) length := len(m.key)
x := make([]byte, length+4) x := make([]byte, length+4)
x[0] = typeServerKeyExchange x[0] = typeServerKeyExchange
@ -1516,12 +1454,10 @@ func (m *serverKeyExchangeMsg) marshal() ([]byte, error) {
x[3] = uint8(length) x[3] = uint8(length)
copy(x[4:], m.key) copy(x[4:], m.key)
m.raw = x
return x, nil return x, nil
} }
func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool { func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool {
m.raw = data
if len(data) < 4 { if len(data) < 4 {
return false return false
} }
@ -1530,15 +1466,10 @@ func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool {
} }
type certificateStatusMsg struct { type certificateStatusMsg struct {
raw []byte
response []byte response []byte
} }
func (m *certificateStatusMsg) marshal() ([]byte, error) { func (m *certificateStatusMsg) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
var b cryptobyte.Builder var b cryptobyte.Builder
b.AddUint8(typeCertificateStatus) b.AddUint8(typeCertificateStatus)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
@ -1548,13 +1479,10 @@ func (m *certificateStatusMsg) marshal() ([]byte, error) {
}) })
}) })
var err error return b.Bytes()
m.raw, err = b.Bytes()
return m.raw, err
} }
func (m *certificateStatusMsg) unmarshal(data []byte) bool { func (m *certificateStatusMsg) unmarshal(data []byte) bool {
m.raw = data
s := cryptobyte.String(data) s := cryptobyte.String(data)
var statusType uint8 var statusType uint8
@ -1580,14 +1508,10 @@ func (m *serverHelloDoneMsg) unmarshal(data []byte) bool {
} }
type clientKeyExchangeMsg struct { type clientKeyExchangeMsg struct {
raw []byte
ciphertext []byte ciphertext []byte
} }
func (m *clientKeyExchangeMsg) marshal() ([]byte, error) { func (m *clientKeyExchangeMsg) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
length := len(m.ciphertext) length := len(m.ciphertext)
x := make([]byte, length+4) x := make([]byte, length+4)
x[0] = typeClientKeyExchange x[0] = typeClientKeyExchange
@ -1596,12 +1520,10 @@ func (m *clientKeyExchangeMsg) marshal() ([]byte, error) {
x[3] = uint8(length) x[3] = uint8(length)
copy(x[4:], m.ciphertext) copy(x[4:], m.ciphertext)
m.raw = x
return x, nil return x, nil
} }
func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool { func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool {
m.raw = data
if len(data) < 4 { if len(data) < 4 {
return false return false
} }
@ -1614,28 +1536,20 @@ func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool {
} }
type finishedMsg struct { type finishedMsg struct {
raw []byte
verifyData []byte verifyData []byte
} }
func (m *finishedMsg) marshal() ([]byte, error) { func (m *finishedMsg) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
var b cryptobyte.Builder var b cryptobyte.Builder
b.AddUint8(typeFinished) b.AddUint8(typeFinished)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(m.verifyData) b.AddBytes(m.verifyData)
}) })
var err error return b.Bytes()
m.raw, err = b.Bytes()
return m.raw, err
} }
func (m *finishedMsg) unmarshal(data []byte) bool { func (m *finishedMsg) unmarshal(data []byte) bool {
m.raw = data
s := cryptobyte.String(data) s := cryptobyte.String(data)
return s.Skip(1) && return s.Skip(1) &&
readUint24LengthPrefixed(&s, &m.verifyData) && readUint24LengthPrefixed(&s, &m.verifyData) &&
@ -1643,7 +1557,6 @@ func (m *finishedMsg) unmarshal(data []byte) bool {
} }
type certificateRequestMsg struct { type certificateRequestMsg struct {
raw []byte
// hasSignatureAlgorithm indicates whether this message includes a list of // hasSignatureAlgorithm indicates whether this message includes a list of
// supported signature algorithms. This change was introduced with TLS 1.2. // supported signature algorithms. This change was introduced with TLS 1.2.
hasSignatureAlgorithm bool hasSignatureAlgorithm bool
@ -1654,10 +1567,6 @@ type certificateRequestMsg struct {
} }
func (m *certificateRequestMsg) marshal() ([]byte, error) { func (m *certificateRequestMsg) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
// See RFC 4346, Section 7.4.4. // See RFC 4346, Section 7.4.4.
length := 1 + len(m.certificateTypes) + 2 length := 1 + len(m.certificateTypes) + 2
casLength := 0 casLength := 0
@ -1704,13 +1613,10 @@ func (m *certificateRequestMsg) marshal() ([]byte, error) {
y = y[len(ca):] y = y[len(ca):]
} }
m.raw = x return x, nil
return m.raw, nil
} }
func (m *certificateRequestMsg) unmarshal(data []byte) bool { func (m *certificateRequestMsg) unmarshal(data []byte) bool {
m.raw = data
if len(data) < 5 { if len(data) < 5 {
return false return false
} }
@ -1785,17 +1691,12 @@ func (m *certificateRequestMsg) unmarshal(data []byte) bool {
} }
type certificateVerifyMsg struct { type certificateVerifyMsg struct {
raw []byte
hasSignatureAlgorithm bool // format change introduced in TLS 1.2 hasSignatureAlgorithm bool // format change introduced in TLS 1.2
signatureAlgorithm SignatureScheme signatureAlgorithm SignatureScheme
signature []byte signature []byte
} }
func (m *certificateVerifyMsg) marshal() ([]byte, error) { func (m *certificateVerifyMsg) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
var b cryptobyte.Builder var b cryptobyte.Builder
b.AddUint8(typeCertificateVerify) b.AddUint8(typeCertificateVerify)
b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
@ -1807,13 +1708,10 @@ func (m *certificateVerifyMsg) marshal() ([]byte, error) {
}) })
}) })
var err error return b.Bytes()
m.raw, err = b.Bytes()
return m.raw, err
} }
func (m *certificateVerifyMsg) unmarshal(data []byte) bool { func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
m.raw = data
s := cryptobyte.String(data) s := cryptobyte.String(data)
if !s.Skip(4) { // message type and uint24 length field if !s.Skip(4) { // message type and uint24 length field
@ -1828,15 +1726,10 @@ func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
} }
type newSessionTicketMsg struct { type newSessionTicketMsg struct {
raw []byte
ticket []byte ticket []byte
} }
func (m *newSessionTicketMsg) marshal() ([]byte, error) { func (m *newSessionTicketMsg) marshal() ([]byte, error) {
if m.raw != nil {
return m.raw, nil
}
// See RFC 5077, Section 3.3. // See RFC 5077, Section 3.3.
ticketLen := len(m.ticket) ticketLen := len(m.ticket)
length := 2 + 4 + ticketLen length := 2 + 4 + ticketLen
@ -1849,14 +1742,10 @@ func (m *newSessionTicketMsg) marshal() ([]byte, error) {
x[9] = uint8(ticketLen) x[9] = uint8(ticketLen)
copy(x[10:], m.ticket) copy(x[10:], m.ticket)
m.raw = x return x, nil
return m.raw, nil
} }
func (m *newSessionTicketMsg) unmarshal(data []byte) bool { func (m *newSessionTicketMsg) unmarshal(data []byte) bool {
m.raw = data
if len(data) < 10 { if len(data) < 10 {
return false return false
} }
@ -1891,9 +1780,25 @@ type transcriptHash interface {
Write([]byte) (int, error) Write([]byte) (int, error)
} }
// transcriptMsg is a helper used to marshal and hash messages which typically // transcriptMsg is a helper used to hash messages which are not hashed when
// are not written to the wire, and as such aren't hashed during Conn.writeRecord. // they are read from, or written to, the wire. This is typically the case for
// messages which are either not sent, or need to be hashed out of order from
// when they are read/written.
//
// For most messages, the message is marshalled using their marshal method,
// since their wire representation is idempotent. For clientHelloMsg and
// serverHelloMsg, we store the original wire representation of the message and
// use that for hashing, since unmarshal/marshal are not idempotent due to
// extension ordering and other malleable fields, which may cause differences
// between what was received and what we marshal.
func transcriptMsg(msg handshakeMessage, h transcriptHash) error { func transcriptMsg(msg handshakeMessage, h transcriptHash) error {
if msgWithOrig, ok := msg.(handshakeMessageWithOriginalBytes); ok {
if orig := msgWithOrig.originalBytes(); orig != nil {
h.Write(msgWithOrig.originalBytes())
return nil
}
}
data, err := msg.marshal() data, err := msg.marshal()
if err != nil { if err != nil {
return err return err

View file

@ -53,7 +53,7 @@ func TestMarshalUnmarshal(t *testing.T) {
for i, m := range tests { for i, m := range tests {
ty := reflect.ValueOf(m).Type() ty := reflect.ValueOf(m).Type()
t.Run(ty.String(), func(t *testing.T) {
n := 100 n := 100
if testing.Short() { if testing.Short() {
n = 5 n = 5
@ -71,12 +71,23 @@ func TestMarshalUnmarshal(t *testing.T) {
t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled) t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
break break
} }
m.marshal() // to fill any marshal cache in the message
if m, ok := m.(*SessionState); ok { if m, ok := m.(*SessionState); ok {
m.activeCertHandles = nil m.activeCertHandles = nil
} }
// clientHelloMsg and serverHelloMsg, when unmarshalled, store
// their original representation, for later use in the handshake
// transcript. In order to prevent DeepEqual from failing since
// we didn't create the original message via unmarshalling, nil
// the field.
switch t := m.(type) {
case *clientHelloMsg:
t.original = nil
case *serverHelloMsg:
t.original = nil
}
if !reflect.DeepEqual(m1, m) { if !reflect.DeepEqual(m1, m) {
t.Errorf("#%d got:%#v want:%#v %x", i, m, m1, marshaled) t.Errorf("#%d got:%#v want:%#v %x", i, m, m1, marshaled)
break break
@ -96,6 +107,7 @@ func TestMarshalUnmarshal(t *testing.T) {
} }
} }
} }
})
} }
} }