disable address validation by default

We should provide safe defaults. Since we implement the 3x amplification
limit, disabling address validation is not unsafe, and will save 1 RTT
for every handshake for applications that don't explicitely configure
Retries.
This commit is contained in:
Marten Seemann 2022-08-11 22:03:10 +04:00
parent 7fde609eef
commit bbfb7bd493
8 changed files with 35 additions and 62 deletions

View file

@ -46,7 +46,7 @@ func populateServerConfig(config *Config) *Config {
config.MaxRetryTokenAge = protocol.RetryTokenValidity config.MaxRetryTokenAge = protocol.RetryTokenValidity
} }
if config.RequireAddressValidation == nil { if config.RequireAddressValidation == nil {
config.RequireAddressValidation = func(net.Addr) bool { return true } config.RequireAddressValidation = func(net.Addr) bool { return false }
} }
return config return config
} }

View file

@ -37,13 +37,11 @@ var _ = Describe("Handshake drop tests", func() {
startListenerAndProxy := func(dropCallback quicproxy.DropCallback, doRetry bool, longCertChain bool, version protocol.VersionNumber) { startListenerAndProxy := func(dropCallback quicproxy.DropCallback, doRetry bool, longCertChain bool, version protocol.VersionNumber) {
conf := getQuicConfig(&quic.Config{ conf := getQuicConfig(&quic.Config{
MaxIdleTimeout: timeout, MaxIdleTimeout: timeout,
HandshakeIdleTimeout: timeout, HandshakeIdleTimeout: timeout,
Versions: []protocol.VersionNumber{version}, Versions: []protocol.VersionNumber{version},
RequireAddressValidation: func(net.Addr) bool { return doRetry },
}) })
if doRetry {
conf.RequireAddressValidation = func(net.Addr) bool { return true }
}
var tlsConf *tls.Config var tlsConf *tls.Config
if longCertChain { if longCertChain {
tlsConf = getTLSConfigWithLongCertChain() tlsConf = getTLSConfigWithLongCertChain()

View file

@ -101,6 +101,7 @@ var _ = Describe("Handshake RTT tests", func() {
// 1 RTT for verifying the source address // 1 RTT for verifying the source address
// 1 RTT for the TLS handshake // 1 RTT for the TLS handshake
It("is forward-secure after 2 RTTs", func() { It("is forward-secure after 2 RTTs", func() {
serverConfig.RequireAddressValidation = func(net.Addr) bool { return true }
runServerAndProxy() runServerAndProxy()
_, err := quic.DialAddr( _, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port), fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
@ -112,7 +113,6 @@ var _ = Describe("Handshake RTT tests", func() {
}) })
It("establishes a connection in 1 RTT when the server doesn't require a token", func() { It("establishes a connection in 1 RTT when the server doesn't require a token", func() {
serverConfig.RequireAddressValidation = func(net.Addr) bool { return false }
runServerAndProxy() runServerAndProxy()
_, err := quic.DialAddr( _, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port), fmt.Sprintf("localhost:%d", proxy.LocalAddr().(*net.UDPAddr).Port),
@ -124,7 +124,6 @@ var _ = Describe("Handshake RTT tests", func() {
}) })
It("establishes a connection in 2 RTTs if a HelloRetryRequest is performed", func() { It("establishes a connection in 2 RTTs if a HelloRetryRequest is performed", func() {
serverConfig.RequireAddressValidation = func(net.Addr) bool { return false }
serverTLSConfig.CurvePreferences = []tls.CurveID{tls.CurveP384} serverTLSConfig.CurvePreferences = []tls.CurveID{tls.CurveP384}
runServerAndProxy() runServerAndProxy()
_, err := quic.DialAddr( _, err := quic.DialAddr(

View file

@ -344,7 +344,6 @@ var _ = Describe("Handshake tests", func() {
} }
BeforeEach(func() { BeforeEach(func() {
serverConfig.RequireAddressValidation = func(net.Addr) bool { return false }
var err error var err error
// start the server, but don't call Accept // start the server, but don't call Accept
server, err = quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) server, err = quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
@ -474,8 +473,6 @@ var _ = Describe("Handshake tests", func() {
Context("using tokens", func() { Context("using tokens", func() {
It("uses tokens provided in NEW_TOKEN frames", func() { It("uses tokens provided in NEW_TOKEN frames", func() {
serverConfig.RequireAddressValidation = func(net.Addr) bool { return false }
server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) server, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())

View file

@ -347,6 +347,7 @@ var _ = Describe("MITM test", func() {
// as it has already accepted a retry. // as it has already accepted a retry.
// TODO: determine behavior when server does not send Retry packets // TODO: determine behavior when server does not send Retry packets
It("fails when a forged Retry packet with modified srcConnID is sent to client", func() { It("fails when a forged Retry packet with modified srcConnID is sent to client", func() {
serverConfig.RequireAddressValidation = func(net.Addr) bool { return true }
var initialPacketIntercepted bool var initialPacketIntercepted bool
done := make(chan struct{}) done := make(chan struct{})
delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration { delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {

View file

@ -26,9 +26,8 @@ var _ = Describe("Packetization", func() {
"localhost:0", "localhost:0",
getTLSConfig(), getTLSConfig(),
getQuicConfig(&quic.Config{ getQuicConfig(&quic.Config{
RequireAddressValidation: func(net.Addr) bool { return false }, DisablePathMTUDiscovery: true,
DisablePathMTUDiscovery: true, Tracer: newTracer(func() logging.ConnectionTracer { return serverTracer }),
Tracer: newTracer(func() logging.ConnectionTracer { return serverTracer }),
}), }),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())

View file

@ -55,9 +55,7 @@ var _ = Describe("0-RTT", func() {
dialAndReceiveSessionTicket := func(serverConf *quic.Config) (*tls.Config, *tls.Config) { dialAndReceiveSessionTicket := func(serverConf *quic.Config) (*tls.Config, *tls.Config) {
tlsConf := getTLSConfig() tlsConf := getTLSConfig()
if serverConf == nil { if serverConf == nil {
serverConf = getQuicConfig(&quic.Config{ serverConf = getQuicConfig(nil)
RequireAddressValidation: func(net.Addr) bool { return false },
})
serverConf.Versions = []protocol.VersionNumber{version} serverConf.Versions = []protocol.VersionNumber{version}
} }
ln, err := quic.ListenAddrEarly( ln, err := quic.ListenAddrEarly(
@ -197,9 +195,8 @@ var _ = Describe("0-RTT", func() {
"localhost:0", "localhost:0",
tlsConf, tlsConf,
getQuicConfig(&quic.Config{ getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version}, Versions: []protocol.VersionNumber{version},
RequireAddressValidation: func(net.Addr) bool { return false }, Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
}), }),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -255,9 +252,8 @@ var _ = Describe("0-RTT", func() {
"localhost:0", "localhost:0",
tlsConf, tlsConf,
getQuicConfig(&quic.Config{ getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version}, Versions: []protocol.VersionNumber{version},
RequireAddressValidation: func(net.Addr) bool { return false }, Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
}), }),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -400,8 +396,9 @@ var _ = Describe("0-RTT", func() {
"localhost:0", "localhost:0",
tlsConf, tlsConf,
getQuicConfig(&quic.Config{ getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version}, Versions: []protocol.VersionNumber{version},
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), RequireAddressValidation: func(net.Addr) bool { return true },
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
}), }),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -452,8 +449,7 @@ var _ = Describe("0-RTT", func() {
It("doesn't reject 0-RTT when the server's transport stream limit increased", func() { It("doesn't reject 0-RTT when the server's transport stream limit increased", func() {
const maxStreams = 1 const maxStreams = 1
tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{ tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{
MaxIncomingUniStreams: maxStreams, MaxIncomingUniStreams: maxStreams,
RequireAddressValidation: func(net.Addr) bool { return false },
})) }))
tracer := newPacketTracer() tracer := newPacketTracer()
@ -461,10 +457,9 @@ var _ = Describe("0-RTT", func() {
"localhost:0", "localhost:0",
tlsConf, tlsConf,
getQuicConfig(&quic.Config{ getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version}, Versions: []protocol.VersionNumber{version},
RequireAddressValidation: func(net.Addr) bool { return false }, MaxIncomingUniStreams: maxStreams + 1,
MaxIncomingUniStreams: maxStreams + 1, Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
}), }),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -498,8 +493,7 @@ var _ = Describe("0-RTT", func() {
It("rejects 0-RTT when the server's stream limit decreased", func() { It("rejects 0-RTT when the server's stream limit decreased", func() {
const maxStreams = 42 const maxStreams = 42
tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{ tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{
MaxIncomingStreams: maxStreams, MaxIncomingStreams: maxStreams,
RequireAddressValidation: func(net.Addr) bool { return false },
})) }))
tracer := newPacketTracer() tracer := newPacketTracer()
@ -507,10 +501,9 @@ var _ = Describe("0-RTT", func() {
"localhost:0", "localhost:0",
tlsConf, tlsConf,
getQuicConfig(&quic.Config{ getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version}, Versions: []protocol.VersionNumber{version},
RequireAddressValidation: func(net.Addr) bool { return false }, MaxIncomingStreams: maxStreams - 1,
MaxIncomingStreams: maxStreams - 1, Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
}), }),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -537,9 +530,8 @@ var _ = Describe("0-RTT", func() {
"localhost:0", "localhost:0",
tlsConf, tlsConf,
getQuicConfig(&quic.Config{ getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version}, Versions: []protocol.VersionNumber{version},
RequireAddressValidation: func(net.Addr) bool { return false }, Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
}), }),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -560,16 +552,14 @@ var _ = Describe("0-RTT", func() {
func(addFlowControlLimit func(*quic.Config, uint64)) { func(addFlowControlLimit func(*quic.Config, uint64)) {
tracer := newPacketTracer() tracer := newPacketTracer()
firstConf := getQuicConfig(&quic.Config{ firstConf := getQuicConfig(&quic.Config{
RequireAddressValidation: func(net.Addr) bool { return false }, Versions: []protocol.VersionNumber{version},
Versions: []protocol.VersionNumber{version},
}) })
addFlowControlLimit(firstConf, 3) addFlowControlLimit(firstConf, 3)
tlsConf, clientConf := dialAndReceiveSessionTicket(firstConf) tlsConf, clientConf := dialAndReceiveSessionTicket(firstConf)
secondConf := getQuicConfig(&quic.Config{ secondConf := getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version}, Versions: []protocol.VersionNumber{version},
RequireAddressValidation: func(net.Addr) bool { return false }, Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
}) })
addFlowControlLimit(secondConf, 100) addFlowControlLimit(secondConf, 100)
ln, err := quic.ListenAddrEarly( ln, err := quic.ListenAddrEarly(
@ -722,9 +712,8 @@ var _ = Describe("0-RTT", func() {
"localhost:0", "localhost:0",
tlsConf, tlsConf,
getQuicConfig(&quic.Config{ getQuicConfig(&quic.Config{
Versions: []protocol.VersionNumber{version}, Versions: []protocol.VersionNumber{version},
RequireAddressValidation: func(net.Addr) bool { return false }, Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
Tracer: newTracer(func() logging.ConnectionTracer { return tracer }),
}), }),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())

View file

@ -241,8 +241,9 @@ var _ = Describe("Server", func() {
It("creates a connection when the token is accepted", func() { It("creates a connection when the token is accepted", func() {
serv.config.RequireAddressValidation = func(net.Addr) bool { return true } serv.config.RequireAddressValidation = func(net.Addr) bool { return true }
raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
retryToken, err := serv.tokenGenerator.NewRetryToken( retryToken, err := serv.tokenGenerator.NewRetryToken(
&net.UDPAddr{}, raddr,
protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde},
protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad},
) )
@ -256,6 +257,7 @@ var _ = Describe("Server", func() {
Token: retryToken, Token: retryToken,
} }
p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize))
p.remoteAddr = raddr
run := make(chan struct{}) run := make(chan struct{})
var token protocol.StatelessResetToken var token protocol.StatelessResetToken
rand.Read(token[:]) rand.Read(token[:])
@ -451,7 +453,6 @@ var _ = Describe("Server", func() {
}) })
It("creates a connection, if no token is required", func() { It("creates a connection, if no token is required", func() {
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }
hdr := &wire.Header{ hdr := &wire.Header{
IsLongHeader: true, IsLongHeader: true,
Type: protocol.PacketTypeInitial, Type: protocol.PacketTypeInitial,
@ -534,7 +535,6 @@ var _ = Describe("Server", func() {
}).AnyTimes() }).AnyTimes()
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).AnyTimes() tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).AnyTimes()
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }
acceptConn := make(chan struct{}) acceptConn := make(chan struct{})
var counter uint32 // to be used as an atomic, so we query it in Eventually var counter uint32 // to be used as an atomic, so we query it in Eventually
serv.newConn = func( serv.newConn = func(
@ -588,7 +588,6 @@ var _ = Describe("Server", func() {
}) })
It("only creates a single connection for a duplicate Initial", func() { It("only creates a single connection for a duplicate Initial", func() {
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }
var createdConn bool var createdConn bool
conn := NewMockQuicConn(mockCtrl) conn := NewMockQuicConn(mockCtrl)
serv.newConn = func( serv.newConn = func(
@ -620,8 +619,6 @@ var _ = Describe("Server", func() {
}) })
It("rejects new connection attempts if the accept queue is full", func() { It("rejects new connection attempts if the accept queue is full", func() {
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }
serv.newConn = func( serv.newConn = func(
_ sendConn, _ sendConn,
runner connRunner, runner connRunner,
@ -688,8 +685,6 @@ var _ = Describe("Server", func() {
}) })
It("doesn't accept new connections if they were closed in the mean time", func() { It("doesn't accept new connections if they were closed in the mean time", func() {
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }
p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
connCreated := make(chan struct{}) connCreated := make(chan struct{})
@ -999,7 +994,6 @@ var _ = Describe("Server", func() {
}() }()
ctx, cancel := context.WithCancel(context.Background()) // handshake context ctx, cancel := context.WithCancel(context.Background()) // handshake context
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }
serv.newConn = func( serv.newConn = func(
_ sendConn, _ sendConn,
runner connRunner, runner connRunner,
@ -1073,7 +1067,6 @@ var _ = Describe("Server", func() {
}() }()
ready := make(chan struct{}) ready := make(chan struct{})
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }
serv.newConn = func( serv.newConn = func(
_ sendConn, _ sendConn,
runner connRunner, runner connRunner,
@ -1114,7 +1107,6 @@ var _ = Describe("Server", func() {
}) })
It("rejects new connection attempts if the accept queue is full", func() { It("rejects new connection attempts if the accept queue is full", func() {
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }
senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42} senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42}
serv.newConn = func( serv.newConn = func(
@ -1175,8 +1167,6 @@ var _ = Describe("Server", func() {
}) })
It("doesn't accept new connections if they were closed in the mean time", func() { It("doesn't accept new connections if they were closed in the mean time", func() {
serv.config.RequireAddressValidation = func(net.Addr) bool { return false }
p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
connCreated := make(chan struct{}) connCreated := make(chan struct{})