Fix rustls hangs during handshake #103

This commit is contained in:
Nikolay Kim 2022-02-19 21:54:09 +06:00
parent cb356517a4
commit 6ae0cd002d
7 changed files with 76 additions and 8 deletions

View file

@ -1,5 +1,9 @@
# Changes
## [0.1.5] - 2022-02-19
* Fix rustls hangs during handshake #103
* Move HttpProtocol to ntex-io
## [0.1.4] - 2022-02-11

View file

@ -38,7 +38,8 @@ tls_openssl = { version="0.10", package = "openssl", optional = true }
tls_rust = { version = "0.20", package = "rustls", optional = true }
[dev-dependencies]
ntex = { version = "0.5", features = ["openssl", "rustls"] }
ntex = { version = "0.5", features = ["openssl", "rustls", "tokio"] }
log = "0.4"
env_logger = "0.9"
rustls-pemfile = { version = "0.2" }
webpki-roots = { version = "0.22" }

View file

@ -0,0 +1,47 @@
use std::io;
use ntex::{codec, connect, io::types::PeerAddr, util::Bytes, util::Either};
use tls_rust::{ClientConfig, OwnedTrustAnchor, RootCertStore};
#[ntex::main]
async fn main() -> io::Result<()> {
std::env::set_var("RUST_LOG", "trace");
env_logger::init();
// rustls config
let mut cert_store = RootCertStore::empty();
cert_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(
|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
},
));
let config = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(cert_store)
.with_no_client_auth();
// rustls connector
let connector = connect::rustls::Connector::new(config.clone());
let io = connector.connect("www.rust-lang.org:443").await.unwrap();
println!("Connected to tls server {:?}", io.query::<PeerAddr>().get());
let result = io
.send(Bytes::from_static(b"GET /\r\n\r\n"), &codec::BytesCodec)
.await
.map_err(Either::into_inner)?;
println!("Send result: {:?}", result);
let resp = io
.recv(&codec::BytesCodec)
.await
.map_err(|e| e.into_inner())?
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "disconnected"))?;
println!("Received: {:?}", resp);
println!("disconnecting");
io.shutdown().await
}

View file

@ -3,7 +3,7 @@ use std::{fs::File, io, io::BufReader, sync::Arc};
use ntex::service::{fn_service, pipeline_factory};
use ntex::{codec, io::filter, io::Io, server, util::Either};
use ntex_tls::rustls::TlsAcceptor;
use rustls_pemfile::{certs, pkcs8_private_keys};
use rustls_pemfile::{certs, rsa_private_keys};
use tls_rust::{Certificate, PrivateKey, ServerConfig};
#[ntex::main]
@ -17,7 +17,7 @@ async fn main() -> io::Result<()> {
let cert_file =
&mut BufReader::new(File::open("../ntex-tls/examples/cert.pem").unwrap());
let key_file = &mut BufReader::new(File::open("../ntex-tls/examples/key.pem").unwrap());
let keys = PrivateKey(pkcs8_private_keys(key_file).unwrap().remove(0));
let keys = PrivateKey(rsa_private_keys(key_file).unwrap().remove(0));
let cert_chain = certs(cert_file)
.unwrap()
.iter()

View file

@ -196,10 +196,14 @@ impl<F: Filter> TlsClientFilter<F> {
let filter = io.filter();
loop {
let (result, wants_read) = {
let (result, wants_read, handshaking) = {
let mut session = filter.client().session.borrow_mut();
let mut wrp = Wrapper(&filter.client().inner);
(session.complete_io(&mut wrp), session.wants_read())
(
session.complete_io(&mut wrp),
session.wants_read(),
session.is_handshaking(),
)
};
match result {
Ok(_) => {
@ -207,6 +211,10 @@ impl<F: Filter> TlsClientFilter<F> {
return Ok(io);
}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
if !handshaking {
filter.client().inner.handshake.set(false);
return Ok(io);
}
poll_fn(|cx| {
let read_ready = if wants_read {
match ready!(io.poll_read_ready(cx))? {

View file

@ -196,10 +196,14 @@ impl<F: Filter> TlsServerFilter<F> {
let filter = io.filter();
loop {
let (result, wants_read) = {
let (result, wants_read, handshaking) = {
let mut session = filter.server().session.borrow_mut();
let mut wrp = Wrapper(&filter.server().inner);
(session.complete_io(&mut wrp), session.wants_read())
(
session.complete_io(&mut wrp),
session.wants_read(),
session.is_handshaking(),
)
};
match result {
Ok(_) => {
@ -207,6 +211,10 @@ impl<F: Filter> TlsServerFilter<F> {
return Ok(io);
}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
if !handshaking {
filter.server().inner.handshake.set(false);
return Ok(io);
}
poll_fn(|cx| {
let read_ready = if wants_read {
match ready!(io.poll_read_ready(cx))? {

View file

@ -79,7 +79,7 @@ thiserror = "1.0"
# http/web framework
h2 = "0.3.9"
http = "0.2"
httparse = "1.5.1"
httparse = "1.6.0"
httpdate = "1.0"
encoding_rs = "0.8"
mime = "0.3"