impl: QUIC Header mimicry

This commit is contained in:
Gaukas Wang 2023-07-29 13:17:27 -06:00
parent 1429e6718b
commit 251b3afe6e
No known key found for this signature in database
GPG key ID: 9E2F8986D76F8B5D
11 changed files with 262 additions and 30 deletions

View file

@ -41,6 +41,7 @@ type client struct {
// make it possible to mock connection ID for initial generation in the tests // make it possible to mock connection ID for initial generation in the tests
var generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial var generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
var generateConnectionIDForInitialWithLength = protocol.GenerateConnectionIDForInitialWithLen
// DialAddr establishes a new QUIC connection to a server. // DialAddr establishes a new QUIC connection to a server.
// It resolves the address, and then creates a new UDP connection to dial the QUIC server. // It resolves the address, and then creates a new UDP connection to dial the QUIC server.
@ -169,11 +170,25 @@ func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config
tlsConf = tlsConf.Clone() tlsConf = tlsConf.Clone()
} }
// // [UQUIC]
// if config.SrcConnIDLength != 0 {
// connIDLen := config.SrcConnIDLength
// connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: connIDLen}
// }
srcConnID, err := connIDGenerator.GenerateConnectionID() srcConnID, err := connIDGenerator.GenerateConnectionID()
if err != nil { if err != nil {
return nil, err return nil, err
} }
destConnID, err := generateConnectionIDForInitial()
var destConnID protocol.ConnectionID
// [UQUIC]
if config.DestConnIDLength > 0 {
destConnID, err = generateConnectionIDForInitialWithLength(config.DestConnIDLength)
} else {
destConnID, err = generateConnectionIDForInitial()
}
// [/UQUIC]
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -189,6 +204,8 @@ func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config
version: config.Versions[0], version: config.Versions[0],
handshakeChan: make(chan struct{}), handshakeChan: make(chan struct{}),
logger: utils.DefaultLogger.WithPrefix("client"), logger: utils.DefaultLogger.WithPrefix("client"),
initialPacketNumber: protocol.PacketNumber(config.InitPacketNumber), // [UQUIC]
} }
return c, nil return c, nil
} }

View file

@ -109,6 +109,8 @@ func populateConfig(config *Config) *Config {
maxIncomingUniStreams = 0 maxIncomingUniStreams = 0
} }
// [UQUIC] TODO: reverse populate config from TransportParameters
return &Config{ return &Config{
GetConfigForClient: config.GetConfigForClient, GetConfigForClient: config.GetConfigForClient,
Versions: versions, Versions: versions,
@ -131,5 +133,23 @@ func populateConfig(config *Config) *Config {
DisableVersionNegotiationPackets: config.DisableVersionNegotiationPackets, DisableVersionNegotiationPackets: config.DisableVersionNegotiationPackets,
Allow0RTT: config.Allow0RTT, Allow0RTT: config.Allow0RTT,
Tracer: config.Tracer, Tracer: config.Tracer,
// [UQUIC]
SrcConnIDLength: config.SrcConnIDLength,
DestConnIDLength: config.DestConnIDLength,
InitPacketNumber: config.InitPacketNumber,
TransportParameters: config.TransportParameters,
InitPacketNumberLength: config.InitPacketNumberLength,
} }
} }
const (
// PacketNumberLen1 is a packet number length of 1 byte
PacketNumberLen1 protocol.PacketNumberLen = 1
// PacketNumberLen2 is a packet number length of 2 bytes
PacketNumberLen2 protocol.PacketNumberLen = 2
// PacketNumberLen3 is a packet number length of 3 bytes
PacketNumberLen3 protocol.PacketNumberLen = 3
// PacketNumberLen4 is a packet number length of 4 bytes
PacketNumberLen4 protocol.PacketNumberLen = 4
)

View file

@ -35,6 +35,8 @@ type connIDManager struct {
addStatelessResetToken func(protocol.StatelessResetToken) addStatelessResetToken func(protocol.StatelessResetToken)
removeStatelessResetToken func(protocol.StatelessResetToken) removeStatelessResetToken func(protocol.StatelessResetToken)
queueControlFrame func(wire.Frame) queueControlFrame func(wire.Frame)
connectionIDLimit uint64 // [UQUIC] custom Connection ID limit
} }
func newConnIDManager( func newConnIDManager(
@ -59,7 +61,13 @@ func (h *connIDManager) Add(f *wire.NewConnectionIDFrame) error {
if err := h.add(f); err != nil { if err := h.add(f); err != nil {
return err return err
} }
if h.queue.Len() >= protocol.MaxActiveConnectionIDs {
connIDLimit := h.connectionIDLimit
if connIDLimit == 0 {
connIDLimit = protocol.MaxActiveConnectionIDs
}
if uint64(h.queue.Len()) >= connIDLimit {
return &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError} return &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError}
} }
return nil return nil
@ -183,6 +191,11 @@ func (h *connIDManager) SetStatelessResetToken(token protocol.StatelessResetToke
h.addStatelessResetToken(token) h.addStatelessResetToken(token)
} }
// [UQUIC]
func (h *connIDManager) SetConnectionIDLimit(limit uint64) {
h.connectionIDLimit = limit
}
func (h *connIDManager) SentPacket() { func (h *connIDManager) SentPacket() {
h.packetsSinceLastChange++ h.packetsSinceLastChange++
} }

View file

@ -366,6 +366,7 @@ var newClientConnection = func(
runner.RemoveResetToken, runner.RemoveResetToken,
s.queueControlFrame, s.queueControlFrame,
) )
s.connIDGenerator = newConnIDGenerator( s.connIDGenerator = newConnIDGenerator(
srcConnID, srcConnID,
nil, nil,
@ -388,31 +389,45 @@ var newClientConnection = func(
s.tracer, s.tracer,
s.logger, s.logger,
) )
if conf.InitPacketNumberLength != 0 {
ackhandler.SetInitialPacketNumberLength(s.sentPacketHandler, conf.InitPacketNumberLength)
}
s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize) s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize)
oneRTTStream := newCryptoStream() oneRTTStream := newCryptoStream()
params := &wire.TransportParameters{
InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow), var params *wire.TransportParameters
InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow), if s.config.TransportParameters != nil {
InitialMaxStreamDataUni: protocol.ByteCount(s.config.InitialStreamReceiveWindow), params = &wire.TransportParameters{
InitialMaxData: protocol.ByteCount(s.config.InitialConnectionReceiveWindow), InitialSourceConnectionID: srcConnID,
MaxIdleTimeout: s.config.MaxIdleTimeout, }
MaxBidiStreamNum: protocol.StreamNum(s.config.MaxIncomingStreams), params.PopulateFromUQUIC(s.config.TransportParameters)
MaxUniStreamNum: protocol.StreamNum(s.config.MaxIncomingUniStreams), s.connIDManager.SetConnectionIDLimit(params.ActiveConnectionIDLimit)
MaxAckDelay: protocol.MaxAckDelayInclGranularity,
AckDelayExponent: protocol.AckDelayExponent,
DisableActiveMigration: true,
// For interoperability with quic-go versions before May 2023, this value must be set to a value
// different from protocol.DefaultActiveConnectionIDLimit.
// If set to the default value, it will be omitted from the transport parameters, which will make
// old quic-go versions interpret it as 0, instead of the default value of 2.
// See https://github.com/quic-go/quic-go/pull/3806.
ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs,
InitialSourceConnectionID: srcConnID,
}
if s.config.EnableDatagrams {
params.MaxDatagramFrameSize = protocol.MaxDatagramFrameSize
} else { } else {
params.MaxDatagramFrameSize = protocol.InvalidByteCount params = &wire.TransportParameters{
InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
InitialMaxStreamDataUni: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
InitialMaxData: protocol.ByteCount(s.config.InitialConnectionReceiveWindow),
MaxIdleTimeout: s.config.MaxIdleTimeout,
MaxBidiStreamNum: protocol.StreamNum(s.config.MaxIncomingStreams),
MaxUniStreamNum: protocol.StreamNum(s.config.MaxIncomingUniStreams),
MaxAckDelay: protocol.MaxAckDelayInclGranularity,
AckDelayExponent: protocol.AckDelayExponent,
DisableActiveMigration: true,
// For interoperability with quic-go versions before May 2023, this value must be set to a value
// different from protocol.DefaultActiveConnectionIDLimit.
// If set to the default value, it will be omitted from the transport parameters, which will make
// old quic-go versions interpret it as 0, instead of the default value of 2.
// See https://github.com/quic-go/quic-go/pull/3806.
ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs,
InitialSourceConnectionID: srcConnID,
}
if s.config.EnableDatagrams {
params.MaxDatagramFrameSize = protocol.MaxDatagramFrameSize
} else {
params.MaxDatagramFrameSize = protocol.InvalidByteCount
}
} }
if s.tracer != nil { if s.tracer != nil {
s.tracer.SentTransportParameters(params) s.tracer.SentTransportParameters(params)
@ -1371,6 +1386,7 @@ func (s *connection) handleHandshakeEvents() error {
case handshake.EventDiscard0RTTKeys: case handshake.EventDiscard0RTTKeys:
err = s.dropEncryptionLevel(protocol.Encryption0RTT) err = s.dropEncryptionLevel(protocol.Encryption0RTT)
case handshake.EventWriteInitialData: case handshake.EventWriteInitialData:
// fmt.Printf("write initial data: %x\n", ev.Data) // [UQUIC] debug
_, err = s.initialStream.Write(ev.Data) _, err = s.initialStream.Write(ev.Data)
case handshake.EventWriteHandshakeData: case handshake.EventWriteHandshakeData:
_, err = s.handshakeStream.Write(ev.Data) _, err = s.handshakeStream.Write(ev.Data)
@ -2050,6 +2066,9 @@ func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time
} }
s.connIDManager.SentPacket() s.connIDManager.SentPacket()
s.sendQueue.Send(packet.buffer, packet.buffer.Len()) s.sendQueue.Send(packet.buffer, packet.buffer.Len())
// [UQUIC]
// fmt.Printf("sendPackedCoalescedPacket:Sending %d bytes\n", packet.buffer.Len())
// fmt.Printf("sendPackedCoalescedPacket: %v\n", packet.buffer.Data)
return nil return nil
} }

