From 2b0a03a9883c62145d078187d77d95a0016636c4 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 22 Apr 2023 12:29:26 +0200 Subject: [PATCH] set the QUIC version for integration tests using a command line flag --- .github/workflows/integration.yml | 17 +- http3/client.go | 6 +- http3/client_test.go | 2 +- integrationtests/self/close_test.go | 115 +- integrationtests/self/conn_id_test.go | 32 +- integrationtests/self/datagram_test.go | 303 ++-- integrationtests/self/drop_test.go | 135 +- integrationtests/self/early_data_test.go | 96 +- integrationtests/self/handshake_drop_test.go | 191 ++- integrationtests/self/handshake_rtt_test.go | 42 +- integrationtests/self/handshake_test.go | 167 +- integrationtests/self/hotswap_test.go | 135 +- integrationtests/self/http_test.go | 686 ++++---- integrationtests/self/mitm_test.go | 840 +++++----- integrationtests/self/multiplex_test.go | 389 +++-- integrationtests/self/rtt_test.go | 159 +- integrationtests/self/self_suite_test.go | 23 +- integrationtests/self/stream_test.go | 261 ++- integrationtests/self/uni_stream_test.go | 10 +- integrationtests/self/zero_rtt_test.go | 1404 ++++++++--------- .../versionnegotiation/rtt_test.go | 53 + 21 files changed, 2471 insertions(+), 2595 deletions(-) create mode 100644 integrationtests/versionnegotiation/rtt_test.go diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index d13a5c7c..0dbf928b 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -20,19 +20,24 @@ jobs: - run: go version - name: set qlogger if: env.DEBUG == 'true' - run: echo "QLOGFLAG=-- -qlog" >> $GITHUB_ENV - - name: Run tests + run: echo "QLOGFLAG= -qlog" >> $GITHUB_ENV + - name: Run other tests run: | go run github.com/onsi/ginkgo/v2/ginkgo -r -v -randomize-all -randomize-suites -trace -skip-package self,versionnegotiation integrationtests - go run github.com/onsi/ginkgo/v2/ginkgo -r -v -randomize-all -randomize-suites -trace integrationtests/versionnegotiation ${{ env.QLOGFLAG }} - go run github.com/onsi/ginkgo/v2/ginkgo -r -v -randomize-all -randomize-suites -trace integrationtests/self ${{ env.QLOGFLAG }} + go run github.com/onsi/ginkgo/v2/ginkgo -r -v -randomize-all -randomize-suites -trace integrationtests/versionnegotiation -- ${{ env.QLOGFLAG }} + - name: Run self tests, using draft-29 + run: go run github.com/onsi/ginkgo/v2/ginkgo -r -v -randomize-all -randomize-suites -trace integrationtests/self -- -version=draft29 ${{ env.QLOGFLAG }} + - name: Run self tests, using QUIC v1 + run: go run github.com/onsi/ginkgo/v2/ginkgo -r -v -randomize-all -randomize-suites -trace integrationtests/self -- -version=1 ${{ env.QLOGFLAG }} + - name: Run self tests, using QUIC v2 + run: go run github.com/onsi/ginkgo/v2/ginkgo -r -v -randomize-all -randomize-suites -trace integrationtests/self -- -version=2 ${{ env.QLOGFLAG }} - name: Run tests (32 bit) env: GOARCH: 386 run: | go run github.com/onsi/ginkgo/v2/ginkgo -r -v -randomize-all -randomize-suites -trace -skip-package self,versionnegotiation integrationtests - go run github.com/onsi/ginkgo/v2/ginkgo -r -v -randomize-all -randomize-suites -trace integrationtests/versionnegotiation ${{ env.QLOGFLAG }} - go run github.com/onsi/ginkgo/v2/ginkgo -r -v -randomize-all -randomize-suites -trace integrationtests/self ${{ env.QLOGFLAG }} + go run github.com/onsi/ginkgo/v2/ginkgo -r -v -randomize-all -randomize-suites -trace integrationtests/versionnegotiation -- ${{ env.QLOGFLAG }} + go run github.com/onsi/ginkgo/v2/ginkgo -r -v -randomize-all -randomize-suites -trace integrationtests/self -- ${{ env.QLOGFLAG }} - name: save qlogs if: ${{ always() && env.DEBUG == 'true' }} uses: actions/upload-artifact@v2 diff --git a/http3/client.go b/http3/client.go index d89f2090..82af3aff 100644 --- a/http3/client.go +++ b/http3/client.go @@ -33,7 +33,6 @@ const ( var defaultQuicConfig = &quic.Config{ MaxIncomingStreams: -1, // don't allow the server to create bidirectional streams KeepAlivePeriod: 10 * time.Second, - Versions: []protocol.VersionNumber{protocol.Version1}, } type dialFunc func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) @@ -74,9 +73,10 @@ var _ roundTripCloser = &client{} func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (roundTripCloser, error) { if conf == nil { conf = defaultQuicConfig.Clone() - } else if len(conf.Versions) == 0 { + } + if len(conf.Versions) == 0 { conf = conf.Clone() - conf.Versions = []quic.VersionNumber{defaultQuicConfig.Versions[0]} + conf.Versions = []quic.VersionNumber{protocol.SupportedVersions[0]} } if len(conf.Versions) != 1 { return nil, errors.New("can only use a single QUIC version for dialing a HTTP/3 connection") diff --git a/http3/client_test.go b/http3/client_test.go index 9f249008..d0f68295 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -65,7 +65,7 @@ var _ = Describe("Client", func() { Expect(err).ToNot(HaveOccurred()) var dialAddrCalled bool dialAddr = func(_ context.Context, _ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlyConnection, error) { - Expect(quicConf).To(Equal(defaultQuicConfig)) + Expect(quicConf.MaxIncomingStreams).To(Equal(defaultQuicConfig.MaxIncomingStreams)) Expect(tlsConf.NextProtos).To(Equal([]string{NextProtoH3})) Expect(quicConf.Versions).To(Equal([]protocol.VersionNumber{protocol.Version1})) dialAddrCalled = true diff --git a/integrationtests/self/close_test.go b/integrationtests/self/close_test.go index afa9e603..c2f1439a 100644 --- a/integrationtests/self/close_test.go +++ b/integrationtests/self/close_test.go @@ -9,74 +9,67 @@ import ( "github.com/quic-go/quic-go" quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy" - "github.com/quic-go/quic-go/internal/protocol" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) var _ = Describe("Connection ID lengths tests", func() { - for _, v := range protocol.SupportedVersions { - version := v + It("retransmits the CONNECTION_CLOSE packet", func() { + server, err := quic.ListenAddr( + "localhost:0", + getTLSConfig(), + getQuicConfig(&quic.Config{ + DisablePathMTUDiscovery: true, + }), + ) + Expect(err).ToNot(HaveOccurred()) - Context(fmt.Sprintf("with QUIC version %s", version), func() { - It("retransmits the CONNECTION_CLOSE packet", func() { - server, err := quic.ListenAddr( - "localhost:0", - getTLSConfig(), - getQuicConfig(&quic.Config{ - DisablePathMTUDiscovery: true, - }), - ) - Expect(err).ToNot(HaveOccurred()) - - var drop atomic.Bool - dropped := make(chan []byte, 100) - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), - DelayPacket: func(dir quicproxy.Direction, _ []byte) time.Duration { - return 5 * time.Millisecond // 10ms RTT - }, - DropPacket: func(dir quicproxy.Direction, b []byte) bool { - if drop := drop.Load(); drop && dir == quicproxy.DirectionOutgoing { - dropped <- b - return true - } - return false - }, - }) - Expect(err).ToNot(HaveOccurred()) - defer proxy.Close() - - conn, err := quic.DialAddr( - fmt.Sprintf("localhost:%d", proxy.LocalPort()), - getTLSClientConfig(), - getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}), - ) - Expect(err).ToNot(HaveOccurred()) - - sconn, err := server.Accept(context.Background()) - Expect(err).ToNot(HaveOccurred()) - time.Sleep(100 * time.Millisecond) - drop.Store(true) - sconn.CloseWithError(1337, "closing") - - // send 100 packets - for i := 0; i < 100; i++ { - str, err := conn.OpenStream() - Expect(err).ToNot(HaveOccurred()) - _, err = str.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - time.Sleep(time.Millisecond) + var drop atomic.Bool + dropped := make(chan []byte, 100) + proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ + RemoteAddr: fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), + DelayPacket: func(dir quicproxy.Direction, _ []byte) time.Duration { + return 5 * time.Millisecond // 10ms RTT + }, + DropPacket: func(dir quicproxy.Direction, b []byte) bool { + if drop := drop.Load(); drop && dir == quicproxy.DirectionOutgoing { + dropped <- b + return true } - // Expect retransmissions of the CONNECTION_CLOSE for the - // 1st, 2nd, 4th, 8th, 16th, 32th, 64th packet: 7 in total (+1 for the original packet) - Eventually(dropped).Should(HaveLen(8)) - first := <-dropped - for len(dropped) > 0 { - Expect(<-dropped).To(Equal(first)) // these packets are all identical - } - }) + return false + }, }) - } + Expect(err).ToNot(HaveOccurred()) + defer proxy.Close() + + conn, err := quic.DialAddr( + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + getTLSClientConfig(), + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + + sconn, err := server.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + time.Sleep(100 * time.Millisecond) + drop.Store(true) + sconn.CloseWithError(1337, "closing") + + // send 100 packets + for i := 0; i < 100; i++ { + str, err := conn.OpenStream() + Expect(err).ToNot(HaveOccurred()) + _, err = str.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + time.Sleep(time.Millisecond) + } + // Expect retransmissions of the CONNECTION_CLOSE for the + // 1st, 2nd, 4th, 8th, 16th, 32th, 64th packet: 7 in total (+1 for the original packet) + Eventually(dropped).Should(HaveLen(8)) + first := <-dropped + for len(dropped) > 0 { + Expect(<-dropped).To(Equal(first)) // these packets are all identical + } + }) }) diff --git a/integrationtests/self/conn_id_test.go b/integrationtests/self/conn_id_test.go index 35d4ee29..21b413e6 100644 --- a/integrationtests/self/conn_id_test.go +++ b/integrationtests/self/conn_id_test.go @@ -32,9 +32,7 @@ func (c *connIDGenerator) ConnectionIDLen() int { } var _ = Describe("Connection ID lengths tests", func() { - randomConnIDLen := func() int { - return 4 + int(mrand.Int31n(15)) - } + randomConnIDLen := func() int { return 4 + int(mrand.Int31n(15)) } runServer := func(conf *quic.Config) quic.Listener { GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the server\n", conf.ConnectionIDLength))) @@ -77,46 +75,32 @@ var _ = Describe("Connection ID lengths tests", func() { } It("downloads a file using a 0-byte connection ID for the client", func() { - serverConf := getQuicConfig(&quic.Config{ - ConnectionIDLength: randomConnIDLen(), - Versions: []protocol.VersionNumber{protocol.Version1}, - }) - clientConf := getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{protocol.Version1}, - }) - + serverConf := getQuicConfig(&quic.Config{ConnectionIDLength: randomConnIDLen()}) ln := runServer(serverConf) defer ln.Close() - runClient(ln.Addr(), clientConf) + + runClient(ln.Addr(), getQuicConfig(nil)) }) It("downloads a file when both client and server use a random connection ID length", func() { - serverConf := getQuicConfig(&quic.Config{ - ConnectionIDLength: randomConnIDLen(), - Versions: []protocol.VersionNumber{protocol.Version1}, - }) - clientConf := getQuicConfig(&quic.Config{ - ConnectionIDLength: randomConnIDLen(), - Versions: []protocol.VersionNumber{protocol.Version1}, - }) - + serverConf := getQuicConfig(&quic.Config{ConnectionIDLength: randomConnIDLen()}) ln := runServer(serverConf) defer ln.Close() - runClient(ln.Addr(), clientConf) + + runClient(ln.Addr(), getQuicConfig(nil)) }) It("downloads a file when both client and server use a custom connection ID generator", func() { serverConf := getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{protocol.Version1}, ConnectionIDGenerator: &connIDGenerator{length: randomConnIDLen()}, }) clientConf := getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{protocol.Version1}, ConnectionIDGenerator: &connIDGenerator{length: randomConnIDLen()}, }) ln := runServer(serverConf) defer ln.Close() + runClient(ln.Addr(), clientConf) }) }) diff --git a/integrationtests/self/datagram_test.go b/integrationtests/self/datagram_test.go index db5a6758..40c8eb96 100644 --- a/integrationtests/self/datagram_test.go +++ b/integrationtests/self/datagram_test.go @@ -12,7 +12,6 @@ import ( "github.com/quic-go/quic-go" quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy" - "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/wire" . "github.com/onsi/ginkgo/v2" @@ -20,175 +19,157 @@ import ( ) var _ = Describe("Datagram test", func() { - for _, v := range protocol.SupportedVersions { - version := v + const num = 100 - Context(fmt.Sprintf("with QUIC version %s", version), func() { - const num = 100 + var ( + proxy *quicproxy.QuicProxy + serverConn, clientConn *net.UDPConn + dropped, total int32 + ) - var ( - proxy *quicproxy.QuicProxy - serverConn, clientConn *net.UDPConn - dropped, total int32 - ) + startServerAndProxy := func(enableDatagram, expectDatagramSupport bool) { + addr, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + serverConn, err = net.ListenUDP("udp", addr) + Expect(err).ToNot(HaveOccurred()) + ln, err := quic.Listen( + serverConn, + getTLSConfig(), + getQuicConfig(&quic.Config{EnableDatagrams: enableDatagram}), + ) + Expect(err).ToNot(HaveOccurred()) - startServerAndProxy := func(enableDatagram, expectDatagramSupport bool) { - addr, err := net.ResolveUDPAddr("udp", "localhost:0") - Expect(err).ToNot(HaveOccurred()) - serverConn, err = net.ListenUDP("udp", addr) - Expect(err).ToNot(HaveOccurred()) - ln, err := quic.Listen( - serverConn, - getTLSConfig(), - getQuicConfig(&quic.Config{ - EnableDatagrams: enableDatagram, - Versions: []protocol.VersionNumber{version}, - }), - ) - Expect(err).ToNot(HaveOccurred()) + go func() { + defer GinkgoRecover() + conn, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) - go func() { - defer GinkgoRecover() - conn, err := ln.Accept(context.Background()) - Expect(err).ToNot(HaveOccurred()) - - if expectDatagramSupport { - Expect(conn.ConnectionState().SupportsDatagrams).To(BeTrue()) - - if enableDatagram { - var wg sync.WaitGroup - wg.Add(num) - for i := 0; i < num; i++ { - go func(i int) { - defer GinkgoRecover() - defer wg.Done() - b := make([]byte, 8) - binary.BigEndian.PutUint64(b, uint64(i)) - Expect(conn.SendMessage(b)).To(Succeed()) - }(i) - } - wg.Wait() - } - } else { - Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse()) - } - }() - - serverPort := ln.Addr().(*net.UDPAddr).Port - proxy, err = quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), - // drop 10% of Short Header packets sent from the server - DropPacket: func(dir quicproxy.Direction, packet []byte) bool { - if dir == quicproxy.DirectionIncoming { - return false - } - // don't drop Long Header packets - if wire.IsLongHeaderPacket(packet[0]) { - return false - } - drop := mrand.Int()%10 == 0 - if drop { - atomic.AddInt32(&dropped, 1) - } - atomic.AddInt32(&total, 1) - return drop - }, - }) - Expect(err).ToNot(HaveOccurred()) - } - - BeforeEach(func() { - addr, err := net.ResolveUDPAddr("udp", "localhost:0") - Expect(err).ToNot(HaveOccurred()) - clientConn, err = net.ListenUDP("udp", addr) - Expect(err).ToNot(HaveOccurred()) - }) - - AfterEach(func() { - Expect(proxy.Close()).To(Succeed()) - }) - - It("sends datagrams", func() { - startServerAndProxy(true, true) - raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort())) - Expect(err).ToNot(HaveOccurred()) - conn, err := quic.Dial( - clientConn, - raddr, - fmt.Sprintf("localhost:%d", proxy.LocalPort()), - getTLSClientConfig(), - getQuicConfig(&quic.Config{ - EnableDatagrams: true, - Versions: []protocol.VersionNumber{version}, - }), - ) - Expect(err).ToNot(HaveOccurred()) + if expectDatagramSupport { Expect(conn.ConnectionState().SupportsDatagrams).To(BeTrue()) - var counter int - for { - // Close the connection if no message is received for 100 ms. - timer := time.AfterFunc(scaleDuration(100*time.Millisecond), func() { - conn.CloseWithError(0, "") - }) - if _, err := conn.ReceiveMessage(); err != nil { - break + + if enableDatagram { + var wg sync.WaitGroup + wg.Add(num) + for i := 0; i < num; i++ { + go func(i int) { + defer GinkgoRecover() + defer wg.Done() + b := make([]byte, 8) + binary.BigEndian.PutUint64(b, uint64(i)) + Expect(conn.SendMessage(b)).To(Succeed()) + }(i) } - timer.Stop() - counter++ + wg.Wait() } - - numDropped := int(atomic.LoadInt32(&dropped)) - expVal := num - numDropped - fmt.Fprintf(GinkgoWriter, "Dropped %d out of %d packets.\n", numDropped, atomic.LoadInt32(&total)) - fmt.Fprintf(GinkgoWriter, "Received %d out of %d sent datagrams.\n", counter, num) - Expect(counter).To(And( - BeNumerically(">", expVal*9/10), - BeNumerically("<", num), - )) - }) - - It("server can disable datagram", func() { - startServerAndProxy(false, true) - raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort())) - Expect(err).ToNot(HaveOccurred()) - conn, err := quic.Dial( - clientConn, - raddr, - fmt.Sprintf("localhost:%d", proxy.LocalPort()), - getTLSClientConfig(), - getQuicConfig(&quic.Config{ - EnableDatagrams: true, - Versions: []protocol.VersionNumber{version}, - }), - ) - Expect(err).ToNot(HaveOccurred()) + } else { Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse()) + } + }() - conn.CloseWithError(0, "") - <-time.After(10 * time.Millisecond) - }) - - It("client can disable datagram", func() { - startServerAndProxy(false, true) - raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort())) - Expect(err).ToNot(HaveOccurred()) - conn, err := quic.Dial( - clientConn, - raddr, - fmt.Sprintf("localhost:%d", proxy.LocalPort()), - getTLSClientConfig(), - getQuicConfig(&quic.Config{ - EnableDatagrams: true, - Versions: []protocol.VersionNumber{version}, - }), - ) - Expect(err).ToNot(HaveOccurred()) - Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse()) - - Expect(conn.SendMessage([]byte{0})).To(HaveOccurred()) - conn.CloseWithError(0, "") - <-time.After(10 * time.Millisecond) - }) + serverPort := ln.Addr().(*net.UDPAddr).Port + proxy, err = quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ + RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), + // drop 10% of Short Header packets sent from the server + DropPacket: func(dir quicproxy.Direction, packet []byte) bool { + if dir == quicproxy.DirectionIncoming { + return false + } + // don't drop Long Header packets + if wire.IsLongHeaderPacket(packet[0]) { + return false + } + drop := mrand.Int()%10 == 0 + if drop { + atomic.AddInt32(&dropped, 1) + } + atomic.AddInt32(&total, 1) + return drop + }, }) + Expect(err).ToNot(HaveOccurred()) } + + BeforeEach(func() { + addr, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + clientConn, err = net.ListenUDP("udp", addr) + Expect(err).ToNot(HaveOccurred()) + }) + + AfterEach(func() { + Expect(proxy.Close()).To(Succeed()) + }) + + It("sends datagrams", func() { + startServerAndProxy(true, true) + raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort())) + Expect(err).ToNot(HaveOccurred()) + conn, err := quic.Dial( + clientConn, + raddr, + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + getTLSClientConfig(), + getQuicConfig(&quic.Config{EnableDatagrams: true}), + ) + Expect(err).ToNot(HaveOccurred()) + Expect(conn.ConnectionState().SupportsDatagrams).To(BeTrue()) + var counter int + for { + // Close the connection if no message is received for 100 ms. + timer := time.AfterFunc(scaleDuration(100*time.Millisecond), func() { + conn.CloseWithError(0, "") + }) + if _, err := conn.ReceiveMessage(); err != nil { + break + } + timer.Stop() + counter++ + } + + numDropped := int(atomic.LoadInt32(&dropped)) + expVal := num - numDropped + fmt.Fprintf(GinkgoWriter, "Dropped %d out of %d packets.\n", numDropped, atomic.LoadInt32(&total)) + fmt.Fprintf(GinkgoWriter, "Received %d out of %d sent datagrams.\n", counter, num) + Expect(counter).To(And( + BeNumerically(">", expVal*9/10), + BeNumerically("<", num), + )) + }) + + It("server can disable datagram", func() { + startServerAndProxy(false, true) + raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort())) + Expect(err).ToNot(HaveOccurred()) + conn, err := quic.Dial( + clientConn, + raddr, + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + getTLSClientConfig(), + getQuicConfig(&quic.Config{EnableDatagrams: true}), + ) + Expect(err).ToNot(HaveOccurred()) + Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse()) + + conn.CloseWithError(0, "") + <-time.After(10 * time.Millisecond) + }) + + It("client can disable datagram", func() { + startServerAndProxy(false, true) + raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxy.LocalPort())) + Expect(err).ToNot(HaveOccurred()) + conn, err := quic.Dial( + clientConn, + raddr, + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + getTLSClientConfig(), + getQuicConfig(&quic.Config{EnableDatagrams: true}), + ) + Expect(err).ToNot(HaveOccurred()) + Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse()) + + Expect(conn.SendMessage([]byte{0})).To(HaveOccurred()) + conn.CloseWithError(0, "") + <-time.After(10 * time.Millisecond) + }) }) diff --git a/integrationtests/self/drop_test.go b/integrationtests/self/drop_test.go index d15a3a3c..52eea516 100644 --- a/integrationtests/self/drop_test.go +++ b/integrationtests/self/drop_test.go @@ -10,7 +10,6 @@ import ( "github.com/quic-go/quic-go" quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy" - "github.com/quic-go/quic-go/internal/protocol" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -26,12 +25,12 @@ var _ = Describe("Drop Tests", func() { ln quic.Listener ) - startListenerAndProxy := func(dropCallback quicproxy.DropCallback, version protocol.VersionNumber) { + startListenerAndProxy := func(dropCallback quicproxy.DropCallback) { var err error ln, err = quic.ListenAddr( "localhost:0", getTLSConfig(), - getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}), + getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) serverPort := ln.Addr().(*net.UDPAddr).Port @@ -51,79 +50,73 @@ var _ = Describe("Drop Tests", func() { Expect(ln.Close()).To(Succeed()) }) - for _, v := range protocol.SupportedVersions { - version := v + for _, d := range directions { + direction := d - Context(fmt.Sprintf("with QUIC version %s", version), func() { - for _, d := range directions { - direction := d + // The purpose of this test is to create a lot of tails, by sending 1 byte messages. + // The interval, the length of the drop period, and the time when the drop period starts are randomized. + // To cover different scenarios, repeat this test a few times. + for rep := 0; rep < 3; rep++ { + It(fmt.Sprintf("sends short messages, dropping packets in %s direction", direction), func() { + const numMessages = 15 - // The purpose of this test is to create a lot of tails, by sending 1 byte messages. - // The interval, the length of the drop period, and the time when the drop period starts are randomized. - // To cover different scenarios, repeat this test a few times. - for rep := 0; rep < 3; rep++ { - It(fmt.Sprintf("sends short messages, dropping packets in %s direction", direction), func() { - const numMessages = 15 + messageInterval := randomDuration(10*time.Millisecond, 100*time.Millisecond) + dropDuration := randomDuration(messageInterval*3/2, 2*time.Second) + dropDelay := randomDuration(25*time.Millisecond, numMessages*messageInterval/2) // makes sure we don't interfere with the handshake + fmt.Fprintf(GinkgoWriter, "Sending a message every %s, %d times.\n", messageInterval, numMessages) + fmt.Fprintf(GinkgoWriter, "Dropping packets for %s, after a delay of %s\n", dropDuration, dropDelay) + startTime := time.Now() - messageInterval := randomDuration(10*time.Millisecond, 100*time.Millisecond) - dropDuration := randomDuration(messageInterval*3/2, 2*time.Second) - dropDelay := randomDuration(25*time.Millisecond, numMessages*messageInterval/2) // makes sure we don't interfere with the handshake - fmt.Fprintf(GinkgoWriter, "Sending a message every %s, %d times.\n", messageInterval, numMessages) - fmt.Fprintf(GinkgoWriter, "Dropping packets for %s, after a delay of %s\n", dropDuration, dropDelay) - startTime := time.Now() + var numDroppedPackets int32 + startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { + if !d.Is(direction) { + return false + } + drop := time.Now().After(startTime.Add(dropDelay)) && time.Now().Before(startTime.Add(dropDelay).Add(dropDuration)) + if drop { + atomic.AddInt32(&numDroppedPackets, 1) + } + return drop + }) - var numDroppedPackets int32 - startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { - if !d.Is(direction) { - return false - } - drop := time.Now().After(startTime.Add(dropDelay)) && time.Now().Before(startTime.Add(dropDelay).Add(dropDuration)) - if drop { - atomic.AddInt32(&numDroppedPackets, 1) - } - return drop - }, version) - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - conn, err := ln.Accept(context.Background()) - Expect(err).ToNot(HaveOccurred()) - str, err := conn.OpenStream() - Expect(err).ToNot(HaveOccurred()) - for i := uint8(1); i <= numMessages; i++ { - n, err := str.Write([]byte{i}) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(1)) - time.Sleep(messageInterval) - } - <-done - Expect(conn.CloseWithError(0, "")).To(Succeed()) - }() - - conn, err := quic.DialAddr( - fmt.Sprintf("localhost:%d", proxy.LocalPort()), - getTLSClientConfig(), - getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}), - ) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + conn, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + str, err := conn.OpenStream() + Expect(err).ToNot(HaveOccurred()) + for i := uint8(1); i <= numMessages; i++ { + n, err := str.Write([]byte{i}) Expect(err).ToNot(HaveOccurred()) - defer conn.CloseWithError(0, "") - str, err := conn.AcceptStream(context.Background()) - Expect(err).ToNot(HaveOccurred()) - for i := uint8(1); i <= numMessages; i++ { - b := []byte{0} - n, err := str.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(n).To(Equal(1)) - Expect(b[0]).To(Equal(i)) - } - close(done) - numDropped := atomic.LoadInt32(&numDroppedPackets) - fmt.Fprintf(GinkgoWriter, "Dropped %d packets.\n", numDropped) - Expect(numDropped).To(BeNumerically(">", 0)) - }) + Expect(n).To(Equal(1)) + time.Sleep(messageInterval) + } + <-done + Expect(conn.CloseWithError(0, "")).To(Succeed()) + }() + + conn, err := quic.DialAddr( + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + getTLSClientConfig(), + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + defer conn.CloseWithError(0, "") + str, err := conn.AcceptStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + for i := uint8(1); i <= numMessages; i++ { + b := []byte{0} + n, err := str.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(1)) + Expect(b[0]).To(Equal(i)) } - } - }) + close(done) + numDropped := atomic.LoadInt32(&numDroppedPackets) + fmt.Fprintf(GinkgoWriter, "Dropped %d packets.\n", numDropped) + Expect(numDropped).To(BeNumerically(">", 0)) + }) + } } }) diff --git a/integrationtests/self/early_data_test.go b/integrationtests/self/early_data_test.go index 063f8622..5b96b2cb 100644 --- a/integrationtests/self/early_data_test.go +++ b/integrationtests/self/early_data_test.go @@ -9,7 +9,6 @@ import ( "github.com/quic-go/quic-go" quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy" - "github.com/quic-go/quic-go/internal/protocol" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -17,56 +16,51 @@ import ( var _ = Describe("early data", func() { const rtt = 80 * time.Millisecond - for _, v := range protocol.SupportedVersions { - version := v - Context(fmt.Sprintf("with QUIC version %s", version), func() { - It("sends 0.5-RTT data", func() { - ln, err := quic.ListenAddrEarly( - "localhost:0", - getTLSConfig(), - getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}), - ) - Expect(err).ToNot(HaveOccurred()) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(done) - conn, err := ln.Accept(context.Background()) - Expect(err).ToNot(HaveOccurred()) - str, err := conn.OpenUniStream() - Expect(err).ToNot(HaveOccurred()) - _, err = str.Write([]byte("early data")) - Expect(err).ToNot(HaveOccurred()) - Expect(str.Close()).To(Succeed()) - // make sure the Write finished before the handshake completed - Expect(conn.HandshakeComplete()).ToNot(BeClosed()) - Eventually(conn.Context().Done()).Should(BeClosed()) - }() - serverPort := ln.Addr().(*net.UDPAddr).Port - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), - DelayPacket: func(quicproxy.Direction, []byte) time.Duration { - return rtt / 2 - }, - }) - Expect(err).ToNot(HaveOccurred()) - defer proxy.Close() - - conn, err := quic.DialAddr( - fmt.Sprintf("localhost:%d", proxy.LocalPort()), - getTLSClientConfig(), - getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}), - ) - Expect(err).ToNot(HaveOccurred()) - str, err := conn.AcceptUniStream(context.Background()) - Expect(err).ToNot(HaveOccurred()) - data, err := io.ReadAll(str) - Expect(err).ToNot(HaveOccurred()) - Expect(data).To(Equal([]byte("early data"))) - conn.CloseWithError(0, "") - Eventually(done).Should(BeClosed()) - }) + It("sends 0.5-RTT data", func() { + ln, err := quic.ListenAddrEarly( + "localhost:0", + getTLSConfig(), + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done) + conn, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + str, err := conn.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + _, err = str.Write([]byte("early data")) + Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) + // make sure the Write finished before the handshake completed + Expect(conn.HandshakeComplete()).ToNot(BeClosed()) + Eventually(conn.Context().Done()).Should(BeClosed()) + }() + serverPort := ln.Addr().(*net.UDPAddr).Port + proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ + RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), + DelayPacket: func(quicproxy.Direction, []byte) time.Duration { + return rtt / 2 + }, }) - } + Expect(err).ToNot(HaveOccurred()) + defer proxy.Close() + + conn, err := quic.DialAddr( + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + getTLSClientConfig(), + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + str, err := conn.AcceptUniStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + data, err := io.ReadAll(str) + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal([]byte("early data"))) + conn.CloseWithError(0, "") + Eventually(done).Should(BeClosed()) + }) }) diff --git a/integrationtests/self/handshake_drop_test.go b/integrationtests/self/handshake_drop_test.go index 2e3a278d..ebb8ee14 100644 --- a/integrationtests/self/handshake_drop_test.go +++ b/integrationtests/self/handshake_drop_test.go @@ -15,7 +15,6 @@ import ( "github.com/quic-go/quic-go" quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy" - "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/wire" . "github.com/onsi/ginkgo/v2" @@ -27,7 +26,7 @@ var directions = []quicproxy.Direction{quicproxy.DirectionIncoming, quicproxy.Di type applicationProtocol struct { name string - run func(protocol.VersionNumber) + run func() } var _ = Describe("Handshake drop tests", func() { @@ -39,11 +38,10 @@ var _ = Describe("Handshake drop tests", func() { data := GeneratePRData(5000) const timeout = 2 * time.Minute - startListenerAndProxy := func(dropCallback quicproxy.DropCallback, doRetry bool, longCertChain bool, version protocol.VersionNumber) { + startListenerAndProxy := func(dropCallback quicproxy.DropCallback, doRetry bool, longCertChain bool) { conf := getQuicConfig(&quic.Config{ MaxIdleTimeout: timeout, HandshakeIdleTimeout: timeout, - Versions: []protocol.VersionNumber{version}, RequireAddressValidation: func(net.Addr) bool { return doRetry }, }) var tlsConf *tls.Config @@ -68,7 +66,7 @@ var _ = Describe("Handshake drop tests", func() { clientSpeaksFirst := &applicationProtocol{ name: "client speaks first", - run: func(version protocol.VersionNumber) { + run: func() { serverConnChan := make(chan quic.Connection) go func() { defer GinkgoRecover() @@ -88,7 +86,6 @@ var _ = Describe("Handshake drop tests", func() { getQuicConfig(&quic.Config{ MaxIdleTimeout: timeout, HandshakeIdleTimeout: timeout, - Versions: []protocol.VersionNumber{version}, }), ) Expect(err).ToNot(HaveOccurred()) @@ -107,7 +104,7 @@ var _ = Describe("Handshake drop tests", func() { serverSpeaksFirst := &applicationProtocol{ name: "server speaks first", - run: func(version protocol.VersionNumber) { + run: func() { serverConnChan := make(chan quic.Connection) go func() { defer GinkgoRecover() @@ -126,7 +123,6 @@ var _ = Describe("Handshake drop tests", func() { getQuicConfig(&quic.Config{ MaxIdleTimeout: timeout, HandshakeIdleTimeout: timeout, - Versions: []protocol.VersionNumber{version}, }), ) Expect(err).ToNot(HaveOccurred()) @@ -145,7 +141,7 @@ var _ = Describe("Handshake drop tests", func() { nobodySpeaks := &applicationProtocol{ name: "nobody speaks", - run: func(version protocol.VersionNumber) { + run: func() { serverConnChan := make(chan quic.Connection) go func() { defer GinkgoRecover() @@ -159,7 +155,6 @@ var _ = Describe("Handshake drop tests", func() { getQuicConfig(&quic.Config{ MaxIdleTimeout: timeout, HandshakeIdleTimeout: timeout, - Versions: []protocol.VersionNumber{version}, }), ) Expect(err).ToNot(HaveOccurred()) @@ -176,106 +171,100 @@ var _ = Describe("Handshake drop tests", func() { Expect(proxy.Close()).To(Succeed()) }) - for _, v := range protocol.SupportedVersions { - version := v + for _, d := range directions { + direction := d - Context(fmt.Sprintf("with QUIC version %s", version), func() { - for _, d := range directions { - direction := d + for _, dr := range []bool{true, false} { + doRetry := dr + desc := "when using Retry" + if !dr { + desc = "when not using Retry" + } - for _, dr := range []bool{true, false} { - doRetry := dr - desc := "when using Retry" - if !dr { - desc = "when not using Retry" - } + Context(desc, func() { + for _, lcc := range []bool{false, true} { + longCertChain := lcc - Context(desc, func() { - for _, lcc := range []bool{false, true} { - longCertChain := lcc + Context(fmt.Sprintf("using a long certificate chain: %t", longCertChain), func() { + for _, a := range []*applicationProtocol{clientSpeaksFirst, serverSpeaksFirst, nobodySpeaks} { + app := a - Context(fmt.Sprintf("using a long certificate chain: %t", longCertChain), func() { - for _, a := range []*applicationProtocol{clientSpeaksFirst, serverSpeaksFirst, nobodySpeaks} { - app := a + Context(app.name, func() { + It(fmt.Sprintf("establishes a connection when the first packet is lost in %s direction", direction), func() { + var incoming, outgoing int32 + startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { + var p int32 + //nolint:exhaustive + switch d { + case quicproxy.DirectionIncoming: + p = atomic.AddInt32(&incoming, 1) + case quicproxy.DirectionOutgoing: + p = atomic.AddInt32(&outgoing, 1) + } + return p == 1 && d.Is(direction) + }, doRetry, longCertChain) + app.run() + }) - Context(app.name, func() { - It(fmt.Sprintf("establishes a connection when the first packet is lost in %s direction", direction), func() { - var incoming, outgoing int32 - startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { - var p int32 - //nolint:exhaustive - switch d { - case quicproxy.DirectionIncoming: - p = atomic.AddInt32(&incoming, 1) - case quicproxy.DirectionOutgoing: - p = atomic.AddInt32(&outgoing, 1) + It(fmt.Sprintf("establishes a connection when the second packet is lost in %s direction", direction), func() { + var incoming, outgoing int32 + startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { + var p int32 + //nolint:exhaustive + switch d { + case quicproxy.DirectionIncoming: + p = atomic.AddInt32(&incoming, 1) + case quicproxy.DirectionOutgoing: + p = atomic.AddInt32(&outgoing, 1) + } + return p == 2 && d.Is(direction) + }, doRetry, longCertChain) + app.run() + }) + + It(fmt.Sprintf("establishes a connection when 1/3 of the packets are lost in %s direction", direction), func() { + const maxSequentiallyDropped = 10 + var mx sync.Mutex + var incoming, outgoing int + + startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { + drop := mrand.Int63n(int64(3)) == 0 + + mx.Lock() + defer mx.Unlock() + // never drop more than 10 consecutive packets + if d.Is(quicproxy.DirectionIncoming) { + if drop { + incoming++ + if incoming > maxSequentiallyDropped { + drop = false } - return p == 1 && d.Is(direction) - }, doRetry, longCertChain, version) - app.run(version) - }) - - It(fmt.Sprintf("establishes a connection when the second packet is lost in %s direction", direction), func() { - var incoming, outgoing int32 - startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { - var p int32 - //nolint:exhaustive - switch d { - case quicproxy.DirectionIncoming: - p = atomic.AddInt32(&incoming, 1) - case quicproxy.DirectionOutgoing: - p = atomic.AddInt32(&outgoing, 1) + } + if !drop { + incoming = 0 + } + } + if d.Is(quicproxy.DirectionOutgoing) { + if drop { + outgoing++ + if outgoing > maxSequentiallyDropped { + drop = false } - return p == 2 && d.Is(direction) - }, doRetry, longCertChain, version) - app.run(version) - }) - - It(fmt.Sprintf("establishes a connection when 1/3 of the packets are lost in %s direction", direction), func() { - const maxSequentiallyDropped = 10 - var mx sync.Mutex - var incoming, outgoing int - - startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { - drop := mrand.Int63n(int64(3)) == 0 - - mx.Lock() - defer mx.Unlock() - // never drop more than 10 consecutive packets - if d.Is(quicproxy.DirectionIncoming) { - if drop { - incoming++ - if incoming > maxSequentiallyDropped { - drop = false - } - } - if !drop { - incoming = 0 - } - } - if d.Is(quicproxy.DirectionOutgoing) { - if drop { - outgoing++ - if outgoing > maxSequentiallyDropped { - drop = false - } - } - if !drop { - outgoing = 0 - } - } - return drop - }, doRetry, longCertChain, version) - app.run(version) - }) - }) - } + } + if !drop { + outgoing = 0 + } + } + return drop + }, doRetry, longCertChain) + app.run() + }) }) } }) } - } - }) + }) + } It("establishes a connection when the ClientHello is larger than 1 MTU (e.g. post-quantum)", func() { origAdditionalTransportParametersClient := wire.AdditionalTransportParametersClient @@ -294,8 +283,8 @@ var _ = Describe("Handshake drop tests", func() { return false } return mrand.Intn(3) == 0 - }, false, false, version) - clientSpeaksFirst.run(version) + }, false, false) + clientSpeaksFirst.run() }) } }) diff --git a/integrationtests/self/handshake_rtt_test.go b/integrationtests/self/handshake_rtt_test.go index 9478b23f..5fcc5546 100644 --- a/integrationtests/self/handshake_rtt_test.go +++ b/integrationtests/self/handshake_rtt_test.go @@ -10,7 +10,6 @@ import ( "github.com/quic-go/quic-go" quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy" - "github.com/quic-go/quic-go/internal/protocol" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -53,35 +52,6 @@ var _ = Describe("Handshake RTT tests", func() { )) } - It("fails when there's no matching version, after 1 RTT", func() { - if len(protocol.SupportedVersions) == 1 { - Skip("Test requires at least 2 supported versions.") - } - serverConfig.Versions = protocol.SupportedVersions[:1] - ln, err := quic.ListenAddr("localhost:0", serverTLSConfig, serverConfig) - Expect(err).ToNot(HaveOccurred()) - defer ln.Close() - - runProxy(ln.Addr()) - startTime := time.Now() - _, err = quic.DialAddr( - proxy.LocalAddr().String(), - getTLSClientConfig(), - getQuicConfig(&quic.Config{Versions: protocol.SupportedVersions[1:2]}), - ) - Expect(err).To(HaveOccurred()) - expectDurationInRTTs(startTime, 1) - }) - - var clientConfig *quic.Config - - BeforeEach(func() { - serverConfig.Versions = []protocol.VersionNumber{protocol.Version1} - clientConfig = getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{protocol.Version1}}) - clientConfig := getTLSClientConfig() - clientConfig.InsecureSkipVerify = true - }) - // 1 RTT for verifying the source address // 1 RTT for the TLS handshake It("is forward-secure after 2 RTTs", func() { @@ -95,7 +65,7 @@ var _ = Describe("Handshake RTT tests", func() { _, err = quic.DialAddr( fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port), getTLSClientConfig(), - clientConfig, + getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) expectDurationInRTTs(startTime, 2) @@ -111,7 +81,7 @@ var _ = Describe("Handshake RTT tests", func() { _, err = quic.DialAddr( fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port), getTLSClientConfig(), - clientConfig, + getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) expectDurationInRTTs(startTime, 1) @@ -128,7 +98,7 @@ var _ = Describe("Handshake RTT tests", func() { _, err = quic.DialAddr( fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port), getTLSClientConfig(), - clientConfig, + getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) expectDurationInRTTs(startTime, 2) @@ -138,6 +108,7 @@ var _ = Describe("Handshake RTT tests", func() { ln, err := quic.ListenAddr("localhost:0", serverTLSConfig, serverConfig) Expect(err).ToNot(HaveOccurred()) go func() { + defer GinkgoRecover() conn, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) str, err := conn.OpenUniStream() @@ -153,7 +124,7 @@ var _ = Describe("Handshake RTT tests", func() { conn, err := quic.DialAddr( fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port), getTLSClientConfig(), - clientConfig, + getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) str, err := conn.AcceptUniStream(context.Background()) @@ -168,6 +139,7 @@ var _ = Describe("Handshake RTT tests", func() { ln, err := quic.ListenAddrEarly("localhost:0", serverTLSConfig, serverConfig) Expect(err).ToNot(HaveOccurred()) go func() { + defer GinkgoRecover() conn, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) // Check the ALPN now. This is probably what an application would do. @@ -186,7 +158,7 @@ var _ = Describe("Handshake RTT tests", func() { conn, err := quic.DialAddr( fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port), getTLSClientConfig(), - clientConfig, + getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) str, err := conn.AcceptUniStream(context.Background()) diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index 7f046c4d..b6d62f64 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -128,101 +128,88 @@ var _ = Describe("Handshake tests", func() { }) Context("Certificate validation", func() { - for _, v := range protocol.SupportedVersions { - version := v + It("accepts the certificate", func() { + runServer(getTLSConfig()) + _, err := quic.DialAddr( + fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), + getTLSClientConfig(), + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + }) - Context(fmt.Sprintf("using %s", version), func() { - var clientConfig *quic.Config + It("works with a long certificate chain", func() { + runServer(getTLSConfigWithLongCertChain()) + _, err := quic.DialAddr( + fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), + getTLSClientConfig(), + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + }) - BeforeEach(func() { - serverConfig.Versions = []protocol.VersionNumber{version} - clientConfig = getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}) - }) + It("errors if the server name doesn't match", func() { + runServer(getTLSConfig()) + conn, err := net.ListenUDP("udp", nil) + Expect(err).ToNot(HaveOccurred()) + _, err = quic.Dial( + conn, + server.Addr(), + "foo.bar", + getTLSClientConfig(), + getQuicConfig(nil), + ) + Expect(err).To(HaveOccurred()) + var transportErr *quic.TransportError + Expect(errors.As(err, &transportErr)).To(BeTrue()) + Expect(transportErr.ErrorCode.IsCryptoError()).To(BeTrue()) + Expect(transportErr.Error()).To(ContainSubstring("x509: certificate is valid for localhost, not foo.bar")) + }) - It("accepts the certificate", func() { - runServer(getTLSConfig()) - _, err := quic.DialAddr( - fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), - getTLSClientConfig(), - clientConfig, - ) - Expect(err).ToNot(HaveOccurred()) - }) + It("fails the handshake if the client fails to provide the requested client cert", func() { + tlsConf := getTLSConfig() + tlsConf.ClientAuth = tls.RequireAndVerifyClientCert + runServer(tlsConf) - It("works with a long certificate chain", func() { - runServer(getTLSConfigWithLongCertChain()) - _, err := quic.DialAddr( - fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), - getTLSClientConfig(), - getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}), - ) - Expect(err).ToNot(HaveOccurred()) - }) + conn, err := quic.DialAddr( + fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), + getTLSClientConfig(), + getQuicConfig(nil), + ) + // Usually, the error will occur after the client already finished the handshake. + // However, there's a race condition here. The server's CONNECTION_CLOSE might be + // received before the connection is returned, so we might already get the error while dialing. + if err == nil { + errChan := make(chan error) + go func() { + defer GinkgoRecover() + _, err := conn.AcceptStream(context.Background()) + errChan <- err + }() + Eventually(errChan).Should(Receive(&err)) + } + Expect(err).To(HaveOccurred()) + var transportErr *quic.TransportError + Expect(errors.As(err, &transportErr)).To(BeTrue()) + Expect(transportErr.ErrorCode.IsCryptoError()).To(BeTrue()) + Expect(transportErr.Error()).To(ContainSubstring("tls: bad certificate")) + }) - It("errors if the server name doesn't match", func() { - runServer(getTLSConfig()) - conn, err := net.ListenUDP("udp", nil) - Expect(err).ToNot(HaveOccurred()) - _, err = quic.Dial( - conn, - server.Addr(), - "foo.bar", - getTLSClientConfig(), - clientConfig, - ) - Expect(err).To(HaveOccurred()) - var transportErr *quic.TransportError - Expect(errors.As(err, &transportErr)).To(BeTrue()) - Expect(transportErr.ErrorCode.IsCryptoError()).To(BeTrue()) - Expect(transportErr.Error()).To(ContainSubstring("x509: certificate is valid for localhost, not foo.bar")) - }) - - It("fails the handshake if the client fails to provide the requested client cert", func() { - tlsConf := getTLSConfig() - tlsConf.ClientAuth = tls.RequireAndVerifyClientCert - runServer(tlsConf) - - conn, err := quic.DialAddr( - fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), - getTLSClientConfig(), - clientConfig, - ) - // Usually, the error will occur after the client already finished the handshake. - // However, there's a race condition here. The server's CONNECTION_CLOSE might be - // received before the connection is returned, so we might already get the error while dialing. - if err == nil { - errChan := make(chan error) - go func() { - defer GinkgoRecover() - _, err := conn.AcceptStream(context.Background()) - errChan <- err - }() - Eventually(errChan).Should(Receive(&err)) - } - Expect(err).To(HaveOccurred()) - var transportErr *quic.TransportError - Expect(errors.As(err, &transportErr)).To(BeTrue()) - Expect(transportErr.ErrorCode.IsCryptoError()).To(BeTrue()) - Expect(transportErr.Error()).To(ContainSubstring("tls: bad certificate")) - }) - - It("uses the ServerName in the tls.Config", func() { - runServer(getTLSConfig()) - tlsConf := getTLSClientConfig() - tlsConf.ServerName = "foo.bar" - _, err := quic.DialAddr( - fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), - tlsConf, - clientConfig, - ) - Expect(err).To(HaveOccurred()) - var transportErr *quic.TransportError - Expect(errors.As(err, &transportErr)).To(BeTrue()) - Expect(transportErr.ErrorCode.IsCryptoError()).To(BeTrue()) - Expect(transportErr.Error()).To(ContainSubstring("x509: certificate is valid for localhost, not foo.bar")) - }) - }) - } + It("uses the ServerName in the tls.Config", func() { + runServer(getTLSConfig()) + tlsConf := getTLSClientConfig() + tlsConf.ServerName = "foo.bar" + _, err := quic.DialAddr( + fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), + tlsConf, + getQuicConfig(nil), + ) + Expect(err).To(HaveOccurred()) + var transportErr *quic.TransportError + Expect(errors.As(err, &transportErr)).To(BeTrue()) + Expect(transportErr.ErrorCode.IsCryptoError()).To(BeTrue()) + Expect(transportErr.Error()).To(ContainSubstring("x509: certificate is valid for localhost, not foo.bar")) + }) }) Context("rate limiting", func() { diff --git a/integrationtests/self/hotswap_test.go b/integrationtests/self/hotswap_test.go index 25ac18df..6eda72dc 100644 --- a/integrationtests/self/hotswap_test.go +++ b/integrationtests/self/hotswap_test.go @@ -3,7 +3,6 @@ package self_test import ( "context" "crypto/tls" - "fmt" "io" "net" "net/http" @@ -13,7 +12,6 @@ import ( "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3" - "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/testdata" . "github.com/onsi/ginkgo/v2" @@ -71,8 +69,6 @@ var _ = Describe("HTTP3 Server hotswap test", func() { port string ) - versions := protocol.SupportedVersions - BeforeEach(func() { mux1 = http.NewServeMux() mux1.HandleFunc("/hello1", func(w http.ResponseWriter, r *http.Request) { @@ -89,17 +85,17 @@ var _ = Describe("HTTP3 Server hotswap test", func() { server1 = &http3.Server{ Handler: mux1, TLSConfig: testdata.GetTLSConfig(), - QuicConfig: getQuicConfig(&quic.Config{Versions: versions}), + QuicConfig: getQuicConfig(nil), } server2 = &http3.Server{ Handler: mux2, TLSConfig: testdata.GetTLSConfig(), - QuicConfig: getQuicConfig(&quic.Config{Versions: versions}), + QuicConfig: getQuicConfig(nil), } tlsConf := http3.ConfigureTLSConfig(testdata.GetTLSConfig()) - quicln, err := quic.ListenAddrEarly("0.0.0.0:0", tlsConf, getQuicConfig(&quic.Config{Versions: versions})) + quicln, err := quic.ListenAddrEarly("0.0.0.0:0", tlsConf, getQuicConfig(nil)) ln = &listenerWrapper{EarlyListener: quicln} Expect(err).NotTo(HaveOccurred()) port = strconv.Itoa(ln.Addr().(*net.UDPAddr).Port) @@ -109,78 +105,69 @@ var _ = Describe("HTTP3 Server hotswap test", func() { Expect(ln.Close()).NotTo(HaveOccurred()) }) - for _, v := range versions { - version := v + BeforeEach(func() { + client = &http.Client{ + Transport: &http3.RoundTripper{ + TLSClientConfig: &tls.Config{ + RootCAs: testdata.GetRootCA(), + }, + DisableCompression: true, + QuicConfig: getQuicConfig(&quic.Config{MaxIdleTimeout: 10 * time.Second}), + }, + } + }) - Context(fmt.Sprintf("with QUIC version %s", version), func() { - BeforeEach(func() { - client = &http.Client{ - Transport: &http3.RoundTripper{ - TLSClientConfig: &tls.Config{ - RootCAs: testdata.GetRootCA(), - }, - DisableCompression: true, - QuicConfig: getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - MaxIdleTimeout: 10 * time.Second, - }), - }, - } - }) + It("hotswap works", func() { + // open first server and make single request to it + fake1 := ln.Faker() + stoppedServing1 := make(chan struct{}) + go func() { + defer GinkgoRecover() + server1.ServeListener(fake1) + close(stoppedServing1) + }() - It("hotswap works", func() { - // open first server and make single request to it - fake1 := ln.Faker() - stoppedServing1 := make(chan struct{}) - go func() { - defer GinkgoRecover() - server1.ServeListener(fake1) - close(stoppedServing1) - }() + resp, err := client.Get("https://localhost:" + port + "/hello1") + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200)) + body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second)) + Expect(err).ToNot(HaveOccurred()) + Expect(string(body)).To(Equal("Hello, World 1!\n")) - resp, err := client.Get("https://localhost:" + port + "/hello1") - Expect(err).ToNot(HaveOccurred()) - Expect(resp.StatusCode).To(Equal(200)) - body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second)) - Expect(err).ToNot(HaveOccurred()) - Expect(string(body)).To(Equal("Hello, World 1!\n")) + // open second server with same underlying listener, + // make sure it opened and both servers are currently running + fake2 := ln.Faker() + stoppedServing2 := make(chan struct{}) + go func() { + defer GinkgoRecover() + server2.ServeListener(fake2) + close(stoppedServing2) + }() - // open second server with same underlying listener, - // make sure it opened and both servers are currently running - fake2 := ln.Faker() - stoppedServing2 := make(chan struct{}) - go func() { - defer GinkgoRecover() - server2.ServeListener(fake2) - close(stoppedServing2) - }() + Consistently(stoppedServing1).ShouldNot(BeClosed()) + Consistently(stoppedServing2).ShouldNot(BeClosed()) - Consistently(stoppedServing1).ShouldNot(BeClosed()) - Consistently(stoppedServing2).ShouldNot(BeClosed()) + // now close first server, no errors should occur here + // and only the fake listener should be closed + Expect(server1.Close()).NotTo(HaveOccurred()) + Eventually(stoppedServing1).Should(BeClosed()) + Expect(fake1.closed).To(Equal(int32(1))) + Expect(fake2.closed).To(Equal(int32(0))) + Expect(ln.listenerClosed).ToNot(BeTrue()) + Expect(client.Transport.(*http3.RoundTripper).Close()).NotTo(HaveOccurred()) - // now close first server, no errors should occur here - // and only the fake listener should be closed - Expect(server1.Close()).NotTo(HaveOccurred()) - Eventually(stoppedServing1).Should(BeClosed()) - Expect(fake1.closed).To(Equal(int32(1))) - Expect(fake2.closed).To(Equal(int32(0))) - Expect(ln.listenerClosed).ToNot(BeTrue()) - Expect(client.Transport.(*http3.RoundTripper).Close()).NotTo(HaveOccurred()) + // verify that new connections are being initiated from the second server now + resp, err = client.Get("https://localhost:" + port + "/hello2") + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200)) + body, err = io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second)) + Expect(err).ToNot(HaveOccurred()) + Expect(string(body)).To(Equal("Hello, World 2!\n")) - // verify that new connections are being initiated from the second server now - resp, err = client.Get("https://localhost:" + port + "/hello2") - Expect(err).ToNot(HaveOccurred()) - Expect(resp.StatusCode).To(Equal(200)) - body, err = io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second)) - Expect(err).ToNot(HaveOccurred()) - Expect(string(body)).To(Equal("Hello, World 2!\n")) - - // close the other server - both the fake and the actual listeners must close now - Expect(server2.Close()).NotTo(HaveOccurred()) - Eventually(stoppedServing2).Should(BeClosed()) - Expect(fake2.closed).To(Equal(int32(1))) - Expect(ln.listenerClosed).To(BeTrue()) - }) - }) - } + // close the other server - both the fake and the actual listeners must close now + Expect(server2.Close()).NotTo(HaveOccurred()) + Eventually(stoppedServing2).Should(BeClosed()) + Expect(fake2.closed).To(Equal(int32(1))) + Expect(ln.listenerClosed).To(BeTrue()) + }) }) diff --git a/integrationtests/self/http_test.go b/integrationtests/self/http_test.go index 0adfff86..0344b472 100644 --- a/integrationtests/self/http_test.go +++ b/integrationtests/self/http_test.go @@ -46,8 +46,6 @@ var _ = Describe("HTTP tests", func() { port string ) - versions := protocol.SupportedVersions - BeforeEach(func() { mux = http.NewServeMux() mux.HandleFunc("/hello", func(w http.ResponseWriter, r *http.Request) { @@ -83,7 +81,7 @@ var _ = Describe("HTTP tests", func() { server = &http3.Server{ Handler: mux, TLSConfig: testdata.GetTLSConfig(), - QuicConfig: getQuicConfig(&quic.Config{Versions: versions}), + QuicConfig: getQuicConfig(nil), } addr, err := net.ResolveUDPAddr("udp", "0.0.0.0:0") @@ -106,362 +104,354 @@ var _ = Describe("HTTP tests", func() { Eventually(stoppedServing).Should(BeClosed()) }) - for _, v := range versions { - version := v + BeforeEach(func() { + client = &http.Client{ + Transport: &http3.RoundTripper{ + TLSClientConfig: &tls.Config{ + RootCAs: testdata.GetRootCA(), + }, + DisableCompression: true, + QuicConfig: getQuicConfig(&quic.Config{MaxIdleTimeout: 10 * time.Second}), + }, + } + }) - Context(fmt.Sprintf("with QUIC version %s", version), func() { - BeforeEach(func() { - client = &http.Client{ - Transport: &http3.RoundTripper{ - TLSClientConfig: &tls.Config{ - RootCAs: testdata.GetRootCA(), - }, - DisableCompression: true, - QuicConfig: getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - MaxIdleTimeout: 10 * time.Second, - }), - }, - } - }) + It("downloads a hello", func() { + resp, err := client.Get("https://localhost:" + port + "/hello") + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200)) + body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second)) + Expect(err).ToNot(HaveOccurred()) + Expect(string(body)).To(Equal("Hello, World!\n")) + }) - It("downloads a hello", func() { - resp, err := client.Get("https://localhost:" + port + "/hello") + It("downloads concurrently", func() { + group, ctx := errgroup.WithContext(context.Background()) + for i := 0; i < 2; i++ { + group.Go(func() error { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://localhost:"+port+"/hello", nil) + Expect(err).ToNot(HaveOccurred()) + resp, err := client.Do(req) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(200)) body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second)) Expect(err).ToNot(HaveOccurred()) Expect(string(body)).To(Equal("Hello, World!\n")) + + return nil }) + } - It("downloads concurrently", func() { - group, ctx := errgroup.WithContext(context.Background()) - for i := 0; i < 2; i++ { - group.Go(func() error { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://localhost:"+port+"/hello", nil) - Expect(err).ToNot(HaveOccurred()) - resp, err := client.Do(req) - Expect(err).ToNot(HaveOccurred()) - Expect(resp.StatusCode).To(Equal(200)) - body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second)) - Expect(err).ToNot(HaveOccurred()) - Expect(string(body)).To(Equal("Hello, World!\n")) + err := group.Wait() + Expect(err).ToNot(HaveOccurred()) + }) - return nil - }) + It("sets and gets request headers", func() { + handlerCalled := make(chan struct{}) + mux.HandleFunc("/headers/request", func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + Expect(r.Header.Get("foo")).To(Equal("bar")) + Expect(r.Header.Get("lorem")).To(Equal("ipsum")) + close(handlerCalled) + }) + + req, err := http.NewRequest(http.MethodGet, "https://localhost:"+port+"/headers/request", nil) + Expect(err).ToNot(HaveOccurred()) + req.Header.Set("foo", "bar") + req.Header.Set("lorem", "ipsum") + resp, err := client.Do(req) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200)) + Eventually(handlerCalled).Should(BeClosed()) + }) + + It("sets and gets response headers", func() { + mux.HandleFunc("/headers/response", func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + w.Header().Set("foo", "bar") + w.Header().Set("lorem", "ipsum") + }) + + resp, err := client.Get("https://localhost:" + port + "/headers/response") + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200)) + Expect(resp.Header.Get("foo")).To(Equal("bar")) + Expect(resp.Header.Get("lorem")).To(Equal("ipsum")) + }) + + It("downloads a small file", func() { + resp, err := client.Get("https://localhost:" + port + "/prdata") + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200)) + body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 5*time.Second)) + Expect(err).ToNot(HaveOccurred()) + Expect(body).To(Equal(PRData)) + }) + + It("downloads a large file", func() { + resp, err := client.Get("https://localhost:" + port + "/prdatalong") + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200)) + body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 20*time.Second)) + Expect(err).ToNot(HaveOccurred()) + Expect(body).To(Equal(PRDataLong)) + }) + + It("downloads many hellos", func() { + const num = 150 + + for i := 0; i < num; i++ { + resp, err := client.Get("https://localhost:" + port + "/hello") + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200)) + body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second)) + Expect(err).ToNot(HaveOccurred()) + Expect(string(body)).To(Equal("Hello, World!\n")) + } + }) + + It("downloads many files, if the response is not read", func() { + const num = 150 + + for i := 0; i < num; i++ { + resp, err := client.Get("https://localhost:" + port + "/prdata") + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200)) + Expect(resp.Body.Close()).To(Succeed()) + } + }) + + It("posts a small message", func() { + resp, err := client.Post( + "https://localhost:"+port+"/echo", + "text/plain", + bytes.NewReader([]byte("Hello, world!")), + ) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200)) + body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 5*time.Second)) + Expect(err).ToNot(HaveOccurred()) + Expect(body).To(Equal([]byte("Hello, world!"))) + }) + + It("uploads a file", func() { + resp, err := client.Post( + "https://localhost:"+port+"/echo", + "text/plain", + bytes.NewReader(PRData), + ) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200)) + body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 5*time.Second)) + Expect(err).ToNot(HaveOccurred()) + Expect(body).To(Equal(PRData)) + }) + + It("uses gzip compression", func() { + mux.HandleFunc("/gzipped/hello", func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + Expect(r.Header.Get("Accept-Encoding")).To(Equal("gzip")) + w.Header().Set("Content-Encoding", "gzip") + w.Header().Set("foo", "bar") + + gw := gzip.NewWriter(w) + defer gw.Close() + gw.Write([]byte("Hello, World!\n")) + }) + + client.Transport.(*http3.RoundTripper).DisableCompression = false + resp, err := client.Get("https://localhost:" + port + "/gzipped/hello") + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200)) + Expect(resp.Uncompressed).To(BeTrue()) + + body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second)) + Expect(err).ToNot(HaveOccurred()) + Expect(string(body)).To(Equal("Hello, World!\n")) + }) + + It("cancels requests", func() { + handlerCalled := make(chan struct{}) + mux.HandleFunc("/cancel", func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + defer close(handlerCalled) + for { + if _, err := w.Write([]byte("foobar")); err != nil { + Expect(r.Context().Done()).To(BeClosed()) + var strErr *quic.StreamError + Expect(errors.As(err, &strErr)).To(BeTrue()) + Expect(strErr.ErrorCode).To(Equal(quic.StreamErrorCode(0x10c))) + return } - - err := group.Wait() - Expect(err).ToNot(HaveOccurred()) - }) - - It("sets and gets request headers", func() { - handlerCalled := make(chan struct{}) - mux.HandleFunc("/headers/request", func(w http.ResponseWriter, r *http.Request) { - defer GinkgoRecover() - Expect(r.Header.Get("foo")).To(Equal("bar")) - Expect(r.Header.Get("lorem")).To(Equal("ipsum")) - close(handlerCalled) - }) - - req, err := http.NewRequest(http.MethodGet, "https://localhost:"+port+"/headers/request", nil) - Expect(err).ToNot(HaveOccurred()) - req.Header.Set("foo", "bar") - req.Header.Set("lorem", "ipsum") - resp, err := client.Do(req) - Expect(err).ToNot(HaveOccurred()) - Expect(resp.StatusCode).To(Equal(200)) - Eventually(handlerCalled).Should(BeClosed()) - }) - - It("sets and gets response headers", func() { - mux.HandleFunc("/headers/response", func(w http.ResponseWriter, r *http.Request) { - defer GinkgoRecover() - w.Header().Set("foo", "bar") - w.Header().Set("lorem", "ipsum") - }) - - resp, err := client.Get("https://localhost:" + port + "/headers/response") - Expect(err).ToNot(HaveOccurred()) - Expect(resp.StatusCode).To(Equal(200)) - Expect(resp.Header.Get("foo")).To(Equal("bar")) - Expect(resp.Header.Get("lorem")).To(Equal("ipsum")) - }) - - It("downloads a small file", func() { - resp, err := client.Get("https://localhost:" + port + "/prdata") - Expect(err).ToNot(HaveOccurred()) - Expect(resp.StatusCode).To(Equal(200)) - body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 5*time.Second)) - Expect(err).ToNot(HaveOccurred()) - Expect(body).To(Equal(PRData)) - }) - - It("downloads a large file", func() { - resp, err := client.Get("https://localhost:" + port + "/prdatalong") - Expect(err).ToNot(HaveOccurred()) - Expect(resp.StatusCode).To(Equal(200)) - body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 20*time.Second)) - Expect(err).ToNot(HaveOccurred()) - Expect(body).To(Equal(PRDataLong)) - }) - - It("downloads many hellos", func() { - const num = 150 - - for i := 0; i < num; i++ { - resp, err := client.Get("https://localhost:" + port + "/hello") - Expect(err).ToNot(HaveOccurred()) - Expect(resp.StatusCode).To(Equal(200)) - body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second)) - Expect(err).ToNot(HaveOccurred()) - Expect(string(body)).To(Equal("Hello, World!\n")) - } - }) - - It("downloads many files, if the response is not read", func() { - const num = 150 - - for i := 0; i < num; i++ { - resp, err := client.Get("https://localhost:" + port + "/prdata") - Expect(err).ToNot(HaveOccurred()) - Expect(resp.StatusCode).To(Equal(200)) - Expect(resp.Body.Close()).To(Succeed()) - } - }) - - It("posts a small message", func() { - resp, err := client.Post( - "https://localhost:"+port+"/echo", - "text/plain", - bytes.NewReader([]byte("Hello, world!")), - ) - Expect(err).ToNot(HaveOccurred()) - Expect(resp.StatusCode).To(Equal(200)) - body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 5*time.Second)) - Expect(err).ToNot(HaveOccurred()) - Expect(body).To(Equal([]byte("Hello, world!"))) - }) - - It("uploads a file", func() { - resp, err := client.Post( - "https://localhost:"+port+"/echo", - "text/plain", - bytes.NewReader(PRData), - ) - Expect(err).ToNot(HaveOccurred()) - Expect(resp.StatusCode).To(Equal(200)) - body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 5*time.Second)) - Expect(err).ToNot(HaveOccurred()) - Expect(body).To(Equal(PRData)) - }) - - It("uses gzip compression", func() { - mux.HandleFunc("/gzipped/hello", func(w http.ResponseWriter, r *http.Request) { - defer GinkgoRecover() - Expect(r.Header.Get("Accept-Encoding")).To(Equal("gzip")) - w.Header().Set("Content-Encoding", "gzip") - w.Header().Set("foo", "bar") - - gw := gzip.NewWriter(w) - defer gw.Close() - gw.Write([]byte("Hello, World!\n")) - }) - - client.Transport.(*http3.RoundTripper).DisableCompression = false - resp, err := client.Get("https://localhost:" + port + "/gzipped/hello") - Expect(err).ToNot(HaveOccurred()) - Expect(resp.StatusCode).To(Equal(200)) - Expect(resp.Uncompressed).To(BeTrue()) - - body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second)) - Expect(err).ToNot(HaveOccurred()) - Expect(string(body)).To(Equal("Hello, World!\n")) - }) - - It("cancels requests", func() { - handlerCalled := make(chan struct{}) - mux.HandleFunc("/cancel", func(w http.ResponseWriter, r *http.Request) { - defer GinkgoRecover() - defer close(handlerCalled) - for { - if _, err := w.Write([]byte("foobar")); err != nil { - Expect(r.Context().Done()).To(BeClosed()) - var strErr *quic.StreamError - Expect(errors.As(err, &strErr)).To(BeTrue()) - Expect(strErr.ErrorCode).To(Equal(quic.StreamErrorCode(0x10c))) - return - } - } - }) - - req, err := http.NewRequest(http.MethodGet, "https://localhost:"+port+"/cancel", nil) - Expect(err).ToNot(HaveOccurred()) - ctx, cancel := context.WithCancel(context.Background()) - req = req.WithContext(ctx) - resp, err := client.Do(req) - Expect(err).ToNot(HaveOccurred()) - Expect(resp.StatusCode).To(Equal(200)) - cancel() - Eventually(handlerCalled).Should(BeClosed()) - _, err = resp.Body.Read([]byte{0}) - Expect(err).To(HaveOccurred()) - }) - - It("allows streamed HTTP requests", func() { - done := make(chan struct{}) - mux.HandleFunc("/echoline", func(w http.ResponseWriter, r *http.Request) { - defer GinkgoRecover() - defer close(done) - w.WriteHeader(200) - w.(http.Flusher).Flush() - reader := bufio.NewReader(r.Body) - for { - msg, err := reader.ReadString('\n') - if err != nil { - return - } - _, err = w.Write([]byte(msg)) - Expect(err).ToNot(HaveOccurred()) - w.(http.Flusher).Flush() - } - }) - - r, w := io.Pipe() - req, err := http.NewRequest("PUT", "https://localhost:"+port+"/echoline", r) - Expect(err).ToNot(HaveOccurred()) - rsp, err := client.Do(req) - Expect(err).ToNot(HaveOccurred()) - Expect(rsp.StatusCode).To(Equal(200)) - - reader := bufio.NewReader(rsp.Body) - for i := 0; i < 5; i++ { - msg := fmt.Sprintf("Hello world, %d!\n", i) - fmt.Fprint(w, msg) - msgRcvd, err := reader.ReadString('\n') - Expect(err).ToNot(HaveOccurred()) - Expect(msgRcvd).To(Equal(msg)) - } - Expect(req.Body.Close()).To(Succeed()) - Eventually(done).Should(BeClosed()) - }) - - It("allows taking over the stream", func() { - mux.HandleFunc("/httpstreamer", func(w http.ResponseWriter, r *http.Request) { - defer GinkgoRecover() - w.WriteHeader(200) - w.(http.Flusher).Flush() - - str := r.Body.(http3.HTTPStreamer).HTTPStream() - str.Write([]byte("foobar")) - - // Do this in a Go routine, so that the handler returns early. - // This way, we can also check that the HTTP/3 doesn't close the stream. - go func() { - defer GinkgoRecover() - _, err := io.Copy(str, str) - Expect(err).ToNot(HaveOccurred()) - Expect(str.Close()).To(Succeed()) - }() - }) - - req, err := http.NewRequest(http.MethodGet, "https://localhost:"+port+"/httpstreamer", nil) - Expect(err).ToNot(HaveOccurred()) - rsp, err := client.Transport.(*http3.RoundTripper).RoundTripOpt(req, http3.RoundTripOpt{DontCloseRequestStream: true}) - Expect(err).ToNot(HaveOccurred()) - Expect(rsp.StatusCode).To(Equal(200)) - - str := rsp.Body.(http3.HTTPStreamer).HTTPStream() - b := make([]byte, 6) - _, err = io.ReadFull(str, b) - Expect(err).ToNot(HaveOccurred()) - Expect(b).To(Equal([]byte("foobar"))) - - data := GeneratePRData(8 * 1024) - _, err = str.Write(data) - Expect(err).ToNot(HaveOccurred()) - Expect(str.Close()).To(Succeed()) - repl, err := io.ReadAll(str) - Expect(err).ToNot(HaveOccurred()) - Expect(repl).To(Equal(data)) - }) - - It("supports read deadlines", func() { - if !go120 { - Skip("This test requires Go 1.20+") - } - - mux.HandleFunc("/read-deadline", func(w http.ResponseWriter, r *http.Request) { - defer GinkgoRecover() - err := setReadDeadline(w, time.Now().Add(deadlineDelay)) - Expect(err).ToNot(HaveOccurred()) - - body, err := io.ReadAll(r.Body) - Expect(err).To(MatchError(os.ErrDeadlineExceeded)) - Expect(body).To(ContainSubstring("aa")) - - w.Write([]byte("ok")) - }) - - expectedEnd := time.Now().Add(deadlineDelay) - resp, err := client.Post("https://localhost:"+port+"/read-deadline", "text/plain", neverEnding('a')) - Expect(err).ToNot(HaveOccurred()) - Expect(resp.StatusCode).To(Equal(200)) - - body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 2*deadlineDelay)) - Expect(err).ToNot(HaveOccurred()) - Expect(time.Now().After(expectedEnd)).To(BeTrue()) - Expect(string(body)).To(Equal("ok")) - }) - - It("supports write deadlines", func() { - if !go120 { - Skip("This test requires Go 1.20+") - } - - mux.HandleFunc("/write-deadline", func(w http.ResponseWriter, r *http.Request) { - defer GinkgoRecover() - err := setWriteDeadline(w, time.Now().Add(deadlineDelay)) - Expect(err).ToNot(HaveOccurred()) - - _, err = io.Copy(w, neverEnding('a')) - Expect(err).To(MatchError(os.ErrDeadlineExceeded)) - }) - - expectedEnd := time.Now().Add(deadlineDelay) - - resp, err := client.Get("https://localhost:" + port + "/write-deadline") - Expect(err).ToNot(HaveOccurred()) - Expect(resp.StatusCode).To(Equal(200)) - - body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 2*deadlineDelay)) - Expect(err).ToNot(HaveOccurred()) - Expect(time.Now().After(expectedEnd)).To(BeTrue()) - Expect(string(body)).To(ContainSubstring("aa")) - }) - - if version != protocol.VersionDraft29 { - It("serves other QUIC connections", func() { - tlsConf := testdata.GetTLSConfig() - tlsConf.NextProtos = []string{"h3"} - ln, err := quic.ListenAddr("localhost:0", tlsConf, nil) - Expect(err).ToNot(HaveOccurred()) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(done) - conn, err := ln.Accept(context.Background()) - Expect(err).ToNot(HaveOccurred()) - Expect(server.ServeQUICConn(conn)).To(Succeed()) - }() - - resp, err := client.Get(fmt.Sprintf("https://localhost:%d/hello", ln.Addr().(*net.UDPAddr).Port)) - Expect(err).ToNot(HaveOccurred()) - Expect(resp.StatusCode).To(Equal(http.StatusOK)) - client.Transport.(io.Closer).Close() - Eventually(done).Should(BeClosed()) - }) } }) - } + + req, err := http.NewRequest(http.MethodGet, "https://localhost:"+port+"/cancel", nil) + Expect(err).ToNot(HaveOccurred()) + ctx, cancel := context.WithCancel(context.Background()) + req = req.WithContext(ctx) + resp, err := client.Do(req) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200)) + cancel() + Eventually(handlerCalled).Should(BeClosed()) + _, err = resp.Body.Read([]byte{0}) + Expect(err).To(HaveOccurred()) + }) + + It("allows streamed HTTP requests", func() { + done := make(chan struct{}) + mux.HandleFunc("/echoline", func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + defer close(done) + w.WriteHeader(200) + w.(http.Flusher).Flush() + reader := bufio.NewReader(r.Body) + for { + msg, err := reader.ReadString('\n') + if err != nil { + return + } + _, err = w.Write([]byte(msg)) + Expect(err).ToNot(HaveOccurred()) + w.(http.Flusher).Flush() + } + }) + + r, w := io.Pipe() + req, err := http.NewRequest("PUT", "https://localhost:"+port+"/echoline", r) + Expect(err).ToNot(HaveOccurred()) + rsp, err := client.Do(req) + Expect(err).ToNot(HaveOccurred()) + Expect(rsp.StatusCode).To(Equal(200)) + + reader := bufio.NewReader(rsp.Body) + for i := 0; i < 5; i++ { + msg := fmt.Sprintf("Hello world, %d!\n", i) + fmt.Fprint(w, msg) + msgRcvd, err := reader.ReadString('\n') + Expect(err).ToNot(HaveOccurred()) + Expect(msgRcvd).To(Equal(msg)) + } + Expect(req.Body.Close()).To(Succeed()) + Eventually(done).Should(BeClosed()) + }) + + It("allows taking over the stream", func() { + mux.HandleFunc("/httpstreamer", func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + w.WriteHeader(200) + w.(http.Flusher).Flush() + + str := r.Body.(http3.HTTPStreamer).HTTPStream() + str.Write([]byte("foobar")) + + // Do this in a Go routine, so that the handler returns early. + // This way, we can also check that the HTTP/3 doesn't close the stream. + go func() { + defer GinkgoRecover() + _, err := io.Copy(str, str) + Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) + }() + }) + + req, err := http.NewRequest(http.MethodGet, "https://localhost:"+port+"/httpstreamer", nil) + Expect(err).ToNot(HaveOccurred()) + rsp, err := client.Transport.(*http3.RoundTripper).RoundTripOpt(req, http3.RoundTripOpt{DontCloseRequestStream: true}) + Expect(err).ToNot(HaveOccurred()) + Expect(rsp.StatusCode).To(Equal(200)) + + str := rsp.Body.(http3.HTTPStreamer).HTTPStream() + b := make([]byte, 6) + _, err = io.ReadFull(str, b) + Expect(err).ToNot(HaveOccurred()) + Expect(b).To(Equal([]byte("foobar"))) + + data := GeneratePRData(8 * 1024) + _, err = str.Write(data) + Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) + repl, err := io.ReadAll(str) + Expect(err).ToNot(HaveOccurred()) + Expect(repl).To(Equal(data)) + }) + + It("serves other QUIC connections", func() { + if version == protocol.VersionDraft29 { + Skip("This test only works on RFC versions") + } + tlsConf := testdata.GetTLSConfig() + tlsConf.NextProtos = []string{"h3"} + ln, err := quic.ListenAddr("localhost:0", tlsConf, nil) + Expect(err).ToNot(HaveOccurred()) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done) + conn, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(server.ServeQUICConn(conn)).To(Succeed()) + }() + + resp, err := client.Get(fmt.Sprintf("https://localhost:%d/hello", ln.Addr().(*net.UDPAddr).Port)) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(http.StatusOK)) + client.Transport.(io.Closer).Close() + Eventually(done).Should(BeClosed()) + }) + + It("supports read deadlines", func() { + if !go120 { + Skip("This test requires Go 1.20+") + } + + mux.HandleFunc("/read-deadline", func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + err := setReadDeadline(w, time.Now().Add(deadlineDelay)) + Expect(err).ToNot(HaveOccurred()) + + body, err := io.ReadAll(r.Body) + Expect(err).To(MatchError(os.ErrDeadlineExceeded)) + Expect(body).To(ContainSubstring("aa")) + + w.Write([]byte("ok")) + }) + + expectedEnd := time.Now().Add(deadlineDelay) + resp, err := client.Post("https://localhost:"+port+"/read-deadline", "text/plain", neverEnding('a')) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200)) + + body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 2*deadlineDelay)) + Expect(err).ToNot(HaveOccurred()) + Expect(time.Now().After(expectedEnd)).To(BeTrue()) + Expect(string(body)).To(Equal("ok")) + }) + + It("supports write deadlines", func() { + if !go120 { + Skip("This test requires Go 1.20+") + } + + mux.HandleFunc("/write-deadline", func(w http.ResponseWriter, r *http.Request) { + defer GinkgoRecover() + err := setWriteDeadline(w, time.Now().Add(deadlineDelay)) + Expect(err).ToNot(HaveOccurred()) + + _, err = io.Copy(w, neverEnding('a')) + Expect(err).To(MatchError(os.ErrDeadlineExceeded)) + }) + + expectedEnd := time.Now().Add(deadlineDelay) + + resp, err := client.Get("https://localhost:" + port + "/write-deadline") + Expect(err).ToNot(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(200)) + + body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 2*deadlineDelay)) + Expect(err).ToNot(HaveOccurred()) + Expect(time.Now().After(expectedEnd)).To(BeTrue()) + Expect(string(body)).To(ContainSubstring("aa")) + }) }) diff --git a/integrationtests/self/mitm_test.go b/integrationtests/self/mitm_test.go index 9ebcb2dd..1afb08f0 100644 --- a/integrationtests/self/mitm_test.go +++ b/integrationtests/self/mitm_test.go @@ -22,451 +22,435 @@ import ( ) var _ = Describe("MITM test", func() { - for _, v := range protocol.SupportedVersions { - version := v + const connIDLen = 6 // explicitly set the connection ID length, so the proxy can parse it - Context(fmt.Sprintf("with QUIC version %s", version), func() { - const connIDLen = 6 // explicitly set the connection ID length, so the proxy can parse it + var ( + serverUDPConn, clientUDPConn *net.UDPConn + serverConn quic.Connection + serverConfig *quic.Config + ) - var ( - serverUDPConn, clientUDPConn *net.UDPConn - serverConn quic.Connection - serverConfig *quic.Config - ) + startServerAndProxy := func(delayCb quicproxy.DelayCallback, dropCb quicproxy.DropCallback) (proxyPort int, closeFn func()) { + addr, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + serverUDPConn, err = net.ListenUDP("udp", addr) + Expect(err).ToNot(HaveOccurred()) + ln, err := quic.Listen(serverUDPConn, getTLSConfig(), serverConfig) + Expect(err).ToNot(HaveOccurred()) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done) + var err error + serverConn, err = ln.Accept(context.Background()) + if err != nil { + return + } + str, err := serverConn.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + _, err = str.Write(PRData) + Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) + }() + serverPort := ln.Addr().(*net.UDPAddr).Port + proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ + RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), + DelayPacket: delayCb, + DropPacket: dropCb, + }) + Expect(err).ToNot(HaveOccurred()) + return proxy.LocalPort(), func() { + proxy.Close() + ln.Close() + serverUDPConn.Close() + <-done + } + } - startServerAndProxy := func(delayCb quicproxy.DelayCallback, dropCb quicproxy.DropCallback) (proxyPort int, closeFn func()) { - addr, err := net.ResolveUDPAddr("udp", "localhost:0") - Expect(err).ToNot(HaveOccurred()) - serverUDPConn, err = net.ListenUDP("udp", addr) - Expect(err).ToNot(HaveOccurred()) - ln, err := quic.Listen(serverUDPConn, getTLSConfig(), serverConfig) - Expect(err).ToNot(HaveOccurred()) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(done) - var err error - serverConn, err = ln.Accept(context.Background()) - if err != nil { - return + BeforeEach(func() { + serverConfig = getQuicConfig(&quic.Config{ConnectionIDLength: connIDLen}) + addr, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + clientUDPConn, err = net.ListenUDP("udp", addr) + Expect(err).ToNot(HaveOccurred()) + }) + + Context("unsuccessful attacks", func() { + AfterEach(func() { + Eventually(serverConn.Context().Done()).Should(BeClosed()) + // Test shutdown is tricky due to the proxy. Just wait for a bit. + time.Sleep(50 * time.Millisecond) + Expect(clientUDPConn.Close()).To(Succeed()) + }) + + Context("injecting invalid packets", func() { + const rtt = 20 * time.Millisecond + + sendRandomPacketsOfSameType := func(conn net.PacketConn, remoteAddr net.Addr, raw []byte) { + defer GinkgoRecover() + const numPackets = 10 + ticker := time.NewTicker(rtt / numPackets) + defer ticker.Stop() + + if wire.IsLongHeaderPacket(raw[0]) { + hdr, _, _, err := wire.ParsePacket(raw) + Expect(err).ToNot(HaveOccurred()) + replyHdr := &wire.ExtendedHeader{ + Header: wire.Header{ + DestConnectionID: hdr.DestConnectionID, + SrcConnectionID: hdr.SrcConnectionID, + Type: hdr.Type, + Version: hdr.Version, + }, + PacketNumber: protocol.PacketNumber(mrand.Int31n(math.MaxInt32 / 4)), + PacketNumberLen: protocol.PacketNumberLen(mrand.Int31n(4) + 1), } - str, err := serverConn.OpenUniStream() + + for i := 0; i < numPackets; i++ { + payloadLen := mrand.Int31n(100) + replyHdr.Length = protocol.ByteCount(mrand.Int31n(payloadLen + 1)) + b, err := replyHdr.Append(nil, hdr.Version) + Expect(err).ToNot(HaveOccurred()) + r := make([]byte, payloadLen) + mrand.Read(r) + b = append(b, r...) + if _, err := conn.WriteTo(b, remoteAddr); err != nil { + return + } + <-ticker.C + } + } else { + connID, err := wire.ParseConnectionID(raw, connIDLen) Expect(err).ToNot(HaveOccurred()) - _, err = str.Write(PRData) - Expect(err).ToNot(HaveOccurred()) - Expect(str.Close()).To(Succeed()) - }() - serverPort := ln.Addr().(*net.UDPAddr).Port - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), - DelayPacket: delayCb, - DropPacket: dropCb, - }) - Expect(err).ToNot(HaveOccurred()) - return proxy.LocalPort(), func() { - proxy.Close() - ln.Close() - serverUDPConn.Close() - <-done + _, pn, pnLen, _, err := wire.ParseShortHeader(raw, connIDLen) + if err != nil { // normally, ParseShortHeader is called after decrypting the header + Expect(err).To(MatchError(wire.ErrInvalidReservedBits)) + } + for i := 0; i < numPackets; i++ { + b, err := wire.AppendShortHeader(nil, connID, pn, pnLen, protocol.KeyPhaseBit(mrand.Intn(2))) + Expect(err).ToNot(HaveOccurred()) + payloadLen := mrand.Int31n(100) + r := make([]byte, payloadLen) + mrand.Read(r) + b = append(b, r...) + if _, err := conn.WriteTo(b, remoteAddr); err != nil { + return + } + <-ticker.C + } } } - BeforeEach(func() { - serverConfig = getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - ConnectionIDLength: connIDLen, - }) - addr, err := net.ResolveUDPAddr("udp", "localhost:0") + runTest := func(delayCb quicproxy.DelayCallback) { + proxyPort, closeFn := startServerAndProxy(delayCb, nil) + defer closeFn() + raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort)) Expect(err).ToNot(HaveOccurred()) - clientUDPConn, err = net.ListenUDP("udp", addr) + conn, err := quic.Dial( + clientUDPConn, + raddr, + fmt.Sprintf("localhost:%d", proxyPort), + getTLSClientConfig(), + getQuicConfig(&quic.Config{ConnectionIDLength: connIDLen}), + ) Expect(err).ToNot(HaveOccurred()) - }) + str, err := conn.AcceptUniStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + data, err := io.ReadAll(str) + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal(PRData)) + Expect(conn.CloseWithError(0, "")).To(Succeed()) + } - Context("unsuccessful attacks", func() { - AfterEach(func() { - Eventually(serverConn.Context().Done()).Should(BeClosed()) - // Test shutdown is tricky due to the proxy. Just wait for a bit. - time.Sleep(50 * time.Millisecond) - Expect(clientUDPConn.Close()).To(Succeed()) - }) - - Context("injecting invalid packets", func() { - const rtt = 20 * time.Millisecond - - sendRandomPacketsOfSameType := func(conn net.PacketConn, remoteAddr net.Addr, raw []byte) { + It("downloads a message when the packets are injected towards the server", func() { + delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration { + if dir == quicproxy.DirectionIncoming { defer GinkgoRecover() - const numPackets = 10 - ticker := time.NewTicker(rtt / numPackets) - defer ticker.Stop() - - if wire.IsLongHeaderPacket(raw[0]) { - hdr, _, _, err := wire.ParsePacket(raw) - Expect(err).ToNot(HaveOccurred()) - replyHdr := &wire.ExtendedHeader{ - Header: wire.Header{ - DestConnectionID: hdr.DestConnectionID, - SrcConnectionID: hdr.SrcConnectionID, - Type: hdr.Type, - Version: hdr.Version, - }, - PacketNumber: protocol.PacketNumber(mrand.Int31n(math.MaxInt32 / 4)), - PacketNumberLen: protocol.PacketNumberLen(mrand.Int31n(4) + 1), - } - - for i := 0; i < numPackets; i++ { - payloadLen := mrand.Int31n(100) - replyHdr.Length = protocol.ByteCount(mrand.Int31n(payloadLen + 1)) - b, err := replyHdr.Append(nil, version) - Expect(err).ToNot(HaveOccurred()) - r := make([]byte, payloadLen) - mrand.Read(r) - b = append(b, r...) - if _, err := conn.WriteTo(b, remoteAddr); err != nil { - return - } - <-ticker.C - } - } else { - connID, err := wire.ParseConnectionID(raw, connIDLen) - Expect(err).ToNot(HaveOccurred()) - _, pn, pnLen, _, err := wire.ParseShortHeader(raw, connIDLen) - if err != nil { // normally, ParseShortHeader is called after decrypting the header - Expect(err).To(MatchError(wire.ErrInvalidReservedBits)) - } - for i := 0; i < numPackets; i++ { - b, err := wire.AppendShortHeader(nil, connID, pn, pnLen, protocol.KeyPhaseBit(mrand.Intn(2))) - Expect(err).ToNot(HaveOccurred()) - payloadLen := mrand.Int31n(100) - r := make([]byte, payloadLen) - mrand.Read(r) - b = append(b, r...) - if _, err := conn.WriteTo(b, remoteAddr); err != nil { - return - } - <-ticker.C - } - } + go sendRandomPacketsOfSameType(clientUDPConn, serverUDPConn.LocalAddr(), raw) } - - runTest := func(delayCb quicproxy.DelayCallback) { - proxyPort, closeFn := startServerAndProxy(delayCb, nil) - defer closeFn() - raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort)) - Expect(err).ToNot(HaveOccurred()) - conn, err := quic.Dial( - clientUDPConn, - raddr, - fmt.Sprintf("localhost:%d", proxyPort), - getTLSClientConfig(), - getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - ConnectionIDLength: connIDLen, - }), - ) - Expect(err).ToNot(HaveOccurred()) - str, err := conn.AcceptUniStream(context.Background()) - Expect(err).ToNot(HaveOccurred()) - data, err := io.ReadAll(str) - Expect(err).ToNot(HaveOccurred()) - Expect(data).To(Equal(PRData)) - Expect(conn.CloseWithError(0, "")).To(Succeed()) - } - - It("downloads a message when the packets are injected towards the server", func() { - delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration { - if dir == quicproxy.DirectionIncoming { - defer GinkgoRecover() - go sendRandomPacketsOfSameType(clientUDPConn, serverUDPConn.LocalAddr(), raw) - } - return rtt / 2 - } - runTest(delayCb) - }) - - It("downloads a message when the packets are injected towards the client", func() { - delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration { - if dir == quicproxy.DirectionOutgoing { - defer GinkgoRecover() - go sendRandomPacketsOfSameType(serverUDPConn, clientUDPConn.LocalAddr(), raw) - } - return rtt / 2 - } - runTest(delayCb) - }) - }) - - runTest := func(dropCb quicproxy.DropCallback) { - proxyPort, closeFn := startServerAndProxy(nil, dropCb) - defer closeFn() - raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort)) - Expect(err).ToNot(HaveOccurred()) - conn, err := quic.Dial( - clientUDPConn, - raddr, - fmt.Sprintf("localhost:%d", proxyPort), - getTLSClientConfig(), - getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - ConnectionIDLength: connIDLen, - }), - ) - Expect(err).ToNot(HaveOccurred()) - str, err := conn.AcceptUniStream(context.Background()) - Expect(err).ToNot(HaveOccurred()) - data, err := io.ReadAll(str) - Expect(err).ToNot(HaveOccurred()) - Expect(data).To(Equal(PRData)) - Expect(conn.CloseWithError(0, "")).To(Succeed()) + return rtt / 2 } - - Context("duplicating packets", func() { - It("downloads a message when packets are duplicated towards the server", func() { - dropCb := func(dir quicproxy.Direction, raw []byte) bool { - defer GinkgoRecover() - if dir == quicproxy.DirectionIncoming { - _, err := clientUDPConn.WriteTo(raw, serverUDPConn.LocalAddr()) - Expect(err).ToNot(HaveOccurred()) - } - return false - } - runTest(dropCb) - }) - - It("downloads a message when packets are duplicated towards the client", func() { - dropCb := func(dir quicproxy.Direction, raw []byte) bool { - defer GinkgoRecover() - if dir == quicproxy.DirectionOutgoing { - _, err := serverUDPConn.WriteTo(raw, clientUDPConn.LocalAddr()) - Expect(err).ToNot(HaveOccurred()) - } - return false - } - runTest(dropCb) - }) - }) - - Context("corrupting packets", func() { - const idleTimeout = time.Second - - var numCorrupted, numPackets int32 - - BeforeEach(func() { - numCorrupted = 0 - numPackets = 0 - serverConfig.MaxIdleTimeout = idleTimeout - }) - - AfterEach(func() { - num := atomic.LoadInt32(&numCorrupted) - fmt.Fprintf(GinkgoWriter, "Corrupted %d of %d packets.", num, atomic.LoadInt32(&numPackets)) - Expect(num).To(BeNumerically(">=", 1)) - // If the packet containing the CONNECTION_CLOSE is corrupted, - // we have to wait for the connection to time out. - Eventually(serverConn.Context().Done(), 3*idleTimeout).Should(BeClosed()) - }) - - It("downloads a message when packet are corrupted towards the server", func() { - const interval = 4 // corrupt every 4th packet (stochastically) - dropCb := func(dir quicproxy.Direction, raw []byte) bool { - defer GinkgoRecover() - if dir == quicproxy.DirectionIncoming { - atomic.AddInt32(&numPackets, 1) - if mrand.Intn(interval) == 0 { - pos := mrand.Intn(len(raw)) - raw[pos] = byte(mrand.Intn(256)) - _, err := clientUDPConn.WriteTo(raw, serverUDPConn.LocalAddr()) - Expect(err).ToNot(HaveOccurred()) - atomic.AddInt32(&numCorrupted, 1) - return true - } - } - return false - } - runTest(dropCb) - }) - - It("downloads a message when packet are corrupted towards the client", func() { - const interval = 10 // corrupt every 10th packet (stochastically) - dropCb := func(dir quicproxy.Direction, raw []byte) bool { - defer GinkgoRecover() - if dir == quicproxy.DirectionOutgoing { - atomic.AddInt32(&numPackets, 1) - if mrand.Intn(interval) == 0 { - pos := mrand.Intn(len(raw)) - raw[pos] = byte(mrand.Intn(256)) - _, err := serverUDPConn.WriteTo(raw, clientUDPConn.LocalAddr()) - Expect(err).ToNot(HaveOccurred()) - atomic.AddInt32(&numCorrupted, 1) - return true - } - } - return false - } - runTest(dropCb) - }) - }) + runTest(delayCb) }) - Context("successful injection attacks", func() { - // These tests demonstrate that the QUIC protocol is vulnerable to injection attacks before the handshake - // finishes. In particular, an adversary who can intercept packets coming from one endpoint and send a reply - // that arrives before the real reply can tear down the connection in multiple ways. - - const rtt = 20 * time.Millisecond - - runTest := func(delayCb quicproxy.DelayCallback) (closeFn func(), err error) { - proxyPort, closeFn := startServerAndProxy(delayCb, nil) - raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort)) - Expect(err).ToNot(HaveOccurred()) - _, err = quic.Dial( - clientUDPConn, - raddr, - fmt.Sprintf("localhost:%d", proxyPort), - getTLSClientConfig(), - getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - ConnectionIDLength: connIDLen, - HandshakeIdleTimeout: 2 * time.Second, - }), - ) - return closeFn, err + It("downloads a message when the packets are injected towards the client", func() { + delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration { + if dir == quicproxy.DirectionOutgoing { + defer GinkgoRecover() + go sendRandomPacketsOfSameType(serverUDPConn, clientUDPConn.LocalAddr(), raw) + } + return rtt / 2 } - - // fails immediately because client connection closes when it can't find compatible version - It("fails when a forged version negotiation packet is sent to client", func() { - done := make(chan struct{}) - delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration { - if dir == quicproxy.DirectionIncoming { - defer GinkgoRecover() - - hdr, _, _, err := wire.ParsePacket(raw) - Expect(err).ToNot(HaveOccurred()) - - if hdr.Type != protocol.PacketTypeInitial { - return 0 - } - - // Create fake version negotiation packet with no supported versions - versions := []protocol.VersionNumber{} - packet := wire.ComposeVersionNegotiation( - protocol.ArbitraryLenConnectionID(hdr.SrcConnectionID.Bytes()), - protocol.ArbitraryLenConnectionID(hdr.DestConnectionID.Bytes()), - versions, - ) - - // Send the packet - _, err = serverUDPConn.WriteTo(packet, clientUDPConn.LocalAddr()) - Expect(err).ToNot(HaveOccurred()) - close(done) - } - return rtt / 2 - } - closeFn, err := runTest(delayCb) - defer closeFn() - Expect(err).To(HaveOccurred()) - vnErr := &quic.VersionNegotiationError{} - Expect(errors.As(err, &vnErr)).To(BeTrue()) - Eventually(done).Should(BeClosed()) - }) - - // times out, because client doesn't accept subsequent real retry packets from server - // as it has already accepted a retry. - // TODO: determine behavior when server does not send Retry packets - It("fails when a forged Retry packet with modified srcConnID is sent to client", func() { - serverConfig.RequireAddressValidation = func(net.Addr) bool { return true } - var initialPacketIntercepted bool - done := make(chan struct{}) - delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration { - if dir == quicproxy.DirectionIncoming && !initialPacketIntercepted { - defer GinkgoRecover() - defer close(done) - - hdr, _, _, err := wire.ParsePacket(raw) - Expect(err).ToNot(HaveOccurred()) - - if hdr.Type != protocol.PacketTypeInitial { - return 0 - } - - initialPacketIntercepted = true - fakeSrcConnID := protocol.ParseConnectionID([]byte{0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12}) - retryPacket := testutils.ComposeRetryPacket(fakeSrcConnID, hdr.SrcConnectionID, hdr.DestConnectionID, []byte("token"), hdr.Version) - - _, err = serverUDPConn.WriteTo(retryPacket, clientUDPConn.LocalAddr()) - Expect(err).ToNot(HaveOccurred()) - } - return rtt / 2 - } - closeFn, err := runTest(delayCb) - defer closeFn() - Expect(err).To(HaveOccurred()) - Expect(err.(net.Error).Timeout()).To(BeTrue()) - Eventually(done).Should(BeClosed()) - }) - - // times out, because client doesn't accept real retry packets from server because - // it has already accepted an initial. - // TODO: determine behavior when server does not send Retry packets - It("fails when a forged initial packet is sent to client", func() { - done := make(chan struct{}) - var injected bool - delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration { - if dir == quicproxy.DirectionIncoming { - defer GinkgoRecover() - - hdr, _, _, err := wire.ParsePacket(raw) - Expect(err).ToNot(HaveOccurred()) - if hdr.Type != protocol.PacketTypeInitial || injected { - return 0 - } - defer close(done) - injected = true - initialPacket := testutils.ComposeInitialPacket(hdr.DestConnectionID, hdr.SrcConnectionID, hdr.Version, hdr.DestConnectionID, nil) - _, err = serverUDPConn.WriteTo(initialPacket, clientUDPConn.LocalAddr()) - Expect(err).ToNot(HaveOccurred()) - } - return rtt - } - closeFn, err := runTest(delayCb) - defer closeFn() - Expect(err).To(HaveOccurred()) - Expect(err.(net.Error).Timeout()).To(BeTrue()) - Eventually(done).Should(BeClosed()) - }) - - // client connection closes immediately on receiving ack for unsent packet - It("fails when a forged initial packet with ack for unsent packet is sent to client", func() { - done := make(chan struct{}) - var injected bool - delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration { - if dir == quicproxy.DirectionIncoming { - defer GinkgoRecover() - - hdr, _, _, err := wire.ParsePacket(raw) - Expect(err).ToNot(HaveOccurred()) - if hdr.Type != protocol.PacketTypeInitial || injected { - return 0 - } - defer close(done) - injected = true - // Fake Initial with ACK for packet 2 (unsent) - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} - initialPacket := testutils.ComposeInitialPacket(hdr.DestConnectionID, hdr.SrcConnectionID, hdr.Version, hdr.DestConnectionID, []wire.Frame{ack}) - _, err = serverUDPConn.WriteTo(initialPacket, clientUDPConn.LocalAddr()) - Expect(err).ToNot(HaveOccurred()) - } - return rtt - } - closeFn, err := runTest(delayCb) - defer closeFn() - Expect(err).To(HaveOccurred()) - var transportErr *quic.TransportError - Expect(errors.As(err, &transportErr)).To(BeTrue()) - Expect(transportErr.ErrorCode).To(Equal(quic.ProtocolViolation)) - Expect(transportErr.ErrorMessage).To(ContainSubstring("received ACK for an unsent packet")) - Eventually(done).Should(BeClosed()) - }) + runTest(delayCb) }) }) - } + + runTest := func(dropCb quicproxy.DropCallback) { + proxyPort, closeFn := startServerAndProxy(nil, dropCb) + defer closeFn() + raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort)) + Expect(err).ToNot(HaveOccurred()) + conn, err := quic.Dial( + clientUDPConn, + raddr, + fmt.Sprintf("localhost:%d", proxyPort), + getTLSClientConfig(), + getQuicConfig(&quic.Config{ConnectionIDLength: connIDLen}), + ) + Expect(err).ToNot(HaveOccurred()) + str, err := conn.AcceptUniStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + data, err := io.ReadAll(str) + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal(PRData)) + Expect(conn.CloseWithError(0, "")).To(Succeed()) + } + + Context("duplicating packets", func() { + It("downloads a message when packets are duplicated towards the server", func() { + dropCb := func(dir quicproxy.Direction, raw []byte) bool { + defer GinkgoRecover() + if dir == quicproxy.DirectionIncoming { + _, err := clientUDPConn.WriteTo(raw, serverUDPConn.LocalAddr()) + Expect(err).ToNot(HaveOccurred()) + } + return false + } + runTest(dropCb) + }) + + It("downloads a message when packets are duplicated towards the client", func() { + dropCb := func(dir quicproxy.Direction, raw []byte) bool { + defer GinkgoRecover() + if dir == quicproxy.DirectionOutgoing { + _, err := serverUDPConn.WriteTo(raw, clientUDPConn.LocalAddr()) + Expect(err).ToNot(HaveOccurred()) + } + return false + } + runTest(dropCb) + }) + }) + + Context("corrupting packets", func() { + const idleTimeout = time.Second + + var numCorrupted, numPackets int32 + + BeforeEach(func() { + numCorrupted = 0 + numPackets = 0 + serverConfig.MaxIdleTimeout = idleTimeout + }) + + AfterEach(func() { + num := atomic.LoadInt32(&numCorrupted) + fmt.Fprintf(GinkgoWriter, "Corrupted %d of %d packets.", num, atomic.LoadInt32(&numPackets)) + Expect(num).To(BeNumerically(">=", 1)) + // If the packet containing the CONNECTION_CLOSE is corrupted, + // we have to wait for the connection to time out. + Eventually(serverConn.Context().Done(), 3*idleTimeout).Should(BeClosed()) + }) + + It("downloads a message when packet are corrupted towards the server", func() { + const interval = 4 // corrupt every 4th packet (stochastically) + dropCb := func(dir quicproxy.Direction, raw []byte) bool { + defer GinkgoRecover() + if dir == quicproxy.DirectionIncoming { + atomic.AddInt32(&numPackets, 1) + if mrand.Intn(interval) == 0 { + pos := mrand.Intn(len(raw)) + raw[pos] = byte(mrand.Intn(256)) + _, err := clientUDPConn.WriteTo(raw, serverUDPConn.LocalAddr()) + Expect(err).ToNot(HaveOccurred()) + atomic.AddInt32(&numCorrupted, 1) + return true + } + } + return false + } + runTest(dropCb) + }) + + It("downloads a message when packet are corrupted towards the client", func() { + const interval = 10 // corrupt every 10th packet (stochastically) + dropCb := func(dir quicproxy.Direction, raw []byte) bool { + defer GinkgoRecover() + if dir == quicproxy.DirectionOutgoing { + atomic.AddInt32(&numPackets, 1) + if mrand.Intn(interval) == 0 { + pos := mrand.Intn(len(raw)) + raw[pos] = byte(mrand.Intn(256)) + _, err := serverUDPConn.WriteTo(raw, clientUDPConn.LocalAddr()) + Expect(err).ToNot(HaveOccurred()) + atomic.AddInt32(&numCorrupted, 1) + return true + } + } + return false + } + runTest(dropCb) + }) + }) + }) + + Context("successful injection attacks", func() { + // These tests demonstrate that the QUIC protocol is vulnerable to injection attacks before the handshake + // finishes. In particular, an adversary who can intercept packets coming from one endpoint and send a reply + // that arrives before the real reply can tear down the connection in multiple ways. + + const rtt = 20 * time.Millisecond + + runTest := func(delayCb quicproxy.DelayCallback) (closeFn func(), err error) { + proxyPort, closeFn := startServerAndProxy(delayCb, nil) + raddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("localhost:%d", proxyPort)) + Expect(err).ToNot(HaveOccurred()) + _, err = quic.Dial( + clientUDPConn, + raddr, + fmt.Sprintf("localhost:%d", proxyPort), + getTLSClientConfig(), + getQuicConfig(&quic.Config{ + ConnectionIDLength: connIDLen, + HandshakeIdleTimeout: 2 * time.Second, + }), + ) + return closeFn, err + } + + // fails immediately because client connection closes when it can't find compatible version + It("fails when a forged version negotiation packet is sent to client", func() { + done := make(chan struct{}) + delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration { + if dir == quicproxy.DirectionIncoming { + defer GinkgoRecover() + + hdr, _, _, err := wire.ParsePacket(raw) + Expect(err).ToNot(HaveOccurred()) + + if hdr.Type != protocol.PacketTypeInitial { + return 0 + } + + // Create fake version negotiation packet with no supported versions + versions := []protocol.VersionNumber{} + packet := wire.ComposeVersionNegotiation( + protocol.ArbitraryLenConnectionID(hdr.SrcConnectionID.Bytes()), + protocol.ArbitraryLenConnectionID(hdr.DestConnectionID.Bytes()), + versions, + ) + + // Send the packet + _, err = serverUDPConn.WriteTo(packet, clientUDPConn.LocalAddr()) + Expect(err).ToNot(HaveOccurred()) + close(done) + } + return rtt / 2 + } + closeFn, err := runTest(delayCb) + defer closeFn() + Expect(err).To(HaveOccurred()) + vnErr := &quic.VersionNegotiationError{} + Expect(errors.As(err, &vnErr)).To(BeTrue()) + Eventually(done).Should(BeClosed()) + }) + + // times out, because client doesn't accept subsequent real retry packets from server + // as it has already accepted a retry. + // TODO: determine behavior when server does not send Retry packets + It("fails when a forged Retry packet with modified srcConnID is sent to client", func() { + serverConfig.RequireAddressValidation = func(net.Addr) bool { return true } + var initialPacketIntercepted bool + done := make(chan struct{}) + delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration { + if dir == quicproxy.DirectionIncoming && !initialPacketIntercepted { + defer GinkgoRecover() + defer close(done) + + hdr, _, _, err := wire.ParsePacket(raw) + Expect(err).ToNot(HaveOccurred()) + + if hdr.Type != protocol.PacketTypeInitial { + return 0 + } + + initialPacketIntercepted = true + fakeSrcConnID := protocol.ParseConnectionID([]byte{0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12}) + retryPacket := testutils.ComposeRetryPacket(fakeSrcConnID, hdr.SrcConnectionID, hdr.DestConnectionID, []byte("token"), hdr.Version) + + _, err = serverUDPConn.WriteTo(retryPacket, clientUDPConn.LocalAddr()) + Expect(err).ToNot(HaveOccurred()) + } + return rtt / 2 + } + closeFn, err := runTest(delayCb) + defer closeFn() + Expect(err).To(HaveOccurred()) + Expect(err.(net.Error).Timeout()).To(BeTrue()) + Eventually(done).Should(BeClosed()) + }) + + // times out, because client doesn't accept real retry packets from server because + // it has already accepted an initial. + // TODO: determine behavior when server does not send Retry packets + It("fails when a forged initial packet is sent to client", func() { + done := make(chan struct{}) + var injected bool + delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration { + if dir == quicproxy.DirectionIncoming { + defer GinkgoRecover() + + hdr, _, _, err := wire.ParsePacket(raw) + Expect(err).ToNot(HaveOccurred()) + if hdr.Type != protocol.PacketTypeInitial || injected { + return 0 + } + defer close(done) + injected = true + initialPacket := testutils.ComposeInitialPacket(hdr.DestConnectionID, hdr.SrcConnectionID, hdr.Version, hdr.DestConnectionID, nil) + _, err = serverUDPConn.WriteTo(initialPacket, clientUDPConn.LocalAddr()) + Expect(err).ToNot(HaveOccurred()) + } + return rtt + } + closeFn, err := runTest(delayCb) + defer closeFn() + Expect(err).To(HaveOccurred()) + Expect(err.(net.Error).Timeout()).To(BeTrue()) + Eventually(done).Should(BeClosed()) + }) + + // client connection closes immediately on receiving ack for unsent packet + It("fails when a forged initial packet with ack for unsent packet is sent to client", func() { + done := make(chan struct{}) + var injected bool + delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration { + if dir == quicproxy.DirectionIncoming { + defer GinkgoRecover() + + hdr, _, _, err := wire.ParsePacket(raw) + Expect(err).ToNot(HaveOccurred()) + if hdr.Type != protocol.PacketTypeInitial || injected { + return 0 + } + defer close(done) + injected = true + // Fake Initial with ACK for packet 2 (unsent) + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}} + initialPacket := testutils.ComposeInitialPacket(hdr.DestConnectionID, hdr.SrcConnectionID, hdr.Version, hdr.DestConnectionID, []wire.Frame{ack}) + _, err = serverUDPConn.WriteTo(initialPacket, clientUDPConn.LocalAddr()) + Expect(err).ToNot(HaveOccurred()) + } + return rtt + } + closeFn, err := runTest(delayCb) + defer closeFn() + Expect(err).To(HaveOccurred()) + var transportErr *quic.TransportError + Expect(errors.As(err, &transportErr)).To(BeTrue()) + Expect(transportErr.ErrorCode).To(Equal(quic.ProtocolViolation)) + Expect(transportErr.ErrorMessage).To(ContainSubstring("received ACK for an unsent packet")) + Eventually(done).Should(BeClosed()) + }) + }) }) diff --git a/integrationtests/self/multiplex_test.go b/integrationtests/self/multiplex_test.go index 2a699502..55b69887 100644 --- a/integrationtests/self/multiplex_test.go +++ b/integrationtests/self/multiplex_test.go @@ -9,213 +9,206 @@ import ( "time" "github.com/quic-go/quic-go" - "github.com/quic-go/quic-go/internal/protocol" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) var _ = Describe("Multiplexing", func() { - for _, v := range protocol.SupportedVersions { - version := v - - Context(fmt.Sprintf("with QUIC version %s", version), func() { - runServer := func(ln quic.Listener) { + runServer := func(ln quic.Listener) { + go func() { + defer GinkgoRecover() + for { + conn, err := ln.Accept(context.Background()) + if err != nil { + return + } go func() { defer GinkgoRecover() - for { - conn, err := ln.Accept(context.Background()) - if err != nil { - return - } - go func() { - defer GinkgoRecover() - str, err := conn.OpenStream() - Expect(err).ToNot(HaveOccurred()) - defer str.Close() - _, err = str.Write(PRData) - Expect(err).ToNot(HaveOccurred()) - }() - } + str, err := conn.OpenStream() + Expect(err).ToNot(HaveOccurred()) + defer str.Close() + _, err = str.Write(PRData) + Expect(err).ToNot(HaveOccurred()) }() } - - dial := func(pconn net.PacketConn, addr net.Addr) { - conn, err := quic.Dial( - pconn, - addr, - fmt.Sprintf("localhost:%d", addr.(*net.UDPAddr).Port), - getTLSClientConfig(), - getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}), - ) - Expect(err).ToNot(HaveOccurred()) - defer conn.CloseWithError(0, "") - str, err := conn.AcceptStream(context.Background()) - Expect(err).ToNot(HaveOccurred()) - data, err := io.ReadAll(str) - Expect(err).ToNot(HaveOccurred()) - Expect(data).To(Equal(PRData)) - } - - Context("multiplexing clients on the same conn", func() { - getListener := func() quic.Listener { - ln, err := quic.ListenAddr( - "localhost:0", - getTLSConfig(), - getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}), - ) - Expect(err).ToNot(HaveOccurred()) - return ln - } - - It("multiplexes connections to the same server", func() { - server := getListener() - runServer(server) - defer server.Close() - - addr, err := net.ResolveUDPAddr("udp", "localhost:0") - Expect(err).ToNot(HaveOccurred()) - conn, err := net.ListenUDP("udp", addr) - Expect(err).ToNot(HaveOccurred()) - defer conn.Close() - - done1 := make(chan struct{}) - done2 := make(chan struct{}) - go func() { - defer GinkgoRecover() - dial(conn, server.Addr()) - close(done1) - }() - go func() { - defer GinkgoRecover() - dial(conn, server.Addr()) - close(done2) - }() - timeout := 30 * time.Second - if debugLog() { - timeout = time.Minute - } - Eventually(done1, timeout).Should(BeClosed()) - Eventually(done2, timeout).Should(BeClosed()) - }) - - It("multiplexes connections to different servers", func() { - server1 := getListener() - runServer(server1) - defer server1.Close() - server2 := getListener() - runServer(server2) - defer server2.Close() - - addr, err := net.ResolveUDPAddr("udp", "localhost:0") - Expect(err).ToNot(HaveOccurred()) - conn, err := net.ListenUDP("udp", addr) - Expect(err).ToNot(HaveOccurred()) - defer conn.Close() - - done1 := make(chan struct{}) - done2 := make(chan struct{}) - go func() { - defer GinkgoRecover() - dial(conn, server1.Addr()) - close(done1) - }() - go func() { - defer GinkgoRecover() - dial(conn, server2.Addr()) - close(done2) - }() - timeout := 30 * time.Second - if debugLog() { - timeout = time.Minute - } - Eventually(done1, timeout).Should(BeClosed()) - Eventually(done2, timeout).Should(BeClosed()) - }) - }) - - Context("multiplexing server and client on the same conn", func() { - It("connects to itself", func() { - addr, err := net.ResolveUDPAddr("udp", "localhost:0") - Expect(err).ToNot(HaveOccurred()) - conn, err := net.ListenUDP("udp", addr) - Expect(err).ToNot(HaveOccurred()) - defer conn.Close() - - server, err := quic.Listen( - conn, - getTLSConfig(), - getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}), - ) - Expect(err).ToNot(HaveOccurred()) - runServer(server) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - dial(conn, server.Addr()) - close(done) - }() - timeout := 30 * time.Second - if debugLog() { - timeout = time.Minute - } - Eventually(done, timeout).Should(BeClosed()) - }) - - It("runs a server and client on the same conn", func() { - if runtime.GOOS == "linux" { - Skip("This test would require setting of iptables rules, see https://stackoverflow.com/questions/23859164/linux-udp-socket-sendto-operation-not-permitted.") - } - addr1, err := net.ResolveUDPAddr("udp", "localhost:0") - Expect(err).ToNot(HaveOccurred()) - conn1, err := net.ListenUDP("udp", addr1) - Expect(err).ToNot(HaveOccurred()) - defer conn1.Close() - - addr2, err := net.ResolveUDPAddr("udp", "localhost:0") - Expect(err).ToNot(HaveOccurred()) - conn2, err := net.ListenUDP("udp", addr2) - Expect(err).ToNot(HaveOccurred()) - defer conn2.Close() - - server1, err := quic.Listen( - conn1, - getTLSConfig(), - getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}), - ) - Expect(err).ToNot(HaveOccurred()) - runServer(server1) - defer server1.Close() - - server2, err := quic.Listen( - conn2, - getTLSConfig(), - getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}), - ) - Expect(err).ToNot(HaveOccurred()) - runServer(server2) - defer server2.Close() - - done1 := make(chan struct{}) - done2 := make(chan struct{}) - go func() { - defer GinkgoRecover() - dial(conn2, server1.Addr()) - close(done1) - }() - go func() { - defer GinkgoRecover() - dial(conn1, server2.Addr()) - close(done2) - }() - timeout := 30 * time.Second - if debugLog() { - timeout = time.Minute - } - Eventually(done1, timeout).Should(BeClosed()) - Eventually(done2, timeout).Should(BeClosed()) - }) - }) - }) + }() } + + dial := func(pconn net.PacketConn, addr net.Addr) { + conn, err := quic.Dial( + pconn, + addr, + fmt.Sprintf("localhost:%d", addr.(*net.UDPAddr).Port), + getTLSClientConfig(), + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + defer conn.CloseWithError(0, "") + str, err := conn.AcceptStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + data, err := io.ReadAll(str) + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal(PRData)) + } + + Context("multiplexing clients on the same conn", func() { + getListener := func() quic.Listener { + ln, err := quic.ListenAddr( + "localhost:0", + getTLSConfig(), + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + return ln + } + + It("multiplexes connections to the same server", func() { + server := getListener() + runServer(server) + defer server.Close() + + addr, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + conn, err := net.ListenUDP("udp", addr) + Expect(err).ToNot(HaveOccurred()) + defer conn.Close() + + done1 := make(chan struct{}) + done2 := make(chan struct{}) + go func() { + defer GinkgoRecover() + dial(conn, server.Addr()) + close(done1) + }() + go func() { + defer GinkgoRecover() + dial(conn, server.Addr()) + close(done2) + }() + timeout := 30 * time.Second + if debugLog() { + timeout = time.Minute + } + Eventually(done1, timeout).Should(BeClosed()) + Eventually(done2, timeout).Should(BeClosed()) + }) + + It("multiplexes connections to different servers", func() { + server1 := getListener() + runServer(server1) + defer server1.Close() + server2 := getListener() + runServer(server2) + defer server2.Close() + + addr, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + conn, err := net.ListenUDP("udp", addr) + Expect(err).ToNot(HaveOccurred()) + defer conn.Close() + + done1 := make(chan struct{}) + done2 := make(chan struct{}) + go func() { + defer GinkgoRecover() + dial(conn, server1.Addr()) + close(done1) + }() + go func() { + defer GinkgoRecover() + dial(conn, server2.Addr()) + close(done2) + }() + timeout := 30 * time.Second + if debugLog() { + timeout = time.Minute + } + Eventually(done1, timeout).Should(BeClosed()) + Eventually(done2, timeout).Should(BeClosed()) + }) + }) + + Context("multiplexing server and client on the same conn", func() { + It("connects to itself", func() { + addr, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + conn, err := net.ListenUDP("udp", addr) + Expect(err).ToNot(HaveOccurred()) + defer conn.Close() + + server, err := quic.Listen( + conn, + getTLSConfig(), + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + runServer(server) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + dial(conn, server.Addr()) + close(done) + }() + timeout := 30 * time.Second + if debugLog() { + timeout = time.Minute + } + Eventually(done, timeout).Should(BeClosed()) + }) + + It("runs a server and client on the same conn", func() { + if runtime.GOOS == "linux" { + Skip("This test would require setting of iptables rules, see https://stackoverflow.com/questions/23859164/linux-udp-socket-sendto-operation-not-permitted.") + } + addr1, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + conn1, err := net.ListenUDP("udp", addr1) + Expect(err).ToNot(HaveOccurred()) + defer conn1.Close() + + addr2, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + conn2, err := net.ListenUDP("udp", addr2) + Expect(err).ToNot(HaveOccurred()) + defer conn2.Close() + + server1, err := quic.Listen( + conn1, + getTLSConfig(), + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + runServer(server1) + defer server1.Close() + + server2, err := quic.Listen( + conn2, + getTLSConfig(), + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + runServer(server2) + defer server2.Close() + + done1 := make(chan struct{}) + done2 := make(chan struct{}) + go func() { + defer GinkgoRecover() + dial(conn2, server1.Addr()) + close(done1) + }() + go func() { + defer GinkgoRecover() + dial(conn1, server2.Addr()) + close(done2) + }() + timeout := 30 * time.Second + if debugLog() { + timeout = time.Minute + } + Eventually(done1, timeout).Should(BeClosed()) + Eventually(done2, timeout).Should(BeClosed()) + }) + }) }) diff --git a/integrationtests/self/rtt_test.go b/integrationtests/self/rtt_test.go index a818c870..223177fd 100644 --- a/integrationtests/self/rtt_test.go +++ b/integrationtests/self/rtt_test.go @@ -9,41 +9,72 @@ import ( "github.com/quic-go/quic-go" quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy" - "github.com/quic-go/quic-go/internal/protocol" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) var _ = Describe("non-zero RTT", func() { - for _, v := range protocol.SupportedVersions { - version := v - - runServer := func() quic.Listener { - ln, err := quic.ListenAddr( - "localhost:0", - getTLSConfig(), - getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}), - ) + runServer := func() quic.Listener { + ln, err := quic.ListenAddr( + "localhost:0", + getTLSConfig(), + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + go func() { + defer GinkgoRecover() + conn, err := ln.Accept(context.Background()) Expect(err).ToNot(HaveOccurred()) - go func() { - defer GinkgoRecover() - conn, err := ln.Accept(context.Background()) - Expect(err).ToNot(HaveOccurred()) - str, err := conn.OpenStream() - Expect(err).ToNot(HaveOccurred()) - _, err = str.Write(PRData) - Expect(err).ToNot(HaveOccurred()) - str.Close() - }() - return ln - } + str, err := conn.OpenStream() + Expect(err).ToNot(HaveOccurred()) + _, err = str.Write(PRData) + Expect(err).ToNot(HaveOccurred()) + str.Close() + }() + return ln + } + + downloadFile := func(port int) { + conn, err := quic.DialAddr( + fmt.Sprintf("localhost:%d", port), + getTLSClientConfig(), + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + str, err := conn.AcceptStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + data, err := io.ReadAll(str) + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal(PRData)) + conn.CloseWithError(0, "") + } + + for _, r := range [...]time.Duration{ + 10 * time.Millisecond, + 50 * time.Millisecond, + 100 * time.Millisecond, + 200 * time.Millisecond, + } { + rtt := r + + It(fmt.Sprintf("downloads a message with %s RTT", rtt), func() { + ln := runServer() + defer ln.Close() + serverPort := ln.Addr().(*net.UDPAddr).Port + proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ + RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), + DelayPacket: func(quicproxy.Direction, []byte) time.Duration { + return rtt / 2 + }, + }) + Expect(err).ToNot(HaveOccurred()) + defer proxy.Close() - downloadFile := func(port int) { conn, err := quic.DialAddr( - fmt.Sprintf("localhost:%d", port), + fmt.Sprintf("localhost:%d", proxy.LocalPort()), getTLSClientConfig(), - getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}), + getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) str, err := conn.AcceptStream(context.Background()) @@ -52,67 +83,29 @@ var _ = Describe("non-zero RTT", func() { Expect(err).ToNot(HaveOccurred()) Expect(data).To(Equal(PRData)) conn.CloseWithError(0, "") - } + }) + } - Context(fmt.Sprintf("with QUIC version %s", version), func() { - for _, r := range [...]time.Duration{ - 10 * time.Millisecond, - 50 * time.Millisecond, - 100 * time.Millisecond, - 200 * time.Millisecond, - } { - rtt := r + for _, r := range [...]time.Duration{ + 10 * time.Millisecond, + 40 * time.Millisecond, + } { + rtt := r - It(fmt.Sprintf("downloads a message with %s RTT", rtt), func() { - ln := runServer() - defer ln.Close() - serverPort := ln.Addr().(*net.UDPAddr).Port - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), - DelayPacket: func(quicproxy.Direction, []byte) time.Duration { - return rtt / 2 - }, - }) - Expect(err).ToNot(HaveOccurred()) - defer proxy.Close() + It(fmt.Sprintf("downloads a message with %s RTT, with reordering", rtt), func() { + ln := runServer() + defer ln.Close() + serverPort := ln.Addr().(*net.UDPAddr).Port + proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ + RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), + DelayPacket: func(quicproxy.Direction, []byte) time.Duration { + return randomDuration(rtt/2, rtt*3/2) / 2 + }, + }) + Expect(err).ToNot(HaveOccurred()) + defer proxy.Close() - conn, err := quic.DialAddr( - fmt.Sprintf("localhost:%d", proxy.LocalPort()), - getTLSClientConfig(), - getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}), - ) - Expect(err).ToNot(HaveOccurred()) - str, err := conn.AcceptStream(context.Background()) - Expect(err).ToNot(HaveOccurred()) - data, err := io.ReadAll(str) - Expect(err).ToNot(HaveOccurred()) - Expect(data).To(Equal(PRData)) - conn.CloseWithError(0, "") - }) - } - - for _, r := range [...]time.Duration{ - 10 * time.Millisecond, - 40 * time.Millisecond, - } { - rtt := r - - It(fmt.Sprintf("downloads a message with %s RTT, with reordering", rtt), func() { - ln := runServer() - defer ln.Close() - serverPort := ln.Addr().(*net.UDPAddr).Port - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), - DelayPacket: func(quicproxy.Direction, []byte) time.Duration { - return randomDuration(rtt/2, rtt*3/2) / 2 - }, - }) - Expect(err).ToNot(HaveOccurred()) - defer proxy.Close() - - downloadFile(proxy.LocalPort()) - }) - } + downloadFile(proxy.LocalPort()) }) } }) diff --git a/integrationtests/self/self_suite_test.go b/integrationtests/self/self_suite_test.go index 83d1b9e0..015fd328 100644 --- a/integrationtests/self/self_suite_test.go +++ b/integrationtests/self/self_suite_test.go @@ -6,6 +6,7 @@ import ( "crypto/tls" "crypto/x509" "flag" + "fmt" "log" mrand "math/rand" "os" @@ -18,6 +19,7 @@ import ( "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/integrationtests/tools" + "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/logging" @@ -80,13 +82,15 @@ func (b *syncedBuffer) Reset() { } var ( - logFileName string // the log file set in the ginkgo flags - logBufOnce sync.Once - logBuf *syncedBuffer + logFileName string // the log file set in the ginkgo flags + logBufOnce sync.Once + logBuf *syncedBuffer + versionParam string qlogTracer logging.Tracer enableQlog bool + version quic.VersionNumber tlsConfig *tls.Config tlsConfigLongChain *tls.Config tlsClientConfig *tls.Config @@ -96,6 +100,7 @@ var ( // to set call ginkgo -- -logfile=log.txt func init() { flag.StringVar(&logFileName, "logfile", "", "log file") + flag.StringVar(&versionParam, "version", "1", "QUIC version") flag.BoolVar(&enableQlog, "qlog", false, "enable qlog") ca, caPrivateKey, err := tools.GenerateCA() @@ -133,6 +138,18 @@ var _ = BeforeSuite(func() { if enableQlog { qlogTracer = tools.NewQlogger(GinkgoWriter) } + switch versionParam { + case "1": + version = quic.Version1 + case "2": + version = quic.Version2 + case "draft29": + version = quic.VersionDraft29 + default: + Fail(fmt.Sprintf("unknown QUIC version: %s", versionParam)) + } + fmt.Printf("Using QUIC version: %s\n", version) + protocol.SupportedVersions = []quic.VersionNumber{version} }) func getTLSConfig() *tls.Config { diff --git a/integrationtests/self/stream_test.go b/integrationtests/self/stream_test.go index 5dea3fda..622fad29 100644 --- a/integrationtests/self/stream_test.go +++ b/integrationtests/self/stream_test.go @@ -7,11 +7,9 @@ import ( "net" "sync" - "github.com/quic-go/quic-go" - "github.com/quic-go/quic-go/internal/protocol" - . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "github.com/quic-go/quic-go" ) var _ = Describe("Bidirectional streams", func() { @@ -20,144 +18,133 @@ var _ = Describe("Bidirectional streams", func() { var ( server quic.Listener serverAddr string - qconf *quic.Config ) - for _, v := range []protocol.VersionNumber{protocol.Version1} { - version := v + BeforeEach(func() { + var err error + server, err = quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil)) + Expect(err).ToNot(HaveOccurred()) + serverAddr = fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port) + }) - Context(fmt.Sprintf("with QUIC %s", version), func() { - BeforeEach(func() { - var err error - qconf = &quic.Config{ - Versions: []protocol.VersionNumber{version}, - MaxIncomingStreams: 0, - } - server, err = quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(qconf)) + AfterEach(func() { + server.Close() + }) + + runSendingPeer := func(conn quic.Connection) { + var wg sync.WaitGroup + wg.Add(numStreams) + for i := 0; i < numStreams; i++ { + str, err := conn.OpenStreamSync(context.Background()) + Expect(err).ToNot(HaveOccurred()) + data := GeneratePRData(25 * i) + go func() { + defer GinkgoRecover() + _, err := str.Write(data) Expect(err).ToNot(HaveOccurred()) - serverAddr = fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port) - }) - - AfterEach(func() { - server.Close() - }) - - runSendingPeer := func(conn quic.Connection) { - var wg sync.WaitGroup - wg.Add(numStreams) - for i := 0; i < numStreams; i++ { - str, err := conn.OpenStreamSync(context.Background()) - Expect(err).ToNot(HaveOccurred()) - data := GeneratePRData(25 * i) - go func() { - defer GinkgoRecover() - _, err := str.Write(data) - Expect(err).ToNot(HaveOccurred()) - Expect(str.Close()).To(Succeed()) - }() - go func() { - defer GinkgoRecover() - defer wg.Done() - dataRead, err := io.ReadAll(str) - Expect(err).ToNot(HaveOccurred()) - Expect(dataRead).To(Equal(data)) - }() - } - wg.Wait() - } - - runReceivingPeer := func(conn quic.Connection) { - var wg sync.WaitGroup - wg.Add(numStreams) - for i := 0; i < numStreams; i++ { - str, err := conn.AcceptStream(context.Background()) - Expect(err).ToNot(HaveOccurred()) - go func() { - defer GinkgoRecover() - defer wg.Done() - // shouldn't use io.Copy here - // we should read from the stream as early as possible, to free flow control credit - data, err := io.ReadAll(str) - Expect(err).ToNot(HaveOccurred()) - _, err = str.Write(data) - Expect(err).ToNot(HaveOccurred()) - Expect(str.Close()).To(Succeed()) - }() - } - wg.Wait() - } - - It(fmt.Sprintf("client opening %d streams to a server", numStreams), func() { - var conn quic.Connection - go func() { - defer GinkgoRecover() - var err error - conn, err = server.Accept(context.Background()) - Expect(err).ToNot(HaveOccurred()) - runReceivingPeer(conn) - }() - - client, err := quic.DialAddr( - serverAddr, - getTLSClientConfig(), - getQuicConfig(qconf), - ) + Expect(str.Close()).To(Succeed()) + }() + go func() { + defer GinkgoRecover() + defer wg.Done() + dataRead, err := io.ReadAll(str) Expect(err).ToNot(HaveOccurred()) - runSendingPeer(client) - }) - - It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() { - go func() { - defer GinkgoRecover() - conn, err := server.Accept(context.Background()) - Expect(err).ToNot(HaveOccurred()) - runSendingPeer(conn) - conn.CloseWithError(0, "") - }() - - client, err := quic.DialAddr( - serverAddr, - getTLSClientConfig(), - getQuicConfig(qconf), - ) - Expect(err).ToNot(HaveOccurred()) - runReceivingPeer(client) - Eventually(client.Context().Done()).Should(BeClosed()) - }) - - It(fmt.Sprintf("client and server opening %d each and sending data to the peer", numStreams), func() { - done1 := make(chan struct{}) - go func() { - defer GinkgoRecover() - conn, err := server.Accept(context.Background()) - Expect(err).ToNot(HaveOccurred()) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - runReceivingPeer(conn) - close(done) - }() - runSendingPeer(conn) - <-done - close(done1) - }() - - client, err := quic.DialAddr( - serverAddr, - getTLSClientConfig(), - getQuicConfig(qconf), - ) - Expect(err).ToNot(HaveOccurred()) - done2 := make(chan struct{}) - go func() { - defer GinkgoRecover() - runSendingPeer(client) - close(done2) - }() - runReceivingPeer(client) - <-done1 - <-done2 - }) - }) + Expect(dataRead).To(Equal(data)) + }() + } + wg.Wait() } + + runReceivingPeer := func(conn quic.Connection) { + var wg sync.WaitGroup + wg.Add(numStreams) + for i := 0; i < numStreams; i++ { + str, err := conn.AcceptStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + go func() { + defer GinkgoRecover() + defer wg.Done() + // shouldn't use io.Copy here + // we should read from the stream as early as possible, to free flow control credit + data, err := io.ReadAll(str) + Expect(err).ToNot(HaveOccurred()) + _, err = str.Write(data) + Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) + }() + } + wg.Wait() + } + + It(fmt.Sprintf("client opening %d streams to a server", numStreams), func() { + var conn quic.Connection + go func() { + defer GinkgoRecover() + var err error + conn, err = server.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + runReceivingPeer(conn) + }() + + client, err := quic.DialAddr( + serverAddr, + getTLSClientConfig(), + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + runSendingPeer(client) + }) + + It(fmt.Sprintf("server opening %d streams to a client", numStreams), func() { + go func() { + defer GinkgoRecover() + conn, err := server.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + runSendingPeer(conn) + conn.CloseWithError(0, "") + }() + + client, err := quic.DialAddr( + serverAddr, + getTLSClientConfig(), + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + runReceivingPeer(client) + Eventually(client.Context().Done()).Should(BeClosed()) + }) + + It(fmt.Sprintf("client and server opening %d each and sending data to the peer", numStreams), func() { + done1 := make(chan struct{}) + go func() { + defer GinkgoRecover() + conn, err := server.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + runReceivingPeer(conn) + close(done) + }() + runSendingPeer(conn) + <-done + close(done1) + }() + + client, err := quic.DialAddr( + serverAddr, + getTLSClientConfig(), + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + done2 := make(chan struct{}) + go func() { + defer GinkgoRecover() + runSendingPeer(client) + close(done2) + }() + runReceivingPeer(client) + <-done1 + <-done2 + }) }) diff --git a/integrationtests/self/uni_stream_test.go b/integrationtests/self/uni_stream_test.go index 770ce167..0f73c1f0 100644 --- a/integrationtests/self/uni_stream_test.go +++ b/integrationtests/self/uni_stream_test.go @@ -20,13 +20,11 @@ var _ = Describe("Unidirectional Streams", func() { var ( server quic.Listener serverAddr string - qconf *quic.Config ) BeforeEach(func() { var err error - qconf = &quic.Config{Versions: []protocol.VersionNumber{protocol.Version1}} - server, err = quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(qconf)) + server, err = quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil)) Expect(err).ToNot(HaveOccurred()) serverAddr = fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port) }) @@ -81,7 +79,7 @@ var _ = Describe("Unidirectional Streams", func() { client, err := quic.DialAddr( serverAddr, getTLSClientConfig(), - getQuicConfig(qconf), + getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) runSendingPeer(client) @@ -99,7 +97,7 @@ var _ = Describe("Unidirectional Streams", func() { client, err := quic.DialAddr( serverAddr, getTLSClientConfig(), - getQuicConfig(qconf), + getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) runReceivingPeer(client) @@ -125,7 +123,7 @@ var _ = Describe("Unidirectional Streams", func() { client, err := quic.DialAddr( serverAddr, getTLSClientConfig(), - getQuicConfig(qconf), + getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) done2 := make(chan struct{}) diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index 8f6813ed..cc7b4264 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -25,773 +25,749 @@ import ( var _ = Describe("0-RTT", func() { rtt := scaleDuration(5 * time.Millisecond) - for _, v := range protocol.SupportedVersions { - version := v - - Context(fmt.Sprintf("with QUIC version %s", version), func() { - runCountingProxy := func(serverPort int) (*quicproxy.QuicProxy, *uint32) { - var num0RTTPackets uint32 // to be used as an atomic - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), - DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration { - for len(data) > 0 { - if !wire.IsLongHeaderPacket(data[0]) { - break - } - hdr, _, rest, err := wire.ParsePacket(data) - Expect(err).ToNot(HaveOccurred()) - if hdr.Type == protocol.PacketType0RTT { - atomic.AddUint32(&num0RTTPackets, 1) - break - } - data = rest - } - return rtt / 2 - }, - }) - Expect(err).ToNot(HaveOccurred()) - - return proxy, &num0RTTPackets - } - - dialAndReceiveSessionTicket := func(serverConf *quic.Config) (*tls.Config, *tls.Config) { - tlsConf := getTLSConfig() - if serverConf == nil { - serverConf = getQuicConfig(nil) - serverConf.Versions = []protocol.VersionNumber{version} + runCountingProxy := func(serverPort int) (*quicproxy.QuicProxy, *uint32) { + var num0RTTPackets uint32 // to be used as an atomic + proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ + RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), + DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration { + for len(data) > 0 { + if !wire.IsLongHeaderPacket(data[0]) { + break + } + hdr, _, rest, err := wire.ParsePacket(data) + Expect(err).ToNot(HaveOccurred()) + if hdr.Type == protocol.PacketType0RTT { + atomic.AddUint32(&num0RTTPackets, 1) + break + } + data = rest } - serverConf.Allow0RTT = func(addr net.Addr) bool { return true } - ln, err := quic.ListenAddrEarly( - "localhost:0", - tlsConf, - serverConf, - ) - Expect(err).ToNot(HaveOccurred()) - defer ln.Close() + return rtt / 2 + }, + }) + Expect(err).ToNot(HaveOccurred()) - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), - DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration { return rtt / 2 }, - }) - Expect(err).ToNot(HaveOccurred()) - defer proxy.Close() + return proxy, &num0RTTPackets + } - // dial the first connection in order to receive a session ticket - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(done) - conn, err := ln.Accept(context.Background()) - Expect(err).ToNot(HaveOccurred()) - <-conn.Context().Done() - }() + dialAndReceiveSessionTicket := func(serverConf *quic.Config) (*tls.Config, *tls.Config) { + tlsConf := getTLSConfig() + if serverConf == nil { + serverConf = getQuicConfig(nil) + } + serverConf.Allow0RTT = func(addr net.Addr) bool { return true } + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + serverConf, + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() - clientConf := getTLSClientConfig() - gets := make(chan string, 100) - puts := make(chan string, 100) - clientConf.ClientSessionCache = newClientSessionCache(gets, puts) - conn, err := quic.DialAddr( - fmt.Sprintf("localhost:%d", proxy.LocalPort()), - clientConf, - getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}), - ) - Expect(err).ToNot(HaveOccurred()) - Eventually(puts).Should(Receive()) - // received the session ticket. We're done here. - Expect(conn.CloseWithError(0, "")).To(Succeed()) - Eventually(done).Should(BeClosed()) - return tlsConf, clientConf + proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ + RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), + DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration { return rtt / 2 }, + }) + Expect(err).ToNot(HaveOccurred()) + defer proxy.Close() + + // dial the first connection in order to receive a session ticket + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done) + conn, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + <-conn.Context().Done() + }() + + clientConf := getTLSClientConfig() + gets := make(chan string, 100) + puts := make(chan string, 100) + clientConf.ClientSessionCache = newClientSessionCache(gets, puts) + conn, err := quic.DialAddr( + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + clientConf, + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + Eventually(puts).Should(Receive()) + // received the session ticket. We're done here. + Expect(conn.CloseWithError(0, "")).To(Succeed()) + Eventually(done).Should(BeClosed()) + return tlsConf, clientConf + } + + transfer0RTTData := func( + ln quic.EarlyListener, + proxyPort int, + clientTLSConf *tls.Config, + clientConf *quic.Config, + testdata []byte, // data to transfer + ) { + // accept the second connection, and receive the data sent in 0-RTT + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + conn, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + str, err := conn.AcceptStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + data, err := io.ReadAll(str) + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal(testdata)) + Expect(str.Close()).To(Succeed()) + Expect(conn.ConnectionState().TLS.Used0RTT).To(BeTrue()) + <-conn.Context().Done() + close(done) + }() + + if clientConf == nil { + clientConf = getQuicConfig(nil) + } + conn, err := quic.DialAddrEarly( + fmt.Sprintf("localhost:%d", proxyPort), + clientTLSConf, + clientConf, + ) + Expect(err).ToNot(HaveOccurred()) + defer conn.CloseWithError(0, "") + str, err := conn.OpenStream() + Expect(err).ToNot(HaveOccurred()) + _, err = str.Write(testdata) + Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) + <-conn.HandshakeComplete() + Expect(conn.ConnectionState().TLS.Used0RTT).To(BeTrue()) + io.ReadAll(str) // wait for the EOF from the server to arrive before closing the conn + conn.CloseWithError(0, "") + Eventually(done).Should(BeClosed()) + Eventually(conn.Context().Done()).Should(BeClosed()) + } + + check0RTTRejected := func( + ln quic.EarlyListener, + proxyPort int, + clientConf *tls.Config, + ) { + conn, err := quic.DialAddrEarly( + fmt.Sprintf("localhost:%d", proxyPort), + clientConf, + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + str, err := conn.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + _, err = str.Write(make([]byte, 3000)) + Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) + Expect(conn.ConnectionState().TLS.Used0RTT).To(BeFalse()) + + // make sure the server doesn't process the data + ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(50*time.Millisecond)) + defer cancel() + serverConn, err := ln.Accept(ctx) + Expect(err).ToNot(HaveOccurred()) + Expect(serverConn.ConnectionState().TLS.Used0RTT).To(BeFalse()) + _, err = serverConn.AcceptUniStream(ctx) + Expect(err).To(Equal(context.DeadlineExceeded)) + Expect(serverConn.CloseWithError(0, "")).To(Succeed()) + Eventually(conn.Context().Done()).Should(BeClosed()) + } + + // can be used to extract 0-RTT from a packetTracer + get0RTTPackets := func(packets []packet) []protocol.PacketNumber { + var zeroRTTPackets []protocol.PacketNumber + for _, p := range packets { + if p.hdr.Type == protocol.PacketType0RTT { + zeroRTTPackets = append(zeroRTTPackets, p.hdr.PacketNumber) } + } + return zeroRTTPackets + } - transfer0RTTData := func( - ln quic.EarlyListener, - proxyPort int, - clientTLSConf *tls.Config, - clientConf *quic.Config, - testdata []byte, // data to transfer - ) { - // accept the second connection, and receive the data sent in 0-RTT - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - conn, err := ln.Accept(context.Background()) - Expect(err).ToNot(HaveOccurred()) - str, err := conn.AcceptStream(context.Background()) - Expect(err).ToNot(HaveOccurred()) - data, err := io.ReadAll(str) - Expect(err).ToNot(HaveOccurred()) - Expect(data).To(Equal(testdata)) - Expect(str.Close()).To(Succeed()) - Expect(conn.ConnectionState().TLS.Used0RTT).To(BeTrue()) - <-conn.Context().Done() - close(done) - }() + for _, l := range []int{0, 15} { + connIDLen := l - if clientConf == nil { - clientConf = getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}) - } - conn, err := quic.DialAddrEarly( - fmt.Sprintf("localhost:%d", proxyPort), - clientTLSConf, - clientConf, - ) - Expect(err).ToNot(HaveOccurred()) - defer conn.CloseWithError(0, "") - str, err := conn.OpenStream() - Expect(err).ToNot(HaveOccurred()) - _, err = str.Write(testdata) - Expect(err).ToNot(HaveOccurred()) - Expect(str.Close()).To(Succeed()) - <-conn.HandshakeComplete() - Expect(conn.ConnectionState().TLS.Used0RTT).To(BeTrue()) - io.ReadAll(str) // wait for the EOF from the server to arrive before closing the conn - conn.CloseWithError(0, "") - Eventually(done).Should(BeClosed()) - Eventually(conn.Context().Done()).Should(BeClosed()) - } + It(fmt.Sprintf("transfers 0-RTT data, with %d byte connection IDs", connIDLen), func() { + tlsConf, clientTLSConf := dialAndReceiveSessionTicket(nil) - check0RTTRejected := func( - ln quic.EarlyListener, - proxyPort int, - clientConf *tls.Config, - ) { - conn, err := quic.DialAddrEarly( - fmt.Sprintf("localhost:%d", proxyPort), - clientConf, - getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}), - ) - Expect(err).ToNot(HaveOccurred()) - str, err := conn.OpenUniStream() - Expect(err).ToNot(HaveOccurred()) - _, err = str.Write(make([]byte, 3000)) - Expect(err).ToNot(HaveOccurred()) - Expect(str.Close()).To(Succeed()) - Expect(conn.ConnectionState().TLS.Used0RTT).To(BeFalse()) + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + Allow0RTT: func(addr net.Addr) bool { return true }, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() - // make sure the server doesn't process the data - ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(50*time.Millisecond)) - defer cancel() - serverConn, err := ln.Accept(ctx) - Expect(err).ToNot(HaveOccurred()) - Expect(serverConn.ConnectionState().TLS.Used0RTT).To(BeFalse()) - _, err = serverConn.AcceptUniStream(ctx) - Expect(err).To(Equal(context.DeadlineExceeded)) - Expect(serverConn.CloseWithError(0, "")).To(Succeed()) - Eventually(conn.Context().Done()).Should(BeClosed()) - } + proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() - // can be used to extract 0-RTT from a packetTracer - get0RTTPackets := func(packets []packet) []protocol.PacketNumber { - var zeroRTTPackets []protocol.PacketNumber - for _, p := range packets { - if p.hdr.Type == protocol.PacketType0RTT { - zeroRTTPackets = append(zeroRTTPackets, p.hdr.PacketNumber) + transfer0RTTData( + ln, + proxy.LocalPort(), + clientTLSConf, + &quic.Config{ConnectionIDLength: connIDLen}, + PRData, + ) + + var numNewConnIDs int + for _, p := range tracer.getRcvdLongHeaderPackets() { + for _, f := range p.frames { + if _, ok := f.(*logging.NewConnectionIDFrame); ok { + numNewConnIDs++ } } - return zeroRTTPackets + } + if connIDLen == 0 { + Expect(numNewConnIDs).To(BeZero()) + } else { + Expect(numNewConnIDs).ToNot(BeZero()) } - for _, l := range []int{0, 15} { - connIDLen := l + num0RTT := atomic.LoadUint32(num0RTTPackets) + fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) + Expect(num0RTT).ToNot(BeZero()) + zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) + Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10)) + sort.Slice(zeroRTTPackets, func(i, j int) bool { return zeroRTTPackets[i] < zeroRTTPackets[j] }) + Expect(zeroRTTPackets[0]).To(Equal(protocol.PacketNumber(0))) + }) + } - It(fmt.Sprintf("transfers 0-RTT data, with %d byte connection IDs", connIDLen), func() { - tlsConf, clientTLSConf := dialAndReceiveSessionTicket(nil) + // Test that data intended to be sent with 1-RTT protection is not sent in 0-RTT packets. + It("waits for a connection until the handshake is done", func() { + tlsConf, clientConf := dialAndReceiveSessionTicket(nil) - tracer := newPacketTracer() - ln, err := quic.ListenAddrEarly( - "localhost:0", - tlsConf, - getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - Allow0RTT: func(addr net.Addr) bool { return true }, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), - }), - ) - Expect(err).ToNot(HaveOccurred()) - defer ln.Close() + zeroRTTData := GeneratePRData(5 << 10) + oneRTTData := PRData - proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) - defer proxy.Close() + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + Allow0RTT: func(net.Addr) bool { return true }, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() - transfer0RTTData( - ln, - proxy.LocalPort(), - clientTLSConf, - &quic.Config{ - ConnectionIDLength: connIDLen, - Versions: []protocol.VersionNumber{version}, - }, - PRData, - ) + // now accept the second connection, and receive the 0-RTT data + go func() { + defer GinkgoRecover() + conn, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + str, err := conn.AcceptUniStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + data, err := io.ReadAll(str) + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal(zeroRTTData)) + str, err = conn.AcceptUniStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + data, err = io.ReadAll(str) + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal(oneRTTData)) + Expect(conn.CloseWithError(0, "")).To(Succeed()) + }() - var numNewConnIDs int - for _, p := range tracer.getRcvdLongHeaderPackets() { - for _, f := range p.frames { - if _, ok := f.(*logging.NewConnectionIDFrame); ok { - numNewConnIDs++ - } - } - } - if connIDLen == 0 { - Expect(numNewConnIDs).To(BeZero()) - } else { - Expect(numNewConnIDs).ToNot(BeZero()) - } + proxy, _ := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() - num0RTT := atomic.LoadUint32(num0RTTPackets) - fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) - Expect(num0RTT).ToNot(BeZero()) - zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) - Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10)) - sort.Slice(zeroRTTPackets, func(i, j int) bool { return zeroRTTPackets[i] < zeroRTTPackets[j] }) - Expect(zeroRTTPackets[0]).To(Equal(protocol.PacketNumber(0))) - }) + conn, err := quic.DialAddrEarly( + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + clientConf, + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + firstStr, err := conn.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + _, err = firstStr.Write(zeroRTTData) + Expect(err).ToNot(HaveOccurred()) + Expect(firstStr.Close()).To(Succeed()) + + // wait for the handshake to complete + Eventually(conn.HandshakeComplete()).Should(BeClosed()) + str, err := conn.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + _, err = str.Write(PRData) + Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) + <-conn.Context().Done() + + // check that 0-RTT packets only contain STREAM frames for the first stream + var num0RTT int + for _, p := range tracer.getRcvdLongHeaderPackets() { + if p.hdr.Header.Type != protocol.PacketType0RTT { + continue } + for _, f := range p.frames { + sf, ok := f.(*logging.StreamFrame) + if !ok { + continue + } + num0RTT++ + Expect(sf.StreamID).To(Equal(firstStr.StreamID())) + } + } + fmt.Fprintf(GinkgoWriter, "received %d STREAM frames in 0-RTT packets\n", num0RTT) + Expect(num0RTT).ToNot(BeZero()) + }) - // Test that data intended to be sent with 1-RTT protection is not sent in 0-RTT packets. - It("waits for a connection until the handshake is done", func() { - tlsConf, clientConf := dialAndReceiveSessionTicket(nil) + It("transfers 0-RTT data, when 0-RTT packets are lost", func() { + var ( + num0RTTPackets uint32 // to be used as an atomic + num0RTTDropped uint32 + ) - zeroRTTData := GeneratePRData(5 << 10) - oneRTTData := PRData + tlsConf, clientConf := dialAndReceiveSessionTicket(nil) - tracer := newPacketTracer() - ln, err := quic.ListenAddrEarly( - "localhost:0", - tlsConf, - getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - Allow0RTT: func(net.Addr) bool { return true }, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), - }), - ) - Expect(err).ToNot(HaveOccurred()) - defer ln.Close() + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + Allow0RTT: func(net.Addr) bool { return true }, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() - // now accept the second connection, and receive the 0-RTT data - go func() { - defer GinkgoRecover() - conn, err := ln.Accept(context.Background()) + proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ + RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), + DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration { + if wire.IsLongHeaderPacket(data[0]) { + hdr, _, _, err := wire.ParsePacket(data) Expect(err).ToNot(HaveOccurred()) - str, err := conn.AcceptUniStream(context.Background()) - Expect(err).ToNot(HaveOccurred()) - data, err := io.ReadAll(str) - Expect(err).ToNot(HaveOccurred()) - Expect(data).To(Equal(zeroRTTData)) - str, err = conn.AcceptUniStream(context.Background()) - Expect(err).ToNot(HaveOccurred()) - data, err = io.ReadAll(str) - Expect(err).ToNot(HaveOccurred()) - Expect(data).To(Equal(oneRTTData)) - Expect(conn.CloseWithError(0, "")).To(Succeed()) - }() - - proxy, _ := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) - defer proxy.Close() - - conn, err := quic.DialAddrEarly( - fmt.Sprintf("localhost:%d", proxy.LocalPort()), - clientConf, - getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}), - ) - Expect(err).ToNot(HaveOccurred()) - firstStr, err := conn.OpenUniStream() - Expect(err).ToNot(HaveOccurred()) - _, err = firstStr.Write(zeroRTTData) - Expect(err).ToNot(HaveOccurred()) - Expect(firstStr.Close()).To(Succeed()) - - // wait for the handshake to complete - Eventually(conn.HandshakeComplete()).Should(BeClosed()) - str, err := conn.OpenUniStream() - Expect(err).ToNot(HaveOccurred()) - _, err = str.Write(PRData) - Expect(err).ToNot(HaveOccurred()) - Expect(str.Close()).To(Succeed()) - <-conn.Context().Done() - - // check that 0-RTT packets only contain STREAM frames for the first stream - var num0RTT int - for _, p := range tracer.getRcvdLongHeaderPackets() { - if p.hdr.Header.Type != protocol.PacketType0RTT { - continue - } - for _, f := range p.frames { - sf, ok := f.(*logging.StreamFrame) - if !ok { - continue - } - num0RTT++ - Expect(sf.StreamID).To(Equal(firstStr.StreamID())) + if hdr.Type == protocol.PacketType0RTT { + atomic.AddUint32(&num0RTTPackets, 1) } } - fmt.Fprintf(GinkgoWriter, "received %d STREAM frames in 0-RTT packets\n", num0RTT) - Expect(num0RTT).ToNot(BeZero()) - }) - - It("transfers 0-RTT data, when 0-RTT packets are lost", func() { - var ( - num0RTTPackets uint32 // to be used as an atomic - num0RTTDropped uint32 - ) - - tlsConf, clientConf := dialAndReceiveSessionTicket(nil) - - tracer := newPacketTracer() - ln, err := quic.ListenAddrEarly( - "localhost:0", - tlsConf, - getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - Allow0RTT: func(net.Addr) bool { return true }, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), - }), - ) + return rtt / 2 + }, + DropPacket: func(_ quicproxy.Direction, data []byte) bool { + if !wire.IsLongHeaderPacket(data[0]) { + return false + } + hdr, _, _, err := wire.ParsePacket(data) Expect(err).ToNot(HaveOccurred()) - defer ln.Close() - - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), - DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration { - if wire.IsLongHeaderPacket(data[0]) { - hdr, _, _, err := wire.ParsePacket(data) - Expect(err).ToNot(HaveOccurred()) - if hdr.Type == protocol.PacketType0RTT { - atomic.AddUint32(&num0RTTPackets, 1) - } - } - return rtt / 2 - }, - DropPacket: func(_ quicproxy.Direction, data []byte) bool { - if !wire.IsLongHeaderPacket(data[0]) { - return false - } - hdr, _, _, err := wire.ParsePacket(data) - Expect(err).ToNot(HaveOccurred()) - if hdr.Type == protocol.PacketType0RTT { - // drop 25% of the 0-RTT packets - drop := mrand.Intn(4) == 0 - if drop { - atomic.AddUint32(&num0RTTDropped, 1) - } - return drop - } - return false - }, - }) - Expect(err).ToNot(HaveOccurred()) - defer proxy.Close() - - transfer0RTTData(ln, proxy.LocalPort(), clientConf, nil, PRData) - - num0RTT := atomic.LoadUint32(&num0RTTPackets) - numDropped := atomic.LoadUint32(&num0RTTDropped) - fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets. Dropped %d of those.", num0RTT, numDropped) - Expect(numDropped).ToNot(BeZero()) - Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).ToNot(BeEmpty()) - }) - - It("retransmits all 0-RTT data when the server performs a Retry", func() { - var mutex sync.Mutex - var firstConnID, secondConnID *protocol.ConnectionID - var firstCounter, secondCounter protocol.ByteCount - - tlsConf, clientConf := dialAndReceiveSessionTicket(nil) - - countZeroRTTBytes := func(data []byte) (n protocol.ByteCount) { - for len(data) > 0 { - hdr, _, rest, err := wire.ParsePacket(data) - if err != nil { - return - } - data = rest - if hdr.Type == protocol.PacketType0RTT { - n += hdr.Length - 16 /* AEAD tag */ - } + if hdr.Type == protocol.PacketType0RTT { + // drop 25% of the 0-RTT packets + drop := mrand.Intn(4) == 0 + if drop { + atomic.AddUint32(&num0RTTDropped, 1) } + return drop + } + return false + }, + }) + Expect(err).ToNot(HaveOccurred()) + defer proxy.Close() + + transfer0RTTData(ln, proxy.LocalPort(), clientConf, nil, PRData) + + num0RTT := atomic.LoadUint32(&num0RTTPackets) + numDropped := atomic.LoadUint32(&num0RTTDropped) + fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets. Dropped %d of those.", num0RTT, numDropped) + Expect(numDropped).ToNot(BeZero()) + Expect(num0RTT).ToNot(BeZero()) + Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).ToNot(BeEmpty()) + }) + + It("retransmits all 0-RTT data when the server performs a Retry", func() { + var mutex sync.Mutex + var firstConnID, secondConnID *protocol.ConnectionID + var firstCounter, secondCounter protocol.ByteCount + + tlsConf, clientConf := dialAndReceiveSessionTicket(nil) + + countZeroRTTBytes := func(data []byte) (n protocol.ByteCount) { + for len(data) > 0 { + hdr, _, rest, err := wire.ParsePacket(data) + if err != nil { return } + data = rest + if hdr.Type == protocol.PacketType0RTT { + n += hdr.Length - 16 /* AEAD tag */ + } + } + return + } - tracer := newPacketTracer() - ln, err := quic.ListenAddrEarly( - "localhost:0", - tlsConf, - getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - RequireAddressValidation: func(net.Addr) bool { return true }, - Allow0RTT: func(net.Addr) bool { return true }, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), - }), - ) + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + RequireAddressValidation: func(net.Addr) bool { return true }, + Allow0RTT: func(net.Addr) bool { return true }, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + + proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ + RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), + DelayPacket: func(dir quicproxy.Direction, data []byte) time.Duration { + connID, err := wire.ParseConnectionID(data, 0) Expect(err).ToNot(HaveOccurred()) - defer ln.Close() - - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), - DelayPacket: func(dir quicproxy.Direction, data []byte) time.Duration { - connID, err := wire.ParseConnectionID(data, 0) - Expect(err).ToNot(HaveOccurred()) - - mutex.Lock() - defer mutex.Unlock() - - if zeroRTTBytes := countZeroRTTBytes(data); zeroRTTBytes > 0 { - if firstConnID == nil { - firstConnID = &connID - firstCounter += zeroRTTBytes - } else if firstConnID != nil && *firstConnID == connID { - Expect(secondConnID).To(BeNil()) - firstCounter += zeroRTTBytes - } else if secondConnID == nil { - secondConnID = &connID - secondCounter += zeroRTTBytes - } else if secondConnID != nil && *secondConnID == connID { - secondCounter += zeroRTTBytes - } else { - Fail("received 3 connection IDs on 0-RTT packets") - } - } - return rtt / 2 - }, - }) - Expect(err).ToNot(HaveOccurred()) - defer proxy.Close() - - transfer0RTTData(ln, proxy.LocalPort(), clientConf, nil, GeneratePRData(5000)) // ~5 packets mutex.Lock() defer mutex.Unlock() - Expect(firstCounter).To(BeNumerically("~", 5000+100 /* framing overhead */, 100)) // the FIN bit might be sent extra - Expect(secondCounter).To(BeNumerically("~", firstCounter, 20)) - zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) - Expect(len(zeroRTTPackets)).To(BeNumerically(">=", 5)) - Expect(zeroRTTPackets[0]).To(BeNumerically(">=", protocol.PacketNumber(5))) + + if zeroRTTBytes := countZeroRTTBytes(data); zeroRTTBytes > 0 { + if firstConnID == nil { + firstConnID = &connID + firstCounter += zeroRTTBytes + } else if firstConnID != nil && *firstConnID == connID { + Expect(secondConnID).To(BeNil()) + firstCounter += zeroRTTBytes + } else if secondConnID == nil { + secondConnID = &connID + secondCounter += zeroRTTBytes + } else if secondConnID != nil && *secondConnID == connID { + secondCounter += zeroRTTBytes + } else { + Fail("received 3 connection IDs on 0-RTT packets") + } + } + return rtt / 2 + }, + }) + Expect(err).ToNot(HaveOccurred()) + defer proxy.Close() + + transfer0RTTData(ln, proxy.LocalPort(), clientConf, nil, GeneratePRData(5000)) // ~5 packets + + mutex.Lock() + defer mutex.Unlock() + Expect(firstCounter).To(BeNumerically("~", 5000+100 /* framing overhead */, 100)) // the FIN bit might be sent extra + Expect(secondCounter).To(BeNumerically("~", firstCounter, 20)) + zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) + Expect(len(zeroRTTPackets)).To(BeNumerically(">=", 5)) + Expect(zeroRTTPackets[0]).To(BeNumerically(">=", protocol.PacketNumber(5))) + }) + + It("doesn't reject 0-RTT when the server's transport stream limit increased", func() { + const maxStreams = 1 + tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{ + MaxIncomingUniStreams: maxStreams, + })) + + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + MaxIncomingUniStreams: maxStreams + 1, + Allow0RTT: func(net.Addr) bool { return true }, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + proxy, _ := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() + + conn, err := quic.DialAddrEarly( + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + clientConf, + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + str, err := conn.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + _, err = str.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) + // The client remembers the old limit and refuses to open a new stream. + _, err = conn.OpenUniStream() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("too many open streams")) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + _, err = conn.OpenUniStreamSync(ctx) + Expect(err).ToNot(HaveOccurred()) + Expect(conn.ConnectionState().TLS.Used0RTT).To(BeTrue()) + Expect(conn.CloseWithError(0, "")).To(Succeed()) + }) + + It("rejects 0-RTT when the server's stream limit decreased", func() { + const maxStreams = 42 + tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{ + MaxIncomingStreams: maxStreams, + })) + + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + MaxIncomingStreams: maxStreams - 1, + Allow0RTT: func(net.Addr) bool { return true }, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() + check0RTTRejected(ln, proxy.LocalPort(), clientConf) + + // The client should send 0-RTT packets, but the server doesn't process them. + num0RTT := atomic.LoadUint32(num0RTTPackets) + fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) + Expect(num0RTT).ToNot(BeZero()) + Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + }) + + It("rejects 0-RTT when the ALPN changed", func() { + tlsConf, clientConf := dialAndReceiveSessionTicket(nil) + + // now close the listener and dial new connection with a different ALPN + clientConf.NextProtos = []string{"new-alpn"} + tlsConf.NextProtos = []string{"new-alpn"} + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + Allow0RTT: func(net.Addr) bool { return true }, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() + + check0RTTRejected(ln, proxy.LocalPort(), clientConf) + + // The client should send 0-RTT packets, but the server doesn't process them. + num0RTT := atomic.LoadUint32(num0RTTPackets) + fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) + Expect(num0RTT).ToNot(BeZero()) + Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + }) + + It("rejects 0-RTT when the application doesn't allow it", func() { + tlsConf, clientConf := dialAndReceiveSessionTicket(nil) + + // now close the listener and dial new connection with a different ALPN + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + Allow0RTT: func(net.Addr) bool { return false }, // application rejects 0-RTT + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() + + check0RTTRejected(ln, proxy.LocalPort(), clientConf) + + // The client should send 0-RTT packets, but the server doesn't process them. + num0RTT := atomic.LoadUint32(num0RTTPackets) + fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) + Expect(num0RTT).ToNot(BeZero()) + Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + }) + + DescribeTable("flow control limits", + func(addFlowControlLimit func(*quic.Config, uint64)) { + tracer := newPacketTracer() + firstConf := getQuicConfig(&quic.Config{Allow0RTT: func(net.Addr) bool { return true }}) + addFlowControlLimit(firstConf, 3) + tlsConf, clientConf := dialAndReceiveSessionTicket(firstConf) + + secondConf := getQuicConfig(&quic.Config{ + Allow0RTT: func(net.Addr) bool { return true }, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }) + addFlowControlLimit(secondConf, 100) + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + secondConf, + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + proxy, _ := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() - It("doesn't reject 0-RTT when the server's transport stream limit increased", func() { - const maxStreams = 1 - tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{ - MaxIncomingUniStreams: maxStreams, - })) - - tracer := newPacketTracer() - ln, err := quic.ListenAddrEarly( - "localhost:0", - tlsConf, - getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - MaxIncomingUniStreams: maxStreams + 1, - Allow0RTT: func(net.Addr) bool { return true }, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), - }), - ) - Expect(err).ToNot(HaveOccurred()) - defer ln.Close() - proxy, _ := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) - defer proxy.Close() - - conn, err := quic.DialAddrEarly( - fmt.Sprintf("localhost:%d", proxy.LocalPort()), - clientConf, - getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}), - ) - Expect(err).ToNot(HaveOccurred()) - str, err := conn.OpenUniStream() - Expect(err).ToNot(HaveOccurred()) - _, err = str.Write([]byte("foobar")) + conn, err := quic.DialAddrEarly( + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + clientConf, + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + str, err := conn.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + written := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(written) + _, err := str.Write([]byte("foobar")) Expect(err).ToNot(HaveOccurred()) Expect(str.Close()).To(Succeed()) - // The client remembers the old limit and refuses to open a new stream. - _, err = conn.OpenUniStream() - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("too many open streams")) - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - _, err = conn.OpenUniStreamSync(ctx) - Expect(err).ToNot(HaveOccurred()) - Expect(conn.ConnectionState().TLS.Used0RTT).To(BeTrue()) - Expect(conn.CloseWithError(0, "")).To(Succeed()) - }) + }() - It("rejects 0-RTT when the server's stream limit decreased", func() { - const maxStreams = 42 - tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{ - MaxIncomingStreams: maxStreams, - })) + Eventually(written).Should(BeClosed()) - tracer := newPacketTracer() - ln, err := quic.ListenAddrEarly( - "localhost:0", - tlsConf, - getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - MaxIncomingStreams: maxStreams - 1, - Allow0RTT: func(net.Addr) bool { return true }, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), - }), - ) - Expect(err).ToNot(HaveOccurred()) - defer ln.Close() - proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) - defer proxy.Close() - check0RTTRejected(ln, proxy.LocalPort(), clientConf) + serverConn, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + rstr, err := serverConn.AcceptUniStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + data, err := io.ReadAll(rstr) + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal([]byte("foobar"))) + Expect(serverConn.ConnectionState().TLS.Used0RTT).To(BeTrue()) + Expect(serverConn.CloseWithError(0, "")).To(Succeed()) + Eventually(conn.Context().Done()).Should(BeClosed()) - // The client should send 0-RTT packets, but the server doesn't process them. - num0RTT := atomic.LoadUint32(num0RTTPackets) - fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) - Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) - }) - - It("rejects 0-RTT when the ALPN changed", func() { - tlsConf, clientConf := dialAndReceiveSessionTicket(nil) - - // now close the listener and dial new connection with a different ALPN - clientConf.NextProtos = []string{"new-alpn"} - tlsConf.NextProtos = []string{"new-alpn"} - tracer := newPacketTracer() - ln, err := quic.ListenAddrEarly( - "localhost:0", - tlsConf, - getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - Allow0RTT: func(net.Addr) bool { return true }, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), - }), - ) - Expect(err).ToNot(HaveOccurred()) - defer ln.Close() - proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) - defer proxy.Close() - - check0RTTRejected(ln, proxy.LocalPort(), clientConf) - - // The client should send 0-RTT packets, but the server doesn't process them. - num0RTT := atomic.LoadUint32(num0RTTPackets) - fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) - Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) - }) - - It("rejects 0-RTT when the application doesn't allow it", func() { - tlsConf, clientConf := dialAndReceiveSessionTicket(nil) - - // now close the listener and dial new connection with a different ALPN - tracer := newPacketTracer() - ln, err := quic.ListenAddrEarly( - "localhost:0", - tlsConf, - getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - Allow0RTT: func(net.Addr) bool { return false }, // application rejects 0-RTT - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), - }), - ) - Expect(err).ToNot(HaveOccurred()) - defer ln.Close() - proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) - defer proxy.Close() - - check0RTTRejected(ln, proxy.LocalPort(), clientConf) - - // The client should send 0-RTT packets, but the server doesn't process them. - num0RTT := atomic.LoadUint32(num0RTTPackets) - fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) - Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) - }) - - DescribeTable("flow control limits", - func(addFlowControlLimit func(*quic.Config, uint64)) { - tracer := newPacketTracer() - firstConf := getQuicConfig(&quic.Config{ - Allow0RTT: func(net.Addr) bool { return true }, - Versions: []protocol.VersionNumber{version}, - }) - addFlowControlLimit(firstConf, 3) - tlsConf, clientConf := dialAndReceiveSessionTicket(firstConf) - - secondConf := getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - Allow0RTT: func(net.Addr) bool { return true }, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), - }) - addFlowControlLimit(secondConf, 100) - ln, err := quic.ListenAddrEarly( - "localhost:0", - tlsConf, - secondConf, - ) - Expect(err).ToNot(HaveOccurred()) - defer ln.Close() - proxy, _ := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) - defer proxy.Close() - - conn, err := quic.DialAddrEarly( - fmt.Sprintf("localhost:%d", proxy.LocalPort()), - clientConf, - getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}), - ) - Expect(err).ToNot(HaveOccurred()) - str, err := conn.OpenUniStream() - Expect(err).ToNot(HaveOccurred()) - written := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(written) - _, err := str.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - Expect(str.Close()).To(Succeed()) - }() - - Eventually(written).Should(BeClosed()) - - serverConn, err := ln.Accept(context.Background()) - Expect(err).ToNot(HaveOccurred()) - rstr, err := serverConn.AcceptUniStream(context.Background()) - Expect(err).ToNot(HaveOccurred()) - data, err := io.ReadAll(rstr) - Expect(err).ToNot(HaveOccurred()) - Expect(data).To(Equal([]byte("foobar"))) - Expect(serverConn.ConnectionState().TLS.Used0RTT).To(BeTrue()) - Expect(serverConn.CloseWithError(0, "")).To(Succeed()) - Eventually(conn.Context().Done()).Should(BeClosed()) - - var processedFirst bool - for _, p := range tracer.getRcvdLongHeaderPackets() { - for _, f := range p.frames { - if sf, ok := f.(*logging.StreamFrame); ok { - if !processedFirst { - // The first STREAM should have been sent in a 0-RTT packet. - // Due to the flow control limit, the STREAM frame was limit to the first 3 bytes. - Expect(p.hdr.Type).To(Equal(protocol.PacketType0RTT)) - Expect(sf.Length).To(BeEquivalentTo(3)) - processedFirst = true - } else { - Fail("STREAM was shouldn't have been sent in 0-RTT") - } - } + var processedFirst bool + for _, p := range tracer.getRcvdLongHeaderPackets() { + for _, f := range p.frames { + if sf, ok := f.(*logging.StreamFrame); ok { + if !processedFirst { + // The first STREAM should have been sent in a 0-RTT packet. + // Due to the flow control limit, the STREAM frame was limit to the first 3 bytes. + Expect(p.hdr.Type).To(Equal(protocol.PacketType0RTT)) + Expect(sf.Length).To(BeEquivalentTo(3)) + processedFirst = true + } else { + Fail("STREAM was shouldn't have been sent in 0-RTT") } } - }, - Entry("doesn't reject 0-RTT when the server's transport stream flow control limit increased", func(c *quic.Config, limit uint64) { c.InitialStreamReceiveWindow = limit }), - Entry("doesn't reject 0-RTT when the server's transport connection flow control limit increased", func(c *quic.Config, limit uint64) { c.InitialConnectionReceiveWindow = limit }), - ) - - for _, l := range []int{0, 15} { - connIDLen := l - - It(fmt.Sprintf("correctly deals with 0-RTT rejections, for %d byte connection IDs", connIDLen), func() { - tlsConf, clientConf := dialAndReceiveSessionTicket(nil) - // now dial new connection with different transport parameters - tracer := newPacketTracer() - ln, err := quic.ListenAddrEarly( - "localhost:0", - tlsConf, - getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - MaxIncomingUniStreams: 1, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), - }), - ) - Expect(err).ToNot(HaveOccurred()) - defer ln.Close() - proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) - defer proxy.Close() - - conn, err := quic.DialAddrEarly( - fmt.Sprintf("localhost:%d", proxy.LocalPort()), - clientConf, - getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}), - ) - Expect(err).ToNot(HaveOccurred()) - // The client remembers that it was allowed to open 2 uni-directional streams. - firstStr, err := conn.OpenUniStream() - Expect(err).ToNot(HaveOccurred()) - written := make(chan struct{}, 2) - go func() { - defer GinkgoRecover() - defer func() { written <- struct{}{} }() - _, err := firstStr.Write([]byte("first flight")) - Expect(err).ToNot(HaveOccurred()) - }() - secondStr, err := conn.OpenUniStream() - Expect(err).ToNot(HaveOccurred()) - go func() { - defer GinkgoRecover() - defer func() { written <- struct{}{} }() - _, err := secondStr.Write([]byte("first flight")) - Expect(err).ToNot(HaveOccurred()) - }() - - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - _, err = conn.AcceptStream(ctx) - Expect(err).To(MatchError(quic.Err0RTTRejected)) - Eventually(written).Should(Receive()) - Eventually(written).Should(Receive()) - _, err = firstStr.Write([]byte("foobar")) - Expect(err).To(MatchError(quic.Err0RTTRejected)) - _, err = conn.OpenUniStream() - Expect(err).To(MatchError(quic.Err0RTTRejected)) - - _, err = conn.AcceptStream(ctx) - Expect(err).To(Equal(quic.Err0RTTRejected)) - - newConn := conn.NextConnection() - str, err := newConn.OpenUniStream() - Expect(err).ToNot(HaveOccurred()) - _, err = newConn.OpenUniStream() - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("too many open streams")) - _, err = str.Write([]byte("second flight")) - Expect(err).ToNot(HaveOccurred()) - Expect(str.Close()).To(Succeed()) - Expect(conn.CloseWithError(0, "")).To(Succeed()) - - // The client should send 0-RTT packets, but the server doesn't process them. - num0RTT := atomic.LoadUint32(num0RTTPackets) - fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) - Expect(num0RTT).ToNot(BeZero()) - Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) - }) + } } + }, + Entry("doesn't reject 0-RTT when the server's transport stream flow control limit increased", func(c *quic.Config, limit uint64) { c.InitialStreamReceiveWindow = limit }), + Entry("doesn't reject 0-RTT when the server's transport connection flow control limit increased", func(c *quic.Config, limit uint64) { c.InitialConnectionReceiveWindow = limit }), + ) - It("queues 0-RTT packets, if the Initial is delayed", func() { - tlsConf, clientConf := dialAndReceiveSessionTicket(nil) + for _, l := range []int{0, 15} { + connIDLen := l - tracer := newPacketTracer() - ln, err := quic.ListenAddrEarly( - "localhost:0", - tlsConf, - getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - Allow0RTT: func(net.Addr) bool { return true }, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), - }), - ) + It(fmt.Sprintf("correctly deals with 0-RTT rejections, for %d byte connection IDs", connIDLen), func() { + tlsConf, clientConf := dialAndReceiveSessionTicket(nil) + // now dial new connection with different transport parameters + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + MaxIncomingUniStreams: 1, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() + + conn, err := quic.DialAddrEarly( + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + clientConf, + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + // The client remembers that it was allowed to open 2 uni-directional streams. + firstStr, err := conn.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + written := make(chan struct{}, 2) + go func() { + defer GinkgoRecover() + defer func() { written <- struct{}{} }() + _, err := firstStr.Write([]byte("first flight")) Expect(err).ToNot(HaveOccurred()) - defer ln.Close() - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: ln.Addr().String(), - DelayPacket: func(dir quicproxy.Direction, data []byte) time.Duration { - if dir == quicproxy.DirectionIncoming && wire.IsLongHeaderPacket(data[0]) && data[0]&0x30>>4 == 0 { // Initial packet from client - return rtt/2 + rtt - } - return rtt / 2 - }, - }) + }() + secondStr, err := conn.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + go func() { + defer GinkgoRecover() + defer func() { written <- struct{}{} }() + _, err := secondStr.Write([]byte("first flight")) Expect(err).ToNot(HaveOccurred()) - defer proxy.Close() + }() - transfer0RTTData(ln, proxy.LocalPort(), clientConf, nil, PRData) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + _, err = conn.AcceptStream(ctx) + Expect(err).To(MatchError(quic.Err0RTTRejected)) + Eventually(written).Should(Receive()) + Eventually(written).Should(Receive()) + _, err = firstStr.Write([]byte("foobar")) + Expect(err).To(MatchError(quic.Err0RTTRejected)) + _, err = conn.OpenUniStream() + Expect(err).To(MatchError(quic.Err0RTTRejected)) - Expect(tracer.getRcvdLongHeaderPackets()[0].hdr.Type).To(Equal(protocol.PacketTypeInitial)) - zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) - Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10)) - Expect(zeroRTTPackets[0]).To(Equal(protocol.PacketNumber(0))) - }) + _, err = conn.AcceptStream(ctx) + Expect(err).To(Equal(quic.Err0RTTRejected)) + + newConn := conn.NextConnection() + str, err := newConn.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + _, err = newConn.OpenUniStream() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("too many open streams")) + _, err = str.Write([]byte("second flight")) + Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) + Expect(conn.CloseWithError(0, "")).To(Succeed()) + + // The client should send 0-RTT packets, but the server doesn't process them. + num0RTT := atomic.LoadUint32(num0RTTPackets) + fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) + Expect(num0RTT).ToNot(BeZero()) + Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) }) } + + It("queues 0-RTT packets, if the Initial is delayed", func() { + tlsConf, clientConf := dialAndReceiveSessionTicket(nil) + + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + Allow0RTT: func(net.Addr) bool { return true }, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ + RemoteAddr: ln.Addr().String(), + DelayPacket: func(dir quicproxy.Direction, data []byte) time.Duration { + if dir == quicproxy.DirectionIncoming && wire.IsLongHeaderPacket(data[0]) && data[0]&0x30>>4 == 0 { // Initial packet from client + return rtt/2 + rtt + } + return rtt / 2 + }, + }) + Expect(err).ToNot(HaveOccurred()) + defer proxy.Close() + + transfer0RTTData(ln, proxy.LocalPort(), clientConf, nil, PRData) + + Expect(tracer.getRcvdLongHeaderPackets()[0].hdr.Type).To(Equal(protocol.PacketTypeInitial)) + zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) + Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10)) + Expect(zeroRTTPackets[0]).To(Equal(protocol.PacketNumber(0))) + }) }) diff --git a/integrationtests/versionnegotiation/rtt_test.go b/integrationtests/versionnegotiation/rtt_test.go new file mode 100644 index 00000000..7f8186ee --- /dev/null +++ b/integrationtests/versionnegotiation/rtt_test.go @@ -0,0 +1,53 @@ +package versionnegotiation + +import ( + "time" + + "github.com/quic-go/quic-go" + quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy" + "github.com/quic-go/quic-go/internal/protocol" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Handshake RTT tests", func() { + const rtt = 400 * time.Millisecond + + expectDurationInRTTs := func(startTime time.Time, num int) { + testDuration := time.Since(startTime) + rtts := float32(testDuration) / float32(rtt) + Expect(rtts).To(SatisfyAll( + BeNumerically(">=", num), + BeNumerically("<", num+1), + )) + } + + It("fails when there's no matching version, after 1 RTT", func() { + if len(protocol.SupportedVersions) == 1 { + Skip("Test requires at least 2 supported versions.") + } + + serverConfig := &quic.Config{} + serverConfig.Versions = protocol.SupportedVersions[:1] + ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + + // start the proxy + proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ + RemoteAddr: ln.Addr().String(), + DelayPacket: func(_ quicproxy.Direction, _ []byte) time.Duration { return rtt / 2 }, + }) + Expect(err).ToNot(HaveOccurred()) + + startTime := time.Now() + _, err = quic.DialAddr( + proxy.LocalAddr().String(), + getTLSClientConfig(), + maybeAddQlogTracer(&quic.Config{Versions: protocol.SupportedVersions[1:2]}), + ) + Expect(err).To(HaveOccurred()) + expectDurationInRTTs(startTime, 1) + }) +})