diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index 3cbae9c6..476cd16d 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -22,21 +22,26 @@ import ( . "github.com/onsi/gomega" ) +type rcvdPacket struct { + hdr *logging.ExtendedHeader + frames []logging.Frame +} + type rcvdPacketTracer struct { connTracer closed chan struct{} - rcvdPackets []*logging.ExtendedHeader + rcvdPackets []rcvdPacket } func newRcvdPacketTracer() *rcvdPacketTracer { return &rcvdPacketTracer{closed: make(chan struct{})} } -func (t *rcvdPacketTracer) ReceivedPacket(hdr *logging.ExtendedHeader, _ logging.ByteCount, _ []logging.Frame) { - t.rcvdPackets = append(t.rcvdPackets, hdr) +func (t *rcvdPacketTracer) ReceivedPacket(hdr *logging.ExtendedHeader, _ logging.ByteCount, frames []logging.Frame) { + t.rcvdPackets = append(t.rcvdPackets, rcvdPacket{hdr: hdr, frames: frames}) } func (t *rcvdPacketTracer) Close() { close(t.closed) } -func (t *rcvdPacketTracer) getRcvdPackets() []*logging.ExtendedHeader { +func (t *rcvdPacketTracer) getRcvdPackets() []rcvdPacket { <-t.closed return t.rcvdPackets } @@ -187,11 +192,11 @@ var _ = Describe("0-RTT", func() { } // can be used to extract 0-RTT from a rcvdPacketTracer - get0RTTPackets := func(hdrs []*logging.ExtendedHeader) []protocol.PacketNumber { + get0RTTPackets := func(packets []rcvdPacket) []protocol.PacketNumber { var zeroRTTPackets []protocol.PacketNumber - for _, hdr := range hdrs { - if hdr.Type == protocol.PacketType0RTT { - zeroRTTPackets = append(zeroRTTPackets, hdr.PacketNumber) + for _, p := range packets { + if p.hdr.Type == protocol.PacketType0RTT { + zeroRTTPackets = append(zeroRTTPackets, p.hdr.PacketNumber) } } return zeroRTTPackets @@ -545,6 +550,78 @@ var _ = Describe("0-RTT", func() { Expect(get0RTTPackets(tracer.getRcvdPackets())).To(BeEmpty()) }) + It("doesn't reject 0-RTT when the server's transport stream flow control limit increased", func() { + tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{ + InitialStreamReceiveWindow: 3, + AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, + })) + + tracer := newRcvdPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + Versions: []protocol.VersionNumber{version}, + AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, + InitialStreamReceiveWindow: 100, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + proxy, _ := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() + + sess, err := quic.DialAddrEarly( + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + clientConf, + getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}), + ) + Expect(err).ToNot(HaveOccurred()) + str, err := sess.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + written := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(written) + _, err := str.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) + }() + + Eventually(written).Should(BeClosed()) + + serverSess, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + rstr, err := serverSess.AcceptUniStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + data, err := ioutil.ReadAll(rstr) + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal([]byte("foobar"))) + Expect(serverSess.ConnectionState().TLS.Used0RTT).To(BeTrue()) + Expect(serverSess.CloseWithError(0, "")).To(Succeed()) + Eventually(sess.Context().Done()).Should(BeClosed()) + + var processedFirst bool + for _, p := range tracer.getRcvdPackets() { + for _, f := range p.frames { + if sf, ok := f.(*logging.StreamFrame); ok { + if !processedFirst { + // The first STREAM should have been sent in a 0-RTT packet. + // Due to the flow control limit, the STREAM frame was limit to the first 3 bytes. + Expect(p.hdr.Type).To(Equal(protocol.PacketType0RTT)) + Expect(sf.Length).To(BeEquivalentTo(3)) + processedFirst = true + } else { + // All other STREAM frames can only be sent after handshake completion. + Expect(p.hdr.IsLongHeader).To(BeFalse()) + Expect(sf.Offset).ToNot(BeZero()) + } + } + } + } + }) + It("correctly deals with 0-RTT rejections", func() { tlsConf, clientConf := dialAndReceiveSessionTicket(nil) // now dial new connection with different transport parameters @@ -648,7 +725,7 @@ var _ = Describe("0-RTT", func() { transfer0RTTData(ln, proxy.LocalPort(), clientConf, PRData, true) - Expect(tracer.rcvdPackets[0].Type).To(Equal(protocol.PacketTypeInitial)) + Expect(tracer.rcvdPackets[0].hdr.Type).To(Equal(protocol.PacketTypeInitial)) zeroRTTPackets := get0RTTPackets(tracer.getRcvdPackets()) Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10)) Expect(zeroRTTPackets[0]).To(Equal(protocol.PacketNumber(0))) diff --git a/internal/wire/transport_parameter_test.go b/internal/wire/transport_parameter_test.go index f4b3a80f..811c47ea 100644 --- a/internal/wire/transport_parameter_test.go +++ b/internal/wire/transport_parameter_test.go @@ -498,28 +498,43 @@ var _ = Describe("Transport Parameters", func() { Expect(p.ValidFor0RTT(saved)).To(BeTrue()) }) - It("rejects the parameters if the InitialMaxStreamDataBidiLocal changed", func() { - p.InitialMaxStreamDataBidiLocal = 0 + It("rejects the parameters if the InitialMaxStreamDataBidiLocal was reduced", func() { + p.InitialMaxStreamDataBidiLocal = saved.InitialMaxStreamDataBidiLocal - 1 Expect(p.ValidFor0RTT(saved)).To(BeFalse()) }) - It("rejects the parameters if the InitialMaxStreamDataBidiRemote changed", func() { - p.InitialMaxStreamDataBidiRemote = 0 + It("doesn't reject the parameters if the InitialMaxStreamDataBidiLocal was increased", func() { + p.InitialMaxStreamDataBidiLocal = saved.InitialMaxStreamDataBidiLocal + 1 + Expect(p.ValidFor0RTT(saved)).To(BeTrue()) + }) + + It("rejects the parameters if the InitialMaxStreamDataBidiRemote was reduced", func() { + p.InitialMaxStreamDataBidiRemote = saved.InitialMaxStreamDataBidiRemote - 1 Expect(p.ValidFor0RTT(saved)).To(BeFalse()) }) - It("rejects the parameters if the InitialMaxStreamDataUni changed", func() { - p.InitialMaxStreamDataUni = 0 + It("doesn't reject the parameters if the InitialMaxStreamDataBidiRemote was increased", func() { + p.InitialMaxStreamDataBidiRemote = saved.InitialMaxStreamDataBidiRemote + 1 + Expect(p.ValidFor0RTT(saved)).To(BeTrue()) + }) + + It("rejects the parameters if the InitialMaxStreamDataUni was reduced", func() { + p.InitialMaxStreamDataUni = saved.InitialMaxStreamDataUni - 1 Expect(p.ValidFor0RTT(saved)).To(BeFalse()) }) + It("doesn't reject the parameters if the InitialMaxStreamDataUni was increased", func() { + p.InitialMaxStreamDataUni = saved.InitialMaxStreamDataUni + 1 + Expect(p.ValidFor0RTT(saved)).To(BeTrue()) + }) + It("rejects the parameters if the InitialMaxData changed", func() { p.InitialMaxData = 0 Expect(p.ValidFor0RTT(saved)).To(BeFalse()) }) - It("rejects the parameters if the MaxBidiStreamNum changed", func() { - p.MaxBidiStreamNum = 0 + It("rejects the parameters if the MaxBidiStreamNum was reduced", func() { + p.MaxBidiStreamNum = saved.MaxBidiStreamNum - 1 Expect(p.ValidFor0RTT(saved)).To(BeFalse()) }) diff --git a/internal/wire/transport_parameters.go b/internal/wire/transport_parameters.go index 6bf437dc..959d8257 100644 --- a/internal/wire/transport_parameters.go +++ b/internal/wire/transport_parameters.go @@ -441,9 +441,9 @@ func (p *TransportParameters) UnmarshalFromSessionTicket(r *bytes.Reader) error // ValidFor0RTT checks if the transport parameters match those saved in the session ticket. func (p *TransportParameters) ValidFor0RTT(saved *TransportParameters) bool { - return p.InitialMaxStreamDataBidiLocal == saved.InitialMaxStreamDataBidiLocal && - p.InitialMaxStreamDataBidiRemote == saved.InitialMaxStreamDataBidiRemote && - p.InitialMaxStreamDataUni == saved.InitialMaxStreamDataUni && + return p.InitialMaxStreamDataBidiLocal >= saved.InitialMaxStreamDataBidiLocal && + p.InitialMaxStreamDataBidiRemote >= saved.InitialMaxStreamDataBidiRemote && + p.InitialMaxStreamDataUni >= saved.InitialMaxStreamDataUni && p.InitialMaxData == saved.InitialMaxData && p.MaxBidiStreamNum >= saved.MaxBidiStreamNum && p.MaxUniStreamNum >= saved.MaxUniStreamNum &&