Refactor ntex-connect (#314)

This commit is contained in:
Nikolay Kim 2024-03-24 15:33:08 +01:00 committed by GitHub
parent 5414e2096a
commit baabcff4a6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
35 changed files with 446 additions and 437 deletions

View file

@ -1,5 +1,11 @@
# Changes
## [1.1.0] - 2024-03-24
* Move tls connectors from ntex-connect
* Upgrade to rustls 0.23
## [1.0.0] - 2024-01-09
* Release

View file

@ -26,9 +26,11 @@ rustls = ["tls_rust"]
[dependencies]
ntex-bytes = "0.1.21"
ntex-io = "1.0.0"
ntex-util = "1.0.0"
ntex-service = "2.0.0"
ntex-io = "1.0"
ntex-util = "1.0"
ntex-service = "2.0"
ntex-net = "1.0"
log = "0.4"
# openssl

View file

@ -9,7 +9,8 @@ async fn main() -> io::Result<()> {
env_logger::init();
// rustls config
let cert_store = RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let cert_store =
RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let config = ClientConfig::builder()
.with_root_certificates(cert_store)
.with_no_client_auth();

View file

@ -17,7 +17,9 @@ async fn main() -> io::Result<()> {
&mut BufReader::new(File::open("../ntex-tls/examples/cert.pem").unwrap());
let key_file = &mut BufReader::new(File::open("../ntex-tls/examples/key.pem").unwrap());
let keys = rustls_pemfile::private_key(key_file).unwrap().unwrap();
let cert_chain = rustls_pemfile::certs(cert_file).collect::<Result<Vec<_>, _>>().unwrap();
let cert_chain = rustls_pemfile::certs(cert_file)
.collect::<Result<Vec<_>, _>>()
.unwrap();
let tls_config = Arc::new(
ServerConfig::builder()
.with_no_client_auth()

View file

@ -0,0 +1,146 @@
use std::{fmt, io};
use ntex_bytes::PoolId;
use ntex_io::{Io, Layer};
use ntex_net::connect::{Address, Connect, ConnectError, Connector as BaseConnector};
use ntex_service::{Pipeline, Service, ServiceCtx, ServiceFactory};
use tls_openssl::ssl::SslConnector as BaseSslConnector;
use super::{connect as connect_io, SslFilter};
pub struct SslConnector<T> {
connector: Pipeline<BaseConnector<T>>,
openssl: BaseSslConnector,
}
impl<T: Address> SslConnector<T> {
/// Construct new OpensslConnectService factory
pub fn new(connector: BaseSslConnector) -> Self {
SslConnector {
connector: BaseConnector::default().into(),
openssl: connector,
}
}
/// Set memory pool.
///
/// Use specified memory pool for memory allocations. By default P0
/// memory pool is used.
pub fn memory_pool(self, id: PoolId) -> Self {
let connector = self
.connector
.into_service()
.expect("Connector has been cloned")
.memory_pool(id)
.into();
Self {
connector,
openssl: self.openssl,
}
}
}
impl<T: Address> SslConnector<T> {
/// Resolve and connect to remote host
pub async fn connect<U>(&self, message: U) -> Result<Io<Layer<SslFilter>>, ConnectError>
where
Connect<T>: From<U>,
{
let message = Connect::from(message);
let host = message.host().split(':').next().unwrap().to_string();
let conn = self.connector.call(message);
let openssl = self.openssl.clone();
let io = conn.await?;
log::trace!("{}: SSL Handshake start for: {:?}", io.tag(), host);
match openssl.configure() {
Err(e) => Err(io::Error::new(io::ErrorKind::Other, e).into()),
Ok(config) => {
let ssl = config
.into_ssl(&host)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
let tag = io.tag();
match connect_io(io, ssl).await {
Ok(io) => {
log::trace!("{}: SSL Handshake success: {:?}", tag, host);
Ok(io)
}
Err(e) => {
log::trace!("{}: SSL Handshake error: {:?}", tag, e);
Err(io::Error::new(io::ErrorKind::Other, format!("{}", e)).into())
}
}
}
}
}
}
impl<T> Clone for SslConnector<T> {
fn clone(&self) -> Self {
Self {
connector: self.connector.clone(),
openssl: self.openssl.clone(),
}
}
}
impl<T> fmt::Debug for SslConnector<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SslConnector(openssl)")
.field("connector", &self.connector)
.field("openssl", &self.openssl)
.finish()
}
}
impl<T: Address, C> ServiceFactory<Connect<T>, C> for SslConnector<T> {
type Response = Io<Layer<SslFilter>>;
type Error = ConnectError;
type Service = SslConnector<T>;
type InitError = ();
async fn create(&self, _: C) -> Result<Self::Service, Self::InitError> {
Ok(self.clone())
}
}
impl<T: Address> Service<Connect<T>> for SslConnector<T> {
type Response = Io<Layer<SslFilter>>;
type Error = ConnectError;
async fn call(
&self,
req: Connect<T>,
_: ServiceCtx<'_, Self>,
) -> Result<Self::Response, Self::Error> {
self.connect(req).await
}
}
#[cfg(test)]
mod tests {
use tls_openssl::ssl::SslMethod;
use super::*;
#[ntex::test]
async fn test_openssl_connect() {
let server = ntex::server::test_server(|| {
ntex::service::fn_service(|_| async { Ok::<_, ()>(()) })
});
let ssl = BaseSslConnector::builder(SslMethod::tls()).unwrap();
let factory = SslConnector::new(ssl.build())
.memory_pool(PoolId::P5)
.clone();
let srv = factory.pipeline(&()).await.unwrap();
let result = srv
.call(Connect::new("").set_addr(Some(server.addr())))
.await;
assert!(result.is_err());
assert!(format!("{:?}", srv).contains("SslConnector"));
}
}

View file

@ -8,6 +8,9 @@ use tls_openssl::x509::X509;
use crate::{PskIdentity, Servername};
mod connect;
pub use self::connect::SslConnector;
mod accept;
pub use self::accept::{SslAcceptor, SslAcceptorService};

View file

@ -64,11 +64,17 @@ impl FilterLayer for TlsClientFilter {
buf.with_dst(|dst| {
loop {
let mut cursor = io::Cursor::new(&src);
let n = session.read_tls(&mut cursor)?;
let n = match session.read_tls(&mut cursor) {
Ok(n) => n,
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
break
}
Err(err) => return Err(err),
};
src.split_to(n);
let state = session
.process_new_packets()
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
let new_b = state.plaintext_bytes_to_read();
if new_b > 0 {
@ -92,18 +98,26 @@ impl FilterLayer for TlsClientFilter {
fn process_write_buf(&self, buf: &WriteBuf<'_>) -> io::Result<()> {
buf.with_src(|src| {
if let Some(src) = src {
let mut session = self.session.borrow_mut();
let mut io = Wrapper(buf);
let mut session = self.session.borrow_mut();
loop {
'outer: loop {
if !src.is_empty() {
src.split_to(session.writer().write(src)?);
}
if session.wants_write() {
session.complete_io(&mut io)?;
} else {
break;
}
while session.wants_write() {
match session.write_tls(&mut io) {
Ok(0) => continue 'outer,
Ok(_) => continue,
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
break
}
Err(err) => return Err(err),
}
}
break;
}
}
Ok(())

View file

@ -0,0 +1,157 @@
use std::{fmt, io, sync::Arc};
use ntex_bytes::PoolId;
use ntex_io::{Io, Layer};
use ntex_net::connect::{Address, Connect, ConnectError, Connector as BaseConnector};
use ntex_service::{Pipeline, Service, ServiceCtx, ServiceFactory};
use tls_rust::{pki_types::ServerName, ClientConfig};
use super::TlsClientFilter;
/// Rustls connector factory
pub struct TlsConnector<T> {
connector: Pipeline<BaseConnector<T>>,
config: Arc<ClientConfig>,
}
impl<T: Address> From<Arc<ClientConfig>> for TlsConnector<T> {
fn from(config: Arc<ClientConfig>) -> Self {
TlsConnector {
config,
connector: BaseConnector::default().into(),
}
}
}
impl<T: Address> TlsConnector<T> {
pub fn new(config: ClientConfig) -> Self {
TlsConnector {
config: Arc::new(config),
connector: BaseConnector::default().into(),
}
}
/// Set memory pool.
///
/// Use specified memory pool for memory allocations. By default P0
/// memory pool is used.
pub fn memory_pool(self, id: PoolId) -> Self {
let connector = self
.connector
.into_service()
.unwrap()
.memory_pool(id)
.into();
Self {
connector,
config: self.config,
}
}
}
impl<T: Address> TlsConnector<T> {
/// Resolve and connect to remote host
pub async fn connect<U>(
&self,
message: U,
) -> Result<Io<Layer<TlsClientFilter>>, ConnectError>
where
Connect<T>: From<U>,
{
let req = Connect::from(message);
let host = req.host().split(':').next().unwrap().to_owned();
let io = self.connector.call(req).await?;
log::trace!("{}: SSL Handshake start for: {:?}", io.tag(), host);
let tag = io.tag();
let config = self.config.clone();
let host = ServerName::try_from(host)
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("{}", e)))?;
match TlsClientFilter::create(io, config, host.clone()).await {
Ok(io) => {
log::trace!("{}: TLS Handshake success: {:?}", tag, &host);
Ok(io)
}
Err(e) => {
log::trace!("{}: TLS Handshake error: {:?}", tag, e);
Err(e.into())
}
}
}
}
impl<T> Clone for TlsConnector<T> {
fn clone(&self) -> Self {
Self {
config: self.config.clone(),
connector: self.connector.clone(),
}
}
}
impl<T> fmt::Debug for TlsConnector<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TlsConnector(rustls)")
.field("connector", &self.connector)
.finish()
}
}
impl<T: Address, C> ServiceFactory<Connect<T>, C> for TlsConnector<T> {
type Response = Io<Layer<TlsClientFilter>>;
type Error = ConnectError;
type Service = TlsConnector<T>;
type InitError = ();
async fn create(&self, _: C) -> Result<Self::Service, Self::InitError> {
Ok(self.clone())
}
}
impl<T: Address> Service<Connect<T>> for TlsConnector<T> {
type Response = Io<Layer<TlsClientFilter>>;
type Error = ConnectError;
async fn call(
&self,
req: Connect<T>,
_: ServiceCtx<'_, Self>,
) -> Result<Self::Response, Self::Error> {
self.connect(req).await
}
}
#[cfg(test)]
mod tests {
use tls_rust::RootCertStore;
use super::*;
use ntex_util::future::lazy;
#[ntex::test]
async fn test_rustls_connect() {
let server = ntex::server::test_server(|| {
ntex::service::fn_service(|_| async { Ok::<_, ()>(()) })
});
let cert_store =
RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let config = ClientConfig::builder()
.with_root_certificates(cert_store)
.with_no_client_auth();
let _ = TlsConnector::<&'static str>::new(config.clone()).clone();
let factory = TlsConnector::from(Arc::new(config))
.memory_pool(PoolId::P5)
.clone();
let srv = factory.pipeline(&()).await.unwrap();
// always ready
assert!(lazy(|cx| srv.poll_ready(cx)).await.is_ready());
let result = srv
.call(Connect::new("www.rust-lang.org").set_addr(Some(server.addr())))
.await;
assert!(result.is_err());
}
}

View file

@ -6,10 +6,12 @@ use tls_rust::pki_types::CertificateDer;
mod accept;
mod client;
mod connect;
mod server;
pub use accept::{TlsAcceptor, TlsAcceptorService};
pub use self::accept::{TlsAcceptor, TlsAcceptorService};
pub use self::client::TlsClientFilter;
pub use self::connect::TlsConnector;
pub use self::server::TlsServerFilter;
/// Connection's peer cert

View file

@ -72,11 +72,17 @@ impl FilterLayer for TlsServerFilter {
buf.with_dst(|dst| {
loop {
let mut cursor = io::Cursor::new(&src);
let n = session.read_tls(&mut cursor)?;
let n = match session.read_tls(&mut cursor) {
Ok(n) => n,
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
break
}
Err(err) => return Err(err),
};
src.split_to(n);
let state = session
.process_new_packets()
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
let new_b = state.plaintext_bytes_to_read();
if new_b > 0 {
@ -100,18 +106,26 @@ impl FilterLayer for TlsServerFilter {
fn process_write_buf(&self, buf: &WriteBuf<'_>) -> io::Result<()> {
buf.with_src(|src| {
if let Some(src) = src {
let mut session = self.session.borrow_mut();
let mut io = Wrapper(buf);
let mut session = self.session.borrow_mut();
loop {
'outer: loop {
if !src.is_empty() {
src.split_to(session.writer().write(src)?);
}
if session.wants_write() {
session.complete_io(&mut io)?;
} else {
break;
}
while session.wants_write() {
match session.write_tls(&mut io) {
Ok(0) => continue 'outer,
Ok(_) => continue,
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
break
}
Err(err) => return Err(err),
}
}
break;
}
}
Ok(())