diff --git a/integrationtests/self/handshake_rtt_test.go b/integrationtests/self/handshake_rtt_test.go index 60f173ca..1c965299 100644 --- a/integrationtests/self/handshake_rtt_test.go +++ b/integrationtests/self/handshake_rtt_test.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "fmt" + "io" "net" "time" @@ -18,54 +19,33 @@ import ( var _ = Describe("Handshake RTT tests", func() { var ( proxy *quicproxy.QuicProxy - server quic.Listener serverConfig *quic.Config serverTLSConfig *tls.Config - testStartedAt time.Time - acceptStopped chan struct{} ) - rtt := 400 * time.Millisecond + const rtt = 400 * time.Millisecond BeforeEach(func() { - acceptStopped = make(chan struct{}) serverConfig = getQuicConfig(nil) serverTLSConfig = getTLSConfig() }) AfterEach(func() { Expect(proxy.Close()).To(Succeed()) - Expect(server.Close()).To(Succeed()) - <-acceptStopped }) - runServerAndProxy := func() { + runProxy := func(serverAddr net.Addr) { var err error - // start the server - server, err = quic.ListenAddr("localhost:0", serverTLSConfig, serverConfig) - Expect(err).ToNot(HaveOccurred()) // start the proxy proxy, err = quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: server.Addr().String(), + RemoteAddr: serverAddr.String(), DelayPacket: func(_ quicproxy.Direction, _ []byte) time.Duration { return rtt / 2 }, }) Expect(err).ToNot(HaveOccurred()) - - testStartedAt = time.Now() - - go func() { - defer GinkgoRecover() - defer close(acceptStopped) - for { - if _, err := server.Accept(context.Background()); err != nil { - return - } - } - }() } - expectDurationInRTTs := func(num int) { - testDuration := time.Since(testStartedAt) + expectDurationInRTTs := func(startTime time.Time, num int) { + testDuration := time.Since(startTime) rtts := float32(testDuration) / float32(rtt) Expect(rtts).To(SatisfyAll( BeNumerically(">=", num), @@ -78,15 +58,19 @@ var _ = Describe("Handshake RTT tests", func() { Skip("Test requires at least 2 supported versions.") } serverConfig.Versions = protocol.SupportedVersions[:1] - runServerAndProxy() - clientConfig := getQuicConfig(&quic.Config{Versions: protocol.SupportedVersions[1:2]}) - _, err := quic.DialAddr( + ln, err := quic.ListenAddr("localhost:0", serverTLSConfig, serverConfig) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + + runProxy(ln.Addr()) + startTime := time.Now() + _, err = quic.DialAddr( proxy.LocalAddr().String(), getTLSClientConfig(), - clientConfig, + getQuicConfig(&quic.Config{Versions: protocol.SupportedVersions[1:2]}), ) Expect(err).To(HaveOccurred()) - expectDurationInRTTs(1) + expectDurationInRTTs(startTime, 1) }) var clientConfig *quic.Config @@ -102,36 +86,114 @@ var _ = Describe("Handshake RTT tests", func() { // 1 RTT for the TLS handshake It("is forward-secure after 2 RTTs", func() { serverConfig.RequireAddressValidation = func(net.Addr) bool { return true } - runServerAndProxy() - _, err := quic.DialAddr( + ln, err := quic.ListenAddr("localhost:0", serverTLSConfig, serverConfig) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + + runProxy(ln.Addr()) + startTime := time.Now() + _, err = quic.DialAddr( fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port), getTLSClientConfig(), clientConfig, ) Expect(err).ToNot(HaveOccurred()) - expectDurationInRTTs(2) + expectDurationInRTTs(startTime, 2) }) It("establishes a connection in 1 RTT when the server doesn't require a token", func() { - runServerAndProxy() - _, err := quic.DialAddr( + ln, err := quic.ListenAddr("localhost:0", serverTLSConfig, serverConfig) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + + runProxy(ln.Addr()) + startTime := time.Now() + _, err = quic.DialAddr( fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port), getTLSClientConfig(), clientConfig, ) Expect(err).ToNot(HaveOccurred()) - expectDurationInRTTs(1) + expectDurationInRTTs(startTime, 1) }) It("establishes a connection in 2 RTTs if a HelloRetryRequest is performed", func() { serverTLSConfig.CurvePreferences = []tls.CurveID{tls.CurveP384} - runServerAndProxy() - _, err := quic.DialAddr( + ln, err := quic.ListenAddr("localhost:0", serverTLSConfig, serverConfig) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + + runProxy(ln.Addr()) + startTime := time.Now() + _, err = quic.DialAddr( fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port), getTLSClientConfig(), clientConfig, ) Expect(err).ToNot(HaveOccurred()) - expectDurationInRTTs(2) + expectDurationInRTTs(startTime, 2) + }) + + It("receives the first message from the server after 2 RTTs, when the server uses ListenAddr", func() { + ln, err := quic.ListenAddr("localhost:0", serverTLSConfig, serverConfig) + Expect(err).ToNot(HaveOccurred()) + go func() { + conn, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + str, err := conn.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + _, err = str.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) + }() + defer ln.Close() + + runProxy(ln.Addr()) + startTime := time.Now() + conn, err := quic.DialAddr( + fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port), + getTLSClientConfig(), + clientConfig, + ) + Expect(err).ToNot(HaveOccurred()) + str, err := conn.AcceptUniStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + data, err := io.ReadAll(str) + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal([]byte("foobar"))) + expectDurationInRTTs(startTime, 2) + }) + + It("receives the first message from the server after 1 RTT, when the server uses ListenAddrEarly", func() { + ln, err := quic.ListenAddrEarly("localhost:0", serverTLSConfig, serverConfig) + Expect(err).ToNot(HaveOccurred()) + go func() { + conn, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + // Check the ALPN now. This is probably what an application would do. + // It makes sure that ConnectionState does not block until the handshake completes. + Expect(conn.ConnectionState().TLS.NegotiatedProtocol).To(Equal(alpn)) + str, err := conn.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + _, err = str.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) + }() + defer ln.Close() + + runProxy(ln.Addr()) + startTime := time.Now() + conn, err := quic.DialAddr( + fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port), + getTLSClientConfig(), + clientConfig, + ) + Expect(err).ToNot(HaveOccurred()) + str, err := conn.AcceptUniStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + data, err := io.ReadAll(str) + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal([]byte("foobar"))) + expectDurationInRTTs(startTime, 1) }) })