From 6880f88089d442d6eb739bcbb84875588896dc82 Mon Sep 17 00:00:00 2001 From: Ameagari <47713057+tanghaowillow@users.noreply.github.com> Date: Sat, 19 Aug 2023 10:16:16 +0800 Subject: [PATCH] save the max_datagram_frame_size transport parameter in the session ticket (#4013) * Add MaxDatagramFrameSize parameter in session ticket * fix gofumpt issues * Update integrationtests/self/zero_rtt_test.go Co-authored-by: Marten Seemann * fix: correct comparsion of max_datagram_frame_size * test: use constant MaxDatagramFrameSize for session ticket test * fix grammar --------- Co-authored-by: Marten Seemann --- integrationtests/self/zero_rtt_oldgo_test.go | 112 ++++++++++++++++++ integrationtests/self/zero_rtt_test.go | 115 +++++++++++++++++++ internal/handshake/session_ticket_test.go | 2 + internal/wire/transport_parameter_test.go | 24 ++++ internal/wire/transport_parameters.go | 10 ++ 5 files changed, 263 insertions(+) diff --git a/integrationtests/self/zero_rtt_oldgo_test.go b/integrationtests/self/zero_rtt_oldgo_test.go index beaf351e..e5809286 100644 --- a/integrationtests/self/zero_rtt_oldgo_test.go +++ b/integrationtests/self/zero_rtt_oldgo_test.go @@ -801,4 +801,116 @@ var _ = Describe("0-RTT", func() { Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10)) Expect(zeroRTTPackets[0]).To(Equal(protocol.PacketNumber(0))) }) + + It("sends 0-RTT datagrams", func() { + tlsConf, clientTLSConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{ + EnableDatagrams: true, + })) + + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + Allow0RTT: true, + EnableDatagrams: true, + Tracer: newTracer(tracer), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + + proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() + + // second connection + sentMessage := GeneratePRData(100) + var receivedMessage []byte + received := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(received) + conn, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + receivedMessage, err = conn.ReceiveMessage(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(conn.ConnectionState().Used0RTT).To(BeTrue()) + }() + conn, err := quic.DialAddrEarly( + context.Background(), + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + clientTLSConf, + getQuicConfig(&quic.Config{ + EnableDatagrams: true, + }), + ) + Expect(err).ToNot(HaveOccurred()) + Expect(conn.ConnectionState().SupportsDatagrams).To(BeTrue()) + Expect(conn.SendMessage(sentMessage)).To(Succeed()) + <-conn.HandshakeComplete() + <-received + + Expect(conn.ConnectionState().Used0RTT).To(BeTrue()) + Expect(receivedMessage).To(Equal(sentMessage)) + num0RTT := atomic.LoadUint32(num0RTTPackets) + fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) + Expect(num0RTT).ToNot(BeZero()) + zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) + Expect(zeroRTTPackets).To(HaveLen(1)) + Expect(conn.CloseWithError(0, "")).To(Succeed()) + }) + + It("rejects 0-RTT datagrams when the server doesn't support datagrams anymore", func() { + tlsConf, clientTLSConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{ + EnableDatagrams: true, + })) + + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + Allow0RTT: true, + EnableDatagrams: false, + Tracer: newTracer(tracer), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + + proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() + + // second connection + go func() { + defer GinkgoRecover() + conn, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + _, err = conn.ReceiveMessage(context.Background()) + Expect(err.Error()).To(Equal("datagram support disabled")) + <-conn.HandshakeComplete() + Expect(conn.ConnectionState().Used0RTT).To(BeFalse()) + }() + conn, err := quic.DialAddrEarly( + context.Background(), + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + clientTLSConf, + getQuicConfig(&quic.Config{ + EnableDatagrams: true, + }), + ) + Expect(err).ToNot(HaveOccurred()) + // the client can temporarily send datagrams but the server doesn't process them. + Expect(conn.ConnectionState().SupportsDatagrams).To(BeTrue()) + Expect(conn.SendMessage(make([]byte, 100))).To(Succeed()) + <-conn.HandshakeComplete() + + Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse()) + Expect(conn.ConnectionState().Used0RTT).To(BeFalse()) + num0RTT := atomic.LoadUint32(num0RTTPackets) + fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) + Expect(num0RTT).ToNot(BeZero()) + Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + Expect(conn.CloseWithError(0, "")).To(Succeed()) + }) }) diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index 2de283ac..011687ae 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -938,4 +938,119 @@ var _ = Describe("0-RTT", func() { ) Expect(restored).To(BeTrue()) }) + + It("sends 0-RTT datagrams", func() { + tlsConf := getTLSConfig() + clientTLSConf := getTLSClientConfig() + dialAndReceiveSessionTicket(tlsConf, getQuicConfig(&quic.Config{ + EnableDatagrams: true, + }), clientTLSConf) + + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + Allow0RTT: true, + EnableDatagrams: true, + Tracer: newTracer(tracer), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() + + // second connection + sentMessage := GeneratePRData(100) + var receivedMessage []byte + received := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(received) + conn, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + receivedMessage, err = conn.ReceiveMessage(context.Background()) + Expect(err).ToNot(HaveOccurred()) + Expect(conn.ConnectionState().Used0RTT).To(BeTrue()) + }() + conn, err := quic.DialAddrEarly( + context.Background(), + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + clientTLSConf, + getQuicConfig(&quic.Config{ + EnableDatagrams: true, + }), + ) + Expect(err).ToNot(HaveOccurred()) + Expect(conn.ConnectionState().SupportsDatagrams).To(BeTrue()) + Expect(conn.SendMessage(sentMessage)).To(Succeed()) + <-conn.HandshakeComplete() + <-received + + Expect(conn.ConnectionState().Used0RTT).To(BeTrue()) + Expect(receivedMessage).To(Equal(sentMessage)) + num0RTT := atomic.LoadUint32(num0RTTPackets) + fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) + Expect(num0RTT).ToNot(BeZero()) + zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) + Expect(zeroRTTPackets).To(HaveLen(1)) + Expect(conn.CloseWithError(0, "")).To(Succeed()) + }) + + It("rejects 0-RTT datagrams when the server doesn't support datagrams anymore", func() { + tlsConf := getTLSConfig() + clientTLSConf := getTLSClientConfig() + dialAndReceiveSessionTicket(tlsConf, getQuicConfig(&quic.Config{ + EnableDatagrams: true, + }), clientTLSConf) + + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + Allow0RTT: true, + EnableDatagrams: false, + Tracer: newTracer(tracer), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + + proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() + + // second connection + go func() { + defer GinkgoRecover() + conn, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + _, err = conn.ReceiveMessage(context.Background()) + Expect(err.Error()).To(Equal("datagram support disabled")) + <-conn.HandshakeComplete() + Expect(conn.ConnectionState().Used0RTT).To(BeFalse()) + }() + conn, err := quic.DialAddrEarly( + context.Background(), + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + clientTLSConf, + getQuicConfig(&quic.Config{ + EnableDatagrams: true, + }), + ) + Expect(err).ToNot(HaveOccurred()) + // the client can temporarily send datagrams but the server doesn't process them. + Expect(conn.ConnectionState().SupportsDatagrams).To(BeTrue()) + Expect(conn.SendMessage(make([]byte, 100))).To(Succeed()) + <-conn.HandshakeComplete() + + Expect(conn.ConnectionState().SupportsDatagrams).To(BeFalse()) + Expect(conn.ConnectionState().Used0RTT).To(BeFalse()) + num0RTT := atomic.LoadUint32(num0RTTPackets) + fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) + Expect(num0RTT).ToNot(BeZero()) + Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + Expect(conn.CloseWithError(0, "")).To(Succeed()) + }) }) diff --git a/internal/handshake/session_ticket_test.go b/internal/handshake/session_ticket_test.go index 67e33b22..6f004de9 100644 --- a/internal/handshake/session_ticket_test.go +++ b/internal/handshake/session_ticket_test.go @@ -17,6 +17,7 @@ var _ = Describe("Session Ticket", func() { InitialMaxStreamDataBidiLocal: 1, InitialMaxStreamDataBidiRemote: 2, ActiveConnectionIDLimit: 10, + MaxDatagramFrameSize: 20, }, RTT: 1337 * time.Microsecond, } @@ -25,6 +26,7 @@ var _ = Describe("Session Ticket", func() { Expect(t.Parameters.InitialMaxStreamDataBidiLocal).To(BeEquivalentTo(1)) Expect(t.Parameters.InitialMaxStreamDataBidiRemote).To(BeEquivalentTo(2)) Expect(t.Parameters.ActiveConnectionIDLimit).To(BeEquivalentTo(10)) + Expect(t.Parameters.MaxDatagramFrameSize).To(BeEquivalentTo(20)) Expect(t.RTT).To(Equal(1337 * time.Microsecond)) }) diff --git a/internal/wire/transport_parameter_test.go b/internal/wire/transport_parameter_test.go index 9dd306b7..2fb79539 100644 --- a/internal/wire/transport_parameter_test.go +++ b/internal/wire/transport_parameter_test.go @@ -503,6 +503,7 @@ var _ = Describe("Transport Parameters", func() { MaxBidiStreamNum: protocol.StreamNum(getRandomValueUpTo(int64(protocol.MaxStreamCount))), MaxUniStreamNum: protocol.StreamNum(getRandomValueUpTo(int64(protocol.MaxStreamCount))), ActiveConnectionIDLimit: 2 + getRandomValueUpTo(math.MaxInt64-2), + MaxDatagramFrameSize: protocol.ByteCount(getRandomValueUpTo(int64(protocol.MaxDatagramFrameSize))), } Expect(params.ValidFor0RTT(params)).To(BeTrue()) b := params.MarshalForSessionTicket(nil) @@ -515,6 +516,7 @@ var _ = Describe("Transport Parameters", func() { Expect(tp.MaxBidiStreamNum).To(Equal(params.MaxBidiStreamNum)) Expect(tp.MaxUniStreamNum).To(Equal(params.MaxUniStreamNum)) Expect(tp.ActiveConnectionIDLimit).To(Equal(params.ActiveConnectionIDLimit)) + Expect(tp.MaxDatagramFrameSize).To(Equal(params.MaxDatagramFrameSize)) }) It("rejects the parameters if it can't parse them", func() { @@ -540,6 +542,7 @@ var _ = Describe("Transport Parameters", func() { MaxBidiStreamNum: 5, MaxUniStreamNum: 6, ActiveConnectionIDLimit: 7, + MaxDatagramFrameSize: 1000, } BeforeEach(func() { @@ -611,6 +614,16 @@ var _ = Describe("Transport Parameters", func() { p.ActiveConnectionIDLimit = 0 Expect(p.ValidFor0RTT(saved)).To(BeFalse()) }) + + It("accepts the parameters if the MaxDatagramFrameSize was increased", func() { + p.MaxDatagramFrameSize = saved.MaxDatagramFrameSize + 1 + Expect(p.ValidFor0RTT(saved)).To(BeTrue()) + }) + + It("rejects the parameters if the MaxDatagramFrameSize reduced", func() { + p.MaxDatagramFrameSize = saved.MaxDatagramFrameSize - 1 + Expect(p.ValidFor0RTT(saved)).To(BeFalse()) + }) }) Context("client checks the parameters after successfully sending 0-RTT data", func() { @@ -623,6 +636,7 @@ var _ = Describe("Transport Parameters", func() { MaxBidiStreamNum: 5, MaxUniStreamNum: 6, ActiveConnectionIDLimit: 7, + MaxDatagramFrameSize: 1000, } BeforeEach(func() { @@ -699,6 +713,16 @@ var _ = Describe("Transport Parameters", func() { p.ActiveConnectionIDLimit = saved.ActiveConnectionIDLimit + 1 Expect(p.ValidForUpdate(saved)).To(BeTrue()) }) + + It("rejects the parameters if the MaxDatagramFrameSize reduced", func() { + p.MaxDatagramFrameSize = saved.MaxDatagramFrameSize - 1 + Expect(p.ValidForUpdate(saved)).To(BeFalse()) + }) + + It("doesn't reject the parameters if the MaxDatagramFrameSize increased", func() { + p.MaxDatagramFrameSize = saved.MaxDatagramFrameSize + 1 + Expect(p.ValidForUpdate(saved)).To(BeTrue()) + }) }) }) }) diff --git a/internal/wire/transport_parameters.go b/internal/wire/transport_parameters.go index dc0aa22f..7226521b 100644 --- a/internal/wire/transport_parameters.go +++ b/internal/wire/transport_parameters.go @@ -454,6 +454,10 @@ func (p *TransportParameters) MarshalForSessionTicket(b []byte) []byte { b = p.marshalVarintParam(b, initialMaxStreamsBidiParameterID, uint64(p.MaxBidiStreamNum)) // initial_max_uni_streams b = p.marshalVarintParam(b, initialMaxStreamsUniParameterID, uint64(p.MaxUniStreamNum)) + // max_datagram_frame_size + if p.MaxDatagramFrameSize != protocol.InvalidByteCount { + b = p.marshalVarintParam(b, maxDatagramFrameSizeParameterID, uint64(p.MaxDatagramFrameSize)) + } // active_connection_id_limit return p.marshalVarintParam(b, activeConnectionIDLimitParameterID, p.ActiveConnectionIDLimit) } @@ -472,6 +476,9 @@ func (p *TransportParameters) UnmarshalFromSessionTicket(r *bytes.Reader) error // ValidFor0RTT checks if the transport parameters match those saved in the session ticket. func (p *TransportParameters) ValidFor0RTT(saved *TransportParameters) bool { + if saved.MaxDatagramFrameSize != protocol.InvalidByteCount && (p.MaxDatagramFrameSize == protocol.InvalidByteCount || p.MaxDatagramFrameSize < saved.MaxDatagramFrameSize) { + return false + } return p.InitialMaxStreamDataBidiLocal >= saved.InitialMaxStreamDataBidiLocal && p.InitialMaxStreamDataBidiRemote >= saved.InitialMaxStreamDataBidiRemote && p.InitialMaxStreamDataUni >= saved.InitialMaxStreamDataUni && @@ -484,6 +491,9 @@ func (p *TransportParameters) ValidFor0RTT(saved *TransportParameters) bool { // ValidForUpdate checks that the new transport parameters don't reduce limits after resuming a 0-RTT connection. // It is only used on the client side. func (p *TransportParameters) ValidForUpdate(saved *TransportParameters) bool { + if saved.MaxDatagramFrameSize != protocol.InvalidByteCount && (p.MaxDatagramFrameSize == protocol.InvalidByteCount || p.MaxDatagramFrameSize < saved.MaxDatagramFrameSize) { + return false + } return p.ActiveConnectionIDLimit >= saved.ActiveConnectionIDLimit && p.InitialMaxData >= saved.InitialMaxData && p.InitialMaxStreamDataBidiLocal >= saved.InitialMaxStreamDataBidiLocal &&