Refactor filter factories (#278)

This commit is contained in:
Nikolay Kim 2024-01-08 15:22:38 +06:00 committed by GitHub
parent a13f677df8
commit 174b5d86f0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
34 changed files with 271 additions and 657 deletions

View file

@ -1,5 +1,11 @@
# Changes # Changes
## [1.0.0-b.1] - 2024-01-08
* Refactor io tls filters
* Remove unnecessary 'static
## [1.0.0-b.0] - 2024-01-07 ## [1.0.0-b.0] - 2024-01-07
* Use "async fn" in trait for Service definition * Use "async fn" in trait for Service definition

View file

@ -1,6 +1,6 @@
[package] [package]
name = "ntex-connect" name = "ntex-connect"
version = "1.0.0-b.0" version = "1.0.0-b.1"
authors = ["ntex contributors <team@ntex.rs>"] authors = ["ntex contributors <team@ntex.rs>"]
description = "ntexwork connect utils for ntex framework" description = "ntexwork connect utils for ntex framework"
keywords = ["network", "framework", "async", "futures"] keywords = ["network", "framework", "async", "futures"]
@ -35,9 +35,9 @@ async-std = ["ntex-rt/async-std", "ntex-async-std"]
[dependencies] [dependencies]
ntex-service = "2.0.0-b.0" ntex-service = "2.0.0-b.0"
ntex-io = "1.0.0-b.0" ntex-io = "1.0.0-b.1"
ntex-tls = "1.0.0-b.0" ntex-tls = "1.0.0-b.1"
ntex-util = "1.0.0-b.0" ntex-util = "1.0.0-b.1"
ntex-bytes = "0.1.21" ntex-bytes = "0.1.21"
ntex-http = "0.1" ntex-http = "0.1"
ntex-rt = "0.4.7" ntex-rt = "0.4.7"

View file

@ -4,9 +4,9 @@ pub use ntex_tls::openssl::SslFilter;
pub use tls_openssl::ssl::{Error as SslError, HandshakeError, SslConnector, SslMethod}; pub use tls_openssl::ssl::{Error as SslError, HandshakeError, SslConnector, SslMethod};
use ntex_bytes::PoolId; use ntex_bytes::PoolId;
use ntex_io::{FilterFactory, Io, Layer}; use ntex_io::{Io, Layer};
use ntex_service::{Pipeline, Service, ServiceCtx, ServiceFactory}; use ntex_service::{Pipeline, Service, ServiceCtx, ServiceFactory};
use ntex_tls::openssl::SslConnector as IoSslConnector; use ntex_tls::openssl::connect as connect_io;
use super::{Address, Connect, ConnectError, Connector as BaseConnector}; use super::{Address, Connect, ConnectError, Connector as BaseConnector};
@ -64,7 +64,7 @@ impl<T: Address> Connector<T> {
.into_ssl(&host) .into_ssl(&host)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
let tag = io.tag(); let tag = io.tag();
match IoSslConnector::new(ssl).create(io).await { match connect_io(io, ssl).await {
Ok(io) => { Ok(io) => {
log::trace!("{}: SSL Handshake success: {:?}", tag, host); log::trace!("{}: SSL Handshake success: {:?}", tag, host);
Ok(io) Ok(io)
@ -97,7 +97,7 @@ impl<T> fmt::Debug for Connector<T> {
} }
} }
impl<T: Address, C: 'static> ServiceFactory<Connect<T>, C> for Connector<T> { impl<T: Address, C> ServiceFactory<Connect<T>, C> for Connector<T> {
type Response = Io<Layer<SslFilter>>; type Response = Io<Layer<SslFilter>>;
type Error = ConnectError; type Error = ConnectError;
type Service = Connector<T>; type Service = Connector<T>;

View file

@ -97,7 +97,7 @@ impl<T> fmt::Debug for Resolver<T> {
} }
} }
impl<T: Address, C: 'static> ServiceFactory<Connect<T>, C> for Resolver<T> { impl<T: Address, C> ServiceFactory<Connect<T>, C> for Resolver<T> {
type Response = Connect<T>; type Response = Connect<T>;
type Error = ConnectError; type Error = ConnectError;
type Service = Resolver<T>; type Service = Resolver<T>;

View file

