remove OptimizeConn, add a Transport.WriteTo method instead (#3957)

* remove OptimizeConn, add a Transport.WriteTo method instead

* fix race condition in Transport.WriteTo
This commit is contained in:
Marten Seemann 2023-07-19 10:28:11 -07:00 committed by GitHub
parent 27301f791f
commit a347d664e2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 54 additions and 76 deletions

View file

@ -26,9 +26,10 @@ var _ = Describe("MITM test", func() {
const connIDLen = 6 // explicitly set the connection ID length, so the proxy can parse it
var (
serverUDPConn, clientUDPConn net.PacketConn
serverConn quic.Connection
serverConfig *quic.Config
clientUDPConn net.PacketConn
serverTransport, clientTransport *quic.Transport
serverConn quic.Connection
serverConfig *quic.Config
)
startServerAndProxy := func(delayCb quicproxy.DelayCallback, dropCb quicproxy.DropCallback) (proxyPort int, closeFn func()) {
@ -36,13 +37,11 @@ var _ = Describe("MITM test", func() {
Expect(err).ToNot(HaveOccurred())
c, err := net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
serverUDPConn, err = quic.OptimizeConn(c)
Expect(err).ToNot(HaveOccurred())
tr := &quic.Transport{
Conn: serverUDPConn,
serverTransport = &quic.Transport{
Conn: c,
ConnectionIDLength: connIDLen,
}
ln, err := tr.Listen(getTLSConfig(), serverConfig)
ln, err := serverTransport.Listen(getTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
@ -69,7 +68,7 @@ var _ = Describe("MITM test", func() {
return proxy.LocalPort(), func() {
proxy.Close()
ln.Close()
serverUDPConn.Close()
serverTransport.Close()
<-done
}
}
@ -78,10 +77,12 @@ var _ = Describe("MITM test", func() {
serverConfig = getQuicConfig(nil)
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
c, err := net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
clientUDPConn, err = quic.OptimizeConn(c)
clientUDPConn, err = net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
clientTransport = &quic.Transport{
Conn: clientUDPConn,
ConnectionIDLength: connIDLen,
}
})
Context("unsuccessful attacks", func() {
@ -90,12 +91,13 @@ var _ = Describe("MITM test", func() {
// Test shutdown is tricky due to the proxy. Just wait for a bit.
time.Sleep(50 * time.Millisecond)
Expect(clientUDPConn.Close()).To(Succeed())
Expect(clientTransport.Close()).To(Succeed())
})
Context("injecting invalid packets", func() {
const rtt = 20 * time.Millisecond
sendRandomPacketsOfSameType := func(conn net.PacketConn, remoteAddr net.Addr, raw []byte) {
sendRandomPacketsOfSameType := func(conn *quic.Transport, remoteAddr net.Addr, raw []byte) {
defer GinkgoRecover()
const numPackets = 10
ticker := time.NewTicker(rtt / numPackets)
@ -155,11 +157,7 @@ var _ = Describe("MITM test", func() {
defer closeFn()
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
Expect(err).ToNot(HaveOccurred())
tr := &quic.Transport{
Conn: clientUDPConn,
ConnectionIDLength: connIDLen,
}
conn, err := tr.Dial(
conn, err := clientTransport.Dial(
context.Background(),
raddr,
getTLSClientConfig(),
@ -178,7 +176,7 @@ var _ = Describe("MITM test", func() {
delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
if dir == quicproxy.DirectionIncoming {
defer GinkgoRecover()
go sendRandomPacketsOfSameType(clientUDPConn, serverUDPConn.LocalAddr(), raw)
go sendRandomPacketsOfSameType(clientTransport, serverTransport.Conn.LocalAddr(), raw)
}
return rtt / 2
}
@ -189,7 +187,7 @@ var _ = Describe("MITM test", func() {
delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
if dir == quicproxy.DirectionOutgoing {
defer GinkgoRecover()
go sendRandomPacketsOfSameType(serverUDPConn, clientUDPConn.LocalAddr(), raw)
go sendRandomPacketsOfSameType(serverTransport, clientTransport.Conn.LocalAddr(), raw)
}
return rtt / 2
}
@ -202,11 +200,7 @@ var _ = Describe("MITM test", func() {
defer closeFn()
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
Expect(err).ToNot(HaveOccurred())
tr := &quic.Transport{
Conn: clientUDPConn,
ConnectionIDLength: connIDLen,
}
conn, err := tr.Dial(
conn, err := clientTransport.Dial(
context.Background(),
raddr,
getTLSClientConfig(),
@ -226,7 +220,7 @@ var _ = Describe("MITM test", func() {
dropCb := func(dir quicproxy.Direction, raw []byte) bool {
defer GinkgoRecover()
if dir == quicproxy.DirectionIncoming {
_, err := clientUDPConn.WriteTo(raw, serverUDPConn.LocalAddr())
_, err := clientTransport.WriteTo(raw, serverTransport.Conn.LocalAddr())
Expect(err).ToNot(HaveOccurred())
}
return false
@ -238,7 +232,7 @@ var _ = Describe("MITM test", func() {
dropCb := func(dir quicproxy.Direction, raw []byte) bool {
defer GinkgoRecover()
if dir == quicproxy.DirectionOutgoing {
_, err := serverUDPConn.WriteTo(raw, clientUDPConn.LocalAddr())
_, err := serverTransport.WriteTo(raw, clientTransport.Conn.LocalAddr())
Expect(err).ToNot(HaveOccurred())
}
return false
@ -276,7 +270,7 @@ var _ = Describe("MITM test", func() {
if rand.Intn(interval) == 0 {
pos := rand.Intn(len(raw))
raw[pos] = byte(rand.Intn(256))
_, err := clientUDPConn.WriteTo(raw, serverUDPConn.LocalAddr())
_, err := clientTransport.WriteTo(raw, serverTransport.Conn.LocalAddr())
Expect(err).ToNot(HaveOccurred())
atomic.AddInt32(&numCorrupted, 1)
return true
@ -296,7 +290,7 @@ var _ = Describe("MITM test", func() {
if rand.Intn(interval) == 0 {
pos := rand.Intn(len(raw))
raw[pos] = byte(rand.Intn(256))
_, err := serverUDPConn.WriteTo(raw, clientUDPConn.LocalAddr())
_, err := serverTransport.WriteTo(raw, clientTransport.Conn.LocalAddr())
Expect(err).ToNot(HaveOccurred())
atomic.AddInt32(&numCorrupted, 1)
return true
@ -320,17 +314,13 @@ var _ = Describe("MITM test", func() {
proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil)
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
Expect(err).ToNot(HaveOccurred())
tr := &quic.Transport{
Conn: clientUDPConn,
ConnectionIDLength: connIDLen,
}
_, err = tr.Dial(
_, err = clientTransport.Dial(
context.Background(),
raddr,
getTLSClientConfig(),
getQuicConfig(&quic.Config{HandshakeIdleTimeout: 2 * time.Second}),
)
return func() { tr.Close(); serverCloseFn() }, err
return func() { clientTransport.Close(); serverCloseFn() }, err
}
// fails immediately because client connection closes when it can't find compatible version
@ -356,7 +346,7 @@ var _ = Describe("MITM test", func() {
)
// Send the packet
_, err = serverUDPConn.WriteTo(packet, clientUDPConn.LocalAddr())
_, err = serverTransport.WriteTo(packet, clientTransport.Conn.LocalAddr())
Expect(err).ToNot(HaveOccurred())
close(done)
}
@ -393,7 +383,7 @@ var _ = Describe("MITM test", func() {
fakeSrcConnID := protocol.ParseConnectionID([]byte{0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12})
retryPacket := testutils.ComposeRetryPacket(fakeSrcConnID, hdr.SrcConnectionID, hdr.DestConnectionID, []byte("token"), hdr.Version)
_, err = serverUDPConn.WriteTo(retryPacket, clientUDPConn.LocalAddr())
_, err = serverTransport.WriteTo(retryPacket, clientTransport.Conn.LocalAddr())
Expect(err).ToNot(HaveOccurred())
}
return rtt / 2
@ -423,7 +413,7 @@ var _ = Describe("MITM test", func() {
defer close(done)
injected = true
initialPacket := testutils.ComposeInitialPacket(hdr.DestConnectionID, hdr.SrcConnectionID, hdr.Version, hdr.DestConnectionID, nil)
_, err = serverUDPConn.WriteTo(initialPacket, clientUDPConn.LocalAddr())
_, err = serverTransport.WriteTo(initialPacket, clientTransport.Conn.LocalAddr())
Expect(err).ToNot(HaveOccurred())
}
return rtt
@ -453,7 +443,7 @@ var _ = Describe("MITM test", func() {
// Fake Initial with ACK for packet 2 (unsent)
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}}
initialPacket := testutils.ComposeInitialPacket(hdr.DestConnectionID, hdr.SrcConnectionID, hdr.Version, hdr.DestConnectionID, []wire.Frame{ack})
_, err = serverUDPConn.WriteTo(initialPacket, clientUDPConn.LocalAddr())
_, err = serverTransport.WriteTo(initialPacket, clientTransport.Conn.LocalAddr())
Expect(err).ToNot(HaveOccurred())
}
return rtt

View file

@ -27,27 +27,7 @@ type OOBCapablePacketConn interface {
var _ OOBCapablePacketConn = &net.UDPConn{}
// OptimizeConn takes a net.PacketConn and attempts to enable various optimizations that will improve QUIC performance:
// 1. It enables the Don't Fragment (DF) bit on the IP header.
// This is required to run DPLPMTUD (Path MTU Discovery, RFC 8899).
// 2. It enables reading of the ECN bits from the IP header.
// This allows the remote node to speed up its loss detection and recovery.
// 3. It uses batched syscalls (recvmmsg) to more efficiently receive packets from the socket.
// 4. It uses Generic Segmentation Offload (GSO) to efficiently send batches of packets (on Linux).
//
// In order for this to work, the connection needs to implement the OOBCapablePacketConn interface (as a *net.UDPConn does).
//
// It's only necessary to call this function explicitly if the application calls WriteTo
// after passing the connection to the Transport.
func OptimizeConn(c net.PacketConn) (net.PacketConn, error) {
return wrapConn(c)
}
func wrapConn(pc net.PacketConn) (interface {
net.PacketConn
rawConn
}, error,
) {
func wrapConn(pc net.PacketConn) (rawConn, error) {
if err := setReceiveBuffer(pc); err != nil {
if !strings.Contains(err.Error(), "use of closed network connection") {
setBufferWarningOnce.Do(func() {

View file

@ -230,13 +230,6 @@ func (c *oobConn) ReadPacket() (receivedPacket, error) {
return p, nil
}
// WriteTo (re)implements the net.PacketConn method.
// This is needed for users who call OptimizeConn to be able to send (non-QUIC) packets on the underlying connection.
// With GSO enabled, this would otherwise not be needed, as the kernel requires the UDP_SEGMENT message to be set.
func (c *oobConn) WriteTo(p []byte, addr net.Addr) (int, error) {
return c.WritePacket(p, uint16(len(p)), addr, nil)
}
// WritePacket writes a new packet.
// If the connection supports GSO (and we activated GSO support before),
// it appends the UDP_SEGMENT size message to oob.

View file

@ -26,9 +26,16 @@ type Transport struct {
// A single net.PacketConn can only be handled by one Transport.
// Bad things will happen if passed to multiple Transports.
//
// If not done by the user, the connection is passed through OptimizeConn to enable a number of optimizations.
// After passing the connection to the Transport, it's invalid to call ReadFrom on the connection.
// Calling WriteTo is only valid on the connection returned by OptimizeConn.
// A number of optimizations will be enabled if the connections implements the OOBCapablePacketConn interface,
// as a *net.UDPConn does.
// 1. It enables the Don't Fragment (DF) bit on the IP header.
// This is required to run DPLPMTUD (Path MTU Discovery, RFC 8899).
// 2. It enables reading of the ECN bits from the IP header.
// This allows the remote node to speed up its loss detection and recovery.
// 3. It uses batched syscalls (recvmmsg) to more efficiently receive packets from the socket.
// 4. It uses Generic Segmentation Offload (GSO) to efficiently send batches of packets (on Linux).
//
// After passing the connection to the Transport, it's invalid to call ReadFrom or WriteTo on the connection.
Conn net.PacketConn
// The length of the connection ID in bytes.
@ -99,7 +106,7 @@ func (t *Transport) Listen(tlsConf *tls.Config, conf *Config) (*Listener, error)
return nil, errListenerAlreadySet
}
conf = populateServerConfig(conf)
if err := t.init(true); err != nil {
if err := t.init(false); err != nil {
return nil, err
}
s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.Tracer, t.closeServer, false)
@ -128,7 +135,7 @@ func (t *Transport) ListenEarly(tlsConf *tls.Config, conf *Config) (*EarlyListen
return nil, errListenerAlreadySet
}
conf = populateServerConfig(conf)
if err := t.init(true); err != nil {
if err := t.init(false); err != nil {
return nil, err
}
s, err := newServer(t.conn, t.handlerMap, t.connIDGenerator, tlsConf, conf, t.Tracer, t.closeServer, true)
@ -145,7 +152,7 @@ func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config
return nil, err
}
conf = populateConfig(conf)
if err := t.init(false); err != nil {
if err := t.init(t.isSingleUse); err != nil {
return nil, err
}
var onClose func()
@ -163,7 +170,7 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C
return nil, err
}
conf = populateConfig(conf)
if err := t.init(false); err != nil {
if err := t.init(t.isSingleUse); err != nil {
return nil, err
}
var onClose func()
@ -175,7 +182,7 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C
return dial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, true)
}
func (t *Transport) init(isServer bool) error {
func (t *Transport) init(allowZeroLengthConnIDs bool) error {
t.initOnce.Do(func() {
var conn rawConn
if c, ok := t.Conn.(rawConn); ok {
@ -203,7 +210,7 @@ func (t *Transport) init(isServer bool) error {
t.connIDLen = t.ConnectionIDGenerator.ConnectionIDLen()
} else {
connIDLen := t.ConnectionIDLength
if t.ConnectionIDLength == 0 && (!t.isSingleUse || isServer) {
if t.ConnectionIDLength == 0 && !allowZeroLengthConnIDs {
connIDLen = protocol.DefaultConnectionIDLength
}
t.connIDLen = connIDLen
@ -217,6 +224,14 @@ func (t *Transport) init(isServer bool) error {
return t.initErr
}
// WriteTo sends a packet on the underlying connection.
func (t *Transport) WriteTo(b []byte, addr net.Addr) (int, error) {
if err := t.init(false); err != nil {
return 0, err
}
return t.conn.WritePacket(b, uint16(len(b)), addr, nil)
}
func (t *Transport) enqueueClosePacket(p closePacket) {
select {
case t.closeQueue <- p: