mirror of
https://github.com/refraction-networking/uquic.git
synced 2025-04-03 20:27:35 +03:00
remove OptimizeConn, add a Transport.WriteTo method instead (#3957)
* remove OptimizeConn, add a Transport.WriteTo method instead * fix race condition in Transport.WriteTo
This commit is contained in:
parent
27301f791f
commit
a347d664e2
4 changed files with 54 additions and 76 deletions
|
@ -26,9 +26,10 @@ var _ = Describe("MITM test", func() {
|
|||
const connIDLen = 6 // explicitly set the connection ID length, so the proxy can parse it
|
||||
|
||||
var (
|
||||
serverUDPConn, clientUDPConn net.PacketConn
|
||||
serverConn quic.Connection
|
||||
serverConfig *quic.Config
|
||||
clientUDPConn net.PacketConn
|
||||
serverTransport, clientTransport *quic.Transport
|
||||
serverConn quic.Connection
|
||||
serverConfig *quic.Config
|
||||
)
|
||||
|
||||
startServerAndProxy := func(delayCb quicproxy.DelayCallback, dropCb quicproxy.DropCallback) (proxyPort int, closeFn func()) {
|
||||
|
@ -36,13 +37,11 @@ var _ = Describe("MITM test", func() {
|
|||
Expect(err).ToNot(HaveOccurred())
|
||||
c, err := net.ListenUDP("udp", addr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
serverUDPConn, err = quic.OptimizeConn(c)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
tr := &quic.Transport{
|
||||
Conn: serverUDPConn,
|
||||
serverTransport = &quic.Transport{
|
||||
Conn: c,
|
||||
ConnectionIDLength: connIDLen,
|
||||
}
|
||||
ln, err := tr.Listen(getTLSConfig(), serverConfig)
|
||||
ln, err := serverTransport.Listen(getTLSConfig(), serverConfig)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
|
@ -69,7 +68,7 @@ var _ = Describe("MITM test", func() {
|
|||
return proxy.LocalPort(), func() {
|
||||
proxy.Close()
|
||||
ln.Close()
|
||||
serverUDPConn.Close()
|
||||
serverTransport.Close()
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
@ -78,10 +77,12 @@ var _ = Describe("MITM test", func() {
|
|||
serverConfig = getQuicConfig(nil)
|
||||
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
c, err := net.ListenUDP("udp", addr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
clientUDPConn, err = quic.OptimizeConn(c)
|
||||
clientUDPConn, err = net.ListenUDP("udp", addr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
clientTransport = &quic.Transport{
|
||||
Conn: clientUDPConn,
|
||||
ConnectionIDLength: connIDLen,
|
||||
}
|
||||
})
|
||||
|
||||
Context("unsuccessful attacks", func() {
|
||||
|
@ -90,12 +91,13 @@ var _ = Describe("MITM test", func() {
|
|||
// Test shutdown is tricky due to the proxy. Just wait for a bit.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
Expect(clientUDPConn.Close()).To(Succeed())
|
||||
Expect(clientTransport.Close()).To(Succeed())
|
||||
})
|
||||
|
||||
Context("injecting invalid packets", func() {
|
||||
const rtt = 20 * time.Millisecond
|
||||
|
||||
sendRandomPacketsOfSameType := func(conn net.PacketConn, remoteAddr net.Addr, raw []byte) {
|
||||
sendRandomPacketsOfSameType := func(conn *quic.Transport, remoteAddr net.Addr, raw []byte) {
|
||||
defer GinkgoRecover()
|
||||
const numPackets = 10
|
||||
ticker := time.NewTicker(rtt / numPackets)
|
||||
|
@ -155,11 +157,7 @@ var _ = Describe("MITM test", func() {
|
|||
defer closeFn()
|
||||
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
tr := &quic.Transport{
|
||||
Conn: clientUDPConn,
|
||||
ConnectionIDLength: connIDLen,
|
||||
}
|
||||
conn, err := tr.Dial(
|
||||
conn, err := clientTransport.Dial(
|
||||
context.Background(),
|
||||
raddr,
|
||||
getTLSClientConfig(),
|
||||
|
@ -178,7 +176,7 @@ var _ = Describe("MITM test", func() {
|
|||
delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
|
||||
if dir == quicproxy.DirectionIncoming {
|
||||
defer GinkgoRecover()
|
||||
go sendRandomPacketsOfSameType(clientUDPConn, serverUDPConn.LocalAddr(), raw)
|
||||
go sendRandomPacketsOfSameType(clientTransport, serverTransport.Conn.LocalAddr(), raw)
|
||||
}
|
||||
return rtt / 2
|
||||
}
|
||||
|
@ -189,7 +187,7 @@ var _ = Describe("MITM test", func() {
|
|||
delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
|
||||
if dir == quicproxy.DirectionOutgoing {
|
||||
defer GinkgoRecover()
|
||||
go sendRandomPacketsOfSameType(serverUDPConn, clientUDPConn.LocalAddr(), raw)
|
||||
go sendRandomPacketsOfSameType(serverTransport, clientTransport.Conn.LocalAddr(), raw)
|
||||
}
|
||||
return rtt / 2
|
||||
}
|
||||
|
@ -202,11 +200,7 @@ var _ = Describe("MITM test", func() {
|
|||
defer closeFn()
|
||||
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
tr := &quic.Transport{
|
||||
Conn: clientUDPConn,
|
||||
ConnectionIDLength: connIDLen,
|
||||
}
|
||||
conn, err := tr.Dial(
|
||||
conn, err := clientTransport.Dial(
|
||||
context.Background(),
|
||||
raddr,
|
||||
getTLSClientConfig(),
|
||||
|
@ -226,7 +220,7 @@ var _ = Describe("MITM test", func() {
|
|||
dropCb := func(dir quicproxy.Direction, raw []byte) bool {
|
||||
defer GinkgoRecover()
|
||||
if dir == quicproxy.DirectionIncoming {
|
||||
_, err := clientUDPConn.WriteTo(raw, serverUDPConn.LocalAddr())
|
||||
_, err := clientTransport.WriteTo(raw, serverTransport.Conn.LocalAddr())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
return false
|
||||
|
@ -238,7 +232,7 @@ var _ = Describe("MITM test", func() {
|
|||
dropCb := func(dir quicproxy.Direction, raw []byte) bool {
|
||||
defer GinkgoRecover()
|
||||
if dir == quicproxy.DirectionOutgoing {
|
||||
_, err := serverUDPConn.WriteTo(raw, clientUDPConn.LocalAddr())
|
||||
_, err := serverTransport.WriteTo(raw, clientTransport.Conn.LocalAddr())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
return false
|
||||
|
@ -276,7 +270,7 @@ var _ = Describe("MITM test", func() {
|
|||
if rand.Intn(interval) == 0 {
|
||||
pos := rand.Intn(len(raw))
|
||||
raw[pos] = byte(rand.Intn(256))
|
||||
_, err := clientUDPConn.WriteTo(raw, serverUDPConn.LocalAddr())
|
||||
_, err := clientTransport.WriteTo(raw, serverTransport.Conn.LocalAddr())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
atomic.AddInt32(&numCorrupted, 1)
|
||||
return true
|
||||
|
@ -296,7 +290,7 @@ var _ = Describe("MITM test", func() {
|
|||
if rand.Intn(interval) == 0 {
|
||||
pos := rand.Intn(len(raw))
|
||||
raw[pos] = byte(rand.Intn(256))
|
||||
_, err := serverUDPConn.WriteTo(raw, clientUDPConn.LocalAddr())
|
||||
_, err := serverTransport.WriteTo(raw, clientTransport.Conn.LocalAddr())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
atomic.AddInt32(&numCorrupted, 1)
|
||||
return true
|
||||
|
@ -320,17 +314,13 @@ var _ = Describe("MITM test", func() {
|
|||
proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil)
|
||||
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
tr := &quic.Transport{
|
||||
Conn: clientUDPConn,
|
||||
ConnectionIDLength: connIDLen,
|
||||
}
|
||||
_, err = tr.Dial(
|
||||
_, err = clientTransport.Dial(
|
||||
context.Background(),
|
||||
raddr,
|
||||
getTLSClientConfig(),
|
||||
getQuicConfig(&quic.Config{HandshakeIdleTimeout: 2 * time.Second}),
|
||||
)
|
||||
return func() { tr.Close(); serverCloseFn() }, err
|
||||
return func() { clientTransport.Close(); serverCloseFn() }, err
|
||||
}
|
||||
|
||||
// fails immediately because client connection closes when it can't find compatible version
|
||||
|
@ -356,7 +346,7 @@ var _ = Describe("MITM test", func() {
|
|||
)
|
||||
|
||||
// Send the packet
|
||||
_, err = serverUDPConn.WriteTo(packet, clientUDPConn.LocalAddr())
|
||||
_, err = serverTransport.WriteTo(packet, clientTransport.Conn.LocalAddr())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
close(done)
|
||||
}
|
||||
|
@ -393,7 +383,7 @@ var _ = Describe("MITM test", func() {
|
|||
fakeSrcConnID := protocol.ParseConnectionID([]byte{0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12})
|
||||
retryPacket := testutils.ComposeRetryPacket(fakeSrcConnID, hdr.SrcConnectionID, hdr.DestConnectionID, []byte("token"), hdr.Version)
|
||||
|
||||
_, err = serverUDPConn.WriteTo(retryPacket, clientUDPConn.LocalAddr())
|
||||
_, err = serverTransport.WriteTo(retryPacket, clientTransport.Conn.LocalAddr())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
return rtt / 2
|
||||
|
@ -423,7 +413,7 @@ var _ = Describe("MITM test", func() {
|
|||
defer close(done)
|
||||
injected = true
|
||||
initialPacket := testutils.ComposeInitialPacket(hdr.DestConnectionID, hdr.SrcConnectionID, hdr.Version, hdr.DestConnectionID, nil)
|
||||
_, err = serverUDPConn.WriteTo(initialPacket, clientUDPConn.LocalAddr())
|
||||
_, err = serverTransport.WriteTo(initialPacket, clientTransport.Conn.LocalAddr())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
return rtt
|
||||
|
@ -453,7 +443,7 @@ var _ = Describe("MITM test", func() {
|
|||
// Fake Initial with ACK for packet 2 (unsent)
|
||||
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}}
|
||||
initialPacket := testutils.ComposeInitialPacket(hdr.DestConnectionID, hdr.SrcConnectionID, hdr.Version, hdr.DestConnectionID, []wire.Frame{ack})
|
||||
_, err = serverUDPConn.WriteTo(initialPacket, clientUDPConn.LocalAddr())
|
||||
_, err = serverTransport.WriteTo(initialPacket, clientTransport.Conn.LocalAddr())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
return rtt
|
||||
|
|
22
sys_conn.go
22
sys_conn.go
|
@ -27,27 +27,7 @@ type OOBCapablePacketConn interface {
|
|||
|
||||
var _ OOBCapablePacketConn = &net.UDPConn{}
|
||||
|
||||
// OptimizeConn takes a net.PacketConn and attempts to enable various optimizations that will improve QUIC performance:
|
||||
// 1. It enables the Don't Fragment (DF) bit on the IP header.
|
||||
// This is required to run DPLPMTUD (Path MTU Discovery, RFC 8899).
|
||||
// 2. It enables reading of the ECN bits from the IP header.
|
||||
// This allows the remote node to speed up its loss detection and recovery.
|
||||
// 3. It uses batched syscalls (recvmmsg) to more efficiently receive packets from the socket.
|
||||
// 4. It uses Generic Segmentation Offload (GSO) to efficiently send batches of packets (on Linux).
|
||||
//
|
||||
// In order for this to work, the connection needs to implement the OOBCapablePacketConn interface (as a *net.UDPConn does).
|
||||
//
|
||||
// It's only necessary to call this function explicitly if the application calls WriteTo
|
||||
// after passing the connection to the Transport.
|
||||
func OptimizeConn(c net.PacketConn) (net.PacketConn, error) {
|
||||
return wrapConn(c)
|
||||
}
|
||||
|
||||
func wrapConn(pc net.PacketConn) (interface {
|
||||
net.PacketConn
|
||||
rawConn
|
||||
}, error,
|
||||
) {
|
||||
func wrapConn(pc net.PacketConn) (rawConn, error) {
|
||||
if err := setReceiveBuffer(pc); err != nil {
|
||||
if !strings.Contains(err.Error(), "use of closed network connection") {
|
||||
setBufferWarningOnce.Do(func() {
|
||||
|
|
|
@ -230,13 +230,6 @@ func (c *oobConn) ReadPacket() (receivedPacket, error) {
|
|||
return p, nil
|
||||
}
|
||||
|
||||
// WriteTo (re)implements the net.PacketConn method.
|
||||
// This is needed for users who call OptimizeConn to be able to send (non-QUIC) packets on the underlying connection.
|
||||
// With GSO enabled, this would otherwise not be needed, as the kernel requires the UDP_SEGMENT message to be set.
|
||||
func (c *oobConn) WriteTo(p []byte, addr net.Addr) (int, error) {
|
||||
return c.WritePacket(p, uint16(len(p)), addr, nil)
|
||||
}
|
||||
|
||||
// WritePacket writes a new packet.
|
||||
// If the connection supports GSO (and we activated GSO support before),
|
||||
// it appends the UDP_SEGMENT size message to oob.
|
||||
|
|
33
transport.go
33
transport.go
|
@ -26,9 +26,16 @@ type Transport struct {
|
|||
// A single net.PacketConn can only be handled by one Transport.
|
||||
// Bad things will happen if passed to multiple Transports.
|
||||
//
|
||||
// If not done by the user, the connection is passed through OptimizeConn to enable a number of optimizations.
|
||||
// After passing the connection to the Transport, it's invalid to call ReadFrom on the connection.
|
||||
// Calling WriteTo is only valid on the connection returned by OptimizeConn.
|
||||
// A number of optimizations will be enabled if the connections implements the OOBCapablePacketConn interface,
|
||||
// as a *net.UDPConn does.
|
||||
// 1. It enables the Don't Fragment (DF) bit on the IP header.
|
||||
// This is required to run DPLPMTUD (Path MTU Discovery, RFC 8899).
|
||||
// 2. It enables reading of the ECN bits from the IP header.
|
||||
// This allows the remote node to speed up its loss detection and recovery.
|
||||
// 3. It uses batched syscalls (recvmmsg) to more efficiently receive packets from the socket.
|
||||
// 4. It uses Generic Segmentation Offload (GSO) to efficiently send batches of packets (on Linux).
|
||||
//
|
||||
// After passing the connection to the Transport, it's invalid to call ReadFrom or WriteTo on the connection.
|
||||
Conn net.PacketConn
|
||||
|
||||
// The length of the connection ID in bytes.
|
||||
|
@ -99,7 +106,7 @@ func (t *Transport) Listen(tlsConf *tls.Config, conf *Config) (*Listener, error)
|
|||
return nil, errListenerAlreadySet
|
||||
}
|
||||
conf = populateServerConfig(conf)
|
||||
if err := t.init(true); err != nil {
|
||||
if err := t.init(false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.Tracer, t.closeServer, false)
|
||||
|
@ -128,7 +135,7 @@ func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListen
|
|||
return nil, errListenerAlreadySet
|
||||
}
|
||||
conf = populateServerConfig(conf)
|
||||
if err := t.init(true); err != nil {
|
||||
if err := t.init(false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.Tracer, t.closeServer, true)
|
||||
|
@ -145,7 +152,7 @@ func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config
|
|||
return nil, err
|
||||
}
|
||||
conf = populateConfig(conf)
|
||||
if err := t.init(false); err != nil {
|
||||
if err := t.init(t.isSingleUse); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var onClose func()
|
||||
|
@ -163,7 +170,7 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C
|
|||
return nil, err
|
||||
}
|
||||
conf = populateConfig(conf)
|
||||
if err := t.init(false); err != nil {
|
||||
if err := t.init(t.isSingleUse); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var onClose func()
|
||||
|
@ -175,7 +182,7 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C
|
|||
return dial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, true)
|
||||
}
|
||||
|
||||
func (t *Transport) init(isServer bool) error {
|
||||
func (t *Transport) init(allowZeroLengthConnIDs bool) error {
|
||||
t.initOnce.Do(func() {
|
||||
var conn rawConn
|
||||
if c, ok := t.Conn.(rawConn); ok {
|
||||
|
@ -203,7 +210,7 @@ func (t *Transport) init(isServer bool) error {
|
|||
t.connIDLen = t.ConnectionIDGenerator.ConnectionIDLen()
|
||||
} else {
|
||||
connIDLen := t.ConnectionIDLength
|
||||
if t.ConnectionIDLength == 0 && (!t.isSingleUse || isServer) {
|
||||
if t.ConnectionIDLength == 0 && !allowZeroLengthConnIDs {
|
||||
connIDLen = protocol.DefaultConnectionIDLength
|
||||
}
|
||||
t.connIDLen = connIDLen
|
||||
|
@ -217,6 +224,14 @@ func (t *Transport) init(isServer bool) error {
|
|||
return t.initErr
|
||||
}
|
||||
|
||||
// WriteTo sends a packet on the underlying connection.
|
||||
func (t *Transport) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||
if err := t.init(false); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return t.conn.WritePacket(b, uint16(len(b)), addr, nil)
|
||||
}
|
||||
|
||||
func (t *Transport) enqueueClosePacket(p closePacket) {
|
||||
select {
|
||||
case t.closeQueue <- p:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue