impl: QUIC Header mimicry

This commit is contained in:
Gaukas Wang 2023-07-29 13:17:27 -06:00
parent 1429e6718b
commit 251b3afe6e
No known key found for this signature in database
GPG key ID: 9E2F8986D76F8B5D
11 changed files with 262 additions and 30 deletions

View file

@ -41,6 +41,7 @@ type client struct {
// make it possible to mock connection ID for initial generation in the tests
var generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
var generateConnectionIDForInitialWithLength = protocol.GenerateConnectionIDForInitialWithLen
// DialAddr establishes a new QUIC connection to a server.
// It resolves the address, and then creates a new UDP connection to dial the QUIC server.
@ -169,11 +170,25 @@ func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config
tlsConf = tlsConf.Clone()
}
// // [UQUIC]
// if config.SrcConnIDLength != 0 {
// connIDLen := config.SrcConnIDLength
// connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: connIDLen}
// }
srcConnID, err := connIDGenerator.GenerateConnectionID()
if err != nil {
return nil, err
}
destConnID, err := generateConnectionIDForInitial()
var destConnID protocol.ConnectionID
// [UQUIC]
if config.DestConnIDLength > 0 {
destConnID, err = generateConnectionIDForInitialWithLength(config.DestConnIDLength)
} else {
destConnID, err = generateConnectionIDForInitial()
}
// [/UQUIC]
if err != nil {
return nil, err
}
@ -189,6 +204,8 @@ func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config
version: config.Versions[0],
handshakeChan: make(chan struct{}),
logger: utils.DefaultLogger.WithPrefix("client"),
initialPacketNumber: protocol.PacketNumber(config.InitPacketNumber), // [UQUIC]
}
return c, nil
}

View file

@ -109,6 +109,8 @@ func populateConfig(config *Config) *Config {
maxIncomingUniStreams = 0
}
// [UQUIC] TODO: reverse populate config from TransportParameters
return &Config{
GetConfigForClient: config.GetConfigForClient,
Versions: versions,
@ -131,5 +133,23 @@ func populateConfig(config *Config) *Config {
DisableVersionNegotiationPackets: config.DisableVersionNegotiationPackets,
Allow0RTT: config.Allow0RTT,
Tracer: config.Tracer,
// [UQUIC]
SrcConnIDLength: config.SrcConnIDLength,
DestConnIDLength: config.DestConnIDLength,
InitPacketNumber: config.InitPacketNumber,
TransportParameters: config.TransportParameters,
InitPacketNumberLength: config.InitPacketNumberLength,
}
}
const (
// PacketNumberLen1 is a packet number length of 1 byte
PacketNumberLen1 protocol.PacketNumberLen = 1
// PacketNumberLen2 is a packet number length of 2 bytes
PacketNumberLen2 protocol.PacketNumberLen = 2
// PacketNumberLen3 is a packet number length of 3 bytes
PacketNumberLen3 protocol.PacketNumberLen = 3
// PacketNumberLen4 is a packet number length of 4 bytes
PacketNumberLen4 protocol.PacketNumberLen = 4
)

View file