@ -1,25 +1,24 @@
use std::{fmt, io}; use std::{fmt, io, sync::Arc};
pub use ntex_tls::rustls::TlsFilter; pub use ntex_tls::rustls::TlsClientFilter;
pub use tls_rustls::{ClientConfig, ServerName}; pub use tls_rustls::{ClientConfig, ServerName};
use ntex_bytes::PoolId; use ntex_bytes::PoolId;
use ntex_io::{FilterFactory, Io, Layer}; use ntex_io::{Io, Layer};
use ntex_service::{Pipeline, Service, ServiceCtx, ServiceFactory}; use ntex_service::{Pipeline, Service, ServiceCtx, ServiceFactory};
use ntex_tls::rustls::TlsConnector;
use super::{Address, Connect, ConnectError, Connector as BaseConnector}; use super::{Address, Connect, ConnectError, Connector as BaseConnector};
/// Rustls connector factory /// Rustls connector factory
pub struct Connector<T> { pub struct Connector<T> {
connector: Pipeline<BaseConnector<T>>, connector: Pipeline<BaseConnector<T>>,
inner: TlsConnector, config: Arc<ClientConfig>,
} }
impl<T: Address> From<std::sync::Arc<ClientConfig>> for Connector<T> { impl<T: Address> From<Arc<ClientConfig>> for Connector<T> {
fn from(cfg: std::sync::Arc<ClientConfig>) -> Self { fn from(config: Arc<ClientConfig>) -> Self {
Connector { Connector {
inner: TlsConnector::new(cfg), config,
connector: BaseConnector::default().into(), connector: BaseConnector::default().into(),
} }
} }
@ -28,7 +27,7 @@ impl<T: Address> From<std::sync::Arc<ClientConfig>> for Connector<T> {
impl<T: Address> Connector<T> { impl<T: Address> Connector<T> {
pub fn new(config: ClientConfig) -> Self { pub fn new(config: ClientConfig) -> Self {
Connector { Connector {
inner: TlsConnector::new(std::sync::Arc::new(config)), config: Arc::new(config),
connector: BaseConnector::default().into(), connector: BaseConnector::default().into(),
} }
} }
@ -46,38 +45,39 @@ impl<T: Address> Connector<T> {
.into(); .into();
Self { Self {
connector, connector,
inner: self.inner, config: self.config,
} }
} }
} }
impl<T: Address + 'static> Connector<T> { impl<T: Address> Connector<T> {
/// Resolve and connect to remote host /// Resolve and connect to remote host
pub async fn connect<U>(&self, message: U) -> Result<Io<Layer<TlsFilter>>, ConnectError> pub async fn connect<U>(
&self,
message: U,
) -> Result<Io<Layer<TlsClientFilter>>, ConnectError>
where where
Connect<T>: From<U>, Connect<T>: From<U>,
{ {
let req = Connect::from(message); let req = Connect::from(message);
let host = req.host().split(':').next().unwrap().to_owned(); let host = req.host().split(':').next().unwrap().to_owned();
let conn = self.connector.call(req); let io = self.connector.call(req).await?;
let connector = self.inner.clone();
let io = conn.await?;
log::trace!("{}: SSL Handshake start for: {:?}", io.tag(), host); log::trace!("{}: SSL Handshake start for: {:?}", io.tag(), host);
let tag = io.tag(); let tag = io.tag();
let config = self.config.clone();
let host = ServerName::try_from(host.as_str()) let host = ServerName::try_from(host.as_str())
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("{}", e)))?; .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("{}", e)))?;
let connector = connector.server_name(host.clone());
match connector.create(io).await { match TlsClientFilter::create(io, config, host.clone()).await {
Ok(io) => { Ok(io) => {
log::trace!("{}: TLS Handshake success: {:?}", tag, &host); log::trace!("{}: TLS Handshake success: {:?}", tag, &host);
Ok(io) Ok(io)
} }
Err(e) => { Err(e) => {
log::trace!("{}: TLS Handshake error: {:?}", tag, e); log::trace!("{}: TLS Handshake error: {:?}", tag, e);
Err(io::Error::new(io::ErrorKind::Other, format!("{}", e)).into()) Err(e.into())
} }
} }
} }
@ -86,7 +86,7 @@ impl<T: Address + 'static> Connector<T> {
impl<T> Clone for Connector<T> { impl<T> Clone for Connector<T> {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { Self {
inner: self.inner.clone(), config: self.config.clone(),
connector: self.connector.clone(), connector: self.connector.clone(),
} }
} }
@ -100,8 +100,8 @@ impl<T> fmt::Debug for Connector<T> {
} }
} }
impl<T: Address, C: 'static> ServiceFactory<Connect<T>, C> for Connector<T> { impl<T: Address, C> ServiceFactory<Connect<T>, C> for Connector<T> {
type Response = Io<Layer<TlsFilter>>; type Response = Io<Layer<TlsClientFilter>>;
type Error = ConnectError; type Error = ConnectError;
type Service = Connector<T>; type Service = Connector<T>;
type InitError = (); type InitError = ();
@ -112,7 +112,7 @@ impl<T: Address, C: 'static> ServiceFactory<Connect<T>, C> for Connector<T> {
} }
impl<T: Address> Service<Connect<T>> for Connector<T> { impl<T: Address> Service<Connect<T>> for Connector<T> {
type Response = Io<Layer<TlsFilter>>; type Response = Io<Layer<TlsClientFilter>>;
type Error = ConnectError; type Error = ConnectError;
async fn call( async fn call(

View file

@ -99,7 +99,7 @@ impl<T> fmt::Debug for Connector<T> {
} }
} }
impl<T: Address, C: 'static> ServiceFactory<Connect<T>, C> for Connector<T> { impl<T: Address, C> ServiceFactory<Connect<T>, C> for Connector<T> {
type Response = Io; type Response = Io;
type Error = ConnectError; type Error = ConnectError;
type Service = Connector<T>; type Service = Connector<T>;

View file

@ -1,5 +1,9 @@
# Changes # Changes
## [1.0.0-b.1] - 2024-01-08
* Remove FilterFactory trait and related utils
## [1.0.0-b.0] - 2024-01-07 ## [1.0.0-b.0] - 2024-01-07
* Use "async fn" in trait for Service definition * Use "async fn" in trait for Service definition

View file

@ -1,6 +1,6 @@
[package] [package]
name = "ntex-io" name = "ntex-io"
version = "1.0.0-b.0" version = "1.0.0-b.1"
authors = ["ntex contributors <team@ntex.rs>"] authors = ["ntex contributors <team@ntex.rs>"]
description = "Utilities for encoding and decoding frames" description = "Utilities for encoding and decoding frames"
keywords = ["network", "framework", "async", "futures"] keywords = ["network", "framework", "async", "futures"]
@ -18,7 +18,7 @@ path = "src/lib.rs"
[dependencies] [dependencies]
ntex-codec = "0.6.2" ntex-codec = "0.6.2"
ntex-bytes = "0.1.21" ntex-bytes = "0.1.21"
ntex-util = "1.0.0-b.0" ntex-util = "1.0.0-b.1"
ntex-service = "2.0.0-b.0" ntex-service = "2.0.0-b.0"
bitflags = "2.4" bitflags = "2.4"

View file

@ -2,8 +2,7 @@
#![deny(rust_2018_idioms, unreachable_pub, missing_debug_implementations)] #![deny(rust_2018_idioms, unreachable_pub, missing_debug_implementations)]
use std::{ use std::{
any::Any, any::TypeId, fmt, future::Future, io as sio, io::Error as IoError, any::Any, any::TypeId, fmt, io as sio, io::Error as IoError, task::Context, task::Poll,
task::Context, task::Poll,
}; };
pub mod testing; pub mod testing;
@ -31,7 +30,7 @@ pub use self::io::{Io, IoRef, OnDisconnect};
pub use self::seal::{IoBoxed, Sealed}; pub use self::seal::{IoBoxed, Sealed};
pub use self::tasks::{ReadContext, WriteContext}; pub use self::tasks::{ReadContext, WriteContext};
pub use self::timer::TimerHandle; pub use self::timer::TimerHandle;
pub use self::utils::{filter, seal, Decoded}; pub use self::utils::{seal, Decoded};
/// Status for read task /// Status for read task
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
@ -92,19 +91,6 @@ pub trait FilterLayer: fmt::Debug + 'static {
} }
} }
/// Creates new `Filter` values.
pub trait FilterFactory<F>: Sized {
/// The `Filter` value created by this factory
type Filter: FilterLayer;
/// Errors produced while building a filter.
type Error: fmt::Debug;
/// The future of the `FilterFactory` instance.
type Future: Future<Output = Result<Io<Layer<Self::Filter, F>>, Self::Error>>;
/// Create and return a new filter value asynchronously.
fn create(self, st: Io<F>) -> Self::Future;
}
pub trait IoStream { pub trait IoStream {
fn start(self, _: ReadContext, _: WriteContext) -> Option<Box<dyn Handle>>; fn start(self, _: ReadContext, _: WriteContext) -> Option<Box<dyn Handle>>;
} }

View file

