From b52d34008faef47aa7bc618ad2686f154c190f0a Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 4 Jan 2023 16:18:11 -0800 Subject: [PATCH] add Allow0RTT opt in the quic.Config to control 0-RTT on the server side (#3635) --- config.go | 1 + config_test.go | 2 +- connection.go | 7 ++- connection_test.go | 1 - fuzzing/handshake/cmd/corpus.go | 2 +- fuzzing/handshake/fuzz.go | 6 ++- integrationtests/self/zero_rtt_test.go | 66 ++++++++++++++++++++----- interface.go | 6 ++- internal/handshake/crypto_setup.go | 23 ++++++--- internal/handshake/crypto_setup_test.go | 20 +++++--- server.go | 2 - server_test.go | 13 ----- 12 files changed, 98 insertions(+), 51 deletions(-) diff --git a/config.go b/config.go index 0e8cc98a..ce495dfd 100644 --- a/config.go +++ b/config.go @@ -135,6 +135,7 @@ func populateConfig(config *Config, defaultConnIDLen int) *Config { EnableDatagrams: config.EnableDatagrams, DisablePathMTUDiscovery: config.DisablePathMTUDiscovery, DisableVersionNegotiationPackets: config.DisableVersionNegotiationPackets, + Allow0RTT: config.Allow0RTT, Tracer: config.Tracer, } } diff --git a/config_test.go b/config_test.go index e2bc8153..ec401994 100644 --- a/config_test.go +++ b/config_test.go @@ -45,7 +45,7 @@ var _ = Describe("Config", func() { } switch fn := typ.Field(i).Name; fn { - case "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease": + case "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease", "Allow0RTT": // Can't compare functions. case "Versions": f.Set(reflect.ValueOf([]VersionNumber{1, 2, 3})) diff --git a/connection.go b/connection.go index 0df45e78..b5309460 100644 --- a/connection.go +++ b/connection.go @@ -241,7 +241,6 @@ var newConnection = func( conf *Config, tlsConf *tls.Config, tokenGenerator *handshake.TokenGenerator, - enable0RTT bool, clientAddressValidated bool, tracer logging.ConnectionTracer, tracingID uint64, @@ -323,6 +322,10 @@ var newConnection = func( if s.tracer != nil { s.tracer.SentTransportParameters(params) } + var allow0RTT func() bool + if conf.Allow0RTT != nil { + allow0RTT = func() bool { return conf.Allow0RTT(conn.RemoteAddr()) } + } cs := handshake.NewCryptoSetupServer( initialStream, handshakeStream, @@ -340,7 +343,7 @@ var newConnection = func( }, }, tlsConf, - enable0RTT, + allow0RTT, s.rttStats, tracer, logger, diff --git a/connection_test.go b/connection_test.go index 45c2cd04..e22538a2 100644 --- a/connection_test.go +++ b/connection_test.go @@ -101,7 +101,6 @@ var _ = Describe("Connection", func() { nil, // tls.Config tokenGenerator, false, - false, tracer, 1234, utils.DefaultLogger, diff --git a/fuzzing/handshake/cmd/corpus.go b/fuzzing/handshake/cmd/corpus.go index d9cd1f07..7d03590d 100644 --- a/fuzzing/handshake/cmd/corpus.go +++ b/fuzzing/handshake/cmd/corpus.go @@ -105,7 +105,7 @@ func main() { &wire.TransportParameters{}, runner, config, - false, + nil, utils.NewRTTStats(), nil, utils.DefaultLogger.WithPrefix("server"), diff --git a/fuzzing/handshake/fuzz.go b/fuzzing/handshake/fuzz.go index 5faf478a..d23557ca 100644 --- a/fuzzing/handshake/fuzz.go +++ b/fuzzing/handshake/fuzz.go @@ -390,6 +390,10 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls. protocol.VersionTLS, ) + var allow0RTT func() bool + if enable0RTTServer { + allow0RTT = func() bool { return true } + } sChunkChan, sInitialStream, sHandshakeStream := initStreams() server = handshake.NewCryptoSetupServer( sInitialStream, @@ -400,7 +404,7 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls. serverTP, runner, serverConf, - enable0RTTServer, + allow0RTT, utils.NewRTTStats(), nil, utils.DefaultLogger.WithPrefix("server"), diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index 055f26b4..e54de97f 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -57,6 +57,7 @@ var _ = Describe("0-RTT", func() { serverConf = getQuicConfig(nil) serverConf.Versions = []protocol.VersionNumber{version} } + serverConf.Allow0RTT = func(addr net.Addr) bool { return true } ln, err := quic.ListenAddrEarly( "localhost:0", tlsConf, @@ -137,6 +138,7 @@ var _ = Describe("0-RTT", func() { _, err = str.Write(testdata) Expect(err).ToNot(HaveOccurred()) Expect(str.Close()).To(Succeed()) + <-conn.HandshakeComplete().Done() Expect(conn.ConnectionState().TLS.Used0RTT).To(BeTrue()) Eventually(done).Should(BeClosed()) Eventually(conn.Context().Done()).Should(BeClosed()) @@ -194,8 +196,9 @@ var _ = Describe("0-RTT", func() { "localhost:0", tlsConf, getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Versions: []protocol.VersionNumber{version}, + Allow0RTT: func(addr net.Addr) bool { return true }, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }), ) Expect(err).ToNot(HaveOccurred()) @@ -251,8 +254,9 @@ var _ = Describe("0-RTT", func() { "localhost:0", tlsConf, getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Versions: []protocol.VersionNumber{version}, + Allow0RTT: func(net.Addr) bool { return true }, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }), ) Expect(err).ToNot(HaveOccurred()) @@ -325,8 +329,9 @@ var _ = Describe("0-RTT", func() { "localhost:0", tlsConf, getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Versions: []protocol.VersionNumber{version}, + Allow0RTT: func(net.Addr) bool { return true }, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }), ) Expect(err).ToNot(HaveOccurred()) @@ -397,6 +402,7 @@ var _ = Describe("0-RTT", func() { getQuicConfig(&quic.Config{ Versions: []protocol.VersionNumber{version}, RequireAddressValidation: func(net.Addr) bool { return true }, + Allow0RTT: func(net.Addr) bool { return true }, Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }), ) @@ -458,6 +464,7 @@ var _ = Describe("0-RTT", func() { getQuicConfig(&quic.Config{ Versions: []protocol.VersionNumber{version}, MaxIncomingUniStreams: maxStreams + 1, + Allow0RTT: func(net.Addr) bool { return true }, Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }), ) @@ -502,6 +509,7 @@ var _ = Describe("0-RTT", func() { getQuicConfig(&quic.Config{ Versions: []protocol.VersionNumber{version}, MaxIncomingStreams: maxStreams - 1, + Allow0RTT: func(net.Addr) bool { return true }, Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }), ) @@ -529,8 +537,37 @@ var _ = Describe("0-RTT", func() { "localhost:0", tlsConf, getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Versions: []protocol.VersionNumber{version}, + Allow0RTT: func(net.Addr) bool { return true }, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() + + check0RTTRejected(ln, proxy.LocalPort(), clientConf) + + // The client should send 0-RTT packets, but the server doesn't process them. + num0RTT := atomic.LoadUint32(num0RTTPackets) + fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) + Expect(num0RTT).ToNot(BeZero()) + Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + }) + + It("rejects 0-RTT when the application doesn't allow it", func() { + tlsConf, clientConf := dialAndReceiveSessionTicket(nil) + + // now close the listener and dial new connection with a different ALPN + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + Versions: []protocol.VersionNumber{version}, + Allow0RTT: func(net.Addr) bool { return false }, // application rejects 0-RTT + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }), ) Expect(err).ToNot(HaveOccurred()) @@ -551,14 +588,16 @@ var _ = Describe("0-RTT", func() { func(addFlowControlLimit func(*quic.Config, uint64)) { tracer := newPacketTracer() firstConf := getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, + Allow0RTT: func(net.Addr) bool { return true }, + Versions: []protocol.VersionNumber{version}, }) addFlowControlLimit(firstConf, 3) tlsConf, clientConf := dialAndReceiveSessionTicket(firstConf) secondConf := getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Versions: []protocol.VersionNumber{version}, + Allow0RTT: func(net.Addr) bool { return true }, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }) addFlowControlLimit(secondConf, 100) ln, err := quic.ListenAddrEarly( @@ -709,8 +748,9 @@ var _ = Describe("0-RTT", func() { "localhost:0", tlsConf, getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + Versions: []protocol.VersionNumber{version}, + Allow0RTT: func(net.Addr) bool { return true }, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), }), ) Expect(err).ToNot(HaveOccurred()) diff --git a/interface.go b/interface.go index 214afcf1..315ed9cc 100644 --- a/interface.go +++ b/interface.go @@ -176,7 +176,6 @@ type Connection interface { // Context returns a context that is cancelled when the connection is closed. Context() context.Context // ConnectionState returns basic details about the QUIC connection. - // It blocks until the handshake completes. // Warning: This API should not be considered stable and might change soon. ConnectionState() ConnectionState @@ -325,6 +324,11 @@ type Config struct { // This can be useful if version information is exchanged out-of-band. // It has no effect for a client. DisableVersionNegotiationPackets bool + // Allow0RTT allows the application to decide if a 0-RTT connection attempt should be accepted. + // When set, 0-RTT is enabled. When not set, 0-RTT is disabled. + // Only valid for the server. + // Warning: This API should not be considered stable and might change soon. + Allow0RTT func(net.Addr) bool // Enable QUIC datagram support (RFC 9221). EnableDatagrams bool Tracer logging.Tracer diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 83325bb3..4ba01942 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "math" "net" "sync" "time" @@ -115,6 +116,7 @@ type cryptoSetup struct { clientHelloWritten bool clientHelloWrittenChan chan struct{} // is closed as soon as the ClientHello is written zeroRTTParametersChan chan<- *wire.TransportParameters + allow0RTT func() bool rttStats *utils.RTTStats @@ -195,7 +197,7 @@ func NewCryptoSetupServer( tp *wire.TransportParameters, runner handshakeRunner, tlsConf *tls.Config, - enable0RTT bool, + allow0RTT func() bool, rttStats *utils.RTTStats, tracer logging.ConnectionTracer, logger utils.Logger, @@ -208,13 +210,14 @@ func NewCryptoSetupServer( tp, runner, tlsConf, - enable0RTT, + allow0RTT != nil, rttStats, tracer, logger, protocol.PerspectiveServer, version, ) + cs.allow0RTT = allow0RTT cs.conn = qtls.Server(newConn(localAddr, remoteAddr, version), cs.tlsConf, cs.extraConf) return cs } @@ -267,7 +270,7 @@ func newCryptoSetup( } var maxEarlyData uint32 if enable0RTT { - maxEarlyData = 0xffffffff + maxEarlyData = math.MaxUint32 } cs.extraConf = &qtls.ExtraConfig{ GetExtensions: extHandler.GetExtensions, @@ -490,13 +493,17 @@ func (h *cryptoSetup) accept0RTT(sessionTicketData []byte) bool { return false } valid := h.ourParams.ValidFor0RTT(t.Parameters) - if valid { - h.logger.Debugf("Accepting 0-RTT. Restoring RTT from session ticket: %s", t.RTT) - h.rttStats.SetInitialRTT(t.RTT) - } else { + if !valid { h.logger.Debugf("Transport parameters changed. Rejecting 0-RTT.") + return false } - return valid + if !h.allow0RTT() { + h.logger.Debugf("0-RTT not allowed. Rejecting 0-RTT.") + return false + } + h.logger.Debugf("Accepting 0-RTT. Restoring RTT from session ticket: %s", t.RTT) + h.rttStats.SetInitialRTT(t.RTT) + return true } // rejected0RTT is called for the client when the server rejects 0-RTT. diff --git a/internal/handshake/crypto_setup_test.go b/internal/handshake/crypto_setup_test.go index c041cd68..4b8b67a3 100644 --- a/internal/handshake/crypto_setup_test.go +++ b/internal/handshake/crypto_setup_test.go @@ -95,7 +95,7 @@ var _ = Describe("Crypto Setup TLS", func() { &wire.TransportParameters{StatelessResetToken: &token}, runner, testdata.GetTLSConfig(), - false, + nil, &utils.RTTStats{}, nil, utils.DefaultLogger.WithPrefix("server"), @@ -177,7 +177,7 @@ var _ = Describe("Crypto Setup TLS", func() { &wire.TransportParameters{StatelessResetToken: &token}, runner, testdata.GetTLSConfig(), - false, + nil, &utils.RTTStats{}, nil, utils.DefaultLogger.WithPrefix("server"), @@ -218,7 +218,7 @@ var _ = Describe("Crypto Setup TLS", func() { &wire.TransportParameters{StatelessResetToken: &token}, runner, serverConf, - false, + nil, &utils.RTTStats{}, nil, utils.DefaultLogger.WithPrefix("server"), @@ -253,7 +253,7 @@ var _ = Describe("Crypto Setup TLS", func() { &wire.TransportParameters{StatelessResetToken: &token}, NewMockHandshakeRunner(mockCtrl), serverConf, - false, + nil, &utils.RTTStats{}, nil, utils.DefaultLogger.WithPrefix("server"), @@ -378,6 +378,10 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.VersionTLS, ) + var allow0RTT func() bool + if enable0RTT { + allow0RTT = func() bool { return true } + } var sHandshakeComplete bool sChunkChan, sInitialStream, sHandshakeStream := initStreams() sErrChan := make(chan error, 1) @@ -398,7 +402,7 @@ var _ = Describe("Crypto Setup TLS", func() { serverTransportParameters, sRunner, serverConf, - enable0RTT, + allow0RTT, serverRTTStats, nil, utils.DefaultLogger.WithPrefix("server"), @@ -536,7 +540,7 @@ var _ = Describe("Crypto Setup TLS", func() { sTransportParameters, sRunner, serverConf, - false, + nil, &utils.RTTStats{}, nil, utils.DefaultLogger.WithPrefix("server"), @@ -591,7 +595,7 @@ var _ = Describe("Crypto Setup TLS", func() { &wire.TransportParameters{StatelessResetToken: &token}, sRunner, serverConf, - false, + nil, &utils.RTTStats{}, nil, utils.DefaultLogger.WithPrefix("server"), @@ -650,7 +654,7 @@ var _ = Describe("Crypto Setup TLS", func() { &wire.TransportParameters{StatelessResetToken: &token}, sRunner, serverConf, - false, + nil, &utils.RTTStats{}, nil, utils.DefaultLogger.WithPrefix("server"), diff --git a/server.go b/server.go index 16d4d818..066a4649 100644 --- a/server.go +++ b/server.go @@ -88,7 +88,6 @@ type baseServer struct { *Config, *tls.Config, *handshake.TokenGenerator, - bool, /* enable 0-RTT */ bool, /* client address validated by an address validation token */ logging.ConnectionTracer, uint64, @@ -506,7 +505,6 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro s.config, s.tlsConf, s.tokenGenerator, - s.acceptEarlyConns, clientAddrIsValid, tracer, tracingID, diff --git a/server_test.go b/server_test.go index 0da7caab..3897e1b5 100644 --- a/server_test.go +++ b/server_test.go @@ -286,14 +286,12 @@ var _ = Describe("Server", func() { _ *Config, _ *tls.Config, _ *handshake.TokenGenerator, - enable0RTT bool, _ bool, _ logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, ) quicConn { - Expect(enable0RTT).To(BeFalse()) Expect(origDestConnID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}))) Expect(*retrySrcConnID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}))) Expect(clientDestConnID).To(Equal(hdr.DestConnectionID)) @@ -489,14 +487,12 @@ var _ = Describe("Server", func() { _ *Config, _ *tls.Config, _ *handshake.TokenGenerator, - enable0RTT bool, _ bool, _ logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, ) quicConn { - Expect(enable0RTT).To(BeFalse()) Expect(origDestConnID).To(Equal(hdr.DestConnectionID)) Expect(retrySrcConnID).To(BeNil()) Expect(clientDestConnID).To(Equal(hdr.DestConnectionID)) @@ -550,7 +546,6 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ bool, _ logging.ConnectionTracer, _ uint64, _ utils.Logger, @@ -604,7 +599,6 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ bool, _ logging.ConnectionTracer, _ uint64, _ utils.Logger, @@ -634,7 +628,6 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ bool, _ logging.ConnectionTracer, _ uint64, _ utils.Logger, @@ -705,7 +698,6 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ bool, _ logging.ConnectionTracer, _ uint64, _ utils.Logger, @@ -1011,7 +1003,6 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ bool, _ logging.ConnectionTracer, _ uint64, _ utils.Logger, @@ -1084,14 +1075,12 @@ var _ = Describe("Server", func() { _ *Config, _ *tls.Config, _ *handshake.TokenGenerator, - enable0RTT bool, _ bool, _ logging.ConnectionTracer, _ uint64, _ utils.Logger, _ protocol.VersionNumber, ) quicConn { - Expect(enable0RTT).To(BeTrue()) conn.EXPECT().handlePacket(gomock.Any()) conn.EXPECT().run().Do(func() {}) conn.EXPECT().earlyConnReady().Return(ready) @@ -1128,7 +1117,6 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ bool, _ logging.ConnectionTracer, _ uint64, _ utils.Logger, @@ -1191,7 +1179,6 @@ var _ = Describe("Server", func() { _ *tls.Config, _ *handshake.TokenGenerator, _ bool, - _ bool, _ logging.ConnectionTracer, _ uint64, _ utils.Logger,