Refactor filter factories (#278)

This commit is contained in:
Nikolay Kim 2024-01-08 15:22:38 +06:00 committed by GitHub
parent a13f677df8
commit 174b5d86f0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
34 changed files with 271 additions and 657 deletions

View file

@ -1,5 +1,11 @@
# Changes
## [1.0.0-b.1] - 2024-01-08
* Refactor io tls filters
* Remove unnecessary 'static
## [1.0.0-b.0] - 2024-01-07
* Use "async fn" in trait for Service definition

View file

@ -1,6 +1,6 @@
[package]
name = "ntex-connect"
version = "1.0.0-b.0"
version = "1.0.0-b.1"
authors = ["ntex contributors <team@ntex.rs>"]
description = "ntexwork connect utils for ntex framework"
keywords = ["network", "framework", "async", "futures"]
@ -35,9 +35,9 @@ async-std = ["ntex-rt/async-std", "ntex-async-std"]
[dependencies]
ntex-service = "2.0.0-b.0"
ntex-io = "1.0.0-b.0"
ntex-tls = "1.0.0-b.0"
ntex-util = "1.0.0-b.0"
ntex-io = "1.0.0-b.1"
ntex-tls = "1.0.0-b.1"
ntex-util = "1.0.0-b.1"
ntex-bytes = "0.1.21"
ntex-http = "0.1"
ntex-rt = "0.4.7"

View file

@ -4,9 +4,9 @@ pub use ntex_tls::openssl::SslFilter;
pub use tls_openssl::ssl::{Error as SslError, HandshakeError, SslConnector, SslMethod};
use ntex_bytes::PoolId;
use ntex_io::{FilterFactory, Io, Layer};
use ntex_io::{Io, Layer};
use ntex_service::{Pipeline, Service, ServiceCtx, ServiceFactory};
use ntex_tls::openssl::SslConnector as IoSslConnector;
use ntex_tls::openssl::connect as connect_io;
use super::{Address, Connect, ConnectError, Connector as BaseConnector};
@ -64,7 +64,7 @@ impl<T: Address> Connector<T> {
.into_ssl(&host)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
let tag = io.tag();
match IoSslConnector::new(ssl).create(io).await {
match connect_io(io, ssl).await {
Ok(io) => {
log::trace!("{}: SSL Handshake success: {:?}", tag, host);
Ok(io)
@ -97,7 +97,7 @@ impl<T> fmt::Debug for Connector<T> {
}
}
impl<T: Address, C: 'static> ServiceFactory<Connect<T>, C> for Connector<T> {
impl<T: Address, C> ServiceFactory<Connect<T>, C> for Connector<T> {
type Response = Io<Layer<SslFilter>>;
type Error = ConnectError;
type Service = Connector<T>;

View file

@ -97,7 +97,7 @@ impl<T> fmt::Debug for Resolver<T> {
}
}
impl<T: Address, C: 'static> ServiceFactory<Connect<T>, C> for Resolver<T> {
impl<T: Address, C> ServiceFactory<Connect<T>, C> for Resolver<T> {
type Response = Connect<T>;
type Error = ConnectError;
type Service = Resolver<T>;

View file

@ -1,25 +1,24 @@
use std::{fmt, io};
use std::{fmt, io, sync::Arc};
pub use ntex_tls::rustls::TlsFilter;
pub use ntex_tls::rustls::TlsClientFilter;
pub use tls_rustls::{ClientConfig, ServerName};
use ntex_bytes::PoolId;
use ntex_io::{FilterFactory, Io, Layer};
use ntex_io::{Io, Layer};
use ntex_service::{Pipeline, Service, ServiceCtx, ServiceFactory};
use ntex_tls::rustls::TlsConnector;
use super::{Address, Connect, ConnectError, Connector as BaseConnector};
/// Rustls connector factory
pub struct Connector<T> {
connector: Pipeline<BaseConnector<T>>,
inner: TlsConnector,
config: Arc<ClientConfig>,
}
impl<T: Address> From<std::sync::Arc<ClientConfig>> for Connector<T> {
fn from(cfg: std::sync::Arc<ClientConfig>) -> Self {
impl<T: Address> From<Arc<ClientConfig>> for Connector<T> {
fn from(config: Arc<ClientConfig>) -> Self {
Connector {
inner: TlsConnector::new(cfg),
config,
connector: BaseConnector::default().into(),
}
}
@ -28,7 +27,7 @@ impl<T: Address> From<std::sync::Arc<ClientConfig>> for Connector<T> {
impl<T: Address> Connector<T> {
pub fn new(config: ClientConfig) -> Self {
Connector {
inner: TlsConnector::new(std::sync::Arc::new(config)),
config: Arc::new(config),
connector: BaseConnector::default().into(),
}
}
@ -46,38 +45,39 @@ impl<T: Address> Connector<T> {
.into();
Self {
connector,
inner: self.inner,
config: self.config,
}
}
}
impl<T: Address + 'static> Connector<T> {
impl<T: Address> Connector<T> {
/// Resolve and connect to remote host
pub async fn connect<U>(&self, message: U) -> Result<Io<Layer<TlsFilter>>, ConnectError>
pub async fn connect<U>(
&self,
message: U,
) -> Result<Io<Layer<TlsClientFilter>>, ConnectError>
where
Connect<T>: From<U>,
{
let req = Connect::from(message);
let host = req.host().split(':').next().unwrap().to_owned();
let conn = self.connector.call(req);
let connector = self.inner.clone();
let io = self.connector.call(req).await?;
let io = conn.await?;
log::trace!("{}: SSL Handshake start for: {:?}", io.tag(), host);
let tag = io.tag();
let config = self.config.clone();
let host = ServerName::try_from(host.as_str())
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("{}", e)))?;
let connector = connector.server_name(host.clone());
match connector.create(io).await {
match TlsClientFilter::create(io, config, host.clone()).await {
Ok(io) => {
log::trace!("{}: TLS Handshake success: {:?}", tag, &host);
Ok(io)
}
Err(e) => {
log::trace!("{}: TLS Handshake error: {:?}", tag, e);
Err(io::Error::new(io::ErrorKind::Other, format!("{}", e)).into())
Err(e.into())
}
}
}
@ -86,7 +86,7 @@ impl<T: Address + 'static> Connector<T> {
impl<T> Clone for Connector<T> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
config: self.config.clone(),
connector: self.connector.clone(),
}
}
@ -100,8 +100,8 @@ impl<T> fmt::Debug for Connector<T> {
}
}
impl<T: Address, C: 'static> ServiceFactory<Connect<T>, C> for Connector<T> {
type Response = Io<Layer<TlsFilter>>;
impl<T: Address, C> ServiceFactory<Connect<T>, C> for Connector<T> {
type Response = Io<Layer<TlsClientFilter>>;
type Error = ConnectError;
type Service = Connector<T>;
type InitError = ();
@ -112,7 +112,7 @@ impl<T: Address, C: 'static> ServiceFactory<Connect<T>, C> for Connector<T> {
}
impl<T: Address> Service<Connect<T>> for Connector<T> {
type Response = Io<Layer<TlsFilter>>;
type Response = Io<Layer<TlsClientFilter>>;
type Error = ConnectError;
async fn call(

View file

@ -99,7 +99,7 @@ impl<T> fmt::Debug for Connector<T> {
}
}
impl<T: Address, C: 'static> ServiceFactory<Connect<T>, C> for Connector<T> {
impl<T: Address, C> ServiceFactory<Connect<T>, C> for Connector<T> {
type Response = Io;
type Error = ConnectError;
type Service = Connector<T>;

View file

@ -1,5 +1,9 @@
# Changes
## [1.0.0-b.1] - 2024-01-08
* Remove FilterFactory trait and related utils
## [1.0.0-b.0] - 2024-01-07
* Use "async fn" in trait for Service definition

View file

@ -1,6 +1,6 @@
[package]
name = "ntex-io"
version = "1.0.0-b.0"
version = "1.0.0-b.1"
authors = ["ntex contributors <team@ntex.rs>"]
description = "Utilities for encoding and decoding frames"
keywords = ["network", "framework", "async", "futures"]
@ -18,7 +18,7 @@ path = "src/lib.rs"
[dependencies]
ntex-codec = "0.6.2"
ntex-bytes = "0.1.21"
ntex-util = "1.0.0-b.0"
ntex-util = "1.0.0-b.1"
ntex-service = "2.0.0-b.0"
bitflags = "2.4"

View file

@ -2,8 +2,7 @@
#![deny(rust_2018_idioms, unreachable_pub, missing_debug_implementations)]
use std::{
any::Any, any::TypeId, fmt, future::Future, io as sio, io::Error as IoError,
task::Context, task::Poll,
any::Any, any::TypeId, fmt, io as sio, io::Error as IoError, task::Context, task::Poll,
};
pub mod testing;
@ -31,7 +30,7 @@ pub use self::io::{Io, IoRef, OnDisconnect};
pub use self::seal::{IoBoxed, Sealed};
pub use self::tasks::{ReadContext, WriteContext};
pub use self::timer::TimerHandle;
pub use self::utils::{filter, seal, Decoded};
pub use self::utils::{seal, Decoded};
/// Status for read task
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
@ -92,19 +91,6 @@ pub trait FilterLayer: fmt::Debug + 'static {
}
}
/// Creates new `Filter` values.
pub trait FilterFactory<F>: Sized {
/// The `Filter` value created by this factory
type Filter: FilterLayer;
/// Errors produced while building a filter.
type Error: fmt::Debug;
/// The future of the `FilterFactory` instance.
type Future: Future<Output = Result<Io<Layer<Self::Filter, F>>, Self::Error>>;
/// Create and return a new filter value asynchronously.
fn create(self, st: Io<F>) -> Self::Future;
}
pub trait IoStream {
fn start(self, _: ReadContext, _: WriteContext) -> Option<Box<dyn Handle>>;
}

View file

@ -1,9 +1,7 @@
use std::{fmt, marker::PhantomData};
use ntex_service::{chain_factory, fn_service, Service, ServiceCtx, ServiceFactory};
use ntex_service::{chain_factory, fn_service, ServiceFactory};
use ntex_util::future::Ready;
use crate::{Filter, FilterFactory, Io, IoBoxed, Layer};
use crate::{Filter, Io, IoBoxed};
/// Decoded item from buffer
#[doc(hidden)]
@ -34,88 +32,13 @@ where
.and_then(srv)
}
/// Create filter factory service
pub fn filter<T, F>(filter: T) -> FilterServiceFactory<T, F>
where
T: FilterFactory<F> + Clone,
{
FilterServiceFactory {
filter,
_t: PhantomData,
}
}
pub struct FilterServiceFactory<T, F> {
filter: T,
_t: PhantomData<F>,
}
impl<T: FilterFactory<F> + fmt::Debug, F> fmt::Debug for FilterServiceFactory<T, F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FilterServiceFactory")
.field("filter_factory", &self.filter)
.finish()
}
}
impl<T, F> ServiceFactory<Io<F>> for FilterServiceFactory<T, F>
where
T: FilterFactory<F> + Clone,
{
type Response = Io<Layer<T::Filter, F>>;
type Error = T::Error;
type Service = FilterService<T, F>;
type InitError = ();
async fn create(&self, _: ()) -> Result<Self::Service, Self::InitError> {
Ok(FilterService {
filter: self.filter.clone(),
_t: PhantomData,
})
}
}
pub struct FilterService<T, F> {
filter: T,
_t: PhantomData<F>,
}
impl<T: FilterFactory<F> + fmt::Debug, F> fmt::Debug for FilterService<T, F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FilterService")
.field("filter_factory", &self.filter)
.finish()
}
}
impl<T, F> Service<Io<F>> for FilterService<T, F>
where
T: FilterFactory<F> + Clone,
{
type Response = Io<Layer<T::Filter, F>>;
type Error = T::Error;
#[inline]
async fn call(
&self,
req: Io<F>,
_: ServiceCtx<'_, Self>,
) -> Result<Self::Response, Self::Error> {
self.filter.clone().create(req).await
}
}
#[cfg(test)]
mod tests {
use std::io;
use ntex_bytes::Bytes;
use ntex_codec::BytesCodec;
use super::*;
use crate::{
buf::Stack, filter::NullFilter, testing::IoTest, FilterLayer, ReadBuf, WriteBuf,
};
use crate::{buf::Stack, filter::NullFilter, testing::IoTest};
#[ntex::test]
async fn test_utils() {
@ -140,62 +63,6 @@ mod tests {
assert_eq!(buf, b"RES".as_ref());
}
#[derive(Debug)]
pub(crate) struct TestFilter;
impl FilterLayer for TestFilter {
fn process_read_buf(&self, buf: &ReadBuf<'_>) -> io::Result<usize> {
Ok(buf.nbytes())
}
fn process_write_buf(&self, _: &WriteBuf<'_>) -> io::Result<()> {
Ok(())
}
}
#[derive(Copy, Clone, Debug)]
struct TestFilterFactory;
impl<F: Filter> FilterFactory<F> for TestFilterFactory {
type Filter = TestFilter;
type Error = std::convert::Infallible;
type Future = Ready<Io<Layer<TestFilter, F>>, Self::Error>;
fn create(self, st: Io<F>) -> Self::Future {
Ready::Ok(st.add_filter(TestFilter))
}
}
#[ntex::test]
async fn test_utils_filter() {
let (_, server) = IoTest::create();
let filter_service_factory = filter::<_, crate::filter::Base>(TestFilterFactory)
.map_err(|_| ())
.map_init_err(|_| ());
assert!(format!("{:?}", filter_service_factory).contains("FilterServiceFactory"));
let svc = chain_factory(filter_service_factory)
.and_then(seal(fn_service(|io: IoBoxed| async move {
let _ = io.recv(&BytesCodec).await;
Ok::<_, ()>(())
})))
.pipeline(())
.await
.unwrap();
let _ = svc.call(Io::new(server)).await;
let (client, _) = IoTest::create();
let io = Io::new(client);
format!("{:?}", TestFilter);
let mut s = Stack::new();
s.add_layer();
let _ = s.read_buf(&io, 0, 0, |b| TestFilter.process_read_buf(b));
let _ = s.write_buf(&io, 0, |b| TestFilter.process_write_buf(b));
}
#[ntex::test]
async fn test_null_filter() {
let (_, server) = IoTest::create();

View file

@ -1,5 +1,9 @@
# Changes
## [1.0.0-b.1] - 2024-01-08
* Refactor io tls filters
## [1.0.0-b.0] - 2024-01-07
* Use "async fn" in trait for Service definition

View file

@ -1,6 +1,6 @@
[package]
name = "ntex-tls"
version = "1.0.0-b.0"
version = "1.0.0-b.1"
authors = ["ntex contributors <team@ntex.rs>"]
description = "An implementation of SSL streams for ntex backed by OpenSSL"
keywords = ["network", "framework", "async", "futures"]
@ -26,8 +26,8 @@ rustls = ["tls_rust"]
[dependencies]
ntex-bytes = "0.1.21"
ntex-io = "1.0.0-b.0"
ntex-util = "1.0.0-b.0"
ntex-io = "1.0.0-b.1"
ntex-util = "1.0.0-b.1"
ntex-service = "2.0.0-b.0"
log = "0.4"
pin-project-lite = "0.2"

View file

@ -1,7 +1,7 @@
use std::{fs::File, io, io::BufReader, sync::Arc};
use ntex::service::{chain_factory, fn_service};
use ntex::{codec, io::filter, io::Io, server, util::Either};
use ntex::{codec, io::Io, server, util::Either};
use ntex_tls::rustls::TlsAcceptor;
use rustls_pemfile::{certs, rsa_private_keys};
use tls_rust::{Certificate, PrivateKey, ServerConfig};
@ -34,8 +34,8 @@ async fn main() -> io::Result<()> {
// start server
server::ServerBuilder::new()
.bind("basic", "127.0.0.1:8443", move |_| {
chain_factory(filter(TlsAcceptor::new(tls_config.clone()))).and_then(
fn_service(|io: Io<_>| async move {
chain_factory(TlsAcceptor::new(tls_config.clone())).and_then(fn_service(
|io: Io<_>| async move {
println!("New client is connected");
io.send(
@ -62,8 +62,8 @@ async fn main() -> io::Result<()> {
}
println!("Client is disconnected");
Ok(())
}),
)
},
))
})?
.workers(1)
.run()

View file

@ -1,7 +1,7 @@
use std::io;
use ntex::service::{chain_factory, fn_service};
use ntex::{codec, io::filter, io::Io, server, util::Either};
use ntex::{codec, io::Io, server, util::Either};
use ntex_tls::openssl::{PeerCert, PeerCertChain, SslAcceptor};
use tls_openssl::ssl::{self, SslFiletype, SslMethod, SslVerifyMode};
@ -27,7 +27,7 @@ async fn main() -> io::Result<()> {
// start server
server::ServerBuilder::new()
.bind("basic", "127.0.0.1:8443", move |_| {
chain_factory(filter(SslAcceptor::new(acceptor.clone()))).and_then(fn_service(
chain_factory(SslAcceptor::new(acceptor.clone())).and_then(fn_service(
|io: Io<_>| async move {
println!("New client is connected");
if let Some(cert) = io.query::<PeerCert>().as_ref() {

View file

@ -1,31 +1,29 @@
use std::task::{Context, Poll};
use std::{error::Error, marker::PhantomData};
use std::{cell::RefCell, error::Error, fmt, io, task::Context, task::Poll};
use ntex_io::{Filter, FilterFactory, Io, Layer};
use ntex_io::{Filter, Io, Layer};
use ntex_service::{Service, ServiceCtx, ServiceFactory};
use ntex_util::time::Millis;
use tls_openssl::ssl::SslAcceptor;
use ntex_util::time::{self, Millis};
use tls_openssl::ssl;
use crate::counter::Counter;
use crate::MAX_SSL_ACCEPT_COUNTER;
use super::{SslAcceptor as IoSslAcceptor, SslFilter};
use super::SslFilter;
#[derive(Debug)]
/// Support `TLS` server connections via openssl package
///
/// `openssl` feature enables `Acceptor` type
pub struct Acceptor<F> {
acceptor: IoSslAcceptor,
_t: PhantomData<F>,
pub struct SslAcceptor {
acceptor: ssl::SslAcceptor,
timeout: Millis,
}
impl<F> Acceptor<F> {
impl SslAcceptor {
/// Create default openssl acceptor service
pub fn new(acceptor: SslAcceptor) -> Self {
Acceptor {
acceptor: IoSslAcceptor::new(acceptor),
_t: PhantomData,
pub fn new(acceptor: ssl::SslAcceptor) -> Self {
SslAcceptor {
acceptor,
timeout: Millis(5_000),
}
}
@ -33,54 +31,69 @@ impl<F> Acceptor<F> {
///
/// Default is set to 5 seconds.
pub fn timeout<U: Into<Millis>>(mut self, timeout: U) -> Self {
self.acceptor.timeout(timeout);
self.timeout = timeout.into();
self
}
}
impl<F> From<SslAcceptor> for Acceptor<F> {
fn from(acceptor: SslAcceptor) -> Self {
impl fmt::Debug for SslAcceptor {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SslAcceptor")
.field("timeout", &self.timeout)
.finish()
}
}
impl From<ssl::SslAcceptor> for SslAcceptor {
fn from(acceptor: ssl::SslAcceptor) -> Self {
Self::new(acceptor)
}
}
impl<F> Clone for Acceptor<F> {
impl Clone for SslAcceptor {
fn clone(&self) -> Self {
Self {
acceptor: self.acceptor.clone(),
_t: PhantomData,
timeout: self.timeout,
}
}
}
impl<F: Filter, C: 'static> ServiceFactory<Io<F>, C> for Acceptor<F> {
impl<F: Filter, C> ServiceFactory<Io<F>, C> for SslAcceptor {
type Response = Io<Layer<SslFilter, F>>;
type Error = Box<dyn Error>;
type Service = AcceptorService<F>;
type Service = SslAcceptorService;
type InitError = ();
async fn create(&self, _: C) -> Result<Self::Service, Self::InitError> {
MAX_SSL_ACCEPT_COUNTER.with(|conns| {
Ok(AcceptorService {
Ok(SslAcceptorService {
acceptor: self.acceptor.clone(),
timeout: self.timeout,
conns: conns.clone(),
_t: PhantomData,
})
})
}
}
#[derive(Debug)]
/// Support `TLS` server connections via openssl package
///
/// `openssl` feature enables `Acceptor` type
pub struct AcceptorService<F> {
acceptor: IoSslAcceptor,
pub struct SslAcceptorService {
acceptor: ssl::SslAcceptor,
timeout: Millis,
conns: Counter,
_t: PhantomData<F>,
}
impl<F: Filter> Service<Io<F>> for AcceptorService<F> {
impl fmt::Debug for SslAcceptorService {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SslAcceptorService")
.field("timeout", &self.timeout)
.finish()
}
}
impl<F: Filter> Service<Io<F>> for SslAcceptorService {
type Response = Io<Layer<SslFilter, F>>;
type Error = Box<dyn Error>;
@ -94,10 +107,40 @@ impl<F: Filter> Service<Io<F>> for AcceptorService<F> {
async fn call(
&self,
req: Io<F>,
io: Io<F>,
_: ServiceCtx<'_, Self>,
) -> Result<Self::Response, Self::Error> {
let _guard = self.conns.get();
self.acceptor.clone().create(req).await
let timeout = self.timeout;
let ctx_result = ssl::Ssl::new(self.acceptor.context());
time::timeout(timeout, async {
let ssl = ctx_result.map_err(super::map_to_ioerr)?;
let inner = super::IoInner {
source: None,
destination: None,
};
let filter = SslFilter {
inner: RefCell::new(ssl::SslStream::new(ssl, inner)?),
};
let io = io.add_filter(filter);
log::debug!("Accepting tls connection");
loop {
let result = io.with_buf(|buf| {
let filter = io.filter();
filter.with_buffers(buf, || filter.inner.borrow_mut().accept())
})?;
if super::handle_result(&io, result).await?.is_some() {
break;
}
}
Ok(io)
})
.await
.map_err(|_| {
io::Error::new(io::ErrorKind::TimedOut, "ssl handshake timeout").into()
})
.and_then(|item| item)
}
}

View file

@ -1,17 +1,15 @@
//! An implementation of SSL streams for ntex backed by OpenSSL
use std::cell::RefCell;
use std::{any, cmp, error::Error, fmt, io, task::Poll};
use std::{any, cell::RefCell, cmp, error::Error, io, task::Poll};
use ntex_bytes::{BufMut, BytesVec};
use ntex_io::{types, Filter, FilterFactory, FilterLayer, Io, Layer, ReadBuf, WriteBuf};
use ntex_util::{future::BoxFuture, time, time::Millis};
use ntex_io::{types, Filter, FilterLayer, Io, Layer, ReadBuf, WriteBuf};
use tls_openssl::ssl::{self, NameType, SslStream};
use tls_openssl::x509::X509;
use crate::{PskIdentity, Servername};
mod accept;
pub use self::accept::{Acceptor, AcceptorService};
pub use self::accept::{SslAcceptor, SslAcceptorService};
/// Connection's peer cert
#[derive(Debug)]
@ -211,132 +209,31 @@ impl FilterLayer for SslFilter {
}
}
pub struct SslAcceptor {
acceptor: ssl::SslAcceptor,
timeout: Millis,
}
impl SslAcceptor {
/// Create openssl acceptor filter factory
pub fn new(acceptor: ssl::SslAcceptor) -> Self {
SslAcceptor {
acceptor,
timeout: Millis(5_000),
}
}
/// Set handshake timeout.
///
/// Default is set to 5 seconds.
pub fn timeout<U: Into<Millis>>(&mut self, timeout: U) -> &mut Self {
self.timeout = timeout.into();
self
}
}
impl Clone for SslAcceptor {
fn clone(&self) -> Self {
Self {
acceptor: self.acceptor.clone(),
timeout: self.timeout,
}
}
}
impl fmt::Debug for SslAcceptor {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SslAcceptor")
.field("timeout", &self.timeout)
.finish()
}
}
impl<F: Filter> FilterFactory<F> for SslAcceptor {
type Filter = SslFilter;
type Error = Box<dyn Error>;
type Future = BoxFuture<'static, Result<Io<Layer<Self::Filter, F>>, Self::Error>>;
fn create(self, io: Io<F>) -> Self::Future {
let timeout = self.timeout;
let ctx_result = ssl::Ssl::new(self.acceptor.context());
Box::pin(async move {
time::timeout(timeout, async {
let ssl = ctx_result.map_err(map_to_ioerr)?;
let inner = IoInner {
source: None,
destination: None,
};
let filter = SslFilter {
inner: RefCell::new(ssl::SslStream::new(ssl, inner)?),
};
let io = io.add_filter(filter);
log::debug!("Accepting tls connection");
loop {
let result = io.with_buf(|buf| {
let filter = io.filter();
filter.with_buffers(buf, || filter.inner.borrow_mut().accept())
})?;
if handle_result(&io, result).await?.is_some() {
break;
}
}
Ok(io)
})
.await
.map_err(|_| {
io::Error::new(io::ErrorKind::TimedOut, "ssl handshake timeout").into()
})
.and_then(|item| item)
})
}
}
#[derive(Debug)]
pub struct SslConnector {
/// Create openssl connector filter factory
pub async fn connect<F: Filter>(
io: Io<F>,
ssl: ssl::Ssl,
}
) -> Result<Io<Layer<SslFilter, F>>, io::Error> {
let inner = IoInner {
source: None,
destination: None,
};
let filter = SslFilter {
inner: RefCell::new(ssl::SslStream::new(ssl, inner)?),
};
let io = io.add_filter(filter);
impl SslConnector {
/// Create openssl connector filter factory
pub fn new(ssl: ssl::Ssl) -> Self {
SslConnector { ssl }
loop {
let result = io.with_buf(|buf| {
let filter = io.filter();
filter.with_buffers(buf, || filter.inner.borrow_mut().connect())
})?;
if handle_result(&io, result).await?.is_some() {
break;
}
}
}
impl<F: Filter> FilterFactory<F> for SslConnector {
type Filter = SslFilter;
type Error = Box<dyn Error>;
type Future = BoxFuture<'static, Result<Io<Layer<Self::Filter, F>>, Self::Error>>;
fn create(self, io: Io<F>) -> Self::Future {
Box::pin(async move {
let inner = IoInner {
source: None,
destination: None,
};
let filter = SslFilter {
inner: RefCell::new(ssl::SslStream::new(self.ssl, inner)?),
};
let io = io.add_filter(filter);
loop {
let result = io.with_buf(|buf| {
let filter = io.filter();
filter.with_buffers(buf, || filter.inner.borrow_mut().connect())
})?;
if handle_result(&io, result).await?.is_some() {
break;
}
}
Ok(io)
})
}
Ok(io)
}
async fn handle_result<T, F>(

View file

@ -1,30 +1,30 @@
use std::task::{Context, Poll};
use std::{io, marker::PhantomData, sync::Arc};
use std::{io, sync::Arc};
use tls_rust::ServerConfig;
use ntex_io::{Filter, FilterFactory, Io, Layer};
use ntex_io::{Filter, Io, Layer};
use ntex_service::{Service, ServiceCtx, ServiceFactory};
use ntex_util::time::Millis;
use super::{TlsAcceptor, TlsFilter};
use super::TlsServerFilter;
use crate::{counter::Counter, MAX_SSL_ACCEPT_COUNTER};
#[derive(Debug)]
/// Support `SSL` connections via rustls package
///
/// `rust-tls` feature enables `RustlsAcceptor` type
pub struct Acceptor<F> {
inner: TlsAcceptor,
_t: PhantomData<F>,
pub struct TlsAcceptor {
config: Arc<ServerConfig>,
timeout: Millis,
}
impl<F> Acceptor<F> {
impl TlsAcceptor {
/// Create rustls based `Acceptor` service factory
pub fn new(config: Arc<ServerConfig>) -> Self {
Acceptor {
inner: TlsAcceptor::new(config),
_t: PhantomData,
Self {
config,
timeout: Millis(5_000),
}
}
@ -32,38 +32,38 @@ impl<F> Acceptor<F> {
///
/// Default is set to 5 seconds.
pub fn timeout<U: Into<Millis>>(mut self, timeout: U) -> Self {
self.inner.timeout(timeout.into());
self.timeout = timeout.into();
self
}
}
impl<F> From<ServerConfig> for Acceptor<F> {
impl From<ServerConfig> for TlsAcceptor {
fn from(cfg: ServerConfig) -> Self {
Self::new(Arc::new(cfg))
}
}
impl<F> Clone for Acceptor<F> {
impl Clone for TlsAcceptor {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
_t: PhantomData,
config: self.config.clone(),
timeout: self.timeout,
}
}
}
impl<F: Filter, C: 'static> ServiceFactory<Io<F>, C> for Acceptor<F> {
type Response = Io<Layer<TlsFilter, F>>;
impl<F: Filter, C> ServiceFactory<Io<F>, C> for TlsAcceptor {
type Response = Io<Layer<TlsServerFilter, F>>;
type Error = io::Error;
type Service = AcceptorService<F>;
type Service = TlsAcceptorService;
type InitError = ();
async fn create(&self, _: C) -> Result<Self::Service, Self::InitError> {
MAX_SSL_ACCEPT_COUNTER.with(|conns| {
Ok(AcceptorService {
acceptor: self.inner.clone(),
Ok(TlsAcceptorService {
config: self.config.clone(),
timeout: self.timeout,
conns: conns.clone(),
io: PhantomData,
})
})
}
@ -71,14 +71,14 @@ impl<F: Filter, C: 'static> ServiceFactory<Io<F>, C> for Acceptor<F> {
#[derive(Debug)]
/// RusTLS based `Acceptor` service
pub struct AcceptorService<F> {
acceptor: TlsAcceptor,
io: PhantomData<F>,
pub struct TlsAcceptorService {
config: Arc<ServerConfig>,
timeout: Millis,
conns: Counter,
}
impl<F: Filter> Service<Io<F>> for AcceptorService<F> {
type Response = Io<Layer<TlsFilter, F>>;
impl<F: Filter> Service<Io<F>> for TlsAcceptorService {
type Response = Io<Layer<TlsServerFilter, F>>;
type Error = io::Error;
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
@ -91,10 +91,10 @@ impl<F: Filter> Service<Io<F>> for AcceptorService<F> {
async fn call(
&self,
req: Io<F>,
io: Io<F>,
_: ServiceCtx<'_, Self>,
) -> Result<Self::Response, Self::Error> {
let _guard = self.conns.get();
self.acceptor.clone().create(req).await
super::TlsServerFilter::create(io, self.config.clone(), self.timeout).await
}
}

View file

@ -7,13 +7,11 @@ use ntex_io::{types, Filter, FilterLayer, Io, Layer, ReadBuf, WriteBuf};
use ntex_util::ready;
use tls_rust::{ClientConfig, ClientConnection, ServerName};
use crate::rustls::{TlsFilter, Wrapper};
use super::{PeerCert, PeerCertChain};
use super::{PeerCert, PeerCertChain, Wrapper};
#[derive(Debug)]
/// An implementation of SSL streams
pub(crate) struct TlsClientFilter {
pub struct TlsClientFilter {
session: RefCell<ClientConnection>,
}
@ -114,22 +112,22 @@ impl FilterLayer for TlsClientFilter {
}
impl TlsClientFilter {
pub(crate) async fn create<F: Filter>(
pub async fn create<F: Filter>(
io: Io<F>,
cfg: Arc<ClientConfig>,
domain: ServerName,
) -> Result<Io<Layer<TlsFilter, F>>, io::Error> {
) -> Result<Io<Layer<TlsClientFilter, F>>, io::Error> {
let session = ClientConnection::new(cfg, domain)
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
let filter = TlsFilter::new_client(TlsClientFilter {
let filter = TlsClientFilter {
session: RefCell::new(session),
});
};
let io = io.add_filter(filter);
let filter = io.filter();
loop {
let (result, wants_read, handshaking) = io.with_buf(|buf| {
let mut session = filter.client().session.borrow_mut();
let mut session = filter.session.borrow_mut();
let mut wrp = Wrapper(buf);
let mut result = (
session.complete_io(&mut wrp),

View file

@ -1,21 +1,16 @@
#![allow(clippy::type_complexity)]
//! An implementation of SSL streams for ntex backed by OpenSSL
use std::{any, cmp, io, sync::Arc, task::Context, task::Poll};
use std::{cmp, io};
use ntex_io::{
Filter, FilterFactory, FilterLayer, Io, Layer, ReadBuf, ReadStatus, WriteBuf,
WriteStatus,
};
use ntex_util::{future::BoxFuture, time::Millis};
use tls_rust::{Certificate, ClientConfig, ServerConfig, ServerName};
use ntex_io::WriteBuf;
use tls_rust::Certificate;
mod accept;
mod client;
mod server;
pub use accept::{Acceptor, AcceptorService};
pub use accept::{TlsAcceptor, TlsAcceptorService};
use self::client::TlsClientFilter;
use self::server::TlsServerFilter;
pub use self::client::TlsClientFilter;
pub use self::server::TlsServerFilter;
/// Connection's peer cert
#[derive(Debug)]
@ -25,203 +20,6 @@ pub struct PeerCert(pub Certificate);
#[derive(Debug)]
pub struct PeerCertChain(pub Vec<Certificate>);
#[derive(Debug)]
/// An implementation of SSL streams
pub struct TlsFilter {
inner: InnerTlsFilter,
}
#[derive(Debug)]
enum InnerTlsFilter {
Server(TlsServerFilter),
Client(TlsClientFilter),
}
impl TlsFilter {
fn new_server(server: TlsServerFilter) -> Self {
TlsFilter {
inner: InnerTlsFilter::Server(server),
}
}
fn new_client(client: TlsClientFilter) -> Self {
TlsFilter {
inner: InnerTlsFilter::Client(client),
}
}
fn server(&self) -> &TlsServerFilter {
match self.inner {
InnerTlsFilter::Server(ref server) => server,
_ => unreachable!(),
}
}
fn client(&self) -> &TlsClientFilter {
match self.inner {
InnerTlsFilter::Client(ref server) => server,
_ => unreachable!(),
}
}
}
impl FilterLayer for TlsFilter {
#[inline]
fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> {
match self.inner {
InnerTlsFilter::Server(ref f) => f.query(id),
InnerTlsFilter::Client(ref f) => f.query(id),
}
}
#[inline]
fn shutdown(&self, buf: &WriteBuf<'_>) -> io::Result<Poll<()>> {
match self.inner {
InnerTlsFilter::Server(ref f) => f.shutdown(buf),
InnerTlsFilter::Client(ref f) => f.shutdown(buf),
}
}
#[inline]
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> {
match self.inner {
InnerTlsFilter::Server(ref f) => f.poll_read_ready(cx),
InnerTlsFilter::Client(ref f) => f.poll_read_ready(cx),
}
}
#[inline]
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> {
match self.inner {
InnerTlsFilter::Server(ref f) => f.poll_write_ready(cx),
InnerTlsFilter::Client(ref f) => f.poll_write_ready(cx),
}
}
#[inline]
fn process_read_buf(&self, buf: &ReadBuf<'_>) -> io::Result<usize> {
match self.inner {
InnerTlsFilter::Server(ref f) => f.process_read_buf(buf),
InnerTlsFilter::Client(ref f) => f.process_read_buf(buf),
}
}
#[inline]
fn process_write_buf(&self, buf: &WriteBuf<'_>) -> io::Result<()> {
match self.inner {
InnerTlsFilter::Server(ref f) => f.process_write_buf(buf),
InnerTlsFilter::Client(ref f) => f.process_write_buf(buf),
}
}
}
#[derive(Debug)]
pub struct TlsAcceptor {
cfg: Arc<ServerConfig>,
timeout: Millis,
}
impl TlsAcceptor {
/// Create openssl acceptor filter factory
pub fn new(cfg: Arc<ServerConfig>) -> Self {
TlsAcceptor {
cfg,
timeout: Millis(5_000),
}
}
/// Set handshake timeout.
///
/// Default is set to 5 seconds.
pub fn timeout<U: Into<Millis>>(&mut self, timeout: U) -> &mut Self {
self.timeout = timeout.into();
self
}
}
impl From<ServerConfig> for TlsAcceptor {
fn from(cfg: ServerConfig) -> Self {
Self::new(Arc::new(cfg))
}
}
impl Clone for TlsAcceptor {
fn clone(&self) -> Self {
Self {
cfg: self.cfg.clone(),
timeout: self.timeout,
}
}
}
impl<F: Filter> FilterFactory<F> for TlsAcceptor {
type Filter = TlsFilter;
type Error = io::Error;
type Future = BoxFuture<'static, Result<Io<Layer<Self::Filter, F>>, io::Error>>;
fn create(self, st: Io<F>) -> Self::Future {
let cfg = self.cfg.clone();
let timeout = self.timeout;
Box::pin(async move { TlsServerFilter::create(st, cfg, timeout).await })
}
}
#[derive(Debug)]
pub struct TlsConnector {
cfg: Arc<ClientConfig>,
}
impl TlsConnector {
/// Create openssl connector filter factory
pub fn new(cfg: Arc<ClientConfig>) -> Self {
TlsConnector { cfg }
}
/// Set server name
pub fn server_name(self, server_name: ServerName) -> TlsConnectorConfigured {
TlsConnectorConfigured {
server_name,
cfg: self.cfg,
}
}
}
impl Clone for TlsConnector {
fn clone(&self) -> Self {
Self {
cfg: self.cfg.clone(),
}
}
}
#[derive(Debug)]
pub struct TlsConnectorConfigured {
cfg: Arc<ClientConfig>,
server_name: ServerName,
}
impl Clone for TlsConnectorConfigured {
fn clone(&self) -> Self {
Self {
cfg: self.cfg.clone(),
server_name: self.server_name.clone(),
}
}
}
impl<F: Filter> FilterFactory<F> for TlsConnectorConfigured {
type Filter = TlsFilter;
type Error = io::Error;
type Future = BoxFuture<'static, Result<Io<Layer<Self::Filter, F>>, io::Error>>;
fn create(self, st: Io<F>) -> Self::Future {
let cfg = self.cfg;
let server_name = self.server_name;
Box::pin(async move { TlsClientFilter::create(st, cfg, server_name).await })
}
}
pub(crate) struct Wrapper<'a, 'b>(&'a WriteBuf<'b>);
impl<'a, 'b> io::Read for Wrapper<'a, 'b> {

View file

@ -7,14 +7,13 @@ use ntex_io::{types, Filter, FilterLayer, Io, Layer, ReadBuf, WriteBuf};
use ntex_util::{ready, time, time::Millis};
use tls_rust::{ServerConfig, ServerConnection};
use crate::rustls::{TlsFilter, Wrapper};
use crate::Servername;
use super::{PeerCert, PeerCertChain};
use super::{PeerCert, PeerCertChain, Wrapper};
#[derive(Debug)]
/// An implementation of SSL streams
pub(crate) struct TlsServerFilter {
pub struct TlsServerFilter {
session: RefCell<ServerConnection>,
}
@ -121,23 +120,23 @@ impl FilterLayer for TlsServerFilter {
}
impl TlsServerFilter {
pub(crate) async fn create<F: Filter>(
pub async fn create<F: Filter>(
io: Io<F>,
cfg: Arc<ServerConfig>,
timeout: Millis,
) -> Result<Io<Layer<TlsFilter, F>>, io::Error> {
) -> Result<Io<Layer<TlsServerFilter, F>>, io::Error> {
time::timeout(timeout, async {
let session = ServerConnection::new(cfg)
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
let filter = TlsFilter::new_server(TlsServerFilter {
let filter = TlsServerFilter {
session: RefCell::new(session),
});
};
let io = io.add_filter(filter);
let filter = io.filter();
loop {
let (result, wants_read, handshaking) = io.with_buf(|buf| {
let mut session = filter.server().session.borrow_mut();
let mut session = filter.session.borrow_mut();
let mut wrp = Wrapper(buf);
let mut result = (
session.complete_io(&mut wrp),

View file

@ -1,5 +1,9 @@
# Changes
## [1.0.0-b.1] - 2024-01-xx
* Remove unnecessary 'static
## [1.0.0-b.0] - 2024-01-07
* Use "async fn" in trait for Service definition

View file

@ -1,6 +1,6 @@
[package]
name = "ntex-util"
version = "1.0.0-b.0"
version = "1.0.0-b.1"
authors = ["ntex contributors <team@ntex.rs>"]
description = "Utilities for ntex framework"
keywords = ["network", "framework", "async", "futures"]

View file

@ -53,7 +53,7 @@ impl<R, E, F> fmt::Debug for KeepAlive<R, E, F> {
}
}
impl<R, E, F, C: 'static> ServiceFactory<R, C> for KeepAlive<R, E, F>
impl<R, E, F, C> ServiceFactory<R, C> for KeepAlive<R, E, F>
where
F: Fn() -> E + Clone,
{

View file

@ -1,5 +1,9 @@
# Changes
## [1.0.0-b.1] - 2024-01-08
* Refactor io tls filters
## [1.0.0-b.0] - 2024-01-07
* Use "async fn" in trait for Service definition

View file

@ -1,6 +1,6 @@
[package]
name = "ntex"
version = "1.0.0-b.0"
version = "1.0.0-b.1"
authors = ["ntex contributors <team@ntex.rs>"]
description = "Framework for composable network services"
readme = "README.md"
@ -49,7 +49,7 @@ async-std = ["ntex-rt/async-std", "ntex-async-std", "ntex-connect/async-std"]
[dependencies]
ntex-codec = "0.6.2"
ntex-connect = "1.0.0-b.0"
ntex-connect = "1.0.0-b.1"
ntex-http = "0.1.11"
ntex-router = "0.5.2"
ntex-service = "2.0.0-b.0"
@ -58,8 +58,8 @@ ntex-util = "1.0.0-b.0"
ntex-bytes = "0.1.21"
ntex-h2 = "0.5.0-b.0"
ntex-rt = "0.4.11"
ntex-io = "1.0.0-b.0"
ntex-tls = "1.0.0-b.0"
ntex-io = "1.0.0-b.1"
ntex-tls = "1.0.0-b.1"
ntex-tokio = { version = "0.4.0-b.0", optional = true }
ntex-glommio = { version = "0.4.0-b.0", optional = true }
ntex-async-std = { version = "0.4.0-b.0", optional = true }

View file

@ -21,7 +21,7 @@ pub struct Encoder<B> {
fut: Option<JoinHandle<Result<ContentEncoder, io::Error>>>,
}
impl<B: MessageBody + 'static> Encoder<B> {
impl<B: MessageBody> Encoder<B> {
pub fn response(
encoding: ContentEncoding,
head: &mut ResponseHead,

View file

@ -47,8 +47,8 @@ where
#[cfg(feature = "openssl")]
mod openssl {
use ntex_tls::openssl::{Acceptor, SslFilter};
use tls_openssl::ssl::SslAcceptor;
use ntex_tls::openssl::{SslAcceptor, SslFilter};
use tls_openssl::ssl;
use super::*;
use crate::{io::Layer, server::SslError};
@ -72,14 +72,14 @@ mod openssl {
/// Create openssl based service
pub fn openssl(
self,
acceptor: SslAcceptor,
acceptor: ssl::SslAcceptor,
) -> impl ServiceFactory<
Io<F>,
Response = (),
Error = SslError<DispatchError>,
InitError = (),
> {
Acceptor::new(acceptor)
SslAcceptor::new(acceptor)
.timeout(self.cfg.ssl_handshake_timeout)
.map_err(SslError::Ssl)
.map_init_err(|_| panic!())
@ -92,13 +92,13 @@ mod openssl {
mod rustls {
use std::fmt;
use ntex_tls::rustls::{Acceptor, TlsFilter};
use ntex_tls::rustls::{TlsAcceptor, TlsServerFilter};
use tls_rustls::ServerConfig;
use super::*;
use crate::{io::Layer, server::SslError};
impl<F, S, B, X, U> H1Service<Layer<TlsFilter, F>, S, B, X, U>
impl<F, S, B, X, U> H1Service<Layer<TlsServerFilter, F>, S, B, X, U>
where
F: Filter,
S: ServiceFactory<Request> + 'static,
@ -109,7 +109,7 @@ mod rustls {
X: ServiceFactory<Request, Response = Request> + 'static,
X::Error: ResponseError,
X::InitError: fmt::Debug,
U: ServiceFactory<(Request, Io<Layer<TlsFilter, F>>, Codec), Response = ()>
U: ServiceFactory<(Request, Io<Layer<TlsServerFilter, F>>, Codec), Response = ()>
+ 'static,
U::Error: fmt::Display + Error,
U::InitError: fmt::Debug,
@ -124,7 +124,7 @@ mod rustls {
Error = SslError<DispatchError>,
InitError = (),
> {
Acceptor::from(config)
TlsAcceptor::from(config)
.timeout(self.cfg.ssl_handshake_timeout)
.map_err(|e| SslError::Ssl(Box::new(e)))
.map_init_err(|_| panic!())

View file

@ -44,8 +44,8 @@ where
#[cfg(feature = "openssl")]
mod openssl {
use ntex_tls::openssl::{Acceptor, SslFilter};
use tls_openssl::ssl::SslAcceptor;
use ntex_tls::openssl::{SslAcceptor, SslFilter};
use tls_openssl::ssl;
use crate::{io::Layer, server::SslError};
@ -62,14 +62,14 @@ mod openssl {
/// Create ssl based service
pub fn openssl(
self,
acceptor: SslAcceptor,
acceptor: ssl::SslAcceptor,
) -> impl ServiceFactory<
Io<F>,
Response = (),
Error = SslError<DispatchError>,
InitError = S::InitError,
> {
Acceptor::new(acceptor)
SslAcceptor::new(acceptor)
.timeout(self.cfg.ssl_handshake_timeout)
.map_err(SslError::Ssl)
.map_init_err(|_| panic!())
@ -80,13 +80,13 @@ mod openssl {
#[cfg(feature = "rustls")]
mod rustls {
use ntex_tls::rustls::{Acceptor, TlsFilter};
use ntex_tls::rustls::{TlsAcceptor, TlsServerFilter};
use tls_rustls::ServerConfig;
use super::*;
use crate::{io::Layer, server::SslError};
impl<F, S, B> H2Service<Layer<TlsFilter, F>, S, B>
impl<F, S, B> H2Service<Layer<TlsServerFilter, F>, S, B>
where
F: Filter,
S: ServiceFactory<Request> + 'static,
@ -107,7 +107,7 @@ mod rustls {
let protos = vec!["h2".to_string().into()];
config.alpn_protocols = protos;
Acceptor::from(config)
TlsAcceptor::from(config)
.timeout(self.cfg.ssl_handshake_timeout)
.map_err(|e| SslError::Ssl(Box::new(e)))
.map_init_err(|_| panic!())

View file

@ -139,8 +139,8 @@ where
#[cfg(feature = "openssl")]
mod openssl {
use ntex_tls::openssl::{Acceptor, SslFilter};
use tls_openssl::ssl::SslAcceptor;
use ntex_tls::openssl::{SslAcceptor, SslFilter};
use tls_openssl::ssl;
use super::*;
use crate::{io::Layer, server::SslError};
@ -164,14 +164,14 @@ mod openssl {
/// Create openssl based service
pub fn openssl(
self,
acceptor: SslAcceptor,
acceptor: ssl::SslAcceptor,
) -> impl ServiceFactory<
Io<F>,
Response = (),
Error = SslError<DispatchError>,
InitError = (),
> {
Acceptor::new(acceptor)
SslAcceptor::new(acceptor)
.timeout(self.cfg.ssl_handshake_timeout)
.map_err(SslError::Ssl)
.map_init_err(|_| panic!())
@ -182,13 +182,13 @@ mod openssl {
#[cfg(feature = "rustls")]
mod rustls {
use ntex_tls::rustls::{Acceptor, TlsFilter};
use ntex_tls::rustls::{TlsAcceptor, TlsServerFilter};
use tls_rustls::ServerConfig;
use super::*;
use crate::{io::Layer, server::SslError};
impl<F, S, B, X, U> HttpService<Layer<TlsFilter, F>, S, B, X, U>
impl<F, S, B, X, U> HttpService<Layer<TlsServerFilter, F>, S, B, X, U>
where
F: Filter,
S: ServiceFactory<Request> + 'static,
@ -199,8 +199,10 @@ mod rustls {
X: ServiceFactory<Request, Response = Request> + 'static,
X::Error: ResponseError,
X::InitError: fmt::Debug,
U: ServiceFactory<(Request, Io<Layer<TlsFilter, F>>, h1::Codec), Response = ()>
+ 'static,
U: ServiceFactory<
(Request, Io<Layer<TlsServerFilter, F>>, h1::Codec),
Response = (),
> + 'static,
U::Error: fmt::Display + error::Error,
U::InitError: fmt::Debug,
{
@ -217,7 +219,7 @@ mod rustls {
let protos = vec!["h2".to_string().into(), "http/1.1".to_string().into()];
config.alpn_protocols = protos;
Acceptor::from(config)
TlsAcceptor::from(config)
.timeout(self.cfg.ssl_handshake_timeout)
.map_err(|e| SslError::Ssl(Box::new(e)))
.map_init_err(|_| panic!())

View file

@ -106,7 +106,6 @@ pub struct DefaultHeadersMiddleware<S> {
impl<S, E> Service<WebRequest<E>> for DefaultHeadersMiddleware<S>
where
S: Service<WebRequest<E>, Response = WebResponse>,
E: 'static,
{
type Response = WebResponse;
type Error = S::Error;

View file

@ -538,7 +538,7 @@ where
pub fn rustls(
&mut self,
config: std::sync::Arc<rustls::ClientConfig>,
) -> WsClientBuilder<Layer<rustls::TlsFilter>, rustls::Connector<Uri>> {
) -> WsClientBuilder<Layer<rustls::TlsClientFilter>, rustls::Connector<Uri>> {
self.connector(rustls::Connector::from(config))
}

View file

@ -20,4 +20,4 @@ pub use self::frame::Parser;
pub use self::handshake::{handshake, handshake_response, verify_handshake};
pub use self::proto::{hash_key, CloseCode, CloseReason, OpCode};
pub use self::sink::WsSink;
pub use self::transport::{WsTransport, WsTransportFactory};
pub use self::transport::{WsTransport, WsTransportService};

View file

@ -2,8 +2,9 @@
use std::{cell::Cell, cmp, io, task::Poll};
use crate::codec::{Decoder, Encoder};
use crate::io::{Filter, FilterFactory, FilterLayer, Io, Layer, ReadBuf, WriteBuf};
use crate::util::{BufMut, PoolRef, Ready};
use crate::io::{Filter, FilterLayer, Io, Layer, ReadBuf, WriteBuf};
use crate::service::{Service, ServiceCtx};
use crate::util::{BufMut, PoolRef};
use super::{CloseCode, CloseReason, Codec, Frame, Item, Message};
@ -174,25 +175,27 @@ impl FilterLayer for WsTransport {
}
#[derive(Clone, Debug)]
/// WebSockets transport factory
pub struct WsTransportFactory {
/// WebSockets transport service
pub struct WsTransportService {
codec: Codec,
}
impl WsTransportFactory {
/// Create websockets transport factory
impl WsTransportService {
/// Create websockets transport service
pub fn new(codec: Codec) -> Self {
Self { codec }
}
}
impl<F: Filter> FilterFactory<F> for WsTransportFactory {
type Filter = WsTransport;
impl<F: Filter> Service<Io<F>> for WsTransportService {
type Response = Io<Layer<WsTransport, F>>;
type Error = io::Error;
type Future = Ready<Io<Layer<Self::Filter, F>>, Self::Error>;
fn create(self, io: Io<F>) -> Self::Future {
Ready::Ok(WsTransport::create(io, self.codec))
async fn call(
&self,
io: Io<F>,
_: ServiceCtx<'_, Self>,
) -> Result<Self::Response, Self::Error> {
Ok(WsTransport::create(io, self.codec.clone()))
}
}

View file

@ -84,7 +84,7 @@ async fn test_openssl_string() {
assert!(res.is_ok());
Ok(io)
}))
.and_then(openssl::Acceptor::new(ssl_acceptor()))
.and_then(openssl::SslAcceptor::new(ssl_acceptor()))
.and_then(fn_service(|io: Io<_>| async move {
io.send(Bytes::from_static(b"test"), &BytesCodec)
.await
@ -127,7 +127,7 @@ async fn test_openssl_read_before_error() {
assert!(res.is_ok());
Ok(io)
}))
.and_then(openssl::Acceptor::new(ssl_acceptor()))
.and_then(openssl::SslAcceptor::new(ssl_acceptor()))
.and_then(fn_service(|io: Io<_>| async move {
io.send(Bytes::from_static(b"test"), &Rc::new(BytesCodec))
.await
@ -168,7 +168,7 @@ async fn test_rustls_string() {
assert!(res.is_ok());
Ok(io)
}))
.and_then(rustls::Acceptor::new(tls_acceptor()))
.and_then(rustls::TlsAcceptor::new(tls_acceptor()))
.and_then(fn_service(|io: Io<_>| async move {
assert!(io.query::<PeerCert>().as_ref().is_none());
assert!(io.query::<PeerCertChain>().as_ref().is_none());