mirror of
https://github.com/ntex-rs/ntex.git
synced 2025-04-03 21:07:39 +03:00
Refactor filter factories (#278)
This commit is contained in:
parent
a13f677df8
commit
174b5d86f0
34 changed files with 271 additions and 657 deletions
|
@ -1,5 +1,9 @@
|
|||
# Changes
|
||||
|
||||
## [1.0.0-b.1] - 2024-01-08
|
||||
|
||||
* Refactor io tls filters
|
||||
|
||||
## [1.0.0-b.0] - 2024-01-07
|
||||
|
||||
* Use "async fn" in trait for Service definition
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "ntex-tls"
|
||||
version = "1.0.0-b.0"
|
||||
version = "1.0.0-b.1"
|
||||
authors = ["ntex contributors <team@ntex.rs>"]
|
||||
description = "An implementation of SSL streams for ntex backed by OpenSSL"
|
||||
keywords = ["network", "framework", "async", "futures"]
|
||||
|
@ -26,8 +26,8 @@ rustls = ["tls_rust"]
|
|||
|
||||
[dependencies]
|
||||
ntex-bytes = "0.1.21"
|
||||
ntex-io = "1.0.0-b.0"
|
||||
ntex-util = "1.0.0-b.0"
|
||||
ntex-io = "1.0.0-b.1"
|
||||
ntex-util = "1.0.0-b.1"
|
||||
ntex-service = "2.0.0-b.0"
|
||||
log = "0.4"
|
||||
pin-project-lite = "0.2"
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use std::{fs::File, io, io::BufReader, sync::Arc};
|
||||
|
||||
use ntex::service::{chain_factory, fn_service};
|
||||
use ntex::{codec, io::filter, io::Io, server, util::Either};
|
||||
use ntex::{codec, io::Io, server, util::Either};
|
||||
use ntex_tls::rustls::TlsAcceptor;
|
||||
use rustls_pemfile::{certs, rsa_private_keys};
|
||||
use tls_rust::{Certificate, PrivateKey, ServerConfig};
|
||||
|
@ -34,8 +34,8 @@ async fn main() -> io::Result<()> {
|
|||
// start server
|
||||
server::ServerBuilder::new()
|
||||
.bind("basic", "127.0.0.1:8443", move |_| {
|
||||
chain_factory(filter(TlsAcceptor::new(tls_config.clone()))).and_then(
|
||||
fn_service(|io: Io<_>| async move {
|
||||
chain_factory(TlsAcceptor::new(tls_config.clone())).and_then(fn_service(
|
||||
|io: Io<_>| async move {
|
||||
println!("New client is connected");
|
||||
|
||||
io.send(
|
||||
|
@ -62,8 +62,8 @@ async fn main() -> io::Result<()> {
|
|||
}
|
||||
println!("Client is disconnected");
|
||||
Ok(())
|
||||
}),
|
||||
)
|
||||
},
|
||||
))
|
||||
})?
|
||||
.workers(1)
|
||||
.run()
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use std::io;
|
||||
|
||||
use ntex::service::{chain_factory, fn_service};
|
||||
use ntex::{codec, io::filter, io::Io, server, util::Either};
|
||||
use ntex::{codec, io::Io, server, util::Either};
|
||||
use ntex_tls::openssl::{PeerCert, PeerCertChain, SslAcceptor};
|
||||
use tls_openssl::ssl::{self, SslFiletype, SslMethod, SslVerifyMode};
|
||||
|
||||
|
@ -27,7 +27,7 @@ async fn main() -> io::Result<()> {
|
|||
// start server
|
||||
server::ServerBuilder::new()
|
||||
.bind("basic", "127.0.0.1:8443", move |_| {
|
||||
chain_factory(filter(SslAcceptor::new(acceptor.clone()))).and_then(fn_service(
|
||||
chain_factory(SslAcceptor::new(acceptor.clone())).and_then(fn_service(
|
||||
|io: Io<_>| async move {
|
||||
println!("New client is connected");
|
||||
if let Some(cert) = io.query::<PeerCert>().as_ref() {
|
||||
|
|
|
@ -1,31 +1,29 @@
|
|||
use std::task::{Context, Poll};
|
||||
use std::{error::Error, marker::PhantomData};
|
||||
use std::{cell::RefCell, error::Error, fmt, io, task::Context, task::Poll};
|
||||
|
||||
use ntex_io::{Filter, FilterFactory, Io, Layer};
|
||||
use ntex_io::{Filter, Io, Layer};
|
||||
use ntex_service::{Service, ServiceCtx, ServiceFactory};
|
||||
use ntex_util::time::Millis;
|
||||
use tls_openssl::ssl::SslAcceptor;
|
||||
use ntex_util::time::{self, Millis};
|
||||
use tls_openssl::ssl;
|
||||
|
||||
use crate::counter::Counter;
|
||||
use crate::MAX_SSL_ACCEPT_COUNTER;
|
||||
|
||||
use super::{SslAcceptor as IoSslAcceptor, SslFilter};
|
||||
use super::SslFilter;
|
||||
|
||||
#[derive(Debug)]
|
||||
/// Support `TLS` server connections via openssl package
|
||||
///
|
||||
/// `openssl` feature enables `Acceptor` type
|
||||
pub struct Acceptor<F> {
|
||||
acceptor: IoSslAcceptor,
|
||||
_t: PhantomData<F>,
|
||||
pub struct SslAcceptor {
|
||||
acceptor: ssl::SslAcceptor,
|
||||
timeout: Millis,
|
||||
}
|
||||
|
||||
impl<F> Acceptor<F> {
|
||||
impl SslAcceptor {
|
||||
/// Create default openssl acceptor service
|
||||
pub fn new(acceptor: SslAcceptor) -> Self {
|
||||
Acceptor {
|
||||
acceptor: IoSslAcceptor::new(acceptor),
|
||||
_t: PhantomData,
|
||||
pub fn new(acceptor: ssl::SslAcceptor) -> Self {
|
||||
SslAcceptor {
|
||||
acceptor,
|
||||
timeout: Millis(5_000),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -33,54 +31,69 @@ impl<F> Acceptor<F> {
|
|||
///
|
||||
/// Default is set to 5 seconds.
|
||||
pub fn timeout<U: Into<Millis>>(mut self, timeout: U) -> Self {
|
||||
self.acceptor.timeout(timeout);
|
||||
self.timeout = timeout.into();
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> From<SslAcceptor> for Acceptor<F> {
|
||||
fn from(acceptor: SslAcceptor) -> Self {
|
||||
impl fmt::Debug for SslAcceptor {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("SslAcceptor")
|
||||
.field("timeout", &self.timeout)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ssl::SslAcceptor> for SslAcceptor {
|
||||
fn from(acceptor: ssl::SslAcceptor) -> Self {
|
||||
Self::new(acceptor)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Clone for Acceptor<F> {
|
||||
impl Clone for SslAcceptor {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
acceptor: self.acceptor.clone(),
|
||||
_t: PhantomData,
|
||||
timeout: self.timeout,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Filter, C: 'static> ServiceFactory<Io<F>, C> for Acceptor<F> {
|
||||
impl<F: Filter, C> ServiceFactory<Io<F>, C> for SslAcceptor {
|
||||
type Response = Io<Layer<SslFilter, F>>;
|
||||
type Error = Box<dyn Error>;
|
||||
type Service = AcceptorService<F>;
|
||||
type Service = SslAcceptorService;
|
||||
type InitError = ();
|
||||
|
||||
async fn create(&self, _: C) -> Result<Self::Service, Self::InitError> {
|
||||
MAX_SSL_ACCEPT_COUNTER.with(|conns| {
|
||||
Ok(AcceptorService {
|
||||
Ok(SslAcceptorService {
|
||||
acceptor: self.acceptor.clone(),
|
||||
timeout: self.timeout,
|
||||
conns: conns.clone(),
|
||||
_t: PhantomData,
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
/// Support `TLS` server connections via openssl package
|
||||
///
|
||||
/// `openssl` feature enables `Acceptor` type
|
||||
pub struct AcceptorService<F> {
|
||||
acceptor: IoSslAcceptor,
|
||||
pub struct SslAcceptorService {
|
||||
acceptor: ssl::SslAcceptor,
|
||||
timeout: Millis,
|
||||
conns: Counter,
|
||||
_t: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: Filter> Service<Io<F>> for AcceptorService<F> {
|
||||
impl fmt::Debug for SslAcceptorService {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("SslAcceptorService")
|
||||
.field("timeout", &self.timeout)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Filter> Service<Io<F>> for SslAcceptorService {
|
||||
type Response = Io<Layer<SslFilter, F>>;
|
||||
type Error = Box<dyn Error>;
|
||||
|
||||
|
@ -94,10 +107,40 @@ impl<F: Filter> Service<Io<F>> for AcceptorService<F> {
|
|||
|
||||
async fn call(
|
||||
&self,
|
||||
req: Io<F>,
|
||||
io: Io<F>,
|
||||
_: ServiceCtx<'_, Self>,
|
||||
) -> Result<Self::Response, Self::Error> {
|
||||
let _guard = self.conns.get();
|
||||
self.acceptor.clone().create(req).await
|
||||
let timeout = self.timeout;
|
||||
let ctx_result = ssl::Ssl::new(self.acceptor.context());
|
||||
|
||||
time::timeout(timeout, async {
|
||||
let ssl = ctx_result.map_err(super::map_to_ioerr)?;
|
||||
let inner = super::IoInner {
|
||||
source: None,
|
||||
destination: None,
|
||||
};
|
||||
let filter = SslFilter {
|
||||
inner: RefCell::new(ssl::SslStream::new(ssl, inner)?),
|
||||
};
|
||||
let io = io.add_filter(filter);
|
||||
|
||||
log::debug!("Accepting tls connection");
|
||||
loop {
|
||||
let result = io.with_buf(|buf| {
|
||||
let filter = io.filter();
|
||||
filter.with_buffers(buf, || filter.inner.borrow_mut().accept())
|
||||
})?;
|
||||
if super::handle_result(&io, result).await?.is_some() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(io)
|
||||
})
|
||||
.await
|
||||
.map_err(|_| {
|
||||
io::Error::new(io::ErrorKind::TimedOut, "ssl handshake timeout").into()
|
||||
})
|
||||
.and_then(|item| item)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,17 +1,15 @@
|
|||
//! An implementation of SSL streams for ntex backed by OpenSSL
|
||||
use std::cell::RefCell;
|
||||
use std::{any, cmp, error::Error, fmt, io, task::Poll};
|
||||
use std::{any, cell::RefCell, cmp, error::Error, io, task::Poll};
|
||||
|
||||
use ntex_bytes::{BufMut, BytesVec};
|
||||
use ntex_io::{types, Filter, FilterFactory, FilterLayer, Io, Layer, ReadBuf, WriteBuf};
|
||||
use ntex_util::{future::BoxFuture, time, time::Millis};
|
||||
use ntex_io::{types, Filter, FilterLayer, Io, Layer, ReadBuf, WriteBuf};
|
||||
use tls_openssl::ssl::{self, NameType, SslStream};
|
||||
use tls_openssl::x509::X509;
|
||||
|
||||
use crate::{PskIdentity, Servername};
|
||||
|
||||
mod accept;
|
||||
pub use self::accept::{Acceptor, AcceptorService};
|
||||
pub use self::accept::{SslAcceptor, SslAcceptorService};
|
||||
|
||||
/// Connection's peer cert
|
||||
#[derive(Debug)]
|
||||
|
@ -211,132 +209,31 @@ impl FilterLayer for SslFilter {
|
|||
}
|
||||
}
|
||||
|
||||
pub struct SslAcceptor {
|
||||
acceptor: ssl::SslAcceptor,
|
||||
timeout: Millis,
|
||||
}
|
||||
|
||||
impl SslAcceptor {
|
||||
/// Create openssl acceptor filter factory
|
||||
pub fn new(acceptor: ssl::SslAcceptor) -> Self {
|
||||
SslAcceptor {
|
||||
acceptor,
|
||||
timeout: Millis(5_000),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set handshake timeout.
|
||||
///
|
||||
/// Default is set to 5 seconds.
|
||||
pub fn timeout<U: Into<Millis>>(&mut self, timeout: U) -> &mut Self {
|
||||
self.timeout = timeout.into();
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for SslAcceptor {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
acceptor: self.acceptor.clone(),
|
||||
timeout: self.timeout,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for SslAcceptor {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("SslAcceptor")
|
||||
.field("timeout", &self.timeout)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Filter> FilterFactory<F> for SslAcceptor {
|
||||
type Filter = SslFilter;
|
||||
|
||||
type Error = Box<dyn Error>;
|
||||
type Future = BoxFuture<'static, Result<Io<Layer<Self::Filter, F>>, Self::Error>>;
|
||||
|
||||
fn create(self, io: Io<F>) -> Self::Future {
|
||||
let timeout = self.timeout;
|
||||
let ctx_result = ssl::Ssl::new(self.acceptor.context());
|
||||
|
||||
Box::pin(async move {
|
||||
time::timeout(timeout, async {
|
||||
let ssl = ctx_result.map_err(map_to_ioerr)?;
|
||||
let inner = IoInner {
|
||||
source: None,
|
||||
destination: None,
|
||||
};
|
||||
let filter = SslFilter {
|
||||
inner: RefCell::new(ssl::SslStream::new(ssl, inner)?),
|
||||
};
|
||||
let io = io.add_filter(filter);
|
||||
|
||||
log::debug!("Accepting tls connection");
|
||||
loop {
|
||||
let result = io.with_buf(|buf| {
|
||||
let filter = io.filter();
|
||||
filter.with_buffers(buf, || filter.inner.borrow_mut().accept())
|
||||
})?;
|
||||
if handle_result(&io, result).await?.is_some() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(io)
|
||||
})
|
||||
.await
|
||||
.map_err(|_| {
|
||||
io::Error::new(io::ErrorKind::TimedOut, "ssl handshake timeout").into()
|
||||
})
|
||||
.and_then(|item| item)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct SslConnector {
|
||||
/// Create openssl connector filter factory
|
||||
pub async fn connect<F: Filter>(
|
||||
io: Io<F>,
|
||||
ssl: ssl::Ssl,
|
||||
}
|
||||
) -> Result<Io<Layer<SslFilter, F>>, io::Error> {
|
||||
let inner = IoInner {
|
||||
source: None,
|
||||
destination: None,
|
||||
};
|
||||
let filter = SslFilter {
|
||||
inner: RefCell::new(ssl::SslStream::new(ssl, inner)?),
|
||||
};
|
||||
let io = io.add_filter(filter);
|
||||
|
||||
impl SslConnector {
|
||||
/// Create openssl connector filter factory
|
||||
pub fn new(ssl: ssl::Ssl) -> Self {
|
||||
SslConnector { ssl }
|
||||
loop {
|
||||
let result = io.with_buf(|buf| {
|
||||
let filter = io.filter();
|
||||
filter.with_buffers(buf, || filter.inner.borrow_mut().connect())
|
||||
})?;
|
||||
if handle_result(&io, result).await?.is_some() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Filter> FilterFactory<F> for SslConnector {
|
||||
type Filter = SslFilter;
|
||||
|
||||
type Error = Box<dyn Error>;
|
||||
type Future = BoxFuture<'static, Result<Io<Layer<Self::Filter, F>>, Self::Error>>;
|
||||
|
||||
fn create(self, io: Io<F>) -> Self::Future {
|
||||
Box::pin(async move {
|
||||
let inner = IoInner {
|
||||
source: None,
|
||||
destination: None,
|
||||
};
|
||||
let filter = SslFilter {
|
||||
inner: RefCell::new(ssl::SslStream::new(self.ssl, inner)?),
|
||||
};
|
||||
let io = io.add_filter(filter);
|
||||
|
||||
loop {
|
||||
let result = io.with_buf(|buf| {
|
||||
let filter = io.filter();
|
||||
filter.with_buffers(buf, || filter.inner.borrow_mut().connect())
|
||||
})?;
|
||||
if handle_result(&io, result).await?.is_some() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(io)
|
||||
})
|
||||
}
|
||||
Ok(io)
|
||||
}
|
||||
|
||||
async fn handle_result<T, F>(
|
||||
|
|
|
@ -1,30 +1,30 @@
|
|||
use std::task::{Context, Poll};
|
||||
use std::{io, marker::PhantomData, sync::Arc};
|
||||
use std::{io, sync::Arc};
|
||||
|
||||
use tls_rust::ServerConfig;
|
||||
|
||||
use ntex_io::{Filter, FilterFactory, Io, Layer};
|
||||
use ntex_io::{Filter, Io, Layer};
|
||||
use ntex_service::{Service, ServiceCtx, ServiceFactory};
|
||||
use ntex_util::time::Millis;
|
||||
|
||||
use super::{TlsAcceptor, TlsFilter};
|
||||
use super::TlsServerFilter;
|
||||
use crate::{counter::Counter, MAX_SSL_ACCEPT_COUNTER};
|
||||
|
||||
#[derive(Debug)]
|
||||
/// Support `SSL` connections via rustls package
|
||||
///
|
||||
/// `rust-tls` feature enables `RustlsAcceptor` type
|
||||
pub struct Acceptor<F> {
|
||||
inner: TlsAcceptor,
|
||||
_t: PhantomData<F>,
|
||||
pub struct TlsAcceptor {
|
||||
config: Arc<ServerConfig>,
|
||||
timeout: Millis,
|
||||
}
|
||||
|
||||
impl<F> Acceptor<F> {
|
||||
impl TlsAcceptor {
|
||||
/// Create rustls based `Acceptor` service factory
|
||||
pub fn new(config: Arc<ServerConfig>) -> Self {
|
||||
Acceptor {
|
||||
inner: TlsAcceptor::new(config),
|
||||
_t: PhantomData,
|
||||
Self {
|
||||
config,
|
||||
timeout: Millis(5_000),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -32,38 +32,38 @@ impl<F> Acceptor<F> {
|
|||
///
|
||||
/// Default is set to 5 seconds.
|
||||
pub fn timeout<U: Into<Millis>>(mut self, timeout: U) -> Self {
|
||||
self.inner.timeout(timeout.into());
|
||||
self.timeout = timeout.into();
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> From<ServerConfig> for Acceptor<F> {
|
||||
impl From<ServerConfig> for TlsAcceptor {
|
||||
fn from(cfg: ServerConfig) -> Self {
|
||||
Self::new(Arc::new(cfg))
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Clone for Acceptor<F> {
|
||||
impl Clone for TlsAcceptor {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
inner: self.inner.clone(),
|
||||
_t: PhantomData,
|
||||
config: self.config.clone(),
|
||||
timeout: self.timeout,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Filter, C: 'static> ServiceFactory<Io<F>, C> for Acceptor<F> {
|
||||
type Response = Io<Layer<TlsFilter, F>>;
|
||||
impl<F: Filter, C> ServiceFactory<Io<F>, C> for TlsAcceptor {
|
||||
type Response = Io<Layer<TlsServerFilter, F>>;
|
||||
type Error = io::Error;
|
||||
type Service = AcceptorService<F>;
|
||||
type Service = TlsAcceptorService;
|
||||
type InitError = ();
|
||||
|
||||
async fn create(&self, _: C) -> Result<Self::Service, Self::InitError> {
|
||||
MAX_SSL_ACCEPT_COUNTER.with(|conns| {
|
||||
Ok(AcceptorService {
|
||||
acceptor: self.inner.clone(),
|
||||
Ok(TlsAcceptorService {
|
||||
config: self.config.clone(),
|
||||
timeout: self.timeout,
|
||||
conns: conns.clone(),
|
||||
io: PhantomData,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
@ -71,14 +71,14 @@ impl<F: Filter, C: 'static> ServiceFactory<Io<F>, C> for Acceptor<F> {
|
|||
|
||||
#[derive(Debug)]
|
||||
/// RusTLS based `Acceptor` service
|
||||
pub struct AcceptorService<F> {
|
||||
acceptor: TlsAcceptor,
|
||||
io: PhantomData<F>,
|
||||
pub struct TlsAcceptorService {
|
||||
config: Arc<ServerConfig>,
|
||||
timeout: Millis,
|
||||
conns: Counter,
|
||||
}
|
||||
|
||||
impl<F: Filter> Service<Io<F>> for AcceptorService<F> {
|
||||
type Response = Io<Layer<TlsFilter, F>>;
|
||||
impl<F: Filter> Service<Io<F>> for TlsAcceptorService {
|
||||
type Response = Io<Layer<TlsServerFilter, F>>;
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
|
@ -91,10 +91,10 @@ impl<F: Filter> Service<Io<F>> for AcceptorService<F> {
|
|||
|
||||
async fn call(
|
||||
&self,
|
||||
req: Io<F>,
|
||||
io: Io<F>,
|
||||
_: ServiceCtx<'_, Self>,
|
||||
) -> Result<Self::Response, Self::Error> {
|
||||
let _guard = self.conns.get();
|
||||
self.acceptor.clone().create(req).await
|
||||
super::TlsServerFilter::create(io, self.config.clone(), self.timeout).await
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,13 +7,11 @@ use ntex_io::{types, Filter, FilterLayer, Io, Layer, ReadBuf, WriteBuf};
|
|||
use ntex_util::ready;
|
||||
use tls_rust::{ClientConfig, ClientConnection, ServerName};
|
||||
|
||||
use crate::rustls::{TlsFilter, Wrapper};
|
||||
|
||||
use super::{PeerCert, PeerCertChain};
|
||||
use super::{PeerCert, PeerCertChain, Wrapper};
|
||||
|
||||
#[derive(Debug)]
|
||||
/// An implementation of SSL streams
|
||||
pub(crate) struct TlsClientFilter {
|
||||
pub struct TlsClientFilter {
|
||||
session: RefCell<ClientConnection>,
|
||||
}
|
||||
|
||||
|
@ -114,22 +112,22 @@ impl FilterLayer for TlsClientFilter {
|
|||
}
|
||||
|
||||
impl TlsClientFilter {
|
||||
pub(crate) async fn create<F: Filter>(
|
||||
pub async fn create<F: Filter>(
|
||||
io: Io<F>,
|
||||
cfg: Arc<ClientConfig>,
|
||||
domain: ServerName,
|
||||
) -> Result<Io<Layer<TlsFilter, F>>, io::Error> {
|
||||
) -> Result<Io<Layer<TlsClientFilter, F>>, io::Error> {
|
||||
let session = ClientConnection::new(cfg, domain)
|
||||
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
|
||||
let filter = TlsFilter::new_client(TlsClientFilter {
|
||||
let filter = TlsClientFilter {
|
||||
session: RefCell::new(session),
|
||||
});
|
||||
};
|
||||
let io = io.add_filter(filter);
|
||||
|
||||
let filter = io.filter();
|
||||
loop {
|
||||
let (result, wants_read, handshaking) = io.with_buf(|buf| {
|
||||
let mut session = filter.client().session.borrow_mut();
|
||||
let mut session = filter.session.borrow_mut();
|
||||
let mut wrp = Wrapper(buf);
|
||||
let mut result = (
|
||||
session.complete_io(&mut wrp),
|
||||
|
|
|
@ -1,21 +1,16 @@
|
|||
#![allow(clippy::type_complexity)]
|
||||
//! An implementation of SSL streams for ntex backed by OpenSSL
|
||||
use std::{any, cmp, io, sync::Arc, task::Context, task::Poll};
|
||||
use std::{cmp, io};
|
||||
|
||||
use ntex_io::{
|
||||
Filter, FilterFactory, FilterLayer, Io, Layer, ReadBuf, ReadStatus, WriteBuf,
|
||||
WriteStatus,
|
||||
};
|
||||
use ntex_util::{future::BoxFuture, time::Millis};
|
||||
use tls_rust::{Certificate, ClientConfig, ServerConfig, ServerName};
|
||||
use ntex_io::WriteBuf;
|
||||
use tls_rust::Certificate;
|
||||
|
||||
mod accept;
|
||||
mod client;
|
||||
mod server;
|
||||
pub use accept::{Acceptor, AcceptorService};
|
||||
pub use accept::{TlsAcceptor, TlsAcceptorService};
|
||||
|
||||
use self::client::TlsClientFilter;
|
||||
use self::server::TlsServerFilter;
|
||||
pub use self::client::TlsClientFilter;
|
||||
pub use self::server::TlsServerFilter;
|
||||
|
||||
/// Connection's peer cert
|
||||
#[derive(Debug)]
|
||||
|
@ -25,203 +20,6 @@ pub struct PeerCert(pub Certificate);
|
|||
#[derive(Debug)]
|
||||
pub struct PeerCertChain(pub Vec<Certificate>);
|
||||
|
||||
#[derive(Debug)]
|
||||
/// An implementation of SSL streams
|
||||
pub struct TlsFilter {
|
||||
inner: InnerTlsFilter,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum InnerTlsFilter {
|
||||
Server(TlsServerFilter),
|
||||
Client(TlsClientFilter),
|
||||
}
|
||||
|
||||
impl TlsFilter {
|
||||
fn new_server(server: TlsServerFilter) -> Self {
|
||||
TlsFilter {
|
||||
inner: InnerTlsFilter::Server(server),
|
||||
}
|
||||
}
|
||||
fn new_client(client: TlsClientFilter) -> Self {
|
||||
TlsFilter {
|
||||
inner: InnerTlsFilter::Client(client),
|
||||
}
|
||||
}
|
||||
fn server(&self) -> &TlsServerFilter {
|
||||
match self.inner {
|
||||
InnerTlsFilter::Server(ref server) => server,
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
fn client(&self) -> &TlsClientFilter {
|
||||
match self.inner {
|
||||
InnerTlsFilter::Client(ref server) => server,
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FilterLayer for TlsFilter {
|
||||
#[inline]
|
||||
fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> {
|
||||
match self.inner {
|
||||
InnerTlsFilter::Server(ref f) => f.query(id),
|
||||
InnerTlsFilter::Client(ref f) => f.query(id),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn shutdown(&self, buf: &WriteBuf<'_>) -> io::Result<Poll<()>> {
|
||||
match self.inner {
|
||||
InnerTlsFilter::Server(ref f) => f.shutdown(buf),
|
||||
InnerTlsFilter::Client(ref f) => f.shutdown(buf),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> {
|
||||
match self.inner {
|
||||
InnerTlsFilter::Server(ref f) => f.poll_read_ready(cx),
|
||||
InnerTlsFilter::Client(ref f) => f.poll_read_ready(cx),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> {
|
||||
match self.inner {
|
||||
InnerTlsFilter::Server(ref f) => f.poll_write_ready(cx),
|
||||
InnerTlsFilter::Client(ref f) => f.poll_write_ready(cx),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn process_read_buf(&self, buf: &ReadBuf<'_>) -> io::Result<usize> {
|
||||
match self.inner {
|
||||
InnerTlsFilter::Server(ref f) => f.process_read_buf(buf),
|
||||
InnerTlsFilter::Client(ref f) => f.process_read_buf(buf),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn process_write_buf(&self, buf: &WriteBuf<'_>) -> io::Result<()> {
|
||||
match self.inner {
|
||||
InnerTlsFilter::Server(ref f) => f.process_write_buf(buf),
|
||||
InnerTlsFilter::Client(ref f) => f.process_write_buf(buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TlsAcceptor {
|
||||
cfg: Arc<ServerConfig>,
|
||||
timeout: Millis,
|
||||
}
|
||||
|
||||
impl TlsAcceptor {
|
||||
/// Create openssl acceptor filter factory
|
||||
pub fn new(cfg: Arc<ServerConfig>) -> Self {
|
||||
TlsAcceptor {
|
||||
cfg,
|
||||
timeout: Millis(5_000),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set handshake timeout.
|
||||
///
|
||||
/// Default is set to 5 seconds.
|
||||
pub fn timeout<U: Into<Millis>>(&mut self, timeout: U) -> &mut Self {
|
||||
self.timeout = timeout.into();
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ServerConfig> for TlsAcceptor {
|
||||
fn from(cfg: ServerConfig) -> Self {
|
||||
Self::new(Arc::new(cfg))
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for TlsAcceptor {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
cfg: self.cfg.clone(),
|
||||
timeout: self.timeout,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Filter> FilterFactory<F> for TlsAcceptor {
|
||||
type Filter = TlsFilter;
|
||||
|
||||
type Error = io::Error;
|
||||
type Future = BoxFuture<'static, Result<Io<Layer<Self::Filter, F>>, io::Error>>;
|
||||
|
||||
fn create(self, st: Io<F>) -> Self::Future {
|
||||
let cfg = self.cfg.clone();
|
||||
let timeout = self.timeout;
|
||||
|
||||
Box::pin(async move { TlsServerFilter::create(st, cfg, timeout).await })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TlsConnector {
|
||||
cfg: Arc<ClientConfig>,
|
||||
}
|
||||
|
||||
impl TlsConnector {
|
||||
/// Create openssl connector filter factory
|
||||
pub fn new(cfg: Arc<ClientConfig>) -> Self {
|
||||
TlsConnector { cfg }
|
||||
}
|
||||
|
||||
/// Set server name
|
||||
pub fn server_name(self, server_name: ServerName) -> TlsConnectorConfigured {
|
||||
TlsConnectorConfigured {
|
||||
server_name,
|
||||
cfg: self.cfg,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for TlsConnector {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
cfg: self.cfg.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TlsConnectorConfigured {
|
||||
cfg: Arc<ClientConfig>,
|
||||
server_name: ServerName,
|
||||
}
|
||||
|
||||
impl Clone for TlsConnectorConfigured {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
cfg: self.cfg.clone(),
|
||||
server_name: self.server_name.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Filter> FilterFactory<F> for TlsConnectorConfigured {
|
||||
type Filter = TlsFilter;
|
||||
|
||||
type Error = io::Error;
|
||||
type Future = BoxFuture<'static, Result<Io<Layer<Self::Filter, F>>, io::Error>>;
|
||||
|
||||
fn create(self, st: Io<F>) -> Self::Future {
|
||||
let cfg = self.cfg;
|
||||
let server_name = self.server_name;
|
||||
|
||||
Box::pin(async move { TlsClientFilter::create(st, cfg, server_name).await })
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct Wrapper<'a, 'b>(&'a WriteBuf<'b>);
|
||||
|
||||
impl<'a, 'b> io::Read for Wrapper<'a, 'b> {
|
||||
|
|
|
@ -7,14 +7,13 @@ use ntex_io::{types, Filter, FilterLayer, Io, Layer, ReadBuf, WriteBuf};
|
|||
use ntex_util::{ready, time, time::Millis};
|
||||
use tls_rust::{ServerConfig, ServerConnection};
|
||||
|
||||
use crate::rustls::{TlsFilter, Wrapper};
|
||||
use crate::Servername;
|
||||
|
||||
use super::{PeerCert, PeerCertChain};
|
||||
use super::{PeerCert, PeerCertChain, Wrapper};
|
||||
|
||||
#[derive(Debug)]
|
||||
/// An implementation of SSL streams
|
||||
pub(crate) struct TlsServerFilter {
|
||||
pub struct TlsServerFilter {
|
||||
session: RefCell<ServerConnection>,
|
||||
}
|
||||
|
||||
|
@ -121,23 +120,23 @@ impl FilterLayer for TlsServerFilter {
|
|||
}
|
||||
|
||||
impl TlsServerFilter {
|
||||
pub(crate) async fn create<F: Filter>(
|
||||
pub async fn create<F: Filter>(
|
||||
io: Io<F>,
|
||||
cfg: Arc<ServerConfig>,
|
||||
timeout: Millis,
|
||||
) -> Result<Io<Layer<TlsFilter, F>>, io::Error> {
|
||||
) -> Result<Io<Layer<TlsServerFilter, F>>, io::Error> {
|
||||
time::timeout(timeout, async {
|
||||
let session = ServerConnection::new(cfg)
|
||||
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
|
||||
let filter = TlsFilter::new_server(TlsServerFilter {
|
||||
let filter = TlsServerFilter {
|
||||
session: RefCell::new(session),
|
||||
});
|
||||
};
|
||||
let io = io.add_filter(filter);
|
||||
|
||||
let filter = io.filter();
|
||||
loop {
|
||||
let (result, wants_read, handshaking) = io.with_buf(|buf| {
|
||||
let mut session = filter.server().session.borrow_mut();
|
||||
let mut session = filter.session.borrow_mut();
|
||||
let mut wrp = Wrapper(buf);
|
||||
let mut result = (
|
||||
session.complete_io(&mut wrp),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue