new: uquic

This commit is contained in:
Gaukas Wang 2023-08-02 15:38:16 -06:00
parent 95f3eaaa66
commit ea40752ca3
No known key found for this signature in database
GPG key ID: 9E2F8986D76F8B5D
25 changed files with 1420 additions and 686 deletions

109
client.go
View file

@ -38,8 +38,6 @@ type client struct {
tracer logging.ConnectionTracer
tracingID uint64
logger utils.Logger
chs *tls.ClientHelloSpec // [UQUIC]
}
// make it possible to mock connection ID for initial generation in the tests
@ -167,41 +165,6 @@ func dial(
return c.conn, nil
}
func dialWithCHS(
ctx context.Context,
conn sendConn,
connIDGenerator ConnectionIDGenerator,
packetHandlers packetHandlerManager,
tlsConf *tls.Config,
config *Config,
onClose func(),
use0RTT bool,
chs *tls.ClientHelloSpec,
) (quicConn, error) {
c, err := newClient(conn, connIDGenerator, config, tlsConf, onClose, use0RTT)
if err != nil {
return nil, err
}
c.packetHandlers = packetHandlers
c.tracingID = nextConnTracingID()
if c.config.Tracer != nil {
c.tracer = c.config.Tracer(context.WithValue(ctx, ConnectionTracingKey, c.tracingID), protocol.PerspectiveClient, c.destConnID)
}
if c.tracer != nil {
c.tracer.StartedConnection(c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID)
}
// [UQUIC]
c.chs = chs
// [/UQUIC]
if err := c.dial(ctx); err != nil {
return nil, err
}
return c.conn, nil
}
func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config *Config, tlsConf *tls.Config, onClose func(), use0RTT bool) (*client, error) {
if tlsConf == nil {
tlsConf = &tls.Config{}
@ -209,25 +172,12 @@ func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config
tlsConf = tlsConf.Clone()
}
// // [UQUIC]
// if config.SrcConnIDLength != 0 {
// connIDLen := config.SrcConnIDLength
// connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: connIDLen}
// }
srcConnID, err := connIDGenerator.GenerateConnectionID()
if err != nil {
return nil, err
}
var destConnID protocol.ConnectionID
// [UQUIC]
if config.DestConnIDLength > 0 {
destConnID, err = generateConnectionIDForInitialWithLength(config.DestConnIDLength)
} else {
destConnID, err = generateConnectionIDForInitial()
}
// [/UQUIC]
destConnID, err := generateConnectionIDForInitial()
if err != nil {
return nil, err
}
@ -243,8 +193,6 @@ func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config
version: config.Versions[0],
handshakeChan: make(chan struct{}),
logger: utils.DefaultLogger.WithPrefix("client"),
initialPacketNumber: protocol.PacketNumber(config.InitPacketNumber), // [UQUIC]
}
return c, nil
}
@ -252,45 +200,22 @@ func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config
func (c *client) dial(ctx context.Context) error {
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
// [UQUIC]
if c.chs == nil {
c.conn = newClientConnection(
c.sendConn,
c.packetHandlers,
c.destConnID,
c.srcConnID,
c.connIDGenerator,
c.config,
c.tlsConf,
c.initialPacketNumber,
c.use0RTT,
c.hasNegotiatedVersion,
c.tracer,
c.tracingID,
c.logger,
c.version,
)
} else {
// [UQUIC]: use custom version of the connection
c.conn = newUClientConnection(
c.sendConn,
c.packetHandlers,
c.destConnID,
c.srcConnID,
c.connIDGenerator,
c.config,
c.tlsConf,
c.initialPacketNumber,
c.use0RTT,
c.hasNegotiatedVersion,
c.tracer,
c.tracingID,
c.logger,
c.version,
c.chs,
)
}
// [/UQUIC]
c.conn = newClientConnection(
c.sendConn,
c.packetHandlers,
c.destConnID,
c.srcConnID,
c.connIDGenerator,
c.config,
c.tlsConf,
c.initialPacketNumber,
c.use0RTT,
c.hasNegotiatedVersion,
c.tracer,
c.tracingID,
c.logger,
c.version,
)
c.packetHandlers.Add(c.srcConnID, c.conn)

View file

@ -131,12 +131,6 @@ func populateConfig(config *Config) *Config {
DisableVersionNegotiationPackets: config.DisableVersionNegotiationPackets,
Allow0RTT: config.Allow0RTT,
Tracer: config.Tracer,
// [UQUIC]
SrcConnIDLength: config.SrcConnIDLength,
DestConnIDLength: config.DestConnIDLength,
InitPacketNumber: config.InitPacketNumber,
InitPacketNumberLength: config.InitPacketNumberLength,
}
}

View file

@ -62,10 +62,12 @@ func (h *connIDManager) Add(f *wire.NewConnectionIDFrame) error {
return err
}
// [UQUIC]
connIDLimit := h.connectionIDLimit
if connIDLimit == 0 {
connIDLimit = protocol.MaxActiveConnectionIDs
}
// [/UQUIC]
if uint64(h.queue.Len()) >= connIDLimit {
return &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError}
@ -191,11 +193,6 @@ func (h *connIDManager) SetStatelessResetToken(token protocol.StatelessResetToke
h.addStatelessResetToken(token)
}
// [UQUIC]
func (h *connIDManager) SetConnectionIDLimit(limit uint64) {
h.connectionIDLimit = limit
}
func (h *connIDManager) SentPacket() {
h.packetsSinceLastChange++
}

View file

@ -390,9 +390,6 @@ var newClientConnection = func(
s.tracer,
s.logger,
)
if conf.InitPacketNumberLength != 0 {
ackhandler.SetInitialPacketNumberLength(s.sentPacketHandler, conf.InitPacketNumberLength)
}
s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize)
oneRTTStream := newCryptoStream()
@ -453,151 +450,6 @@ var newClientConnection = func(
return s
}
// [UQUIC]
var newUClientConnection = func(
conn sendConn,
runner connRunner,
destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID,
connIDGenerator ConnectionIDGenerator,
conf *Config,
tlsConf *tls.Config,
initialPacketNumber protocol.PacketNumber,
enable0RTT bool,
hasNegotiatedVersion bool,
tracer logging.ConnectionTracer,
tracingID uint64,
logger utils.Logger,
v protocol.VersionNumber,
chs *tls.ClientHelloSpec,
) quicConn {
s := &connection{
conn: conn,
config: conf,
origDestConnID: destConnID,
handshakeDestConnID: destConnID,
srcConnIDLen: srcConnID.Len(),
perspective: protocol.PerspectiveClient,
logID: destConnID.String(),
logger: logger,
tracer: tracer,
versionNegotiated: hasNegotiatedVersion,
version: v,
}
s.connIDManager = newConnIDManager(
destConnID,
func(token protocol.StatelessResetToken) { runner.AddResetToken(token, s) },
runner.RemoveResetToken,
s.queueControlFrame,
)
s.connIDGenerator = newConnIDGenerator(
srcConnID,
nil,
func(connID protocol.ConnectionID) { runner.Add(connID, s) },
runner.GetStatelessResetToken,
runner.Remove,
runner.Retire,
runner.ReplaceWithClosed,
s.queueControlFrame,
connIDGenerator,
)
s.preSetup()
s.ctx, s.ctxCancel = context.WithCancelCause(context.WithValue(context.Background(), ConnectionTracingKey, tracingID))
s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler(
initialPacketNumber,
getMaxPacketSize(s.conn.RemoteAddr()),
s.rttStats,
false, /* has no effect */
s.perspective,
s.tracer,
s.logger,
)
if conf.InitPacketNumberLength != 0 {
ackhandler.SetInitialPacketNumberLength(s.sentPacketHandler, conf.InitPacketNumberLength)
}
s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize)
oneRTTStream := newCryptoStream()
var params *wire.TransportParameters
// params := &wire.TransportParameters{
// InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
// InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
// InitialMaxStreamDataUni: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
// InitialMaxData: protocol.ByteCount(s.config.InitialConnectionReceiveWindow),
// MaxIdleTimeout: s.config.MaxIdleTimeout,
// MaxBidiStreamNum: protocol.StreamNum(s.config.MaxIncomingStreams),
// MaxUniStreamNum: protocol.StreamNum(s.config.MaxIncomingUniStreams),
// MaxAckDelay: protocol.MaxAckDelayInclGranularity,
// AckDelayExponent: protocol.AckDelayExponent,
// DisableActiveMigration: true,
// // For interoperability with quic-go versions before May 2023, this value must be set to a value
// // different from protocol.DefaultActiveConnectionIDLimit.
// // If set to the default value, it will be omitted from the transport parameters, which will make
// // old quic-go versions interpret it as 0, instead of the default value of 2.
// // See https://github.com/quic-go/quic-go/pull/3806.
// ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs,
// InitialSourceConnectionID: srcConnID,
// }
// if s.config.EnableDatagrams {
// params.MaxDatagramFrameSize = protocol.MaxDatagramFrameSize
// } else {
// params.MaxDatagramFrameSize = protocol.InvalidByteCount
// }
// [UQUIC] iterate over all Extensions to set the TransportParameters
var tpSet bool
FOR_EACH_TLS_EXTENSION:
for _, ext := range chs.Extensions {
switch ext := ext.(type) {
case *tls.QUICTransportParametersExtension:
params = &wire.TransportParameters{
InitialSourceConnectionID: srcConnID,
}
params.PopulateFromUQUIC(ext.TransportParameters)
s.connIDManager.SetConnectionIDLimit(params.ActiveConnectionIDLimit)
tpSet = true
break FOR_EACH_TLS_EXTENSION
default:
continue FOR_EACH_TLS_EXTENSION
}
}
if !tpSet {
panic("applied ClientHelloSpec must contain a QUICTransportParametersExtension to proceed")
}
if s.tracer != nil {
s.tracer.SentTransportParameters(params)
}
cs := handshake.NewUCryptoSetupClient(
destConnID,
params,
tlsConf,
enable0RTT,
s.rttStats,
tracer,
logger,
s.version,
chs,
)
s.cryptoStreamHandler = cs
s.cryptoStreamManager = newCryptoStreamManager(cs, s.initialStream, s.handshakeStream, oneRTTStream)
s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen)
s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, s.initialStream, s.handshakeStream, s.sentPacketHandler, s.retransmissionQueue, cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective)
if len(tlsConf.ServerName) > 0 {
s.tokenStoreKey = tlsConf.ServerName
} else {
s.tokenStoreKey = conn.RemoteAddr().String()
}
if s.config.TokenStore != nil {
if token := s.config.TokenStore.Pop(s.tokenStoreKey); token != nil {
s.packer.SetToken(token.data)
}
}
return s
}
func (s *connection) preSetup() {
s.initialStream = newCryptoStream()
s.handshakeStream = newCryptoStream()
@ -2204,9 +2056,7 @@ func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time
}
s.connIDManager.SentPacket()
s.sendQueue.Send(packet.buffer, packet.buffer.Len())
// [UQUIC]
// fmt.Printf("sendPackedCoalescedPacket:Sending %d bytes\n", packet.buffer.Len())
// fmt.Printf("sendPackedCoalescedPacket: %v\n", packet.buffer.Data)
return nil
}

View file

@ -14,7 +14,83 @@ import (
"github.com/quic-go/quic-go/http3"
)
func getCHS() *tls.ClientHelloSpec {
func main() {
keyLogWriter, err := os.Create("./keylog.txt")
if err != nil {
panic(err)
}
tlsConf := &tls.Config{
ServerName: "quic.tlsfingerprint.io",
// ServerName: "www.cloudflare.com",
// MinVersion: tls.VersionTLS13,
KeyLogWriter: keyLogWriter,
// NextProtos: []string{"h3"},
}
quicConf := &quic.Config{}
roundTripper := &http3.RoundTripper{
TLSClientConfig: tlsConf,
QuicConfig: quicConf,
}
uRoundTripper := http3.GetURoundTripper(
roundTripper,
// getFFQUICSpec(),
getCRQUICSpec(),
nil,
)
defer uRoundTripper.Close()
hclient := &http.Client{
Transport: uRoundTripper,
}
addr := "https://quic.tlsfingerprint.io/qfp/?beautify=true"
// addr := "https://www.cloudflare.com"
rsp, err := hclient.Get(addr)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Got response for %s: %#v", addr, rsp)
body := &bytes.Buffer{}
_, err = io.Copy(body, rsp.Body)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Response Body: %s", body.Bytes())
}
func getFFQUICSpec() *quic.QUICSpec {
return &quic.QUICSpec{
InitialPacketSpec: quic.InitialPacketSpec{
SrcConnIDLength: 3,
DestConnIDLength: 8,
InitPacketNumberLength: 1,
InitPacketNumber: 1,
ClientTokenLength: 0,
FrameOrder: quic.QUICFrames{
&quic.QUICFrameCrypto{
Offset: 300,
Length: 0,
},
&quic.QUICFramePadding{
Length: 125,
},
&quic.QUICFramePing{},
&quic.QUICFrameCrypto{
Offset: 0,
Length: 300,
},
},
},
ClientHelloSpec: getFFCHS(),
}
}
func getFFCHS() *tls.ClientHelloSpec {
return &tls.ClientHelloSpec{
TLSVersMin: tls.VersionTLS13,
TLSVersMax: tls.VersionTLS13,
@ -135,54 +211,146 @@ func getCHS() *tls.ClientHelloSpec {
}
}
func main() {
keyLogWriter, err := os.Create("./keylog.txt")
if err != nil {
panic(err)
func getCRQUICSpec() *quic.QUICSpec {
return &quic.QUICSpec{
InitialPacketSpec: quic.InitialPacketSpec{
SrcConnIDLength: 0,
DestConnIDLength: 8,
InitPacketNumberLength: 1,
InitPacketNumber: 1,
ClientTokenLength: 0,
FrameOrder: quic.QUICFrames{
&quic.QUICFrameCrypto{
Offset: 300,
Length: 0,
},
&quic.QUICFramePadding{
Length: 125,
},
&quic.QUICFramePing{},
&quic.QUICFrameCrypto{
Offset: 0,
Length: 300,
},
},
},
ClientHelloSpec: getCRCHS(),
}
}
func getCRCHS() *tls.ClientHelloSpec {
return &tls.ClientHelloSpec{
TLSVersMin: tls.VersionTLS13,
TLSVersMax: tls.VersionTLS13,
CipherSuites: []uint16{
tls.TLS_AES_128_GCM_SHA256,
tls.TLS_CHACHA20_POLY1305_SHA256,
tls.TLS_AES_256_GCM_SHA384,
},
CompressionMethods: []uint8{
0x0, // no compression
},
Extensions: []tls.TLSExtension{
&tls.SNIExtension{},
&tls.ExtendedMasterSecretExtension{},
&tls.RenegotiationInfoExtension{
Renegotiation: tls.RenegotiateOnceAsClient,
},
&tls.SupportedCurvesExtension{
Curves: []tls.CurveID{
tls.CurveX25519,
tls.CurveSECP256R1,
tls.CurveSECP384R1,
tls.CurveSECP521R1,
tls.FakeCurveFFDHE2048,
tls.FakeCurveFFDHE3072,
tls.FakeCurveFFDHE4096,
tls.FakeCurveFFDHE6144,
tls.FakeCurveFFDHE8192,
},
},
&tls.ALPNExtension{
AlpnProtocols: []string{
"h3",
},
},
&tls.StatusRequestExtension{},
&tls.FakeDelegatedCredentialsExtension{
SupportedSignatureAlgorithms: []tls.SignatureScheme{
tls.ECDSAWithP256AndSHA256,
tls.ECDSAWithP384AndSHA384,
tls.ECDSAWithP521AndSHA512,
tls.ECDSAWithSHA1,
},
},
&tls.KeyShareExtension{
KeyShares: []tls.KeyShare{
{
Group: tls.X25519,
},
// {
// Group: tls.CurveP256,
// },
},
},
&tls.SupportedVersionsExtension{
Versions: []uint16{
tls.VersionTLS13,
},
},
&tls.SignatureAlgorithmsExtension{
SupportedSignatureAlgorithms: []tls.SignatureScheme{
tls.ECDSAWithP256AndSHA256,
tls.ECDSAWithP384AndSHA384,
tls.ECDSAWithP521AndSHA512,
tls.ECDSAWithSHA1,
tls.PSSWithSHA256,
tls.PSSWithSHA384,
tls.PSSWithSHA512,
tls.PKCS1WithSHA256,
tls.PKCS1WithSHA384,
tls.PKCS1WithSHA512,
tls.PKCS1WithSHA1,
},
},
&tls.PSKKeyExchangeModesExtension{
Modes: []uint8{
tls.PskModeDHE,
},
},
&tls.FakeRecordSizeLimitExtension{
Limit: 0x4001,
},
&tls.QUICTransportParametersExtension{
TransportParameters: tls.TransportParameters{
&tls.GREASE{
IdOverride: 0x35967c5b9c37e023,
ValueOverride: []byte{
0xfc, 0x97, 0xbb, 0x57, 0xb8, 0x02, 0x19, 0xcd,
},
},
tls.InitialMaxStreamsUni(103),
tls.InitialSourceConnectionID([]byte{}),
tls.InitialMaxStreamsBidi(100),
tls.InitialMaxData(15728640),
&tls.VersionInformation{
ChoosenVersion: tls.VERSION_1,
AvailableVersions: []uint32{
tls.VERSION_1,
tls.VERSION_GREASE,
},
LegacyID: true,
},
tls.MaxIdleTimeout(30000),
tls.MaxUDPPayloadSize(1472),
tls.MaxDatagramFrameSize(65536),
tls.InitialMaxStreamDataBidiLocal(6291456),
tls.InitialMaxStreamDataUni(6291456),
tls.InitialMaxStreamDataBidiRemote(6291456),
},
},
&tls.UtlsPaddingExtension{
GetPaddingLen: tls.BoringPaddingStyle,
},
},
}
tlsConf := &tls.Config{
ServerName: "quic.tlsfingerprint.io",
// ServerName: "www.cloudflare.com",
// MinVersion: tls.VersionTLS13,
KeyLogWriter: keyLogWriter,
// NextProtos: []string{"h3"},
}
quicConf := &quic.Config{
Versions: []quic.VersionNumber{quic.Version1},
// EnableDatagrams: true,
SrcConnIDLength: 3, // <4 causes timeout
DestConnIDLength: 8,
InitPacketNumber: 0,
InitPacketNumberLength: quic.PacketNumberLen1, // currently only affects the initial packet number
// Versions: []quic.VersionNumber{quic.Version2},
}
roundTripper := &http3.RoundTripper{
TLSClientConfig: tlsConf,
QuicConfig: quicConf,
ClientHelloSpec: getCHS(),
}
defer roundTripper.Close()
hclient := &http.Client{
Transport: roundTripper,
}
addr := "https://quic.tlsfingerprint.io/qfp/"
// addr := "https://www.cloudflare.com"
rsp, err := hclient.Get(addr)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Got response for %s: %#v", addr, rsp)
body := &bytes.Buffer{}
_, err = io.Copy(body, rsp.Body)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Response Body: %s", body.Bytes())
}

View file

@ -88,9 +88,6 @@ type RoundTripper struct {
newClient func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) // so we can mock it in tests
clients map[string]*roundTripCloserWithCount
transport *quic.Transport
// [UQUIC]
ClientHelloSpec *tls.ClientHelloSpec
}
// RoundTripOpt are options for the Transport.RoundTripOpt method.
@ -194,8 +191,7 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc *roundTr
return nil, false, err
}
r.transport = &quic.Transport{
Conn: udpConn,
ClientHelloSpec: r.ClientHelloSpec,
Conn: udpConn,
}
}
dial = r.makeDialer()

192
http3/u_roundtrip.go Normal file
View file

@ -0,0 +1,192 @@
package http3
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"github.com/quic-go/quic-go"
tls "github.com/refraction-networking/utls"
"golang.org/x/net/http/httpguts"
)
type URoundTripper struct {
*RoundTripper
quicSpec *quic.QUICSpec
uTransportOverride *quic.UTransport
}
func GetURoundTripper(r *RoundTripper, QUICSpec *quic.QUICSpec, uTransport *quic.UTransport) *URoundTripper {
QUICSpec.UpdateConfig(r.QuicConfig)
return &URoundTripper{
RoundTripper: r,
quicSpec: QUICSpec,
uTransportOverride: uTransport,
}
}
// RoundTripOpt is like RoundTrip, but takes options.
func (r *URoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
if req.URL == nil {
closeRequestBody(req)
return nil, errors.New("http3: nil Request.URL")
}
if req.URL.Scheme != "https" {
closeRequestBody(req)
return nil, fmt.Errorf("http3: unsupported protocol scheme: %s", req.URL.Scheme)
}
if req.URL.Host == "" {
closeRequestBody(req)
return nil, errors.New("http3: no Host in request URL")
}
if req.Header == nil {
closeRequestBody(req)
return nil, errors.New("http3: nil Request.Header")
}
for k, vv := range req.Header {
if !httpguts.ValidHeaderFieldName(k) {
return nil, fmt.Errorf("http3: invalid http header field name %q", k)
}
for _, v := range vv {
if !httpguts.ValidHeaderFieldValue(v) {
return nil, fmt.Errorf("http3: invalid http header field value %q for key %v", v, k)
}
}
}
if req.Method != "" && !validMethod(req.Method) {
closeRequestBody(req)
return nil, fmt.Errorf("http3: invalid method %q", req.Method)
}
hostname := authorityAddr("https", hostnameFromRequest(req))
cl, isReused, err := r.getClient(hostname, opt.OnlyCachedConn)
if err != nil {
return nil, err
}
defer cl.useCount.Add(-1)
rsp, err := cl.RoundTripOpt(req, opt)
if err != nil {
r.removeClient(hostname)
if isReused {
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
return r.RoundTripOpt(req, opt)
}
}
}
return rsp, err
}
// RoundTrip does a round trip.
func (r *URoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return r.RoundTripOpt(req, RoundTripOpt{})
}
func (r *URoundTripper) getClient(hostname string, onlyCached bool) (rtc *roundTripCloserWithCount, isReused bool, err error) {
r.mutex.Lock()
defer r.mutex.Unlock()
if r.clients == nil {
r.clients = make(map[string]*roundTripCloserWithCount)
}
client, ok := r.clients[hostname]
if !ok {
if onlyCached {
return nil, false, ErrNoCachedConn
}
var err error
newCl := newClient
if r.newClient != nil {
newCl = r.newClient
}
dial := r.Dial
if dial == nil {
if r.transport == nil && r.uTransportOverride == nil {
udpConn, err := net.ListenUDP("udp", nil)
if err != nil {
return nil, false, err
}
r.uTransportOverride = &quic.UTransport{
Transport: &quic.Transport{
Conn: udpConn,
},
QUICSpec: r.quicSpec,
}
}
dial = r.makeDialer()
}
c, err := newCl(
hostname,
r.TLSClientConfig,
&roundTripperOpts{
EnableDatagram: r.EnableDatagrams,
DisableCompression: r.DisableCompression,
MaxHeaderBytes: r.MaxResponseHeaderBytes,
StreamHijacker: r.StreamHijacker,
UniStreamHijacker: r.UniStreamHijacker,
},
r.QuicConfig,
dial,
)
if err != nil {
return nil, false, err
}
client = &roundTripCloserWithCount{roundTripCloser: c}
r.clients[hostname] = client
} else if client.HandshakeComplete() {
isReused = true
}
client.useCount.Add(1)
return client, isReused, nil
}
func (r *URoundTripper) Close() error {
r.mutex.Lock()
defer r.mutex.Unlock()
for _, client := range r.clients {
if err := client.Close(); err != nil {
return err
}
}
r.clients = nil
if r.transport != nil {
if err := r.transport.Close(); err != nil {
return err
}
if err := r.transport.Conn.Close(); err != nil {
return err
}
r.transport = nil
}
if r.uTransportOverride != nil {
if err := r.uTransportOverride.Close(); err != nil {
return err
}
if err := r.uTransportOverride.Conn.Close(); err != nil {
return err
}
r.uTransportOverride = nil
}
return nil
}
// makeDialer makes a QUIC dialer using r.udpConn.
func (r *URoundTripper) makeDialer() func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
return func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
if r.uTransportOverride != nil {
return r.uTransportOverride.DialEarly(ctx, udpAddr, tlsCfg, cfg)
} else if r.transport == nil {
return nil, errors.New("http3: no QUIC transport available")
}
return r.transport.DialEarly(ctx, udpAddr, tlsCfg, cfg)
}
}

View file

@ -333,12 +333,6 @@ type Config struct {
// Enable QUIC datagram support (RFC 9221).
EnableDatagrams bool
Tracer func(context.Context, logging.Perspective, ConnectionID) logging.ConnectionTracer
// TransportParameters override other transport parameters set by the Config.
SrcConnIDLength int // [UQUIC]
DestConnIDLength int // [UQUIC]
InitPacketNumber uint64 // [UQUIC]
InitPacketNumberLength PacketNumberLen // [UQUIC]
}
type ClientHelloInfo struct {

View file

@ -96,8 +96,6 @@ type sentPacketHandler struct {
tracer logging.ConnectionTracer
logger utils.Logger
initialPacketNumberLength protocol.PacketNumberLen // [UQUIC]
}
var (
@ -138,12 +136,6 @@ func newSentPacketHandler(
}
}
func SetInitialPacketNumberLength(h SentPacketHandler, pnLen protocol.PacketNumberLen) {
if sph, ok := h.(*sentPacketHandler); ok {
sph.initialPacketNumberLength = pnLen
}
}
func (h *sentPacketHandler) removeFromBytesInFlight(p *packet) {
if p.includedInBytesInFlight {
if p.Length > h.bytesInFlight {
@ -725,11 +717,6 @@ func (h *sentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel)
pn := pnSpace.pns.Peek()
// See section 17.1 of RFC 9000.
// [UQUIC] This kinda breaks PN length mimicry.
if encLevel == protocol.EncryptionInitial && h.initialPacketNumberLength != 0 {
return pn, h.initialPacketNumberLength
}
return pn, protocol.GetPacketNumberLengthForHeader(pn, pnSpace.largestAcked)
}

View file

@ -0,0 +1,23 @@
package ackhandler
import (
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/logging"
)
// [UQUIC]
func NewUAckHandler(
initialPacketNumber protocol.PacketNumber,
initialMaxDatagramSize protocol.ByteCount,
rttStats *utils.RTTStats,
clientAddressValidated bool,
pers protocol.Perspective,
tracer logging.ConnectionTracer,
logger utils.Logger,
) (SentPacketHandler, ReceivedPacketHandler) {
sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, clientAddressValidated, pers, tracer, logger)
return &uSentPacketHandler{
sentPacketHandler: sph,
}, newReceivedPacketHandler(sph, rttStats, logger)
}

View file

@ -0,0 +1,30 @@
package ackhandler
import "github.com/quic-go/quic-go/internal/protocol"
type uSentPacketHandler struct {
*sentPacketHandler
initialPacketNumberLength protocol.PacketNumberLen // [UQUIC]
}
func (h *uSentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) {
pnSpace := h.getPacketNumberSpace(encLevel)
pn := pnSpace.pns.Peek()
// See section 17.1 of RFC 9000.
// [UQUIC] Otherwise it kinda breaks PN length mimicry.
if encLevel == protocol.EncryptionInitial && h.initialPacketNumberLength != 0 {
return pn, h.initialPacketNumberLength
}
// [/UQUIC]
return pn, protocol.GetPacketNumberLengthForHeader(pn, pnSpace.largestAcked)
}
// [UQUIC]
func SetInitialPacketNumberLength(h SentPacketHandler, pnLen protocol.PacketNumberLen) {
if sph, ok := h.(*uSentPacketHandler); ok {
sph.initialPacketNumberLength = pnLen
}
}

View file

@ -102,41 +102,6 @@ func NewCryptoSetupClient(
return cs
}
// [UQUIC]
// NewUCryptoSetupClient creates a new crypto setup for the client with UTLS
func NewUCryptoSetupClient(
connID protocol.ConnectionID,
tp *wire.TransportParameters,
tlsConf *tls.Config,
enable0RTT bool,
rttStats *utils.RTTStats,
tracer logging.ConnectionTracer,
logger utils.Logger,
version protocol.VersionNumber,
chs *tls.ClientHelloSpec,
) CryptoSetup {
cs := newCryptoSetup(
connID,
tp,
rttStats,
tracer,
logger,
protocol.PerspectiveClient,
version,
)
tlsConf = tlsConf.Clone()
tlsConf.MinVersion = tls.VersionTLS13
quicConf := &qtls.QUICConfig{TLSConfig: tlsConf}
qtls.SetupConfigForClient(quicConf, cs.marshalDataForSessionState, cs.handleDataFromSessionState)
cs.tlsConf = tlsConf
cs.conn = qtls.UQUICClient(quicConf, chs)
// cs.conn.SetTransportParameters(cs.ourParams.Marshal(protocol.PerspectiveClient)) // [UQUIC] doesn't require this
return cs
}
// NewCryptoSetupServer creates a new crypto setup for the server
func NewCryptoSetupServer(
connID protocol.ConnectionID,
@ -281,6 +246,7 @@ func (h *cryptoSetup) handleEvent(ev qtls.QUICEvent) (done bool, err error) {
return false, h.handleTransportParameters(ev.Data)
case qtls.QUICTransportParametersRequired:
h.conn.SetTransportParameters(h.ourParams.Marshal(h.perspective))
// [UQUIC] doesn't expect this and may fail
return false, nil
case qtls.QUICRejectedEarlyData:
h.rejected0RTT()

View file

@ -0,0 +1,45 @@
package handshake
import (
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qtls"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/logging"
tls "github.com/refraction-networking/utls"
)
// [UQUIC]
// NewUCryptoSetupClient creates a new crypto setup for the client with UTLS
func NewUCryptoSetupClient(
connID protocol.ConnectionID,
tp *wire.TransportParameters,
tlsConf *tls.Config,
enable0RTT bool,
rttStats *utils.RTTStats,
tracer logging.ConnectionTracer,
logger utils.Logger,
version protocol.VersionNumber,
chs *tls.ClientHelloSpec,
) CryptoSetup {
cs := newCryptoSetup(
connID,
tp,
rttStats,
tracer,
logger,
protocol.PerspectiveClient,
version,
)
tlsConf = tlsConf.Clone()
tlsConf.MinVersion = tls.VersionTLS13
quicConf := &qtls.QUICConfig{TLSConfig: tlsConf}
qtls.SetupConfigForClient(quicConf, cs.marshalDataForSessionState, cs.handleDataFromSessionState)
cs.tlsConf = tlsConf
cs.conn = qtls.UQUICClient(quicConf, chs)
// cs.conn.SetTransportParameters(cs.ourParams.Marshal(protocol.PerspectiveClient)) // [UQUIC] doesn't require this
return cs
}

View file

@ -68,11 +68,6 @@ func GenerateConnectionIDForInitial() (ConnectionID, error) {
return GenerateConnectionID(l)
}
// [UQUIC]
func GenerateConnectionIDForInitialWithLen(l int) (ConnectionID, error) {
return GenerateConnectionID(l)
}
// ReadConnectionID reads a connection ID of length len from the given io.Reader.
// It returns io.EOF if there are not enough bytes to read.
func ReadConnectionID(r io.Reader, l int) (ConnectionID, error) {

View file

@ -0,0 +1,16 @@
package protocol
// [UQUIC]
func GenerateConnectionIDForInitialWithLen(l int) (ConnectionID, error) {
return GenerateConnectionID(l)
}
type ExpEmptyConnectionIDGenerator struct{}
func (g *ExpEmptyConnectionIDGenerator) GenerateConnectionID() (ConnectionID, error) {
return GenerateConnectionID(0)
}
func (g *ExpEmptyConnectionIDGenerator) ConnectionIDLen() int {
return 0
}

View file

@ -332,11 +332,6 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.
if initialPayload.length > 0 {
size += p.longHeaderPacketLength(initialHdr, initialPayload, v) + protocol.ByteCount(initialSealer.Overhead())
}
// // [UQUIC]
// if len(initialPayload.frames) > 0 {
// fmt.Printf("onlyAck: %t, PackCoalescedPacket: %v\n", onlyAck, initialPayload.frames[0].Frame)
// }
}
// Add a Handshake packet.
@ -401,24 +396,12 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.
longHdrPackets: make([]*longHeaderPacket, 0, 3),
}
if initialPayload.length > 0 {
if onlyAck || len(initialPayload.frames) == 0 {
// padding := p.initialPaddingLen(initialPayload.frames, size, maxPacketSize)
// cont, err := p.appendLongHeaderPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer, v)
// if err != nil {
// return nil, err
// }
// packet.longHdrPackets = append(packet.longHdrPackets, cont)
return nil, nil // [UQUIC] not to send the ACK frame for Initial
} else { // [UQUIC]
cont, err := p.appendLongHeaderPacketExternalPadding(buffer, initialHdr, initialPayload, protocol.EncryptionInitial, initialSealer, v)
if err != nil {
return nil, err
}
// fmt.Printf("!onlyAck buffer: %v\n", buffer.Data)
packet.longHdrPackets = append(packet.longHdrPackets, cont)
padding := p.initialPaddingLen(initialPayload.frames, size, maxPacketSize)
cont, err := p.appendLongHeaderPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer, v)
if err != nil {
return nil, err
}
packet.longHdrPackets = append(packet.longHdrPackets, cont)
}
if handshakePayload.length > 0 {
cont, err := p.appendLongHeaderPacket(buffer, handshakeHdr, handshakePayload, 0, protocol.EncryptionHandshake, handshakeSealer, v)
@ -688,7 +671,6 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, m
return nil, err
}
hdr, pl = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionInitial, false, true, v)
fmt.Printf("MaybePackProbePacket: %x\n", pl.frames[0])
case protocol.EncryptionHandshake:
var err error
sealer, err = p.cryptoSetup.GetHandshakeSealer()
@ -768,10 +750,6 @@ func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire
}
paddingLen += padding
if encLevel == protocol.EncryptionInitial {
paddingLen = 0
}
header.Length = pnLen + protocol.ByteCount(sealer.Overhead()) + pl.length + paddingLen
startLen := len(buffer.Data)
@ -800,49 +778,6 @@ func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire
}, nil
}
// [UQUIC]
func (p *packetPacker) appendLongHeaderPacketExternalPadding(buffer *packetBuffer, header *wire.ExtendedHeader, pl payload, encLevel protocol.EncryptionLevel, sealer sealer, v protocol.VersionNumber) (*longHeaderPacket, error) {
pnLen := protocol.ByteCount(header.PacketNumberLen)
header.Length = pnLen + protocol.ByteCount(sealer.Overhead()) + pl.length
startLen := len(buffer.Data)
raw := buffer.Data[startLen:] // [UQUIC] raw is a sub-slice of buffer.Data, whose len < size
raw, err := header.Append(raw, v)
if err != nil {
return nil, err
}
fmt.Printf("Pre-Payload: %x\n", raw)
payloadOffset := protocol.ByteCount(len(raw))
raw, err = p.appendCustomInitialPacketPayload(raw, pl, 0, v)
if err != nil {
return nil, err
}
fmt.Printf("Pre-Encryption: %x\n", raw)
raw = p.encryptPacket(raw, sealer, header.PacketNumber, payloadOffset, pnLen)
buffer.Data = buffer.Data[:len(buffer.Data)+len(raw)]
fmt.Printf("Post-Encryption: %x\n", raw)
// [UQUIC]
// append zero to buffer.Data until 1200 bytes
buffer.Data = append(buffer.Data, make([]byte, 1357-len(buffer.Data))...)
if pn := p.pnManager.PopPacketNumber(encLevel); pn != header.PacketNumber {
return nil, fmt.Errorf("packetPacker BUG: Peeked and Popped packet numbers do not match: expected %d, got %d", pn, header.PacketNumber)
}
return &longHeaderPacket{
header: header,
ack: pl.ack,
frames: pl.frames,
streamFrames: pl.streamFrames,
length: protocol.ByteCount(len(raw)),
}, nil
}
func (p *packetPacker) appendShortHeaderPacket(
buffer *packetBuffer,
connID protocol.ConnectionID,
@ -930,44 +865,6 @@ func (p *packetPacker) appendPacketPayload(raw []byte, pl payload, paddingLen pr
return raw, nil
}
func (p *packetPacker) appendCustomInitialPacketPayload(raw []byte, pl payload, paddingLen protocol.ByteCount, v protocol.VersionNumber) ([]byte, error) {
payloadOffset := len(raw)
// [UQUIC] ignores the default ACK/PADDING frame and uses its own frames
// if pl.ack != nil {
// var err error
// raw, err = pl.ack.Append(raw, v)
// if err != nil {
// return nil, err
// }
// }
// if paddingLen > 0 {
// raw = append(raw, make([]byte, paddingLen)...)
// }
for _, f := range pl.frames {
var err error
raw, err = f.Frame.Append(raw, v)
if err != nil {
return nil, err
}
fmt.Printf("UQUIC: appending frame %v\n", f)
}
for _, f := range pl.streamFrames {
var err error
raw, err = f.Frame.Append(raw, v)
if err != nil {
return nil, err
}
fmt.Printf("UQUIC: appending stream frame %v\n", f)
}
if payloadSize := protocol.ByteCount(len(raw)-payloadOffset) - paddingLen; payloadSize != pl.length {
return nil, fmt.Errorf("PacketPacker BUG: payload size inconsistent (expected %d, got %d bytes)", pl.length, payloadSize)
}
return raw, nil
}
func (p *packetPacker) encryptPacket(raw []byte, sealer sealer, pn protocol.PacketNumber, payloadOffset, pnLen protocol.ByteCount) []byte {
_ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], pn, raw[:payloadOffset])
raw = raw[:len(raw)+sealer.Overhead()]

View file

@ -87,8 +87,6 @@ type Transport struct {
isSingleUse bool // was created for a single server or client, i.e. by calling quic.Listen or quic.Dial
logger utils.Logger
ClientHelloSpec *tls.ClientHelloSpec // [UQUIC]
}
// Listen starts listening for incoming QUIC connections.
@ -156,12 +154,6 @@ func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config
}
conf = populateConfig(conf)
// [UQUIC]
if conf.SrcConnIDLength != 0 {
t.ConnectionIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: conf.SrcConnIDLength}
}
// [/UQUIC]
if err := t.init(t.isSingleUse); err != nil {
return nil, err
}
@ -172,9 +164,6 @@ func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config
tlsConf = tlsConf.Clone()
tlsConf.MinVersion = tls.VersionTLS13
if t.ClientHelloSpec != nil { // [UQUIC]
return dialWithCHS(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, false, t.ClientHelloSpec)
}
return dial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, false)
}
@ -185,12 +174,6 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C
}
conf = populateConfig(conf)
// [UQUIC]
if conf.SrcConnIDLength != 0 {
t.ConnectionIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: conf.SrcConnIDLength}
}
// [/UQUIC]
if err := t.init(t.isSingleUse); err != nil {
return nil, err
}
@ -201,9 +184,6 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C
tlsConf = tlsConf.Clone()
tlsConf.MinVersion = tls.VersionTLS13
if t.ClientHelloSpec != nil { // [UQUIC]
return dialWithCHS(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, false, t.ClientHelloSpec)
}
return dial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, true)
}

150
u_client.go Normal file
View file

@ -0,0 +1,150 @@
package quic
import (
"context"
"errors"
"github.com/quic-go/quic-go/internal/protocol"
tls "github.com/refraction-networking/utls"
)
type uClient struct {
*client
uSpec *QUICSpec // [UQUIC]
}
func udial(
ctx context.Context,
conn sendConn,
connIDGenerator ConnectionIDGenerator,
packetHandlers packetHandlerManager,
tlsConf *tls.Config,
config *Config,
onClose func(),
use0RTT bool,
uSpec *QUICSpec, // [UQUIC]
) (quicConn, error) {
c, err := newClient(conn, connIDGenerator, config, tlsConf, onClose, use0RTT)
if err != nil {
return nil, err
}
c.packetHandlers = packetHandlers
// [UQUIC]
if uSpec.InitialPacketSpec.DestConnIDLength > 0 {
destConnID, err := generateConnectionIDForInitialWithLength(uSpec.InitialPacketSpec.DestConnIDLength)
if err != nil {
return nil, err
}
c.destConnID = destConnID
}
c.initialPacketNumber = protocol.PacketNumber(uSpec.InitialPacketSpec.InitPacketNumber)
// [/UQUIC]
c.tracingID = nextConnTracingID()
if c.config.Tracer != nil {
c.tracer = c.config.Tracer(context.WithValue(ctx, ConnectionTracingKey, c.tracingID), protocol.PerspectiveClient, c.destConnID)
}
if c.tracer != nil {
c.tracer.StartedConnection(c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID)
}
// [UQUIC]
uc := &uClient{
client: c,
uSpec: uSpec,
}
// [/UQUIC]
if err := uc.dial(ctx); err != nil {
return nil, err
}
return uc.conn, nil
}
func (c *uClient) dial(ctx context.Context) error {
c.logger.Infof("Starting new uQUIC connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
// [UQUIC]
if c.uSpec.ClientHelloSpec == nil {
c.conn = newClientConnection(
c.sendConn,
c.packetHandlers,
c.destConnID,
c.srcConnID,
c.connIDGenerator,
c.config,
c.tlsConf,
c.initialPacketNumber,
c.use0RTT,
c.hasNegotiatedVersion,
c.tracer,
c.tracingID,
c.logger,
c.version,
)
} else {
// [UQUIC]: use custom version of the connection
c.conn = newUClientConnection(
c.sendConn,
c.packetHandlers,
c.destConnID,
c.srcConnID,
c.connIDGenerator,
c.config,
c.tlsConf,
c.initialPacketNumber,
c.use0RTT,
c.hasNegotiatedVersion,
c.tracer,
c.tracingID,
c.logger,
c.version,
c.uSpec,
)
}
// [/UQUIC]
c.packetHandlers.Add(c.srcConnID, c.conn)
errorChan := make(chan error, 1)
recreateChan := make(chan errCloseForRecreating)
go func() {
err := c.conn.run()
var recreateErr *errCloseForRecreating
if errors.As(err, &recreateErr) {
recreateChan <- *recreateErr
return
}
if c.onClose != nil {
c.onClose()
}
errorChan <- err // returns as soon as the connection is closed
}()
// only set when we're using 0-RTT
// Otherwise, earlyConnChan will be nil. Receiving from a nil chan blocks forever.
var earlyConnChan <-chan struct{}
if c.use0RTT {
earlyConnChan = c.conn.earlyConnReady()
}
select {
case <-ctx.Done():
c.conn.shutdown()
return ctx.Err()
case err := <-errorChan:
return err
case recreateErr := <-recreateChan:
c.initialPacketNumber = recreateErr.nextPacketNumber
c.version = recreateErr.nextVersion
c.hasNegotiatedVersion = true
return c.dial(ctx)
case <-earlyConnChan:
// ready to send 0-RTT data
return nil
case <-c.conn.HandshakeComplete():
// handshake successfully completed
return nil
}
}

6
u_conn_id_manager.go Normal file
View file

@ -0,0 +1,6 @@
package quic
// [UQUIC]
func (h *connIDManager) SetConnectionIDLimit(limit uint64) {
h.connectionIDLimit = limit
}

170
u_connection.go Normal file
View file

@ -0,0 +1,170 @@
package quic
import (
"context"
"github.com/quic-go/quic-go/internal/ackhandler"
"github.com/quic-go/quic-go/internal/handshake"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/logging"
tls "github.com/refraction-networking/utls"
)
// [UQUIC]
var newUClientConnection = func(
conn sendConn,
runner connRunner,
destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID,
connIDGenerator ConnectionIDGenerator,
conf *Config,
tlsConf *tls.Config,
initialPacketNumber protocol.PacketNumber,
enable0RTT bool,
hasNegotiatedVersion bool,
tracer logging.ConnectionTracer,
tracingID uint64,
logger utils.Logger,
v protocol.VersionNumber,
// chs *tls.ClientHelloSpec,
// initPktNbrLen PacketNumberLen,
// qfs QUICFrames,
// udpDatagramMinSize int,
uSpec *QUICSpec, // [UQUIC]
) quicConn {
s := &connection{
conn: conn,
config: conf,
origDestConnID: destConnID,
handshakeDestConnID: destConnID,
srcConnIDLen: srcConnID.Len(),
perspective: protocol.PerspectiveClient,
logID: destConnID.String(),
logger: logger,
tracer: tracer,
versionNegotiated: hasNegotiatedVersion,
version: v,
}
s.connIDManager = newConnIDManager(
destConnID,
func(token protocol.StatelessResetToken) { runner.AddResetToken(token, s) },
runner.RemoveResetToken,
s.queueControlFrame,
)
s.connIDGenerator = newConnIDGenerator(
srcConnID,
nil,
func(connID protocol.ConnectionID) { runner.Add(connID, s) },
runner.GetStatelessResetToken,
runner.Remove,
runner.Retire,
runner.ReplaceWithClosed,
s.queueControlFrame,
connIDGenerator,
)
s.preSetup()
s.ctx, s.ctxCancel = context.WithCancelCause(context.WithValue(context.Background(), ConnectionTracingKey, tracingID))
s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewUAckHandler( // [UQUIC]
initialPacketNumber,
getMaxPacketSize(s.conn.RemoteAddr()),
s.rttStats,
false, /* has no effect */
s.perspective,
s.tracer,
s.logger,
)
// [UQUIC]
if uSpec.InitialPacketSpec.InitPacketNumberLength != 0 {
ackhandler.SetInitialPacketNumberLength(s.sentPacketHandler, uSpec.InitialPacketSpec.InitPacketNumberLength)
}
s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize)
oneRTTStream := newCryptoStream()
var params *wire.TransportParameters
if uSpec.ClientHelloSpec != nil {
// iterate over all Extensions to set the TransportParameters
var tpSet bool
FOR_EACH_TLS_EXTENSION:
for _, ext := range uSpec.ClientHelloSpec.Extensions {
switch ext := ext.(type) {
case *tls.QUICTransportParametersExtension:
params = &wire.TransportParameters{
InitialSourceConnectionID: srcConnID,
}
params.PopulateFromUQUIC(ext.TransportParameters)
s.connIDManager.SetConnectionIDLimit(params.ActiveConnectionIDLimit)
tpSet = true
break FOR_EACH_TLS_EXTENSION
default:
continue FOR_EACH_TLS_EXTENSION
}
}
if !tpSet {
panic("applied ClientHelloSpec must contain a QUICTransportParametersExtension to proceed")
}
} else {
// use default TransportParameters
params = &wire.TransportParameters{
InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
InitialMaxStreamDataUni: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
InitialMaxData: protocol.ByteCount(s.config.InitialConnectionReceiveWindow),
MaxIdleTimeout: s.config.MaxIdleTimeout,
MaxBidiStreamNum: protocol.StreamNum(s.config.MaxIncomingStreams),
MaxUniStreamNum: protocol.StreamNum(s.config.MaxIncomingUniStreams),
MaxAckDelay: protocol.MaxAckDelayInclGranularity,
AckDelayExponent: protocol.AckDelayExponent,
DisableActiveMigration: true,
// For interoperability with quic-go versions before May 2023, this value must be set to a value
// different from protocol.DefaultActiveConnectionIDLimit.
// If set to the default value, it will be omitted from the transport parameters, which will make
// old quic-go versions interpret it as 0, instead of the default value of 2.
// See https://github.com/quic-go/quic-go/pull/3806.
ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs,
InitialSourceConnectionID: srcConnID,
}
if s.config.EnableDatagrams {
params.MaxDatagramFrameSize = protocol.MaxDatagramFrameSize
} else {
params.MaxDatagramFrameSize = protocol.InvalidByteCount
}
}
if s.tracer != nil {
s.tracer.SentTransportParameters(params)
}
cs := handshake.NewUCryptoSetupClient(
destConnID,
params,
tlsConf,
enable0RTT,
s.rttStats,
tracer,
logger,
s.version,
uSpec.ClientHelloSpec,
)
s.cryptoStreamHandler = cs
s.cryptoStreamManager = newCryptoStreamManager(cs, s.initialStream, s.handshakeStream, oneRTTStream)
s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen)
s.packer = newUPacketPacker(
newPacketPacker(srcConnID, s.connIDManager.Get, s.initialStream, s.handshakeStream, s.sentPacketHandler, s.retransmissionQueue, cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective),
uSpec,
)
if len(tlsConf.ServerName) > 0 {
s.tokenStoreKey = tlsConf.ServerName
} else {
s.tokenStoreKey = conn.RemoteAddr().String()
}
if s.config.TokenStore != nil {
if token := s.config.TokenStore.Pop(s.tokenStoreKey); token != nil {
s.packer.SetToken(token.data)
}
}
return s
}

205
u_initial_packet_spec.go Normal file
View file

@ -0,0 +1,205 @@
package quic
import (
"bytes"
"crypto/rand"
"errors"
"github.com/gaukas/clienthellod"
"github.com/quic-go/quic-go/quicvarint"
)
type InitialPacketSpec struct {
// SrcConnIDLength specifies how many bytes should the SrcConnID be
SrcConnIDLength int
// DestConnIDLength specifies how many bytes should the DestConnID be
DestConnIDLength int
// InitPacketNumberLength specifies how many bytes should the InitPacketNumber
// be interpreted as. It is usually 1 or 2 bytes. If unset, UQUIC will use the
// default algorithm to compute the length which is at least 2 bytes.
InitPacketNumberLength PacketNumberLen
// InitPacketNumber is the packet number of the first Initial packet. Following
// Initial packets, if any, will increment the Packet Number accordingly.
InitPacketNumber uint64 // [UQUIC]
// TokenStore is used to store and retrieve tokens. If set, will override the
// one set in the Config.
TokenStore TokenStore
// If ClientTokenLength is set when TokenStore is not set, a dummy TokenStore
// will be created to randomly generate tokens of the specified length for
// Pop() calls with any key and silently drop any Put() calls.
//
// However, the tokens will not be stored anywhere and are expected to be
// invalid since not assigned by the server.
ClientTokenLength int
// QUICFrames specifies a list of QUIC frames to be sent in the first Initial
// packet.
//
// If nil, it will be treated as a list with only a single QUICFrameCrypto.
FrameOrder QUICFrames
}
func (ps *InitialPacketSpec) UpdateConfig(conf *Config) {
conf.TokenStore = ps.getTokenStore()
}
func (ps *InitialPacketSpec) getTokenStore() TokenStore {
if ps.TokenStore != nil {
return ps.TokenStore
}
if ps.ClientTokenLength > 0 {
return &dummyTokenStore{
tokenLength: ps.ClientTokenLength,
}
}
return nil
}
type dummyTokenStore struct {
tokenLength int
}
func (d *dummyTokenStore) Pop(key string) (token *ClientToken) {
var data []byte = make([]byte, d.tokenLength)
rand.Read(data)
return &ClientToken{
data: data,
}
}
func (d *dummyTokenStore) Put(_ string, _ *ClientToken) {
// Do nothing
}
type QUICFrames []QUICFrame
func (qfs QUICFrames) MarshalWithCryptoData(cryptoData []byte) (payload []byte, err error) {
if len(qfs) == 0 { // If no frames specified, send a single crypto frame
qfs = QUICFrames{QUICFrameCrypto{0, 0}}
return qfs.MarshalWithCryptoData(cryptoData)
}
for _, frame := range qfs {
var frameBytes []byte
if offset, length, cryptoOK := frame.CryptoFrameInfo(); cryptoOK {
if length == 0 {
// calculate length: from offset to the end of cryptoData
length = len(cryptoData) - offset
}
frameBytes = []byte{0x06} // CRYPTO frame type
frameBytes = quicvarint.Append(frameBytes, uint64(offset))
frameBytes = quicvarint.Append(frameBytes, uint64(length))
frameCryptoData := make([]byte, length)
copy(frameCryptoData, cryptoData[offset:]) // copy at most length bytes
frameBytes = append(frameBytes, frameCryptoData...)
} else { // Handle none crypto frames: read and append to payload
frameBytes, err = frame.Read()
if err != nil {
return nil, err
}
}
payload = append(payload, frameBytes...)
}
return payload, nil
}
func (qfs QUICFrames) MarshalWithFrames(frames []byte) (payload []byte, err error) {
// parse frames
r := bytes.NewReader(frames)
qchframes, err := clienthellod.ReadAllFrames(r)
if err != nil {
return nil, err
}
// parse crypto data
cryptoData, err := clienthellod.ReassembleCRYPTOFrames(qchframes)
if err != nil {
return nil, err
}
// marshal
return qfs.MarshalWithCryptoData(cryptoData)
}
type QUICFrame interface {
// None crypto frames should return false for cryptoOK
CryptoFrameInfo() (offset, length int, cryptoOK bool)
// None crypto frames should return the byte representation of the frame.
// Crypto frames' behavior is undefined and unused.
Read() ([]byte, error)
}
// QUICFrameCrypto is used to specify the crypto frames containing the TLS ClientHello
// to be sent in the first Initial packet.
type QUICFrameCrypto struct {
// Offset is used to specify the starting offset of the crypto frame.
// Used when sending multiple crypto frames in a single packet.
//
// Multiple crypto frames in a single packet must not overlap and must
// make up an entire crypto stream continuously.
Offset int
// Length is used to specify the length of the crypto frame.
//
// Must be set if it is NOT the last crypto frame in a packet.
Length int
}
// CryptoFrameInfo() implements the QUICFrame interface.
//
// Crypto frames are later replaced by the crypto message using the information
// returned by this function.
func (q QUICFrameCrypto) CryptoFrameInfo() (offset, length int, cryptoOK bool) {
return q.Offset, q.Length, true
}
// Read() implements the QUICFrame interface.
//
// Crypto frames are later replaced by the crypto message, so they are not Read()-able.
func (q QUICFrameCrypto) Read() ([]byte, error) {
return nil, errors.New("crypto frames are not Read()-able")
}
// QUICFramePadding is used to specify the padding frames to be sent in the first Initial
// packet.
type QUICFramePadding struct {
// Length is used to specify the length of the padding frame.
Length int
}
// CryptoFrameInfo() implements the QUICFrame interface.
func (q QUICFramePadding) CryptoFrameInfo() (offset, length int, cryptoOK bool) {
return 0, 0, false
}
// Read() implements the QUICFrame interface.
//
// Padding simply returns a slice of bytes of the specified length filled with 0.
func (q QUICFramePadding) Read() ([]byte, error) {
return make([]byte, q.Length), nil
}
// QUICFramePing is used to specify the ping frames to be sent in the first Initial
// packet.
type QUICFramePing struct{}
// CryptoFrameInfo() implements the QUICFrame interface.
func (q QUICFramePing) CryptoFrameInfo() (offset, length int, cryptoOK bool) {
return 0, 0, false
}
// Read() implements the QUICFrame interface.
//
// Ping simply returns a slice of bytes of size 1 with value 0x01(PING).
func (q QUICFramePing) Read() ([]byte, error) {
return []byte{0x01}, nil
}

243
u_packet_packer.go Normal file
View file

@ -0,0 +1,243 @@
package quic
import (
"fmt"
"github.com/quic-go/quic-go/internal/handshake"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/wire"
)
// uPacketPacker is an extended packetPacker which is used
// to customize some of the packetPacker's behaviors for
// UQUIC.
type uPacketPacker struct {
*packetPacker
// initPktNbrLen PacketNumberLen
// qfs QUICFrames // [UQUIC] uses QUICFrames to customize encrypted frames
// udpDatagramMinSize int
uSpec *QUICSpec // [UQUIC]
}
func newUPacketPacker(
packetPacker *packetPacker,
uSpec *QUICSpec, // [UQUIC]
) *uPacketPacker {
return &uPacketPacker{
packetPacker: packetPacker,
uSpec: uSpec, // [UQUIC]
}
}
// PackCoalescedPacket packs a new packet.
// It packs an Initial / Handshake if there is data to send in these packet number spaces.
// It should only be called before the handshake is confirmed.
func (p *uPacketPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) {
var (
initialHdr, handshakeHdr, zeroRTTHdr *wire.ExtendedHeader
initialPayload, handshakePayload, zeroRTTPayload, oneRTTPayload payload
oneRTTPacketNumber protocol.PacketNumber
oneRTTPacketNumberLen protocol.PacketNumberLen
)
// Try packing an Initial packet.
initialSealer, err := p.cryptoSetup.GetInitialSealer()
if err != nil && err != handshake.ErrKeysDropped {
return nil, err
}
var size protocol.ByteCount
if initialSealer != nil {
initialHdr, initialPayload = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(initialSealer.Overhead()), protocol.EncryptionInitial, onlyAck, true, v)
if initialPayload.length > 0 {
size += p.longHeaderPacketLength(initialHdr, initialPayload, v) + protocol.ByteCount(initialSealer.Overhead())
}
// // [UQUIC]
// if len(initialPayload.frames) > 0 {
// fmt.Printf("onlyAck: %t, PackCoalescedPacket: %v\n", onlyAck, initialPayload.frames[0].Frame)
// }
}
// Add a Handshake packet.
var handshakeSealer sealer
if (onlyAck && size == 0) || (!onlyAck && size < maxPacketSize-protocol.MinCoalescedPacketSize) {
var err error
handshakeSealer, err = p.cryptoSetup.GetHandshakeSealer()
if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable {
return nil, err
}
if handshakeSealer != nil {
handshakeHdr, handshakePayload = p.maybeGetCryptoPacket(maxPacketSize-size-protocol.ByteCount(handshakeSealer.Overhead()), protocol.EncryptionHandshake, onlyAck, size == 0, v)
if handshakePayload.length > 0 {
s := p.longHeaderPacketLength(handshakeHdr, handshakePayload, v) + protocol.ByteCount(handshakeSealer.Overhead())
size += s
}
}
}
// Add a 0-RTT / 1-RTT packet.
var zeroRTTSealer sealer
var oneRTTSealer handshake.ShortHeaderSealer
var connID protocol.ConnectionID
var kp protocol.KeyPhaseBit
if (onlyAck && size == 0) || (!onlyAck && size < maxPacketSize-protocol.MinCoalescedPacketSize) {
var err error
oneRTTSealer, err = p.cryptoSetup.Get1RTTSealer()
if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable {
return nil, err
}
if err == nil { // 1-RTT
kp = oneRTTSealer.KeyPhase()
connID = p.getDestConnID()
oneRTTPacketNumber, oneRTTPacketNumberLen = p.pnManager.PeekPacketNumber(protocol.Encryption1RTT)
hdrLen := wire.ShortHeaderLen(connID, oneRTTPacketNumberLen)
oneRTTPayload = p.maybeGetShortHeaderPacket(oneRTTSealer, hdrLen, maxPacketSize-size, onlyAck, size == 0, v)
if oneRTTPayload.length > 0 {
size += p.shortHeaderPacketLength(connID, oneRTTPacketNumberLen, oneRTTPayload) + protocol.ByteCount(oneRTTSealer.Overhead())
}
} else if p.perspective == protocol.PerspectiveClient && !onlyAck { // 0-RTT packets can't contain ACK frames
var err error
zeroRTTSealer, err = p.cryptoSetup.Get0RTTSealer()
if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable {
return nil, err
}
if zeroRTTSealer != nil {
zeroRTTHdr, zeroRTTPayload = p.maybeGetAppDataPacketFor0RTT(zeroRTTSealer, maxPacketSize-size, v)
if zeroRTTPayload.length > 0 {
size += p.longHeaderPacketLength(zeroRTTHdr, zeroRTTPayload, v) + protocol.ByteCount(zeroRTTSealer.Overhead())
}
}
}
}
if initialPayload.length == 0 && handshakePayload.length == 0 && zeroRTTPayload.length == 0 && oneRTTPayload.length == 0 {
return nil, nil
}
buffer := getPacketBuffer()
packet := &coalescedPacket{
buffer: buffer,
longHdrPackets: make([]*longHeaderPacket, 0, 3),
}
if initialPayload.length > 0 {
if onlyAck || len(initialPayload.frames) == 0 {
// TODO: uQUIC should send Initial Packet if requested.
// However, it should be otherwise configurable whether to request
// to send Initial Packet or not. See quic-go#4007
padding := p.initialPaddingLen(initialPayload.frames, size, maxPacketSize)
cont, err := p.appendLongHeaderPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer, v)
if err != nil {
return nil, err
}
packet.longHdrPackets = append(packet.longHdrPackets, cont)
} else { // [UQUIC]
cont, err := p.appendInitialPacket(buffer, initialHdr, initialPayload, protocol.EncryptionInitial, initialSealer, v)
if err != nil {
return nil, err
}
packet.longHdrPackets = append(packet.longHdrPackets, cont)
}
}
if handshakePayload.length > 0 {
cont, err := p.appendLongHeaderPacket(buffer, handshakeHdr, handshakePayload, 0, protocol.EncryptionHandshake, handshakeSealer, v)
if err != nil {
return nil, err
}
packet.longHdrPackets = append(packet.longHdrPackets, cont)
}
if zeroRTTPayload.length > 0 {
longHdrPacket, err := p.appendLongHeaderPacket(buffer, zeroRTTHdr, zeroRTTPayload, 0, protocol.Encryption0RTT, zeroRTTSealer, v)
if err != nil {
return nil, err
}
packet.longHdrPackets = append(packet.longHdrPackets, longHdrPacket)
} else if oneRTTPayload.length > 0 {
shp, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, kp, oneRTTPayload, 0, maxPacketSize, oneRTTSealer, false, v)
if err != nil {
return nil, err
}
packet.shortHdrPacket = &shp
}
return packet, nil
}
// [UQUIC]
func (p *uPacketPacker) appendInitialPacket(buffer *packetBuffer, header *wire.ExtendedHeader, pl payload, encLevel protocol.EncryptionLevel, sealer sealer, v protocol.VersionNumber) (*longHeaderPacket, error) {
// Shouldn't need this?
// if p.uSpec.InitialPacketSpec.InitPacketNumberLength > 0 {
// header.PacketNumberLen = p.uSpec.InitialPacketSpec.InitPacketNumberLength
// }
uPayload, err := p.MarshalInitialPacketPayload(pl, v)
if err != nil {
return nil, err
}
pnLen := protocol.ByteCount(header.PacketNumberLen)
header.Length = pnLen + protocol.ByteCount(sealer.Overhead()) + protocol.ByteCount(len(uPayload))
startLen := len(buffer.Data)
raw := buffer.Data[startLen:] // [UQUIC] the raw here is a sub-slice of buffer.Data, latter's len < size
raw, err = header.Append(raw, v)
if err != nil {
return nil, err
}
payloadOffset := protocol.ByteCount(len(raw))
raw = append(raw, uPayload...)
// fmt.Printf("Payload: %x\n", raw[payloadOffset:])
// fmt.Printf("Pre-Encryption: %x\n", raw)
raw = p.encryptPacket(raw, sealer, header.PacketNumber, payloadOffset, pnLen)
buffer.Data = buffer.Data[:len(buffer.Data)+len(raw)]
// fmt.Printf("Post-Encryption: %x\n", raw)
// [UQUIC]
// append zero to buffer.Data until min size is reached
minUDPSize := p.uSpec.UDPDatagramMinSize
if minUDPSize == 0 {
minUDPSize = DefaultUDPDatagramMinSize
}
if len(buffer.Data) < minUDPSize {
buffer.Data = append(buffer.Data, make([]byte, minUDPSize-len(buffer.Data))...)
}
if pn := p.pnManager.PopPacketNumber(encLevel); pn != header.PacketNumber {
return nil, fmt.Errorf("packetPacker BUG: Peeked and Popped packet numbers do not match: expected %d, got %d", pn, header.PacketNumber)
}
return &longHeaderPacket{
header: header,
ack: pl.ack,
frames: pl.frames,
streamFrames: pl.streamFrames,
length: protocol.ByteCount(len(raw)),
}, nil
}
func (p *uPacketPacker) MarshalInitialPacketPayload(pl payload, v protocol.VersionNumber) ([]byte, error) {
var originalFrameBytes []byte
for _, f := range pl.frames {
var err error
// only append crypto frames
if _, ok := f.Frame.(*wire.CryptoFrame); !ok {
continue
}
originalFrameBytes, err = f.Frame.Append(originalFrameBytes, v)
if err != nil {
return nil, err
}
}
uPayload, err := p.uSpec.InitialPacketSpec.FrameOrder.MarshalWithFrames(originalFrameBytes)
if err != nil {
return nil, err
}
return uPayload, nil
}

View file

@ -1,200 +1,18 @@
package quic
import (
"bytes"
"crypto/rand"
"errors"
import tls "github.com/refraction-networking/utls"
"github.com/gaukas/clienthellod"
"github.com/quic-go/quic-go/quicvarint"
const (
DefaultUDPDatagramMinSize = 1200
)
type QUICSpec struct {
// SrcConnIDLength specifies how many bytes should the SrcConnID be
SrcConnIDLength int
InitialPacketSpec InitialPacketSpec
ClientHelloSpec *tls.ClientHelloSpec
// DestConnIDLength specifies how many bytes should the DestConnID be
DstConnIDLength int
// InitPacketNumberLength specifies how many bytes should the InitPacketNumber
// be interpreted as. It is usually 1 or 2 bytes. If unset, UQUIC will use the
// default algorithm to compute the length which is at least 2 bytes.
InitPacketNumberLength PacketNumberLen
// InitPacketNumber is the packet number of the first Initial packet. Following
// Initial packets, if any, will increment the Packet Number accordingly.
InitPacketNumber uint64 // [UQUIC]
// TokenStore is used to store and retrieve tokens. If set, will override the
// one set in the Config.
TokenStore TokenStore
// If ClientTokenLength is set when TokenStore is not set, a dummy TokenStore
// will be created to randomly generate tokens of the specified length for
// Pop() calls with any key and silently drop any Put() calls.
//
// However, the tokens will not be stored anywhere and are expected to be
// invalid since not assigned by the server.
ClientTokenLength int
// QUICFrames specifies a list of QUIC frames to be sent in the first Initial
// packet.
//
// If nil, it will be treated as a list with only a single QUICFrameCrypto.
QUICFrames []QUICFrame
UDPDatagramMinSize int
}
func (s *QUICSpec) getTokenStore() TokenStore {
if s.TokenStore != nil {
return s.TokenStore
}
if s.ClientTokenLength > 0 {
return &dummyTokenStore{
tokenLength: s.ClientTokenLength,
}
}
return nil
}
type dummyTokenStore struct {
tokenLength int
}
func (d *dummyTokenStore) Pop(key string) (token *ClientToken) {
var data []byte = make([]byte, d.tokenLength)
rand.Read(data)
return &ClientToken{
data: data,
}
}
func (d *dummyTokenStore) Put(_ string, _ *ClientToken) {
// Do nothing
}
type QUICFrames []QUICFrame
func (qfs QUICFrames) MarshalWithCryptoData(cryptoData []byte) (payload []byte, err error) {
if len(qfs) == 0 { // If no frames specified, send a single crypto frame
payload = make([]byte, len(cryptoData)+1)
}
for _, frame := range qfs {
var frameBytes []byte
if offset, length, cryptoOK := frame.CryptoFrameInfo(); cryptoOK {
if length == 0 {
// calculate length: from offset to the end of cryptoData
length = len(cryptoData) - offset
}
frameBytes = []byte{0x06} // CRYPTO frame type
frameBytes = quicvarint.Append(frameBytes, uint64(offset))
frameBytes = quicvarint.Append(frameBytes, uint64(length))
frameCryptoData := make([]byte, length)
copy(frameCryptoData, cryptoData[offset:]) // copy at most length bytes
frameBytes = append(frameBytes, frameCryptoData...)
} else { // Handle none crypto frames: read and append to payload
frameBytes, err = frame.Read()
if err != nil {
return nil, err
}
}
payload = append(payload, frameBytes...)
}
return payload, nil
}
func (qfs QUICFrames) MarshalWithFrames(frames []byte) (payload []byte, err error) {
// parse frames
r := bytes.NewReader(frames)
qchframes, err := clienthellod.ReadAllFrames(r)
if err != nil {
return nil, err
}
// parse crypto data
cryptoData, err := clienthellod.ReassembleCRYPTOFrames(qchframes)
if err != nil {
return nil, err
}
// marshal
return qfs.MarshalWithCryptoData(cryptoData)
}
type QUICFrame interface {
// None crypto frames should return false for cryptoOK
CryptoFrameInfo() (offset, length int, cryptoOK bool)
// None crypto frames should return the byte representation of the frame.
// Crypto frames' behavior is undefined and unused.
Read() ([]byte, error)
}
// QUICFrameCrypto is used to specify the crypto frames containing the TLS ClientHello
// to be sent in the first Initial packet.
type QUICFrameCrypto struct {
// Offset is used to specify the starting offset of the crypto frame.
// Used when sending multiple crypto frames in a single packet.
//
// Multiple crypto frames in a single packet must not overlap and must
// make up an entire crypto stream continuously.
Offset int
// Length is used to specify the length of the crypto frame.
//
// Must be set if it is NOT the last crypto frame in a packet.
Length int
}
// CryptoFrameInfo() implements the QUICFrame interface.
//
// Crypto frames are later replaced by the crypto message using the information
// returned by this function.
func (q QUICFrameCrypto) CryptoFrameInfo() (offset, length int, cryptoOK bool) {
return q.Offset, q.Length, true
}
// Read() implements the QUICFrame interface.
//
// Crypto frames are later replaced by the crypto message, so they are not Read()-able.
func (q QUICFrameCrypto) Read() ([]byte, error) {
return nil, errors.New("crypto frames are not Read()-able")
}
// QUICFramePadding is used to specify the padding frames to be sent in the first Initial
// packet.
type QUICFramePadding struct {
// Length is used to specify the length of the padding frame.
Length int
}
// CryptoFrameInfo() implements the QUICFrame interface.
func (q QUICFramePadding) CryptoFrameInfo() (offset, length int, cryptoOK bool) {
return 0, 0, false
}
// Read() implements the QUICFrame interface.
//
// Padding simply returns a slice of bytes of the specified length filled with 0.
func (q QUICFramePadding) Read() ([]byte, error) {
return make([]byte, q.Length), nil
}
// QUICFramePing is used to specify the ping frames to be sent in the first Initial
// packet.
type QUICFramePing struct{}
// CryptoFrameInfo() implements the QUICFrame interface.
func (q QUICFramePing) CryptoFrameInfo() (offset, length int, cryptoOK bool) {
return 0, 0, false
}
// Read() implements the QUICFrame interface.
//
// Ping simply returns a slice of bytes of size 1 with value 0x01(PING).
func (q QUICFramePing) Read() ([]byte, error) {
return []byte{0x01}, nil
func (s *QUICSpec) UpdateConfig(config *Config) {
s.InitialPacketSpec.UpdateConfig(config)
}

87
u_transport.go Normal file
View file

@ -0,0 +1,87 @@
package quic
import (
"context"
"net"
"github.com/quic-go/quic-go/internal/protocol"
tls "github.com/refraction-networking/utls"
)
type UTransport struct {
*Transport
QUICSpec *QUICSpec // [UQUIC] using ptr to avoid copying
}
// Dial dials a new connection to a remote host (not using 0-RTT).
func (t *UTransport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) {
if err := validateConfig(conf); err != nil {
return nil, err
}
conf = populateConfig(conf)
// [UQUIC]
// Override the default connection ID generator if the user has specified a length in QUICSpec.
if t.QUICSpec != nil {
if t.QUICSpec.InitialPacketSpec.SrcConnIDLength != 0 {
t.ConnectionIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: t.QUICSpec.InitialPacketSpec.SrcConnIDLength}
} else {
t.ConnectionIDGenerator = &protocol.ExpEmptyConnectionIDGenerator{}
}
}
// [/UQUIC]
if err := t.init(t.isSingleUse); err != nil {
return nil, err
}
var onClose func()
if t.isSingleUse {
onClose = func() { t.Close() }
}
tlsConf = tlsConf.Clone()
tlsConf.MinVersion = tls.VersionTLS13
return udial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, false, t.QUICSpec)
}
// DialEarly dials a new connection, attempting to use 0-RTT if possible.
func (t *UTransport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) {
if err := validateConfig(conf); err != nil {
return nil, err
}
conf = populateConfig(conf)
// [UQUIC]
// Override the default connection ID generator if the user has specified a length in QUICSpec.
if t.QUICSpec != nil {
if t.QUICSpec.InitialPacketSpec.SrcConnIDLength != 0 {
t.ConnectionIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: t.QUICSpec.InitialPacketSpec.SrcConnIDLength}
} else {
t.ConnectionIDGenerator = &protocol.ExpEmptyConnectionIDGenerator{}
}
}
// [/UQUIC]
if err := t.init(t.isSingleUse); err != nil {
return nil, err
}
var onClose func()
if t.isSingleUse {
onClose = func() { t.Close() }
}
tlsConf = tlsConf.Clone()
tlsConf.MinVersion = tls.VersionTLS13
return udial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, true, t.QUICSpec)
}
func (ut *UTransport) MakeDialer() func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *Config) (EarlyConnection, error) {
return func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *Config) (EarlyConnection, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
return ut.DialEarly(ctx, udpAddr, tlsCfg, cfg)
}
}