From 20e2a487b8b13cc1513a864061d1a477c4af5a8f Mon Sep 17 00:00:00 2001 From: Gaukas Wang Date: Sun, 30 Jul 2023 20:01:07 -0600 Subject: [PATCH 1/3] wip: implement InitialSpec type (1/n) - TransportParameters are now set as a part of ClientHelloSpecs - Removes transportparameters package and uses tls.TransportParameters --- config.go | 11 +- connection.go | 130 ++++---- example/uquic/main.go | 63 ++-- interface.go | 10 +- internal/handshake/crypto_setup.go | 2 +- internal/wire/transport_parameters.go | 30 +- packet_packer.go | 50 ++- transportparameters/u_transport_parameters.go | 303 ------------------ .../u_transport_parameters_test.go | 71 ---- 9 files changed, 173 insertions(+), 497 deletions(-) delete mode 100644 transportparameters/u_transport_parameters.go delete mode 100644 transportparameters/u_transport_parameters_test.go diff --git a/config.go b/config.go index a562c30d..f478b979 100644 --- a/config.go +++ b/config.go @@ -136,18 +136,19 @@ func populateConfig(config *Config) *Config { SrcConnIDLength: config.SrcConnIDLength, DestConnIDLength: config.DestConnIDLength, InitPacketNumber: config.InitPacketNumber, - TransportParameters: config.TransportParameters, InitPacketNumberLength: config.InitPacketNumberLength, } } +type PacketNumberLen = protocol.PacketNumberLen + const ( // PacketNumberLen1 is a packet number length of 1 byte - PacketNumberLen1 protocol.PacketNumberLen = 1 + PacketNumberLen1 PacketNumberLen = 1 // PacketNumberLen2 is a packet number length of 2 bytes - PacketNumberLen2 protocol.PacketNumberLen = 2 + PacketNumberLen2 PacketNumberLen = 2 // PacketNumberLen3 is a packet number length of 3 bytes - PacketNumberLen3 protocol.PacketNumberLen = 3 + PacketNumberLen3 PacketNumberLen = 3 // PacketNumberLen4 is a packet number length of 4 bytes - PacketNumberLen4 protocol.PacketNumberLen = 4 + PacketNumberLen4 PacketNumberLen = 4 ) diff --git a/connection.go b/connection.go index 2941be60..06a9f10a 100644 --- a/connection.go +++ b/connection.go @@ -397,39 +397,31 @@ var newClientConnection = func( s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize) oneRTTStream := newCryptoStream() - var params *wire.TransportParameters - if s.config.TransportParameters != nil { - params = &wire.TransportParameters{ - InitialSourceConnectionID: srcConnID, - } - params.PopulateFromUQUIC(s.config.TransportParameters) - s.connIDManager.SetConnectionIDLimit(params.ActiveConnectionIDLimit) - } else { - params = &wire.TransportParameters{ - InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow), - InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow), - InitialMaxStreamDataUni: protocol.ByteCount(s.config.InitialStreamReceiveWindow), - InitialMaxData: protocol.ByteCount(s.config.InitialConnectionReceiveWindow), - MaxIdleTimeout: s.config.MaxIdleTimeout, - MaxBidiStreamNum: protocol.StreamNum(s.config.MaxIncomingStreams), - MaxUniStreamNum: protocol.StreamNum(s.config.MaxIncomingUniStreams), - MaxAckDelay: protocol.MaxAckDelayInclGranularity, - AckDelayExponent: protocol.AckDelayExponent, - DisableActiveMigration: true, - // For interoperability with quic-go versions before May 2023, this value must be set to a value - // different from protocol.DefaultActiveConnectionIDLimit. - // If set to the default value, it will be omitted from the transport parameters, which will make - // old quic-go versions interpret it as 0, instead of the default value of 2. - // See https://github.com/quic-go/quic-go/pull/3806. - ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs, - InitialSourceConnectionID: srcConnID, - } - if s.config.EnableDatagrams { - params.MaxDatagramFrameSize = protocol.MaxDatagramFrameSize - } else { - params.MaxDatagramFrameSize = protocol.InvalidByteCount - } + params := &wire.TransportParameters{ + InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow), + InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow), + InitialMaxStreamDataUni: protocol.ByteCount(s.config.InitialStreamReceiveWindow), + InitialMaxData: protocol.ByteCount(s.config.InitialConnectionReceiveWindow), + MaxIdleTimeout: s.config.MaxIdleTimeout, + MaxBidiStreamNum: protocol.StreamNum(s.config.MaxIncomingStreams), + MaxUniStreamNum: protocol.StreamNum(s.config.MaxIncomingUniStreams), + MaxAckDelay: protocol.MaxAckDelayInclGranularity, + AckDelayExponent: protocol.AckDelayExponent, + DisableActiveMigration: true, + // For interoperability with quic-go versions before May 2023, this value must be set to a value + // different from protocol.DefaultActiveConnectionIDLimit. + // If set to the default value, it will be omitted from the transport parameters, which will make + // old quic-go versions interpret it as 0, instead of the default value of 2. + // See https://github.com/quic-go/quic-go/pull/3806. + ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs, + InitialSourceConnectionID: srcConnID, } + if s.config.EnableDatagrams { + params.MaxDatagramFrameSize = protocol.MaxDatagramFrameSize + } else { + params.MaxDatagramFrameSize = protocol.InvalidByteCount + } + if s.tracer != nil { s.tracer.SentTransportParameters(params) } @@ -529,38 +521,52 @@ var newUClientConnection = func( oneRTTStream := newCryptoStream() var params *wire.TransportParameters - if s.config.TransportParameters != nil { - params = &wire.TransportParameters{ - InitialSourceConnectionID: srcConnID, - } - params.PopulateFromUQUIC(s.config.TransportParameters) - s.connIDManager.SetConnectionIDLimit(params.ActiveConnectionIDLimit) - } else { - params = &wire.TransportParameters{ - InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow), - InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow), - InitialMaxStreamDataUni: protocol.ByteCount(s.config.InitialStreamReceiveWindow), - InitialMaxData: protocol.ByteCount(s.config.InitialConnectionReceiveWindow), - MaxIdleTimeout: s.config.MaxIdleTimeout, - MaxBidiStreamNum: protocol.StreamNum(s.config.MaxIncomingStreams), - MaxUniStreamNum: protocol.StreamNum(s.config.MaxIncomingUniStreams), - MaxAckDelay: protocol.MaxAckDelayInclGranularity, - AckDelayExponent: protocol.AckDelayExponent, - DisableActiveMigration: true, - // For interoperability with quic-go versions before May 2023, this value must be set to a value - // different from protocol.DefaultActiveConnectionIDLimit. - // If set to the default value, it will be omitted from the transport parameters, which will make - // old quic-go versions interpret it as 0, instead of the default value of 2. - // See https://github.com/quic-go/quic-go/pull/3806. - ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs, - InitialSourceConnectionID: srcConnID, - } - if s.config.EnableDatagrams { - params.MaxDatagramFrameSize = protocol.MaxDatagramFrameSize - } else { - params.MaxDatagramFrameSize = protocol.InvalidByteCount + // params := &wire.TransportParameters{ + // InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow), + // InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow), + // InitialMaxStreamDataUni: protocol.ByteCount(s.config.InitialStreamReceiveWindow), + // InitialMaxData: protocol.ByteCount(s.config.InitialConnectionReceiveWindow), + // MaxIdleTimeout: s.config.MaxIdleTimeout, + // MaxBidiStreamNum: protocol.StreamNum(s.config.MaxIncomingStreams), + // MaxUniStreamNum: protocol.StreamNum(s.config.MaxIncomingUniStreams), + // MaxAckDelay: protocol.MaxAckDelayInclGranularity, + // AckDelayExponent: protocol.AckDelayExponent, + // DisableActiveMigration: true, + // // For interoperability with quic-go versions before May 2023, this value must be set to a value + // // different from protocol.DefaultActiveConnectionIDLimit. + // // If set to the default value, it will be omitted from the transport parameters, which will make + // // old quic-go versions interpret it as 0, instead of the default value of 2. + // // See https://github.com/quic-go/quic-go/pull/3806. + // ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs, + // InitialSourceConnectionID: srcConnID, + // } + // if s.config.EnableDatagrams { + // params.MaxDatagramFrameSize = protocol.MaxDatagramFrameSize + // } else { + // params.MaxDatagramFrameSize = protocol.InvalidByteCount + // } + + // [UQUIC] iterate over all Extensions to set the TransportParameters + var tpSet bool +FOR_EACH_TLS_EXTENSION: + for _, ext := range chs.Extensions { + switch ext := ext.(type) { + case *tls.QUICTransportParametersExtension: + params = &wire.TransportParameters{ + InitialSourceConnectionID: srcConnID, + } + params.PopulateFromUQUIC(ext.TransportParameters) + s.connIDManager.SetConnectionIDLimit(params.ActiveConnectionIDLimit) + tpSet = true + break FOR_EACH_TLS_EXTENSION + default: + continue FOR_EACH_TLS_EXTENSION } } + if !tpSet { + panic("applied ClientHelloSpec must contain a QUICTransportParametersExtension to proceed") + } + if s.tracer != nil { s.tracer.SentTransportParameters(params) } diff --git a/example/uquic/main.go b/example/uquic/main.go index b2fcca80..c97654fa 100644 --- a/example/uquic/main.go +++ b/example/uquic/main.go @@ -12,8 +12,6 @@ import ( quic "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3" - - qtp "github.com/quic-go/quic-go/transportparameters" ) func getCHS() *tls.ClientHelloSpec { @@ -99,7 +97,37 @@ func getCHS() *tls.ClientHelloSpec { &tls.FakeRecordSizeLimitExtension{ Limit: 0x4001, }, - &tls.QUICTransportParametersExtension{}, + &tls.QUICTransportParametersExtension{ + TransportParameters: tls.TransportParameters{ + tls.InitialMaxStreamDataBidiRemote(0x100000), + tls.InitialMaxStreamsBidi(16), + tls.MaxDatagramFrameSize(1200), + tls.MaxIdleTimeout(30000), + tls.ActiveConnectionIDLimit(8), + &tls.GREASEQUICBit{}, + &tls.VersionInformation{ + ChoosenVersion: tls.VERSION_1, + AvailableVersions: []uint32{ + tls.VERSION_GREASE, + tls.VERSION_1, + }, + LegacyID: true, + }, + tls.InitialMaxStreamsUni(16), + &tls.GREASE{ + IdOverride: 0xff02de1a, + ValueOverride: []byte{ + 0x43, 0xe8, + }, + }, + tls.InitialMaxStreamDataBidiLocal(0xc00000), + tls.InitialMaxStreamDataUni(0x100000), + tls.InitialSourceConnectionID([]byte{}), + tls.MaxAckDelay(20), + tls.InitialMaxData(0x1800000), + &tls.DisableActiveMigration{}, + }, + }, &tls.UtlsPaddingExtension{ GetPaddingLen: tls.BoringPaddingStyle, }, @@ -129,35 +157,6 @@ func main() { InitPacketNumber: 0, InitPacketNumberLength: quic.PacketNumberLen1, // currently only affects the initial packet number // Versions: []quic.VersionNumber{quic.Version2}, - TransportParameters: qtp.TransportParameters{ - qtp.InitialMaxStreamDataBidiRemote(0x100000), - qtp.InitialMaxStreamsBidi(16), - qtp.MaxDatagramFrameSize(1200), - qtp.MaxIdleTimeout(30000), - qtp.ActiveConnectionIDLimit(8), - &qtp.GREASEQUICBit{}, - &qtp.VersionInformation{ - ChoosenVersion: qtp.VERSION_1, - AvailableVersions: []uint32{ - qtp.VERSION_GREASE, - qtp.VERSION_1, - }, - LegacyID: true, - }, - qtp.InitialMaxStreamsUni(16), - &qtp.GREASE{ - IdOverride: 0xff02de1a, - ValueOverride: []byte{ - 0x43, 0xe8, - }, - }, - qtp.InitialMaxStreamDataBidiLocal(0xc00000), - qtp.InitialMaxStreamDataUni(0x100000), - qtp.InitialSourceConnectionID([]byte{}), - qtp.MaxAckDelay(20), - qtp.InitialMaxData(0x1800000), - &qtp.DisableActiveMigration{}, - }, } roundTripper := &http3.RoundTripper{ diff --git a/interface.go b/interface.go index e4b1085a..92530923 100644 --- a/interface.go +++ b/interface.go @@ -12,7 +12,6 @@ import ( "github.com/quic-go/quic-go/internal/handshake" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/logging" - "github.com/quic-go/quic-go/transportparameters" ) // The StreamID is the ID of a QUIC stream. @@ -336,11 +335,10 @@ type Config struct { Tracer func(context.Context, logging.Perspective, ConnectionID) logging.ConnectionTracer // TransportParameters override other transport parameters set by the Config. - TransportParameters transportparameters.TransportParameters // [UQUIC] - SrcConnIDLength int // [UQUIC] - DestConnIDLength int // [UQUIC] - InitPacketNumber uint64 // [UQUIC] - InitPacketNumberLength protocol.PacketNumberLen // [UQUIC] + SrcConnIDLength int // [UQUIC] + DestConnIDLength int // [UQUIC] + InitPacketNumber uint64 // [UQUIC] + InitPacketNumberLength PacketNumberLen // [UQUIC] } type ClientHelloInfo struct { diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 782bfa3e..95255d14 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -132,7 +132,7 @@ func NewUCryptoSetupClient( cs.tlsConf = tlsConf cs.conn = qtls.UQUICClient(quicConf, chs) - cs.conn.SetTransportParameters(cs.ourParams.Marshal(protocol.PerspectiveClient)) + // cs.conn.SetTransportParameters(cs.ourParams.Marshal(protocol.PerspectiveClient)) // [UQUIC] doesn't require this return cs } diff --git a/internal/wire/transport_parameters.go b/internal/wire/transport_parameters.go index 02b9c03a..fa5b1683 100644 --- a/internal/wire/transport_parameters.go +++ b/internal/wire/transport_parameters.go @@ -15,7 +15,7 @@ import ( "github.com/quic-go/quic-go/internal/qerr" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/quicvarint" - "github.com/quic-go/quic-go/transportparameters" + tls "github.com/refraction-networking/utls" ) // AdditionalTransportParametersClient are additional transport parameters that will be added @@ -91,7 +91,7 @@ type TransportParameters struct { MaxDatagramFrameSize protocol.ByteCount // only used internally - ClientOverride transportparameters.TransportParameters // [UQUIC] + ClientOverride tls.TransportParameters // [UQUIC] } // Unmarshal the transport parameters @@ -513,41 +513,41 @@ func (p *TransportParameters) String() string { return fmt.Sprintf(logString, logParams...) } -func (tp *TransportParameters) PopulateFromUQUIC(quicparams transportparameters.TransportParameters) { +func (tp *TransportParameters) PopulateFromUQUIC(quicparams tls.TransportParameters) { for pIdx, param := range quicparams { switch param.ID() { case uint64(maxIdleTimeoutParameterID): - tp.MaxIdleTimeout = time.Duration(param.(transportparameters.MaxIdleTimeout)) * time.Millisecond + tp.MaxIdleTimeout = time.Duration(param.(tls.MaxIdleTimeout)) * time.Millisecond case uint64(initialMaxDataParameterID): - tp.InitialMaxData = protocol.ByteCount(param.(transportparameters.InitialMaxData)) + tp.InitialMaxData = protocol.ByteCount(param.(tls.InitialMaxData)) case uint64(initialMaxStreamDataBidiLocalParameterID): - tp.InitialMaxStreamDataBidiLocal = protocol.ByteCount(param.(transportparameters.InitialMaxStreamDataBidiLocal)) + tp.InitialMaxStreamDataBidiLocal = protocol.ByteCount(param.(tls.InitialMaxStreamDataBidiLocal)) case uint64(initialMaxStreamDataBidiRemoteParameterID): - tp.InitialMaxStreamDataBidiRemote = protocol.ByteCount(param.(transportparameters.InitialMaxStreamDataBidiRemote)) + tp.InitialMaxStreamDataBidiRemote = protocol.ByteCount(param.(tls.InitialMaxStreamDataBidiRemote)) case uint64(initialMaxStreamDataUniParameterID): - tp.InitialMaxStreamDataUni = protocol.ByteCount(param.(transportparameters.InitialMaxStreamDataUni)) + tp.InitialMaxStreamDataUni = protocol.ByteCount(param.(tls.InitialMaxStreamDataUni)) case uint64(initialMaxStreamsBidiParameterID): - tp.MaxBidiStreamNum = protocol.StreamNum(param.(transportparameters.InitialMaxStreamsBidi)) + tp.MaxBidiStreamNum = protocol.StreamNum(param.(tls.InitialMaxStreamsBidi)) case uint64(initialMaxStreamsUniParameterID): - tp.MaxUniStreamNum = protocol.StreamNum(param.(transportparameters.InitialMaxStreamsUni)) + tp.MaxUniStreamNum = protocol.StreamNum(param.(tls.InitialMaxStreamsUni)) case uint64(maxAckDelayParameterID): - tp.MaxAckDelay = time.Duration(param.(transportparameters.MaxAckDelay)) * time.Millisecond + tp.MaxAckDelay = time.Duration(param.(tls.MaxAckDelay)) * time.Millisecond case uint64(disableActiveMigrationParameterID): tp.DisableActiveMigration = true case uint64(activeConnectionIDLimitParameterID): - tp.ActiveConnectionIDLimit = uint64(param.(transportparameters.ActiveConnectionIDLimit)) + tp.ActiveConnectionIDLimit = uint64(param.(tls.ActiveConnectionIDLimit)) case uint64(initialSourceConnectionIDParameterID): - srcConnIDOverride, ok := param.(transportparameters.InitialSourceConnectionID) + srcConnIDOverride, ok := param.(tls.InitialSourceConnectionID) if ok { if len(srcConnIDOverride) > 0 { // when nil/empty, will leave default srcConnID tp.InitialSourceConnectionID = protocol.ParseConnectionID(srcConnIDOverride) } else { // reversely populate the transport parameter, for it must be written to network - quicparams[pIdx] = transportparameters.InitialSourceConnectionID(tp.InitialSourceConnectionID.Bytes()) + quicparams[pIdx] = tls.InitialSourceConnectionID(tp.InitialSourceConnectionID.Bytes()) } } case uint64(maxDatagramFrameSizeParameterID): - tp.MaxDatagramFrameSize = protocol.ByteCount(param.(transportparameters.MaxDatagramFrameSize)) + tp.MaxDatagramFrameSize = protocol.ByteCount(param.(tls.MaxDatagramFrameSize)) default: // ignore unknown parameters continue diff --git a/packet_packer.go b/packet_packer.go index dbf8d745..e6ab4bec 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -806,19 +806,27 @@ func (p *packetPacker) appendLongHeaderPacketExternalPadding(buffer *packetBuffe header.Length = pnLen + protocol.ByteCount(sealer.Overhead()) + pl.length startLen := len(buffer.Data) - raw := buffer.Data[startLen:] + raw := buffer.Data[startLen:] // [UQUIC] raw is a sub-slice of buffer.Data, whose len < size raw, err := header.Append(raw, v) if err != nil { return nil, err } + + fmt.Printf("Pre-Payload: %x\n", raw) + payloadOffset := protocol.ByteCount(len(raw)) - raw, err = p.appendPacketPayload(raw, pl, 0, v) + raw, err = p.appendCustomInitialPacketPayload(raw, pl, 0, v) if err != nil { return nil, err } + + fmt.Printf("Pre-Encryption: %x\n", raw) + raw = p.encryptPacket(raw, sealer, header.PacketNumber, payloadOffset, pnLen) buffer.Data = buffer.Data[:len(buffer.Data)+len(raw)] + fmt.Printf("Post-Encryption: %x\n", raw) + // [UQUIC] // append zero to buffer.Data until 1200 bytes buffer.Data = append(buffer.Data, make([]byte, 1357-len(buffer.Data))...) @@ -922,6 +930,44 @@ func (p *packetPacker) appendPacketPayload(raw []byte, pl payload, paddingLen pr return raw, nil } +func (p *packetPacker) appendCustomInitialPacketPayload(raw []byte, pl payload, paddingLen protocol.ByteCount, v protocol.VersionNumber) ([]byte, error) { + payloadOffset := len(raw) + + // [UQUIC] ignores the default ACK/PADDING frame and uses its own frames + // if pl.ack != nil { + // var err error + // raw, err = pl.ack.Append(raw, v) + // if err != nil { + // return nil, err + // } + // } + // if paddingLen > 0 { + // raw = append(raw, make([]byte, paddingLen)...) + // } + + for _, f := range pl.frames { + var err error + raw, err = f.Frame.Append(raw, v) + if err != nil { + return nil, err + } + fmt.Printf("UQUIC: appending frame %v\n", f) + } + for _, f := range pl.streamFrames { + var err error + raw, err = f.Frame.Append(raw, v) + if err != nil { + return nil, err + } + fmt.Printf("UQUIC: appending stream frame %v\n", f) + } + + if payloadSize := protocol.ByteCount(len(raw)-payloadOffset) - paddingLen; payloadSize != pl.length { + return nil, fmt.Errorf("PacketPacker BUG: payload size inconsistent (expected %d, got %d bytes)", pl.length, payloadSize) + } + return raw, nil +} + func (p *packetPacker) encryptPacket(raw []byte, sealer sealer, pn protocol.PacketNumber, payloadOffset, pnLen protocol.ByteCount) []byte { _ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], pn, raw[:payloadOffset]) raw = raw[:len(raw)+sealer.Overhead()] diff --git a/transportparameters/u_transport_parameters.go b/transportparameters/u_transport_parameters.go deleted file mode 100644 index 0305a7b0..00000000 --- a/transportparameters/u_transport_parameters.go +++ /dev/null @@ -1,303 +0,0 @@ -package transportparameters - -import ( - "crypto/rand" - "encoding/binary" - "math" - "math/big" - - "github.com/quic-go/quic-go/quicvarint" -) - -const ( - // RFC IDs - max_idle_timeout uint64 = 0x1 - max_udp_payload_size uint64 = 0x3 - initial_max_data uint64 = 0x4 - initial_max_stream_data_bidi_local uint64 = 0x5 - initial_max_stream_data_bidi_remote uint64 = 0x6 - initial_max_stream_data_uni uint64 = 0x7 - initial_max_streams_bidi uint64 = 0x8 - initial_max_streams_uni uint64 = 0x9 - max_ack_delay uint64 = 0xb - disable_active_migration uint64 = 0xc - active_connection_id_limit uint64 = 0xe - initial_source_connection_id uint64 = 0xf - version_information uint64 = 0x11 // RFC 9368 - padding uint64 = 0x15 - max_datagram_frame_size uint64 = 0x20 // RFC 9221 - grease_quic_bit uint64 = 0x2ab2 - - // Legacy IDs from draft - version_information_legacy uint64 = 0xff73db // draft-ietf-quic-version-negotiation-13 and early -) - -type TransportParameters []TransportParameter - -func (tps TransportParameters) Marshal() []byte { - var b []byte - for _, tp := range tps { - b = quicvarint.Append(b, tp.ID()) - b = quicvarint.Append(b, uint64(len(tp.Value()))) - b = append(b, tp.Value()...) - } - return b -} - -// TransportParameter represents a QUIC transport parameter. -// -// Caller will write the following to the wire: -// -// var b []byte -// b = quicvarint.Append(b, ID()) -// b = quicvarint.Append(b, len(Value())) -// b = append(b, Value()) -// -// Therefore Value() should return the exact bytes to be written to the wire AFTER the length field, -// i.e., the bytes MAY be a Variable Length Integer per RFC depending on the type of the transport -// parameter, but MUST NOT including the length field unless the parameter is defined so. -type TransportParameter interface { - ID() uint64 - Value() []byte -} - -type GREASE struct { - IdOverride uint64 // if set to a valid GREASE ID, use this instead of randomly generated one. - Length uint16 // if len(ValueOverride) == 0, will generate random data of this size. - ValueOverride []byte // if len(ValueOverride) > 0, use this instead of random bytes. -} - -const ( - GREASE_MAX_MULTIPLIER = (0x3FFFFFFFFFFFFFFF - 27) / 31 -) - -// IsGREASEID returns true if id is a valid GREASE ID for -// transport parameters. -func (GREASE) IsGREASEID(id uint64) bool { - return (id-27)%31 == 0 -} - -// GetGREASEID returns a random valid GREASE ID for transport parameters. -func (GREASE) GetGREASEID() uint64 { - max := big.NewInt(GREASE_MAX_MULTIPLIER) - - randMultiply, err := rand.Int(rand.Reader, max) - if err != nil { - return 27 - } - - return 27 + randMultiply.Uint64()*31 -} - -func (g *GREASE) ID() uint64 { - if !g.IsGREASEID(g.IdOverride) { - g.IdOverride = g.GetGREASEID() - } - return g.IdOverride -} - -func (g *GREASE) Value() []byte { - if len(g.ValueOverride) == 0 { - g.ValueOverride = make([]byte, g.Length) - rand.Read(g.ValueOverride) - } - return g.ValueOverride -} - -type MaxIdleTimeout uint64 // in milliseconds - -func (MaxIdleTimeout) ID() uint64 { - return max_idle_timeout -} - -func (m MaxIdleTimeout) Value() []byte { - return quicvarint.Append([]byte{}, uint64(m)) -} - -type MaxUDPPayloadSize uint64 - -func (MaxUDPPayloadSize) ID() uint64 { - return max_udp_payload_size -} - -func (m MaxUDPPayloadSize) Value() []byte { - return quicvarint.Append([]byte{}, uint64(m)) -} - -type InitialMaxData uint64 - -func (InitialMaxData) ID() uint64 { - return initial_max_data -} - -func (i InitialMaxData) Value() []byte { - return quicvarint.Append([]byte{}, uint64(i)) -} - -type InitialMaxStreamDataBidiLocal uint64 - -func (InitialMaxStreamDataBidiLocal) ID() uint64 { - return initial_max_stream_data_bidi_local -} - -func (i InitialMaxStreamDataBidiLocal) Value() []byte { - return quicvarint.Append([]byte{}, uint64(i)) -} - -type InitialMaxStreamDataBidiRemote uint64 - -func (InitialMaxStreamDataBidiRemote) ID() uint64 { - return initial_max_stream_data_bidi_remote -} - -func (i InitialMaxStreamDataBidiRemote) Value() []byte { - return quicvarint.Append([]byte{}, uint64(i)) -} - -type InitialMaxStreamDataUni uint64 - -func (InitialMaxStreamDataUni) ID() uint64 { - return initial_max_stream_data_uni -} - -func (i InitialMaxStreamDataUni) Value() []byte { - return quicvarint.Append([]byte{}, uint64(i)) -} - -type InitialMaxStreamsBidi uint64 - -func (InitialMaxStreamsBidi) ID() uint64 { - return initial_max_streams_bidi -} - -func (i InitialMaxStreamsBidi) Value() []byte { - return quicvarint.Append([]byte{}, uint64(i)) -} - -type InitialMaxStreamsUni uint64 - -func (InitialMaxStreamsUni) ID() uint64 { - return initial_max_streams_uni -} - -func (i InitialMaxStreamsUni) Value() []byte { - return quicvarint.Append([]byte{}, uint64(i)) -} - -type MaxAckDelay uint64 - -func (MaxAckDelay) ID() uint64 { - return max_ack_delay -} - -func (m MaxAckDelay) Value() []byte { - return quicvarint.Append([]byte{}, uint64(m)) -} - -type DisableActiveMigration struct{} - -func (*DisableActiveMigration) ID() uint64 { - return disable_active_migration -} - -// Its Value MUST ALWAYS be empty. -func (*DisableActiveMigration) Value() []byte { - return []byte{} -} - -type ActiveConnectionIDLimit uint64 - -func (ActiveConnectionIDLimit) ID() uint64 { - return active_connection_id_limit -} - -func (a ActiveConnectionIDLimit) Value() []byte { - return quicvarint.Append([]byte{}, uint64(a)) -} - -type InitialSourceConnectionID []byte // if empty, will be set to the Connection ID used for the Initial packet. - -func (InitialSourceConnectionID) ID() uint64 { - return initial_source_connection_id -} - -func (i InitialSourceConnectionID) Value() []byte { - return []byte(i) -} - -type VersionInformation struct { - ChoosenVersion uint32 - AvailableVersions []uint32 // Also known as "Other Versions" in early drafts. - - LegacyID bool // If true, use the legacy-assigned ID (0xff73db) instead of the RFC-assigned one (0x11). -} - -const ( - VERSION_NEGOTIATION uint32 = 0x00000000 // rfc9000 - VERSION_1 uint32 = 0x00000001 // rfc9000 - VERSION_2 uint32 = 0x6b3343cf // rfc9369 - - VERSION_GREASE uint32 = 0x0a0a0a0a // -> 0x?a?a?a?a -) - -func (v *VersionInformation) ID() uint64 { - if v.LegacyID { - return version_information_legacy - } - return version_information -} - -func (v *VersionInformation) Value() []byte { - var b []byte - b = binary.BigEndian.AppendUint32(b, v.ChoosenVersion) - for _, version := range v.AvailableVersions { - if version != VERSION_GREASE { - b = binary.BigEndian.AppendUint32(b, version) - } else { - b = binary.BigEndian.AppendUint32(b, v.GetGREASEVersion()) - } - } - return b -} - -func (*VersionInformation) GetGREASEVersion() uint32 { - // get a random uint32 - max := big.NewInt(math.MaxUint32) - randVal, err := rand.Int(rand.Reader, max) - if err != nil { - return VERSION_GREASE - } - - return uint32(randVal.Uint64()&math.MaxUint32) | 0x0a0a0a0a // all GREASE versions are in 0x?a?a?a?a -} - -type Padding []byte - -func (Padding) ID() uint64 { - return padding -} - -func (p Padding) Value() []byte { - return p -} - -type MaxDatagramFrameSize uint64 - -func (MaxDatagramFrameSize) ID() uint64 { - return max_datagram_frame_size -} - -func (m MaxDatagramFrameSize) Value() []byte { - return quicvarint.Append([]byte{}, uint64(m)) -} - -type GREASEQUICBit struct{} - -func (*GREASEQUICBit) ID() uint64 { - return grease_quic_bit -} - -// Its Value MUST ALWAYS be empty. -func (*GREASEQUICBit) Value() []byte { - return []byte{} -} diff --git a/transportparameters/u_transport_parameters_test.go b/transportparameters/u_transport_parameters_test.go deleted file mode 100644 index 65a64dfa..00000000 --- a/transportparameters/u_transport_parameters_test.go +++ /dev/null @@ -1,71 +0,0 @@ -package transportparameters - -import ( - "bytes" - "testing" -) - -func TestMarshal(t *testing.T) { - t.Run("Firefox", testTransportParametersFirefox) -} - -func testTransportParametersFirefox(t *testing.T) { - if !bytes.Equal(_inputTransportParametersFirefox.Marshal(), _truthTransportParametersFirefox) { - t.Errorf("TransportParameters.Marshal() = %v, want %v", _inputTransportParametersFirefox.Marshal(), _truthTransportParametersFirefox) - } -} - -var ( - _inputTransportParametersFirefox = TransportParameters{ - InitialMaxStreamDataBidiRemote(0x100000), - InitialMaxStreamsBidi(16), - MaxDatagramFrameSize(1200), - MaxIdleTimeout(30000), - ActiveConnectionIDLimit(8), - &GREASEQUICBit{}, - &VersionInformation{ - ChoosenVersion: 0x00000001, - AvailableVersions: []uint32{ - 0x8acafaea, - 0x00000001, - }, - LegacyID: true, - }, - InitialMaxStreamsUni(16), - &GREASE{ - IdOverride: 0xff02de1a, - ValueOverride: []byte{ - 0x43, 0xe8, - }, - }, - InitialMaxStreamDataBidiLocal(0xc00000), - InitialMaxStreamDataUni(0x100000), - InitialSourceConnectionID([]byte{0x53, 0xf0, 0xb2}), - MaxAckDelay(20), - InitialMaxData(0x1800000), - &DisableActiveMigration{}, - } - _truthTransportParametersFirefox = []byte{ - 0x06, 0x04, 0x80, 0x10, - 0x00, 0x00, 0x08, 0x01, - 0x10, 0x20, 0x02, 0x44, - 0xb0, 0x01, 0x04, 0x80, - 0x00, 0x75, 0x30, 0x0e, - 0x01, 0x08, 0x6a, 0xb2, - 0x00, 0x80, 0xff, 0x73, - 0xdb, 0x0c, 0x00, 0x00, - 0x00, 0x01, 0x8a, 0xca, - 0xfa, 0xea, 0x00, 0x00, - 0x00, 0x01, 0x09, 0x01, - 0x10, 0xc0, 0x00, 0x00, - 0x00, 0xff, 0x02, 0xde, - 0x1a, 0x02, 0x43, 0xe8, - 0x05, 0x04, 0x80, 0xc0, - 0x00, 0x00, 0x07, 0x04, - 0x80, 0x10, 0x00, 0x00, - 0x0f, 0x03, 0x53, 0xf0, - 0xb2, 0x0b, 0x01, 0x14, - 0x04, 0x04, 0x81, 0x80, - 0x00, 0x00, 0x0c, 0x00, - } -) From 95f3eaaa66613cfbdc6b0a8324f10857fd79d8e3 Mon Sep 17 00:00:00 2001 From: Gaukas Wang Date: Sun, 30 Jul 2023 23:20:36 -0600 Subject: [PATCH 2/3] wip: InitialSpec (2/n) - Added QUICFrame to describe QUIC Frame found in an Initial Packet, including PADDING, PING, and CRYPTO. - Added QUICSpec to describe the QUIC Header and QUIC Frames' order/length/offset. --- go.mod | 7 +- go.sum | 15 ++-- u_quic_spec.go | 200 ++++++++++++++++++++++++++++++++++++++++++++ u_quic_spec_test.go | 118 ++++++++++++++++++++++++++ 4 files changed, 332 insertions(+), 8 deletions(-) create mode 100644 u_quic_spec.go create mode 100644 u_quic_spec_test.go diff --git a/go.mod b/go.mod index 122e9c8c..70085c11 100644 --- a/go.mod +++ b/go.mod @@ -6,12 +6,12 @@ replace github.com/refraction-networking/utls => ../utls require ( github.com/francoispqt/gojay v1.2.13 + github.com/gaukas/clienthellod v0.4.0 github.com/golang/mock v1.6.0 github.com/onsi/ginkgo/v2 v2.9.5 github.com/onsi/gomega v1.27.6 github.com/quic-go/qpack v0.4.0 - github.com/quic-go/qtls-go1-20 v0.3.0 - github.com/refraction-networking/utls v0.0.0-00010101000000-000000000000 + github.com/refraction-networking/utls v1.3.2 golang.org/x/crypto v0.10.0 golang.org/x/exp v0.0.0-20221205204356-47842c84f3db golang.org/x/net v0.11.0 @@ -21,10 +21,11 @@ require ( require ( github.com/andybalholm/brotli v1.0.5 // indirect - github.com/gaukas/godicttls v0.0.3 // indirect + github.com/gaukas/godicttls v0.0.4 // indirect github.com/go-logr/logr v1.2.4 // indirect github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/google/go-cmp v0.5.9 // indirect + github.com/google/gopacket v1.1.19 // indirect github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect github.com/klauspost/compress v1.16.6 // indirect golang.org/x/mod v0.10.0 // indirect diff --git a/go.sum b/go.sum index 95e32f22..20095f99 100644 --- a/go.sum +++ b/go.sum @@ -27,8 +27,10 @@ github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI github.com/francoispqt/gojay v1.2.13 h1:d2m3sFjloqoIUQU3TsHBgj6qg/BVGlTBeHDUmyJnXKk= github.com/francoispqt/gojay v1.2.13/go.mod h1:ehT5mTG4ua4581f1++1WLG0vPdaA9HaiDsoyrBGkyDY= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= -github.com/gaukas/godicttls v0.0.3 h1:YNDIf0d9adcxOijiLrEzpfZGAkNwLRzPaG6OjU7EITk= -github.com/gaukas/godicttls v0.0.3/go.mod h1:l6EenT4TLWgTdwslVb4sEMOCf7Bv0JAK67deKr9/NCI= +github.com/gaukas/clienthellod v0.4.0 h1:DySeZT4c3Xw6OGMzHRlAuOHx9q1P7vQNjA7YkyHrqac= +github.com/gaukas/clienthellod v0.4.0/go.mod h1:gjt7a7cNNzZV4yTe0jKcXtj0a7u6RL2KQvijxFOvcZE= +github.com/gaukas/godicttls v0.0.4 h1:NlRaXb3J6hAnTmWdsEKb9bcSBD6BvcIjdGdeb0zfXbk= +github.com/gaukas/godicttls v0.0.4/go.mod h1:l6EenT4TLWgTdwslVb4sEMOCf7Bv0JAK67deKr9/NCI= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/gliderlabs/ssh v0.1.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q= @@ -52,6 +54,8 @@ github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ= github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= +github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= +github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE= @@ -96,8 +100,6 @@ github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e/go.mod h1:daVV7q github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= -github.com/quic-go/qtls-go1-20 v0.3.0 h1:NrCXmDl8BddZwO67vlvEpBTwT89bJfKYygxv4HQvuDk= -github.com/quic-go/qtls-go1-20 v0.3.0/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY= @@ -147,6 +149,8 @@ golang.org/x/exp v0.0.0-20221205204356-47842c84f3db/go.mod h1:CxIveKay+FTh1D0yPZ golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk= golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= @@ -202,6 +206,7 @@ golang.org/x/tools v0.0.0-20181030000716-a0a13e073c7b/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.9.1 h1:8WMNJAz3zrtPmnYC7ISf5dEn3MT0gY7jBJfw27yrrLo= golang.org/x/tools v0.9.1/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc= @@ -224,7 +229,7 @@ google.golang.org/grpc v1.14.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmE google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio= google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= -google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw= +google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/u_quic_spec.go b/u_quic_spec.go new file mode 100644 index 00000000..4b4999bb --- /dev/null +++ b/u_quic_spec.go @@ -0,0 +1,200 @@ +package quic + +import ( + "bytes" + "crypto/rand" + "errors" + + "github.com/gaukas/clienthellod" + "github.com/quic-go/quic-go/quicvarint" +) + +type QUICSpec struct { + // SrcConnIDLength specifies how many bytes should the SrcConnID be + SrcConnIDLength int + + // DestConnIDLength specifies how many bytes should the DestConnID be + DstConnIDLength int + + // InitPacketNumberLength specifies how many bytes should the InitPacketNumber + // be interpreted as. It is usually 1 or 2 bytes. If unset, UQUIC will use the + // default algorithm to compute the length which is at least 2 bytes. + InitPacketNumberLength PacketNumberLen + + // InitPacketNumber is the packet number of the first Initial packet. Following + // Initial packets, if any, will increment the Packet Number accordingly. + InitPacketNumber uint64 // [UQUIC] + + // TokenStore is used to store and retrieve tokens. If set, will override the + // one set in the Config. + TokenStore TokenStore + + // If ClientTokenLength is set when TokenStore is not set, a dummy TokenStore + // will be created to randomly generate tokens of the specified length for + // Pop() calls with any key and silently drop any Put() calls. + // + // However, the tokens will not be stored anywhere and are expected to be + // invalid since not assigned by the server. + ClientTokenLength int + + // QUICFrames specifies a list of QUIC frames to be sent in the first Initial + // packet. + // + // If nil, it will be treated as a list with only a single QUICFrameCrypto. + QUICFrames []QUICFrame +} + +func (s *QUICSpec) getTokenStore() TokenStore { + if s.TokenStore != nil { + return s.TokenStore + } + + if s.ClientTokenLength > 0 { + return &dummyTokenStore{ + tokenLength: s.ClientTokenLength, + } + } + + return nil +} + +type dummyTokenStore struct { + tokenLength int +} + +func (d *dummyTokenStore) Pop(key string) (token *ClientToken) { + var data []byte = make([]byte, d.tokenLength) + rand.Read(data) + + return &ClientToken{ + data: data, + } +} + +func (d *dummyTokenStore) Put(_ string, _ *ClientToken) { + // Do nothing +} + +type QUICFrames []QUICFrame + +func (qfs QUICFrames) MarshalWithCryptoData(cryptoData []byte) (payload []byte, err error) { + if len(qfs) == 0 { // If no frames specified, send a single crypto frame + payload = make([]byte, len(cryptoData)+1) + } + + for _, frame := range qfs { + var frameBytes []byte + if offset, length, cryptoOK := frame.CryptoFrameInfo(); cryptoOK { + if length == 0 { + // calculate length: from offset to the end of cryptoData + length = len(cryptoData) - offset + } + frameBytes = []byte{0x06} // CRYPTO frame type + frameBytes = quicvarint.Append(frameBytes, uint64(offset)) + frameBytes = quicvarint.Append(frameBytes, uint64(length)) + frameCryptoData := make([]byte, length) + copy(frameCryptoData, cryptoData[offset:]) // copy at most length bytes + frameBytes = append(frameBytes, frameCryptoData...) + } else { // Handle none crypto frames: read and append to payload + frameBytes, err = frame.Read() + if err != nil { + return nil, err + } + } + payload = append(payload, frameBytes...) + } + return payload, nil +} + +func (qfs QUICFrames) MarshalWithFrames(frames []byte) (payload []byte, err error) { + // parse frames + r := bytes.NewReader(frames) + qchframes, err := clienthellod.ReadAllFrames(r) + if err != nil { + return nil, err + } + + // parse crypto data + cryptoData, err := clienthellod.ReassembleCRYPTOFrames(qchframes) + if err != nil { + return nil, err + } + + // marshal + return qfs.MarshalWithCryptoData(cryptoData) +} + +type QUICFrame interface { + // None crypto frames should return false for cryptoOK + CryptoFrameInfo() (offset, length int, cryptoOK bool) + + // None crypto frames should return the byte representation of the frame. + // Crypto frames' behavior is undefined and unused. + Read() ([]byte, error) +} + +// QUICFrameCrypto is used to specify the crypto frames containing the TLS ClientHello +// to be sent in the first Initial packet. +type QUICFrameCrypto struct { + // Offset is used to specify the starting offset of the crypto frame. + // Used when sending multiple crypto frames in a single packet. + // + // Multiple crypto frames in a single packet must not overlap and must + // make up an entire crypto stream continuously. + Offset int + + // Length is used to specify the length of the crypto frame. + // + // Must be set if it is NOT the last crypto frame in a packet. + Length int +} + +// CryptoFrameInfo() implements the QUICFrame interface. +// +// Crypto frames are later replaced by the crypto message using the information +// returned by this function. +func (q QUICFrameCrypto) CryptoFrameInfo() (offset, length int, cryptoOK bool) { + return q.Offset, q.Length, true +} + +// Read() implements the QUICFrame interface. +// +// Crypto frames are later replaced by the crypto message, so they are not Read()-able. +func (q QUICFrameCrypto) Read() ([]byte, error) { + return nil, errors.New("crypto frames are not Read()-able") +} + +// QUICFramePadding is used to specify the padding frames to be sent in the first Initial +// packet. +type QUICFramePadding struct { + // Length is used to specify the length of the padding frame. + Length int +} + +// CryptoFrameInfo() implements the QUICFrame interface. +func (q QUICFramePadding) CryptoFrameInfo() (offset, length int, cryptoOK bool) { + return 0, 0, false +} + +// Read() implements the QUICFrame interface. +// +// Padding simply returns a slice of bytes of the specified length filled with 0. +func (q QUICFramePadding) Read() ([]byte, error) { + return make([]byte, q.Length), nil +} + +// QUICFramePing is used to specify the ping frames to be sent in the first Initial +// packet. +type QUICFramePing struct{} + +// CryptoFrameInfo() implements the QUICFrame interface. +func (q QUICFramePing) CryptoFrameInfo() (offset, length int, cryptoOK bool) { + return 0, 0, false +} + +// Read() implements the QUICFrame interface. +// +// Ping simply returns a slice of bytes of size 1 with value 0x01(PING). +func (q QUICFramePing) Read() ([]byte, error) { + return []byte{0x01}, nil +} diff --git a/u_quic_spec_test.go b/u_quic_spec_test.go new file mode 100644 index 00000000..9d546c19 --- /dev/null +++ b/u_quic_spec_test.go @@ -0,0 +1,118 @@ +package quic + +import ( + "bytes" + "testing" + + "github.com/gaukas/clienthellod" +) + +func TestQUICFramesMarshalWithCryptoData(t *testing.T) { + resultQUICPayload, err := testQUICFrames.MarshalWithCryptoData(testCryptoFrameBytes) + if err != nil { + t.Fatalf("Failed to marshal QUIC frames: %v", err) + } + + if len(resultQUICPayload) != len(truthQUICPayload) { + t.Fatalf("QUIC payload length mismatch: got %d, want %d. \n%x", len(resultQUICPayload), len(truthQUICPayload), resultQUICPayload) + } + + // verify that the crypto frames would actually assemble the original crypto data + r := bytes.NewReader(resultQUICPayload) + qchframes, err := clienthellod.ReadAllFrames(r) + if err != nil { + t.Fatalf("Failed to read QUIC frames: %v", err) + } + + reassembledCryptoData, err := clienthellod.ReassembleCRYPTOFrames(qchframes) + if err != nil { + t.Fatalf("Failed to reassemble crypto data: %v", err) + } + if !bytes.Equal(reassembledCryptoData, testCryptoFrameBytes) { + t.Fatalf("Reassembled crypto data mismatch: \n%x", reassembledCryptoData) + } +} + +var ( + testCryptoFrameBytes = []byte{ + 0x00, 0x01, 0x02, 0x03, + 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0a, 0x0b, + 0x0c, 0x0d, 0x0e, 0x0f, + 0x10, 0x11, 0x12, 0x13, + 0x14, 0x15, 0x16, 0x17, + 0x18, 0x19, 0x1a, 0x1b, + 0x1c, 0x1d, 0x1e, 0x1f, + 0x20, 0x21, 0x22, 0x23, + 0x24, 0x25, 0x26, 0x27, + 0x28, 0x29, 0x2a, 0x2b, + 0x2c, 0x2d, 0x2e, 0x2f, + 0x30, 0x31, 0x32, 0x33, + 0x34, 0x35, 0x36, 0x37, + 0x38, 0x39, 0x3a, 0x3b, + 0x3c, 0x3d, 0x3e, 0x3f, + } // 64 bytes + + testQUICFrames = QUICFrames{ + // first 64 bytes: 01 + 63 bytes of padding + &QUICFramePing{}, + &QUICFramePadding{Length: 63}, + // second 64 bytes: last 32 bytes of crypto frame + 29 bytes of padding + &QUICFrameCrypto{ + Offset: 32, + Length: 0, + }, + &QUICFramePadding{Length: 29}, + // third 64 bytes: first 16 bytes of crypto frame + 45 bytes of padding + &QUICFrameCrypto{ + Offset: 0, + Length: 16, + }, + &QUICFramePadding{Length: 45}, + // fourth 64 bytes: second 16 bytes of crypto frame + 45 bytes of padding + &QUICFrameCrypto{ + Offset: 16, + Length: 16, + }, + &QUICFramePadding{Length: 45}, + } + + truthQUICPayload = []byte{ + 0x01, // ping + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // 63 bytes of padding + 0x06, 0x20, 0x20, // 3 bytes header + 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, + 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, + 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, + 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f, // 32 bytes of crypto frame + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, // 29 bytes of padding + 0x06, 0x00, 0x10, // 3 bytes header + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, // 16 bytes of crypto frame + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, // 45 bytes of padding + 0x06, 0x10, 0x10, // 3 bytes header + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, + 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, // 16 bytes of crypto frame + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, // 45 bytes of padding + } +) From ea40752ca3052bb6f44212261d7c4a37a587dfbd Mon Sep 17 00:00:00 2001 From: Gaukas Wang Date: Wed, 2 Aug 2023 15:38:16 -0600 Subject: [PATCH 3/3] new: uquic --- client.go | 109 ++----- config.go | 6 - conn_id_manager.go | 7 +- connection.go | 152 +--------- example/uquic/main.go | 268 ++++++++++++++---- http3/roundtrip.go | 6 +- http3/u_roundtrip.go | 192 +++++++++++++ interface.go | 6 - internal/ackhandler/sent_packet_handler.go | 13 - internal/ackhandler/u_ackhandler.go | 23 ++ internal/ackhandler/u_sent_packet_handler.go | 30 ++ internal/handshake/crypto_setup.go | 36 +-- internal/handshake/u_crypto_setup.go | 45 +++ internal/protocol/connection_id.go | 5 - internal/protocol/u_connection_id.go | 16 ++ packet_packer.go | 113 +------- transport.go | 20 -- u_client.go | 150 ++++++++++ u_conn_id_manager.go | 6 + u_connection.go | 170 +++++++++++ u_initial_packet_spec.go | 205 ++++++++++++++ ...c_test.go => u_initial_packet_spec_test.go | 0 u_packet_packer.go | 243 ++++++++++++++++ u_quic_spec.go | 198 +------------ u_transport.go | 87 ++++++ 25 files changed, 1420 insertions(+), 686 deletions(-) create mode 100644 http3/u_roundtrip.go create mode 100644 internal/ackhandler/u_ackhandler.go create mode 100644 internal/ackhandler/u_sent_packet_handler.go create mode 100644 internal/handshake/u_crypto_setup.go create mode 100644 internal/protocol/u_connection_id.go create mode 100644 u_client.go create mode 100644 u_conn_id_manager.go create mode 100644 u_connection.go create mode 100644 u_initial_packet_spec.go rename u_quic_spec_test.go => u_initial_packet_spec_test.go (100%) create mode 100644 u_packet_packer.go create mode 100644 u_transport.go diff --git a/client.go b/client.go index dd72f668..b6969573 100644 --- a/client.go +++ b/client.go @@ -38,8 +38,6 @@ type client struct { tracer logging.ConnectionTracer tracingID uint64 logger utils.Logger - - chs *tls.ClientHelloSpec // [UQUIC] } // make it possible to mock connection ID for initial generation in the tests @@ -167,41 +165,6 @@ func dial( return c.conn, nil } -func dialWithCHS( - ctx context.Context, - conn sendConn, - connIDGenerator ConnectionIDGenerator, - packetHandlers packetHandlerManager, - tlsConf *tls.Config, - config *Config, - onClose func(), - use0RTT bool, - chs *tls.ClientHelloSpec, -) (quicConn, error) { - c, err := newClient(conn, connIDGenerator, config, tlsConf, onClose, use0RTT) - if err != nil { - return nil, err - } - c.packetHandlers = packetHandlers - - c.tracingID = nextConnTracingID() - if c.config.Tracer != nil { - c.tracer = c.config.Tracer(context.WithValue(ctx, ConnectionTracingKey, c.tracingID), protocol.PerspectiveClient, c.destConnID) - } - if c.tracer != nil { - c.tracer.StartedConnection(c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID) - } - - // [UQUIC] - c.chs = chs - // [/UQUIC] - - if err := c.dial(ctx); err != nil { - return nil, err - } - return c.conn, nil -} - func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config *Config, tlsConf *tls.Config, onClose func(), use0RTT bool) (*client, error) { if tlsConf == nil { tlsConf = &tls.Config{} @@ -209,25 +172,12 @@ func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config tlsConf = tlsConf.Clone() } - // // [UQUIC] - // if config.SrcConnIDLength != 0 { - // connIDLen := config.SrcConnIDLength - // connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: connIDLen} - // } - srcConnID, err := connIDGenerator.GenerateConnectionID() if err != nil { return nil, err } - var destConnID protocol.ConnectionID - // [UQUIC] - if config.DestConnIDLength > 0 { - destConnID, err = generateConnectionIDForInitialWithLength(config.DestConnIDLength) - } else { - destConnID, err = generateConnectionIDForInitial() - } - // [/UQUIC] + destConnID, err := generateConnectionIDForInitial() if err != nil { return nil, err } @@ -243,8 +193,6 @@ func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config version: config.Versions[0], handshakeChan: make(chan struct{}), logger: utils.DefaultLogger.WithPrefix("client"), - - initialPacketNumber: protocol.PacketNumber(config.InitPacketNumber), // [UQUIC] } return c, nil } @@ -252,45 +200,22 @@ func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config func (c *client) dial(ctx context.Context) error { c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID, c.version) - // [UQUIC] - if c.chs == nil { - c.conn = newClientConnection( - c.sendConn, - c.packetHandlers, - c.destConnID, - c.srcConnID, - c.connIDGenerator, - c.config, - c.tlsConf, - c.initialPacketNumber, - c.use0RTT, - c.hasNegotiatedVersion, - c.tracer, - c.tracingID, - c.logger, - c.version, - ) - } else { - // [UQUIC]: use custom version of the connection - c.conn = newUClientConnection( - c.sendConn, - c.packetHandlers, - c.destConnID, - c.srcConnID, - c.connIDGenerator, - c.config, - c.tlsConf, - c.initialPacketNumber, - c.use0RTT, - c.hasNegotiatedVersion, - c.tracer, - c.tracingID, - c.logger, - c.version, - c.chs, - ) - } - // [/UQUIC] + c.conn = newClientConnection( + c.sendConn, + c.packetHandlers, + c.destConnID, + c.srcConnID, + c.connIDGenerator, + c.config, + c.tlsConf, + c.initialPacketNumber, + c.use0RTT, + c.hasNegotiatedVersion, + c.tracer, + c.tracingID, + c.logger, + c.version, + ) c.packetHandlers.Add(c.srcConnID, c.conn) diff --git a/config.go b/config.go index f478b979..de41656c 100644 --- a/config.go +++ b/config.go @@ -131,12 +131,6 @@ func populateConfig(config *Config) *Config { DisableVersionNegotiationPackets: config.DisableVersionNegotiationPackets, Allow0RTT: config.Allow0RTT, Tracer: config.Tracer, - - // [UQUIC] - SrcConnIDLength: config.SrcConnIDLength, - DestConnIDLength: config.DestConnIDLength, - InitPacketNumber: config.InitPacketNumber, - InitPacketNumberLength: config.InitPacketNumberLength, } } diff --git a/conn_id_manager.go b/conn_id_manager.go index 9008f7f7..7cb38e2e 100644 --- a/conn_id_manager.go +++ b/conn_id_manager.go @@ -62,10 +62,12 @@ func (h *connIDManager) Add(f *wire.NewConnectionIDFrame) error { return err } + // [UQUIC] connIDLimit := h.connectionIDLimit if connIDLimit == 0 { connIDLimit = protocol.MaxActiveConnectionIDs } + // [/UQUIC] if uint64(h.queue.Len()) >= connIDLimit { return &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError} @@ -191,11 +193,6 @@ func (h *connIDManager) SetStatelessResetToken(token protocol.StatelessResetToke h.addStatelessResetToken(token) } -// [UQUIC] -func (h *connIDManager) SetConnectionIDLimit(limit uint64) { - h.connectionIDLimit = limit -} - func (h *connIDManager) SentPacket() { h.packetsSinceLastChange++ } diff --git a/connection.go b/connection.go index 06a9f10a..47589a23 100644 --- a/connection.go +++ b/connection.go @@ -390,9 +390,6 @@ var newClientConnection = func( s.tracer, s.logger, ) - if conf.InitPacketNumberLength != 0 { - ackhandler.SetInitialPacketNumberLength(s.sentPacketHandler, conf.InitPacketNumberLength) - } s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize) oneRTTStream := newCryptoStream() @@ -453,151 +450,6 @@ var newClientConnection = func( return s } -// [UQUIC] -var newUClientConnection = func( - conn sendConn, - runner connRunner, - destConnID protocol.ConnectionID, - srcConnID protocol.ConnectionID, - connIDGenerator ConnectionIDGenerator, - conf *Config, - tlsConf *tls.Config, - initialPacketNumber protocol.PacketNumber, - enable0RTT bool, - hasNegotiatedVersion bool, - tracer logging.ConnectionTracer, - tracingID uint64, - logger utils.Logger, - v protocol.VersionNumber, - chs *tls.ClientHelloSpec, -) quicConn { - s := &connection{ - conn: conn, - config: conf, - origDestConnID: destConnID, - handshakeDestConnID: destConnID, - srcConnIDLen: srcConnID.Len(), - perspective: protocol.PerspectiveClient, - logID: destConnID.String(), - logger: logger, - tracer: tracer, - versionNegotiated: hasNegotiatedVersion, - version: v, - } - s.connIDManager = newConnIDManager( - destConnID, - func(token protocol.StatelessResetToken) { runner.AddResetToken(token, s) }, - runner.RemoveResetToken, - s.queueControlFrame, - ) - - s.connIDGenerator = newConnIDGenerator( - srcConnID, - nil, - func(connID protocol.ConnectionID) { runner.Add(connID, s) }, - runner.GetStatelessResetToken, - runner.Remove, - runner.Retire, - runner.ReplaceWithClosed, - s.queueControlFrame, - connIDGenerator, - ) - s.preSetup() - s.ctx, s.ctxCancel = context.WithCancelCause(context.WithValue(context.Background(), ConnectionTracingKey, tracingID)) - s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler( - initialPacketNumber, - getMaxPacketSize(s.conn.RemoteAddr()), - s.rttStats, - false, /* has no effect */ - s.perspective, - s.tracer, - s.logger, - ) - if conf.InitPacketNumberLength != 0 { - ackhandler.SetInitialPacketNumberLength(s.sentPacketHandler, conf.InitPacketNumberLength) - } - - s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize) - oneRTTStream := newCryptoStream() - - var params *wire.TransportParameters - // params := &wire.TransportParameters{ - // InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow), - // InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow), - // InitialMaxStreamDataUni: protocol.ByteCount(s.config.InitialStreamReceiveWindow), - // InitialMaxData: protocol.ByteCount(s.config.InitialConnectionReceiveWindow), - // MaxIdleTimeout: s.config.MaxIdleTimeout, - // MaxBidiStreamNum: protocol.StreamNum(s.config.MaxIncomingStreams), - // MaxUniStreamNum: protocol.StreamNum(s.config.MaxIncomingUniStreams), - // MaxAckDelay: protocol.MaxAckDelayInclGranularity, - // AckDelayExponent: protocol.AckDelayExponent, - // DisableActiveMigration: true, - // // For interoperability with quic-go versions before May 2023, this value must be set to a value - // // different from protocol.DefaultActiveConnectionIDLimit. - // // If set to the default value, it will be omitted from the transport parameters, which will make - // // old quic-go versions interpret it as 0, instead of the default value of 2. - // // See https://github.com/quic-go/quic-go/pull/3806. - // ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs, - // InitialSourceConnectionID: srcConnID, - // } - // if s.config.EnableDatagrams { - // params.MaxDatagramFrameSize = protocol.MaxDatagramFrameSize - // } else { - // params.MaxDatagramFrameSize = protocol.InvalidByteCount - // } - - // [UQUIC] iterate over all Extensions to set the TransportParameters - var tpSet bool -FOR_EACH_TLS_EXTENSION: - for _, ext := range chs.Extensions { - switch ext := ext.(type) { - case *tls.QUICTransportParametersExtension: - params = &wire.TransportParameters{ - InitialSourceConnectionID: srcConnID, - } - params.PopulateFromUQUIC(ext.TransportParameters) - s.connIDManager.SetConnectionIDLimit(params.ActiveConnectionIDLimit) - tpSet = true - break FOR_EACH_TLS_EXTENSION - default: - continue FOR_EACH_TLS_EXTENSION - } - } - if !tpSet { - panic("applied ClientHelloSpec must contain a QUICTransportParametersExtension to proceed") - } - - if s.tracer != nil { - s.tracer.SentTransportParameters(params) - } - cs := handshake.NewUCryptoSetupClient( - destConnID, - params, - tlsConf, - enable0RTT, - s.rttStats, - tracer, - logger, - s.version, - chs, - ) - s.cryptoStreamHandler = cs - s.cryptoStreamManager = newCryptoStreamManager(cs, s.initialStream, s.handshakeStream, oneRTTStream) - s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen) - s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, s.initialStream, s.handshakeStream, s.sentPacketHandler, s.retransmissionQueue, cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective) - if len(tlsConf.ServerName) > 0 { - s.tokenStoreKey = tlsConf.ServerName - } else { - s.tokenStoreKey = conn.RemoteAddr().String() - } - if s.config.TokenStore != nil { - if token := s.config.TokenStore.Pop(s.tokenStoreKey); token != nil { - s.packer.SetToken(token.data) - } - } - return s -} - func (s *connection) preSetup() { s.initialStream = newCryptoStream() s.handshakeStream = newCryptoStream() @@ -2204,9 +2056,7 @@ func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time } s.connIDManager.SentPacket() s.sendQueue.Send(packet.buffer, packet.buffer.Len()) - // [UQUIC] - // fmt.Printf("sendPackedCoalescedPacket:Sending %d bytes\n", packet.buffer.Len()) - // fmt.Printf("sendPackedCoalescedPacket: %v\n", packet.buffer.Data) + return nil } diff --git a/example/uquic/main.go b/example/uquic/main.go index c97654fa..78bde371 100644 --- a/example/uquic/main.go +++ b/example/uquic/main.go @@ -14,7 +14,83 @@ import ( "github.com/quic-go/quic-go/http3" ) -func getCHS() *tls.ClientHelloSpec { +func main() { + keyLogWriter, err := os.Create("./keylog.txt") + if err != nil { + panic(err) + } + + tlsConf := &tls.Config{ + ServerName: "quic.tlsfingerprint.io", + // ServerName: "www.cloudflare.com", + // MinVersion: tls.VersionTLS13, + KeyLogWriter: keyLogWriter, + // NextProtos: []string{"h3"}, + } + + quicConf := &quic.Config{} + + roundTripper := &http3.RoundTripper{ + TLSClientConfig: tlsConf, + QuicConfig: quicConf, + } + uRoundTripper := http3.GetURoundTripper( + roundTripper, + // getFFQUICSpec(), + getCRQUICSpec(), + nil, + ) + defer uRoundTripper.Close() + + hclient := &http.Client{ + Transport: uRoundTripper, + } + + addr := "https://quic.tlsfingerprint.io/qfp/?beautify=true" + // addr := "https://www.cloudflare.com" + + rsp, err := hclient.Get(addr) + if err != nil { + log.Fatal(err) + } + fmt.Printf("Got response for %s: %#v", addr, rsp) + + body := &bytes.Buffer{} + _, err = io.Copy(body, rsp.Body) + if err != nil { + log.Fatal(err) + } + fmt.Printf("Response Body: %s", body.Bytes()) +} + +func getFFQUICSpec() *quic.QUICSpec { + return &quic.QUICSpec{ + InitialPacketSpec: quic.InitialPacketSpec{ + SrcConnIDLength: 3, + DestConnIDLength: 8, + InitPacketNumberLength: 1, + InitPacketNumber: 1, + ClientTokenLength: 0, + FrameOrder: quic.QUICFrames{ + &quic.QUICFrameCrypto{ + Offset: 300, + Length: 0, + }, + &quic.QUICFramePadding{ + Length: 125, + }, + &quic.QUICFramePing{}, + &quic.QUICFrameCrypto{ + Offset: 0, + Length: 300, + }, + }, + }, + ClientHelloSpec: getFFCHS(), + } +} + +func getFFCHS() *tls.ClientHelloSpec { return &tls.ClientHelloSpec{ TLSVersMin: tls.VersionTLS13, TLSVersMax: tls.VersionTLS13, @@ -135,54 +211,146 @@ func getCHS() *tls.ClientHelloSpec { } } -func main() { - keyLogWriter, err := os.Create("./keylog.txt") - if err != nil { - panic(err) +func getCRQUICSpec() *quic.QUICSpec { + return &quic.QUICSpec{ + InitialPacketSpec: quic.InitialPacketSpec{ + SrcConnIDLength: 0, + DestConnIDLength: 8, + InitPacketNumberLength: 1, + InitPacketNumber: 1, + ClientTokenLength: 0, + FrameOrder: quic.QUICFrames{ + &quic.QUICFrameCrypto{ + Offset: 300, + Length: 0, + }, + &quic.QUICFramePadding{ + Length: 125, + }, + &quic.QUICFramePing{}, + &quic.QUICFrameCrypto{ + Offset: 0, + Length: 300, + }, + }, + }, + ClientHelloSpec: getCRCHS(), + } +} +func getCRCHS() *tls.ClientHelloSpec { + return &tls.ClientHelloSpec{ + TLSVersMin: tls.VersionTLS13, + TLSVersMax: tls.VersionTLS13, + CipherSuites: []uint16{ + tls.TLS_AES_128_GCM_SHA256, + tls.TLS_CHACHA20_POLY1305_SHA256, + tls.TLS_AES_256_GCM_SHA384, + }, + CompressionMethods: []uint8{ + 0x0, // no compression + }, + Extensions: []tls.TLSExtension{ + &tls.SNIExtension{}, + &tls.ExtendedMasterSecretExtension{}, + &tls.RenegotiationInfoExtension{ + Renegotiation: tls.RenegotiateOnceAsClient, + }, + &tls.SupportedCurvesExtension{ + Curves: []tls.CurveID{ + tls.CurveX25519, + tls.CurveSECP256R1, + tls.CurveSECP384R1, + tls.CurveSECP521R1, + tls.FakeCurveFFDHE2048, + tls.FakeCurveFFDHE3072, + tls.FakeCurveFFDHE4096, + tls.FakeCurveFFDHE6144, + tls.FakeCurveFFDHE8192, + }, + }, + &tls.ALPNExtension{ + AlpnProtocols: []string{ + "h3", + }, + }, + &tls.StatusRequestExtension{}, + &tls.FakeDelegatedCredentialsExtension{ + SupportedSignatureAlgorithms: []tls.SignatureScheme{ + tls.ECDSAWithP256AndSHA256, + tls.ECDSAWithP384AndSHA384, + tls.ECDSAWithP521AndSHA512, + tls.ECDSAWithSHA1, + }, + }, + &tls.KeyShareExtension{ + KeyShares: []tls.KeyShare{ + { + Group: tls.X25519, + }, + // { + // Group: tls.CurveP256, + // }, + }, + }, + &tls.SupportedVersionsExtension{ + Versions: []uint16{ + tls.VersionTLS13, + }, + }, + &tls.SignatureAlgorithmsExtension{ + SupportedSignatureAlgorithms: []tls.SignatureScheme{ + tls.ECDSAWithP256AndSHA256, + tls.ECDSAWithP384AndSHA384, + tls.ECDSAWithP521AndSHA512, + tls.ECDSAWithSHA1, + tls.PSSWithSHA256, + tls.PSSWithSHA384, + tls.PSSWithSHA512, + tls.PKCS1WithSHA256, + tls.PKCS1WithSHA384, + tls.PKCS1WithSHA512, + tls.PKCS1WithSHA1, + }, + }, + &tls.PSKKeyExchangeModesExtension{ + Modes: []uint8{ + tls.PskModeDHE, + }, + }, + &tls.FakeRecordSizeLimitExtension{ + Limit: 0x4001, + }, + &tls.QUICTransportParametersExtension{ + TransportParameters: tls.TransportParameters{ + &tls.GREASE{ + IdOverride: 0x35967c5b9c37e023, + ValueOverride: []byte{ + 0xfc, 0x97, 0xbb, 0x57, 0xb8, 0x02, 0x19, 0xcd, + }, + }, + tls.InitialMaxStreamsUni(103), + tls.InitialSourceConnectionID([]byte{}), + tls.InitialMaxStreamsBidi(100), + tls.InitialMaxData(15728640), + &tls.VersionInformation{ + ChoosenVersion: tls.VERSION_1, + AvailableVersions: []uint32{ + tls.VERSION_1, + tls.VERSION_GREASE, + }, + LegacyID: true, + }, + tls.MaxIdleTimeout(30000), + tls.MaxUDPPayloadSize(1472), + tls.MaxDatagramFrameSize(65536), + tls.InitialMaxStreamDataBidiLocal(6291456), + tls.InitialMaxStreamDataUni(6291456), + tls.InitialMaxStreamDataBidiRemote(6291456), + }, + }, + &tls.UtlsPaddingExtension{ + GetPaddingLen: tls.BoringPaddingStyle, + }, + }, } - - tlsConf := &tls.Config{ - ServerName: "quic.tlsfingerprint.io", - // ServerName: "www.cloudflare.com", - // MinVersion: tls.VersionTLS13, - KeyLogWriter: keyLogWriter, - // NextProtos: []string{"h3"}, - } - - quicConf := &quic.Config{ - Versions: []quic.VersionNumber{quic.Version1}, - // EnableDatagrams: true, - SrcConnIDLength: 3, // <4 causes timeout - DestConnIDLength: 8, - InitPacketNumber: 0, - InitPacketNumberLength: quic.PacketNumberLen1, // currently only affects the initial packet number - // Versions: []quic.VersionNumber{quic.Version2}, - } - - roundTripper := &http3.RoundTripper{ - TLSClientConfig: tlsConf, - QuicConfig: quicConf, - ClientHelloSpec: getCHS(), - } - defer roundTripper.Close() - - hclient := &http.Client{ - Transport: roundTripper, - } - - addr := "https://quic.tlsfingerprint.io/qfp/" - // addr := "https://www.cloudflare.com" - - rsp, err := hclient.Get(addr) - if err != nil { - log.Fatal(err) - } - fmt.Printf("Got response for %s: %#v", addr, rsp) - - body := &bytes.Buffer{} - _, err = io.Copy(body, rsp.Body) - if err != nil { - log.Fatal(err) - } - fmt.Printf("Response Body: %s", body.Bytes()) } diff --git a/http3/roundtrip.go b/http3/roundtrip.go index d2b9ae2c..ae589438 100644 --- a/http3/roundtrip.go +++ b/http3/roundtrip.go @@ -88,9 +88,6 @@ type RoundTripper struct { newClient func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) // so we can mock it in tests clients map[string]*roundTripCloserWithCount transport *quic.Transport - - // [UQUIC] - ClientHelloSpec *tls.ClientHelloSpec } // RoundTripOpt are options for the Transport.RoundTripOpt method. @@ -194,8 +191,7 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc *roundTr return nil, false, err } r.transport = &quic.Transport{ - Conn: udpConn, - ClientHelloSpec: r.ClientHelloSpec, + Conn: udpConn, } } dial = r.makeDialer() diff --git a/http3/u_roundtrip.go b/http3/u_roundtrip.go new file mode 100644 index 00000000..8e61741f --- /dev/null +++ b/http3/u_roundtrip.go @@ -0,0 +1,192 @@ +package http3 + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + + "github.com/quic-go/quic-go" + tls "github.com/refraction-networking/utls" + "golang.org/x/net/http/httpguts" +) + +type URoundTripper struct { + *RoundTripper + + quicSpec *quic.QUICSpec + uTransportOverride *quic.UTransport +} + +func GetURoundTripper(r *RoundTripper, QUICSpec *quic.QUICSpec, uTransport *quic.UTransport) *URoundTripper { + QUICSpec.UpdateConfig(r.QuicConfig) + + return &URoundTripper{ + RoundTripper: r, + quicSpec: QUICSpec, + uTransportOverride: uTransport, + } +} + +// RoundTripOpt is like RoundTrip, but takes options. +func (r *URoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { + if req.URL == nil { + closeRequestBody(req) + return nil, errors.New("http3: nil Request.URL") + } + if req.URL.Scheme != "https" { + closeRequestBody(req) + return nil, fmt.Errorf("http3: unsupported protocol scheme: %s", req.URL.Scheme) + } + if req.URL.Host == "" { + closeRequestBody(req) + return nil, errors.New("http3: no Host in request URL") + } + if req.Header == nil { + closeRequestBody(req) + return nil, errors.New("http3: nil Request.Header") + } + for k, vv := range req.Header { + if !httpguts.ValidHeaderFieldName(k) { + return nil, fmt.Errorf("http3: invalid http header field name %q", k) + } + for _, v := range vv { + if !httpguts.ValidHeaderFieldValue(v) { + return nil, fmt.Errorf("http3: invalid http header field value %q for key %v", v, k) + } + } + } + + if req.Method != "" && !validMethod(req.Method) { + closeRequestBody(req) + return nil, fmt.Errorf("http3: invalid method %q", req.Method) + } + + hostname := authorityAddr("https", hostnameFromRequest(req)) + cl, isReused, err := r.getClient(hostname, opt.OnlyCachedConn) + if err != nil { + return nil, err + } + defer cl.useCount.Add(-1) + rsp, err := cl.RoundTripOpt(req, opt) + if err != nil { + r.removeClient(hostname) + if isReused { + if nerr, ok := err.(net.Error); ok && nerr.Timeout() { + return r.RoundTripOpt(req, opt) + } + } + } + return rsp, err +} + +// RoundTrip does a round trip. +func (r *URoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return r.RoundTripOpt(req, RoundTripOpt{}) +} + +func (r *URoundTripper) getClient(hostname string, onlyCached bool) (rtc *roundTripCloserWithCount, isReused bool, err error) { + r.mutex.Lock() + defer r.mutex.Unlock() + + if r.clients == nil { + r.clients = make(map[string]*roundTripCloserWithCount) + } + + client, ok := r.clients[hostname] + if !ok { + if onlyCached { + return nil, false, ErrNoCachedConn + } + var err error + newCl := newClient + if r.newClient != nil { + newCl = r.newClient + } + dial := r.Dial + if dial == nil { + if r.transport == nil && r.uTransportOverride == nil { + udpConn, err := net.ListenUDP("udp", nil) + if err != nil { + return nil, false, err + } + r.uTransportOverride = &quic.UTransport{ + Transport: &quic.Transport{ + Conn: udpConn, + }, + QUICSpec: r.quicSpec, + } + } + dial = r.makeDialer() + } + c, err := newCl( + hostname, + r.TLSClientConfig, + &roundTripperOpts{ + EnableDatagram: r.EnableDatagrams, + DisableCompression: r.DisableCompression, + MaxHeaderBytes: r.MaxResponseHeaderBytes, + StreamHijacker: r.StreamHijacker, + UniStreamHijacker: r.UniStreamHijacker, + }, + r.QuicConfig, + dial, + ) + if err != nil { + return nil, false, err + } + client = &roundTripCloserWithCount{roundTripCloser: c} + r.clients[hostname] = client + } else if client.HandshakeComplete() { + isReused = true + } + client.useCount.Add(1) + return client, isReused, nil +} + +func (r *URoundTripper) Close() error { + r.mutex.Lock() + defer r.mutex.Unlock() + for _, client := range r.clients { + if err := client.Close(); err != nil { + return err + } + } + r.clients = nil + if r.transport != nil { + if err := r.transport.Close(); err != nil { + return err + } + if err := r.transport.Conn.Close(); err != nil { + return err + } + r.transport = nil + } + if r.uTransportOverride != nil { + if err := r.uTransportOverride.Close(); err != nil { + return err + } + if err := r.uTransportOverride.Conn.Close(); err != nil { + return err + } + r.uTransportOverride = nil + } + return nil +} + +// makeDialer makes a QUIC dialer using r.udpConn. +func (r *URoundTripper) makeDialer() func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { + return func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + if r.uTransportOverride != nil { + return r.uTransportOverride.DialEarly(ctx, udpAddr, tlsCfg, cfg) + } else if r.transport == nil { + return nil, errors.New("http3: no QUIC transport available") + } + return r.transport.DialEarly(ctx, udpAddr, tlsCfg, cfg) + } +} diff --git a/interface.go b/interface.go index 92530923..dbfedccf 100644 --- a/interface.go +++ b/interface.go @@ -333,12 +333,6 @@ type Config struct { // Enable QUIC datagram support (RFC 9221). EnableDatagrams bool Tracer func(context.Context, logging.Perspective, ConnectionID) logging.ConnectionTracer - - // TransportParameters override other transport parameters set by the Config. - SrcConnIDLength int // [UQUIC] - DestConnIDLength int // [UQUIC] - InitPacketNumber uint64 // [UQUIC] - InitPacketNumberLength PacketNumberLen // [UQUIC] } type ClientHelloInfo struct { diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 82d45c25..08684aa2 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -96,8 +96,6 @@ type sentPacketHandler struct { tracer logging.ConnectionTracer logger utils.Logger - - initialPacketNumberLength protocol.PacketNumberLen // [UQUIC] } var ( @@ -138,12 +136,6 @@ func newSentPacketHandler( } } -func SetInitialPacketNumberLength(h SentPacketHandler, pnLen protocol.PacketNumberLen) { - if sph, ok := h.(*sentPacketHandler); ok { - sph.initialPacketNumberLength = pnLen - } -} - func (h *sentPacketHandler) removeFromBytesInFlight(p *packet) { if p.includedInBytesInFlight { if p.Length > h.bytesInFlight { @@ -725,11 +717,6 @@ func (h *sentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel) pn := pnSpace.pns.Peek() // See section 17.1 of RFC 9000. - // [UQUIC] This kinda breaks PN length mimicry. - if encLevel == protocol.EncryptionInitial && h.initialPacketNumberLength != 0 { - return pn, h.initialPacketNumberLength - } - return pn, protocol.GetPacketNumberLengthForHeader(pn, pnSpace.largestAcked) } diff --git a/internal/ackhandler/u_ackhandler.go b/internal/ackhandler/u_ackhandler.go new file mode 100644 index 00000000..56886795 --- /dev/null +++ b/internal/ackhandler/u_ackhandler.go @@ -0,0 +1,23 @@ +package ackhandler + +import ( + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/logging" +) + +// [UQUIC] +func NewUAckHandler( + initialPacketNumber protocol.PacketNumber, + initialMaxDatagramSize protocol.ByteCount, + rttStats *utils.RTTStats, + clientAddressValidated bool, + pers protocol.Perspective, + tracer logging.ConnectionTracer, + logger utils.Logger, +) (SentPacketHandler, ReceivedPacketHandler) { + sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, clientAddressValidated, pers, tracer, logger) + return &uSentPacketHandler{ + sentPacketHandler: sph, + }, newReceivedPacketHandler(sph, rttStats, logger) +} diff --git a/internal/ackhandler/u_sent_packet_handler.go b/internal/ackhandler/u_sent_packet_handler.go new file mode 100644 index 00000000..1e1dfcd3 --- /dev/null +++ b/internal/ackhandler/u_sent_packet_handler.go @@ -0,0 +1,30 @@ +package ackhandler + +import "github.com/quic-go/quic-go/internal/protocol" + +type uSentPacketHandler struct { + *sentPacketHandler + + initialPacketNumberLength protocol.PacketNumberLen // [UQUIC] +} + +func (h *uSentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) { + pnSpace := h.getPacketNumberSpace(encLevel) + pn := pnSpace.pns.Peek() + // See section 17.1 of RFC 9000. + + // [UQUIC] Otherwise it kinda breaks PN length mimicry. + if encLevel == protocol.EncryptionInitial && h.initialPacketNumberLength != 0 { + return pn, h.initialPacketNumberLength + } + // [/UQUIC] + + return pn, protocol.GetPacketNumberLengthForHeader(pn, pnSpace.largestAcked) +} + +// [UQUIC] +func SetInitialPacketNumberLength(h SentPacketHandler, pnLen protocol.PacketNumberLen) { + if sph, ok := h.(*uSentPacketHandler); ok { + sph.initialPacketNumberLength = pnLen + } +} diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 95255d14..670dbd71 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -102,41 +102,6 @@ func NewCryptoSetupClient( return cs } -// [UQUIC] -// NewUCryptoSetupClient creates a new crypto setup for the client with UTLS -func NewUCryptoSetupClient( - connID protocol.ConnectionID, - tp *wire.TransportParameters, - tlsConf *tls.Config, - enable0RTT bool, - rttStats *utils.RTTStats, - tracer logging.ConnectionTracer, - logger utils.Logger, - version protocol.VersionNumber, - chs *tls.ClientHelloSpec, -) CryptoSetup { - cs := newCryptoSetup( - connID, - tp, - rttStats, - tracer, - logger, - protocol.PerspectiveClient, - version, - ) - - tlsConf = tlsConf.Clone() - tlsConf.MinVersion = tls.VersionTLS13 - quicConf := &qtls.QUICConfig{TLSConfig: tlsConf} - qtls.SetupConfigForClient(quicConf, cs.marshalDataForSessionState, cs.handleDataFromSessionState) - cs.tlsConf = tlsConf - - cs.conn = qtls.UQUICClient(quicConf, chs) - // cs.conn.SetTransportParameters(cs.ourParams.Marshal(protocol.PerspectiveClient)) // [UQUIC] doesn't require this - - return cs -} - // NewCryptoSetupServer creates a new crypto setup for the server func NewCryptoSetupServer( connID protocol.ConnectionID, @@ -281,6 +246,7 @@ func (h *cryptoSetup) handleEvent(ev qtls.QUICEvent) (done bool, err error) { return false, h.handleTransportParameters(ev.Data) case qtls.QUICTransportParametersRequired: h.conn.SetTransportParameters(h.ourParams.Marshal(h.perspective)) + // [UQUIC] doesn't expect this and may fail return false, nil case qtls.QUICRejectedEarlyData: h.rejected0RTT() diff --git a/internal/handshake/u_crypto_setup.go b/internal/handshake/u_crypto_setup.go new file mode 100644 index 00000000..38488ddd --- /dev/null +++ b/internal/handshake/u_crypto_setup.go @@ -0,0 +1,45 @@ +package handshake + +import ( + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/qtls" + "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/internal/wire" + "github.com/quic-go/quic-go/logging" + tls "github.com/refraction-networking/utls" +) + +// [UQUIC] +// NewUCryptoSetupClient creates a new crypto setup for the client with UTLS +func NewUCryptoSetupClient( + connID protocol.ConnectionID, + tp *wire.TransportParameters, + tlsConf *tls.Config, + enable0RTT bool, + rttStats *utils.RTTStats, + tracer logging.ConnectionTracer, + logger utils.Logger, + version protocol.VersionNumber, + chs *tls.ClientHelloSpec, +) CryptoSetup { + cs := newCryptoSetup( + connID, + tp, + rttStats, + tracer, + logger, + protocol.PerspectiveClient, + version, + ) + + tlsConf = tlsConf.Clone() + tlsConf.MinVersion = tls.VersionTLS13 + quicConf := &qtls.QUICConfig{TLSConfig: tlsConf} + qtls.SetupConfigForClient(quicConf, cs.marshalDataForSessionState, cs.handleDataFromSessionState) + cs.tlsConf = tlsConf + + cs.conn = qtls.UQUICClient(quicConf, chs) + // cs.conn.SetTransportParameters(cs.ourParams.Marshal(protocol.PerspectiveClient)) // [UQUIC] doesn't require this + + return cs +} diff --git a/internal/protocol/connection_id.go b/internal/protocol/connection_id.go index 042d3a99..77259b5f 100644 --- a/internal/protocol/connection_id.go +++ b/internal/protocol/connection_id.go @@ -68,11 +68,6 @@ func GenerateConnectionIDForInitial() (ConnectionID, error) { return GenerateConnectionID(l) } -// [UQUIC] -func GenerateConnectionIDForInitialWithLen(l int) (ConnectionID, error) { - return GenerateConnectionID(l) -} - // ReadConnectionID reads a connection ID of length len from the given io.Reader. // It returns io.EOF if there are not enough bytes to read. func ReadConnectionID(r io.Reader, l int) (ConnectionID, error) { diff --git a/internal/protocol/u_connection_id.go b/internal/protocol/u_connection_id.go new file mode 100644 index 00000000..8f3b388d --- /dev/null +++ b/internal/protocol/u_connection_id.go @@ -0,0 +1,16 @@ +package protocol + +// [UQUIC] +func GenerateConnectionIDForInitialWithLen(l int) (ConnectionID, error) { + return GenerateConnectionID(l) +} + +type ExpEmptyConnectionIDGenerator struct{} + +func (g *ExpEmptyConnectionIDGenerator) GenerateConnectionID() (ConnectionID, error) { + return GenerateConnectionID(0) +} + +func (g *ExpEmptyConnectionIDGenerator) ConnectionIDLen() int { + return 0 +} diff --git a/packet_packer.go b/packet_packer.go index e6ab4bec..21af817c 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -332,11 +332,6 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol. if initialPayload.length > 0 { size += p.longHeaderPacketLength(initialHdr, initialPayload, v) + protocol.ByteCount(initialSealer.Overhead()) } - - // // [UQUIC] - // if len(initialPayload.frames) > 0 { - // fmt.Printf("onlyAck: %t, PackCoalescedPacket: %v\n", onlyAck, initialPayload.frames[0].Frame) - // } } // Add a Handshake packet. @@ -401,24 +396,12 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol. longHdrPackets: make([]*longHeaderPacket, 0, 3), } if initialPayload.length > 0 { - if onlyAck || len(initialPayload.frames) == 0 { - // padding := p.initialPaddingLen(initialPayload.frames, size, maxPacketSize) - // cont, err := p.appendLongHeaderPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer, v) - // if err != nil { - // return nil, err - // } - // packet.longHdrPackets = append(packet.longHdrPackets, cont) - return nil, nil // [UQUIC] not to send the ACK frame for Initial - } else { // [UQUIC] - cont, err := p.appendLongHeaderPacketExternalPadding(buffer, initialHdr, initialPayload, protocol.EncryptionInitial, initialSealer, v) - if err != nil { - return nil, err - } - - // fmt.Printf("!onlyAck buffer: %v\n", buffer.Data) - - packet.longHdrPackets = append(packet.longHdrPackets, cont) + padding := p.initialPaddingLen(initialPayload.frames, size, maxPacketSize) + cont, err := p.appendLongHeaderPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer, v) + if err != nil { + return nil, err } + packet.longHdrPackets = append(packet.longHdrPackets, cont) } if handshakePayload.length > 0 { cont, err := p.appendLongHeaderPacket(buffer, handshakeHdr, handshakePayload, 0, protocol.EncryptionHandshake, handshakeSealer, v) @@ -688,7 +671,6 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, m return nil, err } hdr, pl = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionInitial, false, true, v) - fmt.Printf("MaybePackProbePacket: %x\n", pl.frames[0]) case protocol.EncryptionHandshake: var err error sealer, err = p.cryptoSetup.GetHandshakeSealer() @@ -768,10 +750,6 @@ func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire } paddingLen += padding - if encLevel == protocol.EncryptionInitial { - paddingLen = 0 - } - header.Length = pnLen + protocol.ByteCount(sealer.Overhead()) + pl.length + paddingLen startLen := len(buffer.Data) @@ -800,49 +778,6 @@ func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire }, nil } -// [UQUIC] -func (p *packetPacker) appendLongHeaderPacketExternalPadding(buffer *packetBuffer, header *wire.ExtendedHeader, pl payload, encLevel protocol.EncryptionLevel, sealer sealer, v protocol.VersionNumber) (*longHeaderPacket, error) { - pnLen := protocol.ByteCount(header.PacketNumberLen) - header.Length = pnLen + protocol.ByteCount(sealer.Overhead()) + pl.length - - startLen := len(buffer.Data) - raw := buffer.Data[startLen:] // [UQUIC] raw is a sub-slice of buffer.Data, whose len < size - raw, err := header.Append(raw, v) - if err != nil { - return nil, err - } - - fmt.Printf("Pre-Payload: %x\n", raw) - - payloadOffset := protocol.ByteCount(len(raw)) - raw, err = p.appendCustomInitialPacketPayload(raw, pl, 0, v) - if err != nil { - return nil, err - } - - fmt.Printf("Pre-Encryption: %x\n", raw) - - raw = p.encryptPacket(raw, sealer, header.PacketNumber, payloadOffset, pnLen) - buffer.Data = buffer.Data[:len(buffer.Data)+len(raw)] - - fmt.Printf("Post-Encryption: %x\n", raw) - - // [UQUIC] - // append zero to buffer.Data until 1200 bytes - buffer.Data = append(buffer.Data, make([]byte, 1357-len(buffer.Data))...) - - if pn := p.pnManager.PopPacketNumber(encLevel); pn != header.PacketNumber { - return nil, fmt.Errorf("packetPacker BUG: Peeked and Popped packet numbers do not match: expected %d, got %d", pn, header.PacketNumber) - } - return &longHeaderPacket{ - header: header, - ack: pl.ack, - frames: pl.frames, - streamFrames: pl.streamFrames, - length: protocol.ByteCount(len(raw)), - }, nil -} - func (p *packetPacker) appendShortHeaderPacket( buffer *packetBuffer, connID protocol.ConnectionID, @@ -930,44 +865,6 @@ func (p *packetPacker) appendPacketPayload(raw []byte, pl payload, paddingLen pr return raw, nil } -func (p *packetPacker) appendCustomInitialPacketPayload(raw []byte, pl payload, paddingLen protocol.ByteCount, v protocol.VersionNumber) ([]byte, error) { - payloadOffset := len(raw) - - // [UQUIC] ignores the default ACK/PADDING frame and uses its own frames - // if pl.ack != nil { - // var err error - // raw, err = pl.ack.Append(raw, v) - // if err != nil { - // return nil, err - // } - // } - // if paddingLen > 0 { - // raw = append(raw, make([]byte, paddingLen)...) - // } - - for _, f := range pl.frames { - var err error - raw, err = f.Frame.Append(raw, v) - if err != nil { - return nil, err - } - fmt.Printf("UQUIC: appending frame %v\n", f) - } - for _, f := range pl.streamFrames { - var err error - raw, err = f.Frame.Append(raw, v) - if err != nil { - return nil, err - } - fmt.Printf("UQUIC: appending stream frame %v\n", f) - } - - if payloadSize := protocol.ByteCount(len(raw)-payloadOffset) - paddingLen; payloadSize != pl.length { - return nil, fmt.Errorf("PacketPacker BUG: payload size inconsistent (expected %d, got %d bytes)", pl.length, payloadSize) - } - return raw, nil -} - func (p *packetPacker) encryptPacket(raw []byte, sealer sealer, pn protocol.PacketNumber, payloadOffset, pnLen protocol.ByteCount) []byte { _ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], pn, raw[:payloadOffset]) raw = raw[:len(raw)+sealer.Overhead()] diff --git a/transport.go b/transport.go index b0b527c4..e002a261 100644 --- a/transport.go +++ b/transport.go @@ -87,8 +87,6 @@ type Transport struct { isSingleUse bool // was created for a single server or client, i.e. by calling quic.Listen or quic.Dial logger utils.Logger - - ClientHelloSpec *tls.ClientHelloSpec // [UQUIC] } // Listen starts listening for incoming QUIC connections. @@ -156,12 +154,6 @@ func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config } conf = populateConfig(conf) - // [UQUIC] - if conf.SrcConnIDLength != 0 { - t.ConnectionIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: conf.SrcConnIDLength} - } - // [/UQUIC] - if err := t.init(t.isSingleUse); err != nil { return nil, err } @@ -172,9 +164,6 @@ func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config tlsConf = tlsConf.Clone() tlsConf.MinVersion = tls.VersionTLS13 - if t.ClientHelloSpec != nil { // [UQUIC] - return dialWithCHS(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, false, t.ClientHelloSpec) - } return dial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, false) } @@ -185,12 +174,6 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C } conf = populateConfig(conf) - // [UQUIC] - if conf.SrcConnIDLength != 0 { - t.ConnectionIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: conf.SrcConnIDLength} - } - // [/UQUIC] - if err := t.init(t.isSingleUse); err != nil { return nil, err } @@ -201,9 +184,6 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C tlsConf = tlsConf.Clone() tlsConf.MinVersion = tls.VersionTLS13 - if t.ClientHelloSpec != nil { // [UQUIC] - return dialWithCHS(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, false, t.ClientHelloSpec) - } return dial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, true) } diff --git a/u_client.go b/u_client.go new file mode 100644 index 00000000..e3b0a408 --- /dev/null +++ b/u_client.go @@ -0,0 +1,150 @@ +package quic + +import ( + "context" + "errors" + + "github.com/quic-go/quic-go/internal/protocol" + tls "github.com/refraction-networking/utls" +) + +type uClient struct { + *client + uSpec *QUICSpec // [UQUIC] +} + +func udial( + ctx context.Context, + conn sendConn, + connIDGenerator ConnectionIDGenerator, + packetHandlers packetHandlerManager, + tlsConf *tls.Config, + config *Config, + onClose func(), + use0RTT bool, + uSpec *QUICSpec, // [UQUIC] +) (quicConn, error) { + c, err := newClient(conn, connIDGenerator, config, tlsConf, onClose, use0RTT) + if err != nil { + return nil, err + } + c.packetHandlers = packetHandlers + + // [UQUIC] + if uSpec.InitialPacketSpec.DestConnIDLength > 0 { + destConnID, err := generateConnectionIDForInitialWithLength(uSpec.InitialPacketSpec.DestConnIDLength) + if err != nil { + return nil, err + } + c.destConnID = destConnID + } + c.initialPacketNumber = protocol.PacketNumber(uSpec.InitialPacketSpec.InitPacketNumber) + // [/UQUIC] + + c.tracingID = nextConnTracingID() + if c.config.Tracer != nil { + c.tracer = c.config.Tracer(context.WithValue(ctx, ConnectionTracingKey, c.tracingID), protocol.PerspectiveClient, c.destConnID) + } + if c.tracer != nil { + c.tracer.StartedConnection(c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID) + } + + // [UQUIC] + uc := &uClient{ + client: c, + uSpec: uSpec, + } + // [/UQUIC] + + if err := uc.dial(ctx); err != nil { + return nil, err + } + return uc.conn, nil +} + +func (c *uClient) dial(ctx context.Context) error { + c.logger.Infof("Starting new uQUIC connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID, c.version) + + // [UQUIC] + if c.uSpec.ClientHelloSpec == nil { + c.conn = newClientConnection( + c.sendConn, + c.packetHandlers, + c.destConnID, + c.srcConnID, + c.connIDGenerator, + c.config, + c.tlsConf, + c.initialPacketNumber, + c.use0RTT, + c.hasNegotiatedVersion, + c.tracer, + c.tracingID, + c.logger, + c.version, + ) + } else { + // [UQUIC]: use custom version of the connection + c.conn = newUClientConnection( + c.sendConn, + c.packetHandlers, + c.destConnID, + c.srcConnID, + c.connIDGenerator, + c.config, + c.tlsConf, + c.initialPacketNumber, + c.use0RTT, + c.hasNegotiatedVersion, + c.tracer, + c.tracingID, + c.logger, + c.version, + c.uSpec, + ) + } + // [/UQUIC] + + c.packetHandlers.Add(c.srcConnID, c.conn) + + errorChan := make(chan error, 1) + recreateChan := make(chan errCloseForRecreating) + go func() { + err := c.conn.run() + var recreateErr *errCloseForRecreating + if errors.As(err, &recreateErr) { + recreateChan <- *recreateErr + return + } + if c.onClose != nil { + c.onClose() + } + errorChan <- err // returns as soon as the connection is closed + }() + + // only set when we're using 0-RTT + // Otherwise, earlyConnChan will be nil. Receiving from a nil chan blocks forever. + var earlyConnChan <-chan struct{} + if c.use0RTT { + earlyConnChan = c.conn.earlyConnReady() + } + + select { + case <-ctx.Done(): + c.conn.shutdown() + return ctx.Err() + case err := <-errorChan: + return err + case recreateErr := <-recreateChan: + c.initialPacketNumber = recreateErr.nextPacketNumber + c.version = recreateErr.nextVersion + c.hasNegotiatedVersion = true + return c.dial(ctx) + case <-earlyConnChan: + // ready to send 0-RTT data + return nil + case <-c.conn.HandshakeComplete(): + // handshake successfully completed + return nil + } +} diff --git a/u_conn_id_manager.go b/u_conn_id_manager.go new file mode 100644 index 00000000..9a9d1e93 --- /dev/null +++ b/u_conn_id_manager.go @@ -0,0 +1,6 @@ +package quic + +// [UQUIC] +func (h *connIDManager) SetConnectionIDLimit(limit uint64) { + h.connectionIDLimit = limit +} diff --git a/u_connection.go b/u_connection.go new file mode 100644 index 00000000..9a1803e4 --- /dev/null +++ b/u_connection.go @@ -0,0 +1,170 @@ +package quic + +import ( + "context" + + "github.com/quic-go/quic-go/internal/ackhandler" + "github.com/quic-go/quic-go/internal/handshake" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/internal/wire" + "github.com/quic-go/quic-go/logging" + tls "github.com/refraction-networking/utls" +) + +// [UQUIC] +var newUClientConnection = func( + conn sendConn, + runner connRunner, + destConnID protocol.ConnectionID, + srcConnID protocol.ConnectionID, + connIDGenerator ConnectionIDGenerator, + conf *Config, + tlsConf *tls.Config, + initialPacketNumber protocol.PacketNumber, + enable0RTT bool, + hasNegotiatedVersion bool, + tracer logging.ConnectionTracer, + tracingID uint64, + logger utils.Logger, + v protocol.VersionNumber, + // chs *tls.ClientHelloSpec, + // initPktNbrLen PacketNumberLen, + // qfs QUICFrames, + // udpDatagramMinSize int, + uSpec *QUICSpec, // [UQUIC] +) quicConn { + s := &connection{ + conn: conn, + config: conf, + origDestConnID: destConnID, + handshakeDestConnID: destConnID, + srcConnIDLen: srcConnID.Len(), + perspective: protocol.PerspectiveClient, + logID: destConnID.String(), + logger: logger, + tracer: tracer, + versionNegotiated: hasNegotiatedVersion, + version: v, + } + s.connIDManager = newConnIDManager( + destConnID, + func(token protocol.StatelessResetToken) { runner.AddResetToken(token, s) }, + runner.RemoveResetToken, + s.queueControlFrame, + ) + + s.connIDGenerator = newConnIDGenerator( + srcConnID, + nil, + func(connID protocol.ConnectionID) { runner.Add(connID, s) }, + runner.GetStatelessResetToken, + runner.Remove, + runner.Retire, + runner.ReplaceWithClosed, + s.queueControlFrame, + connIDGenerator, + ) + s.preSetup() + s.ctx, s.ctxCancel = context.WithCancelCause(context.WithValue(context.Background(), ConnectionTracingKey, tracingID)) + s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewUAckHandler( // [UQUIC] + initialPacketNumber, + getMaxPacketSize(s.conn.RemoteAddr()), + s.rttStats, + false, /* has no effect */ + s.perspective, + s.tracer, + s.logger, + ) + // [UQUIC] + if uSpec.InitialPacketSpec.InitPacketNumberLength != 0 { + ackhandler.SetInitialPacketNumberLength(s.sentPacketHandler, uSpec.InitialPacketSpec.InitPacketNumberLength) + } + + s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize) + oneRTTStream := newCryptoStream() + + var params *wire.TransportParameters + + if uSpec.ClientHelloSpec != nil { + // iterate over all Extensions to set the TransportParameters + var tpSet bool + FOR_EACH_TLS_EXTENSION: + for _, ext := range uSpec.ClientHelloSpec.Extensions { + switch ext := ext.(type) { + case *tls.QUICTransportParametersExtension: + params = &wire.TransportParameters{ + InitialSourceConnectionID: srcConnID, + } + params.PopulateFromUQUIC(ext.TransportParameters) + s.connIDManager.SetConnectionIDLimit(params.ActiveConnectionIDLimit) + tpSet = true + break FOR_EACH_TLS_EXTENSION + default: + continue FOR_EACH_TLS_EXTENSION + } + } + if !tpSet { + panic("applied ClientHelloSpec must contain a QUICTransportParametersExtension to proceed") + } + } else { + // use default TransportParameters + params = &wire.TransportParameters{ + InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow), + InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow), + InitialMaxStreamDataUni: protocol.ByteCount(s.config.InitialStreamReceiveWindow), + InitialMaxData: protocol.ByteCount(s.config.InitialConnectionReceiveWindow), + MaxIdleTimeout: s.config.MaxIdleTimeout, + MaxBidiStreamNum: protocol.StreamNum(s.config.MaxIncomingStreams), + MaxUniStreamNum: protocol.StreamNum(s.config.MaxIncomingUniStreams), + MaxAckDelay: protocol.MaxAckDelayInclGranularity, + AckDelayExponent: protocol.AckDelayExponent, + DisableActiveMigration: true, + // For interoperability with quic-go versions before May 2023, this value must be set to a value + // different from protocol.DefaultActiveConnectionIDLimit. + // If set to the default value, it will be omitted from the transport parameters, which will make + // old quic-go versions interpret it as 0, instead of the default value of 2. + // See https://github.com/quic-go/quic-go/pull/3806. + ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs, + InitialSourceConnectionID: srcConnID, + } + if s.config.EnableDatagrams { + params.MaxDatagramFrameSize = protocol.MaxDatagramFrameSize + } else { + params.MaxDatagramFrameSize = protocol.InvalidByteCount + } + } + + if s.tracer != nil { + s.tracer.SentTransportParameters(params) + } + cs := handshake.NewUCryptoSetupClient( + destConnID, + params, + tlsConf, + enable0RTT, + s.rttStats, + tracer, + logger, + s.version, + uSpec.ClientHelloSpec, + ) + s.cryptoStreamHandler = cs + s.cryptoStreamManager = newCryptoStreamManager(cs, s.initialStream, s.handshakeStream, oneRTTStream) + s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen) + s.packer = newUPacketPacker( + newPacketPacker(srcConnID, s.connIDManager.Get, s.initialStream, s.handshakeStream, s.sentPacketHandler, s.retransmissionQueue, cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective), + uSpec, + ) + if len(tlsConf.ServerName) > 0 { + s.tokenStoreKey = tlsConf.ServerName + } else { + s.tokenStoreKey = conn.RemoteAddr().String() + } + if s.config.TokenStore != nil { + if token := s.config.TokenStore.Pop(s.tokenStoreKey); token != nil { + s.packer.SetToken(token.data) + } + } + return s +} diff --git a/u_initial_packet_spec.go b/u_initial_packet_spec.go new file mode 100644 index 00000000..070fd292 --- /dev/null +++ b/u_initial_packet_spec.go @@ -0,0 +1,205 @@ +package quic + +import ( + "bytes" + "crypto/rand" + "errors" + + "github.com/gaukas/clienthellod" + "github.com/quic-go/quic-go/quicvarint" +) + +type InitialPacketSpec struct { + // SrcConnIDLength specifies how many bytes should the SrcConnID be + SrcConnIDLength int + + // DestConnIDLength specifies how many bytes should the DestConnID be + DestConnIDLength int + + // InitPacketNumberLength specifies how many bytes should the InitPacketNumber + // be interpreted as. It is usually 1 or 2 bytes. If unset, UQUIC will use the + // default algorithm to compute the length which is at least 2 bytes. + InitPacketNumberLength PacketNumberLen + + // InitPacketNumber is the packet number of the first Initial packet. Following + // Initial packets, if any, will increment the Packet Number accordingly. + InitPacketNumber uint64 // [UQUIC] + + // TokenStore is used to store and retrieve tokens. If set, will override the + // one set in the Config. + TokenStore TokenStore + + // If ClientTokenLength is set when TokenStore is not set, a dummy TokenStore + // will be created to randomly generate tokens of the specified length for + // Pop() calls with any key and silently drop any Put() calls. + // + // However, the tokens will not be stored anywhere and are expected to be + // invalid since not assigned by the server. + ClientTokenLength int + + // QUICFrames specifies a list of QUIC frames to be sent in the first Initial + // packet. + // + // If nil, it will be treated as a list with only a single QUICFrameCrypto. + FrameOrder QUICFrames +} + +func (ps *InitialPacketSpec) UpdateConfig(conf *Config) { + conf.TokenStore = ps.getTokenStore() +} + +func (ps *InitialPacketSpec) getTokenStore() TokenStore { + if ps.TokenStore != nil { + return ps.TokenStore + } + + if ps.ClientTokenLength > 0 { + return &dummyTokenStore{ + tokenLength: ps.ClientTokenLength, + } + } + + return nil +} + +type dummyTokenStore struct { + tokenLength int +} + +func (d *dummyTokenStore) Pop(key string) (token *ClientToken) { + var data []byte = make([]byte, d.tokenLength) + rand.Read(data) + + return &ClientToken{ + data: data, + } +} + +func (d *dummyTokenStore) Put(_ string, _ *ClientToken) { + // Do nothing +} + +type QUICFrames []QUICFrame + +func (qfs QUICFrames) MarshalWithCryptoData(cryptoData []byte) (payload []byte, err error) { + if len(qfs) == 0 { // If no frames specified, send a single crypto frame + qfs = QUICFrames{QUICFrameCrypto{0, 0}} + return qfs.MarshalWithCryptoData(cryptoData) + } + + for _, frame := range qfs { + var frameBytes []byte + if offset, length, cryptoOK := frame.CryptoFrameInfo(); cryptoOK { + if length == 0 { + // calculate length: from offset to the end of cryptoData + length = len(cryptoData) - offset + } + frameBytes = []byte{0x06} // CRYPTO frame type + frameBytes = quicvarint.Append(frameBytes, uint64(offset)) + frameBytes = quicvarint.Append(frameBytes, uint64(length)) + frameCryptoData := make([]byte, length) + copy(frameCryptoData, cryptoData[offset:]) // copy at most length bytes + frameBytes = append(frameBytes, frameCryptoData...) + } else { // Handle none crypto frames: read and append to payload + frameBytes, err = frame.Read() + if err != nil { + return nil, err + } + } + payload = append(payload, frameBytes...) + } + return payload, nil +} + +func (qfs QUICFrames) MarshalWithFrames(frames []byte) (payload []byte, err error) { + // parse frames + r := bytes.NewReader(frames) + qchframes, err := clienthellod.ReadAllFrames(r) + if err != nil { + return nil, err + } + + // parse crypto data + cryptoData, err := clienthellod.ReassembleCRYPTOFrames(qchframes) + if err != nil { + return nil, err + } + + // marshal + return qfs.MarshalWithCryptoData(cryptoData) +} + +type QUICFrame interface { + // None crypto frames should return false for cryptoOK + CryptoFrameInfo() (offset, length int, cryptoOK bool) + + // None crypto frames should return the byte representation of the frame. + // Crypto frames' behavior is undefined and unused. + Read() ([]byte, error) +} + +// QUICFrameCrypto is used to specify the crypto frames containing the TLS ClientHello +// to be sent in the first Initial packet. +type QUICFrameCrypto struct { + // Offset is used to specify the starting offset of the crypto frame. + // Used when sending multiple crypto frames in a single packet. + // + // Multiple crypto frames in a single packet must not overlap and must + // make up an entire crypto stream continuously. + Offset int + + // Length is used to specify the length of the crypto frame. + // + // Must be set if it is NOT the last crypto frame in a packet. + Length int +} + +// CryptoFrameInfo() implements the QUICFrame interface. +// +// Crypto frames are later replaced by the crypto message using the information +// returned by this function. +func (q QUICFrameCrypto) CryptoFrameInfo() (offset, length int, cryptoOK bool) { + return q.Offset, q.Length, true +} + +// Read() implements the QUICFrame interface. +// +// Crypto frames are later replaced by the crypto message, so they are not Read()-able. +func (q QUICFrameCrypto) Read() ([]byte, error) { + return nil, errors.New("crypto frames are not Read()-able") +} + +// QUICFramePadding is used to specify the padding frames to be sent in the first Initial +// packet. +type QUICFramePadding struct { + // Length is used to specify the length of the padding frame. + Length int +} + +// CryptoFrameInfo() implements the QUICFrame interface. +func (q QUICFramePadding) CryptoFrameInfo() (offset, length int, cryptoOK bool) { + return 0, 0, false +} + +// Read() implements the QUICFrame interface. +// +// Padding simply returns a slice of bytes of the specified length filled with 0. +func (q QUICFramePadding) Read() ([]byte, error) { + return make([]byte, q.Length), nil +} + +// QUICFramePing is used to specify the ping frames to be sent in the first Initial +// packet. +type QUICFramePing struct{} + +// CryptoFrameInfo() implements the QUICFrame interface. +func (q QUICFramePing) CryptoFrameInfo() (offset, length int, cryptoOK bool) { + return 0, 0, false +} + +// Read() implements the QUICFrame interface. +// +// Ping simply returns a slice of bytes of size 1 with value 0x01(PING). +func (q QUICFramePing) Read() ([]byte, error) { + return []byte{0x01}, nil +} diff --git a/u_quic_spec_test.go b/u_initial_packet_spec_test.go similarity index 100% rename from u_quic_spec_test.go rename to u_initial_packet_spec_test.go diff --git a/u_packet_packer.go b/u_packet_packer.go new file mode 100644 index 00000000..0d984507 --- /dev/null +++ b/u_packet_packer.go @@ -0,0 +1,243 @@ +package quic + +import ( + "fmt" + + "github.com/quic-go/quic-go/internal/handshake" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/wire" +) + +// uPacketPacker is an extended packetPacker which is used +// to customize some of the packetPacker's behaviors for +// UQUIC. +type uPacketPacker struct { + *packetPacker + + // initPktNbrLen PacketNumberLen + // qfs QUICFrames // [UQUIC] uses QUICFrames to customize encrypted frames + // udpDatagramMinSize int + uSpec *QUICSpec // [UQUIC] +} + +func newUPacketPacker( + packetPacker *packetPacker, + uSpec *QUICSpec, // [UQUIC] +) *uPacketPacker { + return &uPacketPacker{ + packetPacker: packetPacker, + uSpec: uSpec, // [UQUIC] + } +} + +// PackCoalescedPacket packs a new packet. +// It packs an Initial / Handshake if there is data to send in these packet number spaces. +// It should only be called before the handshake is confirmed. +func (p *uPacketPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) { + var ( + initialHdr, handshakeHdr, zeroRTTHdr *wire.ExtendedHeader + initialPayload, handshakePayload, zeroRTTPayload, oneRTTPayload payload + oneRTTPacketNumber protocol.PacketNumber + oneRTTPacketNumberLen protocol.PacketNumberLen + ) + // Try packing an Initial packet. + initialSealer, err := p.cryptoSetup.GetInitialSealer() + if err != nil && err != handshake.ErrKeysDropped { + return nil, err + } + var size protocol.ByteCount + if initialSealer != nil { + initialHdr, initialPayload = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(initialSealer.Overhead()), protocol.EncryptionInitial, onlyAck, true, v) + if initialPayload.length > 0 { + size += p.longHeaderPacketLength(initialHdr, initialPayload, v) + protocol.ByteCount(initialSealer.Overhead()) + } + + // // [UQUIC] + // if len(initialPayload.frames) > 0 { + // fmt.Printf("onlyAck: %t, PackCoalescedPacket: %v\n", onlyAck, initialPayload.frames[0].Frame) + // } + } + + // Add a Handshake packet. + var handshakeSealer sealer + if (onlyAck && size == 0) || (!onlyAck && size < maxPacketSize-protocol.MinCoalescedPacketSize) { + var err error + handshakeSealer, err = p.cryptoSetup.GetHandshakeSealer() + if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable { + return nil, err + } + if handshakeSealer != nil { + handshakeHdr, handshakePayload = p.maybeGetCryptoPacket(maxPacketSize-size-protocol.ByteCount(handshakeSealer.Overhead()), protocol.EncryptionHandshake, onlyAck, size == 0, v) + if handshakePayload.length > 0 { + s := p.longHeaderPacketLength(handshakeHdr, handshakePayload, v) + protocol.ByteCount(handshakeSealer.Overhead()) + size += s + } + } + } + + // Add a 0-RTT / 1-RTT packet. + var zeroRTTSealer sealer + var oneRTTSealer handshake.ShortHeaderSealer + var connID protocol.ConnectionID + var kp protocol.KeyPhaseBit + if (onlyAck && size == 0) || (!onlyAck && size < maxPacketSize-protocol.MinCoalescedPacketSize) { + var err error + oneRTTSealer, err = p.cryptoSetup.Get1RTTSealer() + if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable { + return nil, err + } + if err == nil { // 1-RTT + kp = oneRTTSealer.KeyPhase() + connID = p.getDestConnID() + oneRTTPacketNumber, oneRTTPacketNumberLen = p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) + hdrLen := wire.ShortHeaderLen(connID, oneRTTPacketNumberLen) + oneRTTPayload = p.maybeGetShortHeaderPacket(oneRTTSealer, hdrLen, maxPacketSize-size, onlyAck, size == 0, v) + if oneRTTPayload.length > 0 { + size += p.shortHeaderPacketLength(connID, oneRTTPacketNumberLen, oneRTTPayload) + protocol.ByteCount(oneRTTSealer.Overhead()) + } + } else if p.perspective == protocol.PerspectiveClient && !onlyAck { // 0-RTT packets can't contain ACK frames + var err error + zeroRTTSealer, err = p.cryptoSetup.Get0RTTSealer() + if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable { + return nil, err + } + if zeroRTTSealer != nil { + zeroRTTHdr, zeroRTTPayload = p.maybeGetAppDataPacketFor0RTT(zeroRTTSealer, maxPacketSize-size, v) + if zeroRTTPayload.length > 0 { + size += p.longHeaderPacketLength(zeroRTTHdr, zeroRTTPayload, v) + protocol.ByteCount(zeroRTTSealer.Overhead()) + } + } + } + } + + if initialPayload.length == 0 && handshakePayload.length == 0 && zeroRTTPayload.length == 0 && oneRTTPayload.length == 0 { + return nil, nil + } + + buffer := getPacketBuffer() + packet := &coalescedPacket{ + buffer: buffer, + longHdrPackets: make([]*longHeaderPacket, 0, 3), + } + if initialPayload.length > 0 { + if onlyAck || len(initialPayload.frames) == 0 { + // TODO: uQUIC should send Initial Packet if requested. + // However, it should be otherwise configurable whether to request + // to send Initial Packet or not. See quic-go#4007 + padding := p.initialPaddingLen(initialPayload.frames, size, maxPacketSize) + cont, err := p.appendLongHeaderPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer, v) + if err != nil { + return nil, err + } + packet.longHdrPackets = append(packet.longHdrPackets, cont) + } else { // [UQUIC] + cont, err := p.appendInitialPacket(buffer, initialHdr, initialPayload, protocol.EncryptionInitial, initialSealer, v) + if err != nil { + return nil, err + } + + packet.longHdrPackets = append(packet.longHdrPackets, cont) + } + } + if handshakePayload.length > 0 { + cont, err := p.appendLongHeaderPacket(buffer, handshakeHdr, handshakePayload, 0, protocol.EncryptionHandshake, handshakeSealer, v) + if err != nil { + return nil, err + } + packet.longHdrPackets = append(packet.longHdrPackets, cont) + } + if zeroRTTPayload.length > 0 { + longHdrPacket, err := p.appendLongHeaderPacket(buffer, zeroRTTHdr, zeroRTTPayload, 0, protocol.Encryption0RTT, zeroRTTSealer, v) + if err != nil { + return nil, err + } + packet.longHdrPackets = append(packet.longHdrPackets, longHdrPacket) + } else if oneRTTPayload.length > 0 { + shp, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, kp, oneRTTPayload, 0, maxPacketSize, oneRTTSealer, false, v) + if err != nil { + return nil, err + } + packet.shortHdrPacket = &shp + } + return packet, nil +} + +// [UQUIC] +func (p *uPacketPacker) appendInitialPacket(buffer *packetBuffer, header *wire.ExtendedHeader, pl payload, encLevel protocol.EncryptionLevel, sealer sealer, v protocol.VersionNumber) (*longHeaderPacket, error) { + // Shouldn't need this? + // if p.uSpec.InitialPacketSpec.InitPacketNumberLength > 0 { + // header.PacketNumberLen = p.uSpec.InitialPacketSpec.InitPacketNumberLength + // } + + uPayload, err := p.MarshalInitialPacketPayload(pl, v) + if err != nil { + return nil, err + } + + pnLen := protocol.ByteCount(header.PacketNumberLen) + header.Length = pnLen + protocol.ByteCount(sealer.Overhead()) + protocol.ByteCount(len(uPayload)) + + startLen := len(buffer.Data) + raw := buffer.Data[startLen:] // [UQUIC] the raw here is a sub-slice of buffer.Data, latter's len < size + + raw, err = header.Append(raw, v) + if err != nil { + return nil, err + } + payloadOffset := protocol.ByteCount(len(raw)) + raw = append(raw, uPayload...) + + // fmt.Printf("Payload: %x\n", raw[payloadOffset:]) + + // fmt.Printf("Pre-Encryption: %x\n", raw) + + raw = p.encryptPacket(raw, sealer, header.PacketNumber, payloadOffset, pnLen) + buffer.Data = buffer.Data[:len(buffer.Data)+len(raw)] + + // fmt.Printf("Post-Encryption: %x\n", raw) + + // [UQUIC] + // append zero to buffer.Data until min size is reached + minUDPSize := p.uSpec.UDPDatagramMinSize + if minUDPSize == 0 { + minUDPSize = DefaultUDPDatagramMinSize + } + if len(buffer.Data) < minUDPSize { + buffer.Data = append(buffer.Data, make([]byte, minUDPSize-len(buffer.Data))...) + } + + if pn := p.pnManager.PopPacketNumber(encLevel); pn != header.PacketNumber { + return nil, fmt.Errorf("packetPacker BUG: Peeked and Popped packet numbers do not match: expected %d, got %d", pn, header.PacketNumber) + } + return &longHeaderPacket{ + header: header, + ack: pl.ack, + frames: pl.frames, + streamFrames: pl.streamFrames, + length: protocol.ByteCount(len(raw)), + }, nil +} + +func (p *uPacketPacker) MarshalInitialPacketPayload(pl payload, v protocol.VersionNumber) ([]byte, error) { + var originalFrameBytes []byte + + for _, f := range pl.frames { + var err error + // only append crypto frames + if _, ok := f.Frame.(*wire.CryptoFrame); !ok { + continue + } + + originalFrameBytes, err = f.Frame.Append(originalFrameBytes, v) + if err != nil { + return nil, err + } + } + + uPayload, err := p.uSpec.InitialPacketSpec.FrameOrder.MarshalWithFrames(originalFrameBytes) + if err != nil { + return nil, err + } + + return uPayload, nil +} diff --git a/u_quic_spec.go b/u_quic_spec.go index 4b4999bb..a7ef41dd 100644 --- a/u_quic_spec.go +++ b/u_quic_spec.go @@ -1,200 +1,18 @@ package quic -import ( - "bytes" - "crypto/rand" - "errors" +import tls "github.com/refraction-networking/utls" - "github.com/gaukas/clienthellod" - "github.com/quic-go/quic-go/quicvarint" +const ( + DefaultUDPDatagramMinSize = 1200 ) type QUICSpec struct { - // SrcConnIDLength specifies how many bytes should the SrcConnID be - SrcConnIDLength int + InitialPacketSpec InitialPacketSpec + ClientHelloSpec *tls.ClientHelloSpec - // DestConnIDLength specifies how many bytes should the DestConnID be - DstConnIDLength int - - // InitPacketNumberLength specifies how many bytes should the InitPacketNumber - // be interpreted as. It is usually 1 or 2 bytes. If unset, UQUIC will use the - // default algorithm to compute the length which is at least 2 bytes. - InitPacketNumberLength PacketNumberLen - - // InitPacketNumber is the packet number of the first Initial packet. Following - // Initial packets, if any, will increment the Packet Number accordingly. - InitPacketNumber uint64 // [UQUIC] - - // TokenStore is used to store and retrieve tokens. If set, will override the - // one set in the Config. - TokenStore TokenStore - - // If ClientTokenLength is set when TokenStore is not set, a dummy TokenStore - // will be created to randomly generate tokens of the specified length for - // Pop() calls with any key and silently drop any Put() calls. - // - // However, the tokens will not be stored anywhere and are expected to be - // invalid since not assigned by the server. - ClientTokenLength int - - // QUICFrames specifies a list of QUIC frames to be sent in the first Initial - // packet. - // - // If nil, it will be treated as a list with only a single QUICFrameCrypto. - QUICFrames []QUICFrame + UDPDatagramMinSize int } -func (s *QUICSpec) getTokenStore() TokenStore { - if s.TokenStore != nil { - return s.TokenStore - } - - if s.ClientTokenLength > 0 { - return &dummyTokenStore{ - tokenLength: s.ClientTokenLength, - } - } - - return nil -} - -type dummyTokenStore struct { - tokenLength int -} - -func (d *dummyTokenStore) Pop(key string) (token *ClientToken) { - var data []byte = make([]byte, d.tokenLength) - rand.Read(data) - - return &ClientToken{ - data: data, - } -} - -func (d *dummyTokenStore) Put(_ string, _ *ClientToken) { - // Do nothing -} - -type QUICFrames []QUICFrame - -func (qfs QUICFrames) MarshalWithCryptoData(cryptoData []byte) (payload []byte, err error) { - if len(qfs) == 0 { // If no frames specified, send a single crypto frame - payload = make([]byte, len(cryptoData)+1) - } - - for _, frame := range qfs { - var frameBytes []byte - if offset, length, cryptoOK := frame.CryptoFrameInfo(); cryptoOK { - if length == 0 { - // calculate length: from offset to the end of cryptoData - length = len(cryptoData) - offset - } - frameBytes = []byte{0x06} // CRYPTO frame type - frameBytes = quicvarint.Append(frameBytes, uint64(offset)) - frameBytes = quicvarint.Append(frameBytes, uint64(length)) - frameCryptoData := make([]byte, length) - copy(frameCryptoData, cryptoData[offset:]) // copy at most length bytes - frameBytes = append(frameBytes, frameCryptoData...) - } else { // Handle none crypto frames: read and append to payload - frameBytes, err = frame.Read() - if err != nil { - return nil, err - } - } - payload = append(payload, frameBytes...) - } - return payload, nil -} - -func (qfs QUICFrames) MarshalWithFrames(frames []byte) (payload []byte, err error) { - // parse frames - r := bytes.NewReader(frames) - qchframes, err := clienthellod.ReadAllFrames(r) - if err != nil { - return nil, err - } - - // parse crypto data - cryptoData, err := clienthellod.ReassembleCRYPTOFrames(qchframes) - if err != nil { - return nil, err - } - - // marshal - return qfs.MarshalWithCryptoData(cryptoData) -} - -type QUICFrame interface { - // None crypto frames should return false for cryptoOK - CryptoFrameInfo() (offset, length int, cryptoOK bool) - - // None crypto frames should return the byte representation of the frame. - // Crypto frames' behavior is undefined and unused. - Read() ([]byte, error) -} - -// QUICFrameCrypto is used to specify the crypto frames containing the TLS ClientHello -// to be sent in the first Initial packet. -type QUICFrameCrypto struct { - // Offset is used to specify the starting offset of the crypto frame. - // Used when sending multiple crypto frames in a single packet. - // - // Multiple crypto frames in a single packet must not overlap and must - // make up an entire crypto stream continuously. - Offset int - - // Length is used to specify the length of the crypto frame. - // - // Must be set if it is NOT the last crypto frame in a packet. - Length int -} - -// CryptoFrameInfo() implements the QUICFrame interface. -// -// Crypto frames are later replaced by the crypto message using the information -// returned by this function. -func (q QUICFrameCrypto) CryptoFrameInfo() (offset, length int, cryptoOK bool) { - return q.Offset, q.Length, true -} - -// Read() implements the QUICFrame interface. -// -// Crypto frames are later replaced by the crypto message, so they are not Read()-able. -func (q QUICFrameCrypto) Read() ([]byte, error) { - return nil, errors.New("crypto frames are not Read()-able") -} - -// QUICFramePadding is used to specify the padding frames to be sent in the first Initial -// packet. -type QUICFramePadding struct { - // Length is used to specify the length of the padding frame. - Length int -} - -// CryptoFrameInfo() implements the QUICFrame interface. -func (q QUICFramePadding) CryptoFrameInfo() (offset, length int, cryptoOK bool) { - return 0, 0, false -} - -// Read() implements the QUICFrame interface. -// -// Padding simply returns a slice of bytes of the specified length filled with 0. -func (q QUICFramePadding) Read() ([]byte, error) { - return make([]byte, q.Length), nil -} - -// QUICFramePing is used to specify the ping frames to be sent in the first Initial -// packet. -type QUICFramePing struct{} - -// CryptoFrameInfo() implements the QUICFrame interface. -func (q QUICFramePing) CryptoFrameInfo() (offset, length int, cryptoOK bool) { - return 0, 0, false -} - -// Read() implements the QUICFrame interface. -// -// Ping simply returns a slice of bytes of size 1 with value 0x01(PING). -func (q QUICFramePing) Read() ([]byte, error) { - return []byte{0x01}, nil +func (s *QUICSpec) UpdateConfig(config *Config) { + s.InitialPacketSpec.UpdateConfig(config) } diff --git a/u_transport.go b/u_transport.go new file mode 100644 index 00000000..488d41fa --- /dev/null +++ b/u_transport.go @@ -0,0 +1,87 @@ +package quic + +import ( + "context" + "net" + + "github.com/quic-go/quic-go/internal/protocol" + tls "github.com/refraction-networking/utls" +) + +type UTransport struct { + *Transport + + QUICSpec *QUICSpec // [UQUIC] using ptr to avoid copying +} + +// Dial dials a new connection to a remote host (not using 0-RTT). +func (t *UTransport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) { + if err := validateConfig(conf); err != nil { + return nil, err + } + conf = populateConfig(conf) + + // [UQUIC] + // Override the default connection ID generator if the user has specified a length in QUICSpec. + if t.QUICSpec != nil { + if t.QUICSpec.InitialPacketSpec.SrcConnIDLength != 0 { + t.ConnectionIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: t.QUICSpec.InitialPacketSpec.SrcConnIDLength} + } else { + t.ConnectionIDGenerator = &protocol.ExpEmptyConnectionIDGenerator{} + } + } + // [/UQUIC] + + if err := t.init(t.isSingleUse); err != nil { + return nil, err + } + var onClose func() + if t.isSingleUse { + onClose = func() { t.Close() } + } + tlsConf = tlsConf.Clone() + tlsConf.MinVersion = tls.VersionTLS13 + + return udial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, false, t.QUICSpec) +} + +// DialEarly dials a new connection, attempting to use 0-RTT if possible. +func (t *UTransport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) { + if err := validateConfig(conf); err != nil { + return nil, err + } + conf = populateConfig(conf) + + // [UQUIC] + // Override the default connection ID generator if the user has specified a length in QUICSpec. + if t.QUICSpec != nil { + if t.QUICSpec.InitialPacketSpec.SrcConnIDLength != 0 { + t.ConnectionIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: t.QUICSpec.InitialPacketSpec.SrcConnIDLength} + } else { + t.ConnectionIDGenerator = &protocol.ExpEmptyConnectionIDGenerator{} + } + } + // [/UQUIC] + + if err := t.init(t.isSingleUse); err != nil { + return nil, err + } + var onClose func() + if t.isSingleUse { + onClose = func() { t.Close() } + } + tlsConf = tlsConf.Clone() + tlsConf.MinVersion = tls.VersionTLS13 + + return udial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, true, t.QUICSpec) +} + +func (ut *UTransport) MakeDialer() func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *Config) (EarlyConnection, error) { + return func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *Config) (EarlyConnection, error) { + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + return ut.DialEarly(ctx, udpAddr, tlsCfg, cfg) + } +}