remove OptimizeConn, add a Transport.WriteTo method instead (#3957)

* remove OptimizeConn, add a Transport.WriteTo method instead

* fix race condition in Transport.WriteTo
This commit is contained in:
Marten Seemann 2023-07-19 10:28:11 -07:00 committed by GitHub
parent 27301f791f
commit a347d664e2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 54 additions and 76 deletions

View file

@ -26,9 +26,10 @@ var _ = Describe("MITM test", func() {
const connIDLen = 6 // explicitly set the connection ID length, so the proxy can parse it
var (
serverUDPConn, clientUDPConn net.PacketConn
serverConn quic.Connection
serverConfig *quic.Config
clientUDPConn net.PacketConn
serverTransport, clientTransport *quic.Transport
serverConn quic.Connection
serverConfig *quic.Config
)
startServerAndProxy := func(delayCb quicproxy.DelayCallback, dropCb quicproxy.DropCallback) (proxyPort int, closeFn func()) {
@ -36,13 +37,11 @@ var _ = Describe("MITM test", func() {
Expect(err).ToNot(HaveOccurred())
c, err := net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
serverUDPConn, err = quic.OptimizeConn(c)
Expect(err).ToNot(HaveOccurred())
tr := &quic.Transport{
Conn: serverUDPConn,
serverTransport = &quic.Transport{
Conn: c,
ConnectionIDLength: connIDLen,
}
ln, err := tr.Listen(getTLSConfig(), serverConfig)
ln, err := serverTransport.Listen(getTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())
done := make(chan struct{})
go func() {
@ -69,7 +68,7 @@ var _ = Describe("MITM test", func() {
return proxy.LocalPort(), func() {
proxy.Close()
ln.Close()
serverUDPConn.Close()
serverTransport.Close()
<-done
}
}
@ -78,10 +77,12 @@ var _ = Describe("MITM test", func() {
serverConfig = getQuicConfig(nil)
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
c, err := net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
clientUDPConn, err = quic.OptimizeConn(c)
clientUDPConn, err = net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
clientTransport = &quic.Transport{
Conn: clientUDPConn,
ConnectionIDLength: connIDLen,
}
})
Context("unsuccessful attacks", func() {
@ -90,12 +91,13 @@ var _ = Describe("MITM test", func() {
// Test shutdown is tricky due to the proxy. Just wait for a bit.
time.Sleep(50 * time.Millisecond)
Expect(clientUDPConn.Close()).To(Succeed())
Expect(clientTransport.Close()).To(Succeed())
})
Context("injecting invalid packets", func() {
const rtt = 20 * time.Millisecond
sendRandomPacketsOfSameType := func(conn net.PacketConn, remoteAddr net.Addr, raw []byte) {
sendRandomPacketsOfSameType := func(conn *quic.Transport, remoteAddr net.Addr, raw []byte) {
defer GinkgoRecover()
const numPackets = 10
ticker := time.NewTicker(rtt / numPackets)
@ -155,11 +157,7 @@ var _ = Describe("MITM test", func() {
defer closeFn()
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
Expect(err).ToNot(HaveOccurred())
tr := &quic.Transport{
Conn: clientUDPConn,
ConnectionIDLength: connIDLen,
}
conn, err := tr.Dial(
conn, err := clientTransport.Dial(
context.Background(),
raddr,
getTLSClientConfig(),
@ -178,7 +176,7 @@ var _ = Describe("MITM test", func() {
delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
if dir == quicproxy.DirectionIncoming {
defer GinkgoRecover()
go sendRandomPacketsOfSameType(clientUDPConn, serverUDPConn.LocalAddr(), raw)
go sendRandomPacketsOfSameType(clientTransport, serverTransport.Conn.LocalAddr(), raw)
}
return rtt / 2
}
@ -189,7 +187,7 @@ var _ = Describe("MITM test", func() {
delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
if dir == quicproxy.DirectionOutgoing {
defer GinkgoRecover()
go sendRandomPacketsOfSameType(serverUDPConn, clientUDPConn.LocalAddr(), raw)
go sendRandomPacketsOfSameType(serverTransport, clientTransport.Conn.LocalAddr(), raw)
}
return rtt / 2
}
@ -202,11 +200,7 @@ var _ = Describe("MITM test", func() {
defer closeFn()
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
Expect(err).ToNot(HaveOccurred())
tr := &quic.Transport{
Conn: clientUDPConn,
ConnectionIDLength: connIDLen,
}
conn, err := tr.Dial(
conn, err := clientTransport.Dial(
context.Background(),
raddr,
getTLSClientConfig(),
@ -226,7 +220,7 @@ var _ = Describe("MITM test", func() {
dropCb := func(dir quicproxy.Direction, raw []byte) bool {
defer GinkgoRecover()
if dir == quicproxy.DirectionIncoming {
_, err := clientUDPConn.WriteTo(raw, serverUDPConn.LocalAddr())
_, err := clientTransport.WriteTo(raw, serverTransport.Conn.LocalAddr())
Expect(err).ToNot(HaveOccurred())
}
return false
@ -238,7 +232,7 @@ var _ = Describe("MITM test", func() {
dropCb := func(dir quicproxy.Direction, raw []byte) bool {
defer GinkgoRecover()
if dir == quicproxy.DirectionOutgoing {
_, err := serverUDPConn.WriteTo(raw, clientUDPConn.LocalAddr())
_, err := serverTransport.WriteTo(raw, clientTransport.Conn.LocalAddr())
Expect(err).ToNot(HaveOccurred())
}
return false
@ -276,7 +270,7 @@ var _ = Describe("MITM test", func() {
if rand.Intn(interval) == 0 {
pos := rand.Intn(len(raw))
raw[pos] = byte(rand.Intn(256))
_, err := clientUDPConn.WriteTo(raw, serverUDPConn.LocalAddr())
_, err := clientTransport.WriteTo(raw, serverTransport.Conn.LocalAddr())
Expect(err).ToNot(HaveOccurred())
atomic.AddInt32(&numCorrupted, 1)
return true
@ -296,7 +290,7 @@ var _ = Describe("MITM test", func() {
if rand.Intn(interval) == 0 {
pos := rand.Intn(len(raw))
raw[pos] = byte(rand.Intn(256))
_, err := serverUDPConn.WriteTo(raw, clientUDPConn.LocalAddr())
_, err := serverTransport.WriteTo(raw, clientTransport.Conn.LocalAddr())
Expect(err).ToNot(HaveOccurred())
atomic.AddInt32(&numCorrupted, 1)
return true
@ -320,17 +314,13 @@ var _ = Describe("MITM test", func() {
proxyPort, serverCloseFn := startServerAndProxy(delayCb, nil)
raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort))
Expect(err).ToNot(HaveOccurred())
tr := &quic.Transport{
Conn: clientUDPConn,
ConnectionIDLength: connIDLen,
}
_, err = tr.Dial(
_, err = clientTransport.Dial(
context.Background(),
raddr,
getTLSClientConfig(),
getQuicConfig(&quic.Config{HandshakeIdleTimeout: 2 * time.Second}),
)
return func() { tr.Close(); serverCloseFn() }, err
return func() { clientTransport.Close(); serverCloseFn() }, err
}
// fails immediately because client connection closes when it can't find compatible version
@ -356,7 +346,7 @@ var _ = Describe("MITM test", func() {
)
// Send the packet
_, err = serverUDPConn.WriteTo(packet, clientUDPConn.LocalAddr())
_, err = serverTransport.WriteTo(packet, clientTransport.Conn.LocalAddr())
Expect(err).ToNot(HaveOccurred())
close(done)
}
@ -393,7 +383,7 @@ var _ = Describe("MITM test", func() {
fakeSrcConnID := protocol.ParseConnectionID([]byte{0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12})
retryPacket := testutils.ComposeRetryPacket(fakeSrcConnID, hdr.SrcConnectionID, hdr.DestConnectionID, []byte("token"), hdr.Version)
_, err = serverUDPConn.WriteTo(retryPacket, clientUDPConn.LocalAddr())
_, err = serverTransport.WriteTo(retryPacket, clientTransport.Conn.LocalAddr())
Expect(err).ToNot(HaveOccurred())
}
return rtt / 2
@ -423,7 +413,7 @@ var _ = Describe("MITM test", func() {
defer close(done)
injected = true
initialPacket := testutils.ComposeInitialPacket(hdr.DestConnectionID, hdr.SrcConnectionID, hdr.Version, hdr.DestConnectionID, nil)
_, err = serverUDPConn.WriteTo(initialPacket, clientUDPConn.LocalAddr())
_, err = serverTransport.WriteTo(initialPacket, clientTransport.Conn.LocalAddr())
Expect(err).ToNot(HaveOccurred())
}
return rtt
@ -453,7 +443,7 @@ var _ = Describe("MITM test", func() {
// Fake Initial with ACK for packet 2 (unsent)
ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}}
initialPacket := testutils.ComposeInitialPacket(hdr.DestConnectionID, hdr.SrcConnectionID, hdr.Version, hdr.DestConnectionID, []wire.Frame{ack})
_, err = serverUDPConn.WriteTo(initialPacket, clientUDPConn.LocalAddr())
_, err = serverTransport.WriteTo(initialPacket, clientTransport.Conn.LocalAddr())
Expect(err).ToNot(HaveOccurred())
}
return rtt