add Allow0RTT opt in the quic.Config to control 0-RTT on the server side (#3635)

This commit is contained in:
Marten Seemann 2023-01-04 16:18:11 -08:00 committed by GitHub
parent 421893b1c4
commit b52d34008f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 98 additions and 51 deletions

View file

@ -135,6 +135,7 @@ func populateConfig(config *Config, defaultConnIDLen int) *Config {
EnableDatagrams: config.EnableDatagrams,
DisablePathMTUDiscovery: config.DisablePathMTUDiscovery,
DisableVersionNegotiationPackets: config.DisableVersionNegotiationPackets,
Allow0RTT: config.Allow0RTT,
Tracer: config.Tracer,
}
}

View file

@ -45,7 +45,7 @@ var _ = Describe("Config", func() {
}
switch fn := typ.Field(i).Name; fn {
case "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease":
case "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease", "Allow0RTT":
// Can't compare functions.
case "Versions":
f.Set(reflect.ValueOf([]VersionNumber{1, 2, 3}))

View file

@ -241,7 +241,6 @@ var newConnection = func(
conf *Config,
tlsConf *tls.Config,
tokenGenerator *handshake.TokenGenerator,
enable0RTT bool,
clientAddressValidated bool,
tracer logging.ConnectionTracer,
tracingID uint64,
@ -323,6 +322,10 @@ var newConnection = func(
if s.tracer != nil {
s.tracer.SentTransportParameters(params)
}
var allow0RTT func() bool
if conf.Allow0RTT != nil {
allow0RTT = func() bool { return conf.Allow0RTT(conn.RemoteAddr()) }
}
cs := handshake.NewCryptoSetupServer(
initialStream,
handshakeStream,
@ -340,7 +343,7 @@ var newConnection = func(
},
},
tlsConf,
enable0RTT,
allow0RTT,
s.rttStats,
tracer,
logger,

View file

@ -101,7 +101,6 @@ var _ = Describe("Connection", func() {
nil, // tls.Config
tokenGenerator,
false,
false,
tracer,
1234,
utils.DefaultLogger,

View file

@ -105,7 +105,7 @@ func main() {
&wire.TransportParameters{},
runner,
config,
false,
nil,
utils.NewRTTStats(),
nil,
utils.DefaultLogger.WithPrefix("server"),

View file

@ -390,6 +390,10 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls.
protocol.VersionTLS,
)
var allow0RTT func() bool
if enable0RTTServer {
allow0RTT = func() bool { return true }
}
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
server = handshake.NewCryptoSetupServer(
sInitialStream,
@ -400,7 +404,7 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls.
serverTP,
runner,
serverConf,
enable0RTTServer,
allow0RTT,
utils.NewRTTStats(),
nil,
utils.DefaultLogger.WithPrefix("server"),

View file

@ -57,6 +57,7 @@ var _ = Describe("0-RTT", func() {
serverConf = getQuicConfig(nil)
serverConf.Versions = []protocol.VersionNumber{version}
}
serverConf.Allow0RTT = func(addr net.Addr) bool { return true }
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
@ -137,6 +138,7 @@ var _ = Describe("0-RTT", func() {
_, err = str.Write(testdata)
Expect(err).ToNot(HaveOccurred())
Expect(str.Close()).To(Succeed())
<-conn.HandshakeComplete().Done()
Expect(conn.ConnectionState().TLS.Used0RTT).To(BeTrue())
Eventually(done).Should(BeClosed())
Eventually(conn.Context().Done()).Should(BeClosed())
@ -194,8 +196,9 @@ var _ = Describe("0-RTT", func() {
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version},
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Versions: []protocol.VersionNumber{version},
Allow0RTT: func(addr net.Addr) bool { return true },
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
}),
)
Expect(err).ToNot(HaveOccurred())
@ -251,8 +254,9 @@ var _ = Describe("0-RTT", func() {
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version},
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Versions: []protocol.VersionNumber{version},
Allow0RTT: func(net.Addr) bool { return true },
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
}),
)
Expect(err).ToNot(HaveOccurred())
@ -325,8 +329,9 @@ var _ = Describe("0-RTT", func() {
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version},
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Versions: []protocol.VersionNumber{version},
Allow0RTT: func(net.Addr) bool { return true },
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
}),
)
Expect(err).ToNot(HaveOccurred())
@ -397,6 +402,7 @@ var _ = Describe("0-RTT", func() {
getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version},
RequireAddressValidation: func(net.Addr) bool { return true },
Allow0RTT: func(net.Addr) bool { return true },
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
}),
)
@ -458,6 +464,7 @@ var _ = Describe("0-RTT", func() {
getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version},
MaxIncomingUniStreams: maxStreams + 1,
Allow0RTT: func(net.Addr) bool { return true },
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
}),
)
@ -502,6 +509,7 @@ var _ = Describe("0-RTT", func() {
getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version},
MaxIncomingStreams: maxStreams - 1,
Allow0RTT: func(net.Addr) bool { return true },
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
}),
)
@ -529,8 +537,37 @@ var _ = Describe("0-RTT", func() {
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version},
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Versions: []protocol.VersionNumber{version},
Allow0RTT: func(net.Addr) bool { return true },
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
}),
)
Expect(err).ToNot(HaveOccurred())
defer ln.Close()
proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port)
defer proxy.Close()
check0RTTRejected(ln, proxy.LocalPort(), clientConf)
// The client should send 0-RTT packets, but the server doesn't process them.
num0RTT := atomic.LoadUint32(num0RTTPackets)
fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT)
Expect(num0RTT).ToNot(BeZero())
Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty())
})
It("rejects 0-RTT when the application doesn't allow it", func() {
tlsConf, clientConf := dialAndReceiveSessionTicket(nil)
// now close the listener and dial new connection with a different ALPN
tracer := newPacketTracer()
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version},
Allow0RTT: func(net.Addr) bool { return false }, // application rejects 0-RTT
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
}),
)
Expect(err).ToNot(HaveOccurred())
@ -551,14 +588,16 @@ var _ = Describe("0-RTT", func() {
func(addFlowControlLimit func(*quic.Config, uint64)) {
tracer := newPacketTracer()
firstConf := getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version},
Allow0RTT: func(net.Addr) bool { return true },
Versions: []protocol.VersionNumber{version},
})
addFlowControlLimit(firstConf, 3)
tlsConf, clientConf := dialAndReceiveSessionTicket(firstConf)
secondConf := getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version},
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Versions: []protocol.VersionNumber{version},
Allow0RTT: func(net.Addr) bool { return true },
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
})
addFlowControlLimit(secondConf, 100)
ln, err := quic.ListenAddrEarly(
@ -709,8 +748,9 @@ var _ = Describe("0-RTT", func() {
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version},
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Versions: []protocol.VersionNumber{version},
Allow0RTT: func(net.Addr) bool { return true },
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
}),
)
Expect(err).ToNot(HaveOccurred())

