From f8313d868fcb80d77a8edc7db67bf16e6a95df33 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 4 Mar 2021 18:14:25 +0800 Subject: [PATCH] return an Err0RTTRejected when the server rejects a 0-RTT connection --- integrationtests/self/zero_rtt_test.go | 133 ++++++++++++++++++++++++- interface.go | 2 + internal/mocks/quic/early_session.go | 14 +++ mock_quic_session_test.go | 14 +++ session.go | 15 +++ 5 files changed, 174 insertions(+), 4 deletions(-) diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index 3ab95efb..904ea907 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -21,11 +21,22 @@ import ( ) var _ = Describe("0-RTT", func() { - const rtt = 50 * time.Millisecond + rtt := scaleDuration(5 * time.Millisecond) + for _, v := range protocol.SupportedVersions { version := v Context(fmt.Sprintf("with QUIC version %s", version), func() { + runDelayProxy := func(serverPort int) *quicproxy.QuicProxy { + proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ + RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), + DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration { return rtt / 2 }, + }) + Expect(err).ToNot(HaveOccurred()) + + return proxy + } + runCountingProxy := func(serverPort int) (*quicproxy.QuicProxy, *uint32) { var num0RTTPackets uint32 // to be used as an atomic proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ @@ -105,6 +116,34 @@ var _ = Describe("0-RTT", func() { Eventually(done).Should(BeClosed()) } + check0RTTRejected := func( + ln quic.EarlyListener, + proxyPort int, + clientConf *tls.Config, + ) { + sess, err := quic.DialAddrEarly( + fmt.Sprintf("localhost:%d", proxyPort), + clientConf, + getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}), + ) + Expect(err).ToNot(HaveOccurred()) + str, err := sess.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + _, err = str.Write(make([]byte, 3000)) + Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) + Expect(sess.ConnectionState().TLS.Used0RTT).To(BeFalse()) + + // make sure the server doesn't process the data + ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(50*time.Millisecond)) + defer cancel() + serverSess, err := ln.Accept(ctx) + Expect(err).ToNot(HaveOccurred()) + Expect(serverSess.ConnectionState().TLS.Used0RTT).To(BeFalse()) + _, err = serverSess.AcceptUniStream(ctx) + Expect(err).To(Equal(context.DeadlineExceeded)) + } + It("transfers 0-RTT data", func() { ln, err := quic.ListenAddrEarly( "localhost:0", @@ -354,7 +393,7 @@ var _ = Describe("0-RTT", func() { Expect(err).ToNot(HaveOccurred()) proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) defer proxy.Close() - transfer0RTTData(ln, proxy.LocalPort(), clientConf, PRData, false) + check0RTTRejected(ln, proxy.LocalPort(), clientConf) // The client should send 0-RTT packets, but the server doesn't process them. num0RTT := atomic.LoadUint32(num0RTTPackets) @@ -374,7 +413,9 @@ var _ = Describe("0-RTT", func() { ) Expect(err).ToNot(HaveOccurred()) - clientConf := dialAndReceiveSessionTicket(ln, ln.Addr().(*net.UDPAddr).Port) + delayProxy := runDelayProxy(ln.Addr().(*net.UDPAddr).Port) + defer delayProxy.Close() + clientConf := dialAndReceiveSessionTicket(ln, delayProxy.LocalPort()) // now close the listener and dial new connection with a different ALPN Expect(ln.Close()).To(Succeed()) @@ -391,7 +432,91 @@ var _ = Describe("0-RTT", func() { Expect(err).ToNot(HaveOccurred()) proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) defer proxy.Close() - transfer0RTTData(ln, proxy.LocalPort(), clientConf, PRData, false) + + check0RTTRejected(ln, proxy.LocalPort(), clientConf) + + // The client should send 0-RTT packets, but the server doesn't process them. + num0RTT := atomic.LoadUint32(num0RTTPackets) + fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) + Expect(num0RTT).ToNot(BeZero()) + }) + + It("correctly deals with 0-RTT rejections", func() { + tlsConf := getTLSConfig() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + Versions: []protocol.VersionNumber{version}, + MaxIncomingUniStreams: 2, + AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, + }), + ) + Expect(err).ToNot(HaveOccurred()) + + delayProxy := runDelayProxy(ln.Addr().(*net.UDPAddr).Port) + defer delayProxy.Close() + clientConf := dialAndReceiveSessionTicket(ln, delayProxy.LocalPort()) + // now close the listener and dial new connection with different transport parameters + Expect(ln.Close()).To(Succeed()) + ln, err = quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + Versions: []protocol.VersionNumber{version}, + MaxIncomingUniStreams: 1, + }), + ) + Expect(err).ToNot(HaveOccurred()) + proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done) + sess, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + str, err := sess.AcceptUniStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + data, err := ioutil.ReadAll(str) + Expect(err).ToNot(HaveOccurred()) + Expect(string(data)).To(Equal("second flight")) + }() + + sess, err := quic.DialAddrEarly( + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + clientConf, + getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}), + ) + Expect(err).ToNot(HaveOccurred()) + // The client remembers that it was allowed to open 2 uni-directional streams. + for i := 0; i < 2; i++ { + str, err := sess.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + go func() { + defer GinkgoRecover() + _, err = str.Write([]byte("first flight")) + Expect(err).ToNot(HaveOccurred()) + }() + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + _, err = sess.AcceptStream(ctx) + Expect(err).To(Equal(quic.Err0RTTRejected)) + + newSess := sess.NextSession() + str, err := newSess.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + _, err = newSess.OpenUniStream() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("too many open streams")) + _, err = str.Write([]byte("second flight")) + Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) + + Eventually(done).Should(BeClosed()) // The client should send 0-RTT packets, but the server doesn't process them. num0RTT := atomic.LoadUint32(num0RTTPackets) diff --git a/interface.go b/interface.go index 84ae906c..b6df74ff 100644 --- a/interface.go +++ b/interface.go @@ -215,6 +215,8 @@ type EarlySession interface { // Data sent before completion of the handshake is encrypted with 1-RTT keys. // Note that the client's identity hasn't been verified yet. HandshakeComplete() context.Context + + NextSession() Session } // Config contains all configuration data needed for a QUIC server or client. diff --git a/internal/mocks/quic/early_session.go b/internal/mocks/quic/early_session.go index 3d5e5519..5d24a0b4 100644 --- a/internal/mocks/quic/early_session.go +++ b/internal/mocks/quic/early_session.go @@ -137,6 +137,20 @@ func (mr *MockEarlySessionMockRecorder) LocalAddr() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockEarlySession)(nil).LocalAddr)) } +// NextSession mocks base method. +func (m *MockEarlySession) NextSession() quic.Session { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NextSession") + ret0, _ := ret[0].(quic.Session) + return ret0 +} + +// NextSession indicates an expected call of NextSession. +func (mr *MockEarlySessionMockRecorder) NextSession() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextSession", reflect.TypeOf((*MockEarlySession)(nil).NextSession)) +} + // OpenStream mocks base method. func (m *MockEarlySession) OpenStream() (quic.Stream, error) { m.ctrl.T.Helper() diff --git a/mock_quic_session_test.go b/mock_quic_session_test.go index b36a2627..e91d0c3b 100644 --- a/mock_quic_session_test.go +++ b/mock_quic_session_test.go @@ -150,6 +150,20 @@ func (mr *MockQuicSessionMockRecorder) LocalAddr() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockQuicSession)(nil).LocalAddr)) } +// NextSession mocks base method. +func (m *MockQuicSession) NextSession() Session { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NextSession") + ret0, _ := ret[0].(Session) + return ret0 +} + +// NextSession indicates an expected call of NextSession. +func (mr *MockQuicSessionMockRecorder) NextSession() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextSession", reflect.TypeOf((*MockQuicSession)(nil).NextSession)) +} + // OpenStream mocks base method. func (m *MockQuicSession) OpenStream() (Stream, error) { m.ctrl.T.Helper() diff --git a/session.go b/session.go index 395eadfb..aedb16ab 100644 --- a/session.go +++ b/session.go @@ -1458,6 +1458,15 @@ func (s *session) dropEncryptionLevel(encLevel protocol.EncryptionLevel) { if s.tracer != nil { s.tracer.DroppedEncryptionLevel(encLevel) } + if encLevel == protocol.Encryption0RTT { + s.streamsMap.ResetFor0RTT() + if err := s.connFlowController.Reset(); err != nil { + s.closeLocal(err) + } + if err := s.framer.Handle0RTTRejection(); err != nil { + s.closeLocal(err) + } + } } // is called for the client, when restoring transport parameters saved for 0-RTT @@ -1884,3 +1893,9 @@ func (s *session) getPerspective() protocol.Perspective { func (s *session) GetVersion() protocol.VersionNumber { return s.version } + +func (s *session) NextSession() Session { + <-s.HandshakeComplete().Done() + s.streamsMap.UseResetMaps() + return s +}