make Config.Allow0RTT a bool, not a callback

This commit is contained in:
Marten Seemann 2023-03-22 13:25:09 +13:00
parent bc7cb706c5
commit 7a0ef5f867
10 changed files with 35 additions and 49 deletions

View file

@ -46,7 +46,7 @@ var _ = Describe("Config", func() {
}
switch fn := typ.Field(i).Name; fn {
case "GetConfigForClient", "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease", "Allow0RTT":
case "GetConfigForClient", "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease":
// Can't compare functions.
case "Versions":
f.Set(reflect.ValueOf([]VersionNumber{1, 2, 3}))
@ -86,6 +86,8 @@ var _ = Describe("Config", func() {
f.Set(reflect.ValueOf(true))
case "DisablePathMTUDiscovery":
f.Set(reflect.ValueOf(true))
case "Allow0RTT":
f.Set(reflect.ValueOf(true))
case "Tracer":
f.Set(reflect.ValueOf(mocklogging.NewMockTracer(mockCtrl)))
default:

View file

@ -324,10 +324,6 @@ var newConnection = func(
if s.tracer != nil {
s.tracer.SentTransportParameters(params)
}
var allow0RTT func() bool
if conf.Allow0RTT != nil {
allow0RTT = func() bool { return conf.Allow0RTT(conn.RemoteAddr()) }
}
cs := handshake.NewCryptoSetupServer(
initialStream,
handshakeStream,
@ -345,7 +341,7 @@ var newConnection = func(
},
},
tlsConf,
allow0RTT,
conf.Allow0RTT,
s.rttStats,
tracer,
logger,

View file

@ -105,7 +105,7 @@ func main() {
&wire.TransportParameters{ActiveConnectionIDLimit: 2},
runner,
config,
nil,
false,
utils.NewRTTStats(),
nil,
utils.DefaultLogger.WithPrefix("server"),

View file

@ -390,10 +390,6 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls.
protocol.Version1,
)
var allow0RTT func() bool
if enable0RTTServer {
allow0RTT = func() bool { return true }
}
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
server = handshake.NewCryptoSetupServer(
sInitialStream,
@ -404,7 +400,7 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls.
serverTP,
runner,
serverConf,
allow0RTT,
enable0RTTServer,
utils.NewRTTStats(),
nil,
utils.DefaultLogger.WithPrefix("server"),

View file

@ -288,7 +288,7 @@ func (s *Server) serveConn(tlsConf *tls.Config, conn net.PacketConn) error {
baseConf := ConfigureTLSConfig(tlsConf)
quicConf := s.QuicConfig
if quicConf == nil {
quicConf = &quic.Config{Allow0RTT: func(net.Addr) bool { return true }}
quicConf = &quic.Config{Allow0RTT: true}
} else {
quicConf = s.QuicConfig.Clone()
}

View file

@ -54,7 +54,7 @@ var _ = Describe("0-RTT", func() {
if serverConf == nil {
serverConf = getQuicConfig(nil)
}
serverConf.Allow0RTT = func(addr net.Addr) bool { return true }
serverConf.Allow0RTT = true
ln, err := quic.ListenAddrEarly(
"localhost:0",
tlsConf,
@ -222,7 +222,7 @@ var _ = Describe("0-RTT", func() {
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
Allow0RTT: func(addr net.Addr) bool { return true },
Allow0RTT: true,
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
}),
)
@ -276,7 +276,7 @@ var _ = Describe("0-RTT", func() {
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
Allow0RTT: func(net.Addr) bool { return true },
Allow0RTT: true,
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
}),
)
@ -358,7 +358,7 @@ var _ = Describe("0-RTT", func() {
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
Allow0RTT: func(net.Addr) bool { return true },
Allow0RTT: true,
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
}),
)
@ -434,7 +434,7 @@ var _ = Describe("0-RTT", func() {
tlsConf,
getQuicConfig(&quic.Config{
RequireAddressValidation: func(net.Addr) bool { return true },
Allow0RTT: func(net.Addr) bool { return true },
Allow0RTT: true,
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
}),
)
@ -495,7 +495,7 @@ var _ = Describe("0-RTT", func() {
tlsConf,
getQuicConfig(&quic.Config{
MaxIncomingUniStreams: maxStreams + 1,
Allow0RTT: func(net.Addr) bool { return true },
Allow0RTT: true,
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
}),
)
@ -540,7 +540,7 @@ var _ = Describe("0-RTT", func() {
tlsConf,
getQuicConfig(&quic.Config{
MaxIncomingStreams: maxStreams - 1,
Allow0RTT: func(net.Addr) bool { return true },
Allow0RTT: true,
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
}),
)
@ -568,7 +568,7 @@ var _ = Describe("0-RTT", func() {
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
Allow0RTT: func(net.Addr) bool { return true },
Allow0RTT: true,
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
}),
)
@ -595,7 +595,7 @@ var _ = Describe("0-RTT", func() {
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
Allow0RTT: func(net.Addr) bool { return false }, // application rejects 0-RTT
Allow0RTT: false, // application rejects 0-RTT
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
}),
)
@ -616,12 +616,12 @@ var _ = Describe("0-RTT", func() {
DescribeTable("flow control limits",
func(addFlowControlLimit func(*quic.Config, uint64)) {
tracer := newPacketTracer()
firstConf := getQuicConfig(&quic.Config{Allow0RTT: func(net.Addr) bool { return true }})
firstConf := getQuicConfig(&quic.Config{Allow0RTT: true})
addFlowControlLimit(firstConf, 3)
tlsConf, clientConf := dialAndReceiveSessionTicket(firstConf)
secondConf := getQuicConfig(&quic.Config{
Allow0RTT: func(net.Addr) bool { return true },
Allow0RTT: true,
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
})
addFlowControlLimit(secondConf, 100)
@ -774,7 +774,7 @@ var _ = Describe("0-RTT", func() {
"localhost:0",
tlsConf,
getQuicConfig(&quic.Config{
Allow0RTT: func(net.Addr) bool { return true },
Allow0RTT: true,
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
}),
)

View file

@ -318,10 +318,8 @@ type Config struct {
// It has no effect for a client.
DisableVersionNegotiationPackets bool
// Allow0RTT allows the application to decide if a 0-RTT connection attempt should be accepted.
// When set, 0-RTT is enabled. When not set, 0-RTT is disabled.
// Only valid for the server.
// Warning: This API should not be considered stable and might change soon.
Allow0RTT func(net.Addr) bool
Allow0RTT bool
// Enable QUIC datagram support (RFC 9221).
EnableDatagrams bool
Tracer logging.Tracer

View file

@ -116,7 +116,7 @@ type cryptoSetup struct {
clientHelloWritten bool
clientHelloWrittenChan chan struct{} // is closed as soon as the ClientHello is written
zeroRTTParametersChan chan<- *wire.TransportParameters
allow0RTT func() bool
allow0RTT bool
rttStats *utils.RTTStats
@ -197,7 +197,7 @@ func NewCryptoSetupServer(
tp *wire.TransportParameters,
runner handshakeRunner,
tlsConf *tls.Config,
allow0RTT func() bool,
allow0RTT bool,
rttStats *utils.RTTStats,
tracer logging.ConnectionTracer,
logger utils.Logger,
@ -210,14 +210,13 @@ func NewCryptoSetupServer(
tp,
runner,
tlsConf,
allow0RTT != nil,
allow0RTT,
rttStats,
tracer,
logger,
protocol.PerspectiveServer,
version,
)
cs.allow0RTT = allow0RTT
cs.conn = qtls.Server(newConn(localAddr, remoteAddr), cs.tlsConf, cs.extraConf)
return cs
}
@ -253,6 +252,7 @@ func newCryptoSetup(
readEncLevel: protocol.EncryptionInitial,
writeEncLevel: protocol.EncryptionInitial,
runner: runner,
allow0RTT: enable0RTT,
ourParams: tp,
paramsChan: extHandler.TransportParameters(),
rttStats: rttStats,
@ -503,7 +503,7 @@ func (h *cryptoSetup) accept0RTT(sessionTicketData []byte) bool {
h.logger.Debugf("Transport parameters changed. Rejecting 0-RTT.")
return false
}
if !h.allow0RTT() {
if !h.allow0RTT {
h.logger.Debugf("0-RTT not allowed. Rejecting 0-RTT.")
return false
}

View file

@ -95,7 +95,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&wire.TransportParameters{StatelessResetToken: &token},
runner,
testdata.GetTLSConfig(),
nil,
false,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
@ -177,7 +177,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&wire.TransportParameters{StatelessResetToken: &token},
runner,
testdata.GetTLSConfig(),
nil,
false,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
@ -218,7 +218,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&wire.TransportParameters{StatelessResetToken: &token},
runner,
serverConf,
nil,
false,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
@ -253,7 +253,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&wire.TransportParameters{StatelessResetToken: &token},
NewMockHandshakeRunner(mockCtrl),
serverConf,
nil,
false,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
@ -378,10 +378,6 @@ var _ = Describe("Crypto Setup TLS", func() {
protocol.Version1,
)
var allow0RTT func() bool
if enable0RTT {
allow0RTT = func() bool { return true }
}
var sHandshakeComplete bool
sChunkChan, sInitialStream, sHandshakeStream := initStreams()
sErrChan := make(chan error, 1)
@ -402,7 +398,7 @@ var _ = Describe("Crypto Setup TLS", func() {
serverTransportParameters,
sRunner,
serverConf,
allow0RTT,
enable0RTT,
serverRTTStats,
nil,
utils.DefaultLogger.WithPrefix("server"),
@ -541,7 +537,7 @@ var _ = Describe("Crypto Setup TLS", func() {
sTransportParameters,
sRunner,
serverConf,
nil,
false,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
@ -596,7 +592,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&wire.TransportParameters{ActiveConnectionIDLimit: 2, StatelessResetToken: &token},
sRunner,
serverConf,
nil,
false,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),
@ -655,7 +651,7 @@ var _ = Describe("Crypto Setup TLS", func() {
&wire.TransportParameters{ActiveConnectionIDLimit: 2, StatelessResetToken: &token},
sRunner,
serverConf,
nil,
false,
&utils.RTTStats{},
nil,
utils.DefaultLogger.WithPrefix("server"),

View file

@ -46,6 +46,7 @@ func main() {
// a quic.Config that doesn't do a Retry
quicConf := &quic.Config{
RequireAddressValidation: func(net.Addr) bool { return testcase == "retry" },
Allow0RTT: testcase == "zerortt",
Tracer: qlog.NewTracer(getLogWriter),
}
cert, err := tls.LoadX509KeyPair("/certs/cert.pem", "/certs/priv.key")
@ -59,10 +60,7 @@ func main() {
}
switch testcase {
case "zerortt":
quicConf.Allow0RTT = func(net.Addr) bool { return true }
fallthrough
case "versionnegotiation", "handshake", "retry", "transfer", "resumption", "multiconnect":
case "versionnegotiation", "handshake", "retry", "transfer", "resumption", "multiconnect", "zerortt":
err = runHTTP09Server(quicConf)
case "chacha20":
reset := qtls.SetCipherSuite(tls.TLS_CHACHA20_POLY1305_SHA256)