mirror of
https://github.com/ntex-rs/ntex.git
synced 2025-04-03 21:07:39 +03:00
Refactor ntex-connect (#314)
This commit is contained in:
parent
5414e2096a
commit
baabcff4a6
35 changed files with 446 additions and 437 deletions
|
@ -1,5 +1,11 @@
|
|||
# Changes
|
||||
|
||||
## [1.1.0] - 2024-03-24
|
||||
|
||||
* Move tls connectors from ntex-connect
|
||||
|
||||
* Upgrade to rustls 0.23
|
||||
|
||||
## [1.0.0] - 2024-01-09
|
||||
|
||||
* Release
|
||||
|
|
|
@ -26,9 +26,11 @@ rustls = ["tls_rust"]
|
|||
|
||||
[dependencies]
|
||||
ntex-bytes = "0.1.21"
|
||||
ntex-io = "1.0.0"
|
||||
ntex-util = "1.0.0"
|
||||
ntex-service = "2.0.0"
|
||||
ntex-io = "1.0"
|
||||
ntex-util = "1.0"
|
||||
ntex-service = "2.0"
|
||||
ntex-net = "1.0"
|
||||
|
||||
log = "0.4"
|
||||
|
||||
# openssl
|
||||
|
|
|
@ -9,7 +9,8 @@ async fn main() -> io::Result<()> {
|
|||
env_logger::init();
|
||||
|
||||
// rustls config
|
||||
let cert_store = RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
|
||||
let cert_store =
|
||||
RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
|
||||
let config = ClientConfig::builder()
|
||||
.with_root_certificates(cert_store)
|
||||
.with_no_client_auth();
|
||||
|
|
|
@ -17,7 +17,9 @@ async fn main() -> io::Result<()> {
|
|||
&mut BufReader::new(File::open("../ntex-tls/examples/cert.pem").unwrap());
|
||||
let key_file = &mut BufReader::new(File::open("../ntex-tls/examples/key.pem").unwrap());
|
||||
let keys = rustls_pemfile::private_key(key_file).unwrap().unwrap();
|
||||
let cert_chain = rustls_pemfile::certs(cert_file).collect::<Result<Vec<_>, _>>().unwrap();
|
||||
let cert_chain = rustls_pemfile::certs(cert_file)
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.unwrap();
|
||||
let tls_config = Arc::new(
|
||||
ServerConfig::builder()
|
||||
.with_no_client_auth()
|
||||
|
|
146
ntex-tls/src/openssl/connect.rs
Normal file
146
ntex-tls/src/openssl/connect.rs
Normal file
|
@ -0,0 +1,146 @@
|
|||
use std::{fmt, io};
|
||||
|
||||
use ntex_bytes::PoolId;
|
||||
use ntex_io::{Io, Layer};
|
||||
use ntex_net::connect::{Address, Connect, ConnectError, Connector as BaseConnector};
|
||||
use ntex_service::{Pipeline, Service, ServiceCtx, ServiceFactory};
|
||||
use tls_openssl::ssl::SslConnector as BaseSslConnector;
|
||||
|
||||
use super::{connect as connect_io, SslFilter};
|
||||
|
||||
pub struct SslConnector<T> {
|
||||
connector: Pipeline<BaseConnector<T>>,
|
||||
openssl: BaseSslConnector,
|
||||
}
|
||||
|
||||
impl<T: Address> SslConnector<T> {
|
||||
/// Construct new OpensslConnectService factory
|
||||
pub fn new(connector: BaseSslConnector) -> Self {
|
||||
SslConnector {
|
||||
connector: BaseConnector::default().into(),
|
||||
openssl: connector,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set memory pool.
|
||||
///
|
||||
/// Use specified memory pool for memory allocations. By default P0
|
||||
/// memory pool is used.
|
||||
pub fn memory_pool(self, id: PoolId) -> Self {
|
||||
let connector = self
|
||||
.connector
|
||||
.into_service()
|
||||
.expect("Connector has been cloned")
|
||||
.memory_pool(id)
|
||||
.into();
|
||||
|
||||
Self {
|
||||
connector,
|
||||
openssl: self.openssl,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Address> SslConnector<T> {
|
||||
/// Resolve and connect to remote host
|
||||
pub async fn connect<U>(&self, message: U) -> Result<Io<Layer<SslFilter>>, ConnectError>
|
||||
where
|
||||
Connect<T>: From<U>,
|
||||
{
|
||||
let message = Connect::from(message);
|
||||
let host = message.host().split(':').next().unwrap().to_string();
|
||||
let conn = self.connector.call(message);
|
||||
let openssl = self.openssl.clone();
|
||||
|
||||
let io = conn.await?;
|
||||
log::trace!("{}: SSL Handshake start for: {:?}", io.tag(), host);
|
||||
|
||||
match openssl.configure() {
|
||||
Err(e) => Err(io::Error::new(io::ErrorKind::Other, e).into()),
|
||||
Ok(config) => {
|
||||
let ssl = config
|
||||
.into_ssl(&host)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
|
||||
let tag = io.tag();
|
||||
match connect_io(io, ssl).await {
|
||||
Ok(io) => {
|
||||
log::trace!("{}: SSL Handshake success: {:?}", tag, host);
|
||||
Ok(io)
|
||||
}
|
||||
Err(e) => {
|
||||
log::trace!("{}: SSL Handshake error: {:?}", tag, e);
|
||||
Err(io::Error::new(io::ErrorKind::Other, format!("{}", e)).into())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Clone for SslConnector<T> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
connector: self.connector.clone(),
|
||||
openssl: self.openssl.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> fmt::Debug for SslConnector<T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("SslConnector(openssl)")
|
||||
.field("connector", &self.connector)
|
||||
.field("openssl", &self.openssl)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Address, C> ServiceFactory<Connect<T>, C> for SslConnector<T> {
|
||||
type Response = Io<Layer<SslFilter>>;
|
||||
type Error = ConnectError;
|
||||
type Service = SslConnector<T>;
|
||||
type InitError = ();
|
||||
|
||||
async fn create(&self, _: C) -> Result<Self::Service, Self::InitError> {
|
||||
Ok(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Address> Service<Connect<T>> for SslConnector<T> {
|
||||
type Response = Io<Layer<SslFilter>>;
|
||||
type Error = ConnectError;
|
||||
|
||||
async fn call(
|
||||
&self,
|
||||
req: Connect<T>,
|
||||
_: ServiceCtx<'_, Self>,
|
||||
) -> Result<Self::Response, Self::Error> {
|
||||
self.connect(req).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use tls_openssl::ssl::SslMethod;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[ntex::test]
|
||||
async fn test_openssl_connect() {
|
||||
let server = ntex::server::test_server(|| {
|
||||
ntex::service::fn_service(|_| async { Ok::<_, ()>(()) })
|
||||
});
|
||||
|
||||
let ssl = BaseSslConnector::builder(SslMethod::tls()).unwrap();
|
||||
let factory = SslConnector::new(ssl.build())
|
||||
.memory_pool(PoolId::P5)
|
||||
.clone();
|
||||
|
||||
let srv = factory.pipeline(&()).await.unwrap();
|
||||
let result = srv
|
||||
.call(Connect::new("").set_addr(Some(server.addr())))
|
||||
.await;
|
||||
assert!(result.is_err());
|
||||
assert!(format!("{:?}", srv).contains("SslConnector"));
|
||||
}
|
||||
}
|
|
@ -8,6 +8,9 @@ use tls_openssl::x509::X509;
|
|||
|
||||
use crate::{PskIdentity, Servername};
|
||||
|
||||
mod connect;
|
||||
pub use self::connect::SslConnector;
|
||||
|
||||
mod accept;
|
||||
pub use self::accept::{SslAcceptor, SslAcceptorService};
|
||||
|
||||
|
|
|
@ -64,11 +64,17 @@ impl FilterLayer for TlsClientFilter {
|
|||
buf.with_dst(|dst| {
|
||||
loop {
|
||||
let mut cursor = io::Cursor::new(&src);
|
||||
let n = session.read_tls(&mut cursor)?;
|
||||
let n = match session.read_tls(&mut cursor) {
|
||||
Ok(n) => n,
|
||||
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
|
||||
break
|
||||
}
|
||||
Err(err) => return Err(err),
|
||||
};
|
||||
src.split_to(n);
|
||||
let state = session
|
||||
.process_new_packets()
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
|
||||
|
||||
let new_b = state.plaintext_bytes_to_read();
|
||||
if new_b > 0 {
|
||||
|
@ -92,18 +98,26 @@ impl FilterLayer for TlsClientFilter {
|
|||
fn process_write_buf(&self, buf: &WriteBuf<'_>) -> io::Result<()> {
|
||||
buf.with_src(|src| {
|
||||
if let Some(src) = src {
|
||||
let mut session = self.session.borrow_mut();
|
||||
let mut io = Wrapper(buf);
|
||||
let mut session = self.session.borrow_mut();
|
||||
|
||||
loop {
|
||||
'outer: loop {
|
||||
if !src.is_empty() {
|
||||
src.split_to(session.writer().write(src)?);
|
||||
}
|
||||
if session.wants_write() {
|
||||
session.complete_io(&mut io)?;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
while session.wants_write() {
|
||||
match session.write_tls(&mut io) {
|
||||
Ok(0) => continue 'outer,
|
||||
Ok(_) => continue,
|
||||
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
|
||||
break
|
||||
}
|
||||
Err(err) => return Err(err),
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
|
|
157
ntex-tls/src/rustls/connect.rs
Normal file
157
ntex-tls/src/rustls/connect.rs
Normal file
|
@ -0,0 +1,157 @@
|
|||
use std::{fmt, io, sync::Arc};
|
||||
|
||||
use ntex_bytes::PoolId;
|
||||
use ntex_io::{Io, Layer};
|
||||
use ntex_net::connect::{Address, Connect, ConnectError, Connector as BaseConnector};
|
||||
use ntex_service::{Pipeline, Service, ServiceCtx, ServiceFactory};
|
||||
use tls_rust::{pki_types::ServerName, ClientConfig};
|
||||
|
||||
use super::TlsClientFilter;
|
||||
|
||||
/// Rustls connector factory
|
||||
pub struct TlsConnector<T> {
|
||||
connector: Pipeline<BaseConnector<T>>,
|
||||
config: Arc<ClientConfig>,
|
||||
}
|
||||
|
||||
impl<T: Address> From<Arc<ClientConfig>> for TlsConnector<T> {
|
||||
fn from(config: Arc<ClientConfig>) -> Self {
|
||||
TlsConnector {
|
||||
config,
|
||||
connector: BaseConnector::default().into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Address> TlsConnector<T> {
|
||||
pub fn new(config: ClientConfig) -> Self {
|
||||
TlsConnector {
|
||||
config: Arc::new(config),
|
||||
connector: BaseConnector::default().into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set memory pool.
|
||||
///
|
||||
/// Use specified memory pool for memory allocations. By default P0
|
||||
/// memory pool is used.
|
||||
pub fn memory_pool(self, id: PoolId) -> Self {
|
||||
let connector = self
|
||||
.connector
|
||||
.into_service()
|
||||
.unwrap()
|
||||
.memory_pool(id)
|
||||
.into();
|
||||
Self {
|
||||
connector,
|
||||
config: self.config,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Address> TlsConnector<T> {
|
||||
/// Resolve and connect to remote host
|
||||
pub async fn connect<U>(
|
||||
&self,
|
||||
message: U,
|
||||
) -> Result<Io<Layer<TlsClientFilter>>, ConnectError>
|
||||
where
|
||||
Connect<T>: From<U>,
|
||||
{
|
||||
let req = Connect::from(message);
|
||||
let host = req.host().split(':').next().unwrap().to_owned();
|
||||
let io = self.connector.call(req).await?;
|
||||
|
||||
log::trace!("{}: SSL Handshake start for: {:?}", io.tag(), host);
|
||||
|
||||
let tag = io.tag();
|
||||
let config = self.config.clone();
|
||||
let host = ServerName::try_from(host)
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("{}", e)))?;
|
||||
|
||||
match TlsClientFilter::create(io, config, host.clone()).await {
|
||||
Ok(io) => {
|
||||
log::trace!("{}: TLS Handshake success: {:?}", tag, &host);
|
||||
Ok(io)
|
||||
}
|
||||
Err(e) => {
|
||||
log::trace!("{}: TLS Handshake error: {:?}", tag, e);
|
||||
Err(e.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Clone for TlsConnector<T> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
config: self.config.clone(),
|
||||
connector: self.connector.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> fmt::Debug for TlsConnector<T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("TlsConnector(rustls)")
|
||||
.field("connector", &self.connector)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Address, C> ServiceFactory<Connect<T>, C> for TlsConnector<T> {
|
||||
type Response = Io<Layer<TlsClientFilter>>;
|
||||
type Error = ConnectError;
|
||||
type Service = TlsConnector<T>;
|
||||
type InitError = ();
|
||||
|
||||
async fn create(&self, _: C) -> Result<Self::Service, Self::InitError> {
|
||||
Ok(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Address> Service<Connect<T>> for TlsConnector<T> {
|
||||
type Response = Io<Layer<TlsClientFilter>>;
|
||||
type Error = ConnectError;
|
||||
|
||||
async fn call(
|
||||
&self,
|
||||
req: Connect<T>,
|
||||
_: ServiceCtx<'_, Self>,
|
||||
) -> Result<Self::Response, Self::Error> {
|
||||
self.connect(req).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use tls_rust::RootCertStore;
|
||||
|
||||
use super::*;
|
||||
use ntex_util::future::lazy;
|
||||
|
||||
#[ntex::test]
|
||||
async fn test_rustls_connect() {
|
||||
let server = ntex::server::test_server(|| {
|
||||
ntex::service::fn_service(|_| async { Ok::<_, ()>(()) })
|
||||
});
|
||||
|
||||
let cert_store =
|
||||
RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
|
||||
let config = ClientConfig::builder()
|
||||
.with_root_certificates(cert_store)
|
||||
.with_no_client_auth();
|
||||
let _ = TlsConnector::<&'static str>::new(config.clone()).clone();
|
||||
let factory = TlsConnector::from(Arc::new(config))
|
||||
.memory_pool(PoolId::P5)
|
||||
.clone();
|
||||
|
||||
let srv = factory.pipeline(&()).await.unwrap();
|
||||
// always ready
|
||||
assert!(lazy(|cx| srv.poll_ready(cx)).await.is_ready());
|
||||
let result = srv
|
||||
.call(Connect::new("www.rust-lang.org").set_addr(Some(server.addr())))
|
||||
.await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
|
@ -6,10 +6,12 @@ use tls_rust::pki_types::CertificateDer;
|
|||
|
||||
mod accept;
|
||||
mod client;
|
||||
mod connect;
|
||||
mod server;
|
||||
pub use accept::{TlsAcceptor, TlsAcceptorService};
|
||||
|
||||
pub use self::accept::{TlsAcceptor, TlsAcceptorService};
|
||||
pub use self::client::TlsClientFilter;
|
||||
pub use self::connect::TlsConnector;
|
||||
pub use self::server::TlsServerFilter;
|
||||
|
||||
/// Connection's peer cert
|
||||
|
|
|
@ -72,11 +72,17 @@ impl FilterLayer for TlsServerFilter {
|
|||
buf.with_dst(|dst| {
|
||||
loop {
|
||||
let mut cursor = io::Cursor::new(&src);
|
||||
let n = session.read_tls(&mut cursor)?;
|
||||
let n = match session.read_tls(&mut cursor) {
|
||||
Ok(n) => n,
|
||||
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
|
||||
break
|
||||
}
|
||||
Err(err) => return Err(err),
|
||||
};
|
||||
src.split_to(n);
|
||||
let state = session
|
||||
.process_new_packets()
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
|
||||
|
||||
let new_b = state.plaintext_bytes_to_read();
|
||||
if new_b > 0 {
|
||||
|
@ -100,18 +106,26 @@ impl FilterLayer for TlsServerFilter {
|
|||
fn process_write_buf(&self, buf: &WriteBuf<'_>) -> io::Result<()> {
|
||||
buf.with_src(|src| {
|
||||
if let Some(src) = src {
|
||||
let mut session = self.session.borrow_mut();
|
||||
let mut io = Wrapper(buf);
|
||||
let mut session = self.session.borrow_mut();
|
||||
|
||||
loop {
|
||||
'outer: loop {
|
||||
if !src.is_empty() {
|
||||
src.split_to(session.writer().write(src)?);
|
||||
}
|
||||
if session.wants_write() {
|
||||
session.complete_io(&mut io)?;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
while session.wants_write() {
|
||||
match session.write_tls(&mut io) {
|
||||
Ok(0) => continue 'outer,
|
||||
Ok(_) => continue,
|
||||
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
|
||||
break
|
||||
}
|
||||
Err(err) => return Err(err),
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue