diff --git a/client.go b/client.go index dd72f668..b6969573 100644 --- a/client.go +++ b/client.go @@ -38,8 +38,6 @@ type client struct { tracer logging.ConnectionTracer tracingID uint64 logger utils.Logger - - chs *tls.ClientHelloSpec // [UQUIC] } // make it possible to mock connection ID for initial generation in the tests @@ -167,41 +165,6 @@ func dial( return c.conn, nil } -func dialWithCHS( - ctx context.Context, - conn sendConn, - connIDGenerator ConnectionIDGenerator, - packetHandlers packetHandlerManager, - tlsConf *tls.Config, - config *Config, - onClose func(), - use0RTT bool, - chs *tls.ClientHelloSpec, -) (quicConn, error) { - c, err := newClient(conn, connIDGenerator, config, tlsConf, onClose, use0RTT) - if err != nil { - return nil, err - } - c.packetHandlers = packetHandlers - - c.tracingID = nextConnTracingID() - if c.config.Tracer != nil { - c.tracer = c.config.Tracer(context.WithValue(ctx, ConnectionTracingKey, c.tracingID), protocol.PerspectiveClient, c.destConnID) - } - if c.tracer != nil { - c.tracer.StartedConnection(c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID) - } - - // [UQUIC] - c.chs = chs - // [/UQUIC] - - if err := c.dial(ctx); err != nil { - return nil, err - } - return c.conn, nil -} - func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config *Config, tlsConf *tls.Config, onClose func(), use0RTT bool) (*client, error) { if tlsConf == nil { tlsConf = &tls.Config{} @@ -209,25 +172,12 @@ func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config tlsConf = tlsConf.Clone() } - // // [UQUIC] - // if config.SrcConnIDLength != 0 { - // connIDLen := config.SrcConnIDLength - // connIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: connIDLen} - // } - srcConnID, err := connIDGenerator.GenerateConnectionID() if err != nil { return nil, err } - var destConnID protocol.ConnectionID - // [UQUIC] - if config.DestConnIDLength > 0 { - destConnID, err = generateConnectionIDForInitialWithLength(config.DestConnIDLength) - } else { - destConnID, err = generateConnectionIDForInitial() - } - // [/UQUIC] + destConnID, err := generateConnectionIDForInitial() if err != nil { return nil, err } @@ -243,8 +193,6 @@ func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config version: config.Versions[0], handshakeChan: make(chan struct{}), logger: utils.DefaultLogger.WithPrefix("client"), - - initialPacketNumber: protocol.PacketNumber(config.InitPacketNumber), // [UQUIC] } return c, nil } @@ -252,45 +200,22 @@ func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config func (c *client) dial(ctx context.Context) error { c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID, c.version) - // [UQUIC] - if c.chs == nil { - c.conn = newClientConnection( - c.sendConn, - c.packetHandlers, - c.destConnID, - c.srcConnID, - c.connIDGenerator, - c.config, - c.tlsConf, - c.initialPacketNumber, - c.use0RTT, - c.hasNegotiatedVersion, - c.tracer, - c.tracingID, - c.logger, - c.version, - ) - } else { - // [UQUIC]: use custom version of the connection - c.conn = newUClientConnection( - c.sendConn, - c.packetHandlers, - c.destConnID, - c.srcConnID, - c.connIDGenerator, - c.config, - c.tlsConf, - c.initialPacketNumber, - c.use0RTT, - c.hasNegotiatedVersion, - c.tracer, - c.tracingID, - c.logger, - c.version, - c.chs, - ) - } - // [/UQUIC] + c.conn = newClientConnection( + c.sendConn, + c.packetHandlers, + c.destConnID, + c.srcConnID, + c.connIDGenerator, + c.config, + c.tlsConf, + c.initialPacketNumber, + c.use0RTT, + c.hasNegotiatedVersion, + c.tracer, + c.tracingID, + c.logger, + c.version, + ) c.packetHandlers.Add(c.srcConnID, c.conn) diff --git a/config.go b/config.go index f478b979..de41656c 100644 --- a/config.go +++ b/config.go @@ -131,12 +131,6 @@ func populateConfig(config *Config) *Config { DisableVersionNegotiationPackets: config.DisableVersionNegotiationPackets, Allow0RTT: config.Allow0RTT, Tracer: config.Tracer, - - // [UQUIC] - SrcConnIDLength: config.SrcConnIDLength, - DestConnIDLength: config.DestConnIDLength, - InitPacketNumber: config.InitPacketNumber, - InitPacketNumberLength: config.InitPacketNumberLength, } } diff --git a/conn_id_manager.go b/conn_id_manager.go index 9008f7f7..7cb38e2e 100644 --- a/conn_id_manager.go +++ b/conn_id_manager.go @@ -62,10 +62,12 @@ func (h *connIDManager) Add(f *wire.NewConnectionIDFrame) error { return err } + // [UQUIC] connIDLimit := h.connectionIDLimit if connIDLimit == 0 { connIDLimit = protocol.MaxActiveConnectionIDs } + // [/UQUIC] if uint64(h.queue.Len()) >= connIDLimit { return &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError} @@ -191,11 +193,6 @@ func (h *connIDManager) SetStatelessResetToken(token protocol.StatelessResetToke h.addStatelessResetToken(token) } -// [UQUIC] -func (h *connIDManager) SetConnectionIDLimit(limit uint64) { - h.connectionIDLimit = limit -} - func (h *connIDManager) SentPacket() { h.packetsSinceLastChange++ } diff --git a/connection.go b/connection.go index 06a9f10a..47589a23 100644 --- a/connection.go +++ b/connection.go @@ -390,9 +390,6 @@ var newClientConnection = func( s.tracer, s.logger, ) - if conf.InitPacketNumberLength != 0 { - ackhandler.SetInitialPacketNumberLength(s.sentPacketHandler, conf.InitPacketNumberLength) - } s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize) oneRTTStream := newCryptoStream() @@ -453,151 +450,6 @@ var newClientConnection = func( return s } -// [UQUIC] -var newUClientConnection = func( - conn sendConn, - runner connRunner, - destConnID protocol.ConnectionID, - srcConnID protocol.ConnectionID, - connIDGenerator ConnectionIDGenerator, - conf *Config, - tlsConf *tls.Config, - initialPacketNumber protocol.PacketNumber, - enable0RTT bool, - hasNegotiatedVersion bool, - tracer logging.ConnectionTracer, - tracingID uint64, - logger utils.Logger, - v protocol.VersionNumber, - chs *tls.ClientHelloSpec, -) quicConn { - s := &connection{ - conn: conn, - config: conf, - origDestConnID: destConnID, - handshakeDestConnID: destConnID, - srcConnIDLen: srcConnID.Len(), - perspective: protocol.PerspectiveClient, - logID: destConnID.String(), - logger: logger, - tracer: tracer, - versionNegotiated: hasNegotiatedVersion, - version: v, - } - s.connIDManager = newConnIDManager( - destConnID, - func(token protocol.StatelessResetToken) { runner.AddResetToken(token, s) }, - runner.RemoveResetToken, - s.queueControlFrame, - ) - - s.connIDGenerator = newConnIDGenerator( - srcConnID, - nil, - func(connID protocol.ConnectionID) { runner.Add(connID, s) }, - runner.GetStatelessResetToken, - runner.Remove, - runner.Retire, - runner.ReplaceWithClosed, - s.queueControlFrame, - connIDGenerator, - ) - s.preSetup() - s.ctx, s.ctxCancel = context.WithCancelCause(context.WithValue(context.Background(), ConnectionTracingKey, tracingID)) - s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler( - initialPacketNumber, - getMaxPacketSize(s.conn.RemoteAddr()), - s.rttStats, - false, /* has no effect */ - s.perspective, - s.tracer, - s.logger, - ) - if conf.InitPacketNumberLength != 0 { - ackhandler.SetInitialPacketNumberLength(s.sentPacketHandler, conf.InitPacketNumberLength) - } - - s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize) - oneRTTStream := newCryptoStream() - - var params *wire.TransportParameters - // params := &wire.TransportParameters{ - // InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow), - // InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow), - // InitialMaxStreamDataUni: protocol.ByteCount(s.config.InitialStreamReceiveWindow), - // InitialMaxData: protocol.ByteCount(s.config.InitialConnectionReceiveWindow), - // MaxIdleTimeout: s.config.MaxIdleTimeout, - // MaxBidiStreamNum: protocol.StreamNum(s.config.MaxIncomingStreams), - // MaxUniStreamNum: protocol.StreamNum(s.config.MaxIncomingUniStreams), - // MaxAckDelay: protocol.MaxAckDelayInclGranularity, - // AckDelayExponent: protocol.AckDelayExponent, - // DisableActiveMigration: true, - // // For interoperability with quic-go versions before May 2023, this value must be set to a value - // // different from protocol.DefaultActiveConnectionIDLimit. - // // If set to the default value, it will be omitted from the transport parameters, which will make - // // old quic-go versions interpret it as 0, instead of the default value of 2. - // // See https://github.com/quic-go/quic-go/pull/3806. - // ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs, - // InitialSourceConnectionID: srcConnID, - // } - // if s.config.EnableDatagrams { - // params.MaxDatagramFrameSize = protocol.MaxDatagramFrameSize - // } else { - // params.MaxDatagramFrameSize = protocol.InvalidByteCount - // } - - // [UQUIC] iterate over all Extensions to set the TransportParameters - var tpSet bool -FOR_EACH_TLS_EXTENSION: - for _, ext := range chs.Extensions { - switch ext := ext.(type) { - case *tls.QUICTransportParametersExtension: - params = &wire.TransportParameters{ - InitialSourceConnectionID: srcConnID, - } - params.PopulateFromUQUIC(ext.TransportParameters) - s.connIDManager.SetConnectionIDLimit(params.ActiveConnectionIDLimit) - tpSet = true - break FOR_EACH_TLS_EXTENSION - default: - continue FOR_EACH_TLS_EXTENSION - } - } - if !tpSet { - panic("applied ClientHelloSpec must contain a QUICTransportParametersExtension to proceed") - } - - if s.tracer != nil { - s.tracer.SentTransportParameters(params) - } - cs := handshake.NewUCryptoSetupClient( - destConnID, - params, - tlsConf, - enable0RTT, - s.rttStats, - tracer, - logger, - s.version, - chs, - ) - s.cryptoStreamHandler = cs - s.cryptoStreamManager = newCryptoStreamManager(cs, s.initialStream, s.handshakeStream, oneRTTStream) - s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen) - s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, s.initialStream, s.handshakeStream, s.sentPacketHandler, s.retransmissionQueue, cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective) - if len(tlsConf.ServerName) > 0 { - s.tokenStoreKey = tlsConf.ServerName - } else { - s.tokenStoreKey = conn.RemoteAddr().String() - } - if s.config.TokenStore != nil { - if token := s.config.TokenStore.Pop(s.tokenStoreKey); token != nil { - s.packer.SetToken(token.data) - } - } - return s -} - func (s *connection) preSetup() { s.initialStream = newCryptoStream() s.handshakeStream = newCryptoStream() @@ -2204,9 +2056,7 @@ func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, now time } s.connIDManager.SentPacket() s.sendQueue.Send(packet.buffer, packet.buffer.Len()) - // [UQUIC] - // fmt.Printf("sendPackedCoalescedPacket:Sending %d bytes\n", packet.buffer.Len()) - // fmt.Printf("sendPackedCoalescedPacket: %v\n", packet.buffer.Data) + return nil } diff --git a/example/uquic/main.go b/example/uquic/main.go index c97654fa..78bde371 100644 --- a/example/uquic/main.go +++ b/example/uquic/main.go @@ -14,7 +14,83 @@ import ( "github.com/quic-go/quic-go/http3" ) -func getCHS() *tls.ClientHelloSpec { +func main() { + keyLogWriter, err := os.Create("./keylog.txt") + if err != nil { + panic(err) + } + + tlsConf := &tls.Config{ + ServerName: "quic.tlsfingerprint.io", + // ServerName: "www.cloudflare.com", + // MinVersion: tls.VersionTLS13, + KeyLogWriter: keyLogWriter, + // NextProtos: []string{"h3"}, + } + + quicConf := &quic.Config{} + + roundTripper := &http3.RoundTripper{ + TLSClientConfig: tlsConf, + QuicConfig: quicConf, + } + uRoundTripper := http3.GetURoundTripper( + roundTripper, + // getFFQUICSpec(), + getCRQUICSpec(), + nil, + ) + defer uRoundTripper.Close() + + hclient := &http.Client{ + Transport: uRoundTripper, + } + + addr := "https://quic.tlsfingerprint.io/qfp/?beautify=true" + // addr := "https://www.cloudflare.com" + + rsp, err := hclient.Get(addr) + if err != nil { + log.Fatal(err) + } + fmt.Printf("Got response for %s: %#v", addr, rsp) + + body := &bytes.Buffer{} + _, err = io.Copy(body, rsp.Body) + if err != nil { + log.Fatal(err) + } + fmt.Printf("Response Body: %s", body.Bytes()) +} + +func getFFQUICSpec() *quic.QUICSpec { + return &quic.QUICSpec{ + InitialPacketSpec: quic.InitialPacketSpec{ + SrcConnIDLength: 3, + DestConnIDLength: 8, + InitPacketNumberLength: 1, + InitPacketNumber: 1, + ClientTokenLength: 0, + FrameOrder: quic.QUICFrames{ + &quic.QUICFrameCrypto{ + Offset: 300, + Length: 0, + }, + &quic.QUICFramePadding{ + Length: 125, + }, + &quic.QUICFramePing{}, + &quic.QUICFrameCrypto{ + Offset: 0, + Length: 300, + }, + }, + }, + ClientHelloSpec: getFFCHS(), + } +} + +func getFFCHS() *tls.ClientHelloSpec { return &tls.ClientHelloSpec{ TLSVersMin: tls.VersionTLS13, TLSVersMax: tls.VersionTLS13, @@ -135,54 +211,146 @@ func getCHS() *tls.ClientHelloSpec { } } -func main() { - keyLogWriter, err := os.Create("./keylog.txt") - if err != nil { - panic(err) +func getCRQUICSpec() *quic.QUICSpec { + return &quic.QUICSpec{ + InitialPacketSpec: quic.InitialPacketSpec{ + SrcConnIDLength: 0, + DestConnIDLength: 8, + InitPacketNumberLength: 1, + InitPacketNumber: 1, + ClientTokenLength: 0, + FrameOrder: quic.QUICFrames{ + &quic.QUICFrameCrypto{ + Offset: 300, + Length: 0, + }, + &quic.QUICFramePadding{ + Length: 125, + }, + &quic.QUICFramePing{}, + &quic.QUICFrameCrypto{ + Offset: 0, + Length: 300, + }, + }, + }, + ClientHelloSpec: getCRCHS(), + } +} +func getCRCHS() *tls.ClientHelloSpec { + return &tls.ClientHelloSpec{ + TLSVersMin: tls.VersionTLS13, + TLSVersMax: tls.VersionTLS13, + CipherSuites: []uint16{ + tls.TLS_AES_128_GCM_SHA256, + tls.TLS_CHACHA20_POLY1305_SHA256, + tls.TLS_AES_256_GCM_SHA384, + }, + CompressionMethods: []uint8{ + 0x0, // no compression + }, + Extensions: []tls.TLSExtension{ + &tls.SNIExtension{}, + &tls.ExtendedMasterSecretExtension{}, + &tls.RenegotiationInfoExtension{ + Renegotiation: tls.RenegotiateOnceAsClient, + }, + &tls.SupportedCurvesExtension{ + Curves: []tls.CurveID{ + tls.CurveX25519, + tls.CurveSECP256R1, + tls.CurveSECP384R1, + tls.CurveSECP521R1, + tls.FakeCurveFFDHE2048, + tls.FakeCurveFFDHE3072, + tls.FakeCurveFFDHE4096, + tls.FakeCurveFFDHE6144, + tls.FakeCurveFFDHE8192, + }, + }, + &tls.ALPNExtension{ + AlpnProtocols: []string{ + "h3", + }, + }, + &tls.StatusRequestExtension{}, + &tls.FakeDelegatedCredentialsExtension{ + SupportedSignatureAlgorithms: []tls.SignatureScheme{ + tls.ECDSAWithP256AndSHA256, + tls.ECDSAWithP384AndSHA384, + tls.ECDSAWithP521AndSHA512, + tls.ECDSAWithSHA1, + }, + }, + &tls.KeyShareExtension{ + KeyShares: []tls.KeyShare{ + { + Group: tls.X25519, + }, + // { + // Group: tls.CurveP256, + // }, + }, + }, + &tls.SupportedVersionsExtension{ + Versions: []uint16{ + tls.VersionTLS13, + }, + }, + &tls.SignatureAlgorithmsExtension{ + SupportedSignatureAlgorithms: []tls.SignatureScheme{ + tls.ECDSAWithP256AndSHA256, + tls.ECDSAWithP384AndSHA384, + tls.ECDSAWithP521AndSHA512, + tls.ECDSAWithSHA1, + tls.PSSWithSHA256, + tls.PSSWithSHA384, + tls.PSSWithSHA512, + tls.PKCS1WithSHA256, + tls.PKCS1WithSHA384, + tls.PKCS1WithSHA512, + tls.PKCS1WithSHA1, + }, + }, + &tls.PSKKeyExchangeModesExtension{ + Modes: []uint8{ + tls.PskModeDHE, + }, + }, + &tls.FakeRecordSizeLimitExtension{ + Limit: 0x4001, + }, + &tls.QUICTransportParametersExtension{ + TransportParameters: tls.TransportParameters{ + &tls.GREASE{ + IdOverride: 0x35967c5b9c37e023, + ValueOverride: []byte{ + 0xfc, 0x97, 0xbb, 0x57, 0xb8, 0x02, 0x19, 0xcd, + }, + }, + tls.InitialMaxStreamsUni(103), + tls.InitialSourceConnectionID([]byte{}), + tls.InitialMaxStreamsBidi(100), + tls.InitialMaxData(15728640), + &tls.VersionInformation{ + ChoosenVersion: tls.VERSION_1, + AvailableVersions: []uint32{ + tls.VERSION_1, + tls.VERSION_GREASE, + }, + LegacyID: true, + }, + tls.MaxIdleTimeout(30000), + tls.MaxUDPPayloadSize(1472), + tls.MaxDatagramFrameSize(65536), + tls.InitialMaxStreamDataBidiLocal(6291456), + tls.InitialMaxStreamDataUni(6291456), + tls.InitialMaxStreamDataBidiRemote(6291456), + }, + }, + &tls.UtlsPaddingExtension{ + GetPaddingLen: tls.BoringPaddingStyle, + }, + }, } - - tlsConf := &tls.Config{ - ServerName: "quic.tlsfingerprint.io", - // ServerName: "www.cloudflare.com", - // MinVersion: tls.VersionTLS13, - KeyLogWriter: keyLogWriter, - // NextProtos: []string{"h3"}, - } - - quicConf := &quic.Config{ - Versions: []quic.VersionNumber{quic.Version1}, - // EnableDatagrams: true, - SrcConnIDLength: 3, // <4 causes timeout - DestConnIDLength: 8, - InitPacketNumber: 0, - InitPacketNumberLength: quic.PacketNumberLen1, // currently only affects the initial packet number - // Versions: []quic.VersionNumber{quic.Version2}, - } - - roundTripper := &http3.RoundTripper{ - TLSClientConfig: tlsConf, - QuicConfig: quicConf, - ClientHelloSpec: getCHS(), - } - defer roundTripper.Close() - - hclient := &http.Client{ - Transport: roundTripper, - } - - addr := "https://quic.tlsfingerprint.io/qfp/" - // addr := "https://www.cloudflare.com" - - rsp, err := hclient.Get(addr) - if err != nil { - log.Fatal(err) - } - fmt.Printf("Got response for %s: %#v", addr, rsp) - - body := &bytes.Buffer{} - _, err = io.Copy(body, rsp.Body) - if err != nil { - log.Fatal(err) - } - fmt.Printf("Response Body: %s", body.Bytes()) } diff --git a/http3/roundtrip.go b/http3/roundtrip.go index d2b9ae2c..ae589438 100644 --- a/http3/roundtrip.go +++ b/http3/roundtrip.go @@ -88,9 +88,6 @@ type RoundTripper struct { newClient func(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) // so we can mock it in tests clients map[string]*roundTripCloserWithCount transport *quic.Transport - - // [UQUIC] - ClientHelloSpec *tls.ClientHelloSpec } // RoundTripOpt are options for the Transport.RoundTripOpt method. @@ -194,8 +191,7 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (rtc *roundTr return nil, false, err } r.transport = &quic.Transport{ - Conn: udpConn, - ClientHelloSpec: r.ClientHelloSpec, + Conn: udpConn, } } dial = r.makeDialer() diff --git a/http3/u_roundtrip.go b/http3/u_roundtrip.go new file mode 100644 index 00000000..8e61741f --- /dev/null +++ b/http3/u_roundtrip.go @@ -0,0 +1,192 @@ +package http3 + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + + "github.com/quic-go/quic-go" + tls "github.com/refraction-networking/utls" + "golang.org/x/net/http/httpguts" +) + +type URoundTripper struct { + *RoundTripper + + quicSpec *quic.QUICSpec + uTransportOverride *quic.UTransport +} + +func GetURoundTripper(r *RoundTripper, QUICSpec *quic.QUICSpec, uTransport *quic.UTransport) *URoundTripper { + QUICSpec.UpdateConfig(r.QuicConfig) + + return &URoundTripper{ + RoundTripper: r, + quicSpec: QUICSpec, + uTransportOverride: uTransport, + } +} + +// RoundTripOpt is like RoundTrip, but takes options. +func (r *URoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { + if req.URL == nil { + closeRequestBody(req) + return nil, errors.New("http3: nil Request.URL") + } + if req.URL.Scheme != "https" { + closeRequestBody(req) + return nil, fmt.Errorf("http3: unsupported protocol scheme: %s", req.URL.Scheme) + } + if req.URL.Host == "" { + closeRequestBody(req) + return nil, errors.New("http3: no Host in request URL") + } + if req.Header == nil { + closeRequestBody(req) + return nil, errors.New("http3: nil Request.Header") + } + for k, vv := range req.Header { + if !httpguts.ValidHeaderFieldName(k) { + return nil, fmt.Errorf("http3: invalid http header field name %q", k) + } + for _, v := range vv { + if !httpguts.ValidHeaderFieldValue(v) { + return nil, fmt.Errorf("http3: invalid http header field value %q for key %v", v, k) + } + } + } + + if req.Method != "" && !validMethod(req.Method) { + closeRequestBody(req) + return nil, fmt.Errorf("http3: invalid method %q", req.Method) + } + + hostname := authorityAddr("https", hostnameFromRequest(req)) + cl, isReused, err := r.getClient(hostname, opt.OnlyCachedConn) + if err != nil { + return nil, err + } + defer cl.useCount.Add(-1) + rsp, err := cl.RoundTripOpt(req, opt) + if err != nil { + r.removeClient(hostname) + if isReused { + if nerr, ok := err.(net.Error); ok && nerr.Timeout() { + return r.RoundTripOpt(req, opt) + } + } + } + return rsp, err +} + +// RoundTrip does a round trip. +func (r *URoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return r.RoundTripOpt(req, RoundTripOpt{}) +} + +func (r *URoundTripper) getClient(hostname string, onlyCached bool) (rtc *roundTripCloserWithCount, isReused bool, err error) { + r.mutex.Lock() + defer r.mutex.Unlock() + + if r.clients == nil { + r.clients = make(map[string]*roundTripCloserWithCount) + } + + client, ok := r.clients[hostname] + if !ok { + if onlyCached { + return nil, false, ErrNoCachedConn + } + var err error + newCl := newClient + if r.newClient != nil { + newCl = r.newClient + } + dial := r.Dial + if dial == nil { + if r.transport == nil && r.uTransportOverride == nil { + udpConn, err := net.ListenUDP("udp", nil) + if err != nil { + return nil, false, err + } + r.uTransportOverride = &quic.UTransport{ + Transport: &quic.Transport{ + Conn: udpConn, + }, + QUICSpec: r.quicSpec, + } + } + dial = r.makeDialer() + } + c, err := newCl( + hostname, + r.TLSClientConfig, + &roundTripperOpts{ + EnableDatagram: r.EnableDatagrams, + DisableCompression: r.DisableCompression, + MaxHeaderBytes: r.MaxResponseHeaderBytes, + StreamHijacker: r.StreamHijacker, + UniStreamHijacker: r.UniStreamHijacker, + }, + r.QuicConfig, + dial, + ) + if err != nil { + return nil, false, err + } + client = &roundTripCloserWithCount{roundTripCloser: c} + r.clients[hostname] = client + } else if client.HandshakeComplete() { + isReused = true + } + client.useCount.Add(1) + return client, isReused, nil +} + +func (r *URoundTripper) Close() error { + r.mutex.Lock() + defer r.mutex.Unlock() + for _, client := range r.clients { + if err := client.Close(); err != nil { + return err + } + } + r.clients = nil + if r.transport != nil { + if err := r.transport.Close(); err != nil { + return err + } + if err := r.transport.Conn.Close(); err != nil { + return err + } + r.transport = nil + } + if r.uTransportOverride != nil { + if err := r.uTransportOverride.Close(); err != nil { + return err + } + if err := r.uTransportOverride.Conn.Close(); err != nil { + return err + } + r.uTransportOverride = nil + } + return nil +} + +// makeDialer makes a QUIC dialer using r.udpConn. +func (r *URoundTripper) makeDialer() func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { + return func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + if r.uTransportOverride != nil { + return r.uTransportOverride.DialEarly(ctx, udpAddr, tlsCfg, cfg) + } else if r.transport == nil { + return nil, errors.New("http3: no QUIC transport available") + } + return r.transport.DialEarly(ctx, udpAddr, tlsCfg, cfg) + } +} diff --git a/interface.go b/interface.go index 92530923..dbfedccf 100644 --- a/interface.go +++ b/interface.go @@ -333,12 +333,6 @@ type Config struct { // Enable QUIC datagram support (RFC 9221). EnableDatagrams bool Tracer func(context.Context, logging.Perspective, ConnectionID) logging.ConnectionTracer - - // TransportParameters override other transport parameters set by the Config. - SrcConnIDLength int // [UQUIC] - DestConnIDLength int // [UQUIC] - InitPacketNumber uint64 // [UQUIC] - InitPacketNumberLength PacketNumberLen // [UQUIC] } type ClientHelloInfo struct { diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 82d45c25..08684aa2 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -96,8 +96,6 @@ type sentPacketHandler struct { tracer logging.ConnectionTracer logger utils.Logger - - initialPacketNumberLength protocol.PacketNumberLen // [UQUIC] } var ( @@ -138,12 +136,6 @@ func newSentPacketHandler( } } -func SetInitialPacketNumberLength(h SentPacketHandler, pnLen protocol.PacketNumberLen) { - if sph, ok := h.(*sentPacketHandler); ok { - sph.initialPacketNumberLength = pnLen - } -} - func (h *sentPacketHandler) removeFromBytesInFlight(p *packet) { if p.includedInBytesInFlight { if p.Length > h.bytesInFlight { @@ -725,11 +717,6 @@ func (h *sentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel) pn := pnSpace.pns.Peek() // See section 17.1 of RFC 9000. - // [UQUIC] This kinda breaks PN length mimicry. - if encLevel == protocol.EncryptionInitial && h.initialPacketNumberLength != 0 { - return pn, h.initialPacketNumberLength - } - return pn, protocol.GetPacketNumberLengthForHeader(pn, pnSpace.largestAcked) } diff --git a/internal/ackhandler/u_ackhandler.go b/internal/ackhandler/u_ackhandler.go new file mode 100644 index 00000000..56886795 --- /dev/null +++ b/internal/ackhandler/u_ackhandler.go @@ -0,0 +1,23 @@ +package ackhandler + +import ( + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/logging" +) + +// [UQUIC] +func NewUAckHandler( + initialPacketNumber protocol.PacketNumber, + initialMaxDatagramSize protocol.ByteCount, + rttStats *utils.RTTStats, + clientAddressValidated bool, + pers protocol.Perspective, + tracer logging.ConnectionTracer, + logger utils.Logger, +) (SentPacketHandler, ReceivedPacketHandler) { + sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, clientAddressValidated, pers, tracer, logger) + return &uSentPacketHandler{ + sentPacketHandler: sph, + }, newReceivedPacketHandler(sph, rttStats, logger) +} diff --git a/internal/ackhandler/u_sent_packet_handler.go b/internal/ackhandler/u_sent_packet_handler.go new file mode 100644 index 00000000..1e1dfcd3 --- /dev/null +++ b/internal/ackhandler/u_sent_packet_handler.go @@ -0,0 +1,30 @@ +package ackhandler + +import "github.com/quic-go/quic-go/internal/protocol" + +type uSentPacketHandler struct { + *sentPacketHandler + + initialPacketNumberLength protocol.PacketNumberLen // [UQUIC] +} + +func (h *uSentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) { + pnSpace := h.getPacketNumberSpace(encLevel) + pn := pnSpace.pns.Peek() + // See section 17.1 of RFC 9000. + + // [UQUIC] Otherwise it kinda breaks PN length mimicry. + if encLevel == protocol.EncryptionInitial && h.initialPacketNumberLength != 0 { + return pn, h.initialPacketNumberLength + } + // [/UQUIC] + + return pn, protocol.GetPacketNumberLengthForHeader(pn, pnSpace.largestAcked) +} + +// [UQUIC] +func SetInitialPacketNumberLength(h SentPacketHandler, pnLen protocol.PacketNumberLen) { + if sph, ok := h.(*uSentPacketHandler); ok { + sph.initialPacketNumberLength = pnLen + } +} diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 95255d14..670dbd71 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -102,41 +102,6 @@ func NewCryptoSetupClient( return cs } -// [UQUIC] -// NewUCryptoSetupClient creates a new crypto setup for the client with UTLS -func NewUCryptoSetupClient( - connID protocol.ConnectionID, - tp *wire.TransportParameters, - tlsConf *tls.Config, - enable0RTT bool, - rttStats *utils.RTTStats, - tracer logging.ConnectionTracer, - logger utils.Logger, - version protocol.VersionNumber, - chs *tls.ClientHelloSpec, -) CryptoSetup { - cs := newCryptoSetup( - connID, - tp, - rttStats, - tracer, - logger, - protocol.PerspectiveClient, - version, - ) - - tlsConf = tlsConf.Clone() - tlsConf.MinVersion = tls.VersionTLS13 - quicConf := &qtls.QUICConfig{TLSConfig: tlsConf} - qtls.SetupConfigForClient(quicConf, cs.marshalDataForSessionState, cs.handleDataFromSessionState) - cs.tlsConf = tlsConf - - cs.conn = qtls.UQUICClient(quicConf, chs) - // cs.conn.SetTransportParameters(cs.ourParams.Marshal(protocol.PerspectiveClient)) // [UQUIC] doesn't require this - - return cs -} - // NewCryptoSetupServer creates a new crypto setup for the server func NewCryptoSetupServer( connID protocol.ConnectionID, @@ -281,6 +246,7 @@ func (h *cryptoSetup) handleEvent(ev qtls.QUICEvent) (done bool, err error) { return false, h.handleTransportParameters(ev.Data) case qtls.QUICTransportParametersRequired: h.conn.SetTransportParameters(h.ourParams.Marshal(h.perspective)) + // [UQUIC] doesn't expect this and may fail return false, nil case qtls.QUICRejectedEarlyData: h.rejected0RTT() diff --git a/internal/handshake/u_crypto_setup.go b/internal/handshake/u_crypto_setup.go new file mode 100644 index 00000000..38488ddd --- /dev/null +++ b/internal/handshake/u_crypto_setup.go @@ -0,0 +1,45 @@ +package handshake + +import ( + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/qtls" + "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/internal/wire" + "github.com/quic-go/quic-go/logging" + tls "github.com/refraction-networking/utls" +) + +// [UQUIC] +// NewUCryptoSetupClient creates a new crypto setup for the client with UTLS +func NewUCryptoSetupClient( + connID protocol.ConnectionID, + tp *wire.TransportParameters, + tlsConf *tls.Config, + enable0RTT bool, + rttStats *utils.RTTStats, + tracer logging.ConnectionTracer, + logger utils.Logger, + version protocol.VersionNumber, + chs *tls.ClientHelloSpec, +) CryptoSetup { + cs := newCryptoSetup( + connID, + tp, + rttStats, + tracer, + logger, + protocol.PerspectiveClient, + version, + ) + + tlsConf = tlsConf.Clone() + tlsConf.MinVersion = tls.VersionTLS13 + quicConf := &qtls.QUICConfig{TLSConfig: tlsConf} + qtls.SetupConfigForClient(quicConf, cs.marshalDataForSessionState, cs.handleDataFromSessionState) + cs.tlsConf = tlsConf + + cs.conn = qtls.UQUICClient(quicConf, chs) + // cs.conn.SetTransportParameters(cs.ourParams.Marshal(protocol.PerspectiveClient)) // [UQUIC] doesn't require this + + return cs +} diff --git a/internal/protocol/connection_id.go b/internal/protocol/connection_id.go index 042d3a99..77259b5f 100644 --- a/internal/protocol/connection_id.go +++ b/internal/protocol/connection_id.go @@ -68,11 +68,6 @@ func GenerateConnectionIDForInitial() (ConnectionID, error) { return GenerateConnectionID(l) } -// [UQUIC] -func GenerateConnectionIDForInitialWithLen(l int) (ConnectionID, error) { - return GenerateConnectionID(l) -} - // ReadConnectionID reads a connection ID of length len from the given io.Reader. // It returns io.EOF if there are not enough bytes to read. func ReadConnectionID(r io.Reader, l int) (ConnectionID, error) { diff --git a/internal/protocol/u_connection_id.go b/internal/protocol/u_connection_id.go new file mode 100644 index 00000000..8f3b388d --- /dev/null +++ b/internal/protocol/u_connection_id.go @@ -0,0 +1,16 @@ +package protocol + +// [UQUIC] +func GenerateConnectionIDForInitialWithLen(l int) (ConnectionID, error) { + return GenerateConnectionID(l) +} + +type ExpEmptyConnectionIDGenerator struct{} + +func (g *ExpEmptyConnectionIDGenerator) GenerateConnectionID() (ConnectionID, error) { + return GenerateConnectionID(0) +} + +func (g *ExpEmptyConnectionIDGenerator) ConnectionIDLen() int { + return 0 +} diff --git a/packet_packer.go b/packet_packer.go index e6ab4bec..21af817c 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -332,11 +332,6 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol. if initialPayload.length > 0 { size += p.longHeaderPacketLength(initialHdr, initialPayload, v) + protocol.ByteCount(initialSealer.Overhead()) } - - // // [UQUIC] - // if len(initialPayload.frames) > 0 { - // fmt.Printf("onlyAck: %t, PackCoalescedPacket: %v\n", onlyAck, initialPayload.frames[0].Frame) - // } } // Add a Handshake packet. @@ -401,24 +396,12 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol. longHdrPackets: make([]*longHeaderPacket, 0, 3), } if initialPayload.length > 0 { - if onlyAck || len(initialPayload.frames) == 0 { - // padding := p.initialPaddingLen(initialPayload.frames, size, maxPacketSize) - // cont, err := p.appendLongHeaderPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer, v) - // if err != nil { - // return nil, err - // } - // packet.longHdrPackets = append(packet.longHdrPackets, cont) - return nil, nil // [UQUIC] not to send the ACK frame for Initial - } else { // [UQUIC] - cont, err := p.appendLongHeaderPacketExternalPadding(buffer, initialHdr, initialPayload, protocol.EncryptionInitial, initialSealer, v) - if err != nil { - return nil, err - } - - // fmt.Printf("!onlyAck buffer: %v\n", buffer.Data) - - packet.longHdrPackets = append(packet.longHdrPackets, cont) + padding := p.initialPaddingLen(initialPayload.frames, size, maxPacketSize) + cont, err := p.appendLongHeaderPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer, v) + if err != nil { + return nil, err } + packet.longHdrPackets = append(packet.longHdrPackets, cont) } if handshakePayload.length > 0 { cont, err := p.appendLongHeaderPacket(buffer, handshakeHdr, handshakePayload, 0, protocol.EncryptionHandshake, handshakeSealer, v) @@ -688,7 +671,6 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel, m return nil, err } hdr, pl = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionInitial, false, true, v) - fmt.Printf("MaybePackProbePacket: %x\n", pl.frames[0]) case protocol.EncryptionHandshake: var err error sealer, err = p.cryptoSetup.GetHandshakeSealer() @@ -768,10 +750,6 @@ func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire } paddingLen += padding - if encLevel == protocol.EncryptionInitial { - paddingLen = 0 - } - header.Length = pnLen + protocol.ByteCount(sealer.Overhead()) + pl.length + paddingLen startLen := len(buffer.Data) @@ -800,49 +778,6 @@ func (p *packetPacker) appendLongHeaderPacket(buffer *packetBuffer, header *wire }, nil } -// [UQUIC] -func (p *packetPacker) appendLongHeaderPacketExternalPadding(buffer *packetBuffer, header *wire.ExtendedHeader, pl payload, encLevel protocol.EncryptionLevel, sealer sealer, v protocol.VersionNumber) (*longHeaderPacket, error) { - pnLen := protocol.ByteCount(header.PacketNumberLen) - header.Length = pnLen + protocol.ByteCount(sealer.Overhead()) + pl.length - - startLen := len(buffer.Data) - raw := buffer.Data[startLen:] // [UQUIC] raw is a sub-slice of buffer.Data, whose len < size - raw, err := header.Append(raw, v) - if err != nil { - return nil, err - } - - fmt.Printf("Pre-Payload: %x\n", raw) - - payloadOffset := protocol.ByteCount(len(raw)) - raw, err = p.appendCustomInitialPacketPayload(raw, pl, 0, v) - if err != nil { - return nil, err - } - - fmt.Printf("Pre-Encryption: %x\n", raw) - - raw = p.encryptPacket(raw, sealer, header.PacketNumber, payloadOffset, pnLen) - buffer.Data = buffer.Data[:len(buffer.Data)+len(raw)] - - fmt.Printf("Post-Encryption: %x\n", raw) - - // [UQUIC] - // append zero to buffer.Data until 1200 bytes - buffer.Data = append(buffer.Data, make([]byte, 1357-len(buffer.Data))...) - - if pn := p.pnManager.PopPacketNumber(encLevel); pn != header.PacketNumber { - return nil, fmt.Errorf("packetPacker BUG: Peeked and Popped packet numbers do not match: expected %d, got %d", pn, header.PacketNumber) - } - return &longHeaderPacket{ - header: header, - ack: pl.ack, - frames: pl.frames, - streamFrames: pl.streamFrames, - length: protocol.ByteCount(len(raw)), - }, nil -} - func (p *packetPacker) appendShortHeaderPacket( buffer *packetBuffer, connID protocol.ConnectionID, @@ -930,44 +865,6 @@ func (p *packetPacker) appendPacketPayload(raw []byte, pl payload, paddingLen pr return raw, nil } -func (p *packetPacker) appendCustomInitialPacketPayload(raw []byte, pl payload, paddingLen protocol.ByteCount, v protocol.VersionNumber) ([]byte, error) { - payloadOffset := len(raw) - - // [UQUIC] ignores the default ACK/PADDING frame and uses its own frames - // if pl.ack != nil { - // var err error - // raw, err = pl.ack.Append(raw, v) - // if err != nil { - // return nil, err - // } - // } - // if paddingLen > 0 { - // raw = append(raw, make([]byte, paddingLen)...) - // } - - for _, f := range pl.frames { - var err error - raw, err = f.Frame.Append(raw, v) - if err != nil { - return nil, err - } - fmt.Printf("UQUIC: appending frame %v\n", f) - } - for _, f := range pl.streamFrames { - var err error - raw, err = f.Frame.Append(raw, v) - if err != nil { - return nil, err - } - fmt.Printf("UQUIC: appending stream frame %v\n", f) - } - - if payloadSize := protocol.ByteCount(len(raw)-payloadOffset) - paddingLen; payloadSize != pl.length { - return nil, fmt.Errorf("PacketPacker BUG: payload size inconsistent (expected %d, got %d bytes)", pl.length, payloadSize) - } - return raw, nil -} - func (p *packetPacker) encryptPacket(raw []byte, sealer sealer, pn protocol.PacketNumber, payloadOffset, pnLen protocol.ByteCount) []byte { _ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], pn, raw[:payloadOffset]) raw = raw[:len(raw)+sealer.Overhead()] diff --git a/transport.go b/transport.go index b0b527c4..e002a261 100644 --- a/transport.go +++ b/transport.go @@ -87,8 +87,6 @@ type Transport struct { isSingleUse bool // was created for a single server or client, i.e. by calling quic.Listen or quic.Dial logger utils.Logger - - ClientHelloSpec *tls.ClientHelloSpec // [UQUIC] } // Listen starts listening for incoming QUIC connections. @@ -156,12 +154,6 @@ func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config } conf = populateConfig(conf) - // [UQUIC] - if conf.SrcConnIDLength != 0 { - t.ConnectionIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: conf.SrcConnIDLength} - } - // [/UQUIC] - if err := t.init(t.isSingleUse); err != nil { return nil, err } @@ -172,9 +164,6 @@ func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config tlsConf = tlsConf.Clone() tlsConf.MinVersion = tls.VersionTLS13 - if t.ClientHelloSpec != nil { // [UQUIC] - return dialWithCHS(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, false, t.ClientHelloSpec) - } return dial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, false) } @@ -185,12 +174,6 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C } conf = populateConfig(conf) - // [UQUIC] - if conf.SrcConnIDLength != 0 { - t.ConnectionIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: conf.SrcConnIDLength} - } - // [/UQUIC] - if err := t.init(t.isSingleUse); err != nil { return nil, err } @@ -201,9 +184,6 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C tlsConf = tlsConf.Clone() tlsConf.MinVersion = tls.VersionTLS13 - if t.ClientHelloSpec != nil { // [UQUIC] - return dialWithCHS(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, false, t.ClientHelloSpec) - } return dial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, true) } diff --git a/u_client.go b/u_client.go new file mode 100644 index 00000000..e3b0a408 --- /dev/null +++ b/u_client.go @@ -0,0 +1,150 @@ +package quic + +import ( + "context" + "errors" + + "github.com/quic-go/quic-go/internal/protocol" + tls "github.com/refraction-networking/utls" +) + +type uClient struct { + *client + uSpec *QUICSpec // [UQUIC] +} + +func udial( + ctx context.Context, + conn sendConn, + connIDGenerator ConnectionIDGenerator, + packetHandlers packetHandlerManager, + tlsConf *tls.Config, + config *Config, + onClose func(), + use0RTT bool, + uSpec *QUICSpec, // [UQUIC] +) (quicConn, error) { + c, err := newClient(conn, connIDGenerator, config, tlsConf, onClose, use0RTT) + if err != nil { + return nil, err + } + c.packetHandlers = packetHandlers + + // [UQUIC] + if uSpec.InitialPacketSpec.DestConnIDLength > 0 { + destConnID, err := generateConnectionIDForInitialWithLength(uSpec.InitialPacketSpec.DestConnIDLength) + if err != nil { + return nil, err + } + c.destConnID = destConnID + } + c.initialPacketNumber = protocol.PacketNumber(uSpec.InitialPacketSpec.InitPacketNumber) + // [/UQUIC] + + c.tracingID = nextConnTracingID() + if c.config.Tracer != nil { + c.tracer = c.config.Tracer(context.WithValue(ctx, ConnectionTracingKey, c.tracingID), protocol.PerspectiveClient, c.destConnID) + } + if c.tracer != nil { + c.tracer.StartedConnection(c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID) + } + + // [UQUIC] + uc := &uClient{ + client: c, + uSpec: uSpec, + } + // [/UQUIC] + + if err := uc.dial(ctx); err != nil { + return nil, err + } + return uc.conn, nil +} + +func (c *uClient) dial(ctx context.Context) error { + c.logger.Infof("Starting new uQUIC connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID, c.version) + + // [UQUIC] + if c.uSpec.ClientHelloSpec == nil { + c.conn = newClientConnection( + c.sendConn, + c.packetHandlers, + c.destConnID, + c.srcConnID, + c.connIDGenerator, + c.config, + c.tlsConf, + c.initialPacketNumber, + c.use0RTT, + c.hasNegotiatedVersion, + c.tracer, + c.tracingID, + c.logger, + c.version, + ) + } else { + // [UQUIC]: use custom version of the connection + c.conn = newUClientConnection( + c.sendConn, + c.packetHandlers, + c.destConnID, + c.srcConnID, + c.connIDGenerator, + c.config, + c.tlsConf, + c.initialPacketNumber, + c.use0RTT, + c.hasNegotiatedVersion, + c.tracer, + c.tracingID, + c.logger, + c.version, + c.uSpec, + ) + } + // [/UQUIC] + + c.packetHandlers.Add(c.srcConnID, c.conn) + + errorChan := make(chan error, 1) + recreateChan := make(chan errCloseForRecreating) + go func() { + err := c.conn.run() + var recreateErr *errCloseForRecreating + if errors.As(err, &recreateErr) { + recreateChan <- *recreateErr + return + } + if c.onClose != nil { + c.onClose() + } + errorChan <- err // returns as soon as the connection is closed + }() + + // only set when we're using 0-RTT + // Otherwise, earlyConnChan will be nil. Receiving from a nil chan blocks forever. + var earlyConnChan <-chan struct{} + if c.use0RTT { + earlyConnChan = c.conn.earlyConnReady() + } + + select { + case <-ctx.Done(): + c.conn.shutdown() + return ctx.Err() + case err := <-errorChan: + return err + case recreateErr := <-recreateChan: + c.initialPacketNumber = recreateErr.nextPacketNumber + c.version = recreateErr.nextVersion + c.hasNegotiatedVersion = true + return c.dial(ctx) + case <-earlyConnChan: + // ready to send 0-RTT data + return nil + case <-c.conn.HandshakeComplete(): + // handshake successfully completed + return nil + } +} diff --git a/u_conn_id_manager.go b/u_conn_id_manager.go new file mode 100644 index 00000000..9a9d1e93 --- /dev/null +++ b/u_conn_id_manager.go @@ -0,0 +1,6 @@ +package quic + +// [UQUIC] +func (h *connIDManager) SetConnectionIDLimit(limit uint64) { + h.connectionIDLimit = limit +} diff --git a/u_connection.go b/u_connection.go new file mode 100644 index 00000000..9a1803e4 --- /dev/null +++ b/u_connection.go @@ -0,0 +1,170 @@ +package quic + +import ( + "context" + + "github.com/quic-go/quic-go/internal/ackhandler" + "github.com/quic-go/quic-go/internal/handshake" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/internal/wire" + "github.com/quic-go/quic-go/logging" + tls "github.com/refraction-networking/utls" +) + +// [UQUIC] +var newUClientConnection = func( + conn sendConn, + runner connRunner, + destConnID protocol.ConnectionID, + srcConnID protocol.ConnectionID, + connIDGenerator ConnectionIDGenerator, + conf *Config, + tlsConf *tls.Config, + initialPacketNumber protocol.PacketNumber, + enable0RTT bool, + hasNegotiatedVersion bool, + tracer logging.ConnectionTracer, + tracingID uint64, + logger utils.Logger, + v protocol.VersionNumber, + // chs *tls.ClientHelloSpec, + // initPktNbrLen PacketNumberLen, + // qfs QUICFrames, + // udpDatagramMinSize int, + uSpec *QUICSpec, // [UQUIC] +) quicConn { + s := &connection{ + conn: conn, + config: conf, + origDestConnID: destConnID, + handshakeDestConnID: destConnID, + srcConnIDLen: srcConnID.Len(), + perspective: protocol.PerspectiveClient, + logID: destConnID.String(), + logger: logger, + tracer: tracer, + versionNegotiated: hasNegotiatedVersion, + version: v, + } + s.connIDManager = newConnIDManager( + destConnID, + func(token protocol.StatelessResetToken) { runner.AddResetToken(token, s) }, + runner.RemoveResetToken, + s.queueControlFrame, + ) + + s.connIDGenerator = newConnIDGenerator( + srcConnID, + nil, + func(connID protocol.ConnectionID) { runner.Add(connID, s) }, + runner.GetStatelessResetToken, + runner.Remove, + runner.Retire, + runner.ReplaceWithClosed, + s.queueControlFrame, + connIDGenerator, + ) + s.preSetup() + s.ctx, s.ctxCancel = context.WithCancelCause(context.WithValue(context.Background(), ConnectionTracingKey, tracingID)) + s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewUAckHandler( // [UQUIC] + initialPacketNumber, + getMaxPacketSize(s.conn.RemoteAddr()), + s.rttStats, + false, /* has no effect */ + s.perspective, + s.tracer, + s.logger, + ) + // [UQUIC] + if uSpec.InitialPacketSpec.InitPacketNumberLength != 0 { + ackhandler.SetInitialPacketNumberLength(s.sentPacketHandler, uSpec.InitialPacketSpec.InitPacketNumberLength) + } + + s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize) + oneRTTStream := newCryptoStream() + + var params *wire.TransportParameters + + if uSpec.ClientHelloSpec != nil { + // iterate over all Extensions to set the TransportParameters + var tpSet bool + FOR_EACH_TLS_EXTENSION: + for _, ext := range uSpec.ClientHelloSpec.Extensions { + switch ext := ext.(type) { + case *tls.QUICTransportParametersExtension: + params = &wire.TransportParameters{ + InitialSourceConnectionID: srcConnID, + } + params.PopulateFromUQUIC(ext.TransportParameters) + s.connIDManager.SetConnectionIDLimit(params.ActiveConnectionIDLimit) + tpSet = true + break FOR_EACH_TLS_EXTENSION + default: + continue FOR_EACH_TLS_EXTENSION + } + } + if !tpSet { + panic("applied ClientHelloSpec must contain a QUICTransportParametersExtension to proceed") + } + } else { + // use default TransportParameters + params = &wire.TransportParameters{ + InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow), + InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow), + InitialMaxStreamDataUni: protocol.ByteCount(s.config.InitialStreamReceiveWindow), + InitialMaxData: protocol.ByteCount(s.config.InitialConnectionReceiveWindow), + MaxIdleTimeout: s.config.MaxIdleTimeout, + MaxBidiStreamNum: protocol.StreamNum(s.config.MaxIncomingStreams), + MaxUniStreamNum: protocol.StreamNum(s.config.MaxIncomingUniStreams), + MaxAckDelay: protocol.MaxAckDelayInclGranularity, + AckDelayExponent: protocol.AckDelayExponent, + DisableActiveMigration: true, + // For interoperability with quic-go versions before May 2023, this value must be set to a value + // different from protocol.DefaultActiveConnectionIDLimit. + // If set to the default value, it will be omitted from the transport parameters, which will make + // old quic-go versions interpret it as 0, instead of the default value of 2. + // See https://github.com/quic-go/quic-go/pull/3806. + ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs, + InitialSourceConnectionID: srcConnID, + } + if s.config.EnableDatagrams { + params.MaxDatagramFrameSize = protocol.MaxDatagramFrameSize + } else { + params.MaxDatagramFrameSize = protocol.InvalidByteCount + } + } + + if s.tracer != nil { + s.tracer.SentTransportParameters(params) + } + cs := handshake.NewUCryptoSetupClient( + destConnID, + params, + tlsConf, + enable0RTT, + s.rttStats, + tracer, + logger, + s.version, + uSpec.ClientHelloSpec, + ) + s.cryptoStreamHandler = cs + s.cryptoStreamManager = newCryptoStreamManager(cs, s.initialStream, s.handshakeStream, oneRTTStream) + s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen) + s.packer = newUPacketPacker( + newPacketPacker(srcConnID, s.connIDManager.Get, s.initialStream, s.handshakeStream, s.sentPacketHandler, s.retransmissionQueue, cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective), + uSpec, + ) + if len(tlsConf.ServerName) > 0 { + s.tokenStoreKey = tlsConf.ServerName + } else { + s.tokenStoreKey = conn.RemoteAddr().String() + } + if s.config.TokenStore != nil { + if token := s.config.TokenStore.Pop(s.tokenStoreKey); token != nil { + s.packer.SetToken(token.data) + } + } + return s +} diff --git a/u_initial_packet_spec.go b/u_initial_packet_spec.go new file mode 100644 index 00000000..070fd292 --- /dev/null +++ b/u_initial_packet_spec.go @@ -0,0 +1,205 @@ +package quic + +import ( + "bytes" + "crypto/rand" + "errors" + + "github.com/gaukas/clienthellod" + "github.com/quic-go/quic-go/quicvarint" +) + +type InitialPacketSpec struct { + // SrcConnIDLength specifies how many bytes should the SrcConnID be + SrcConnIDLength int + + // DestConnIDLength specifies how many bytes should the DestConnID be + DestConnIDLength int + + // InitPacketNumberLength specifies how many bytes should the InitPacketNumber + // be interpreted as. It is usually 1 or 2 bytes. If unset, UQUIC will use the + // default algorithm to compute the length which is at least 2 bytes. + InitPacketNumberLength PacketNumberLen + + // InitPacketNumber is the packet number of the first Initial packet. Following + // Initial packets, if any, will increment the Packet Number accordingly. + InitPacketNumber uint64 // [UQUIC] + + // TokenStore is used to store and retrieve tokens. If set, will override the + // one set in the Config. + TokenStore TokenStore + + // If ClientTokenLength is set when TokenStore is not set, a dummy TokenStore + // will be created to randomly generate tokens of the specified length for + // Pop() calls with any key and silently drop any Put() calls. + // + // However, the tokens will not be stored anywhere and are expected to be + // invalid since not assigned by the server. + ClientTokenLength int + + // QUICFrames specifies a list of QUIC frames to be sent in the first Initial + // packet. + // + // If nil, it will be treated as a list with only a single QUICFrameCrypto. + FrameOrder QUICFrames +} + +func (ps *InitialPacketSpec) UpdateConfig(conf *Config) { + conf.TokenStore = ps.getTokenStore() +} + +func (ps *InitialPacketSpec) getTokenStore() TokenStore { + if ps.TokenStore != nil { + return ps.TokenStore + } + + if ps.ClientTokenLength > 0 { + return &dummyTokenStore{ + tokenLength: ps.ClientTokenLength, + } + } + + return nil +} + +type dummyTokenStore struct { + tokenLength int +} + +func (d *dummyTokenStore) Pop(key string) (token *ClientToken) { + var data []byte = make([]byte, d.tokenLength) + rand.Read(data) + + return &ClientToken{ + data: data, + } +} + +func (d *dummyTokenStore) Put(_ string, _ *ClientToken) { + // Do nothing +} + +type QUICFrames []QUICFrame + +func (qfs QUICFrames) MarshalWithCryptoData(cryptoData []byte) (payload []byte, err error) { + if len(qfs) == 0 { // If no frames specified, send a single crypto frame + qfs = QUICFrames{QUICFrameCrypto{0, 0}} + return qfs.MarshalWithCryptoData(cryptoData) + } + + for _, frame := range qfs { + var frameBytes []byte + if offset, length, cryptoOK := frame.CryptoFrameInfo(); cryptoOK { + if length == 0 { + // calculate length: from offset to the end of cryptoData + length = len(cryptoData) - offset + } + frameBytes = []byte{0x06} // CRYPTO frame type + frameBytes = quicvarint.Append(frameBytes, uint64(offset)) + frameBytes = quicvarint.Append(frameBytes, uint64(length)) + frameCryptoData := make([]byte, length) + copy(frameCryptoData, cryptoData[offset:]) // copy at most length bytes + frameBytes = append(frameBytes, frameCryptoData...) + } else { // Handle none crypto frames: read and append to payload + frameBytes, err = frame.Read() + if err != nil { + return nil, err + } + } + payload = append(payload, frameBytes...) + } + return payload, nil +} + +func (qfs QUICFrames) MarshalWithFrames(frames []byte) (payload []byte, err error) { + // parse frames + r := bytes.NewReader(frames) + qchframes, err := clienthellod.ReadAllFrames(r) + if err != nil { + return nil, err + } + + // parse crypto data + cryptoData, err := clienthellod.ReassembleCRYPTOFrames(qchframes) + if err != nil { + return nil, err + } + + // marshal + return qfs.MarshalWithCryptoData(cryptoData) +} + +type QUICFrame interface { + // None crypto frames should return false for cryptoOK + CryptoFrameInfo() (offset, length int, cryptoOK bool) + + // None crypto frames should return the byte representation of the frame. + // Crypto frames' behavior is undefined and unused. + Read() ([]byte, error) +} + +// QUICFrameCrypto is used to specify the crypto frames containing the TLS ClientHello +// to be sent in the first Initial packet. +type QUICFrameCrypto struct { + // Offset is used to specify the starting offset of the crypto frame. + // Used when sending multiple crypto frames in a single packet. + // + // Multiple crypto frames in a single packet must not overlap and must + // make up an entire crypto stream continuously. + Offset int + + // Length is used to specify the length of the crypto frame. + // + // Must be set if it is NOT the last crypto frame in a packet. + Length int +} + +// CryptoFrameInfo() implements the QUICFrame interface. +// +// Crypto frames are later replaced by the crypto message using the information +// returned by this function. +func (q QUICFrameCrypto) CryptoFrameInfo() (offset, length int, cryptoOK bool) { + return q.Offset, q.Length, true +} + +// Read() implements the QUICFrame interface. +// +// Crypto frames are later replaced by the crypto message, so they are not Read()-able. +func (q QUICFrameCrypto) Read() ([]byte, error) { + return nil, errors.New("crypto frames are not Read()-able") +} + +// QUICFramePadding is used to specify the padding frames to be sent in the first Initial +// packet. +type QUICFramePadding struct { + // Length is used to specify the length of the padding frame. + Length int +} + +// CryptoFrameInfo() implements the QUICFrame interface. +func (q QUICFramePadding) CryptoFrameInfo() (offset, length int, cryptoOK bool) { + return 0, 0, false +} + +// Read() implements the QUICFrame interface. +// +// Padding simply returns a slice of bytes of the specified length filled with 0. +func (q QUICFramePadding) Read() ([]byte, error) { + return make([]byte, q.Length), nil +} + +// QUICFramePing is used to specify the ping frames to be sent in the first Initial +// packet. +type QUICFramePing struct{} + +// CryptoFrameInfo() implements the QUICFrame interface. +func (q QUICFramePing) CryptoFrameInfo() (offset, length int, cryptoOK bool) { + return 0, 0, false +} + +// Read() implements the QUICFrame interface. +// +// Ping simply returns a slice of bytes of size 1 with value 0x01(PING). +func (q QUICFramePing) Read() ([]byte, error) { + return []byte{0x01}, nil +} diff --git a/u_quic_spec_test.go b/u_initial_packet_spec_test.go similarity index 100% rename from u_quic_spec_test.go rename to u_initial_packet_spec_test.go diff --git a/u_packet_packer.go b/u_packet_packer.go new file mode 100644 index 00000000..0d984507 --- /dev/null +++ b/u_packet_packer.go @@ -0,0 +1,243 @@ +package quic + +import ( + "fmt" + + "github.com/quic-go/quic-go/internal/handshake" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/wire" +) + +// uPacketPacker is an extended packetPacker which is used +// to customize some of the packetPacker's behaviors for +// UQUIC. +type uPacketPacker struct { + *packetPacker + + // initPktNbrLen PacketNumberLen + // qfs QUICFrames // [UQUIC] uses QUICFrames to customize encrypted frames + // udpDatagramMinSize int + uSpec *QUICSpec // [UQUIC] +} + +func newUPacketPacker( + packetPacker *packetPacker, + uSpec *QUICSpec, // [UQUIC] +) *uPacketPacker { + return &uPacketPacker{ + packetPacker: packetPacker, + uSpec: uSpec, // [UQUIC] + } +} + +// PackCoalescedPacket packs a new packet. +// It packs an Initial / Handshake if there is data to send in these packet number spaces. +// It should only be called before the handshake is confirmed. +func (p *uPacketPacker) PackCoalescedPacket(onlyAck bool, maxPacketSize protocol.ByteCount, v protocol.VersionNumber) (*coalescedPacket, error) { + var ( + initialHdr, handshakeHdr, zeroRTTHdr *wire.ExtendedHeader + initialPayload, handshakePayload, zeroRTTPayload, oneRTTPayload payload + oneRTTPacketNumber protocol.PacketNumber + oneRTTPacketNumberLen protocol.PacketNumberLen + ) + // Try packing an Initial packet. + initialSealer, err := p.cryptoSetup.GetInitialSealer() + if err != nil && err != handshake.ErrKeysDropped { + return nil, err + } + var size protocol.ByteCount + if initialSealer != nil { + initialHdr, initialPayload = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(initialSealer.Overhead()), protocol.EncryptionInitial, onlyAck, true, v) + if initialPayload.length > 0 { + size += p.longHeaderPacketLength(initialHdr, initialPayload, v) + protocol.ByteCount(initialSealer.Overhead()) + } + + // // [UQUIC] + // if len(initialPayload.frames) > 0 { + // fmt.Printf("onlyAck: %t, PackCoalescedPacket: %v\n", onlyAck, initialPayload.frames[0].Frame) + // } + } + + // Add a Handshake packet. + var handshakeSealer sealer + if (onlyAck && size == 0) || (!onlyAck && size < maxPacketSize-protocol.MinCoalescedPacketSize) { + var err error + handshakeSealer, err = p.cryptoSetup.GetHandshakeSealer() + if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable { + return nil, err + } + if handshakeSealer != nil { + handshakeHdr, handshakePayload = p.maybeGetCryptoPacket(maxPacketSize-size-protocol.ByteCount(handshakeSealer.Overhead()), protocol.EncryptionHandshake, onlyAck, size == 0, v) + if handshakePayload.length > 0 { + s := p.longHeaderPacketLength(handshakeHdr, handshakePayload, v) + protocol.ByteCount(handshakeSealer.Overhead()) + size += s + } + } + } + + // Add a 0-RTT / 1-RTT packet. + var zeroRTTSealer sealer + var oneRTTSealer handshake.ShortHeaderSealer + var connID protocol.ConnectionID + var kp protocol.KeyPhaseBit + if (onlyAck && size == 0) || (!onlyAck && size < maxPacketSize-protocol.MinCoalescedPacketSize) { + var err error + oneRTTSealer, err = p.cryptoSetup.Get1RTTSealer() + if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable { + return nil, err + } + if err == nil { // 1-RTT + kp = oneRTTSealer.KeyPhase() + connID = p.getDestConnID() + oneRTTPacketNumber, oneRTTPacketNumberLen = p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) + hdrLen := wire.ShortHeaderLen(connID, oneRTTPacketNumberLen) + oneRTTPayload = p.maybeGetShortHeaderPacket(oneRTTSealer, hdrLen, maxPacketSize-size, onlyAck, size == 0, v) + if oneRTTPayload.length > 0 { + size += p.shortHeaderPacketLength(connID, oneRTTPacketNumberLen, oneRTTPayload) + protocol.ByteCount(oneRTTSealer.Overhead()) + } + } else if p.perspective == protocol.PerspectiveClient && !onlyAck { // 0-RTT packets can't contain ACK frames + var err error + zeroRTTSealer, err = p.cryptoSetup.Get0RTTSealer() + if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable { + return nil, err + } + if zeroRTTSealer != nil { + zeroRTTHdr, zeroRTTPayload = p.maybeGetAppDataPacketFor0RTT(zeroRTTSealer, maxPacketSize-size, v) + if zeroRTTPayload.length > 0 { + size += p.longHeaderPacketLength(zeroRTTHdr, zeroRTTPayload, v) + protocol.ByteCount(zeroRTTSealer.Overhead()) + } + } + } + } + + if initialPayload.length == 0 && handshakePayload.length == 0 && zeroRTTPayload.length == 0 && oneRTTPayload.length == 0 { + return nil, nil + } + + buffer := getPacketBuffer() + packet := &coalescedPacket{ + buffer: buffer, + longHdrPackets: make([]*longHeaderPacket, 0, 3), + } + if initialPayload.length > 0 { + if onlyAck || len(initialPayload.frames) == 0 { + // TODO: uQUIC should send Initial Packet if requested. + // However, it should be otherwise configurable whether to request + // to send Initial Packet or not. See quic-go#4007 + padding := p.initialPaddingLen(initialPayload.frames, size, maxPacketSize) + cont, err := p.appendLongHeaderPacket(buffer, initialHdr, initialPayload, padding, protocol.EncryptionInitial, initialSealer, v) + if err != nil { + return nil, err + } + packet.longHdrPackets = append(packet.longHdrPackets, cont) + } else { // [UQUIC] + cont, err := p.appendInitialPacket(buffer, initialHdr, initialPayload, protocol.EncryptionInitial, initialSealer, v) + if err != nil { + return nil, err + } + + packet.longHdrPackets = append(packet.longHdrPackets, cont) + } + } + if handshakePayload.length > 0 { + cont, err := p.appendLongHeaderPacket(buffer, handshakeHdr, handshakePayload, 0, protocol.EncryptionHandshake, handshakeSealer, v) + if err != nil { + return nil, err + } + packet.longHdrPackets = append(packet.longHdrPackets, cont) + } + if zeroRTTPayload.length > 0 { + longHdrPacket, err := p.appendLongHeaderPacket(buffer, zeroRTTHdr, zeroRTTPayload, 0, protocol.Encryption0RTT, zeroRTTSealer, v) + if err != nil { + return nil, err + } + packet.longHdrPackets = append(packet.longHdrPackets, longHdrPacket) + } else if oneRTTPayload.length > 0 { + shp, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, kp, oneRTTPayload, 0, maxPacketSize, oneRTTSealer, false, v) + if err != nil { + return nil, err + } + packet.shortHdrPacket = &shp + } + return packet, nil +} + +// [UQUIC] +func (p *uPacketPacker) appendInitialPacket(buffer *packetBuffer, header *wire.ExtendedHeader, pl payload, encLevel protocol.EncryptionLevel, sealer sealer, v protocol.VersionNumber) (*longHeaderPacket, error) { + // Shouldn't need this? + // if p.uSpec.InitialPacketSpec.InitPacketNumberLength > 0 { + // header.PacketNumberLen = p.uSpec.InitialPacketSpec.InitPacketNumberLength + // } + + uPayload, err := p.MarshalInitialPacketPayload(pl, v) + if err != nil { + return nil, err + } + + pnLen := protocol.ByteCount(header.PacketNumberLen) + header.Length = pnLen + protocol.ByteCount(sealer.Overhead()) + protocol.ByteCount(len(uPayload)) + + startLen := len(buffer.Data) + raw := buffer.Data[startLen:] // [UQUIC] the raw here is a sub-slice of buffer.Data, latter's len < size + + raw, err = header.Append(raw, v) + if err != nil { + return nil, err + } + payloadOffset := protocol.ByteCount(len(raw)) + raw = append(raw, uPayload...) + + // fmt.Printf("Payload: %x\n", raw[payloadOffset:]) + + // fmt.Printf("Pre-Encryption: %x\n", raw) + + raw = p.encryptPacket(raw, sealer, header.PacketNumber, payloadOffset, pnLen) + buffer.Data = buffer.Data[:len(buffer.Data)+len(raw)] + + // fmt.Printf("Post-Encryption: %x\n", raw) + + // [UQUIC] + // append zero to buffer.Data until min size is reached + minUDPSize := p.uSpec.UDPDatagramMinSize + if minUDPSize == 0 { + minUDPSize = DefaultUDPDatagramMinSize + } + if len(buffer.Data) < minUDPSize { + buffer.Data = append(buffer.Data, make([]byte, minUDPSize-len(buffer.Data))...) + } + + if pn := p.pnManager.PopPacketNumber(encLevel); pn != header.PacketNumber { + return nil, fmt.Errorf("packetPacker BUG: Peeked and Popped packet numbers do not match: expected %d, got %d", pn, header.PacketNumber) + } + return &longHeaderPacket{ + header: header, + ack: pl.ack, + frames: pl.frames, + streamFrames: pl.streamFrames, + length: protocol.ByteCount(len(raw)), + }, nil +} + +func (p *uPacketPacker) MarshalInitialPacketPayload(pl payload, v protocol.VersionNumber) ([]byte, error) { + var originalFrameBytes []byte + + for _, f := range pl.frames { + var err error + // only append crypto frames + if _, ok := f.Frame.(*wire.CryptoFrame); !ok { + continue + } + + originalFrameBytes, err = f.Frame.Append(originalFrameBytes, v) + if err != nil { + return nil, err + } + } + + uPayload, err := p.uSpec.InitialPacketSpec.FrameOrder.MarshalWithFrames(originalFrameBytes) + if err != nil { + return nil, err + } + + return uPayload, nil +} diff --git a/u_quic_spec.go b/u_quic_spec.go index 4b4999bb..a7ef41dd 100644 --- a/u_quic_spec.go +++ b/u_quic_spec.go @@ -1,200 +1,18 @@ package quic -import ( - "bytes" - "crypto/rand" - "errors" +import tls "github.com/refraction-networking/utls" - "github.com/gaukas/clienthellod" - "github.com/quic-go/quic-go/quicvarint" +const ( + DefaultUDPDatagramMinSize = 1200 ) type QUICSpec struct { - // SrcConnIDLength specifies how many bytes should the SrcConnID be - SrcConnIDLength int + InitialPacketSpec InitialPacketSpec + ClientHelloSpec *tls.ClientHelloSpec - // DestConnIDLength specifies how many bytes should the DestConnID be - DstConnIDLength int - - // InitPacketNumberLength specifies how many bytes should the InitPacketNumber - // be interpreted as. It is usually 1 or 2 bytes. If unset, UQUIC will use the - // default algorithm to compute the length which is at least 2 bytes. - InitPacketNumberLength PacketNumberLen - - // InitPacketNumber is the packet number of the first Initial packet. Following - // Initial packets, if any, will increment the Packet Number accordingly. - InitPacketNumber uint64 // [UQUIC] - - // TokenStore is used to store and retrieve tokens. If set, will override the - // one set in the Config. - TokenStore TokenStore - - // If ClientTokenLength is set when TokenStore is not set, a dummy TokenStore - // will be created to randomly generate tokens of the specified length for - // Pop() calls with any key and silently drop any Put() calls. - // - // However, the tokens will not be stored anywhere and are expected to be - // invalid since not assigned by the server. - ClientTokenLength int - - // QUICFrames specifies a list of QUIC frames to be sent in the first Initial - // packet. - // - // If nil, it will be treated as a list with only a single QUICFrameCrypto. - QUICFrames []QUICFrame + UDPDatagramMinSize int } -func (s *QUICSpec) getTokenStore() TokenStore { - if s.TokenStore != nil { - return s.TokenStore - } - - if s.ClientTokenLength > 0 { - return &dummyTokenStore{ - tokenLength: s.ClientTokenLength, - } - } - - return nil -} - -type dummyTokenStore struct { - tokenLength int -} - -func (d *dummyTokenStore) Pop(key string) (token *ClientToken) { - var data []byte = make([]byte, d.tokenLength) - rand.Read(data) - - return &ClientToken{ - data: data, - } -} - -func (d *dummyTokenStore) Put(_ string, _ *ClientToken) { - // Do nothing -} - -type QUICFrames []QUICFrame - -func (qfs QUICFrames) MarshalWithCryptoData(cryptoData []byte) (payload []byte, err error) { - if len(qfs) == 0 { // If no frames specified, send a single crypto frame - payload = make([]byte, len(cryptoData)+1) - } - - for _, frame := range qfs { - var frameBytes []byte - if offset, length, cryptoOK := frame.CryptoFrameInfo(); cryptoOK { - if length == 0 { - // calculate length: from offset to the end of cryptoData - length = len(cryptoData) - offset - } - frameBytes = []byte{0x06} // CRYPTO frame type - frameBytes = quicvarint.Append(frameBytes, uint64(offset)) - frameBytes = quicvarint.Append(frameBytes, uint64(length)) - frameCryptoData := make([]byte, length) - copy(frameCryptoData, cryptoData[offset:]) // copy at most length bytes - frameBytes = append(frameBytes, frameCryptoData...) - } else { // Handle none crypto frames: read and append to payload - frameBytes, err = frame.Read() - if err != nil { - return nil, err - } - } - payload = append(payload, frameBytes...) - } - return payload, nil -} - -func (qfs QUICFrames) MarshalWithFrames(frames []byte) (payload []byte, err error) { - // parse frames - r := bytes.NewReader(frames) - qchframes, err := clienthellod.ReadAllFrames(r) - if err != nil { - return nil, err - } - - // parse crypto data - cryptoData, err := clienthellod.ReassembleCRYPTOFrames(qchframes) - if err != nil { - return nil, err - } - - // marshal - return qfs.MarshalWithCryptoData(cryptoData) -} - -type QUICFrame interface { - // None crypto frames should return false for cryptoOK - CryptoFrameInfo() (offset, length int, cryptoOK bool) - - // None crypto frames should return the byte representation of the frame. - // Crypto frames' behavior is undefined and unused. - Read() ([]byte, error) -} - -// QUICFrameCrypto is used to specify the crypto frames containing the TLS ClientHello -// to be sent in the first Initial packet. -type QUICFrameCrypto struct { - // Offset is used to specify the starting offset of the crypto frame. - // Used when sending multiple crypto frames in a single packet. - // - // Multiple crypto frames in a single packet must not overlap and must - // make up an entire crypto stream continuously. - Offset int - - // Length is used to specify the length of the crypto frame. - // - // Must be set if it is NOT the last crypto frame in a packet. - Length int -} - -// CryptoFrameInfo() implements the QUICFrame interface. -// -// Crypto frames are later replaced by the crypto message using the information -// returned by this function. -func (q QUICFrameCrypto) CryptoFrameInfo() (offset, length int, cryptoOK bool) { - return q.Offset, q.Length, true -} - -// Read() implements the QUICFrame interface. -// -// Crypto frames are later replaced by the crypto message, so they are not Read()-able. -func (q QUICFrameCrypto) Read() ([]byte, error) { - return nil, errors.New("crypto frames are not Read()-able") -} - -// QUICFramePadding is used to specify the padding frames to be sent in the first Initial -// packet. -type QUICFramePadding struct { - // Length is used to specify the length of the padding frame. - Length int -} - -// CryptoFrameInfo() implements the QUICFrame interface. -func (q QUICFramePadding) CryptoFrameInfo() (offset, length int, cryptoOK bool) { - return 0, 0, false -} - -// Read() implements the QUICFrame interface. -// -// Padding simply returns a slice of bytes of the specified length filled with 0. -func (q QUICFramePadding) Read() ([]byte, error) { - return make([]byte, q.Length), nil -} - -// QUICFramePing is used to specify the ping frames to be sent in the first Initial -// packet. -type QUICFramePing struct{} - -// CryptoFrameInfo() implements the QUICFrame interface. -func (q QUICFramePing) CryptoFrameInfo() (offset, length int, cryptoOK bool) { - return 0, 0, false -} - -// Read() implements the QUICFrame interface. -// -// Ping simply returns a slice of bytes of size 1 with value 0x01(PING). -func (q QUICFramePing) Read() ([]byte, error) { - return []byte{0x01}, nil +func (s *QUICSpec) UpdateConfig(config *Config) { + s.InitialPacketSpec.UpdateConfig(config) } diff --git a/u_transport.go b/u_transport.go new file mode 100644 index 00000000..488d41fa --- /dev/null +++ b/u_transport.go @@ -0,0 +1,87 @@ +package quic + +import ( + "context" + "net" + + "github.com/quic-go/quic-go/internal/protocol" + tls "github.com/refraction-networking/utls" +) + +type UTransport struct { + *Transport + + QUICSpec *QUICSpec // [UQUIC] using ptr to avoid copying +} + +// Dial dials a new connection to a remote host (not using 0-RTT). +func (t *UTransport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) { + if err := validateConfig(conf); err != nil { + return nil, err + } + conf = populateConfig(conf) + + // [UQUIC] + // Override the default connection ID generator if the user has specified a length in QUICSpec. + if t.QUICSpec != nil { + if t.QUICSpec.InitialPacketSpec.SrcConnIDLength != 0 { + t.ConnectionIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: t.QUICSpec.InitialPacketSpec.SrcConnIDLength} + } else { + t.ConnectionIDGenerator = &protocol.ExpEmptyConnectionIDGenerator{} + } + } + // [/UQUIC] + + if err := t.init(t.isSingleUse); err != nil { + return nil, err + } + var onClose func() + if t.isSingleUse { + onClose = func() { t.Close() } + } + tlsConf = tlsConf.Clone() + tlsConf.MinVersion = tls.VersionTLS13 + + return udial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, false, t.QUICSpec) +} + +// DialEarly dials a new connection, attempting to use 0-RTT if possible. +func (t *UTransport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) { + if err := validateConfig(conf); err != nil { + return nil, err + } + conf = populateConfig(conf) + + // [UQUIC] + // Override the default connection ID generator if the user has specified a length in QUICSpec. + if t.QUICSpec != nil { + if t.QUICSpec.InitialPacketSpec.SrcConnIDLength != 0 { + t.ConnectionIDGenerator = &protocol.DefaultConnectionIDGenerator{ConnLen: t.QUICSpec.InitialPacketSpec.SrcConnIDLength} + } else { + t.ConnectionIDGenerator = &protocol.ExpEmptyConnectionIDGenerator{} + } + } + // [/UQUIC] + + if err := t.init(t.isSingleUse); err != nil { + return nil, err + } + var onClose func() + if t.isSingleUse { + onClose = func() { t.Close() } + } + tlsConf = tlsConf.Clone() + tlsConf.MinVersion = tls.VersionTLS13 + + return udial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, true, t.QUICSpec) +} + +func (ut *UTransport) MakeDialer() func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *Config) (EarlyConnection, error) { + return func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *Config) (EarlyConnection, error) { + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + return ut.DialEarly(ctx, udpAddr, tlsCfg, cfg) + } +}