View file

@ -176,7 +176,6 @@ type Connection interface {
// Context returns a context that is cancelled when the connection is closed.
Context() context.Context
// ConnectionState returns basic details about the QUIC connection.
// It blocks until the handshake completes.
// Warning: This API should not be considered stable and might change soon.
ConnectionState() ConnectionState
@ -325,6 +324,11 @@ type Config struct {
// This can be useful if version information is exchanged out-of-band.
// It has no effect for a client.
DisableVersionNegotiationPackets bool
// Allow0RTT allows the application to decide if a 0-RTT connection attempt should be accepted.
// When set, 0-RTT is enabled. When not set, 0-RTT is disabled.
// Only valid for the server.
// Warning: This API should not be considered stable and might change soon.
Allow0RTT func(net.Addr) bool
// Enable QUIC datagram support (RFC 9221).
EnableDatagrams bool
Tracer logging.Tracer

View file

@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"io"
"math"
"net"
"sync"
"time"
@ -115,6 +116,7 @@ type cryptoSetup struct {
clientHelloWritten bool
clientHelloWrittenChan chan struct{} // is closed as soon as the ClientHello is written
zeroRTTParametersChan chan<- *wire.TransportParameters
allow0RTT func() bool
rttStats *utils.RTTStats
@ -195,7 +197,7 @@ func NewCryptoSetupServer(
tp *wire.TransportParameters,
runner handshakeRunner,
tlsConf *tls.Config,
enable0RTT bool,
allow0RTT func() bool,
rttStats *utils.RTTStats,
tracer logging.ConnectionTracer,
logger utils.Logger,
@ -208,13 +210,14 @@ func NewCryptoSetupServer(
tp,
runner,
tlsConf,
enable0RTT,
allow0RTT != nil,
rttStats,
tracer,
logger,
protocol.PerspectiveServer,
version,
)
cs.allow0RTT = allow0RTT
cs.conn = qtls.Server(newConn(localAddr, remoteAddr, version), cs.tlsConf, cs.extraConf)
return cs
}
@ -267,7 +270,7 @@ func newCryptoSetup(
}
var maxEarlyData uint32
if enable0RTT {
maxEarlyData = 0xffffffff
maxEarlyData = math.MaxUint32
}
cs.extraConf = &qtls.ExtraConfig{
GetExtensions: extHandler.GetExtensions,
@ -490,13 +493,17 @@ func (h *cryptoSetup) accept0RTT(sessionTicketData []byte) bool {
return false
}
valid := h.ourParams.ValidFor0RTT(t.Parameters)
if valid {
h.logger.Debugf("Accepting 0-RTT. Restoring RTT from session ticket: %s", t.RTT)
h.rttStats.SetInitialRTT(t.RTT)
} else {
if !valid {
h.logger.Debugf("Transport parameters changed. Rejecting 0-RTT.")
return false
}
return valid
if !h.allow0RTT() {
h.logger.Debugf("0-RTT not allowed. Rejecting 0-RTT.")
return false
}
h.logger.Debugf("Accepting 0-RTT. Restoring RTT from session ticket: %s", t.RTT)
h.rttStats.SetInitialRTT(t.RTT)
return true
}
// rejected0RTT is called for the client when the server rejects 0-RTT.

View file

@ -95,7 +95,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&wire.TransportParameters{StatelessResetToken: &token},
runner,
testdata.GetTLSConfig(),
false,
nil,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
@ -177,7 +177,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&wire.TransportParameters{StatelessResetToken: &token},
runner,
testdata.GetTLSConfig(),
false,
nil,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
@ -218,7 +218,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&wire.TransportParameters{StatelessResetToken: &token},
runner,
serverConf,
false,
nil,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
@ -253,7 +253,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&wire.TransportParameters{StatelessResetToken: &token},
NewMockHandshakeRunner(mockCtrl),
serverConf,
false,
nil,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
@ -378,6 +378,10 @@ var _ = Describe("Crypto Setup TLS", func() {
protocol.VersionTLS,
)
var allow0RTT func() bool
if enable0RTT {
allow0RTT = func() bool { return true }
}
var sHandshakeComplete bool
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
sErrChan := make(chan error, 1)
@ -398,7 +402,7 @@ var _ = Describe("Crypto Setup TLS", func() {
serverTransportParameters,
sRunner,
serverConf,
enable0RTT,
allow0RTT,
serverRTTStats,
nil,
utils.DefaultLogger.WithPrefix("server"),
@ -536,7 +540,7 @@ var _ = Describe("Crypto Setup TLS", func() {
sTransportParameters,
sRunner,
serverConf,
false,
nil,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
@ -591,7 +595,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&wire.TransportParameters{StatelessResetToken: &token},
sRunner,
serverConf,
false,
nil,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
@ -650,7 +654,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&wire.TransportParameters{StatelessResetToken: &token},
sRunner,
serverConf,
false,
nil,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),

View file

@ -88,7 +88,6 @@ type baseServer struct {
*Config,
*tls.Config,
*handshake.TokenGenerator,
bool, /* enable 0-RTT */
bool, /* client address validated by an address validation token */
logging.ConnectionTracer,
uint64,
@ -506,7 +505,6 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
s.config,
s.tlsConf,
s.tokenGenerator,
s.acceptEarlyConns,
clientAddrIsValid,
tracer,
tracingID,

View file

@ -286,14 +286,12 @@ var _ = Describe("Server", func() {
_ *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
enable0RTT bool,
_ bool,
_ logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.VersionNumber,
) quicConn {
Expect(enable0RTT).To(BeFalse())
Expect(origDestConnID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde})))
Expect(*retrySrcConnID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad})))
Expect(clientDestConnID).To(Equal(hdr.DestConnectionID))
@ -489,14 +487,12 @@ var _ = Describe("Server", func() {
_ *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
enable0RTT bool,
_ bool,
_ logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.VersionNumber,
) quicConn {
Expect(enable0RTT).To(BeFalse())
Expect(origDestConnID).To(Equal(hdr.DestConnectionID))
Expect(retrySrcConnID).To(BeNil())
Expect(clientDestConnID).To(Equal(hdr.DestConnectionID))
@ -550,7 +546,6 @@ var _ = Describe("Server", func() {
_ *tls.Config,
_ *handshake.TokenGenerator,
_ bool,
_ bool,
_ logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
@ -604,7 +599,6 @@ var _ = Describe("Server", func() {
_ *tls.Config,
_ *handshake.TokenGenerator,
_ bool,
_ bool,
_ logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
@ -634,7 +628,6 @@ var _ = Describe("Server", func() {
_ *tls.Config,
_ *handshake.TokenGenerator,
_ bool,
_ bool,
_ logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
@ -705,7 +698,6 @@ var _ = Describe("Server", func() {
_ *tls.Config,
_ *handshake.TokenGenerator,
_ bool,
_ bool,
_ logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
@ -1011,7 +1003,6 @@ var _ = Describe("Server", func() {
_ *tls.Config,
_ *handshake.TokenGenerator,
_ bool,
_ bool,
_ logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
@ -1084,14 +1075,12 @@ var _ = Describe("Server", func() {
_ *Config,
_ *tls.Config,
_ *handshake.TokenGenerator,
enable0RTT bool,
_ bool,
_ logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
_ protocol.VersionNumber,
) quicConn {
Expect(enable0RTT).To(BeTrue())
conn.EXPECT().handlePacket(gomock.Any())
conn.EXPECT().run().Do(func() {})
conn.EXPECT().earlyConnReady().Return(ready)
@ -1128,7 +1117,6 @@ var _ = Describe("Server", func() {
_ *tls.Config,
_ *handshake.TokenGenerator,
_ bool,
_ bool,
_ logging.ConnectionTracer,
_ uint64,
_ utils.Logger,
@ -1191,7 +1179,6 @@ var _ = Describe("Server", func() {
_ *tls.Config,
_ *handshake.TokenGenerator,
_ bool,
_ bool,
_ logging.ConnectionTracer,
_ uint64,
_ utils.Logger,