@ -1,9 +1,7 @@
use std::{fmt, marker::PhantomData}; use ntex_service::{chain_factory, fn_service, ServiceFactory};
use ntex_service::{chain_factory, fn_service, Service, ServiceCtx, ServiceFactory};
use ntex_util::future::Ready; use ntex_util::future::Ready;
use crate::{Filter, FilterFactory, Io, IoBoxed, Layer}; use crate::{Filter, Io, IoBoxed};
/// Decoded item from buffer /// Decoded item from buffer
#[doc(hidden)] #[doc(hidden)]
@ -34,88 +32,13 @@ where
.and_then(srv) .and_then(srv)
} }
/// Create filter factory service
pub fn filter<T, F>(filter: T) -> FilterServiceFactory<T, F>
where
T: FilterFactory<F> + Clone,
{
FilterServiceFactory {
filter,
_t: PhantomData,
}
}
pub struct FilterServiceFactory<T, F> {
filter: T,
_t: PhantomData<F>,
}
impl<T: FilterFactory<F> + fmt::Debug, F> fmt::Debug for FilterServiceFactory<T, F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FilterServiceFactory")
.field("filter_factory", &self.filter)
.finish()
}
}
impl<T, F> ServiceFactory<Io<F>> for FilterServiceFactory<T, F>
where
T: FilterFactory<F> + Clone,
{
type Response = Io<Layer<T::Filter, F>>;
type Error = T::Error;
type Service = FilterService<T, F>;
type InitError = ();
async fn create(&self, _: ()) -> Result<Self::Service, Self::InitError> {
Ok(FilterService {
filter: self.filter.clone(),
_t: PhantomData,
})
}
}
pub struct FilterService<T, F> {
filter: T,
_t: PhantomData<F>,
}
impl<T: FilterFactory<F> + fmt::Debug, F> fmt::Debug for FilterService<T, F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FilterService")
.field("filter_factory", &self.filter)
.finish()
}
}
impl<T, F> Service<Io<F>> for FilterService<T, F>
where
T: FilterFactory<F> + Clone,
{
type Response = Io<Layer<T::Filter, F>>;
type Error = T::Error;
#[inline]
async fn call(
&self,
req: Io<F>,
_: ServiceCtx<'_, Self>,
) -> Result<Self::Response, Self::Error> {
self.filter.clone().create(req).await
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::io;
use ntex_bytes::Bytes; use ntex_bytes::Bytes;
use ntex_codec::BytesCodec; use ntex_codec::BytesCodec;
use super::*; use super::*;
use crate::{ use crate::{buf::Stack, filter::NullFilter, testing::IoTest};
buf::Stack, filter::NullFilter, testing::IoTest, FilterLayer, ReadBuf, WriteBuf,
};
#[ntex::test] #[ntex::test]
async fn test_utils() { async fn test_utils() {
@ -140,62 +63,6 @@ mod tests {
assert_eq!(buf, b"RES".as_ref()); assert_eq!(buf, b"RES".as_ref());
} }
#[derive(Debug)]
pub(crate) struct TestFilter;
impl FilterLayer for TestFilter {
fn process_read_buf(&self, buf: &ReadBuf<'_>) -> io::Result<usize> {
Ok(buf.nbytes())
}
fn process_write_buf(&self, _: &WriteBuf<'_>) -> io::Result<()> {
Ok(())
}
}
#[derive(Copy, Clone, Debug)]
struct TestFilterFactory;
impl<F: Filter> FilterFactory<F> for TestFilterFactory {
type Filter = TestFilter;
type Error = std::convert::Infallible;
type Future = Ready<Io<Layer<TestFilter, F>>, Self::Error>;
fn create(self, st: Io<F>) -> Self::Future {
Ready::Ok(st.add_filter(TestFilter))
}
}
#[ntex::test]
async fn test_utils_filter() {
let (_, server) = IoTest::create();
let filter_service_factory = filter::<_, crate::filter::Base>(TestFilterFactory)
.map_err(|_| ())
.map_init_err(|_| ());
assert!(format!("{:?}", filter_service_factory).contains("FilterServiceFactory"));
let svc = chain_factory(filter_service_factory)
.and_then(seal(fn_service(|io: IoBoxed| async move {
let _ = io.recv(&BytesCodec).await;
Ok::<_, ()>(())
})))
.pipeline(())
.await
.unwrap();
let _ = svc.call(Io::new(server)).await;
let (client, _) = IoTest::create();
let io = Io::new(client);
format!("{:?}", TestFilter);
let mut s = Stack::new();
s.add_layer();
let _ = s.read_buf(&io, 0, 0, |b| TestFilter.process_read_buf(b));
let _ = s.write_buf(&io, 0, |b| TestFilter.process_write_buf(b));
}
#[ntex::test] #[ntex::test]
async fn test_null_filter() { async fn test_null_filter() {
let (_, server) = IoTest::create(); let (_, server) = IoTest::create();

View file

@ -1,5 +1,9 @@
# Changes # Changes
## [1.0.0-b.1] - 2024-01-08
* Refactor io tls filters
## [1.0.0-b.0] - 2024-01-07 ## [1.0.0-b.0] - 2024-01-07
* Use "async fn" in trait for Service definition * Use "async fn" in trait for Service definition

View file

@ -1,6 +1,6 @@
[package] [package]
name = "ntex-tls" name = "ntex-tls"
version = "1.0.0-b.0" version = "1.0.0-b.1"
authors = ["ntex contributors <team@ntex.rs>"] authors = ["ntex contributors <team@ntex.rs>"]
description = "An implementation of SSL streams for ntex backed by OpenSSL" description = "An implementation of SSL streams for ntex backed by OpenSSL"
keywords = ["network", "framework", "async", "futures"] keywords = ["network", "framework", "async", "futures"]
@ -26,8 +26,8 @@ rustls = ["tls_rust"]
[dependencies] [dependencies]
ntex-bytes = "0.1.21" ntex-bytes = "0.1.21"
ntex-io = "1.0.0-b.0" ntex-io = "1.0.0-b.1"
ntex-util = "1.0.0-b.0" ntex-util = "1.0.0-b.1"
ntex-service = "2.0.0-b.0" ntex-service = "2.0.0-b.0"
log = "0.4" log = "0.4"
pin-project-lite = "0.2" pin-project-lite = "0.2"

View file

@ -1,7 +1,7 @@
use std::{fs::File, io, io::BufReader, sync::Arc}; use std::{fs::File, io, io::BufReader, sync::Arc};
use ntex::service::{chain_factory, fn_service}; use ntex::service::{chain_factory, fn_service};
use ntex::{codec, io::filter, io::Io, server, util::Either}; use ntex::{codec, io::Io, server, util::Either};
use ntex_tls::rustls::TlsAcceptor; use ntex_tls::rustls::TlsAcceptor;
use rustls_pemfile::{certs, rsa_private_keys}; use rustls_pemfile::{certs, rsa_private_keys};
use tls_rust::{Certificate, PrivateKey, ServerConfig}; use tls_rust::{Certificate, PrivateKey, ServerConfig};
@ -34,8 +34,8 @@ async fn main() -> io::Result<()> {
// start server // start server
server::ServerBuilder::new() server::ServerBuilder::new()
.bind("basic", "127.0.0.1:8443", move |_| { .bind("basic", "127.0.0.1:8443", move |_| {
chain_factory(filter(TlsAcceptor::new(tls_config.clone()))).and_then( chain_factory(TlsAcceptor::new(tls_config.clone())).and_then(fn_service(
fn_service(|io: Io<_>| async move { |io: Io<_>| async move {
println!("New client is connected"); println!("New client is connected");
io.send( io.send(
@ -62,8 +62,8 @@ async fn main() -> io::Result<()> {
} }
println!("Client is disconnected"); println!("Client is disconnected");
Ok(()) Ok(())
}), },
) ))
})? })?
.workers(1) .workers(1)
.run() .run()

View file

@ -1,7 +1,7 @@
use std::io; use std::io;
use ntex::service::{chain_factory, fn_service}; use ntex::service::{chain_factory, fn_service};
use ntex::{codec, io::filter, io::Io, server, util::Either}; use ntex::{codec, io::Io, server, util::Either};
use ntex_tls::openssl::{PeerCert, PeerCertChain, SslAcceptor}; use ntex_tls::openssl::{PeerCert, PeerCertChain, SslAcceptor};
use tls_openssl::ssl::{self, SslFiletype, SslMethod, SslVerifyMode}; use tls_openssl::ssl::{self, SslFiletype, SslMethod, SslVerifyMode};
@ -27,7 +27,7 @@ async fn main() -> io::Result<()> {
// start server // start server
server::ServerBuilder::new() server::ServerBuilder::new()
.bind("basic", "127.0.0.1:8443", move |_| { .bind("basic", "127.0.0.1:8443", move |_| {
chain_factory(filter(SslAcceptor::new(acceptor.clone()))).and_then(fn_service( chain_factory(SslAcceptor::new(acceptor.clone())).and_then(fn_service(
|io: Io<_>| async move { |io: Io<_>| async move {
println!("New client is connected"); println!("New client is connected");
if let Some(cert) = io.query::<PeerCert>().as_ref() { if let Some(cert) = io.query::<PeerCert>().as_ref() {

View file

@ -1,31 +1,29 @@
use std::task::{Context, Poll}; use std::{cell::RefCell, error::Error, fmt, io, task::Context, task::Poll};
use std::{error::Error, marker::PhantomData};
use ntex_io::{Filter, FilterFactory, Io, Layer}; use ntex_io::{Filter, Io, Layer};
use ntex_service::{Service, ServiceCtx, ServiceFactory}; use ntex_service::{Service, ServiceCtx, ServiceFactory};
use ntex_util::time::Millis; use ntex_util::time::{self, Millis};
use tls_openssl::ssl::SslAcceptor; use tls_openssl::ssl;
use crate::counter::Counter; use crate::counter::Counter;
use crate::MAX_SSL_ACCEPT_COUNTER; use crate::MAX_SSL_ACCEPT_COUNTER;
use super::{SslAcceptor as IoSslAcceptor, SslFilter}; use super::SslFilter;
#[derive(Debug)]
/// Support `TLS` server connections via openssl package /// Support `TLS` server connections via openssl package
/// ///
/// `openssl` feature enables `Acceptor` type /// `openssl` feature enables `Acceptor` type
pub struct Acceptor<F> { pub struct SslAcceptor {
acceptor: IoSslAcceptor, acceptor: ssl::SslAcceptor,
_t: PhantomData<F>, timeout: Millis,
} }
impl<F> Acceptor<F> { impl SslAcceptor {
/// Create default openssl acceptor service /// Create default openssl acceptor service
pub fn new(acceptor: SslAcceptor) -> Self { pub fn new(acceptor: ssl::SslAcceptor) -> Self {
Acceptor { SslAcceptor {
acceptor: IoSslAcceptor::new(acceptor), acceptor,
_t: PhantomData, timeout: Millis(5_000),
} }
} }
@ -33,54 +31,69 @@ impl<F> Acceptor<F> {
/// ///
/// Default is set to 5 seconds. /// Default is set to 5 seconds.
pub fn timeout<U: Into<Millis>>(mut self, timeout: U) -> Self { pub fn timeout<U: Into<Millis>>(mut self, timeout: U) -> Self {
self.acceptor.timeout(timeout); self.timeout = timeout.into();
self self
} }
} }
impl<F> From<SslAcceptor> for Acceptor<F> { impl fmt::Debug for SslAcceptor {
fn from(acceptor: SslAcceptor) -> Self { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SslAcceptor")
.field("timeout", &self.timeout)
.finish()
}
}
impl From<ssl::SslAcceptor> for SslAcceptor {
fn from(acceptor: ssl::SslAcceptor) -> Self {
Self::new(acceptor) Self::new(acceptor)
} }
} }
impl<F> Clone for Acceptor<F> { impl Clone for SslAcceptor {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { Self {
acceptor: self.acceptor.clone(), acceptor: self.acceptor.clone(),
_t: PhantomData, timeout: self.timeout,
} }
} }
} }
impl<F: Filter, C: 'static> ServiceFactory<Io<F>, C> for Acceptor<F> { impl<F: Filter, C> ServiceFactory<Io<F>, C> for SslAcceptor {
type Response = Io<Layer<SslFilter, F>>; type Response = Io<Layer<SslFilter, F>>;
type Error = Box<dyn Error>; type Error = Box<dyn Error>;
type Service = AcceptorService<F>; type Service = SslAcceptorService;
type InitError = (); type InitError = ();
async fn create(&self, _: C) -> Result<Self::Service, Self::InitError> { async fn create(&self, _: C) -> Result<Self::Service, Self::InitError> {
MAX_SSL_ACCEPT_COUNTER.with(|conns| { MAX_SSL_ACCEPT_COUNTER.with(|conns| {
Ok(AcceptorService { Ok(SslAcceptorService {
acceptor: self.acceptor.clone(), acceptor: self.acceptor.clone(),
timeout: self.timeout,
conns: conns.clone(), conns: conns.clone(),
_t: PhantomData,
}) })
}) })
} }
} }
#[derive(Debug)]
/// Support `TLS` server connections via openssl package /// Support `TLS` server connections via openssl package
/// ///
/// `openssl` feature enables `Acceptor` type /// `openssl` feature enables `Acceptor` type
pub struct AcceptorService<F> { pub struct SslAcceptorService {
acceptor: IoSslAcceptor, acceptor: ssl::SslAcceptor,
timeout: Millis,
conns: Counter, conns: Counter,
_t: PhantomData<F>,
} }
impl<F: Filter> Service<Io<F>> for AcceptorService<F> { impl fmt::Debug for SslAcceptorService {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SslAcceptorService")
.field("timeout", &self.timeout)
.finish()
}
}
impl<F: Filter> Service<Io<F>> for SslAcceptorService {
type Response = Io<Layer<SslFilter, F>>; type Response = Io<Layer<SslFilter, F>>;
type Error = Box<dyn Error>; type Error = Box<dyn Error>;
@ -94,10 +107,40 @@ impl<F: Filter> Service<Io<F>> for AcceptorService<F> {
async fn call( async fn call(
&self, &self,
req: Io<F>, io: Io<F>,
_: ServiceCtx<'_, Self>, _: ServiceCtx<'_, Self>,
) -> Result<Self::Response, Self::Error> { ) -> Result<Self::Response, Self::Error> {
let _guard = self.conns.get(); let timeout = self.timeout;
self.acceptor.clone().create(req).await let ctx_result = ssl::Ssl::new(self.acceptor.context());
time::timeout(timeout, async {
let ssl = ctx_result.map_err(super::map_to_ioerr)?;
let inner = super::IoInner {
source: None,
destination: None,
};
let filter = SslFilter {
inner: RefCell::new(ssl::SslStream::new(ssl, inner)?),
};
let io = io.add_filter(filter);
log::debug!("Accepting tls connection");
loop {
let result = io.with_buf(|buf| {
let filter = io.filter();
filter.with_buffers(buf, || filter.inner.borrow_mut().accept())
})?;
if super::handle_result(&io, result).await?.is_some() {
break;
}
}
Ok(io)
})
.await
.map_err(|_| {
io::Error::new(io::ErrorKind::TimedOut, "ssl handshake timeout").into()
})
.and_then(|item| item)
} }
} }

View file

@ -1,17 +1,15 @@
//! An implementation of SSL streams for ntex backed by OpenSSL //! An implementation of SSL streams for ntex backed by OpenSSL
use std::cell::RefCell; use std::{any, cell::RefCell, cmp, error::Error, io, task::Poll};
use std::{any, cmp, error::Error, fmt, io, task::Poll};
use ntex_bytes::{BufMut, BytesVec}; use ntex_bytes::{BufMut, BytesVec};
use ntex_io::{types, Filter, FilterFactory, FilterLayer, Io, Layer, ReadBuf, WriteBuf}; use ntex_io::{types, Filter, FilterLayer, Io, Layer, ReadBuf, WriteBuf};
use ntex_util::{future::BoxFuture, time, time::Millis};
use tls_openssl::ssl::{self, NameType, SslStream}; use tls_openssl::ssl::{self, NameType, SslStream};
use tls_openssl::x509::X509; use tls_openssl::x509::X509;
use crate::{PskIdentity, Servername}; use crate::{PskIdentity, Servername};
mod accept; mod accept;
pub use self::accept::{Acceptor, AcceptorService}; pub use self::accept::{SslAcceptor, SslAcceptorService};
/// Connection's peer cert /// Connection's peer cert
#[derive(Debug)] #[derive(Debug)]
@ -211,132 +209,31 @@ impl FilterLayer for SslFilter {
} }
} }
pub struct SslAcceptor { /// Create openssl connector filter factory
acceptor: ssl::SslAcceptor, pub async fn connect<F: Filter>(
timeout: Millis, io: Io<F>,
}
impl SslAcceptor {
/// Create openssl acceptor filter factory
pub fn new(acceptor: ssl::SslAcceptor) -> Self {
SslAcceptor {
acceptor,
timeout: Millis(5_000),
}
}
/// Set handshake timeout.
///
/// Default is set to 5 seconds.
pub fn timeout<U: Into<Millis>>(&mut self, timeout: U) -> &mut Self {
self.timeout = timeout.into();
self
}
}
impl Clone for SslAcceptor {
fn clone(&self) -> Self {
Self {
acceptor: self.acceptor.clone(),
timeout: self.timeout,
}
}
}
impl fmt::Debug for SslAcceptor {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SslAcceptor")
.field("timeout", &self.timeout)
.finish()
}
}
impl<F: Filter> FilterFactory<F> for SslAcceptor {
type Filter = SslFilter;
type Error = Box<dyn Error>;
type Future = BoxFuture<'static, Result<Io<Layer<Self::Filter, F>>, Self::Error>>;
fn create(self, io: Io<F>) -> Self::Future {
let timeout = self.timeout;
let ctx_result = ssl::Ssl::new(self.acceptor.context());
Box::pin(async move {
time::timeout(timeout, async {
let ssl = ctx_result.map_err(map_to_ioerr)?;
let inner = IoInner {
source: None,
destination: None,
};
let filter = SslFilter {
inner: RefCell::new(ssl::SslStream::new(ssl, inner)?),
};
let io = io.add_filter(filter);
log::debug!("Accepting tls connection");
loop {
let result = io.with_buf(|buf| {
let filter = io.filter();
filter.with_buffers(buf, || filter.inner.borrow_mut().accept())
})?;
if handle_result(&io, result).await?.is_some() {
break;
}
}
Ok(io)
})
.await
.map_err(|_| {
io::Error::new(io::ErrorKind::TimedOut, "ssl handshake timeout").into()
})
.and_then(|item| item)
})
}
}
#[derive(Debug)]
pub struct SslConnector {
ssl: ssl::Ssl, ssl: ssl::Ssl,
} ) -> Result<Io<Layer<SslFilter, F>>, io::Error> {
let inner = IoInner {
source: None,
destination: None,
};
let filter = SslFilter {
inner: RefCell::new(ssl::SslStream::new(ssl, inner)?),
};
let io = io.add_filter(filter);
impl SslConnector { loop {
/// Create openssl connector filter factory let result = io.with_buf(|buf| {
pub fn new(ssl: ssl::Ssl) -> Self { let filter = io.filter();
SslConnector { ssl } filter.with_buffers(buf, || filter.inner.borrow_mut().connect())
})?;
if handle_result(&io, result).await?.is_some() {
break;
}
} }
}
impl<F: Filter> FilterFactory<F> for SslConnector { Ok(io)
type Filter = SslFilter;
type Error = Box<dyn Error>;
type Future = BoxFuture<'static, Result<Io<Layer<Self::Filter, F>>, Self::Error>>;
fn create(self, io: Io<F>) -> Self::Future {
Box::pin(async move {
let inner = IoInner {
source: None,
destination: None,
};
let filter = SslFilter {
inner: RefCell::new(ssl::SslStream::new(self.ssl, inner)?),
};
let io = io.add_filter(filter);
loop {
let result = io.with_buf(|buf| {
let filter = io.filter();
filter.with_buffers(buf, || filter.inner.borrow_mut().connect())
})?;
if handle_result(&io, result).await?.is_some() {
break;
}
}
Ok(io)
})
}
} }
async fn handle_result<T, F>( async fn handle_result<T, F>(

View file

@ -1,30 +1,30 @@
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use std::{io, marker::PhantomData, sync::Arc}; use std::{io, sync::Arc};
use tls_rust::ServerConfig; use tls_rust::ServerConfig;
use ntex_io::{Filter, FilterFactory, Io, Layer}; use ntex_io::{Filter, Io, Layer};
use ntex_service::{Service, ServiceCtx, ServiceFactory}; use ntex_service::{Service, ServiceCtx, ServiceFactory};
use ntex_util::time::Millis; use ntex_util::time::Millis;
use super::{TlsAcceptor, TlsFilter}; use super::TlsServerFilter;
use crate::{counter::Counter, MAX_SSL_ACCEPT_COUNTER}; use crate::{counter::Counter, MAX_SSL_ACCEPT_COUNTER};
#[derive(Debug)] #[derive(Debug)]
/// Support `SSL` connections via rustls package /// Support `SSL` connections via rustls package
/// ///
/// `rust-tls` feature enables `RustlsAcceptor` type /// `rust-tls` feature enables `RustlsAcceptor` type
pub struct Acceptor<F> { pub struct TlsAcceptor {
inner: TlsAcceptor, config: Arc<ServerConfig>,
_t: PhantomData<F>, timeout: Millis,
} }
impl<F> Acceptor<F> { impl TlsAcceptor {
/// Create rustls based `Acceptor` service factory /// Create rustls based `Acceptor` service factory
pub fn new(config: Arc<ServerConfig>) -> Self { pub fn new(config: Arc<ServerConfig>) -> Self {
Acceptor { Self {
inner: TlsAcceptor::new(config), config,
_t: PhantomData, timeout: Millis(5_000),
} }
} }
@ -32,38 +32,38 @@ impl<F> Acceptor<F> {
/// ///
/// Default is set to 5 seconds. /// Default is set to 5 seconds.
pub fn timeout<U: Into<Millis>>(mut self, timeout: U) -> Self { pub fn timeout<U: Into<Millis>>(mut self, timeout: U) -> Self {
self.inner.timeout(timeout.into()); self.timeout = timeout.into();
self self
} }
} }
impl<F> From<ServerConfig> for Acceptor<F> { impl From<ServerConfig> for TlsAcceptor {
fn from(cfg: ServerConfig) -> Self { fn from(cfg: ServerConfig) -> Self {
Self::new(Arc::new(cfg)) Self::new(Arc::new(cfg))
} }
} }
impl<F> Clone for Acceptor<F> { impl Clone for TlsAcceptor {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { Self {
inner: self.inner.clone(), config: self.config.clone(),
_t: PhantomData, timeout: self.timeout,
} }
} }
} }
impl<F: Filter, C: 'static> ServiceFactory<Io<F>, C> for Acceptor<F> { impl<F: Filter, C> ServiceFactory<Io<F>, C> for TlsAcceptor {
type Response = Io<Layer<TlsFilter, F>>; type Response = Io<Layer<TlsServerFilter, F>>;
type Error = io::Error; type Error = io::Error;
type Service = AcceptorService<F>; type Service = TlsAcceptorService;
type InitError = (); type InitError = ();
async fn create(&self, _: C) -> Result<Self::Service, Self::InitError> { async fn create(&self, _: C) -> Result<Self::Service, Self::InitError> {
MAX_SSL_ACCEPT_COUNTER.with(|conns| { MAX_SSL_ACCEPT_COUNTER.with(|conns| {
Ok(AcceptorService { Ok(TlsAcceptorService {
acceptor: self.inner.clone(), config: self.config.clone(),
timeout: self.timeout,
conns: conns.clone(), conns: conns.clone(),
io: PhantomData,
}) })
}) })
} }
@ -71,14 +71,14 @@ impl<F: Filter, C: 'static> ServiceFactory<Io<F>, C> for Acceptor<F> {
#[derive(Debug)] #[derive(Debug)]
/// RusTLS based `Acceptor` service /// RusTLS based `Acceptor` service
pub struct AcceptorService<F> { pub struct TlsAcceptorService {
acceptor: TlsAcceptor, config: Arc<ServerConfig>,
io: PhantomData<F>, timeout: Millis,
conns: Counter, conns: Counter,
} }
impl<F: Filter> Service<Io<F>> for AcceptorService<F> { impl<F: Filter> Service<Io<F>> for TlsAcceptorService {
type Response = Io<Layer<TlsFilter, F>>; type Response = Io<Layer<TlsServerFilter, F>>;
type Error = io::Error; type Error = io::Error;
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
@ -91,10 +91,10 @@ impl<F: Filter> Service<Io<F>> for AcceptorService<F> {
async fn call( async fn call(
&self, &self,
req: Io<F>, io: Io<F>,
_: ServiceCtx<'_, Self>, _: ServiceCtx<'_, Self>,
) -> Result<Self::Response, Self::Error> { ) -> Result<Self::Response, Self::Error> {
let _guard = self.conns.get(); let _guard = self.conns.get();
self.acceptor.clone().create(req).await super::TlsServerFilter::create(io, self.config.clone(), self.timeout).await
} }
} }

View file

@ -7,13 +7,11 @@ use ntex_io::{types, Filter, FilterLayer, Io, Layer, ReadBuf, WriteBuf};
use ntex_util::ready; use ntex_util::ready;
use tls_rust::{ClientConfig, ClientConnection, ServerName}; use tls_rust::{ClientConfig, ClientConnection, ServerName};
use crate::rustls::{TlsFilter, Wrapper}; use super::{PeerCert, PeerCertChain, Wrapper};
use super::{PeerCert, PeerCertChain};
#[derive(Debug)] #[derive(Debug)]
/// An implementation of SSL streams /// An implementation of SSL streams
pub(crate) struct TlsClientFilter { pub struct TlsClientFilter {
session: RefCell<ClientConnection>, session: RefCell<ClientConnection>,
} }
@ -114,22 +112,22 @@ impl FilterLayer for TlsClientFilter {
} }
impl TlsClientFilter { impl TlsClientFilter {
pub(crate) async fn create<F: Filter>( pub async fn create<F: Filter>(
io: Io<F>, io: Io<F>,
cfg: Arc<ClientConfig>, cfg: Arc<ClientConfig>,
domain: ServerName, domain: ServerName,
) -> Result<Io<Layer<TlsFilter, F>>, io::Error> { ) -> Result<Io<Layer<TlsClientFilter, F>>, io::Error> {
let session = ClientConnection::new(cfg, domain) let session = ClientConnection::new(cfg, domain)
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
let filter = TlsFilter::new_client(TlsClientFilter { let filter = TlsClientFilter {
session: RefCell::new(session), session: RefCell::new(session),
}); };
let io = io.add_filter(filter); let io = io.add_filter(filter);
let filter = io.filter(); let filter = io.filter();
loop { loop {
let (result, wants_read, handshaking) = io.with_buf(|buf| { let (result, wants_read, handshaking) = io.with_buf(|buf| {
let mut session = filter.client().session.borrow_mut(); let mut session = filter.session.borrow_mut();
let mut wrp = Wrapper(buf); let mut wrp = Wrapper(buf);
let mut result = ( let mut result = (
session.complete_io(&mut wrp), session.complete_io(&mut wrp),

View file

@ -1,21 +1,16 @@
#![allow(clippy::type_complexity)]
//! An implementation of SSL streams for ntex backed by OpenSSL //! An implementation of SSL streams for ntex backed by OpenSSL
use std::{any, cmp, io, sync::Arc, task::Context, task::Poll}; use std::{cmp, io};
use ntex_io::{ use ntex_io::WriteBuf;
Filter, FilterFactory, FilterLayer, Io, Layer, ReadBuf, ReadStatus, WriteBuf, use tls_rust::Certificate;
WriteStatus,
};
use ntex_util::{future::BoxFuture, time::Millis};
use tls_rust::{Certificate, ClientConfig, ServerConfig, ServerName};
mod accept; mod accept;
mod client; mod client;
mod server; mod server;
pub use accept::{Acceptor, AcceptorService}; pub use accept::{TlsAcceptor, TlsAcceptorService};
use self::client::TlsClientFilter; pub use self::client::TlsClientFilter;
use self::server::TlsServerFilter; pub use self::server::TlsServerFilter;
/// Connection's peer cert /// Connection's peer cert
#[derive(Debug)] #[derive(Debug)]
@ -25,203 +20,6 @@ pub struct PeerCert(pub Certificate);
#[derive(Debug)] #[derive(Debug)]
pub struct PeerCertChain(pub Vec<Certificate>); pub struct PeerCertChain(pub Vec<Certificate>);
#[derive(Debug)]
/// An implementation of SSL streams
pub struct TlsFilter {
inner: InnerTlsFilter,
}
#[derive(Debug)]
enum InnerTlsFilter {
Server(TlsServerFilter),
Client(TlsClientFilter),
}
impl TlsFilter {
fn new_server(server: TlsServerFilter) -> Self {
TlsFilter {
inner: InnerTlsFilter::Server(server),
}
}
fn new_client(client: TlsClientFilter) -> Self {
TlsFilter {
inner: InnerTlsFilter::Client(client),
}
}
fn server(&self) -> &TlsServerFilter {
match self.inner {
InnerTlsFilter::Server(ref server) => server,
_ => unreachable!(),
}
}
fn client(&self) -> &TlsClientFilter {
match self.inner {
InnerTlsFilter::Client(ref server) => server,
_ => unreachable!(),
}
}
}
impl FilterLayer for TlsFilter {
#[inline]
fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> {
match self.inner {
InnerTlsFilter::Server(ref f) => f.query(id),
InnerTlsFilter::Client(ref f) => f.query(id),
}
}
#[inline]
fn shutdown(&self, buf: &WriteBuf<'_>) -> io::Result<Poll<()>> {
match self.inner {
InnerTlsFilter::Server(ref f) => f.shutdown(buf),
InnerTlsFilter::Client(ref f) => f.shutdown(buf),
}
}
#[inline]
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> {
match self.inner {
InnerTlsFilter::Server(ref f) => f.poll_read_ready(cx),
InnerTlsFilter::Client(ref f) => f.poll_read_ready(cx),
}
}
#[inline]
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> {
match self.inner {
InnerTlsFilter::Server(ref f) => f.poll_write_ready(cx),
InnerTlsFilter::Client(ref f) => f.poll_write_ready(cx),
}
}
#[inline]
fn process_read_buf(&self, buf: &ReadBuf<'_>) -> io::Result<usize> {
match self.inner {
InnerTlsFilter::Server(ref f) => f.process_read_buf(buf),
InnerTlsFilter::Client(ref f) => f.process_read_buf(buf),
}
}
#[inline]
fn process_write_buf(&self, buf: &WriteBuf<'_>) -> io::Result<()> {
match self.inner {
InnerTlsFilter::Server(ref f) => f.process_write_buf(buf),
InnerTlsFilter::Client(ref f) => f.process_write_buf(buf),
}
}
}
#[derive(Debug)]
pub struct TlsAcceptor {
cfg: Arc<ServerConfig>,
timeout: Millis,
}
impl TlsAcceptor {
/// Create openssl acceptor filter factory
pub fn new(cfg: Arc<ServerConfig>) -> Self {
TlsAcceptor {
cfg,
timeout: Millis(5_000),
}
}
/// Set handshake timeout.
///
/// Default is set to 5 seconds.
pub fn timeout<U: Into<Millis>>(&mut self, timeout: U) -> &mut Self {
self.timeout = timeout.into();
self
}
}
impl From<ServerConfig> for TlsAcceptor {
fn from(cfg: ServerConfig) -> Self {
Self::new(Arc::new(cfg))
}
}
impl Clone for TlsAcceptor {
fn clone(&self) -> Self {
Self {
cfg: self.cfg.clone(),
timeout: self.timeout,
}
}
}
impl<F: Filter> FilterFactory<F> for TlsAcceptor {
type Filter = TlsFilter;
type Error = io::Error;
type Future = BoxFuture<'static, Result<Io<Layer<Self::Filter, F>>, io::Error>>;
fn create(self, st: Io<F>) -> Self::Future {
let cfg = self.cfg.clone();
let timeout = self.timeout;
Box::pin(async move { TlsServerFilter::create(st, cfg, timeout).await })
}
}
#[derive(Debug)]
pub struct TlsConnector {
cfg: Arc<ClientConfig>,
}
impl TlsConnector {
/// Create openssl connector filter factory
pub fn new(cfg: Arc<ClientConfig>) -> Self {
TlsConnector { cfg }
}
/// Set server name
pub fn server_name(self, server_name: ServerName) -> TlsConnectorConfigured {
TlsConnectorConfigured {
server_name,
cfg: self.cfg,
}
}
}
impl Clone for TlsConnector {
fn clone(&self) -> Self {
Self {
cfg: self.cfg.clone(),
}
}
}
#[derive(Debug)]
pub struct TlsConnectorConfigured {
cfg: Arc<ClientConfig>,
server_name: ServerName,
}
impl Clone for TlsConnectorConfigured {
fn clone(&self) -> Self {
Self {
cfg: self.cfg.clone(),
server_name: self.server_name.clone(),
}
}
}
impl<F: Filter> FilterFactory<F> for TlsConnectorConfigured {
type Filter = TlsFilter;
type Error = io::Error;
type Future = BoxFuture<'static, Result<Io<Layer<Self::Filter, F>>, io::Error>>;
fn create(self, st: Io<F>) -> Self::Future {
let cfg = self.cfg;
let server_name = self.server_name;
Box::pin(async move { TlsClientFilter::create(st, cfg, server_name).await })
}
}
pub(crate) struct Wrapper<'a, 'b>(&'a WriteBuf<'b>); pub(crate) struct Wrapper<'a, 'b>(&'a WriteBuf<'b>);
impl<'a, 'b> io::Read for Wrapper<'a, 'b> { impl<'a, 'b> io::Read for Wrapper<'a, 'b> {

View file

@ -7,14 +7,13 @@ use ntex_io::{types, Filter, FilterLayer, Io, Layer, ReadBuf, WriteBuf};
use ntex_util::{ready, time, time::Millis}; use ntex_util::{ready, time, time::Millis};
use tls_rust::{ServerConfig, ServerConnection}; use tls_rust::{ServerConfig, ServerConnection};
use crate::rustls::{TlsFilter, Wrapper};
use crate::Servername; use crate::Servername;
use super::{PeerCert, PeerCertChain}; use super::{PeerCert, PeerCertChain, Wrapper};
#[derive(Debug)] #[derive(Debug)]
/// An implementation of SSL streams /// An implementation of SSL streams
pub(crate) struct TlsServerFilter { pub struct TlsServerFilter {
session: RefCell<ServerConnection>, session: RefCell<ServerConnection>,
} }
@ -121,23 +120,23 @@ impl FilterLayer for TlsServerFilter {
} }
impl TlsServerFilter { impl TlsServerFilter {
pub(crate) async fn create<F: Filter>( pub async fn create<F: Filter>(
io: Io<F>, io: Io<F>,
cfg: Arc<ServerConfig>, cfg: Arc<ServerConfig>,
timeout: Millis, timeout: Millis,
) -> Result<Io<Layer<TlsFilter, F>>, io::Error> { ) -> Result<Io<Layer<TlsServerFilter, F>>, io::Error> {
time::timeout(timeout, async { time::timeout(timeout, async {
let session = ServerConnection::new(cfg) let session = ServerConnection::new(cfg)
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
let filter = TlsFilter::new_server(TlsServerFilter { let filter = TlsServerFilter {
session: RefCell::new(session), session: RefCell::new(session),
}); };
let io = io.add_filter(filter); let io = io.add_filter(filter);
let filter = io.filter(); let filter = io.filter();
loop { loop {
let (result, wants_read, handshaking) = io.with_buf(|buf| { let (result, wants_read, handshaking) = io.with_buf(|buf| {
let mut session = filter.server().session.borrow_mut(); let mut session = filter.session.borrow_mut();
let mut wrp = Wrapper(buf); let mut wrp = Wrapper(buf);
let mut result = ( let mut result = (
session.complete_io(&mut wrp), session.complete_io(&mut wrp),

View file

@ -1,5 +1,9 @@
# Changes # Changes
## [1.0.0-b.1] - 2024-01-xx
* Remove unnecessary 'static
## [1.0.0-b.0] - 2024-01-07 ## [1.0.0-b.0] - 2024-01-07
* Use "async fn" in trait for Service definition * Use "async fn" in trait for Service definition

View file

@ -1,6 +1,6 @@
[package] [package]
name = "ntex-util" name = "ntex-util"
version = "1.0.0-b.0" version = "1.0.0-b.1"
authors = ["ntex contributors <team@ntex.rs>"] authors = ["ntex contributors <team@ntex.rs>"]
description = "Utilities for ntex framework" description = "Utilities for ntex framework"
keywords = ["network", "framework", "async", "futures"] keywords = ["network", "framework", "async", "futures"]

View file

@ -53,7 +53,7 @@ impl<R, E, F> fmt::Debug for KeepAlive<R, E, F> {
} }
} }
impl<R, E, F, C: 'static> ServiceFactory<R, C> for KeepAlive<R, E, F> impl<R, E, F, C> ServiceFactory<R, C> for KeepAlive<R, E, F>
where where
F: Fn() -> E + Clone, F: Fn() -> E + Clone,
{ {

View file

@ -1,5 +1,9 @@
# Changes # Changes
## [1.0.0-b.1] - 2024-01-08
* Refactor io tls filters
## [1.0.0-b.0] - 2024-01-07 ## [1.0.0-b.0] - 2024-01-07
* Use "async fn" in trait for Service definition * Use "async fn" in trait for Service definition

View file

@ -1,6 +1,6 @@
[package] [package]
name = "ntex" name = "ntex"
version = "1.0.0-b.0" version = "1.0.0-b.1"
authors = ["ntex contributors <team@ntex.rs>"] authors = ["ntex contributors <team@ntex.rs>"]
description = "Framework for composable network services" description = "Framework for composable network services"
readme = "README.md" readme = "README.md"
@ -49,7 +49,7 @@ async-std = ["ntex-rt/async-std", "ntex-async-std", "ntex-connect/async-std"]
[dependencies] [dependencies]
ntex-codec = "0.6.2" ntex-codec = "0.6.2"
ntex-connect = "1.0.0-b.0" ntex-connect = "1.0.0-b.1"
ntex-http = "0.1.11" ntex-http = "0.1.11"
ntex-router = "0.5.2" ntex-router = "0.5.2"
ntex-service = "2.0.0-b.0" ntex-service = "2.0.0-b.0"
@ -58,8 +58,8 @@ ntex-util = "1.0.0-b.0"
ntex-bytes = "0.1.21" ntex-bytes = "0.1.21"
ntex-h2 = "0.5.0-b.0" ntex-h2 = "0.5.0-b.0"
ntex-rt = "0.4.11" ntex-rt = "0.4.11"
ntex-io = "1.0.0-b.0" ntex-io = "1.0.0-b.1"
ntex-tls = "1.0.0-b.0" ntex-tls = "1.0.0-b.1"
ntex-tokio = { version = "0.4.0-b.0", optional = true } ntex-tokio = { version = "0.4.0-b.0", optional = true }
ntex-glommio = { version = "0.4.0-b.0", optional = true } ntex-glommio = { version = "0.4.0-b.0", optional = true }
ntex-async-std = { version = "0.4.0-b.0", optional = true } ntex-async-std = { version = "0.4.0-b.0", optional = true }

View file

@ -21,7 +21,7 @@ pub struct Encoder<B> {
fut: Option<JoinHandle<Result<ContentEncoder, io::Error>>>, fut: Option<JoinHandle<Result<ContentEncoder, io::Error>>>,
} }
impl<B: MessageBody + 'static> Encoder<B> { impl<B: MessageBody> Encoder<B> {
pub fn response( pub fn response(
encoding: ContentEncoding, encoding: ContentEncoding,
head: &mut ResponseHead, head: &mut ResponseHead,

View file

@ -47,8 +47,8 @@ where
#[cfg(feature = "openssl")] #[cfg(feature = "openssl")]
mod openssl { mod openssl {
use ntex_tls::openssl::{Acceptor, SslFilter}; use ntex_tls::openssl::{SslAcceptor, SslFilter};
use tls_openssl::ssl::SslAcceptor; use tls_openssl::ssl;
use super::*; use super::*;
use crate::{io::Layer, server::SslError}; use crate::{io::Layer, server::SslError};
@ -72,14 +72,14 @@ mod openssl {
/// Create openssl based service /// Create openssl based service
pub fn openssl( pub fn openssl(
self, self,
acceptor: SslAcceptor, acceptor: ssl::SslAcceptor,
) -> impl ServiceFactory< ) -> impl ServiceFactory<
Io<F>, Io<F>,
Response = (), Response = (),
Error = SslError<DispatchError>, Error = SslError<DispatchError>,
InitError = (), InitError = (),
> { > {
Acceptor::new(acceptor) SslAcceptor::new(acceptor)
.timeout(self.cfg.ssl_handshake_timeout) .timeout(self.cfg.ssl_handshake_timeout)
.map_err(SslError::Ssl) .map_err(SslError::Ssl)
.map_init_err(|_| panic!()) .map_init_err(|_| panic!())
@ -92,13 +92,13 @@ mod openssl {
mod rustls { mod rustls {
use std::fmt; use std::fmt;
use ntex_tls::rustls::{Acceptor, TlsFilter}; use ntex_tls::rustls::{TlsAcceptor, TlsServerFilter};
use tls_rustls::ServerConfig; use tls_rustls::ServerConfig;
use super::*; use super::*;
use crate::{io::Layer, server::SslError}; use crate::{io::Layer, server::SslError};
impl<F, S, B, X, U> H1Service<Layer<TlsFilter, F>, S, B, X, U> impl<F, S, B, X, U> H1Service<Layer<TlsServerFilter, F>, S, B, X, U>
where where
F: Filter, F: Filter,
S: ServiceFactory<Request> + 'static, S: ServiceFactory<Request> + 'static,
@ -109,7 +109,7 @@ mod rustls {
X: ServiceFactory<Request, Response = Request> + 'static, X: ServiceFactory<Request, Response = Request> + 'static,
X::Error: ResponseError, X::Error: ResponseError,
X::InitError: fmt::Debug, X::InitError: fmt::Debug,
U: ServiceFactory<(Request, Io<Layer<TlsFilter, F>>, Codec), Response = ()> U: ServiceFactory<(Request, Io<Layer<TlsServerFilter, F>>, Codec), Response = ()>
+ 'static, + 'static,
U::Error: fmt::Display + Error, U::Error: fmt::Display + Error,
U::InitError: fmt::Debug, U::InitError: fmt::Debug,
@ -124,7 +124,7 @@ mod rustls {
Error = SslError<DispatchError>, Error = SslError<DispatchError>,
InitError = (), InitError = (),
> { > {
Acceptor::from(config) TlsAcceptor::from(config)
.timeout(self.cfg.ssl_handshake_timeout) .timeout(self.cfg.ssl_handshake_timeout)
.map_err(|e| SslError::Ssl(Box::new(e))) .map_err(|e| SslError::Ssl(Box::new(e)))
.map_init_err(|_| panic!()) .map_init_err(|_| panic!())

View file

@ -44,8 +44,8 @@ where
#[cfg(feature = "openssl")] #[cfg(feature = "openssl")]
mod openssl { mod openssl {
use ntex_tls::openssl::{Acceptor, SslFilter}; use ntex_tls::openssl::{SslAcceptor, SslFilter};
use tls_openssl::ssl::SslAcceptor; use tls_openssl::ssl;
use crate::{io::Layer, server::SslError}; use crate::{io::Layer, server::SslError};
@ -62,14 +62,14 @@ mod openssl {
/// Create ssl based service /// Create ssl based service
pub fn openssl( pub fn openssl(
self, self,
acceptor: SslAcceptor, acceptor: ssl::SslAcceptor,
) -> impl ServiceFactory< ) -> impl ServiceFactory<
Io<F>, Io<F>,
Response = (), Response = (),
Error = SslError<DispatchError>, Error = SslError<DispatchError>,
InitError = S::InitError, InitError = S::InitError,
> { > {
Acceptor::new(acceptor) SslAcceptor::new(acceptor)
.timeout(self.cfg.ssl_handshake_timeout) .timeout(self.cfg.ssl_handshake_timeout)
.map_err(SslError::Ssl) .map_err(SslError::Ssl)
.map_init_err(|_| panic!()) .map_init_err(|_| panic!())
@ -80,13 +80,13 @@ mod openssl {
#[cfg(feature = "rustls")] #[cfg(feature = "rustls")]
mod rustls { mod rustls {
use ntex_tls::rustls::{Acceptor, TlsFilter}; use ntex_tls::rustls::{TlsAcceptor, TlsServerFilter};
use tls_rustls::ServerConfig; use tls_rustls::ServerConfig;
use super::*; use super::*;
use crate::{io::Layer, server::SslError}; use crate::{io::Layer, server::SslError};
impl<F, S, B> H2Service<Layer<TlsFilter, F>, S, B> impl<F, S, B> H2Service<Layer<TlsServerFilter, F>, S, B>
where where
F: Filter, F: Filter,
S: ServiceFactory<Request> + 'static, S: ServiceFactory<Request> + 'static,
@ -107,7 +107,7 @@ mod rustls {
let protos = vec!["h2".to_string().into()]; let protos = vec!["h2".to_string().into()];
config.alpn_protocols = protos; config.alpn_protocols = protos;
Acceptor::from(config) TlsAcceptor::from(config)
.timeout(self.cfg.ssl_handshake_timeout) .timeout(self.cfg.ssl_handshake_timeout)
.map_err(|e| SslError::Ssl(Box::new(e))) .map_err(|e| SslError::Ssl(Box::new(e)))
.map_init_err(|_| panic!()) .map_init_err(|_| panic!())

View file

@ -139,8 +139,8 @@ where
#[cfg(feature = "openssl")] #[cfg(feature = "openssl")]
mod openssl { mod openssl {
use ntex_tls::openssl::{Acceptor, SslFilter}; use ntex_tls::openssl::{SslAcceptor, SslFilter};
use tls_openssl::ssl::SslAcceptor; use tls_openssl::ssl;
use super::*; use super::*;
use crate::{io::Layer, server::SslError}; use crate::{io::Layer, server::SslError};
@ -164,14 +164,14 @@ mod openssl {
/// Create openssl based service /// Create openssl based service
pub fn openssl( pub fn openssl(
self, self,
acceptor: SslAcceptor, acceptor: ssl::SslAcceptor,
) -> impl ServiceFactory< ) -> impl ServiceFactory<
Io<F>, Io<F>,
Response = (), Response = (),
Error = SslError<DispatchError>, Error = SslError<DispatchError>,
InitError = (), InitError = (),
> { > {
Acceptor::new(acceptor) SslAcceptor::new(acceptor)
.timeout(self.cfg.ssl_handshake_timeout) .timeout(self.cfg.ssl_handshake_timeout)
.map_err(SslError::Ssl) .map_err(SslError::Ssl)
.map_init_err(|_| panic!()) .map_init_err(|_| panic!())
@ -182,13 +182,13 @@ mod openssl {
#[cfg(feature = "rustls")] #[cfg(feature = "rustls")]
mod rustls { mod rustls {
use ntex_tls::rustls::{Acceptor, TlsFilter}; use ntex_tls::rustls::{TlsAcceptor, TlsServerFilter};
use tls_rustls::ServerConfig; use tls_rustls::ServerConfig;
use super::*; use super::*;
use crate::{io::Layer, server::SslError}; use crate::{io::Layer, server::SslError};
impl<F, S, B, X, U> HttpService<Layer<TlsFilter, F>, S, B, X, U> impl<F, S, B, X, U> HttpService<Layer<TlsServerFilter, F>, S, B, X, U>
where where
F: Filter, F: Filter,
S: ServiceFactory<Request> + 'static, S: ServiceFactory<Request> + 'static,
@ -199,8 +199,10 @@ mod rustls {
X: ServiceFactory<Request, Response = Request> + 'static, X: ServiceFactory<Request, Response = Request> + 'static,
X::Error: ResponseError, X::Error: ResponseError,
X::InitError: fmt::Debug, X::InitError: fmt::Debug,
U: ServiceFactory<(Request, Io<Layer<TlsFilter, F>>, h1::Codec), Response = ()> U: ServiceFactory<
+ 'static, (Request, Io<Layer<TlsServerFilter, F>>, h1::Codec),
Response = (),
> + 'static,
U::Error: fmt::Display + error::Error, U::Error: fmt::Display + error::Error,
U::InitError: fmt::Debug, U::InitError: fmt::Debug,
{ {
@ -217,7 +219,7 @@ mod rustls {
let protos = vec!["h2".to_string().into(), "http/1.1".to_string().into()]; let protos = vec!["h2".to_string().into(), "http/1.1".to_string().into()];
config.alpn_protocols = protos; config.alpn_protocols = protos;
Acceptor::from(config) TlsAcceptor::from(config)
.timeout(self.cfg.ssl_handshake_timeout) .timeout(self.cfg.ssl_handshake_timeout)
.map_err(|e| SslError::Ssl(Box::new(e))) .map_err(|e| SslError::Ssl(Box::new(e)))
.map_init_err(|_| panic!()) .map_init_err(|_| panic!())

View file

@ -106,7 +106,6 @@ pub struct DefaultHeadersMiddleware<S> {
impl<S, E> Service<WebRequest<E>> for DefaultHeadersMiddleware<S> impl<S, E> Service<WebRequest<E>> for DefaultHeadersMiddleware<S>
where where
S: Service<WebRequest<E>, Response = WebResponse>, S: Service<WebRequest<E>, Response = WebResponse>,
E: 'static,
{ {
type Response = WebResponse; type Response = WebResponse;
type Error = S::Error; type Error = S::Error;

View file

@ -538,7 +538,7 @@ where
pub fn rustls( pub fn rustls(
&mut self, &mut self,
config: std::sync::Arc<rustls::ClientConfig>, config: std::sync::Arc<rustls::ClientConfig>,
) -> WsClientBuilder<Layer<rustls::TlsFilter>, rustls::Connector<Uri>> { ) -> WsClientBuilder<Layer<rustls::TlsClientFilter>, rustls::Connector<Uri>> {
self.connector(rustls::Connector::from(config)) self.connector(rustls::Connector::from(config))
} }

View file

@ -20,4 +20,4 @@ pub use self::frame::Parser;
pub use self::handshake::{handshake, handshake_response, verify_handshake}; pub use self::handshake::{handshake, handshake_response, verify_handshake};
pub use self::proto::{hash_key, CloseCode, CloseReason, OpCode}; pub use self::proto::{hash_key, CloseCode, CloseReason, OpCode};
pub use self::sink::WsSink; pub use self::sink::WsSink;
pub use self::transport::{WsTransport, WsTransportFactory}; pub use self::transport::{WsTransport, WsTransportService};

View file

@ -2,8 +2,9 @@
use std::{cell::Cell, cmp, io, task::Poll}; use std::{cell::Cell, cmp, io, task::Poll};
use crate::codec::{Decoder, Encoder}; use crate::codec::{Decoder, Encoder};
use crate::io::{Filter, FilterFactory, FilterLayer, Io, Layer, ReadBuf, WriteBuf}; use crate::io::{Filter, FilterLayer, Io, Layer, ReadBuf, WriteBuf};
use crate::util::{BufMut, PoolRef, Ready}; use crate::service::{Service, ServiceCtx};
use crate::util::{BufMut, PoolRef};
use super::{CloseCode, CloseReason, Codec, Frame, Item, Message}; use super::{CloseCode, CloseReason, Codec, Frame, Item, Message};
@ -174,25 +175,27 @@ impl FilterLayer for WsTransport {
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
/// WebSockets transport factory /// WebSockets transport service
pub struct WsTransportFactory { pub struct WsTransportService {
codec: Codec, codec: Codec,
} }
impl WsTransportFactory { impl WsTransportService {
/// Create websockets transport factory /// Create websockets transport service
pub fn new(codec: Codec) -> Self { pub fn new(codec: Codec) -> Self {
Self { codec } Self { codec }
} }
} }
impl<F: Filter> FilterFactory<F> for WsTransportFactory { impl<F: Filter> Service<Io<F>> for WsTransportService {
type Filter = WsTransport; type Response = Io<Layer<WsTransport, F>>;
type Error = io::Error; type Error = io::Error;
type Future = Ready<Io<Layer<Self::Filter, F>>, Self::Error>;
fn create(self, io: Io<F>) -> Self::Future { async fn call(
Ready::Ok(WsTransport::create(io, self.codec)) &self,
io: Io<F>,
_: ServiceCtx<'_, Self>,
) -> Result<Self::Response, Self::Error> {
Ok(WsTransport::create(io, self.codec.clone()))
} }
} }

View file

@ -84,7 +84,7 @@ async fn test_openssl_string() {
assert!(res.is_ok()); assert!(res.is_ok());
Ok(io) Ok(io)
})) }))
.and_then(openssl::Acceptor::new(ssl_acceptor())) .and_then(openssl::SslAcceptor::new(ssl_acceptor()))
.and_then(fn_service(|io: Io<_>| async move { .and_then(fn_service(|io: Io<_>| async move {
io.send(Bytes::from_static(b"test"), &BytesCodec) io.send(Bytes::from_static(b"test"), &BytesCodec)
.await .await
@ -127,7 +127,7 @@ async fn test_openssl_read_before_error() {
assert!(res.is_ok()); assert!(res.is_ok());
Ok(io) Ok(io)
})) }))
.and_then(openssl::Acceptor::new(ssl_acceptor())) .and_then(openssl::SslAcceptor::new(ssl_acceptor()))
.and_then(fn_service(|io: Io<_>| async move { .and_then(fn_service(|io: Io<_>| async move {
io.send(Bytes::from_static(b"test"), &Rc::new(BytesCodec)) io.send(Bytes::from_static(b"test"), &Rc::new(BytesCodec))
.await .await
@ -168,7 +168,7 @@ async fn test_rustls_string() {
assert!(res.is_ok()); assert!(res.is_ok());
Ok(io) Ok(io)
})) }))
.and_then(rustls::Acceptor::new(tls_acceptor())) .and_then(rustls::TlsAcceptor::new(tls_acceptor()))
.and_then(fn_service(|io: Io<_>| async move { .and_then(fn_service(|io: Io<_>| async move {
assert!(io.query::<PeerCert>().as_ref().is_none()); assert!(io.query::<PeerCert>().as_ref().is_none());
assert!(io.query::<PeerCertChain>().as_ref().is_none()); assert!(io.query::<PeerCertChain>().as_ref().is_none());