diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index 476cd16d..72b02e36 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -19,6 +19,7 @@ import ( "github.com/lucas-clemente/quic-go/logging" . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" . "github.com/onsi/gomega" ) @@ -550,77 +551,84 @@ var _ = Describe("0-RTT", func() { Expect(get0RTTPackets(tracer.getRcvdPackets())).To(BeEmpty()) }) - It("doesn't reject 0-RTT when the server's transport stream flow control limit increased", func() { - tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{ - InitialStreamReceiveWindow: 3, - AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, - })) + DescribeTable("flow control limits", + func(addFlowControlLimit func(*quic.Config, uint64)) { + tracer := newRcvdPacketTracer() + firstConf := getQuicConfig(&quic.Config{ + AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, + Versions: []protocol.VersionNumber{version}, + }) + addFlowControlLimit(firstConf, 3) + tlsConf, clientConf := dialAndReceiveSessionTicket(firstConf) - tracer := newRcvdPacketTracer() - ln, err := quic.ListenAddrEarly( - "localhost:0", - tlsConf, - getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, - InitialStreamReceiveWindow: 100, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), - }), - ) - Expect(err).ToNot(HaveOccurred()) - defer ln.Close() - proxy, _ := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) - defer proxy.Close() - - sess, err := quic.DialAddrEarly( - fmt.Sprintf("localhost:%d", proxy.LocalPort()), - clientConf, - getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}), - ) - Expect(err).ToNot(HaveOccurred()) - str, err := sess.OpenUniStream() - Expect(err).ToNot(HaveOccurred()) - written := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(written) - _, err := str.Write([]byte("foobar")) + secondConf := getQuicConfig(&quic.Config{ + Versions: []protocol.VersionNumber{version}, + AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + }) + addFlowControlLimit(secondConf, 100) + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + secondConf, + ) Expect(err).ToNot(HaveOccurred()) - Expect(str.Close()).To(Succeed()) - }() + defer ln.Close() + proxy, _ := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() - Eventually(written).Should(BeClosed()) + sess, err := quic.DialAddrEarly( + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + clientConf, + getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}), + ) + Expect(err).ToNot(HaveOccurred()) + str, err := sess.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + written := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(written) + _, err := str.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) + }() - serverSess, err := ln.Accept(context.Background()) - Expect(err).ToNot(HaveOccurred()) - rstr, err := serverSess.AcceptUniStream(context.Background()) - Expect(err).ToNot(HaveOccurred()) - data, err := ioutil.ReadAll(rstr) - Expect(err).ToNot(HaveOccurred()) - Expect(data).To(Equal([]byte("foobar"))) - Expect(serverSess.ConnectionState().TLS.Used0RTT).To(BeTrue()) - Expect(serverSess.CloseWithError(0, "")).To(Succeed()) - Eventually(sess.Context().Done()).Should(BeClosed()) + Eventually(written).Should(BeClosed()) - var processedFirst bool - for _, p := range tracer.getRcvdPackets() { - for _, f := range p.frames { - if sf, ok := f.(*logging.StreamFrame); ok { - if !processedFirst { - // The first STREAM should have been sent in a 0-RTT packet. - // Due to the flow control limit, the STREAM frame was limit to the first 3 bytes. - Expect(p.hdr.Type).To(Equal(protocol.PacketType0RTT)) - Expect(sf.Length).To(BeEquivalentTo(3)) - processedFirst = true - } else { - // All other STREAM frames can only be sent after handshake completion. - Expect(p.hdr.IsLongHeader).To(BeFalse()) - Expect(sf.Offset).ToNot(BeZero()) + serverSess, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + rstr, err := serverSess.AcceptUniStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + data, err := ioutil.ReadAll(rstr) + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal([]byte("foobar"))) + Expect(serverSess.ConnectionState().TLS.Used0RTT).To(BeTrue()) + Expect(serverSess.CloseWithError(0, "")).To(Succeed()) + Eventually(sess.Context().Done()).Should(BeClosed()) + + var processedFirst bool + for _, p := range tracer.getRcvdPackets() { + for _, f := range p.frames { + if sf, ok := f.(*logging.StreamFrame); ok { + if !processedFirst { + // The first STREAM should have been sent in a 0-RTT packet. + // Due to the flow control limit, the STREAM frame was limit to the first 3 bytes. + Expect(p.hdr.Type).To(Equal(protocol.PacketType0RTT)) + Expect(sf.Length).To(BeEquivalentTo(3)) + processedFirst = true + } else { + // All other STREAM frames can only be sent after handshake completion. + Expect(p.hdr.IsLongHeader).To(BeFalse()) + Expect(sf.Offset).ToNot(BeZero()) + } } } } - } - }) + }, + Entry("doesn't reject 0-RTT when the server's transport stream flow control limit increased", func(c *quic.Config, limit uint64) { c.InitialStreamReceiveWindow = limit }), + Entry("doesn't reject 0-RTT when the server's transport connection flow control limit increased", func(c *quic.Config, limit uint64) { c.InitialConnectionReceiveWindow = limit }), + ) It("correctly deals with 0-RTT rejections", func() { tlsConf, clientConf := dialAndReceiveSessionTicket(nil) diff --git a/internal/handshake/crypto_setup_test.go b/internal/handshake/crypto_setup_test.go index 8a663289..d09673e9 100644 --- a/internal/handshake/crypto_setup_test.go +++ b/internal/handshake/crypto_setup_test.go @@ -776,7 +776,7 @@ var _ = Describe("Crypto Setup TLS", func() { Expect(client.ConnectionState().Used0RTT).To(BeTrue()) }) - It("rejects 0-RTT, whent the transport parameters changed", func() { + It("rejects 0-RTT, when the transport parameters changed", func() { csc := mocktls.NewMockClientSessionCache(mockCtrl) var state *tls.ClientSessionState receivedSessionTicket := make(chan struct{}) @@ -810,7 +810,7 @@ var _ = Describe("Crypto Setup TLS", func() { clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf( clientConf, serverConf, clientRTTStats, &utils.RTTStats{}, - &wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData + 1}, + &wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData - 1}, true, ) Expect(clientErr).ToNot(HaveOccurred()) diff --git a/internal/wire/transport_parameter_test.go b/internal/wire/transport_parameter_test.go index 811c47ea..283fdab6 100644 --- a/internal/wire/transport_parameter_test.go +++ b/internal/wire/transport_parameter_test.go @@ -528,11 +528,16 @@ var _ = Describe("Transport Parameters", func() { Expect(p.ValidFor0RTT(saved)).To(BeTrue()) }) - It("rejects the parameters if the InitialMaxData changed", func() { - p.InitialMaxData = 0 + It("rejects the parameters if the InitialMaxData was reduced", func() { + p.InitialMaxData = saved.InitialMaxData - 1 Expect(p.ValidFor0RTT(saved)).To(BeFalse()) }) + It("doesn't reject the parameters if the InitialMaxData was increased", func() { + p.InitialMaxData = saved.InitialMaxData + 1 + Expect(p.ValidFor0RTT(saved)).To(BeTrue()) + }) + It("rejects the parameters if the MaxBidiStreamNum was reduced", func() { p.MaxBidiStreamNum = saved.MaxBidiStreamNum - 1 Expect(p.ValidFor0RTT(saved)).To(BeFalse()) diff --git a/internal/wire/transport_parameters.go b/internal/wire/transport_parameters.go index 959d8257..1f1085bc 100644 --- a/internal/wire/transport_parameters.go +++ b/internal/wire/transport_parameters.go @@ -444,7 +444,7 @@ func (p *TransportParameters) ValidFor0RTT(saved *TransportParameters) bool { return p.InitialMaxStreamDataBidiLocal >= saved.InitialMaxStreamDataBidiLocal && p.InitialMaxStreamDataBidiRemote >= saved.InitialMaxStreamDataBidiRemote && p.InitialMaxStreamDataUni >= saved.InitialMaxStreamDataUni && - p.InitialMaxData == saved.InitialMaxData && + p.InitialMaxData >= saved.InitialMaxData && p.MaxBidiStreamNum >= saved.MaxBidiStreamNum && p.MaxUniStreamNum >= saved.MaxUniStreamNum && p.ActiveConnectionIDLimit == saved.ActiveConnectionIDLimit