@ -35,6 +35,8 @@ type connIDManager struct {
addStatelessResetToken func(protocol.StatelessResetToken)
removeStatelessResetToken func(protocol.StatelessResetToken)
queueControlFrame func(wire.Frame)
connectionIDLimit uint64 // [UQUIC] custom Connection ID limit
}
func newConnIDManager(
@ -59,7 +61,13 @@ func (h *connIDManager) Add(f *wire.NewConnectionIDFrame) error {
if err := h.add(f); err != nil {
return err
}
if h.queue.Len() >= protocol.MaxActiveConnectionIDs {
connIDLimit := h.connectionIDLimit
if connIDLimit == 0 {
connIDLimit = protocol.MaxActiveConnectionIDs
}
if uint64(h.queue.Len()) >= connIDLimit {
return &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError}
}
return nil
@ -183,6 +191,11 @@ func (h *connIDManager) SetStatelessResetToken(token protocol.StatelessResetToke
h.addStatelessResetToken(token)
}
// [UQUIC]
func (h *connIDManager) SetConnectionIDLimit(limit uint64) {
h.connectionIDLimit = limit
}
func (h *connIDManager) SentPacket() {
h.packetsSinceLastChange++
}

View file

@ -366,6 +366,7 @@ var newClientConnection = func(
runner.RemoveResetToken,
s.queueControlFrame,
)
s.connIDGenerator = newConnIDGenerator(
srcConnID,
nil,
@ -388,31 +389,45 @@ var newClientConnection = func(
s.tracer,
s.logger,
)
if conf.InitPacketNumberLength != 0 {
ackhandler.SetInitialPacketNumberLength(s.sentPacketHandler, conf.InitPacketNumberLength)
}
s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize)
oneRTTStream := newCryptoStream()
params := &wire.TransportParameters{
InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
InitialMaxStreamDataUni: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
InitialMaxData: protocol.ByteCount(s.config.InitialConnectionReceiveWindow),
MaxIdleTimeout: s.config.MaxIdleTimeout,
MaxBidiStreamNum: protocol.StreamNum(s.config.MaxIncomingStreams),
MaxUniStreamNum: protocol.StreamNum(s.config.MaxIncomingUniStreams),
MaxAckDelay: protocol.MaxAckDelayInclGranularity,
AckDelayExponent: protocol.AckDelayExponent,
DisableActiveMigration: true,
// For interoperability with quic-go versions before May 2023, this value must be set to a value
// different from protocol.DefaultActiveConnectionIDLimit.
// If set to the default value, it will be omitted from the transport parameters, which will make
// old quic-go versions interpret it as 0, instead of the default value of 2.
// See https://github.com/quic-go/quic-go/pull/3806.
ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs,
InitialSourceConnectionID: srcConnID,
}
if s.config.EnableDatagrams {
params.MaxDatagramFrameSize = protocol.MaxDatagramFrameSize
var params *wire.TransportParameters
if s.config.TransportParameters != nil {
params = &wire.TransportParameters{
InitialSourceConnectionID: srcConnID,
}
params.PopulateFromUQUIC(s.config.TransportParameters)
s.connIDManager.SetConnectionIDLimit(params.ActiveConnectionIDLimit)
} else {
params.MaxDatagramFrameSize = protocol.InvalidByteCount
params = &wire.TransportParameters{
InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
InitialMaxStreamDataUni: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
InitialMaxData: protocol.ByteCount(s.config.InitialConnectionReceiveWindow),
MaxIdleTimeout: s.config.MaxIdleTimeout,
MaxBidiStreamNum: protocol.StreamNum(s.config.MaxIncomingStreams),
MaxUniStreamNum: protocol.StreamNum(s.config.MaxIncomingUniStreams),
MaxAckDelay: protocol.MaxAckDelayInclGranularity,
AckDelayExponent: protocol.AckDelayExponent,
DisableActiveMigration: true,
// For interoperability with quic-go versions before May 2023, this value must be set to a value
// different from protocol.DefaultActiveConnectionIDLimit.
// If set to the default value, it will be omitted from the transport parameters, which will make
// old quic-go versions interpret it as 0, instead of the default value of 2.
// See https://github.com/quic-go/quic-go/pull/3806.
ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs,
InitialSourceConnectionID: srcConnID,
}
if s.config.EnableDatagrams {
params.MaxDatagramFrameSize = protocol.MaxDatagramFrameSize
} else {
params.MaxDatagramFrameSize = protocol.InvalidByteCount
}
}
if s.tracer != nil {
s.tracer.SentTransportParameters(params)
@ -1371,6 +1386,7 @@ func (s *connection) handleHandshakeEvents() error {
case handshake.EventDiscard0RTTKeys:
err = s.dropEncryptionLevel(protocol.Encryption0RTT)
case handshake.EventWriteInitialData:
// fmt.Printf("write initial data: %x\n", ev.Data) // [UQUIC] debug
_, err = s.initialStream.Write(ev.Data)
case handshake.EventWriteHandshakeData:
_, err = s.handshakeStream.Write(ev.Data)
@ -2050,6 +2066,9 @@ func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time
}
s.connIDManager.SentPacket()
s.sendQueue.Send(packet.buffer, packet.buffer.Len())
// [UQUIC]
// fmt.Printf("sendPackedCoalescedPacket:Sending %d bytes\n", packet.buffer.Len())
// fmt.Printf("sendPackedCoalescedPacket: %v\n", packet.buffer.Data)
return nil
}

View file

@ -11,6 +11,7 @@ import (
"github.com/quic-go/quic-go/internal/handshake"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/logging"
"github.com/quic-go/quic-go/transportparameters"
)
// The StreamID is the ID of a QUIC stream.
@ -332,6 +333,13 @@ type Config struct {
// Enable QUIC datagram support (RFC 9221).
EnableDatagrams bool
Tracer func(context.Context, logging.Perspective, ConnectionID) logging.ConnectionTracer
// TransportParameters override other transport parameters set by the Config.
TransportParameters transportparameters.TransportParameters // [UQUIC]
SrcConnIDLength int // [UQUIC]
DestConnIDLength int // [UQUIC]
InitPacketNumber uint64 // [UQUIC]
InitPacketNumberLength protocol.PacketNumberLen // [UQUIC]
}
type ClientHelloInfo struct {

View file

@ -96,6 +96,8 @@ type sentPacketHandler struct {
tracer logging.ConnectionTracer
logger utils.Logger
initialPacketNumberLength protocol.PacketNumberLen // [UQUIC]
}
var (
@ -136,6 +138,12 @@ func newSentPacketHandler(
}
}
func SetInitialPacketNumberLength(h SentPacketHandler, pnLen protocol.PacketNumberLen) {
if sph, ok := h.(*sentPacketHandler); ok {
sph.initialPacketNumberLength = pnLen
}
}
func (h *sentPacketHandler) removeFromBytesInFlight(p *packet) {
if p.includedInBytesInFlight {
if p.Length > h.bytesInFlight {
@ -716,6 +724,12 @@ func (h *sentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel)
pnSpace := h.getPacketNumberSpace(encLevel)
pn := pnSpace.pns.Peek()
// See section 17.1 of RFC 9000.
// [UQUIC] This kinda breaks PN length mimicry.
if encLevel == protocol.EncryptionInitial && h.initialPacketNumberLength != 0 {
return pn, h.initialPacketNumberLength
}
return pn, protocol.GetPacketNumberLengthForHeader(pn, pnSpace.largestAcked)
}

View file

@ -250,6 +250,14 @@ func (h *cryptoSetup) handleEvent(ev qtls.QUICEvent) (done bool, err error) {
h.rejected0RTT()
return false, nil
case qtls.QUICWriteData:
// [UQUIC] debug
// if ev.Level == qtls.QUICEncryptionLevelInitial {
// fmt.Printf("Init: %x\n", ev.Data)
// } else if ev.Level == qtls.QUICEncryptionLevelHandshake {
// fmt.Printf("HS: %x\n", ev.Data)
// } else {
// fmt.Printf("APP: %x\n", ev.Data)
// }
h.WriteRecord(ev.Level, ev.Data)
return false, nil
case qtls.QUICHandshakeDone:

View file

@ -68,6 +68,11 @@ func GenerateConnectionIDForInitial() (ConnectionID, error) {
return GenerateConnectionID(l)
}
// [UQUIC]
func GenerateConnectionIDForInitialWithLen(l int) (ConnectionID, error) {
return GenerateConnectionID(l)
}
// ReadConnectionID reads a connection ID of length len from the given io.Reader.
// It returns io.EOF if there are not enough bytes to read.
func ReadConnectionID(r io.Reader, l int) (ConnectionID, error) {

View file

@ -15,6 +15,7 @@ import (
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/quicvarint"
"github.com/quic-go/quic-go/transportparameters"
)
// AdditionalTransportParametersClient are additional transport parameters that will be added
@ -88,6 +89,9 @@ type TransportParameters struct {
ActiveConnectionIDLimit uint64
MaxDatagramFrameSize protocol.ByteCount
// only used internally
ClientOverride transportparameters.TransportParameters // [UQUIC]
}
// Unmarshal the transport parameters
@ -325,6 +329,13 @@ func (p *TransportParameters) readNumericTransportParameter(
// Marshal the transport parameters
func (p *TransportParameters) Marshal(pers protocol.Perspective) []byte {
// [UQUIC]
if p.ClientOverride != nil {
fmt.Println("ClientOverride!!")
return p.ClientOverride.Marshal()
}
// [/UQUIC]
// Typical Transport Parameters consume around 110 bytes, depending on the exact values,
// especially the lengths of the Connection IDs.
// Allocate 256 bytes, so we won't have to grow the slice in any case.
@ -502,3 +513,47 @@ func (p *TransportParameters) String() string {
logString += "}"
return fmt.Sprintf(logString, logParams...)
}
func (tp *TransportParameters) PopulateFromUQUIC(quicparams transportparameters.TransportParameters) {
for pIdx, param := range quicparams {
switch param.ID() {
case uint64(maxIdleTimeoutParameterID):
tp.MaxIdleTimeout = time.Duration(param.(transportparameters.MaxIdleTimeout)) * time.Millisecond
case uint64(initialMaxDataParameterID):
tp.InitialMaxData = protocol.ByteCount(param.(transportparameters.InitialMaxData))
case uint64(initialMaxStreamDataBidiLocalParameterID):
tp.InitialMaxStreamDataBidiLocal = protocol.ByteCount(param.(transportparameters.InitialMaxStreamDataBidiLocal))
case uint64(initialMaxStreamDataBidiRemoteParameterID):
tp.InitialMaxStreamDataBidiRemote = protocol.ByteCount(param.(transportparameters.InitialMaxStreamDataBidiRemote))
case uint64(initialMaxStreamDataUniParameterID):
tp.InitialMaxStreamDataUni = protocol.ByteCount(param.(transportparameters.InitialMaxStreamDataUni))
case uint64(initialMaxStreamsBidiParameterID):
tp.MaxBidiStreamNum = protocol.StreamNum(param.(transportparameters.InitialMaxStreamsBidi))
case uint64(initialMaxStreamsUniParameterID):
tp.MaxUniStreamNum = protocol.StreamNum(param.(transportparameters.InitialMaxStreamsUni))
case uint64(maxAckDelayParameterID):
tp.MaxAckDelay = time.Duration(param.(transportparameters.MaxAckDelay)) * time.Millisecond
case uint64(disableActiveMigrationParameterID):
tp.DisableActiveMigration = true
case uint64(activeConnectionIDLimitParameterID):
tp.ActiveConnectionIDLimit = uint64(param.(transportparameters.ActiveConnectionIDLimit))
case uint64(initialSourceConnectionIDParameterID):
srcConnIDOverride, ok := param.(transportparameters.InitialSourceConnectionID)
if ok {
if len(srcConnIDOverride) > 0 { // when nil/empty, will leave default srcConnID
tp.InitialSourceConnectionID = protocol.ParseConnectionID(srcConnIDOverride)
} else {
// reversely populate the transport parameter, for it must be written to network
quicparams[pIdx] = transportparameters.InitialSourceConnectionID(tp.InitialSourceConnectionID.Bytes())
}
}
case uint64(maxDatagramFrameSizeParameterID):
tp.MaxDatagramFrameSize = protocol.ByteCount(param.(transportparameters.MaxDatagramFrameSize))
default:
// ignore unknown parameters
continue
}
}
tp.ClientOverride = quicparams
}

View file

@ -332,6 +332,11 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.
if initialPayload.length > 0 {
size += p.longHeaderPacketLength(initialHdr, initialPayload, v) + protocol.ByteCount(initialSealer.Overhead())
}
// // [UQUIC]
// if len(initialPayload.frames) > 0 {
// fmt.Printf("onlyAck: %t, PackCoalescedPacket: %v\n", onlyAck, initialPayload.frames[0].Frame)
// }
}
// Add a Handshake packet.
@ -396,12 +401,24 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.
longHdrPackets: make([]*longHeaderPacket, 0, 3),
}
if initialPayload.length > 0 {
padding := p.initialPaddingLen(initialPayload.frames, size, maxPacketSize)
cont, err := p.appendLongHeaderPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer, v)
if err != nil {
return nil, err
if onlyAck || len(initialPayload.frames) == 0 {
// padding := p.initialPaddingLen(initialPayload.frames, size, maxPacketSize)
// cont, err := p.appendLongHeaderPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer, v)
// if err != nil {
// return nil, err
// }
// packet.longHdrPackets = append(packet.longHdrPackets, cont)
return nil, nil // [UQUIC] not to send the ACK frame for Initial
} else { // [UQUIC]
cont, err := p.appendLongHeaderPacketExternalPadding(buffer, initialHdr, initialPayload, protocol.EncryptionInitial, initialSealer, v)
if err != nil {
return nil, err
}
// fmt.Printf("!onlyAck buffer: %v\n", buffer.Data)
packet.longHdrPackets = append(packet.longHdrPackets, cont)
}
packet.longHdrPackets = append(packet.longHdrPackets, cont)
}
if handshakePayload.length > 0 {
cont, err := p.appendLongHeaderPacket(buffer, handshakeHdr, handshakePayload, 0, protocol.EncryptionHandshake, handshakeSealer, v)
@ -523,6 +540,10 @@ func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, en
}
} else if s.HasData() {
cf := s.PopCryptoFrame(maxPacketSize)
if encLevel == protocol.EncryptionInitial {
fmt.Printf("PopCryptoFrame for Initial: %x...\n", cf.Data[:5])
// fmt.Printf("Offset: %d\n", cf.Offset)
}
pl.frames = []ackhandler.Frame{{Frame: cf, Handler: handler}}
pl.length += cf.Length(v)
}
@ -671,6 +692,7 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, m
return nil, err
}
hdr, pl = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionInitial, false, true, v)
fmt.Printf("MaybePackProbePacket: %x\n", pl.frames[0])
case protocol.EncryptionHandshake:
var err error
sealer, err = p.cryptoSetup.GetHandshakeSealer()
@ -749,6 +771,11 @@ func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire
paddingLen = 4 - pnLen - pl.length
}
paddingLen += padding
if encLevel == protocol.EncryptionInitial {
paddingLen = 0
}
header.Length = pnLen + protocol.ByteCount(sealer.Overhead()) + pl.length + paddingLen
startLen := len(buffer.Data)
@ -758,7 +785,6 @@ func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire
return nil, err
}
payloadOffset := protocol.ByteCount(len(raw))
raw, err = p.appendPacketPayload(raw, pl, paddingLen, v)
if err != nil {
return nil, err
@ -778,6 +804,41 @@ func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire
}, nil
}
// [UQUIC]
func (p *packetPacker) appendLongHeaderPacketExternalPadding(buffer *packetBuffer, header *wire.ExtendedHeader, pl payload, encLevel protocol.EncryptionLevel, sealer sealer, v protocol.VersionNumber) (*longHeaderPacket, error) {
pnLen := protocol.ByteCount(header.PacketNumberLen)
header.Length = pnLen + protocol.ByteCount(sealer.Overhead()) + pl.length
startLen := len(buffer.Data)
raw := buffer.Data[startLen:]
raw, err := header.Append(raw, v)
if err != nil {
return nil, err
}
payloadOffset := protocol.ByteCount(len(raw))
raw, err = p.appendPacketPayload(raw, pl, 0, v)
if err != nil {
return nil, err
}
raw = p.encryptPacket(raw, sealer, header.PacketNumber, payloadOffset, pnLen)
buffer.Data = buffer.Data[:len(buffer.Data)+len(raw)]
// [UQUIC]
// append zero to buffer.Data until 1200 bytes
buffer.Data = append(buffer.Data, make([]byte, 1357-len(buffer.Data))...)
if pn := p.pnManager.PopPacketNumber(encLevel); pn != header.PacketNumber {
return nil, fmt.Errorf("packetPacker BUG: Peeked and Popped packet numbers do not match: expected %d, got %d", pn, header.PacketNumber)
}
return &longHeaderPacket{
header: header,
ack: pl.ack,
frames: pl.frames,
streamFrames: pl.streamFrames,
length: protocol.ByteCount(len(raw)),
}, nil
}
func (p *packetPacker) appendShortHeaderPacket(
buffer *packetBuffer,
connID protocol.ConnectionID,

View file

@ -152,6 +152,12 @@ func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config
return nil, err
}
conf = populateConfig(conf)
// [UQUIC]
if conf.SrcConnIDLength != 0 {
t.ConnectionIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: conf.SrcConnIDLength}
}
if err := t.init(t.isSingleUse); err != nil {
return nil, err
}
@ -170,6 +176,12 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C
return nil, err
}
conf = populateConfig(conf)
// [UQUIC]
if conf.SrcConnIDLength != 0 {
t.ConnectionIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: conf.SrcConnIDLength}
}
if err := t.init(t.isSingleUse); err != nil {
return nil, err
}