View file

@ -11,6 +11,7 @@ import (
"github.com/quic-go/quic-go/internal/handshake" "github.com/quic-go/quic-go/internal/handshake"
"github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/logging" "github.com/quic-go/quic-go/logging"
"github.com/quic-go/quic-go/transportparameters"
) )
// The StreamID is the ID of a QUIC stream. // The StreamID is the ID of a QUIC stream.
@ -332,6 +333,13 @@ type Config struct {
// Enable QUIC datagram support (RFC 9221). // Enable QUIC datagram support (RFC 9221).
EnableDatagrams bool EnableDatagrams bool
Tracer func(context.Context, logging.Perspective, ConnectionID) logging.ConnectionTracer Tracer func(context.Context, logging.Perspective, ConnectionID) logging.ConnectionTracer
// TransportParameters override other transport parameters set by the Config.
TransportParameters transportparameters.TransportParameters // [UQUIC]
SrcConnIDLength int // [UQUIC]
DestConnIDLength int // [UQUIC]
InitPacketNumber uint64 // [UQUIC]
InitPacketNumberLength protocol.PacketNumberLen // [UQUIC]
} }
type ClientHelloInfo struct { type ClientHelloInfo struct {

View file

@ -96,6 +96,8 @@ type sentPacketHandler struct {
tracer logging.ConnectionTracer tracer logging.ConnectionTracer
logger utils.Logger logger utils.Logger
initialPacketNumberLength protocol.PacketNumberLen // [UQUIC]
} }
var ( var (
@ -136,6 +138,12 @@ func newSentPacketHandler(
} }
} }
func SetInitialPacketNumberLength(h SentPacketHandler, pnLen protocol.PacketNumberLen) {
if sph, ok := h.(*sentPacketHandler); ok {
sph.initialPacketNumberLength = pnLen
}
}
func (h *sentPacketHandler) removeFromBytesInFlight(p *packet) { func (h *sentPacketHandler) removeFromBytesInFlight(p *packet) {
if p.includedInBytesInFlight { if p.includedInBytesInFlight {
if p.Length > h.bytesInFlight { if p.Length > h.bytesInFlight {
@ -716,6 +724,12 @@ func (h *sentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel)
pnSpace := h.getPacketNumberSpace(encLevel) pnSpace := h.getPacketNumberSpace(encLevel)
pn := pnSpace.pns.Peek() pn := pnSpace.pns.Peek()
// See section 17.1 of RFC 9000. // See section 17.1 of RFC 9000.
// [UQUIC] This kinda breaks PN length mimicry.
if encLevel == protocol.EncryptionInitial && h.initialPacketNumberLength != 0 {
return pn, h.initialPacketNumberLength
}
return pn, protocol.GetPacketNumberLengthForHeader(pn, pnSpace.largestAcked) return pn, protocol.GetPacketNumberLengthForHeader(pn, pnSpace.largestAcked)
} }

View file

@ -250,6 +250,14 @@ func (h *cryptoSetup) handleEvent(ev qtls.QUICEvent) (done bool, err error) {
h.rejected0RTT() h.rejected0RTT()
return false, nil return false, nil
case qtls.QUICWriteData: case qtls.QUICWriteData:
// [UQUIC] debug
// if ev.Level == qtls.QUICEncryptionLevelInitial {
// fmt.Printf("Init: %x\n", ev.Data)
// } else if ev.Level == qtls.QUICEncryptionLevelHandshake {
// fmt.Printf("HS: %x\n", ev.Data)
// } else {
// fmt.Printf("APP: %x\n", ev.Data)
// }
h.WriteRecord(ev.Level, ev.Data) h.WriteRecord(ev.Level, ev.Data)
return false, nil return false, nil
case qtls.QUICHandshakeDone: case qtls.QUICHandshakeDone:

View file

@ -68,6 +68,11 @@ func GenerateConnectionIDForInitial() (ConnectionID, error) {
return GenerateConnectionID(l) return GenerateConnectionID(l)
} }
// [UQUIC]
func GenerateConnectionIDForInitialWithLen(l int) (ConnectionID, error) {
return GenerateConnectionID(l)
}
// ReadConnectionID reads a connection ID of length len from the given io.Reader. // ReadConnectionID reads a connection ID of length len from the given io.Reader.
// It returns io.EOF if there are not enough bytes to read. // It returns io.EOF if there are not enough bytes to read.
func ReadConnectionID(r io.Reader, l int) (ConnectionID, error) { func ReadConnectionID(r io.Reader, l int) (ConnectionID, error) {

View file

@ -15,6 +15,7 @@ import (
"github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/quicvarint" "github.com/quic-go/quic-go/quicvarint"
"github.com/quic-go/quic-go/transportparameters"
) )
// AdditionalTransportParametersClient are additional transport parameters that will be added // AdditionalTransportParametersClient are additional transport parameters that will be added
@ -88,6 +89,9 @@ type TransportParameters struct {
ActiveConnectionIDLimit uint64 ActiveConnectionIDLimit uint64
MaxDatagramFrameSize protocol.ByteCount MaxDatagramFrameSize protocol.ByteCount
// only used internally
ClientOverride transportparameters.TransportParameters // [UQUIC]
} }
// Unmarshal the transport parameters // Unmarshal the transport parameters
@ -325,6 +329,13 @@ func (p *TransportParameters) readNumericTransportParameter(
// Marshal the transport parameters // Marshal the transport parameters
func (p *TransportParameters) Marshal(pers protocol.Perspective) []byte { func (p *TransportParameters) Marshal(pers protocol.Perspective) []byte {
// [UQUIC]
if p.ClientOverride != nil {
fmt.Println("ClientOverride!!")
return p.ClientOverride.Marshal()
}
// [/UQUIC]
// Typical Transport Parameters consume around 110 bytes, depending on the exact values, // Typical Transport Parameters consume around 110 bytes, depending on the exact values,
// especially the lengths of the Connection IDs. // especially the lengths of the Connection IDs.
// Allocate 256 bytes, so we won't have to grow the slice in any case. // Allocate 256 bytes, so we won't have to grow the slice in any case.
@ -502,3 +513,47 @@ func (p *TransportParameters) String() string {
logString += "}" logString += "}"
return fmt.Sprintf(logString, logParams...) return fmt.Sprintf(logString, logParams...)
} }
func (tp *TransportParameters) PopulateFromUQUIC(quicparams transportparameters.TransportParameters) {
for pIdx, param := range quicparams {
switch param.ID() {
case uint64(maxIdleTimeoutParameterID):
tp.MaxIdleTimeout = time.Duration(param.(transportparameters.MaxIdleTimeout)) * time.Millisecond
case uint64(initialMaxDataParameterID):
tp.InitialMaxData = protocol.ByteCount(param.(transportparameters.InitialMaxData))
case uint64(initialMaxStreamDataBidiLocalParameterID):
tp.InitialMaxStreamDataBidiLocal = protocol.ByteCount(param.(transportparameters.InitialMaxStreamDataBidiLocal))
case uint64(initialMaxStreamDataBidiRemoteParameterID):
tp.InitialMaxStreamDataBidiRemote = protocol.ByteCount(param.(transportparameters.InitialMaxStreamDataBidiRemote))
case uint64(initialMaxStreamDataUniParameterID):
tp.InitialMaxStreamDataUni = protocol.ByteCount(param.(transportparameters.InitialMaxStreamDataUni))
case uint64(initialMaxStreamsBidiParameterID):
tp.MaxBidiStreamNum = protocol.StreamNum(param.(transportparameters.InitialMaxStreamsBidi))
case uint64(initialMaxStreamsUniParameterID):
tp.MaxUniStreamNum = protocol.StreamNum(param.(transportparameters.InitialMaxStreamsUni))
case uint64(maxAckDelayParameterID):
tp.MaxAckDelay = time.Duration(param.(transportparameters.MaxAckDelay)) * time.Millisecond
case uint64(disableActiveMigrationParameterID):
tp.DisableActiveMigration = true
case uint64(activeConnectionIDLimitParameterID):
tp.ActiveConnectionIDLimit = uint64(param.(transportparameters.ActiveConnectionIDLimit))
case uint64(initialSourceConnectionIDParameterID):
srcConnIDOverride, ok := param.(transportparameters.InitialSourceConnectionID)
if ok {
if len(srcConnIDOverride) > 0 { // when nil/empty, will leave default srcConnID
tp.InitialSourceConnectionID = protocol.ParseConnectionID(srcConnIDOverride)
} else {
// reversely populate the transport parameter, for it must be written to network
quicparams[pIdx] = transportparameters.InitialSourceConnectionID(tp.InitialSourceConnectionID.Bytes())
}
}
case uint64(maxDatagramFrameSizeParameterID):
tp.MaxDatagramFrameSize = protocol.ByteCount(param.(transportparameters.MaxDatagramFrameSize))
default:
// ignore unknown parameters
continue
}
}
tp.ClientOverride = quicparams
}

View file

@ -332,6 +332,11 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.
if initialPayload.length > 0 { if initialPayload.length > 0 {
size += p.longHeaderPacketLength(initialHdr, initialPayload, v) + protocol.ByteCount(initialSealer.Overhead()) size += p.longHeaderPacketLength(initialHdr, initialPayload, v) + protocol.ByteCount(initialSealer.Overhead())
} }
// // [UQUIC]
// if len(initialPayload.frames) > 0 {
// fmt.Printf("onlyAck: %t, PackCoalescedPacket: %v\n", onlyAck, initialPayload.frames[0].Frame)
// }
} }
// Add a Handshake packet. // Add a Handshake packet.
@ -396,12 +401,24 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.
longHdrPackets: make([]*longHeaderPacket, 0, 3), longHdrPackets: make([]*longHeaderPacket, 0, 3),
} }
if initialPayload.length > 0 { if initialPayload.length > 0 {
padding := p.initialPaddingLen(initialPayload.frames, size, maxPacketSize) if onlyAck || len(initialPayload.frames) == 0 {
cont, err := p.appendLongHeaderPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer, v) // padding := p.initialPaddingLen(initialPayload.frames, size, maxPacketSize)
if err != nil { // cont, err := p.appendLongHeaderPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer, v)
return nil, err // if err != nil {
// return nil, err
// }
// packet.longHdrPackets = append(packet.longHdrPackets, cont)
return nil, nil // [UQUIC] not to send the ACK frame for Initial
} else { // [UQUIC]
cont, err := p.appendLongHeaderPacketExternalPadding(buffer, initialHdr, initialPayload, protocol.EncryptionInitial, initialSealer, v)
if err != nil {
return nil, err
}
// fmt.Printf("!onlyAck buffer: %v\n", buffer.Data)
packet.longHdrPackets = append(packet.longHdrPackets, cont)
} }
packet.longHdrPackets = append(packet.longHdrPackets, cont)
} }
if handshakePayload.length > 0 { if handshakePayload.length > 0 {
cont, err := p.appendLongHeaderPacket(buffer, handshakeHdr, handshakePayload, 0, protocol.EncryptionHandshake, handshakeSealer, v) cont, err := p.appendLongHeaderPacket(buffer, handshakeHdr, handshakePayload, 0, protocol.EncryptionHandshake, handshakeSealer, v)
@ -523,6 +540,10 @@ func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, en
} }
} else if s.HasData() { } else if s.HasData() {
cf := s.PopCryptoFrame(maxPacketSize) cf := s.PopCryptoFrame(maxPacketSize)
if encLevel == protocol.EncryptionInitial {
fmt.Printf("PopCryptoFrame for Initial: %x...\n", cf.Data[:5])
// fmt.Printf("Offset: %d\n", cf.Offset)
}
pl.frames = []ackhandler.Frame{{Frame: cf, Handler: handler}} pl.frames = []ackhandler.Frame{{Frame: cf, Handler: handler}}
pl.length += cf.Length(v) pl.length += cf.Length(v)
} }
@ -671,6 +692,7 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, m
return nil, err return nil, err
} }
hdr, pl = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionInitial, false, true, v) hdr, pl = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionInitial, false, true, v)
fmt.Printf("MaybePackProbePacket: %x\n", pl.frames[0])
case protocol.EncryptionHandshake: case protocol.EncryptionHandshake:
var err error var err error
sealer, err = p.cryptoSetup.GetHandshakeSealer() sealer, err = p.cryptoSetup.GetHandshakeSealer()
@ -749,6 +771,11 @@ func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire
paddingLen = 4 - pnLen - pl.length paddingLen = 4 - pnLen - pl.length
} }
paddingLen += padding paddingLen += padding
if encLevel == protocol.EncryptionInitial {
paddingLen = 0
}
header.Length = pnLen + protocol.ByteCount(sealer.Overhead()) + pl.length + paddingLen header.Length = pnLen + protocol.ByteCount(sealer.Overhead()) + pl.length + paddingLen
startLen := len(buffer.Data) startLen := len(buffer.Data)
@ -758,7 +785,6 @@ func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire
return nil, err return nil, err
} }
payloadOffset := protocol.ByteCount(len(raw)) payloadOffset := protocol.ByteCount(len(raw))
raw, err = p.appendPacketPayload(raw, pl, paddingLen, v) raw, err = p.appendPacketPayload(raw, pl, paddingLen, v)
if err != nil { if err != nil {
return nil, err return nil, err
@ -778,6 +804,41 @@ func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire
}, nil }, nil
} }
// [UQUIC]
func (p *packetPacker) appendLongHeaderPacketExternalPadding(buffer *packetBuffer, header *wire.ExtendedHeader, pl payload, encLevel protocol.EncryptionLevel, sealer sealer, v protocol.VersionNumber) (*longHeaderPacket, error) {
pnLen := protocol.ByteCount(header.PacketNumberLen)
header.Length = pnLen + protocol.ByteCount(sealer.Overhead()) + pl.length
startLen := len(buffer.Data)
raw := buffer.Data[startLen:]
raw, err := header.Append(raw, v)
if err != nil {
return nil, err
}
payloadOffset := protocol.ByteCount(len(raw))
raw, err = p.appendPacketPayload(raw, pl, 0, v)
if err != nil {
return nil, err
}
raw = p.encryptPacket(raw, sealer, header.PacketNumber, payloadOffset, pnLen)
buffer.Data = buffer.Data[:len(buffer.Data)+len(raw)]
// [UQUIC]
// append zero to buffer.Data until 1200 bytes
buffer.Data = append(buffer.Data, make([]byte, 1357-len(buffer.Data))...)
if pn := p.pnManager.PopPacketNumber(encLevel); pn != header.PacketNumber {
return nil, fmt.Errorf("packetPacker BUG: Peeked and Popped packet numbers do not match: expected %d, got %d", pn, header.PacketNumber)
}
return &longHeaderPacket{
header: header,
ack: pl.ack,
frames: pl.frames,
streamFrames: pl.streamFrames,
length: protocol.ByteCount(len(raw)),
}, nil
}
func (p *packetPacker) appendShortHeaderPacket( func (p *packetPacker) appendShortHeaderPacket(
buffer *packetBuffer, buffer *packetBuffer,
connID protocol.ConnectionID, connID protocol.ConnectionID,

View file

@ -152,6 +152,12 @@ func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config
return nil, err return nil, err
} }
conf = populateConfig(conf) conf = populateConfig(conf)
// [UQUIC]
if conf.SrcConnIDLength != 0 {
t.ConnectionIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: conf.SrcConnIDLength}
}
if err := t.init(t.isSingleUse); err != nil { if err := t.init(t.isSingleUse); err != nil {
return nil, err return nil, err
} }
@ -170,6 +176,12 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C
return nil, err return nil, err
} }
conf = populateConfig(conf) conf = populateConfig(conf)
// [UQUIC]
if conf.SrcConnIDLength != 0 {
t.ConnectionIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: conf.SrcConnIDLength}
}
if err := t.init(t.isSingleUse); err != nil { if err := t.init(t.isSingleUse); err != nil {
return nil, err return nil, err
} }