speed up marshaling of transport parameters (#3531)

The speedup comes from multiple sources:
1. We now preallocate a byte slice, instead of appending multiple times.
2. Marshaling into a byte slice is faster than using a bytes.Buffer.
3. quicvarint.Write allocates, while quicvarint.Append doesn't.
This commit is contained in:
Marten Seemann 2022-08-29 23:05:52 +03:00 committed by GitHub
parent 3f1adfd822
commit 7023b52e13
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 74 additions and 78 deletions

View file

@ -1,7 +1,6 @@
package main package main
import ( import (
"bytes"
"log" "log"
"math" "math"
"math/rand" "math/rand"
@ -78,9 +77,7 @@ func main() {
} }
data = tp.Marshal(pers) data = tp.Marshal(pers)
} else { } else {
b := &bytes.Buffer{} data = tp.MarshalForSessionTicket(nil)
tp.MarshalForSessionTicket(b)
data = b.Bytes()
} }
if err := helper.WriteCorpusFileWithPrefix("corpus", data, transportparameters.PrefixLen); err != nil { if err := helper.WriteCorpusFileWithPrefix("corpus", data, transportparameters.PrefixLen); err != nil {
log.Fatal(err) log.Fatal(err)

View file

@ -51,10 +51,9 @@ func fuzzTransportParametersForSessionTicket(data []byte) int {
if err := tp.UnmarshalFromSessionTicket(bytes.NewReader(data)); err != nil { if err := tp.UnmarshalFromSessionTicket(bytes.NewReader(data)); err != nil {
return 0 return 0
} }
buf := &bytes.Buffer{} b := tp.MarshalForSessionTicket(nil)
tp.MarshalForSessionTicket(buf)
tp2 := &wire.TransportParameters{} tp2 := &wire.TransportParameters{}
if err := tp2.UnmarshalFromSessionTicket(bytes.NewReader(buf.Bytes())); err != nil { if err := tp2.UnmarshalFromSessionTicket(bytes.NewReader(b)); err != nil {
panic(err) panic(err)
} }
return 1 return 1

View file

@ -432,11 +432,10 @@ func (h *cryptoSetup) handleTransportParameters(data []byte) {
// must be called after receiving the transport parameters // must be called after receiving the transport parameters
func (h *cryptoSetup) marshalDataForSessionState() []byte { func (h *cryptoSetup) marshalDataForSessionState() []byte {
buf := &bytes.Buffer{} b := make([]byte, 0, 256)
quicvarint.Write(buf, clientSessionStateRevision) b = quicvarint.Append(b, clientSessionStateRevision)
quicvarint.Write(buf, uint64(h.rttStats.SmoothedRTT().Microseconds())) b = quicvarint.Append(b, uint64(h.rttStats.SmoothedRTT().Microseconds()))
h.peerParams.MarshalForSessionTicket(buf) return h.peerParams.MarshalForSessionTicket(b)
return buf.Bytes()
} }
func (h *cryptoSetup) handleDataFromSessionState(data []byte) { func (h *cryptoSetup) handleDataFromSessionState(data []byte) {

View file

@ -18,11 +18,10 @@ type sessionTicket struct {
} }
func (t *sessionTicket) Marshal() []byte { func (t *sessionTicket) Marshal() []byte {
b := &bytes.Buffer{} b := make([]byte, 0, 256)
quicvarint.Write(b, sessionTicketRevision) b = quicvarint.Append(b, sessionTicketRevision)
quicvarint.Write(b, uint64(t.RTT.Microseconds())) b = quicvarint.Append(b, uint64(t.RTT.Microseconds()))
t.Parameters.MarshalForSessionTicket(b) return t.Parameters.MarshalForSessionTicket(b)
return b.Bytes()
} }
func (t *sessionTicket) Unmarshal(b []byte) error { func (t *sessionTicket) Unmarshal(b []byte) error {

View file

@ -486,10 +486,9 @@ var _ = Describe("Transport Parameters", func() {
ActiveConnectionIDLimit: getRandomValue(), ActiveConnectionIDLimit: getRandomValue(),
} }
Expect(params.ValidFor0RTT(params)).To(BeTrue()) Expect(params.ValidFor0RTT(params)).To(BeTrue())
b := &bytes.Buffer{} b := params.MarshalForSessionTicket(nil)
params.MarshalForSessionTicket(b)
var tp TransportParameters var tp TransportParameters
Expect(tp.UnmarshalFromSessionTicket(bytes.NewReader(b.Bytes()))).To(Succeed()) Expect(tp.UnmarshalFromSessionTicket(bytes.NewReader(b))).To(Succeed())
Expect(tp.InitialMaxStreamDataBidiLocal).To(Equal(params.InitialMaxStreamDataBidiLocal)) Expect(tp.InitialMaxStreamDataBidiLocal).To(Equal(params.InitialMaxStreamDataBidiLocal))
Expect(tp.InitialMaxStreamDataBidiRemote).To(Equal(params.InitialMaxStreamDataBidiRemote)) Expect(tp.InitialMaxStreamDataBidiRemote).To(Equal(params.InitialMaxStreamDataBidiRemote))
Expect(tp.InitialMaxStreamDataUni).To(Equal(params.InitialMaxStreamDataUni)) Expect(tp.InitialMaxStreamDataUni).To(Equal(params.InitialMaxStreamDataUni))
@ -506,9 +505,7 @@ var _ = Describe("Transport Parameters", func() {
It("rejects the parameters if the version changed", func() { It("rejects the parameters if the version changed", func() {
var p TransportParameters var p TransportParameters
buf := &bytes.Buffer{} data := p.MarshalForSessionTicket(nil)
p.MarshalForSessionTicket(buf)
data := buf.Bytes()
b := &bytes.Buffer{} b := &bytes.Buffer{}
quicvarint.Write(b, transportParameterMarshalingVersion+1) quicvarint.Write(b, transportParameterMarshalingVersion+1)
b.Write(data[quicvarint.Len(transportParameterMarshalingVersion):]) b.Write(data[quicvarint.Len(transportParameterMarshalingVersion):])

View file

@ -2,6 +2,7 @@ package wire
import ( import (
"bytes" "bytes"
"encoding/binary"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -313,94 +314,98 @@ func (p *TransportParameters) readNumericTransportParameter(
// Marshal the transport parameters // Marshal the transport parameters
func (p *TransportParameters) Marshal(pers protocol.Perspective) []byte { func (p *TransportParameters) Marshal(pers protocol.Perspective) []byte {
b := &bytes.Buffer{} // Typical Transport Parameters consume around 110 bytes, depending on the exact values,
// especially the lengths of the Connection IDs.
// Allocate 256 bytes, so we won't have to grow the slice in any case.
b := make([]byte, 0, 256)
// add a greased value // add a greased value
quicvarint.Write(b, uint64(27+31*rand.Intn(100))) b = quicvarint.Append(b, uint64(27+31*rand.Intn(100)))
length := rand.Intn(16) length := rand.Intn(16)
randomData := make([]byte, length) b = quicvarint.Append(b, uint64(length))
rand.Read(randomData) b = b[:len(b)+length]
quicvarint.Write(b, uint64(length)) rand.Read(b[len(b)-length:])
b.Write(randomData)
// initial_max_stream_data_bidi_local // initial_max_stream_data_bidi_local
p.marshalVarintParam(b, initialMaxStreamDataBidiLocalParameterID, uint64(p.InitialMaxStreamDataBidiLocal)) b = p.marshalVarintParam(b, initialMaxStreamDataBidiLocalParameterID, uint64(p.InitialMaxStreamDataBidiLocal))
// initial_max_stream_data_bidi_remote // initial_max_stream_data_bidi_remote
p.marshalVarintParam(b, initialMaxStreamDataBidiRemoteParameterID, uint64(p.InitialMaxStreamDataBidiRemote)) b = p.marshalVarintParam(b, initialMaxStreamDataBidiRemoteParameterID, uint64(p.InitialMaxStreamDataBidiRemote))
// initial_max_stream_data_uni // initial_max_stream_data_uni
p.marshalVarintParam(b, initialMaxStreamDataUniParameterID, uint64(p.InitialMaxStreamDataUni)) b = p.marshalVarintParam(b, initialMaxStreamDataUniParameterID, uint64(p.InitialMaxStreamDataUni))
// initial_max_data // initial_max_data
p.marshalVarintParam(b, initialMaxDataParameterID, uint64(p.InitialMaxData)) b = p.marshalVarintParam(b, initialMaxDataParameterID, uint64(p.InitialMaxData))
// initial_max_bidi_streams // initial_max_bidi_streams
p.marshalVarintParam(b, initialMaxStreamsBidiParameterID, uint64(p.MaxBidiStreamNum)) b = p.marshalVarintParam(b, initialMaxStreamsBidiParameterID, uint64(p.MaxBidiStreamNum))
// initial_max_uni_streams // initial_max_uni_streams
p.marshalVarintParam(b, initialMaxStreamsUniParameterID, uint64(p.MaxUniStreamNum)) b = p.marshalVarintParam(b, initialMaxStreamsUniParameterID, uint64(p.MaxUniStreamNum))
// idle_timeout // idle_timeout
p.marshalVarintParam(b, maxIdleTimeoutParameterID, uint64(p.MaxIdleTimeout/time.Millisecond)) b = p.marshalVarintParam(b, maxIdleTimeoutParameterID, uint64(p.MaxIdleTimeout/time.Millisecond))
// max_packet_size // max_packet_size
p.marshalVarintParam(b, maxUDPPayloadSizeParameterID, uint64(protocol.MaxPacketBufferSize)) b = p.marshalVarintParam(b, maxUDPPayloadSizeParameterID, uint64(protocol.MaxPacketBufferSize))
// max_ack_delay // max_ack_delay
// Only send it if is different from the default value. // Only send it if is different from the default value.
if p.MaxAckDelay != protocol.DefaultMaxAckDelay { if p.MaxAckDelay != protocol.DefaultMaxAckDelay {
p.marshalVarintParam(b, maxAckDelayParameterID, uint64(p.MaxAckDelay/time.Millisecond)) b = p.marshalVarintParam(b, maxAckDelayParameterID, uint64(p.MaxAckDelay/time.Millisecond))
} }
// ack_delay_exponent // ack_delay_exponent
// Only send it if is different from the default value. // Only send it if is different from the default value.
if p.AckDelayExponent != protocol.DefaultAckDelayExponent { if p.AckDelayExponent != protocol.DefaultAckDelayExponent {
p.marshalVarintParam(b, ackDelayExponentParameterID, uint64(p.AckDelayExponent)) b = p.marshalVarintParam(b, ackDelayExponentParameterID, uint64(p.AckDelayExponent))
} }
// disable_active_migration // disable_active_migration
if p.DisableActiveMigration { if p.DisableActiveMigration {
quicvarint.Write(b, uint64(disableActiveMigrationParameterID)) b = quicvarint.Append(b, uint64(disableActiveMigrationParameterID))
quicvarint.Write(b, 0) b = quicvarint.Append(b, 0)
} }
if pers == protocol.PerspectiveServer { if pers == protocol.PerspectiveServer {
// stateless_reset_token // stateless_reset_token
if p.StatelessResetToken != nil { if p.StatelessResetToken != nil {
quicvarint.Write(b, uint64(statelessResetTokenParameterID)) b = quicvarint.Append(b, uint64(statelessResetTokenParameterID))
quicvarint.Write(b, 16) b = quicvarint.Append(b, 16)
b.Write(p.StatelessResetToken[:]) b = append(b, p.StatelessResetToken[:]...)
} }
// original_destination_connection_id // original_destination_connection_id
quicvarint.Write(b, uint64(originalDestinationConnectionIDParameterID)) b = quicvarint.Append(b, uint64(originalDestinationConnectionIDParameterID))
quicvarint.Write(b, uint64(p.OriginalDestinationConnectionID.Len())) b = quicvarint.Append(b, uint64(p.OriginalDestinationConnectionID.Len()))
b.Write(p.OriginalDestinationConnectionID.Bytes()) b = append(b, p.OriginalDestinationConnectionID.Bytes()...)
// preferred_address // preferred_address
if p.PreferredAddress != nil { if p.PreferredAddress != nil {
quicvarint.Write(b, uint64(preferredAddressParameterID)) b = quicvarint.Append(b, uint64(preferredAddressParameterID))
quicvarint.Write(b, 4+2+16+2+1+uint64(p.PreferredAddress.ConnectionID.Len())+16) b = quicvarint.Append(b, 4+2+16+2+1+uint64(p.PreferredAddress.ConnectionID.Len())+16)
ipv4 := p.PreferredAddress.IPv4 ipv4 := p.PreferredAddress.IPv4
b.Write(ipv4[len(ipv4)-4:]) b = append(b, ipv4[len(ipv4)-4:]...)
utils.BigEndian.WriteUint16(b, p.PreferredAddress.IPv4Port) b = append(b, []byte{0, 0}...)
b.Write(p.PreferredAddress.IPv6) binary.BigEndian.PutUint16(b[len(b)-2:], p.PreferredAddress.IPv4Port)
utils.BigEndian.WriteUint16(b, p.PreferredAddress.IPv6Port) b = append(b, p.PreferredAddress.IPv6...)
b.WriteByte(uint8(p.PreferredAddress.ConnectionID.Len())) b = append(b, []byte{0, 0}...)
b.Write(p.PreferredAddress.ConnectionID.Bytes()) binary.BigEndian.PutUint16(b[len(b)-2:], p.PreferredAddress.IPv6Port)
b.Write(p.PreferredAddress.StatelessResetToken[:]) b = append(b, uint8(p.PreferredAddress.ConnectionID.Len()))
b = append(b, p.PreferredAddress.ConnectionID.Bytes()...)
b = append(b, p.PreferredAddress.StatelessResetToken[:]...)
} }
} }
// active_connection_id_limit // active_connection_id_limit
p.marshalVarintParam(b, activeConnectionIDLimitParameterID, p.ActiveConnectionIDLimit) b = p.marshalVarintParam(b, activeConnectionIDLimitParameterID, p.ActiveConnectionIDLimit)
// initial_source_connection_id // initial_source_connection_id
quicvarint.Write(b, uint64(initialSourceConnectionIDParameterID)) b = quicvarint.Append(b, uint64(initialSourceConnectionIDParameterID))
quicvarint.Write(b, uint64(p.InitialSourceConnectionID.Len())) b = quicvarint.Append(b, uint64(p.InitialSourceConnectionID.Len()))
b.Write(p.InitialSourceConnectionID.Bytes()) b = append(b, p.InitialSourceConnectionID.Bytes()...)
// retry_source_connection_id // retry_source_connection_id
if pers == protocol.PerspectiveServer && p.RetrySourceConnectionID != nil { if pers == protocol.PerspectiveServer && p.RetrySourceConnectionID != nil {
quicvarint.Write(b, uint64(retrySourceConnectionIDParameterID)) b = quicvarint.Append(b, uint64(retrySourceConnectionIDParameterID))
quicvarint.Write(b, uint64(p.RetrySourceConnectionID.Len())) b = quicvarint.Append(b, uint64(p.RetrySourceConnectionID.Len()))
b.Write(p.RetrySourceConnectionID.Bytes()) b = append(b, p.RetrySourceConnectionID.Bytes()...)
} }
if p.MaxDatagramFrameSize != protocol.InvalidByteCount { if p.MaxDatagramFrameSize != protocol.InvalidByteCount {
p.marshalVarintParam(b, maxDatagramFrameSizeParameterID, uint64(p.MaxDatagramFrameSize)) b = p.marshalVarintParam(b, maxDatagramFrameSizeParameterID, uint64(p.MaxDatagramFrameSize))
} }
return b.Bytes() return b
} }
func (p *TransportParameters) marshalVarintParam(b *bytes.Buffer, id transportParameterID, val uint64) { func (p *TransportParameters) marshalVarintParam(b []byte, id transportParameterID, val uint64) []byte {
quicvarint.Write(b, uint64(id)) b = quicvarint.Append(b, uint64(id))
quicvarint.Write(b, uint64(quicvarint.Len(val))) b = quicvarint.Append(b, uint64(quicvarint.Len(val)))
quicvarint.Write(b, val) return quicvarint.Append(b, val)
} }
// MarshalForSessionTicket marshals the transport parameters we save in the session ticket. // MarshalForSessionTicket marshals the transport parameters we save in the session ticket.
@ -411,23 +416,23 @@ func (p *TransportParameters) marshalVarintParam(b *bytes.Buffer, id transportPa
// if the transport parameters changed. // if the transport parameters changed.
// Since the session ticket is encrypted, the serialization format is defined by the server. // Since the session ticket is encrypted, the serialization format is defined by the server.
// For convenience, we use the same format that we also use for sending the transport parameters. // For convenience, we use the same format that we also use for sending the transport parameters.
func (p *TransportParameters) MarshalForSessionTicket(b *bytes.Buffer) { func (p *TransportParameters) MarshalForSessionTicket(b []byte) []byte {
quicvarint.Write(b, transportParameterMarshalingVersion) b = quicvarint.Append(b, transportParameterMarshalingVersion)
// initial_max_stream_data_bidi_local // initial_max_stream_data_bidi_local
p.marshalVarintParam(b, initialMaxStreamDataBidiLocalParameterID, uint64(p.InitialMaxStreamDataBidiLocal)) b = p.marshalVarintParam(b, initialMaxStreamDataBidiLocalParameterID, uint64(p.InitialMaxStreamDataBidiLocal))
// initial_max_stream_data_bidi_remote // initial_max_stream_data_bidi_remote
p.marshalVarintParam(b, initialMaxStreamDataBidiRemoteParameterID, uint64(p.InitialMaxStreamDataBidiRemote)) b = p.marshalVarintParam(b, initialMaxStreamDataBidiRemoteParameterID, uint64(p.InitialMaxStreamDataBidiRemote))
// initial_max_stream_data_uni // initial_max_stream_data_uni
p.marshalVarintParam(b, initialMaxStreamDataUniParameterID, uint64(p.InitialMaxStreamDataUni)) b = p.marshalVarintParam(b, initialMaxStreamDataUniParameterID, uint64(p.InitialMaxStreamDataUni))
// initial_max_data // initial_max_data
p.marshalVarintParam(b, initialMaxDataParameterID, uint64(p.InitialMaxData)) b = p.marshalVarintParam(b, initialMaxDataParameterID, uint64(p.InitialMaxData))
// initial_max_bidi_streams // initial_max_bidi_streams
p.marshalVarintParam(b, initialMaxStreamsBidiParameterID, uint64(p.MaxBidiStreamNum)) b = p.marshalVarintParam(b, initialMaxStreamsBidiParameterID, uint64(p.MaxBidiStreamNum))
// initial_max_uni_streams // initial_max_uni_streams
p.marshalVarintParam(b, initialMaxStreamsUniParameterID, uint64(p.MaxUniStreamNum)) b = p.marshalVarintParam(b, initialMaxStreamsUniParameterID, uint64(p.MaxUniStreamNum))
// active_connection_id_limit // active_connection_id_limit
p.marshalVarintParam(b, activeConnectionIDLimitParameterID, p.ActiveConnectionIDLimit) return p.marshalVarintParam(b, activeConnectionIDLimitParameterID, p.ActiveConnectionIDLimit)
} }
// UnmarshalFromSessionTicket unmarshals transport parameters from a session ticket. // UnmarshalFromSessionTicket unmarshals transport parameters from a session ticket.