mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-04 04:37:36 +03:00
Merge branch 'upstream' into sync-upstream
This commit is contained in:
commit
856bc02b8f
130 changed files with 1364 additions and 463 deletions
104
transport.go
104
transport.go
|
@ -6,14 +6,14 @@ import (
|
|||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
tls "github.com/refraction-networking/utls"
|
||||
|
||||
"github.com/refraction-networking/uquic/internal/wire"
|
||||
|
||||
"github.com/refraction-networking/uquic/internal/protocol"
|
||||
"github.com/refraction-networking/uquic/internal/utils"
|
||||
"github.com/refraction-networking/uquic/internal/wire"
|
||||
"github.com/refraction-networking/uquic/logging"
|
||||
)
|
||||
|
||||
|
@ -86,6 +86,9 @@ type Transport struct {
|
|||
createdConn bool
|
||||
isSingleUse bool // was created for a single server or client, i.e. by calling quic.Listen or quic.Dial
|
||||
|
||||
readingNonQUICPackets atomic.Bool
|
||||
nonQUICPackets chan receivedPacket
|
||||
|
||||
logger utils.Logger
|
||||
}
|
||||
|
||||
|
@ -149,26 +152,15 @@ func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListen
|
|||
|
||||
// Dial dials a new connection to a remote host (not using 0-RTT).
|
||||
func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) {
|
||||
if err := validateConfig(conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conf = populateConfig(conf)
|
||||
|
||||
if err := t.init(t.isSingleUse); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var onClose func()
|
||||
if t.isSingleUse {
|
||||
onClose = func() { t.Close() }
|
||||
}
|
||||
tlsConf = tlsConf.Clone()
|
||||
tlsConf.MinVersion = tls.VersionTLS13
|
||||
|
||||
return dial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, false)
|
||||
return t.dial(ctx, addr, "", tlsConf, conf, false)
|
||||
}
|
||||
|
||||
// DialEarly dials a new connection, attempting to use 0-RTT if possible.
|
||||
func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) {
|
||||
return t.dial(ctx, addr, "", tlsConf, conf, true)
|
||||
}
|
||||
|
||||
func (t *Transport) dial(ctx context.Context, addr net.Addr, host string, tlsConf *tls.Config, conf *Config, use0RTT bool) (EarlyConnection, error) {
|
||||
if err := validateConfig(conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -183,8 +175,8 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C
|
|||
}
|
||||
tlsConf = tlsConf.Clone()
|
||||
tlsConf.MinVersion = tls.VersionTLS13
|
||||
|
||||
return dial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, true)
|
||||
setTLSConfigServerName(tlsConf, addr, host)
|
||||
return dial(ctx, newSendConn(t.conn, addr, packetInfo{}, utils.DefaultLogger), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, use0RTT)
|
||||
}
|
||||
|
||||
func (t *Transport) init(allowZeroLengthConnIDs bool) error {
|
||||
|
@ -200,7 +192,6 @@ func (t *Transport) init(allowZeroLengthConnIDs bool) error {
|
|||
return
|
||||
}
|
||||
}
|
||||
t.conn = conn
|
||||
|
||||
t.logger = utils.DefaultLogger // TODO: make this configurable
|
||||
t.conn = conn
|
||||
|
@ -234,7 +225,7 @@ func (t *Transport) WriteTo(b []byte, addr net.Addr) (int, error) {
|
|||
if err := t.init(false); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return t.conn.WritePacket(b, uint16(len(b)), addr, nil)
|
||||
return t.conn.WritePacket(b, addr, nil)
|
||||
}
|
||||
|
||||
func (t *Transport) enqueueClosePacket(p closePacket) {
|
||||
|
@ -252,7 +243,7 @@ func (t *Transport) runSendQueue() {
|
|||
case <-t.listening:
|
||||
return
|
||||
case p := <-t.closeQueue:
|
||||
t.conn.WritePacket(p.payload, uint16(len(p.payload)), p.addr, p.info.OOB())
|
||||
t.conn.WritePacket(p.payload, p.addr, p.info.OOB())
|
||||
case p := <-t.statelessResetQueue:
|
||||
t.sendStatelessReset(p)
|
||||
}
|
||||
|
@ -347,6 +338,13 @@ func (t *Transport) listen(conn rawConn) {
|
|||
}
|
||||
|
||||
func (t *Transport) handlePacket(p receivedPacket) {
|
||||
if len(p.data) == 0 {
|
||||
return
|
||||
}
|
||||
if !wire.IsPotentialQUICPacket(p.data[0]) && !wire.IsLongHeaderPacket(p.data[0]) {
|
||||
t.handleNonQUICPacket(p)
|
||||
return
|
||||
}
|
||||
connID, err := wire.ParseConnectionID(p.data, t.connIDLen)
|
||||
if err != nil {
|
||||
t.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err)
|
||||
|
@ -413,7 +411,7 @@ func (t *Transport) sendStatelessReset(p receivedPacket) {
|
|||
rand.Read(data)
|
||||
data[0] = (data[0] & 0x7f) | 0x40
|
||||
data = append(data, token[:]...)
|
||||
if _, err := t.conn.WritePacket(data, uint16(len(data)), p.remoteAddr, p.info.OOB()); err != nil {
|
||||
if _, err := t.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil {
|
||||
t.logger.Debugf("Error sending Stateless Reset to %s: %s", p.remoteAddr, err)
|
||||
}
|
||||
}
|
||||
|
@ -435,3 +433,61 @@ func (t *Transport) maybeHandleStatelessReset(data []byte) bool {
|
|||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *Transport) handleNonQUICPacket(p receivedPacket) {
|
||||
// Strictly speaking, this is racy,
|
||||
// but we only care about receiving packets at some point after ReadNonQUICPacket has been called.
|
||||
if !t.readingNonQUICPackets.Load() {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case t.nonQUICPackets <- p:
|
||||
default:
|
||||
if t.Tracer != nil {
|
||||
t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const maxQueuedNonQUICPackets = 32
|
||||
|
||||
// ReadNonQUICPacket reads non-QUIC packets received on the underlying connection.
|
||||
// The detection logic is very simple: Any packet that has the first and second bit of the packet set to 0.
|
||||
// Note that this is stricter than the detection logic defined in RFC 9443.
|
||||
func (t *Transport) ReadNonQUICPacket(ctx context.Context, b []byte) (int, net.Addr, error) {
|
||||
if err := t.init(false); err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
if !t.readingNonQUICPackets.Load() {
|
||||
t.nonQUICPackets = make(chan receivedPacket, maxQueuedNonQUICPackets)
|
||||
t.readingNonQUICPackets.Store(true)
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return 0, nil, ctx.Err()
|
||||
case p := <-t.nonQUICPackets:
|
||||
n := copy(b, p.data)
|
||||
return n, p.remoteAddr, nil
|
||||
case <-t.listening:
|
||||
return 0, nil, errors.New("closed")
|
||||
}
|
||||
}
|
||||
|
||||
func setTLSConfigServerName(tlsConf *tls.Config, addr net.Addr, host string) {
|
||||
// If no ServerName is set, infer the ServerName from the host we're connecting to.
|
||||
if tlsConf.ServerName != "" {
|
||||
return
|
||||
}
|
||||
if host == "" {
|
||||
if udpAddr, ok := addr.(*net.UDPAddr); ok {
|
||||
tlsConf.ServerName = udpAddr.IP.String()
|
||||
return
|
||||
}
|
||||
}
|
||||
h, _, err := net.SplitHostPort(host)
|
||||
if err != nil { // This happens if the host doesn't contain a port number.
|
||||
tlsConf.ServerName = host
|
||||
return
|
||||
}
|
||||
tlsConf.ServerName = h
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue