Refactor ntex-io (#164)

* Refactor Io and Filter types
This commit is contained in:
Nikolay Kim 2023-01-23 14:42:00 +06:00 committed by GitHub
parent 8cbd8758a5
commit 83d05d81ef
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
55 changed files with 1615 additions and 1495 deletions

View file

@ -1,6 +1,6 @@
[package] [package]
name = "ntex-async-std" name = "ntex-async-std"
version = "0.2.0" version = "0.2.1"
authors = ["ntex contributors <team@ntex.rs>"] authors = ["ntex contributors <team@ntex.rs>"]
description = "async-std intergration for ntex framework" description = "async-std intergration for ntex framework"
keywords = ["network", "framework", "async", "futures"] keywords = ["network", "framework", "async", "futures"]
@ -16,8 +16,8 @@ name = "ntex_async_std"
path = "src/lib.rs" path = "src/lib.rs"
[dependencies] [dependencies]
ntex-bytes = "0.1.11" ntex-bytes = "0.1.19"
ntex-io = "0.2.0" ntex-io = "0.2.1"
ntex-util = "0.2.0" ntex-util = "0.2.0"
async-oneshot = "0.5.0" async-oneshot = "0.5.0"
log = "0.4" log = "0.4"

View file

@ -1,4 +1,4 @@
use std::{any, future::Future, io, pin::Pin, task::Context, task::Poll}; use std::{any, cell::RefCell, future::Future, io, pin::Pin, task::Context, task::Poll};
use async_std::io::{Read, Write}; use async_std::io::{Read, Write};
use ntex_bytes::{Buf, BufMut, BytesVec}; use ntex_bytes::{Buf, BufMut, BytesVec};
@ -30,35 +30,31 @@ impl Handle for TcpStream {
/// Read io task /// Read io task
struct ReadTask { struct ReadTask {
io: TcpStream, io: RefCell<TcpStream>,
state: ReadContext, state: ReadContext,
} }
impl ReadTask { impl ReadTask {
/// Create new read io task /// Create new read io task
fn new(io: TcpStream, state: ReadContext) -> Self { fn new(io: TcpStream, state: ReadContext) -> Self {
Self { io, state } Self {
state,
io: RefCell::new(io),
}
} }
} }
impl Future for ReadTask { impl Future for ReadTask {
type Output = (); type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.as_mut(); let this = self.as_ref();
loop { this.state.with_buf(|buf, hw, lw| {
match ready!(this.state.poll_ready(cx)) { match ready!(this.state.poll_ready(cx)) {
ReadStatus::Ready => { ReadStatus::Ready => {
let pool = this.state.memory_pool();
let mut buf = this.state.get_read_buf();
let io = &mut this.io;
let (hw, lw) = pool.read_params().unpack();
// read data from socket // read data from socket
let mut new_bytes = 0; let mut io = self.io.borrow_mut();
let mut close = false;
let mut pending = false;
loop { loop {
// make sure we've got room // make sure we've got room
let remaining = buf.remaining_mut(); let remaining = buf.remaining_mut();
@ -66,52 +62,31 @@ impl Future for ReadTask {
buf.reserve(hw - remaining); buf.reserve(hw - remaining);
} }
match poll_read_buf(Pin::new(&mut io.0), cx, &mut buf) { return match poll_read_buf(Pin::new(&mut io.0), cx, buf) {
Poll::Pending => { Poll::Pending => Poll::Pending,
pending = true;
break;
}
Poll::Ready(Ok(n)) => { Poll::Ready(Ok(n)) => {
if n == 0 { if n == 0 {
log::trace!("async-std stream is disconnected"); log::trace!("async-std stream is disconnected");
close = true; Poll::Ready(Ok(()))
} else if buf.len() < hw {
continue;
} else { } else {
new_bytes += n; Poll::Pending
if new_bytes <= hw {
continue;
}
} }
break;
} }
Poll::Ready(Err(err)) => { Poll::Ready(Err(err)) => {
log::trace!("read task failed on io {:?}", err); log::trace!("async-std read task failed on io {:?}", err);
this.state.release_read_buf(buf, new_bytes); Poll::Ready(Err(err))
this.state.close(Some(err));
return Poll::Ready(());
} }
} };
} }
if new_bytes == 0 && close {
this.state.close(None);
return Poll::Ready(());
}
this.state.release_read_buf(buf, new_bytes);
return if close {
this.state.close(None);
Poll::Ready(())
} else if pending {
Poll::Pending
} else {
continue;
};
} }
ReadStatus::Terminate => { ReadStatus::Terminate => {
log::trace!("read task is instructed to shutdown"); log::trace!("read task is instructed to shutdown");
return Poll::Ready(()); Poll::Ready(Ok(()))
} }
} }
} })
} }
} }
@ -358,10 +333,6 @@ pub fn poll_read_buf<T: Read>(
cx: &mut Context<'_>, cx: &mut Context<'_>,
buf: &mut BytesVec, buf: &mut BytesVec,
) -> Poll<io::Result<usize>> { ) -> Poll<io::Result<usize>> {
if !buf.has_remaining_mut() {
return Poll::Ready(Ok(0));
}
let dst = unsafe { &mut *(buf.chunk_mut() as *mut _ as *mut [u8]) }; let dst = unsafe { &mut *(buf.chunk_mut() as *mut _ as *mut [u8]) };
let n = ready!(io.poll_read(cx, dst))?; let n = ready!(io.poll_read(cx, dst))?;
@ -389,35 +360,31 @@ mod unixstream {
/// Read io task /// Read io task
struct ReadTask { struct ReadTask {
io: UnixStream, io: RefCell<UnixStream>,
state: ReadContext, state: ReadContext,
} }
impl ReadTask { impl ReadTask {
/// Create new read io task /// Create new read io task
fn new(io: UnixStream, state: ReadContext) -> Self { fn new(io: UnixStream, state: ReadContext) -> Self {
Self { io, state } Self {
state,
io: RefCell::new(io),
}
} }
} }
impl Future for ReadTask { impl Future for ReadTask {
type Output = (); type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.as_mut(); let this = self.as_ref();
loop { this.state.with_buf(|buf, hw, lw| {
match ready!(this.state.poll_ready(cx)) { match ready!(this.state.poll_ready(cx)) {
ReadStatus::Ready => { ReadStatus::Ready => {
let pool = this.state.memory_pool();
let mut buf = this.state.get_read_buf();
let io = &mut this.io;
let (hw, lw) = pool.read_params().unpack();
// read data from socket // read data from socket
let mut new_bytes = 0; let mut io = this.io.borrow_mut();
let mut close = false;
let mut pending = false;
loop { loop {
// make sure we've got room // make sure we've got room
let remaining = buf.remaining_mut(); let remaining = buf.remaining_mut();
@ -425,52 +392,31 @@ mod unixstream {
buf.reserve(hw - remaining); buf.reserve(hw - remaining);
} }
match poll_read_buf(Pin::new(&mut io.0), cx, &mut buf) { return match poll_read_buf(Pin::new(&mut io.0), cx, buf) {
Poll::Pending => { Poll::Pending => Poll::Pending,
pending = true;
break;
}
Poll::Ready(Ok(n)) => { Poll::Ready(Ok(n)) => {
if n == 0 { if n == 0 {
log::trace!("async-std stream is disconnected"); log::trace!("async-std stream is disconnected");
close = true; Poll::Ready(Ok(()))
} else if buf.len() < hw {
continue;
} else { } else {
new_bytes += n; Poll::Pending
if new_bytes <= hw {
continue;
}
} }
break;
} }
Poll::Ready(Err(err)) => { Poll::Ready(Err(err)) => {
log::trace!("read task failed on io {:?}", err); log::trace!("read task failed on io {:?}", err);
this.state.release_read_buf(buf, new_bytes); Poll::Ready(Err(err))
this.state.close(Some(err));
return Poll::Ready(());
} }
} };
} }
if new_bytes == 0 && close {
this.state.close(None);
return Poll::Ready(());
}
this.state.release_read_buf(buf, new_bytes);
return if close {
this.state.close(None);
Poll::Ready(())
} else if pending {
Poll::Pending
} else {
continue;
};
} }
ReadStatus::Terminate => { ReadStatus::Terminate => {
log::trace!("read task is instructed to shutdown"); log::trace!("read task is instructed to shutdown");
return Poll::Ready(()); Poll::Ready(Ok(()))
} }
} }
} })
} }
} }

View file

@ -1,5 +1,9 @@
# Changes # Changes
## [0.1.19] (2023-01-23)
* Add PollRef::resize_read_buf() and PollRef::resize_write_buf() helpers
## [0.1.18] (2022-12-13) ## [0.1.18] (2022-12-13)
* Add Bytes<&Bytes> for Bytes impl * Add Bytes<&Bytes> for Bytes impl

View file

@ -1,6 +1,6 @@
[package] [package]
name = "ntex-bytes" name = "ntex-bytes"
version = "0.1.18" version = "0.1.19"
license = "MIT" license = "MIT"
authors = ["Nikolay Kim <fafhrd91@gmail.com>", "Carl Lerche <me@carllerche.com>"] authors = ["Nikolay Kim <fafhrd91@gmail.com>", "Carl Lerche <me@carllerche.com>"]
description = "Types and traits for working with bytes (bytes crate fork)" description = "Types and traits for working with bytes (bytes crate fork)"

View file

@ -6,7 +6,7 @@ use std::{cell::Cell, cell::RefCell, fmt, future::Future, mem, pin::Pin, ptr, rc
use futures_core::task::__internal::AtomicWaker; use futures_core::task::__internal::AtomicWaker;
use crate::{BytesMut, BytesVec}; use crate::{BufMut, BytesMut, BytesVec};
pub struct Pool { pub struct Pool {
idx: Cell<usize>, idx: Cell<usize>,
@ -293,6 +293,17 @@ impl PoolRef {
} }
} }
#[doc(hidden)]
#[inline]
/// Resize read buffer
pub fn resize_read_buf(self, buf: &mut BytesVec) {
let (hw, lw) = self.0.write_wm.get().unpack();
let remaining = buf.remaining_mut();
if remaining < lw {
buf.reserve(hw - remaining);
}
}
#[doc(hidden)] #[doc(hidden)]
#[inline] #[inline]
/// Release read buffer, buf must be allocated from this pool /// Release read buffer, buf must be allocated from this pool
@ -318,6 +329,17 @@ impl PoolRef {
} }
} }
#[doc(hidden)]
#[inline]
/// Resize write buffer
pub fn resize_write_buf(self, buf: &mut BytesVec) {
let (hw, lw) = self.0.write_wm.get().unpack();
let remaining = buf.remaining_mut();
if remaining < lw {
buf.reserve(hw - remaining);
}
}
#[doc(hidden)] #[doc(hidden)]
#[inline] #[inline]
/// Release write buffer, buf must be allocated from this pool /// Release write buffer, buf must be allocated from this pool

View file

@ -1,5 +1,9 @@
# Changes # Changes
## [0.2.1] - 2023-01-23
* Use new Io object
## [0.2.0] - 2023-01-04 ## [0.2.0] - 2023-01-04
* Release * Release

View file

@ -1,6 +1,6 @@
[package] [package]
name = "ntex-connect" name = "ntex-connect"
version = "0.2.0" version = "0.2.1"
authors = ["ntex contributors <team@ntex.rs>"] authors = ["ntex contributors <team@ntex.rs>"]
description = "ntexwork connect utils for ntex framework" description = "ntexwork connect utils for ntex framework"
keywords = ["network", "framework", "async", "futures"] keywords = ["network", "framework", "async", "futures"]
@ -35,18 +35,18 @@ async-std = ["ntex-rt/async-std", "ntex-async-std"]
[dependencies] [dependencies]
ntex-service = "1.0.0" ntex-service = "1.0.0"
ntex-bytes = "0.1.18" ntex-bytes = "0.1.19"
ntex-http = "0.1.8" ntex-http = "0.1.8"
ntex-io = "0.2.0" ntex-io = "0.2.1"
ntex-rt = "0.4.7" ntex-rt = "0.4.7"
ntex-tls = "0.2.0" ntex-tls = "0.2.1"
ntex-util = "0.2.0" ntex-util = "0.2.0"
log = "0.4" log = "0.4"
thiserror = "1.0" thiserror = "1.0"
ntex-tokio = { version = "0.2.0", optional = true } ntex-tokio = { version = "0.2.1", optional = true }
ntex-glommio = { version = "0.2.0", optional = true } ntex-glommio = { version = "0.2.1", optional = true }
ntex-async-std = { version = "0.2.0", optional = true } ntex-async-std = { version = "0.2.1", optional = true }
# openssl # openssl
tls-openssl = { version="0.10", package = "openssl", optional = true } tls-openssl = { version="0.10", package = "openssl", optional = true }

View file

@ -4,7 +4,7 @@ pub use ntex_tls::openssl::SslFilter;
pub use tls_openssl::ssl::{Error as SslError, HandshakeError, SslConnector, SslMethod}; pub use tls_openssl::ssl::{Error as SslError, HandshakeError, SslConnector, SslMethod};
use ntex_bytes::PoolId; use ntex_bytes::PoolId;
use ntex_io::{Base, Io}; use ntex_io::{FilterFactory, Io, Layer};
use ntex_service::{Service, ServiceFactory}; use ntex_service::{Service, ServiceFactory};
use ntex_tls::openssl::SslConnector as IoSslConnector; use ntex_tls::openssl::SslConnector as IoSslConnector;
use ntex_util::future::{BoxFuture, Ready}; use ntex_util::future::{BoxFuture, Ready};
@ -39,7 +39,7 @@ impl<T> Connector<T> {
impl<T: Address> Connector<T> { impl<T: Address> Connector<T> {
/// Resolve and connect to remote host /// Resolve and connect to remote host
pub async fn connect<U>(&self, message: U) -> Result<Io<SslFilter<Base>>, ConnectError> pub async fn connect<U>(&self, message: U) -> Result<Io<Layer<SslFilter>>, ConnectError>
where where
Connect<T>: From<U>, Connect<T>: From<U>,
{ {
@ -57,7 +57,7 @@ impl<T: Address> Connector<T> {
let ssl = config let ssl = config
.into_ssl(&host) .into_ssl(&host)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
match io.add_filter(IoSslConnector::new(ssl)).await { match IoSslConnector::new(ssl).create(io).await {
Ok(io) => { Ok(io) => {
trace!("SSL Handshake success: {:?}", host); trace!("SSL Handshake success: {:?}", host);
Ok(io) Ok(io)
@ -82,7 +82,7 @@ impl<T> Clone for Connector<T> {
} }
impl<T: Address, C: 'static> ServiceFactory<Connect<T>, C> for Connector<T> { impl<T: Address, C: 'static> ServiceFactory<Connect<T>, C> for Connector<T> {
type Response = Io<SslFilter<Base>>; type Response = Io<Layer<SslFilter>>;
type Error = ConnectError; type Error = ConnectError;
type Service = Connector<T>; type Service = Connector<T>;
type InitError = (); type InitError = ();
@ -95,7 +95,7 @@ impl<T: Address, C: 'static> ServiceFactory<Connect<T>, C> for Connector<T> {
} }
impl<T: Address> Service<Connect<T>> for Connector<T> { impl<T: Address> Service<Connect<T>> for Connector<T> {
type Response = Io<SslFilter<Base>>; type Response = Io<Layer<SslFilter>>;
type Error = ConnectError; type Error = ConnectError;
type Future<'f> = BoxFuture<'f, Result<Self::Response, Self::Error>>; type Future<'f> = BoxFuture<'f, Result<Self::Response, Self::Error>>;

View file

@ -4,7 +4,7 @@ pub use ntex_tls::rustls::TlsFilter;
pub use tls_rustls::{ClientConfig, ServerName}; pub use tls_rustls::{ClientConfig, ServerName};
use ntex_bytes::PoolId; use ntex_bytes::PoolId;
use ntex_io::{Base, Io}; use ntex_io::{FilterFactory, Io, Layer};
use ntex_service::{Service, ServiceFactory}; use ntex_service::{Service, ServiceFactory};
use ntex_tls::rustls::TlsConnector; use ntex_tls::rustls::TlsConnector;
use ntex_util::future::{BoxFuture, Ready}; use ntex_util::future::{BoxFuture, Ready};
@ -48,7 +48,7 @@ impl<T> Connector<T> {
impl<T: Address + 'static> Connector<T> { impl<T: Address + 'static> Connector<T> {
/// Resolve and connect to remote host /// Resolve and connect to remote host
pub async fn connect<U>(&self, message: U) -> Result<Io<TlsFilter<Base>>, ConnectError> pub async fn connect<U>(&self, message: U) -> Result<Io<Layer<TlsFilter>>, ConnectError>
where where
Connect<T>: From<U>, Connect<T>: From<U>,
{ {
@ -64,7 +64,7 @@ impl<T: Address + 'static> Connector<T> {
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("{}", e)))?; .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("{}", e)))?;
let connector = connector.server_name(host.clone()); let connector = connector.server_name(host.clone());
match io.add_filter(connector).await { match connector.create(io).await {
Ok(io) => { Ok(io) => {
trace!("TLS Handshake success: {:?}", &host); trace!("TLS Handshake success: {:?}", &host);
Ok(io) Ok(io)
@ -87,7 +87,7 @@ impl<T> Clone for Connector<T> {
} }
impl<T: Address, C: 'static> ServiceFactory<Connect<T>, C> for Connector<T> { impl<T: Address, C: 'static> ServiceFactory<Connect<T>, C> for Connector<T> {
type Response = Io<TlsFilter<Base>>; type Response = Io<Layer<TlsFilter>>;
type Error = ConnectError; type Error = ConnectError;
type Service = Connector<T>; type Service = Connector<T>;
type InitError = (); type InitError = ();
@ -100,7 +100,7 @@ impl<T: Address, C: 'static> ServiceFactory<Connect<T>, C> for Connector<T> {
} }
impl<T: Address> Service<Connect<T>> for Connector<T> { impl<T: Address> Service<Connect<T>> for Connector<T> {
type Response = Io<TlsFilter<Base>>; type Response = Io<Layer<TlsFilter>>;
type Error = ConnectError; type Error = ConnectError;
type Future<'f> = BoxFuture<'f, Result<Self::Response, Self::Error>>; type Future<'f> = BoxFuture<'f, Result<Self::Response, Self::Error>>;

View file

@ -1,6 +1,6 @@
[package] [package]
name = "ntex-glommio" name = "ntex-glommio"
version = "0.2.0" version = "0.2.1"
authors = ["ntex contributors <team@ntex.rs>"] authors = ["ntex contributors <team@ntex.rs>"]
description = "glommio intergration for ntex framework" description = "glommio intergration for ntex framework"
keywords = ["network", "framework", "async", "futures"] keywords = ["network", "framework", "async", "futures"]
@ -16,8 +16,8 @@ name = "ntex_glommio"
path = "src/lib.rs" path = "src/lib.rs"
[dependencies] [dependencies]
ntex-bytes = "0.1.18" ntex-bytes = "0.1.19"
ntex-io = "0.2.0" ntex-io = "0.2.1"
ntex-util = "0.2.0" ntex-util = "0.2.0"
async-oneshot = "0.5.0" async-oneshot = "0.5.0"
futures-lite = "1.12" futures-lite = "1.12"

View file

@ -57,17 +57,10 @@ impl Future for ReadTask {
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.as_mut(); let this = self.as_mut();
loop { this.state.with_buf(|buf, hw, lw| {
match ready!(this.state.poll_ready(cx)) { match ready!(this.state.poll_ready(cx)) {
ReadStatus::Ready => { ReadStatus::Ready => {
let pool = this.state.memory_pool();
let mut buf = this.state.get_read_buf();
let (hw, lw) = pool.read_params().unpack();
// read data from socket // read data from socket
let mut new_bytes = 0;
let mut close = false;
let mut pending = false;
loop { loop {
// make sure we've got room // make sure we've got room
let remaining = buf.remaining_mut(); let remaining = buf.remaining_mut();
@ -75,56 +68,35 @@ impl Future for ReadTask {
buf.reserve(hw - remaining); buf.reserve(hw - remaining);
} }
match poll_read_buf( return match poll_read_buf(
Pin::new(&mut *this.io.0.borrow_mut()), Pin::new(&mut *this.io.0.borrow_mut()),
cx, cx,
&mut buf, buf,
) { ) {
Poll::Pending => { Poll::Pending => Poll::Pending,
pending = true;
break;
}
Poll::Ready(Ok(n)) => { Poll::Ready(Ok(n)) => {
if n == 0 { if n == 0 {
log::trace!("glommio stream is disconnected"); log::trace!("glommio stream is disconnected");
close = true; Poll::Ready(Ok(()))
} else if buf.len() < hw {
continue;
} else { } else {
new_bytes += n; Poll::Pending
if new_bytes <= hw {
continue;
}
} }
break;
} }
Poll::Ready(Err(err)) => { Poll::Ready(Err(err)) => {
log::trace!("read task failed on io {:?}", err); log::trace!("read task failed on io {:?}", err);
let _ = this.state.release_read_buf(buf, new_bytes); Poll::Ready(Err(err))
this.state.close(Some(err));
return Poll::Ready(());
} }
} };
} }
if new_bytes == 0 && close {
this.state.close(None);
return Poll::Ready(());
}
this.state.release_read_buf(buf, new_bytes);
return if close {
this.state.close(None);
Poll::Ready(())
} else if pending {
Poll::Pending
} else {
continue;
};
} }
ReadStatus::Terminate => { ReadStatus::Terminate => {
log::trace!("read task is instructed to shutdown"); log::trace!("read task is instructed to shutdown");
return Poll::Ready(()); Poll::Ready(Ok(()))
} }
} }
} })
} }
} }
@ -372,10 +344,6 @@ pub fn poll_read_buf<T: AsyncRead>(
cx: &mut Context<'_>, cx: &mut Context<'_>,
buf: &mut BytesVec, buf: &mut BytesVec,
) -> Poll<io::Result<usize>> { ) -> Poll<io::Result<usize>> {
if !buf.has_remaining_mut() {
return Poll::Ready(Ok(0));
}
let dst = unsafe { &mut *(buf.chunk_mut() as *mut _ as *mut [u8]) }; let dst = unsafe { &mut *(buf.chunk_mut() as *mut _ as *mut [u8]) };
let n = ready!(io.poll_read(cx, dst))?; let n = ready!(io.poll_read(cx, dst))?;
@ -407,17 +375,10 @@ impl Future for UnixReadTask {
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.as_mut(); let this = self.as_mut();
loop { this.state.with_buf(|buf, hw, lw| {
match ready!(this.state.poll_ready(cx)) { match ready!(this.state.poll_ready(cx)) {
ReadStatus::Ready => { ReadStatus::Ready => {
let pool = this.state.memory_pool();
let mut buf = this.state.get_read_buf();
let (hw, lw) = pool.read_params().unpack();
// read data from socket // read data from socket
let mut new_bytes = 0;
let mut close = false;
let mut pending = false;
loop { loop {
// make sure we've got room // make sure we've got room
let remaining = buf.remaining_mut(); let remaining = buf.remaining_mut();
@ -425,56 +386,35 @@ impl Future for UnixReadTask {
buf.reserve(hw - remaining); buf.reserve(hw - remaining);
} }
match poll_read_buf( return match poll_read_buf(
Pin::new(&mut *this.io.0.borrow_mut()), Pin::new(&mut *this.io.0.borrow_mut()),
cx, cx,
&mut buf, buf,
) { ) {
Poll::Pending => { Poll::Pending => Poll::Pending,
pending = true;
break;
}
Poll::Ready(Ok(n)) => { Poll::Ready(Ok(n)) => {
if n == 0 { if n == 0 {
log::trace!("glommio stream is disconnected"); log::trace!("glommio stream is disconnected");
close = true; Poll::Ready(Ok(()))
} else if buf.len() < hw {
continue;
} else { } else {
new_bytes += n; Poll::Pending
if new_bytes <= hw {
continue;
}
} }
break;
} }
Poll::Ready(Err(err)) => { Poll::Ready(Err(err)) => {
log::trace!("read task failed on io {:?}", err); log::trace!("read task failed on io {:?}", err);
let _ = this.state.release_read_buf(buf, new_bytes); Poll::Ready(Err(err))
this.state.close(Some(err));
return Poll::Ready(());
} }
} };
} }
if new_bytes == 0 && close {
this.state.close(None);
return Poll::Ready(());
}
this.state.release_read_buf(buf, new_bytes);
return if close {
this.state.close(None);
Poll::Ready(())
} else if pending {
Poll::Pending
} else {
continue;
};
} }
ReadStatus::Terminate => { ReadStatus::Terminate => {
log::trace!("read task is instructed to shutdown"); log::trace!("read task is instructed to shutdown");
return Poll::Ready(()); Poll::Ready(Ok(()))
} }
} }
} })
} }
} }

View file

@ -1,5 +1,9 @@
# Changes # Changes
## [0.2.1] - 2023-01-23
* Refactor Io and Filter types
## [0.2.0] - 2023-01-04 ## [0.2.0] - 2023-01-04
* Release * Release

View file

@ -1,6 +1,6 @@
[package] [package]
name = "ntex-io" name = "ntex-io"
version = "0.2.0" version = "0.2.1"
authors = ["ntex contributors <team@ntex.rs>"] authors = ["ntex contributors <team@ntex.rs>"]
description = "Utilities for encoding and decoding frames" description = "Utilities for encoding and decoding frames"
keywords = ["network", "framework", "async", "futures"] keywords = ["network", "framework", "async", "futures"]
@ -17,13 +17,14 @@ path = "src/lib.rs"
[dependencies] [dependencies]
ntex-codec = "0.6.2" ntex-codec = "0.6.2"
ntex-bytes = "0.1.18" ntex-bytes = "0.1.19"
ntex-util = "0.2.0" ntex-util = "0.2.0"
ntex-service = "1.0.0" ntex-service = "1.0.0"
bitflags = "1.3" bitflags = "1.3"
log = "0.4" log = "0.4"
pin-project-lite = "0.2" pin-project-lite = "0.2"
smallvec = "1"
[dev-dependencies] [dev-dependencies]
rand = "0.8" rand = "0.8"

361
ntex-io/src/buf.rs Normal file
View file

@ -0,0 +1,361 @@
use ntex_bytes::{BytesVec, PoolRef};
use smallvec::SmallVec;
use crate::IoRef;
#[derive(Debug)]
pub struct Stack {
pub(crate) buffers: SmallVec<[(Option<BytesVec>, Option<BytesVec>); 4]>,
}
impl Stack {
pub(crate) fn new() -> Self {
let mut buffers = SmallVec::with_capacity(4);
buffers.push((None, None));
Self { buffers }
}
pub(crate) fn add_layer(&mut self) {
self.buffers.insert(0, (None, None));
}
pub(crate) fn read_buf<F, R>(
&mut self,
io: &IoRef,
idx: usize,
nbytes: usize,
f: F,
) -> R
where
F: FnOnce(&mut ReadBuf<'_>) -> R,
{
let pos = idx + 1;
if self.buffers.len() > pos {
let (curr, next) = self.buffers.split_at_mut(pos);
let mut buf = ReadBuf {
io,
nbytes,
curr: &mut curr[idx],
next: &mut next[0],
};
f(&mut buf)
} else {
let mut val1 = (self.buffers[idx].0.take(), None);
let mut val2 = (None, self.buffers[idx].1.take());
let mut buf = ReadBuf {
io,
nbytes,
curr: &mut val1,
next: &mut val2,
};
let result = f(&mut buf);
self.buffers[idx].0 = val1.0;
self.buffers[idx].1 = val2.1;
result
}
}
pub(crate) fn write_buf<F, R>(&mut self, io: &IoRef, idx: usize, f: F) -> R
where
F: FnOnce(&mut WriteBuf<'_>) -> R,
{
let pos = idx + 1;
if self.buffers.len() > pos {
let (curr, next) = self.buffers.split_at_mut(pos);
let mut buf = WriteBuf {
io,
curr: &mut curr[idx],
next: &mut next[0],
};
f(&mut buf)
} else {
let mut val1 = (self.buffers[idx].0.take(), None);
let mut val2 = (None, self.buffers[idx].1.take());
let mut buf = WriteBuf {
io,
curr: &mut val1,
next: &mut val2,
};
let result = f(&mut buf);
self.buffers[idx].0 = val1.0;
self.buffers[idx].1 = val2.1;
result
}
}
pub(crate) fn first_read_buf_size(&self) -> usize {
self.buffers[0].0.as_ref().map(|b| b.len()).unwrap_or(0)
}
pub(crate) fn first_read_buf(&mut self) -> &mut Option<BytesVec> {
&mut self.buffers[0].0
}
pub(crate) fn first_write_buf(&mut self, io: &IoRef) -> &mut BytesVec {
if self.buffers[0].1.is_none() {
self.buffers[0].1 = Some(io.memory_pool().get_write_buf());
}
self.buffers[0].1.as_mut().unwrap()
}
pub(crate) fn last_read_buf(&mut self) -> &mut Option<BytesVec> {
let idx = self.buffers.len() - 1;
&mut self.buffers[idx].0
}
pub(crate) fn last_write_buf(&mut self) -> &mut Option<BytesVec> {
let idx = self.buffers.len() - 1;
&mut self.buffers[idx].1
}
pub(crate) fn last_write_buf_size(&self) -> usize {
let idx = self.buffers.len() - 1;
self.buffers[idx].1.as_ref().map(|b| b.len()).unwrap_or(0)
}
pub(crate) fn set_last_write_buf(&mut self, buf: BytesVec) {
let idx = self.buffers.len() - 1;
self.buffers[idx].1 = Some(buf);
}
pub(crate) fn release(&mut self, pool: PoolRef) {
for buf in &mut self.buffers {
if let Some(buf) = buf.0.take() {
pool.release_read_buf(buf);
}
if let Some(buf) = buf.1.take() {
pool.release_write_buf(buf);
}
}
}
pub(crate) fn set_memory_pool(&mut self, pool: PoolRef) {
for buf in &mut self.buffers {
if let Some(ref mut b) = buf.0 {
pool.move_vec_in(b);
}
if let Some(ref mut b) = buf.1 {
pool.move_vec_in(b);
}
}
}
}
#[derive(Debug)]
pub struct ReadBuf<'a> {
pub(crate) io: &'a IoRef,
pub(crate) curr: &'a mut (Option<BytesVec>, Option<BytesVec>),
pub(crate) next: &'a mut (Option<BytesVec>, Option<BytesVec>),
pub(crate) nbytes: usize,
}
impl<'a> ReadBuf<'a> {
#[inline]
/// Get number of newly added bytes
pub fn nbytes(&self) -> usize {
self.nbytes
}
#[inline]
/// Initiate graceful io stream shutdown
pub fn want_shutdown(&self) {
self.io.want_shutdown()
}
#[inline]
/// Get reference to source read buffer
pub fn get_src(&mut self) -> &mut BytesVec {
if self.next.0.is_none() {
self.next.0 = Some(self.io.memory_pool().get_read_buf());
}
self.next.0.as_mut().unwrap()
}
#[inline]
/// Take source read buffer
pub fn take_src(&mut self) -> Option<BytesVec> {
self.next
.0
.take()
.and_then(|b| if b.is_empty() { None } else { Some(b) })
}
#[inline]
/// Set source read buffer
pub fn set_src(&mut self, src: Option<BytesVec>) {
if let Some(src) = src {
if src.is_empty() {
self.io.memory_pool().release_read_buf(src);
} else {
if let Some(b) = self.next.0.take() {
self.io.memory_pool().release_read_buf(b);
}
self.next.0 = Some(src);
}
}
}
#[inline]
/// Get reference to destination read buffer
pub fn get_dst(&mut self) -> &mut BytesVec {
if self.curr.0.is_none() {
self.curr.0 = Some(self.io.memory_pool().get_read_buf());
}
self.curr.0.as_mut().unwrap()
}
#[inline]
/// Take destination read buffer
pub fn take_dst(&mut self) -> BytesVec {
self.curr
.0
.take()
.unwrap_or_else(|| self.io.memory_pool().get_read_buf())
}
#[inline]
/// Set destination read buffer
pub fn set_dst(&mut self, dst: Option<BytesVec>) {
if let Some(dst) = dst {
if dst.is_empty() {
self.io.memory_pool().release_read_buf(dst);
} else {
if let Some(b) = self.curr.0.take() {
self.io.memory_pool().release_read_buf(b);
}
self.curr.0 = Some(dst);
}
}
}
#[inline]
/// Get reference to source and destination read buffers (src, dst)
pub fn get_pair(&mut self) -> (&mut BytesVec, &mut BytesVec) {
if self.next.0.is_none() {
self.next.0 = Some(self.io.memory_pool().get_read_buf());
}
if self.curr.0.is_none() {
self.curr.0 = Some(self.io.memory_pool().get_read_buf());
}
(self.next.0.as_mut().unwrap(), self.curr.0.as_mut().unwrap())
}
#[inline]
pub fn with_write_buf<'b, F, R>(&'b mut self, f: F) -> R
where
F: FnOnce(&mut WriteBuf<'b>) -> R,
{
let mut buf = WriteBuf {
io: self.io,
curr: self.curr,
next: self.next,
};
f(&mut buf)
}
}
#[derive(Debug)]
pub struct WriteBuf<'a> {
pub(crate) io: &'a IoRef,
pub(crate) curr: &'a mut (Option<BytesVec>, Option<BytesVec>),
pub(crate) next: &'a mut (Option<BytesVec>, Option<BytesVec>),
}
impl<'a> WriteBuf<'a> {
#[inline]
/// Initiate graceful io stream shutdown
pub fn want_shutdown(&self) {
self.io.want_shutdown()
}
#[inline]
/// Get reference to source write buffer
pub fn get_src(&mut self) -> &mut BytesVec {
if self.curr.1.is_none() {
self.curr.1 = Some(self.io.memory_pool().get_write_buf());
}
self.curr.1.as_mut().unwrap()
}
#[inline]
/// Take source write buffer
pub fn take_src(&mut self) -> Option<BytesVec> {
self.curr
.1
.take()
.and_then(|b| if b.is_empty() { None } else { Some(b) })
}
#[inline]
/// Set source write buffer
pub fn set_src(&mut self, src: Option<BytesVec>) {
if let Some(b) = self.curr.1.take() {
self.io.memory_pool().release_read_buf(b);
}
self.curr.1 = src;
}
#[inline]
/// Get reference to destination write buffer
pub fn get_dst(&mut self) -> &mut BytesVec {
if self.next.1.is_none() {
self.next.1 = Some(self.io.memory_pool().get_write_buf());
}
self.next.1.as_mut().unwrap()
}
#[inline]
/// Take destination write buffer
pub fn take_dst(&mut self) -> BytesVec {
self.next
.1
.take()
.unwrap_or_else(|| self.io.memory_pool().get_write_buf())
}
#[inline]
/// Set destination write buffer
pub fn set_dst(&mut self, dst: Option<BytesVec>) {
if let Some(dst) = dst {
if dst.is_empty() {
self.io.memory_pool().release_write_buf(dst);
} else {
if let Some(b) = self.next.1.take() {
self.io.memory_pool().release_write_buf(b);
}
self.next.1 = Some(dst);
}
}
}
#[inline]
/// Get reference to source and destination buffers (src, dst)
pub fn get_pair(&mut self) -> (&mut BytesVec, &mut BytesVec) {
if self.curr.1.is_none() {
self.curr.1 = Some(self.io.memory_pool().get_write_buf());
}
if self.next.1.is_none() {
self.next.1 = Some(self.io.memory_pool().get_write_buf());
}
(self.curr.1.as_mut().unwrap(), self.next.1.as_mut().unwrap())
}
#[inline]
pub fn with_read_buf<'b, F, R>(&'b mut self, f: F) -> R
where
F: FnOnce(&mut ReadBuf<'b>) -> R,
{
let mut buf = ReadBuf {
io: self.io,
curr: self.curr,
next: self.next,
nbytes: 0,
};
f(&mut buf)
}
}

View file

@ -316,18 +316,15 @@ where
} }
// shutdown service // shutdown service
DispatcherState::Shutdown => { DispatcherState::Shutdown => {
let err = slf.error.take();
return if this.inner.shared.service.poll_shutdown(cx).is_ready() { return if this.inner.shared.service.poll_shutdown(cx).is_ready() {
log::trace!("service shutdown is completed, stop"); log::trace!("service shutdown is completed, stop");
Poll::Ready(if let Some(err) = err { Poll::Ready(if let Some(err) = slf.error.take() {
Err(err) Err(err)
} else { } else {
Ok(()) Ok(())
}) })
} else { } else {
slf.error.set(err);
Poll::Pending Poll::Pending
}; };
} }
@ -632,9 +629,7 @@ mod tests {
// close read side // close read side
client.close().await; client.close().await;
assert!(client.is_server_dropped());
// TODO! fix
// assert!(client.is_server_dropped());
// service must be checked for readiness only once // service must be checked for readiness only once
assert_eq!(counter.get(), 1); assert_eq!(counter.get(), 1);

View file

@ -1,8 +1,6 @@
use std::{any, io, task::Context, task::Poll}; use std::{any, io, task::Context, task::Poll};
use ntex_bytes::BytesVec; use super::{buf::Stack, io::Flags, FilterLayer, IoRef, ReadStatus, WriteStatus};
use super::{io::Flags, Filter, IoRef, ReadStatus, WriteStatus};
/// Default `Io` filter /// Default `Io` filter
pub struct Base(IoRef); pub struct Base(IoRef);
@ -13,8 +11,54 @@ impl Base {
} }
} }
pub struct Layer<F, L = Base>(pub(crate) F, L);
impl<F: FilterLayer, L: Filter> Layer<F, L> {
pub(crate) fn new(f: F, l: L) -> Self {
Self(f, l)
}
}
pub(crate) struct NullFilter;
const NULL: NullFilter = NullFilter;
impl NullFilter {
pub(super) fn get() -> &'static dyn Filter {
&NULL
}
}
pub trait Filter: 'static {
fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>>;
fn process_read_buf(
&self,
io: &IoRef,
stack: &mut Stack,
idx: usize,
nbytes: usize,
) -> io::Result<usize>;
/// Process write buffer
fn process_write_buf(
&self,
io: &IoRef,
stack: &mut Stack,
idx: usize,
) -> io::Result<()>;
/// Gracefully shutdown filter
fn shutdown(&self, io: &IoRef, stack: &mut Stack, idx: usize) -> io::Result<Poll<()>>;
/// Check readiness for read operations
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus>;
/// Check readiness for write operations
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus>;
}
impl Filter for Base { impl Filter for Base {
#[inline]
fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> { fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> {
if let Some(hnd) = self.0 .0.handle.take() { if let Some(hnd) = self.0 .0.handle.take() {
let res = hnd.query(id); let res = hnd.query(id);
@ -25,11 +69,6 @@ impl Filter for Base {
} }
} }
#[inline]
fn poll_shutdown(&self) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
#[inline] #[inline]
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> { fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> {
let flags = self.0.flags(); let flags = self.0.flags();
@ -67,51 +106,128 @@ impl Filter for Base {
} }
#[inline] #[inline]
fn get_read_buf(&self) -> Option<BytesVec> { fn process_read_buf(
self.0 .0.read_buf.take() &self,
_: &IoRef,
_: &mut Stack,
_: usize,
nbytes: usize,
) -> io::Result<usize> {
Ok(nbytes)
} }
#[inline] #[inline]
fn get_write_buf(&self) -> Option<BytesVec> { fn process_write_buf(&self, _: &IoRef, s: &mut Stack, _: usize) -> io::Result<()> {
self.0 .0.write_buf.take() if let Some(buf) = s.last_write_buf() {
} if buf.len() >= self.0.memory_pool().write_params_high() {
#[inline]
fn release_read_buf(&self, buf: BytesVec) {
self.0 .0.read_buf.set(Some(buf));
}
#[inline]
fn process_read_buf(&self, _: &IoRef, n: usize) -> io::Result<(usize, usize)> {
let buf = self.0 .0.read_buf.as_ptr();
let ref_buf = unsafe { buf.as_ref().unwrap() };
let total = ref_buf.as_ref().map(|v| v.len()).unwrap_or(0);
Ok((total, n))
}
#[inline]
fn release_write_buf(&self, buf: BytesVec) -> Result<(), io::Error> {
let pool = self.0.memory_pool();
if buf.is_empty() {
pool.release_write_buf(buf);
} else {
if buf.len() >= pool.write_params_high() {
self.0 .0.insert_flags(Flags::WR_BACKPRESSURE); self.0 .0.insert_flags(Flags::WR_BACKPRESSURE);
} }
self.0 .0.write_buf.set(Some(buf));
self.0 .0.write_task.wake();
} }
Ok(()) Ok(())
} }
#[inline]
fn shutdown(&self, _: &IoRef, _: &mut Stack, _: usize) -> io::Result<Poll<()>> {
Ok(Poll::Ready(()))
}
} }
pub(crate) struct NullFilter; impl<F, L> Filter for Layer<F, L>
where
F: FilterLayer,
L: Filter,
{
#[inline]
fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> {
self.0.query(id).or_else(|| self.1.query(id))
}
const NULL: NullFilter = NullFilter; #[inline]
fn shutdown(&self, io: &IoRef, stack: &mut Stack, idx: usize) -> io::Result<Poll<()>> {
let result1 = stack.write_buf(io, idx, |buf| self.0.shutdown(buf))?;
self.process_write_buf(io, stack, idx)?;
impl NullFilter { let result2 = if F::BUFFERS {
pub(super) fn get() -> &'static dyn Filter { self.1.shutdown(io, stack, idx + 1)?
&NULL } else {
self.1.shutdown(io, stack, idx)?
};
if result1.is_pending() || result2.is_pending() {
Ok(Poll::Pending)
} else {
Ok(Poll::Ready(()))
}
}
#[inline]
fn process_read_buf(
&self,
io: &IoRef,
stack: &mut Stack,
idx: usize,
nbytes: usize,
) -> io::Result<usize> {
let nbytes = if F::BUFFERS {
self.1.process_read_buf(io, stack, idx + 1, nbytes)?
} else {
self.1.process_read_buf(io, stack, idx, nbytes)?
};
stack.read_buf(io, idx, nbytes, |buf| self.0.process_read_buf(buf))
}
#[inline]
fn process_write_buf(
&self,
io: &IoRef,
stack: &mut Stack,
idx: usize,
) -> io::Result<()> {
stack.write_buf(io, idx, |buf| self.0.process_write_buf(buf))?;
if F::BUFFERS {
self.1.process_write_buf(io, stack, idx + 1)
} else {
self.1.process_write_buf(io, stack, idx)
}
}
#[inline]
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> {
let res1 = self.0.poll_read_ready(cx);
let res2 = self.1.poll_read_ready(cx);
match res1 {
Poll::Pending => Poll::Pending,
Poll::Ready(ReadStatus::Ready) => res2,
Poll::Ready(ReadStatus::Terminate) => Poll::Ready(ReadStatus::Terminate),
}
}
#[inline]
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> {
let res1 = self.0.poll_write_ready(cx);
let res2 = self.1.poll_write_ready(cx);
match res1 {
Poll::Pending => Poll::Pending,
Poll::Ready(WriteStatus::Ready) => res2,
Poll::Ready(WriteStatus::Terminate) => Poll::Ready(WriteStatus::Terminate),
Poll::Ready(WriteStatus::Shutdown(t)) => {
if res2 == Poll::Ready(WriteStatus::Terminate) {
Poll::Ready(WriteStatus::Terminate)
} else {
Poll::Ready(WriteStatus::Shutdown(t))
}
}
Poll::Ready(WriteStatus::Timeout(t)) => match res2 {
Poll::Ready(WriteStatus::Terminate) => Poll::Ready(WriteStatus::Terminate),
Poll::Ready(WriteStatus::Shutdown(t)) => {
Poll::Ready(WriteStatus::Shutdown(t))
}
_ => Poll::Ready(WriteStatus::Timeout(t)),
},
}
} }
} }
@ -121,11 +237,6 @@ impl Filter for NullFilter {
None None
} }
#[inline]
fn poll_shutdown(&self) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
#[inline] #[inline]
fn poll_read_ready(&self, _: &mut Context<'_>) -> Poll<ReadStatus> { fn poll_read_ready(&self, _: &mut Context<'_>) -> Poll<ReadStatus> {
Poll::Ready(ReadStatus::Terminate) Poll::Ready(ReadStatus::Terminate)
@ -137,25 +248,23 @@ impl Filter for NullFilter {
} }
#[inline] #[inline]
fn get_read_buf(&self) -> Option<BytesVec> { fn process_read_buf(
None &self,
_: &IoRef,
_: &mut Stack,
_: usize,
_: usize,
) -> io::Result<usize> {
Ok(0)
} }
#[inline] #[inline]
fn get_write_buf(&self) -> Option<BytesVec> { fn process_write_buf(&self, _: &IoRef, _: &mut Stack, _: usize) -> io::Result<()> {
None
}
#[inline]
fn release_read_buf(&self, _: BytesVec) {}
#[inline]
fn process_read_buf(&self, _: &IoRef, _: usize) -> io::Result<(usize, usize)> {
Ok((0, 0))
}
#[inline]
fn release_write_buf(&self, _: BytesVec) -> Result<(), io::Error> {
Ok(()) Ok(())
} }
#[inline]
fn shutdown(&self, _: &IoRef, _: &mut Stack, _: usize) -> io::Result<Poll<()>> {
Ok(Poll::Ready(()))
}
} }

View file

@ -1,4 +1,4 @@
use std::cell::Cell; use std::cell::{Cell, RefCell};
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use std::{fmt, future::Future, hash, io, marker, mem, ops, pin::Pin, ptr, rc::Rc, time}; use std::{fmt, future::Future, hash, io, marker, mem, ops, pin::Pin, ptr, rc::Rc, time};
@ -7,10 +7,11 @@ use ntex_codec::{Decoder, Encoder};
use ntex_util::time::{now, Millis}; use ntex_util::time::{now, Millis};
use ntex_util::{future::poll_fn, future::Either, task::LocalWaker}; use ntex_util::{future::poll_fn, future::Either, task::LocalWaker};
use super::filter::{Base, NullFilter}; use crate::buf::Stack;
use super::seal::Sealed; use crate::filter::{Base, Filter, Layer, NullFilter};
use super::tasks::{ReadContext, WriteContext}; use crate::seal::Sealed;
use super::{Filter, FilterFactory, Handle, IoStatusUpdate, IoStream, RecvError}; use crate::tasks::{ReadContext, WriteContext};
use crate::{FilterLayer, Handle, IoStatusUpdate, IoStream, RecvError};
bitflags::bitflags! { bitflags::bitflags! {
pub struct Flags: u16 { pub struct Flags: u16 {
@ -59,8 +60,7 @@ pub(crate) struct IoState {
pub(super) read_task: LocalWaker, pub(super) read_task: LocalWaker,
pub(super) write_task: LocalWaker, pub(super) write_task: LocalWaker,
pub(super) dispatch_task: LocalWaker, pub(super) dispatch_task: LocalWaker,
pub(super) read_buf: Cell<Option<BytesVec>>, pub(super) buffer: RefCell<Stack>,
pub(super) write_buf: Cell<Option<BytesVec>>,
pub(super) filter: Cell<&'static dyn Filter>, pub(super) filter: Cell<&'static dyn Filter>,
pub(super) handle: Cell<Option<Box<dyn Handle>>>, pub(super) handle: Cell<Option<Box<dyn Handle>>>,
#[allow(clippy::box_collection)] #[allow(clippy::box_collection)]
@ -104,7 +104,6 @@ impl IoState {
} }
} }
#[inline]
pub(super) fn io_stopped(&self, err: Option<io::Error>) { pub(super) fn io_stopped(&self, err: Option<io::Error>) {
if err.is_some() { if err.is_some() {
self.error.set(err); self.error.set(err);
@ -119,9 +118,8 @@ impl IoState {
); );
} }
#[inline]
/// Gracefully shutdown read and write io tasks /// Gracefully shutdown read and write io tasks
pub(super) fn init_shutdown(&self, err: Option<io::Error>) { pub(super) fn init_shutdown(&self, err: Option<io::Error>, io: &IoRef) {
if err.is_some() { if err.is_some() {
self.io_stopped(err); self.io_stopped(err);
} else if !self } else if !self
@ -131,28 +129,25 @@ impl IoState {
{ {
log::trace!("initiate io shutdown {:?}", self.flags.get()); log::trace!("initiate io shutdown {:?}", self.flags.get());
self.insert_flags(Flags::IO_STOPPING_FILTERS); self.insert_flags(Flags::IO_STOPPING_FILTERS);
self.shutdown_filters(); self.shutdown_filters(io);
} }
} }
#[inline] pub(super) fn shutdown_filters(&self, io: &IoRef) {
pub(super) fn shutdown_filters(&self) {
if !self if !self
.flags .flags
.get() .get()
.intersects(Flags::IO_STOPPED | Flags::IO_STOPPING) .intersects(Flags::IO_STOPPED | Flags::IO_STOPPING)
{ {
match self.filter.get().poll_shutdown() { let mut buffer = self.buffer.borrow_mut();
Poll::Ready(Ok(())) => { match self.filter.get().shutdown(io, &mut buffer, 0) {
Ok(Poll::Ready(())) => {
self.read_task.wake(); self.read_task.wake();
self.write_task.wake(); self.write_task.wake();
self.dispatch_task.wake(); self.dispatch_task.wake();
self.insert_flags(Flags::IO_STOPPING); self.insert_flags(Flags::IO_STOPPING);
} }
Poll::Ready(Err(err)) => { Ok(Poll::Pending) => {
self.io_stopped(Some(err));
}
Poll::Pending => {
let flags = self.flags.get(); let flags = self.flags.get();
// check read buffer, if buffer is not consumed it is unlikely // check read buffer, if buffer is not consumed it is unlikely
// that filter will properly complete shutdown // that filter will properly complete shutdown
@ -165,40 +160,37 @@ impl IoState {
self.insert_flags(Flags::IO_STOPPING); self.insert_flags(Flags::IO_STOPPING);
} }
} }
} Err(err) => {
self.io_stopped(Some(err));
}
};
self.write_task.wake();
} }
} }
#[inline]
pub(super) fn with_read_buf<Fn, Ret>(&self, release: bool, f: Fn) -> Ret pub(super) fn with_read_buf<Fn, Ret>(&self, release: bool, f: Fn) -> Ret
where where
Fn: FnOnce(&mut Option<BytesVec>) -> Ret, Fn: FnOnce(&mut Option<BytesVec>) -> Ret,
{ {
let filter = self.filter.get(); // use top most buffer
let mut buf = filter.get_read_buf(); let mut buffer = self.buffer.borrow_mut();
let result = f(&mut buf); let buf = buffer.first_read_buf();
let result = f(buf);
if let Some(buf) = buf { // release buffer
if release { if release && buf.as_ref().map(|b| b.is_empty()).unwrap_or(false) {
// release buffer if let Some(b) = buf.take() {
if buf.is_empty() { self.pool.get().release_read_buf(b);
self.pool.get().release_read_buf(buf);
return result;
}
} }
filter.release_read_buf(buf);
} }
result result
} }
#[inline]
pub(super) fn with_write_buf<Fn, Ret>(&self, f: Fn) -> Ret pub(super) fn with_write_buf<Fn, Ret>(&self, f: Fn) -> Ret
where where
Fn: FnOnce(&mut Option<BytesVec>) -> Ret, Fn: FnOnce(&mut Option<BytesVec>) -> Ret,
{ {
let buf = self.write_buf.as_ptr(); f(self.buffer.borrow_mut().last_write_buf())
let ref_buf = unsafe { buf.as_mut().unwrap() };
f(ref_buf)
} }
} }
@ -221,12 +213,7 @@ impl hash::Hash for IoState {
impl Drop for IoState { impl Drop for IoState {
#[inline] #[inline]
fn drop(&mut self) { fn drop(&mut self) {
if let Some(buf) = self.read_buf.take() { self.buffer.borrow_mut().release(self.pool.get());
self.pool.get().release_read_buf(buf);
}
if let Some(buf) = self.write_buf.take() {
self.pool.get().release_write_buf(buf);
}
} }
} }
@ -248,8 +235,7 @@ impl Io {
dispatch_task: LocalWaker::new(), dispatch_task: LocalWaker::new(),
read_task: LocalWaker::new(), read_task: LocalWaker::new(),
write_task: LocalWaker::new(), write_task: LocalWaker::new(),
read_buf: Cell::new(None), buffer: RefCell::new(Stack::new()),
write_buf: Cell::new(None),
filter: Cell::new(NullFilter::get()), filter: Cell::new(NullFilter::get()),
handle: Cell::new(None), handle: Cell::new(None),
on_disconnect: Cell::new(None), on_disconnect: Cell::new(None),
@ -277,14 +263,7 @@ impl<F> Io<F> {
#[inline] #[inline]
/// Set memory pool /// Set memory pool
pub fn set_memory_pool(&self, pool: PoolRef) { pub fn set_memory_pool(&self, pool: PoolRef) {
if let Some(mut buf) = self.0 .0.read_buf.take() { self.0 .0.buffer.borrow_mut().set_memory_pool(pool);
pool.move_vec_in(&mut buf);
self.0 .0.read_buf.set(Some(buf));
}
if let Some(mut buf) = self.0 .0.write_buf.take() {
pool.move_vec_in(&mut buf);
self.0 .0.write_buf.set(Some(buf));
}
self.0 .0.pool.set(pool); self.0 .0.pool.set(pool);
} }
@ -312,8 +291,7 @@ impl<F> Io<F> {
dispatch_task: LocalWaker::new(), dispatch_task: LocalWaker::new(),
read_task: LocalWaker::new(), read_task: LocalWaker::new(),
write_task: LocalWaker::new(), write_task: LocalWaker::new(),
read_buf: Cell::new(None), buffer: RefCell::new(Stack::new()),
write_buf: Cell::new(None),
filter: Cell::new(NullFilter::get()), filter: Cell::new(NullFilter::get()),
handle: Cell::new(None), handle: Cell::new(None),
on_disconnect: Cell::new(None), on_disconnect: Cell::new(None),
@ -353,57 +331,37 @@ impl<F> Io<F> {
} }
impl<F: Filter> Io<F> { impl<F: Filter> Io<F> {
#[inline]
/// Get referece to a filter
pub fn filter(&self) -> &F {
self.1.filter()
}
#[inline] #[inline]
/// Convert current io stream into sealed version /// Convert current io stream into sealed version
pub fn seal(mut self) -> Io<Sealed> { pub fn seal(mut self) -> Io<Sealed> {
// get current filter let (filter, filter_ref) = self.1.seal();
let filter = unsafe { self.0 .0.filter.replace(filter_ref);
let filter = self.1.seal(); Io(self.0.clone(), filter)
let filter_ref: &'static dyn Filter = {
let filter: &dyn Filter = filter.0.as_ref();
std::mem::transmute(filter)
};
self.0 .0.filter.replace(filter_ref);
filter
};
Io(self.0.clone(), FilterItem::with_sealed(filter))
}
#[inline]
/// Create new filter and replace current one
pub fn add_filter<T>(self, factory: T) -> T::Future
where
T: FilterFactory<F>,
{
factory.create(self)
} }
#[inline] #[inline]
/// Map current filter with new one /// Map current filter with new one
pub fn map_filter<T, U, E>(mut self, map: U) -> Result<Io<T>, E> pub fn add_filter<U>(mut self, nf: U) -> Io<Layer<U, F>>
where where
T: Filter, U: FilterLayer,
U: FnOnce(F) -> Result<T, E>,
{ {
// replace current filter // add layer to buffers
let filter = unsafe { if U::BUFFERS {
let filter = Box::new(map(*(self.1.get_filter()))?); self.0 .0.buffer.borrow_mut().add_layer();
let filter_ref: &'static dyn Filter = { }
let filter: &dyn Filter = filter.as_ref();
std::mem::transmute(filter)
};
self.0 .0.filter.replace(filter_ref);
filter
};
Ok(Io(self.0.clone(), FilterItem::with_filter(filter))) // replace current filter
let (filter, filter_ref) = self.1.add_filter(nf);
self.0 .0.filter.replace(filter_ref);
Io(self.0.clone(), filter)
}
}
impl<F: FilterLayer, T: Filter> Io<Layer<F, T>> {
#[inline]
/// Get referece to a filter
pub fn filter(&self) -> &F {
&self.1.filter().0
} }
} }
@ -629,8 +587,10 @@ impl<F> Io<F> {
} }
} else { } else {
if !flags.contains(Flags::IO_STOPPING_FILTERS) { if !flags.contains(Flags::IO_STOPPING_FILTERS) {
self.0 .0.init_shutdown(None); self.0 .0.init_shutdown(None, &self.0);
} }
self.0 .0.read_task.wake();
self.0 .0.dispatch_task.register(cx.waker()); self.0 .0.dispatch_task.register(cx.waker());
Poll::Pending Poll::Pending
} }
@ -759,20 +719,6 @@ impl<F> FilterItem<F> {
slf slf
} }
fn with_sealed(f: Sealed) -> Self {
let mut slf = Self {
data: [0; SEALED_SIZE],
_t: marker::PhantomData,
};
unsafe {
let ptr = &mut slf.data as *mut _ as *mut Sealed;
ptr.write(f);
slf.data[KIND_IDX] |= KIND_SEALED;
}
slf
}
/// Get filter, panic if it is not filter /// Get filter, panic if it is not filter
fn filter(&self) -> &F { fn filter(&self) -> &F {
if self.data[KIND_IDX] & KIND_PTR != 0 { if self.data[KIND_IDX] & KIND_PTR != 0 {
@ -786,8 +732,8 @@ impl<F> FilterItem<F> {
} }
} }
/// Get filter, panic if it is not filter /// Get filter, panic if it is not set
fn get_filter(&mut self) -> Box<F> { fn take_filter(&mut self) -> Box<F> {
if self.data[KIND_IDX] & KIND_PTR != 0 { if self.data[KIND_IDX] & KIND_PTR != 0 {
self.data[KIND_IDX] &= KIND_UNMASK; self.data[KIND_IDX] &= KIND_UNMASK;
let ptr = &mut self.data as *mut _ as *mut *mut F; let ptr = &mut self.data as *mut _ as *mut *mut F;
@ -801,7 +747,7 @@ impl<F> FilterItem<F> {
} }
/// Get sealed, panic if it is already sealed /// Get sealed, panic if it is already sealed
fn get_sealed(&mut self) -> Sealed { fn take_sealed(&mut self) -> Sealed {
if self.data[KIND_IDX] & KIND_SEALED != 0 { if self.data[KIND_IDX] & KIND_SEALED != 0 {
self.data[KIND_IDX] &= KIND_UNMASK; self.data[KIND_IDX] &= KIND_UNMASK;
let ptr = &mut self.data as *mut _ as *mut Sealed; let ptr = &mut self.data as *mut _ as *mut Sealed;
@ -820,25 +766,54 @@ impl<F> FilterItem<F> {
fn drop_filter(&mut self) { fn drop_filter(&mut self) {
if self.data[KIND_IDX] & KIND_PTR != 0 { if self.data[KIND_IDX] & KIND_PTR != 0 {
self.get_filter(); self.take_filter();
} else if self.data[KIND_IDX] & KIND_SEALED != 0 { } else if self.data[KIND_IDX] & KIND_SEALED != 0 {
self.get_sealed(); self.take_sealed();
} }
} }
} }
impl<F: Filter> FilterItem<F> { impl<F: Filter> FilterItem<F> {
fn seal(&mut self) -> Sealed { fn add_filter<T: FilterLayer>(
if self.data[KIND_IDX] & KIND_PTR != 0 { &mut self,
Sealed(Box::new(*self.get_filter())) new: T,
) -> (FilterItem<Layer<T, F>>, &'static dyn Filter) {
let filter = Box::new(Layer::new(new, *self.take_filter()));
let filter_ref: &'static dyn Filter = {
let filter: &dyn Filter = filter.as_ref();
unsafe { std::mem::transmute(filter) }
};
(FilterItem::with_filter(filter), filter_ref)
}
fn seal(&mut self) -> (FilterItem<Sealed>, &'static dyn Filter) {
let filter = if self.data[KIND_IDX] & KIND_PTR != 0 {
Sealed(Box::new(*self.take_filter()))
} else if self.data[KIND_IDX] & KIND_SEALED != 0 { } else if self.data[KIND_IDX] & KIND_SEALED != 0 {
self.get_sealed() self.take_sealed()
} else { } else {
panic!( panic!(
"Wrong filter item {:?} expected: {:?}", "Wrong filter item {:?} expected: {:?}",
self.data[KIND_IDX], KIND_PTR self.data[KIND_IDX], KIND_PTR
); );
};
let filter_ref: &'static dyn Filter = {
let filter: &dyn Filter = filter.0.as_ref();
unsafe { std::mem::transmute(filter) }
};
let mut slf = FilterItem {
data: [0; SEALED_SIZE],
_t: marker::PhantomData,
};
unsafe {
let ptr = &mut slf.data as *mut _ as *mut Sealed;
ptr.write(filter);
slf.data[KIND_IDX] |= KIND_SEALED;
} }
(slf, filter_ref)
} }
} }

View file

@ -1,9 +1,9 @@
use std::{any, fmt, hash, io, time}; use std::{any, fmt, hash, io, time};
use ntex_bytes::{BufMut, BytesVec, PoolRef}; use ntex_bytes::{BytesVec, PoolRef};
use ntex_codec::{Decoder, Encoder}; use ntex_codec::{Decoder, Encoder};
use super::{io::Flags, timer, types, Filter, IoRef, OnDisconnect}; use super::{io::Flags, timer, types, Filter, IoRef, OnDisconnect, WriteBuf};
impl IoRef { impl IoRef {
#[inline] #[inline]
@ -49,7 +49,7 @@ impl IoRef {
/// Notify dispatcher and initiate io stream shutdown process. /// Notify dispatcher and initiate io stream shutdown process.
pub fn close(&self) { pub fn close(&self) {
self.0.insert_flags(Flags::DSP_STOP); self.0.insert_flags(Flags::DSP_STOP);
self.0.init_shutdown(None); self.0.init_shutdown(None, self);
} }
#[inline] #[inline]
@ -72,8 +72,16 @@ impl IoRef {
#[inline] #[inline]
/// Gracefully shutdown io stream /// Gracefully shutdown io stream
pub fn want_shutdown(&self, err: Option<io::Error>) { pub fn want_shutdown(&self) {
self.0.init_shutdown(err); if !self
.0
.flags
.get()
.intersects(Flags::IO_STOPPED | Flags::IO_STOPPING | Flags::IO_STOPPING_FILTERS)
{
log::trace!("initiate io shutdown {:?}", self.0.flags.get());
self.0.insert_flags(Flags::IO_STOPPING_FILTERS);
}
} }
#[inline] #[inline]
@ -96,13 +104,8 @@ impl IoRef {
if !flags.contains(Flags::IO_STOPPING) { if !flags.contains(Flags::IO_STOPPING) {
self.with_write_buf(|buf| { self.with_write_buf(|buf| {
let (hw, lw) = self.memory_pool().write_params().unpack();
// make sure we've got room // make sure we've got room
let remaining = buf.remaining_mut(); self.memory_pool().resize_write_buf(buf);
if remaining < lw {
buf.reserve(hw - remaining);
}
// encode item and wake write task // encode item and wake write task
codec.encode_vec(item, buf) codec.encode_vec(item, buf)
@ -151,21 +154,35 @@ impl IoRef {
#[inline] #[inline]
/// Get mut access to write buffer /// Get mut access to write buffer
pub fn with_write_buf<F, R>(&self, f: F) -> Result<R, io::Error> pub fn with_write_buf<F, R>(&self, f: F) -> io::Result<R>
where where
F: FnOnce(&mut BytesVec) -> R, F: FnOnce(&mut BytesVec) -> R,
{ {
let filter = self.0.filter.get(); let mut buffer = self.0.buffer.borrow_mut();
let mut buf = filter let is_write_sleep = buffer.last_write_buf_size() == 0;
.get_write_buf()
.unwrap_or_else(|| self.memory_pool().get_write_buf());
let is_write_sleep = buf.is_empty();
let result = f(&mut buf); let result = f(buffer.first_write_buf(self));
if is_write_sleep { self.0
.filter
.get()
.process_write_buf(self, &mut buffer, 0)?;
if is_write_sleep && buffer.last_write_buf_size() != 0 {
self.0.write_task.wake(); self.0.write_task.wake();
} }
filter.release_write_buf(buf)?; Ok(result)
}
#[inline]
/// Get mut access to write buffer
pub fn with_buf<F, R>(&self, f: F) -> io::Result<R>
where
F: FnOnce(&mut WriteBuf<'_>) -> R,
{
let mut b = self.0.buffer.borrow_mut();
let result = b.write_buf(self, 0, f);
self.0.filter.get().process_write_buf(self, &mut b, 0)?;
self.0.write_task.wake();
Ok(result) Ok(result)
} }
@ -240,16 +257,15 @@ impl fmt::Debug for IoRef {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::cell::{Cell, RefCell}; use std::cell::{Cell, RefCell};
use std::{future::Future, pin::Pin, rc::Rc, task::Context, task::Poll}; use std::{future::Future, pin::Pin, rc::Rc, task::Poll};
use ntex_bytes::Bytes; use ntex_bytes::Bytes;
use ntex_codec::BytesCodec; use ntex_codec::BytesCodec;
use ntex_util::future::{lazy, poll_fn, Ready}; use ntex_util::future::{lazy, poll_fn};
use ntex_util::time::{sleep, Millis}; use ntex_util::time::{sleep, Millis};
use super::*; use super::*;
use crate::testing::IoTest; use crate::{testing::IoTest, FilterLayer, Io, ReadBuf, WriteBuf};
use crate::{Filter, FilterFactory, Io, ReadStatus, WriteStatus};
const BIN: &[u8] = b"GET /test HTTP/1\r\n\r\n"; const BIN: &[u8] = b"GET /test HTTP/1\r\n\r\n";
const TEXT: &str = "GET /test HTTP/1\r\n\r\n"; const TEXT: &str = "GET /test HTTP/1\r\n\r\n";
@ -370,87 +386,28 @@ mod tests {
assert_eq!(waiter.await, ()); assert_eq!(waiter.await, ());
} }
struct Counter<F> { struct Counter {
idx: usize, idx: usize,
inner: F,
in_bytes: Rc<Cell<usize>>, in_bytes: Rc<Cell<usize>>,
out_bytes: Rc<Cell<usize>>, out_bytes: Rc<Cell<usize>>,
read_order: Rc<RefCell<Vec<usize>>>, read_order: Rc<RefCell<Vec<usize>>>,
write_order: Rc<RefCell<Vec<usize>>>, write_order: Rc<RefCell<Vec<usize>>>,
} }
impl<F: Filter> Filter for Counter<F> {
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> {
self.inner.poll_read_ready(cx)
}
fn get_read_buf(&self) -> Option<BytesVec> { impl FilterLayer for Counter {
self.inner.get_read_buf() const BUFFERS: bool = false;
}
fn release_read_buf(&self, buf: BytesVec) { fn process_read_buf(&self, buf: &mut ReadBuf<'_>) -> io::Result<usize> {
self.inner.release_read_buf(buf)
}
fn process_read_buf(&self, io: &IoRef, n: usize) -> io::Result<(usize, usize)> {
let result = self.inner.process_read_buf(io, n)?;
self.read_order.borrow_mut().push(self.idx); self.read_order.borrow_mut().push(self.idx);
self.in_bytes.set(self.in_bytes.get() + result.1); self.in_bytes.set(self.in_bytes.get() + buf.nbytes());
Ok(result) Ok(buf.nbytes())
} }
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> { fn process_write_buf(&self, buf: &mut WriteBuf<'_>) -> io::Result<()> {
self.inner.poll_write_ready(cx)
}
fn get_write_buf(&self) -> Option<BytesVec> {
if let Some(buf) = self.inner.get_write_buf() {
self.out_bytes.set(self.out_bytes.get() - buf.len());
Some(buf)
} else {
None
}
}
fn release_write_buf(&self, buf: BytesVec) -> Result<(), io::Error> {
self.write_order.borrow_mut().push(self.idx); self.write_order.borrow_mut().push(self.idx);
self.out_bytes.set(self.out_bytes.get() + buf.len()); self.out_bytes
self.inner.release_write_buf(buf) .set(self.out_bytes.get() + buf.get_dst().len());
} Ok(())
}
struct CounterFactory(
usize,
Rc<Cell<usize>>,
Rc<Cell<usize>>,
Rc<RefCell<Vec<usize>>>,
Rc<RefCell<Vec<usize>>>,
);
impl<F: Filter> FilterFactory<F> for CounterFactory {
type Filter = Counter<F>;
type Error = ();
type Future = Ready<Io<Counter<F>>, Self::Error>;
fn create(self, io: Io<F>) -> Self::Future {
let idx = self.0;
let in_bytes = self.1.clone();
let out_bytes = self.2.clone();
let read_order = self.3.clone();
let write_order = self.4;
Ready::Ok(
io.map_filter(|inner| {
Ok::<_, ()>(Counter {
idx,
inner,
in_bytes,
out_bytes,
read_order,
write_order,
})
})
.unwrap(),
)
} }
} }
@ -460,24 +417,22 @@ mod tests {
let out_bytes = Rc::new(Cell::new(0)); let out_bytes = Rc::new(Cell::new(0));
let read_order = Rc::new(RefCell::new(Vec::new())); let read_order = Rc::new(RefCell::new(Vec::new()));
let write_order = Rc::new(RefCell::new(Vec::new())); let write_order = Rc::new(RefCell::new(Vec::new()));
let factory = CounterFactory(
1,
in_bytes.clone(),
out_bytes.clone(),
read_order.clone(),
write_order.clone(),
);
let (client, server) = IoTest::create(); let (client, server) = IoTest::create();
let state = Io::new(server).add_filter(factory).await.unwrap(); let io = Io::new(server).add_filter(Counter {
idx: 1,
in_bytes: in_bytes.clone(),
out_bytes: out_bytes.clone(),
read_order: read_order.clone(),
write_order: write_order.clone(),
});
client.remote_buffer_cap(1024); client.remote_buffer_cap(1024);
client.write(TEXT); client.write(TEXT);
let msg = state.recv(&BytesCodec).await.unwrap().unwrap(); let msg = io.recv(&BytesCodec).await.unwrap().unwrap();
assert_eq!(msg, Bytes::from_static(BIN)); assert_eq!(msg, Bytes::from_static(BIN));
state io.send(Bytes::from_static(b"test"), &BytesCodec)
.send(Bytes::from_static(b"test"), &BytesCodec)
.await .await
.unwrap(); .unwrap();
let buf = client.read().await.unwrap(); let buf = client.read().await.unwrap();
@ -496,24 +451,20 @@ mod tests {
let (client, server) = IoTest::create(); let (client, server) = IoTest::create();
let state = Io::new(server) let state = Io::new(server)
.add_filter(CounterFactory( .add_filter(Counter {
1, idx: 1,
in_bytes.clone(), in_bytes: in_bytes.clone(),
out_bytes.clone(), out_bytes: out_bytes.clone(),
read_order.clone(), read_order: read_order.clone(),
write_order.clone(), write_order: write_order.clone(),
)) })
.await .add_filter(Counter {
.unwrap() idx: 2,
.add_filter(CounterFactory( in_bytes: in_bytes.clone(),
2, out_bytes: out_bytes.clone(),
in_bytes.clone(), read_order: read_order.clone(),
out_bytes.clone(), write_order: write_order.clone(),
read_order.clone(), });
write_order.clone(),
))
.await
.unwrap();
let state = state.seal(); let state = state.seal();
client.remote_buffer_cap(1024); client.remote_buffer_cap(1024);

View file

@ -7,6 +7,7 @@ use std::{
pub mod testing; pub mod testing;
pub mod types; pub mod types;
mod buf;
mod dispatcher; mod dispatcher;
mod filter; mod filter;
mod framed; mod framed;
@ -17,12 +18,12 @@ mod tasks;
mod timer; mod timer;
mod utils; mod utils;
use ntex_bytes::BytesVec;
use ntex_codec::{Decoder, Encoder}; use ntex_codec::{Decoder, Encoder};
use ntex_util::time::Millis; use ntex_util::time::Millis;
pub use self::buf::{ReadBuf, WriteBuf};
pub use self::dispatcher::Dispatcher; pub use self::dispatcher::Dispatcher;
pub use self::filter::Base; pub use self::filter::{Base, Filter, Layer};
pub use self::framed::Framed; pub use self::framed::Framed;
pub use self::io::{Io, IoRef, OnDisconnect}; pub use self::io::{Io, IoRef, OnDisconnect};
pub use self::seal::{IoBoxed, Sealed}; pub use self::seal::{IoBoxed, Sealed};
@ -49,44 +50,51 @@ pub enum WriteStatus {
Terminate, Terminate,
} }
pub trait Filter: 'static { #[allow(unused_variables)]
fn query(&self, _: TypeId) -> Option<Box<dyn Any>> { pub trait FilterLayer: 'static {
None /// Create buffers for this filter
const BUFFERS: bool = true;
#[inline]
/// Check readiness for read operations
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> {
Poll::Ready(ReadStatus::Ready)
} }
fn get_read_buf(&self) -> Option<BytesVec>; #[inline]
/// Check readiness for write operations
fn release_read_buf(&self, buf: BytesVec); fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> {
Poll::Ready(WriteStatus::Ready)
}
/// Process read buffer /// Process read buffer
/// ///
/// Returns tuple (total bytes, new bytes) /// Inner filter must process buffer before current.
fn process_read_buf(&self, io: &IoRef, n: usize) -> sio::Result<(usize, usize)>; /// Returns number of new bytes.
fn process_read_buf(&self, buf: &mut ReadBuf<'_>) -> sio::Result<usize>;
fn get_write_buf(&self) -> Option<BytesVec>; /// Process write buffer
fn process_write_buf(&self, buf: &mut WriteBuf<'_>) -> sio::Result<()>;
fn release_write_buf(&self, buf: BytesVec) -> sio::Result<()>; /// Query internal filter data
fn query(&self, id: TypeId) -> Option<Box<dyn Any>> {
/// Check readiness for read operations None
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus>; }
/// Check readiness for write operations
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus>;
/// Gracefully shutdown filter /// Gracefully shutdown filter
fn poll_shutdown(&self) -> Poll<sio::Result<()>> { fn shutdown(&self, buf: &mut WriteBuf<'_>) -> sio::Result<Poll<()>> {
Poll::Ready(Ok(())) Ok(Poll::Ready(()))
} }
} }
/// Creates new `Filter` values. /// Creates new `Filter` values.
pub trait FilterFactory<F: Filter>: Sized { pub trait FilterFactory<F>: Sized {
/// The `Filter` value created by this factory /// The `Filter` value created by this factory
type Filter: Filter; type Filter: FilterLayer;
/// Errors produced while building a filter. /// Errors produced while building a filter.
type Error: fmt::Debug; type Error: fmt::Debug;
/// The future of the `FilterFactory` instance. /// The future of the `FilterFactory` instance.
type Future: Future<Output = Result<Io<Self::Filter>, Self::Error>>; type Future: Future<Output = Result<Io<Layer<Self::Filter, F>>, Self::Error>>;
/// Create and return a new filter value asynchronously. /// Create and return a new filter value asynchronously.
fn create(self, st: Io<F>) -> Self::Future; fn create(self, st: Io<F>) -> Self::Future;

View file

@ -1,6 +1,6 @@
use std::ops; use std::ops;
use crate::{Filter, Io}; use crate::{filter::Filter, Io};
/// Sealed filter type /// Sealed filter type
pub struct Sealed(pub(crate) Box<dyn Filter>); pub struct Sealed(pub(crate) Box<dyn Filter>);

View file

@ -12,62 +12,93 @@ impl ReadContext {
Self(io.clone()) Self(io.clone())
} }
#[inline]
/// Return memory pool for this context
pub fn memory_pool(&self) -> PoolRef {
self.0.memory_pool()
}
#[inline] #[inline]
/// Check readiness for read operations /// Check readiness for read operations
pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> { pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> {
self.0.filter().poll_read_ready(cx) self.0.filter().poll_read_ready(cx)
} }
#[inline]
/// Get read buffer /// Get read buffer
pub fn get_read_buf(&self) -> BytesVec { pub fn with_buf<F>(&self, f: F) -> Poll<()>
self.0 where
.0 F: FnOnce(&mut BytesVec, usize, usize) -> Poll<io::Result<()>>,
.read_buf {
let mut stack = self.0 .0.buffer.borrow_mut();
let mut buf = stack
.last_read_buf()
.take() .take()
.unwrap_or_else(|| self.0.memory_pool().get_read_buf()) .unwrap_or_else(|| self.0.memory_pool().get_read_buf());
}
#[inline] let total = buf.len();
/// Release read buffer after io read operations let (hw, lw) = self.0.memory_pool().read_params().unpack();
pub fn release_read_buf(&self, buf: BytesVec, nbytes: usize) {
// call provided callback
let result = f(&mut buf, hw, lw);
// handle buffer changes
if buf.is_empty() { if buf.is_empty() {
self.0.memory_pool().release_read_buf(buf); self.0.memory_pool().release_read_buf(buf);
} else { } else {
self.0 .0.read_buf.set(Some(buf)); let total2 = buf.len();
let filter = self.0.filter(); let nbytes = if total2 > total { total2 - total } else { 0 };
match filter.process_read_buf(&self.0, nbytes) { *stack.last_read_buf() = Some(buf);
Ok((total, nbytes)) => {
if nbytes > 0 { if nbytes > 0 {
if total > self.0.memory_pool().read_params().high as usize { let buf_full = nbytes >= hw;
match self
.0
.filter()
.process_read_buf(&self.0, &mut stack, 0, nbytes)
{
Ok(nbytes) => {
if nbytes > 0 {
if buf_full || stack.first_read_buf_size() >= hw {
log::trace!(
"io read buffer is too large {}, enable read back-pressure",
total2
);
self.0
.0
.insert_flags(Flags::RD_READY | Flags::RD_BUF_FULL);
} else {
self.0 .0.insert_flags(Flags::RD_READY);
}
self.0 .0.dispatch_task.wake();
log::trace!( log::trace!(
"buffer is too large {}, enable read back-pressure", "new {} bytes available, wakeup dispatcher",
total nbytes,
); );
self.0 .0.insert_flags(Flags::RD_READY | Flags::RD_BUF_FULL); } else if buf_full {
// read task is paused because of read back-pressure
self.0 .0.read_task.wake();
} }
}
Err(err) => {
self.0 .0.dispatch_task.wake(); self.0 .0.dispatch_task.wake();
self.0 .0.insert_flags(Flags::RD_READY); self.0 .0.insert_flags(Flags::RD_READY);
log::trace!("new {} bytes available, wakeup dispatcher", nbytes); self.0 .0.init_shutdown(Some(err), &self.0);
} }
} }
Err(err) => {
self.0 .0.dispatch_task.wake();
self.0 .0.insert_flags(Flags::RD_READY);
self.0.want_shutdown(Some(err));
}
} }
} }
let result = match result {
Poll::Ready(Ok(())) => {
self.0 .0.io_stopped(None);
Poll::Ready(())
}
Poll::Ready(Err(e)) => {
self.0 .0.io_stopped(Some(e));
Poll::Ready(())
}
Poll::Pending => Poll::Pending,
};
drop(stack);
if self.0.flags().contains(Flags::IO_STOPPING_FILTERS) { if self.0.flags().contains(Flags::IO_STOPPING_FILTERS) {
self.0 .0.shutdown_filters(); self.0 .0.shutdown_filters(&self.0);
} }
result
} }
#[inline] #[inline]
@ -100,7 +131,7 @@ impl WriteContext {
#[inline] #[inline]
/// Get write buffer /// Get write buffer
pub fn get_write_buf(&self) -> Option<BytesVec> { pub fn get_write_buf(&self) -> Option<BytesVec> {
self.0 .0.write_buf.take() self.0 .0.buffer.borrow_mut().last_write_buf().take()
} }
#[inline] #[inline]
@ -125,11 +156,11 @@ impl WriteContext {
self.0.set_flags(flags); self.0.set_flags(flags);
self.0 .0.dispatch_task.wake(); self.0 .0.dispatch_task.wake();
} }
self.0 .0.write_buf.set(Some(buf)) self.0 .0.buffer.borrow_mut().set_last_write_buf(buf);
} }
if self.0.flags().contains(Flags::IO_STOPPING_FILTERS) { if self.0.flags().contains(Flags::IO_STOPPING_FILTERS) {
self.0 .0.shutdown_filters(); self.0 .0.shutdown_filters(&self.0);
} }
Ok(()) Ok(())

View file

@ -344,6 +344,12 @@ impl Drop for IoTest {
_ => (), _ => (),
} }
self.state.set(state); self.state.set(state);
let guard = self.remote.lock().unwrap();
let mut remote = guard.borrow_mut();
remote.read = IoTestState::Close;
remote.waker.wake();
log::trace!("drop remote socket");
} }
} }
@ -388,58 +394,58 @@ impl Future for ReadTask {
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.as_ref(); let this = self.as_ref();
match this.state.poll_ready(cx) { this.state.with_buf(|buf, hw, lw| {
Poll::Ready(ReadStatus::Terminate) => { match this.state.poll_ready(cx) {
log::trace!("read task is instructed to terminate"); Poll::Ready(ReadStatus::Terminate) => {
Poll::Ready(()) log::trace!("read task is instructed to terminate");
} Poll::Ready(Ok(()))
Poll::Ready(ReadStatus::Ready) => { }
let io = &this.io; Poll::Ready(ReadStatus::Ready) => {
let pool = this.state.memory_pool(); let io = &this.io;
let mut buf = self.state.get_read_buf();
let (hw, lw) = pool.read_params().unpack();
// read data from socket // read data from socket
let mut new_bytes = 0; let mut new_bytes = 0;
loop { loop {
// make sure we've got room // make sure we've got room
let remaining = buf.remaining_mut(); let remaining = buf.remaining_mut();
if remaining < lw { if remaining < lw {
buf.reserve(hw - remaining); buf.reserve(hw - remaining);
}
match io.poll_read_buf(cx, &mut buf) {
Poll::Pending => {
log::trace!("no more data in io stream, read: {:?}", new_bytes);
break;
} }
Poll::Ready(Ok(n)) => { match io.poll_read_buf(cx, buf) {
if n == 0 { Poll::Pending => {
log::trace!("io stream is disconnected"); log::trace!(
this.state.release_read_buf(buf, new_bytes); "no more data in io stream, read: {:?}",
this.state.close(None); new_bytes
return Poll::Ready(()); );
} else { break;
new_bytes += n; }
if buf.len() > hw { Poll::Ready(Ok(n)) => {
break; if n == 0 {
log::trace!("io stream is disconnected");
return Poll::Ready(Ok(()));
} else {
new_bytes += n;
if buf.len() >= hw {
log::trace!(
"high water mark pause reading, read: {:?}",
new_bytes
);
break;
}
} }
} }
} Poll::Ready(Err(err)) => {
Poll::Ready(Err(err)) => { log::trace!("read task failed on io {:?}", err);
log::trace!("read task failed on io {:?}", err); return Poll::Ready(Err(err));
this.state.release_read_buf(buf, new_bytes); }
this.state.close(Some(err));
return Poll::Ready(());
} }
} }
}
this.state.release_read_buf(buf, new_bytes); Poll::Pending
Poll::Pending }
Poll::Pending => Poll::Pending,
} }
Poll::Pending => Poll::Pending, })
}
} }
} }

View file

@ -3,7 +3,7 @@ use std::marker::PhantomData;
use ntex_service::{fn_service, pipeline_factory, Service, ServiceFactory}; use ntex_service::{fn_service, pipeline_factory, Service, ServiceFactory};
use ntex_util::future::Ready; use ntex_util::future::Ready;
use crate::{Filter, FilterFactory, Io, IoBoxed}; use crate::{Filter, FilterFactory, Io, IoBoxed, Layer};
/// Service that converts any Io<F> stream to IoBoxed stream /// Service that converts any Io<F> stream to IoBoxed stream
pub fn seal<F, S, C>( pub fn seal<F, S, C>(
@ -30,7 +30,6 @@ where
pub fn filter<T, F>(filter: T) -> FilterServiceFactory<T, F> pub fn filter<T, F>(filter: T) -> FilterServiceFactory<T, F>
where where
T: FilterFactory<F> + Clone, T: FilterFactory<F> + Clone,
F: Filter,
{ {
FilterServiceFactory { FilterServiceFactory {
filter, filter,
@ -46,9 +45,8 @@ pub struct FilterServiceFactory<T, F> {
impl<T, F> ServiceFactory<Io<F>> for FilterServiceFactory<T, F> impl<T, F> ServiceFactory<Io<F>> for FilterServiceFactory<T, F>
where where
T: FilterFactory<F> + Clone, T: FilterFactory<F> + Clone,
F: Filter,
{ {
type Response = Io<T::Filter>; type Response = Io<Layer<T::Filter, F>>;
type Error = T::Error; type Error = T::Error;
type Service = FilterService<T, F>; type Service = FilterService<T, F>;
type InitError = (); type InitError = ();
@ -71,25 +69,28 @@ pub struct FilterService<T, F> {
impl<T, F> Service<Io<F>> for FilterService<T, F> impl<T, F> Service<Io<F>> for FilterService<T, F>
where where
T: FilterFactory<F> + Clone, T: FilterFactory<F> + Clone,
F: Filter,
{ {
type Response = Io<T::Filter>; type Response = Io<Layer<T::Filter, F>>;
type Error = T::Error; type Error = T::Error;
type Future<'f> = T::Future where T: 'f; type Future<'f> = T::Future where T: 'f, F: 'f;
#[inline] #[inline]
fn call(&self, req: Io<F>) -> Self::Future<'_> { fn call(&self, req: Io<F>) -> Self::Future<'_> {
req.add_filter(self.filter.clone()) self.filter.clone().create(req)
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use ntex_bytes::{Bytes, BytesVec}; use std::io;
use ntex_bytes::Bytes;
use ntex_codec::BytesCodec; use ntex_codec::BytesCodec;
use super::*; use super::*;
use crate::{filter::NullFilter, testing::IoTest}; use crate::{
buf::Stack, filter::NullFilter, testing::IoTest, FilterLayer, ReadBuf, WriteBuf,
};
#[ntex::test] #[ntex::test]
async fn test_utils() { async fn test_utils() {
@ -114,16 +115,28 @@ mod tests {
assert_eq!(buf, b"RES".as_ref()); assert_eq!(buf, b"RES".as_ref());
} }
#[derive(Copy, Clone, Debug)] pub(crate) struct TestFilter;
struct NullFilterFactory;
impl<F: Filter> FilterFactory<F> for NullFilterFactory { impl FilterLayer for TestFilter {
type Filter = crate::filter::NullFilter; fn process_read_buf(&self, buf: &mut ReadBuf<'_>) -> io::Result<usize> {
Ok(buf.nbytes())
}
fn process_write_buf(&self, _: &mut WriteBuf<'_>) -> io::Result<()> {
Ok(())
}
}
#[derive(Copy, Clone, Debug)]
struct TestFilterFactory;
impl<F: Filter> FilterFactory<F> for TestFilterFactory {
type Filter = TestFilter;
type Error = std::convert::Infallible; type Error = std::convert::Infallible;
type Future = Ready<Io<Self::Filter>, Self::Error>; type Future = Ready<Io<Layer<TestFilter, F>>, Self::Error>;
fn create(self, st: Io<F>) -> Self::Future { fn create(self, st: Io<F>) -> Self::Future {
st.map_filter(|_| Ok(NullFilter)).into() Ready::Ok(st.add_filter(TestFilter).into())
} }
} }
@ -131,7 +144,7 @@ mod tests {
async fn test_utils_filter() { async fn test_utils_filter() {
let (_, server) = IoTest::create(); let (_, server) = IoTest::create();
let svc = pipeline_factory( let svc = pipeline_factory(
filter::<_, crate::filter::Base>(NullFilterFactory) filter::<_, crate::filter::Base>(TestFilterFactory)
.map_err(|_| ()) .map_err(|_| ())
.map_init_err(|_| ()), .map_init_err(|_| ()),
) )
@ -147,8 +160,15 @@ mod tests {
#[ntex::test] #[ntex::test]
async fn test_null_filter() { async fn test_null_filter() {
let (_, server) = IoTest::create();
let io = Io::new(server);
let ioref = io.get_ref();
let mut stack = Stack::new();
assert!(NullFilter.query(std::any::TypeId::of::<()>()).is_none()); assert!(NullFilter.query(std::any::TypeId::of::<()>()).is_none());
assert!(NullFilter.poll_shutdown().is_ready()); assert!(NullFilter
.shutdown(&ioref, &mut stack, 0)
.unwrap()
.is_ready());
assert_eq!( assert_eq!(
ntex_util::future::poll_fn(|cx| NullFilter.poll_read_ready(cx)).await, ntex_util::future::poll_fn(|cx| NullFilter.poll_read_ready(cx)).await,
crate::ReadStatus::Terminate crate::ReadStatus::Terminate
@ -157,16 +177,12 @@ mod tests {
ntex_util::future::poll_fn(|cx| NullFilter.poll_write_ready(cx)).await, ntex_util::future::poll_fn(|cx| NullFilter.poll_write_ready(cx)).await,
crate::WriteStatus::Terminate crate::WriteStatus::Terminate
); );
assert_eq!(NullFilter.get_read_buf(), None); assert!(NullFilter.process_write_buf(&ioref, &mut stack, 0).is_ok());
assert_eq!(NullFilter.get_write_buf(), None);
assert!(NullFilter.release_write_buf(BytesVec::new()).is_ok());
NullFilter.release_read_buf(BytesVec::new());
let (_, server) = IoTest::create();
let io = Io::new(server);
assert_eq!( assert_eq!(
NullFilter.process_read_buf(&io.get_ref(), 10).unwrap(), NullFilter
(0, 0) .process_read_buf(&ioref, &mut stack, 0, 0)
.unwrap(),
(0)
) )
} }
} }

View file

@ -1,5 +1,9 @@
# Changes # Changes
## [0.2.1] - 2023-01-23
* Update filter implementation
## [0.2.0] - 2023-01-04 ## [0.2.0] - 2023-01-04
* Release * Release

View file

@ -1,6 +1,6 @@
[package] [package]
name = "ntex-tls" name = "ntex-tls"
version = "0.2.0" version = "0.2.1"
authors = ["ntex contributors <team@ntex.rs>"] authors = ["ntex contributors <team@ntex.rs>"]
description = "An implementation of SSL streams for ntex backed by OpenSSL" description = "An implementation of SSL streams for ntex backed by OpenSSL"
keywords = ["network", "framework", "async", "futures"] keywords = ["network", "framework", "async", "futures"]
@ -25,15 +25,15 @@ openssl = ["tls_openssl"]
rustls = ["tls_rust"] rustls = ["tls_rust"]
[dependencies] [dependencies]
ntex-bytes = "0.1.18" ntex-bytes = "0.1.19"
ntex-io = "0.2.0" ntex-io = "0.2.1"
ntex-util = "0.2.0" ntex-util = "0.2.0"
ntex-service = "1.0.0" ntex-service = "1.0.0"
log = "0.4" log = "0.4"
pin-project-lite = "0.2" pin-project-lite = "0.2"
# openssl # openssl
tls_openssl = { version="0.10.42", package = "openssl", optional = true } tls_openssl = { version="0.10", package = "openssl", optional = true }
# rustls # rustls
tls_rust = { version = "0.20", package = "rustls", optional = true } tls_rust = { version = "0.20", package = "rustls", optional = true }

View file

@ -1,31 +1,31 @@
-----BEGIN CERTIFICATE----- -----BEGIN CERTIFICATE-----
MIIFPjCCAyYCCQDWGwiaSniPcTANBgkqhkiG9w0BAQsFADBhMQswCQYDVQQGEwJV MIIFPjCCAyYCCQDvLYiYD+jqeTANBgkqhkiG9w0BAQsFADBhMQswCQYDVQQGEwJV
UzELMAkGA1UECAwCQ0ExCzAJBgNVBAcMAlNGMRAwDgYDVQQKDAdDb21wYW55MQww UzELMAkGA1UECAwCQ0ExCzAJBgNVBAcMAlNGMRAwDgYDVQQKDAdDb21wYW55MQww
CgYDVQQLDANPcmcxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTAeFw0yMTEyMTgx CgYDVQQLDANPcmcxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTAeFw0xODAxMjUx
NjMwNDlaFw0yMjEyMTgxNjMwNDlaMGExCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJD NzQ2MDFaFw0xOTAxMjUxNzQ2MDFaMGExCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJD
QTELMAkGA1UEBwwCU0YxEDAOBgNVBAoMB0NvbXBhbnkxDDAKBgNVBAsMA09yZzEY QTELMAkGA1UEBwwCU0YxEDAOBgNVBAoMB0NvbXBhbnkxDDAKBgNVBAsMA09yZzEY
MBYGA1UEAwwPd3d3LmV4YW1wbGUuY29tMIICIjANBgkqhkiG9w0BAQEFAAOCAg8A MBYGA1UEAwwPd3d3LmV4YW1wbGUuY29tMIICIjANBgkqhkiG9w0BAQEFAAOCAg8A
MIICCgKCAgEAryUL1k7npaMOck9OO+EjzeL0FoysOP5JrgRh+8BoPY7WPyL56oFP MIICCgKCAgEA2WzIA2IpVR9Tb9EFhITlxuhE5rY2a3S6qzYNzQVgSFggxXEPn8k1
aCYKp2YMucmvFh/VZSyupC75JJNIaW0fvcIe4Euzy2Ex0VukPxYteRicaWRsxSId sQEcer5BfAP986Sck3H0FvB4Bt/I8PwOtUCmhwcc8KtB5TcGPR4fjXnrpC+MIK5U
o5RNNHd7JOf3ZWIMqkxmDhPNnqSGHcnVs/14+I5IbJCoba+KNElmL9CrL3gQkqNY NLkwuyBDKziYzTdBj8kUFX1WxmvEHEgqToPOZfBgsS71cJAR/zOWraDLSRM54jXy
Jf2FSIgou5j1OthEdnQpiRxSRLmJ7gXtvpFGgj4AnrHGsMAPHueeop6yOX6egFnw voLZN4Ti9rQagQrvTQ44Vz5ycDQy7UxtbUGh1CVv69vNVr7/SOOh/Nw5FNOZWLWr
2cwp98c/0tMOUsXnDU1MTGF11+4UVr043SruZKU7bvhMZRcf4NTR2MNin0b3DYJ+ odGyoec5wh9iqRZgRqiTUc6Lt7V2RWc2X2gjwST2UfI+U46Ip3oaQ7ZD4eAkoqND
JbTn+HgPhhhx3mrsWRyCvfP23jzwnV/222o+U46i7tNYYrDN8vXIM17gtIvKrv2F xdniBZAykVG3c/99ux4BAESTF8fsNch6UticBxYMuTu+ouvP0psfI9wwwNliJDmA
CLTJE6tsp0xAi6dT+J+AIVqkJntrsxqx2CuOYGOOkPPc4rSf64bwOR1mikdvZCnV CRUTB9AgRynbL1AzhqQoDfsb98IZfjfNOpwnwuLwpMAPhbgd5KNdZaIJ4Hb6/stI
NwGEXcH3nBRFMlk5bByCW0kUy03QNakiUEF+PoFzLrCL+V+21Q6Fd7Jmw06BzVFV yFElOExxd3TAxF2Gshd/lq1JcNHAZ1DSXV5MvOWT/NWgXwbIzUgQ8eIi+HuDYX2U
2YtsqFcSo7HXW91XJTDVJCPnrMJOooKQ9Fbq4zbQM0Lv02LyJWyR+0PMBzy4FfkW UuaB6R8tbd52H7rbUv6HrfinuSlKWqjSYLkiKHkwUpoMw8y9UycRSzs1E9nPwPTO
ZWz10g3w+CITL/MQ65fsBBc9hRHC3QBWetj3puqM8DlPwqPhgmCA5zo8AWx7CogR vRXb0mNCQeBCV9FvStNVXdCUTT8LGPv87xSD2pmt7LijlE6mHLG8McfcWkzA69un
V66ukkeBYXYFHwV5uDJTX91tbwYesOL43rlDT905aV0VbaAyDZflipMCAwEAATAN CEHIFAFDimTuN7EBljc119xWFTcHMyoZAfFF+oTqwSbBGImruCxnaJECAwEAATAN
BgkqhkiG9w0BAQsFAAOCAgEAWeq502+YKMHrk8YD4L2mzY/AHSEWf6XubMgkNRbh BgkqhkiG9w0BAQsFAAOCAgEApavsgsn7SpPHfhDSN5iZs1ILZQRewJg0Bty0xPfk
s72+zJs2SrAzu+y+iv5La4JXOxrWEvZOUCKAK0sRG/+ESQxul5mbyPQLWFJgSqv5 3tynSW6bNH3nSaKbpsdmxxomthNSQgD2heOq1By9YzeOoNR+7Pk3s4FkASnf3ToI
O2RmhQ65l+O6RjPZbHPNJMTLMMlkFrKctgGIg5ysKHWPEZZ7ZlS3maxon+X75/b5 JNTUasBFFfaCG96s4Yvs8KiWS/k84yaWuU8c3Wb1jXs5Rv1qE1Uvuwat1DSGXSoD
uI3BxBpJTWcg6zOxh0+zIxhesgEbRmaEz6qu3ZSktBeUQFpTElreCcbkntlFbr+9 JNluuIkCsC4kWkyq5pWCGQrabWPRTWsHwC3PTcwSRBaFgYLJaR72SloHB1ot02zL
SiKkaO4l6qEwRDhA595/7/JRZo4R5U1MifU6JhTMOyXTsH3BV1aVeS81/9jGPHl8 d2age9dmFRFLLCBzP+D7RojBvL37qS/HR+rQ4SoQwiVc/JzaeqSe7ZbvEH9sZYEu
kgVxeKSpL/jDwuSJdr+dMxs/TJHV6fsnVewFFFmigLWThYGDnKmXqJQNyt8utRpe ALowJzgbwro7oZflwTWunSeSGDSltkqKjvWvZI61pwfHKDahUTmZ5h2y67FuGEaC
6vvReWSSIece1EdBActy0rtjPaUJNTTdYk1UYo63OIbCguLWQD1XYN1qJg4KWJzB CIOUI8dSVSPKITxaq3JL4ze2e9/0Lt7hj19YK2uUmtMAW5Tirz4Yx5lyGH9U8Wur
PjS6KCOLmJvYrAxRMED4XeZ17+PxC3xr2IpAL+loRhZUuxXV4GhccGZ4z89OIdOU y/X8VPxTc4A9TMlJgkyz0hqvhbPOT/zSWB10zXh0glKAsSBryAOEDxV1UygmSir7
x97x2BjjV5Nnnt6eBfF3vP5sOz31QpAS/8tzdlGD+6Xq2/i1ZKMPrwgs2dhTyah0 YV8Qaq+oyKUTMc1MFq5vZ07M51EPaietn85t8V2Y+k/8XYltRp32NxsypxAJuyxh
kCBfdE88Zew/A79z55IsVNiYJ4MrD8WTFjcM2j8SgI7tg+M+X/unj+wnzYT0L0dg g/ko6RVTrWa1sMvz/F9LFqAdKiK5eM96lh9IU4xiLg4ob8aS/GRAA8oIFkZFhLrt
BEfzPd7zWdDOPInlTV9zUj1WOsLHX9odOh/Jj5X0FV5vZtcyQ0sGJAhdgTaXDvXs tOwjIUPmEPyHWFi8dLpNuQKYalLYhuwZftG/9xV+wqhKGZO9iPrpHSYBRTap8w2y
Ing= 1QU=
-----END CERTIFICATE----- -----END CERTIFICATE-----

View file

@ -1,51 +1,51 @@
-----BEGIN RSA PRIVATE KEY----- -----BEGIN RSA PRIVATE KEY-----
MIIJKQIBAAKCAgEAryUL1k7npaMOck9OO+EjzeL0FoysOP5JrgRh+8BoPY7WPyL5 MIIJKAIBAAKCAgEA2WzIA2IpVR9Tb9EFhITlxuhE5rY2a3S6qzYNzQVgSFggxXEP
6oFPaCYKp2YMucmvFh/VZSyupC75JJNIaW0fvcIe4Euzy2Ex0VukPxYteRicaWRs n8k1sQEcer5BfAP986Sck3H0FvB4Bt/I8PwOtUCmhwcc8KtB5TcGPR4fjXnrpC+M
xSIdo5RNNHd7JOf3ZWIMqkxmDhPNnqSGHcnVs/14+I5IbJCoba+KNElmL9CrL3gQ IK5UNLkwuyBDKziYzTdBj8kUFX1WxmvEHEgqToPOZfBgsS71cJAR/zOWraDLSRM5
kqNYJf2FSIgou5j1OthEdnQpiRxSRLmJ7gXtvpFGgj4AnrHGsMAPHueeop6yOX6e 4jXyvoLZN4Ti9rQagQrvTQ44Vz5ycDQy7UxtbUGh1CVv69vNVr7/SOOh/Nw5FNOZ
gFnw2cwp98c/0tMOUsXnDU1MTGF11+4UVr043SruZKU7bvhMZRcf4NTR2MNin0b3 WLWrodGyoec5wh9iqRZgRqiTUc6Lt7V2RWc2X2gjwST2UfI+U46Ip3oaQ7ZD4eAk
DYJ+JbTn+HgPhhhx3mrsWRyCvfP23jzwnV/222o+U46i7tNYYrDN8vXIM17gtIvK oqNDxdniBZAykVG3c/99ux4BAESTF8fsNch6UticBxYMuTu+ouvP0psfI9wwwNli
rv2FCLTJE6tsp0xAi6dT+J+AIVqkJntrsxqx2CuOYGOOkPPc4rSf64bwOR1mikdv JDmACRUTB9AgRynbL1AzhqQoDfsb98IZfjfNOpwnwuLwpMAPhbgd5KNdZaIJ4Hb6
ZCnVNwGEXcH3nBRFMlk5bByCW0kUy03QNakiUEF+PoFzLrCL+V+21Q6Fd7Jmw06B /stIyFElOExxd3TAxF2Gshd/lq1JcNHAZ1DSXV5MvOWT/NWgXwbIzUgQ8eIi+HuD
zVFV2YtsqFcSo7HXW91XJTDVJCPnrMJOooKQ9Fbq4zbQM0Lv02LyJWyR+0PMBzy4 YX2UUuaB6R8tbd52H7rbUv6HrfinuSlKWqjSYLkiKHkwUpoMw8y9UycRSzs1E9nP
FfkWZWz10g3w+CITL/MQ65fsBBc9hRHC3QBWetj3puqM8DlPwqPhgmCA5zo8AWx7 wPTOvRXb0mNCQeBCV9FvStNVXdCUTT8LGPv87xSD2pmt7LijlE6mHLG8McfcWkzA
CogRV66ukkeBYXYFHwV5uDJTX91tbwYesOL43rlDT905aV0VbaAyDZflipMCAwEA 69unCEHIFAFDimTuN7EBljc119xWFTcHMyoZAfFF+oTqwSbBGImruCxnaJECAwEA
AQKCAgBoOnqt4a0XNE8PlcRv/A6Loskxdiuzixib133cDOe74nn7frwNY0C3MRRc AQKCAgAME3aoeXNCPxMrSri7u4Xnnk71YXl0Tm9vwvjRQlMusXZggP8VKN/KjP0/
BG4ETlLErtMWb53KlS2tJ30LSGaATbqELmjj2oaEGa5H4NHU4+GJErtsIV5UD5hW 9AE/GhmoxqPLrLCZ9ZE1EIjgmZ9Xgde9+C8rTtfCG2RFUL7/5J2p6NonlocmxoJm
ZdhB4U2n5s60tdxx+jT+eNhbd9aWU3yfJkVRXlDtXW64qQmH4P1OtXvfWBfIG/Qq YkxYwjP6ce86RTjQWL3RF3s09u0inz9/efJk5O7M6bOWMQ9VZXDlBiRY5BYvbqUR
cuUSpvchOrybZYumTdVjkqrTnHGcW+YC8hT6W79rRhB5issr6ZcUghafOWcMpeQ/ 6FeSzD4MnMbdyMRoVBeXE88gTvZk8xhB6DJnLzYgc0tKiRoeKT0iYv5JZw25VyRM
0TJZK0K13ZIfp2WFeuZfRw6Rg/AIJllSScZxxo/oBPfym5P6FGRndxrkzkh19g+q ycLzfTrFmXCPfB1ylb483d9Ly4fBlM8nkx37PzEnAuukIawDxsPOb9yZC+hfvNJI
HQDYA0oYW7clXMMtebbrEIb8kLRdaIHDiwyFXmyywvuAAk0jHbA8snM2dyeJWSRr 7NFiMN+3maEqG2iC00w4Lep4skHY7eHUEUMl+Wjr+koAy2YGLWAwHZQTm7iXn9Ab
WQjvQFccGF4z390ZGUCN0ZeESskndg12r4jYaL/aQ8dQZ1ivS69F8bmbQtQNU2Ej L6adL53zyCKelRuEQOzbeosJAqS+5fpMK0ekXyoFIuskj7bWuIoCX7K/kg6q5IW+
hscTUzEMOnrBTxvRQTjI9nnrbsbklagKmJHXOc/fj13g6/FkcfmeTrjuB30LxJyH vC2FrlsrbQ79GztWLVmHFO1I4J9M5r666YS0qdh8c+2yyRl4FmSiHfGxb3eOKpxQ
j+xXAi8AGv/oZRk6s/txas5hXpcFXnQDRobVoJjV8kuomcDTt1j33H+05ACFyvHM b6uI97iZlkxPF9LYUCSc7wq0V2gGz+6LnGvTHlHrOfVXqw/5pLAKhXqxvnroDTwz
/2jxJ1f3xbFx3fqivL89+Z4r8RYxLoWLg7QuqQLdtRgThEKUG0t3lt59fUo+JVVJ 0Ay/xFF6ei/NSxBY5t8ztGCBm45wCU3l8pW0X6dXqwUipw5b4MRy1VFRu6rqlmbL
CgRbj/OM3n5udgiIeBAyMAMZjVPUKhvLIFpiUY2vKnYx/97L0QKCAQEA4QUt3dEh OPSCuLxqyqsigiEYsBgS/icvXz9DWmCQMPd2XM9YhsHvUq+R4QKCAQEA98EuMMXI
P0L4eQEAzg/J8JuleH7io5VxoK5c2oulhCdQdRDF5HWSesPKJmCmgJRmIXi7zseo 6UKIt1kK2t/3OeJRyDd4iv/fCMUAnuPjLBvFE4cXD/SbqCxcQYqb+pue3PYkiTIC
Sbg7Hd2xt/QnaPhRnASXJOdn7ddtoZ1M6Zb0y+d6mmcG+mK6PshtMCQ5S3Lqhsuh 71rN8OQAc5yKhzmmnCE5N26br/0pG4pwEjIr6mt8kZHmemOCNEzvhhT83nfKmV0g
tYQbwawNlCFzwzCzwGb3aD9lBKQYts7KFrMT3Goexg3Qqv374XGn6Eg1LMhXWYbT 9lNtuGEQMiwmZrpUOF51JOMC39bzcVjYX2Cmvb7cFbIq3lR0zwM+aZpQ4P8LHCIu
M5gcPOYnOT+RugeaTxMnJ6nr6E7kyrLIS+xASXKwXGxSUsQG9VWH7jDuzzARrPEU bgHmwbdlkLyIULJcQmHIbo6nPFB3ZZE4mqmjwY+rA6Fh9rgBa8OFCfTtrgeYXrNb
aeyxWdbDkBn2vzW+wDpMPMqzoShZsRC9NnFfncXRZfUC5DJWGzwA/xZaR0ZNNng2 IgZQ5U8GoYRPNC2ot0vpTinraboa/cgm6oG4M7FW1POCJTl+/ktHEnKuO5oroSga
OE7rILyAH/aZSQKCAQEAx0ICGi7y94vn5KWdaNVol3nPid4aMnk4LDcX5m0tiqUG /BSg7hCNFVaOhwKCAQEA4Kkys0HtwEbV5mY/NnvUD5KwfXX7BxoXc9lZ6seVoLEc
7LIqnFDOOjEdxTf13n7Cv4gotOQNluOypswSDZI4tI0xQ/dJ8PI+vwmA0oHSzf7U KjgPYxqYRVrC7dB2YDwwp3qcRTi/uBAgFNm3iYlDzI4xS5SeaudUWjglj7BSgXE2
ZPO2gzIOzububPllQsCrKHN++2SyyNlKyYFu/akmlu6yIN3EMRLqYKvZaIL5z9Lk iOEa7EwcvVPluLaTgiWjlzUKeUCNNHWSeQOt+paBOT+IgwRVemGVpAgkqQzNh/nP
pTU7eS0AsXJyqD54zRLFkw6S9omQHJXrEzYAuZI+Ue/Arlgyq95mUMsHYRHgaTq4 tl3p9aNtgzEm1qVlPclY/XUCtf3bcOR+z1f1b4jBdn0leu5OhnxkC+Htik+2fTXD
GDMDLHNyrdKUhW+ZiZ9dhX+aRghHgNiXDk/Eh2/RZrLhKdVk94dJQbfGu/aiSk71 jt6JGrMkanN25YzsjnD3Sn+v6SO26H99wnYx5oMSdmb8SlWRrKtfJHnihphjG/YY
dXPEAaQ7o1MDwQgu4TsCVCzac/CeqvmcoMFyx3NA+wKCAQEAoLfLR8hsH7wcroiZ l1cyorV6M/asSgXNQfGJm4OuJi0I4/FL2wLUHnU+JwKCAQEAzh4WipcRthYXXcoj
45QBXzo8WLD//WjrDKIdLfdaE+bkn4iIX6HeKpMXGowjwGi9/aA3O/z85RKSHsXO gMKRkMOb3GFh1OpYqJgVExtudNTJmZxq8GhFU51MR27Eo7LycMwKy2UjEfTOnplh
fp4DXAUofPAGaFRjtcwNwMYSPjEUzWKa/hciM8o6TkdnPWBSD+KXQgnFiVk/Xfge Us2qZiPtW7k8O8S2m6yXlYUQBeNdq9IuuYDTaYD94vsazscJNSAeGodjE+uGvb1q
hrPR9BMgAAdLJIlLBKKUCFXwn3/uaprdOgZ6CPd5ZU+BZvXUDRVW1lnnFc3KNXEJ 1wLqE87yoE7dUInYa1cOA3+xy2/CaNuviBFJHtzOrSb6tqqenQEyQf6h9/12+DTW
iOkvk5iEjYAXkkvadEWNQn2pdBjc3djtwEWaEwVyFt6tROJsX01tAoH6W6G0Fn+/ t5pSIiixHrzxHiFqOoCLRKGToQB+71rSINwTf0nITNpGBWmSj5VcC3VV3TG5/XxI
lHgG9hFUGgZJl44L+MpSLZbQHkehzJWS92ilVQni2HbmG0wC1S+QTJxV1agAZpRc fPlxV2yhD5WFDPVNGBGvwPDSh4jSMZdZMSNBZCy4XWFNSKjGEWoK4DFYed3DoSt9
SvgeCQKCAQB3PnVrnfUhV8Sq/MG63xv8qpUc+KHM2uZW75GKAIRkmGYQeH8vlNwV 5IG1YwKCAQA63ntHl64KJUWlkwNbboU583FF3uWBjee5VqoGKHhf3CkKMxhtGqnt
zxb104t8X3fEj4Ns3Z2UUyey0iVrobn1sxlshyzk2NPcF5/UWoUBaiNJVuA+m1Jp +oN7t5VdUEhbinhqdx1dyPPvIsHCS3K1pkjqii4cyzNCVNYa2dQ00Qq+QWZBpwwc
V6IP7SBAVnUXfCbd42Fq+T7cYG0/uF6zrJ1FNfIXPC6vM6ij9t3xFVBn3fd9iQUF 3GAkz8rFXsGIPMDa1vxpU6mnBjzPniKMcsZ9tmQDppCEpBGfLpio2eAA5IkK8eEf
LGyZaul4MGe0neAtUh3APae0k3jTlUVeW5B/xaBtYmbwqs/7s2sNDmrlcIHRtDVI cIDB3CM0Vo94EvI76CJZabaE9IJ+0HIJb2+jz9BJ00yQBIqvJIYoNy9gP5Xjpi+T
+OCRCjxkM88P+VEl4AaKgRPFKM+ADdbPEvXUxzPpPjkE7yorimmM9rvGUkVWhiZ6 qV/tdMkD5jwWjHD3AYHLWKUGkNwwkAYFeqT/gX6jpWBP+ZRPOp011X3KInJFSpKU
k0+H0ZHckCfQoBcLk1AhGcg2HA7IdZzJAoIBAQDAicb6CWlNdaIcJfADKSNK4+BF DT5GQ1Dux7EMTCwVGtXqjO8Ym5wjwwsfAoIBAEcxlhIW1G6BiNfnWbNPWBdh3v/K
JFbH+lXYrTxVSTV+Ubdi0w8Kwk0bzf20EstJnaOCyLDCjcxafjbmuGBVbw7an0lt 5Ln98Rcrz8UIbWyl7qNPjYb13C1KmifVG1Rym9vWMO3KuG5atK3Mz2yLVRtmWAVc
Yxjx0fWXxMfvb9/3vKuJVUySA4iq/zfXKlokRaFoqbdRfod3PVGUsynCV7HmImf3 fxzR57zz9MZFDun66xo+Z1wN3fVxQB4CYpOEI4Lb9ioX4v85hm3D6RpFukNtRQEc
RZA0WkcSwzbg2E2QNKQ3CPd3cHtPpBX8TwRCotg/R5yCR9lihVfkSyULikwBFvrm Gfr4scTjJX4jFWDp0h6ffMb8mY+quvZoJ0TJqV9L9Yj6Ksdvqez/bdSraev97bHQ
2UKZm4pPESWSfMHBToJoAeO0g67itbwwpNhwvgUdyjaj8u46qyjN1FMx3mBiv7Yq 4gbQxaTZ6WjaD4HjpPQefMdWp97Metg0ZQSS8b8EzmNFgyJ3XcjirzwliKTAQtn6
CIE+H0qNu0jmFhoqPrgxfFrGCi6eDPYjRS86536Nc4m8y24z2hie8JLK8QKQ I2sd0NCIooelrKRD8EJoDUwxoOctY7R97wpZ7/wEHU45cBCbRV3H4JILS5c=
-----END RSA PRIVATE KEY----- -----END RSA PRIVATE KEY-----

View file

@ -27,7 +27,8 @@ async fn main() -> io::Result<()> {
// rustls connector // rustls connector
let connector = connect::rustls::Connector::new(config.clone()); let connector = connect::rustls::Connector::new(config.clone());
let io = connector.connect("www.rust-lang.org:443").await.unwrap(); //let io = connector.connect("www.rust-lang.org:443").await.unwrap();
let io = connector.connect("127.0.0.1:8443").await.unwrap();
println!("Connected to tls server {:?}", io.query::<PeerAddr>().get()); println!("Connected to tls server {:?}", io.query::<PeerAddr>().get());
let result = io let result = io
.send(Bytes::from_static(b"GET /\r\n\r\n"), &codec::BytesCodec) .send(Bytes::from_static(b"GET /\r\n\r\n"), &codec::BytesCodec)

View file

@ -1,7 +1,7 @@
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use std::{error::Error, future::Future, marker::PhantomData, pin::Pin}; use std::{error::Error, future::Future, marker::PhantomData, pin::Pin};
use ntex_io::{Filter, FilterFactory, Io}; use ntex_io::{Filter, FilterFactory, Io, Layer};
use ntex_service::{Service, ServiceFactory}; use ntex_service::{Service, ServiceFactory};
use ntex_util::{future::Ready, time::Millis}; use ntex_util::{future::Ready, time::Millis};
use tls_openssl::ssl::SslAcceptor; use tls_openssl::ssl::SslAcceptor;
@ -53,7 +53,7 @@ impl<F> Clone for Acceptor<F> {
} }
impl<F: Filter, C: 'static> ServiceFactory<Io<F>, C> for Acceptor<F> { impl<F: Filter, C: 'static> ServiceFactory<Io<F>, C> for Acceptor<F> {
type Response = Io<SslFilter<F>>; type Response = Io<Layer<SslFilter, F>>;
type Error = Box<dyn Error>; type Error = Box<dyn Error>;
type Service = AcceptorService<F>; type Service = AcceptorService<F>;
type InitError = (); type InitError = ();
@ -81,7 +81,7 @@ pub struct AcceptorService<F> {
} }
impl<F: Filter> Service<Io<F>> for AcceptorService<F> { impl<F: Filter> Service<Io<F>> for AcceptorService<F> {
type Response = Io<SslFilter<F>>; type Response = Io<Layer<SslFilter, F>>;
type Error = Box<dyn Error>; type Error = Box<dyn Error>;
type Future<'f> = AcceptorServiceResponse<F>; type Future<'f> = AcceptorServiceResponse<F>;
@ -115,7 +115,7 @@ pin_project_lite::pin_project! {
} }
impl<F: Filter> Future for AcceptorServiceResponse<F> { impl<F: Filter> Future for AcceptorServiceResponse<F> {
type Output = Result<Io<SslFilter<F>>, Box<dyn Error>>; type Output = Result<Io<Layer<SslFilter, F>>, Box<dyn Error>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.project().fut.poll(cx) self.project().fut.poll(cx)

View file

@ -1,10 +1,9 @@
#![allow(clippy::type_complexity)]
//! An implementation of SSL streams for ntex backed by OpenSSL //! An implementation of SSL streams for ntex backed by OpenSSL
use std::cell::{Cell, RefCell}; use std::cell::{Cell, RefCell};
use std::{any, cmp, error::Error, io, task::Context, task::Poll}; use std::{any, cmp, error::Error, io, task::Context, task::Poll};
use ntex_bytes::{BufMut, BytesVec, PoolRef}; use ntex_bytes::{BufMut, BytesVec, PoolRef};
use ntex_io::{types, Base, Filter, FilterFactory, Io, IoRef, ReadStatus, WriteStatus}; use ntex_io::{types, Filter, FilterFactory, FilterLayer, Io, Layer, ReadBuf, WriteBuf};
use ntex_util::{future::poll_fn, future::BoxFuture, ready, time, time::Millis}; use ntex_util::{future::poll_fn, future::BoxFuture, ready, time, time::Millis};
use tls_openssl::ssl::{self, NameType, SslStream}; use tls_openssl::ssl::{self, NameType, SslStream};
use tls_openssl::x509::X509; use tls_openssl::x509::X509;
@ -23,28 +22,27 @@ pub struct PeerCert(pub X509);
pub struct PeerCertChain(pub Vec<X509>); pub struct PeerCertChain(pub Vec<X509>);
/// An implementation of SSL streams /// An implementation of SSL streams
pub struct SslFilter<F = Base> { pub struct SslFilter {
inner: RefCell<SslStream<IoInner<F>>>, inner: RefCell<SslStream<IoInner>>,
pool: PoolRef, pool: PoolRef,
handshake: Cell<bool>, handshake: Cell<bool>,
read_buf: Cell<Option<BytesVec>>,
} }
struct IoInner<F> { struct IoInner {
inner: F, inner_read_buf: Option<BytesVec>,
inner_write_buf: Option<BytesVec>,
pool: PoolRef, pool: PoolRef,
write_buf: Option<BytesVec>,
} }
impl<F: Filter> io::Read for IoInner<F> { impl io::Read for IoInner {
fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> { fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> {
if let Some(mut buf) = self.inner.get_read_buf() { if let Some(mut buf) = self.inner_read_buf.take() {
if buf.is_empty() { if buf.is_empty() {
Err(io::Error::from(io::ErrorKind::WouldBlock)) Err(io::Error::from(io::ErrorKind::WouldBlock))
} else { } else {
let len = cmp::min(buf.len(), dst.len()); let len = cmp::min(buf.len(), dst.len());
dst[..len].copy_from_slice(&buf.split_to(len)); dst[..len].copy_from_slice(&buf.split_to(len));
self.inner.release_read_buf(buf); self.inner_read_buf = Some(buf);
Ok(len) Ok(len)
} }
} else { } else {
@ -53,16 +51,16 @@ impl<F: Filter> io::Read for IoInner<F> {
} }
} }
impl<F: Filter> io::Write for IoInner<F> { impl io::Write for IoInner {
fn write(&mut self, src: &[u8]) -> io::Result<usize> { fn write(&mut self, src: &[u8]) -> io::Result<usize> {
let mut buf = if let Some(mut buf) = self.inner.get_write_buf() { let mut buf = if let Some(mut buf) = self.inner_write_buf.take() {
buf.reserve(src.len()); buf.reserve(src.len());
buf buf
} else { } else {
BytesVec::with_capacity_in(src.len(), self.pool) BytesVec::with_capacity_in(src.len(), self.pool)
}; };
buf.extend_from_slice(src); buf.extend_from_slice(src);
self.inner.release_write_buf(buf)?; self.inner_write_buf = Some(buf);
Ok(src.len()) Ok(src.len())
} }
@ -71,7 +69,37 @@ impl<F: Filter> io::Write for IoInner<F> {
} }
} }
impl<F: Filter> Filter for SslFilter<F> { impl SslFilter {
fn with_buffers<F, R>(&self, buf: &mut WriteBuf<'_>, f: F) -> R
where
F: FnOnce() -> R,
{
self.inner.borrow_mut().get_mut().inner_write_buf = Some(buf.take_dst());
self.inner.borrow_mut().get_mut().inner_read_buf =
buf.with_read_buf(|b| b.take_src());
let result = f();
buf.set_dst(self.inner.borrow_mut().get_mut().inner_write_buf.take());
buf.with_read_buf(|b| {
b.set_src(self.inner.borrow_mut().get_mut().inner_read_buf.take())
});
result
}
fn set_buffers(&self, buf: &mut WriteBuf<'_>) {
self.inner.borrow_mut().get_mut().inner_write_buf = Some(buf.take_dst());
self.inner.borrow_mut().get_mut().inner_read_buf =
buf.with_read_buf(|b| b.take_src());
}
fn unset_buffers(&self, buf: &mut WriteBuf<'_>) {
buf.set_dst(self.inner.borrow_mut().get_mut().inner_write_buf.take());
buf.with_read_buf(|b| {
b.set_src(self.inner.borrow_mut().get_mut().inner_read_buf.take())
});
}
}
impl FilterLayer for SslFilter {
fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> { fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> {
const H2: &[u8] = b"h2"; const H2: &[u8] = b"h2";
@ -116,86 +144,37 @@ impl<F: Filter> Filter for SslFilter<F> {
None None
} }
} else { } else {
self.inner.borrow().get_ref().inner.query(id) None
} }
} }
fn poll_shutdown(&self) -> Poll<io::Result<()>> { fn shutdown(&self, buf: &mut WriteBuf<'_>) -> io::Result<Poll<()>> {
let ssl_result = self.inner.borrow_mut().shutdown(); let ssl_result = self.with_buffers(buf, || self.inner.borrow_mut().shutdown());
match ssl_result { match ssl_result {
Ok(ssl::ShutdownResult::Sent) => Poll::Pending, Ok(ssl::ShutdownResult::Sent) => Ok(Poll::Pending),
Ok(ssl::ShutdownResult::Received) => { Ok(ssl::ShutdownResult::Received) => Ok(Poll::Ready(())),
self.inner.borrow().get_ref().inner.poll_shutdown() Err(ref e) if e.code() == ssl::ErrorCode::ZERO_RETURN => Ok(Poll::Ready(())),
}
Err(ref e) if e.code() == ssl::ErrorCode::ZERO_RETURN => {
self.inner.borrow().get_ref().inner.poll_shutdown()
}
Err(ref e) Err(ref e)
if e.code() == ssl::ErrorCode::WANT_READ if e.code() == ssl::ErrorCode::WANT_READ
|| e.code() == ssl::ErrorCode::WANT_WRITE => || e.code() == ssl::ErrorCode::WANT_WRITE =>
{ {
Poll::Pending Ok(Poll::Pending)
} }
Err(e) => Poll::Ready(Err(e Err(e) => Err(e
.into_io_error() .into_io_error()
.unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e)))), .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e))),
} }
} }
#[inline] fn process_read_buf(&self, buf: &mut ReadBuf<'_>) -> io::Result<usize> {
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> { buf.with_write_buf(|b| self.set_buffers(b));
self.inner.borrow().get_ref().inner.poll_read_ready(cx)
}
#[inline]
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> {
self.inner.borrow().get_ref().inner.poll_write_ready(cx)
}
#[inline]
fn get_read_buf(&self) -> Option<BytesVec> {
self.read_buf.take()
}
#[inline]
fn get_write_buf(&self) -> Option<BytesVec> {
self.inner.borrow_mut().get_mut().write_buf.take()
}
#[inline]
fn release_read_buf(&self, buf: BytesVec) {
self.read_buf.set(Some(buf));
}
fn process_read_buf(&self, io: &IoRef, nbytes: usize) -> io::Result<(usize, usize)> {
// ask inner filter to process read buf
match self
.inner
.borrow_mut()
.get_ref()
.inner
.process_read_buf(io, nbytes)
{
Err(err) => io.want_shutdown(Some(err)),
Ok((n, 0)) => return Ok((n, 0)),
Ok((_, _)) => (),
}
// get processed buffer
let mut dst = if let Some(dst) = self.get_read_buf() {
dst
} else {
self.pool.get_read_buf()
};
let (hw, lw) = self.pool.read_params().unpack();
let dst = buf.get_dst();
let mut new_bytes = usize::from(self.handshake.get()); let mut new_bytes = usize::from(self.handshake.get());
loop { loop {
// make sure we've got room // make sure we've got room
let remaining = dst.remaining_mut(); self.pool.resize_read_buf(dst);
if remaining < lw {
dst.reserve(hw - remaining);
}
let chunk: &mut [u8] = unsafe { std::mem::transmute(&mut *dst.chunk_mut()) }; let chunk: &mut [u8] = unsafe { std::mem::transmute(&mut *dst.chunk_mut()) };
let ssl_result = self.inner.borrow_mut().ssl_read(chunk); let ssl_result = self.inner.borrow_mut().ssl_read(chunk);
@ -209,43 +188,61 @@ impl<F: Filter> Filter for SslFilter<F> {
if e.code() == ssl::ErrorCode::WANT_READ if e.code() == ssl::ErrorCode::WANT_READ
|| e.code() == ssl::ErrorCode::WANT_WRITE => || e.code() == ssl::ErrorCode::WANT_WRITE =>
{ {
Ok((dst.len(), new_bytes)) Ok(new_bytes)
} }
Err(ref e) if e.code() == ssl::ErrorCode::ZERO_RETURN => { Err(ref e) if e.code() == ssl::ErrorCode::ZERO_RETURN => {
io.want_shutdown(None); buf.want_shutdown();
Ok((dst.len(), new_bytes)) Ok(new_bytes)
} }
Err(e) => { Err(e) => {
log::trace!("SSL Error: {:?}", e); log::trace!("SSL Error: {:?}", e);
Err(map_to_ioerr(e)) Err(map_to_ioerr(e))
} }
}; };
self.release_read_buf(dst);
buf.with_write_buf(|b| self.unset_buffers(b));
return result; return result;
} }
} }
fn release_write_buf(&self, mut buf: BytesVec) -> Result<(), io::Error> { fn process_write_buf(&self, buf: &mut WriteBuf<'_>) -> io::Result<()> {
loop { if let Some(mut src) = buf.take_src() {
if buf.is_empty() { self.set_buffers(buf);
return Ok(());
} loop {
let ssl_result = self.inner.borrow_mut().ssl_write(&buf); if src.is_empty() {
match ssl_result { self.unset_buffers(buf);
Ok(v) => { return Ok(());
buf.split_to(v);
continue;
} }
Err(e) => { let ssl_result = self.inner.borrow_mut().ssl_write(&src);
if !buf.is_empty() { match ssl_result {
self.inner.borrow_mut().get_mut().write_buf = Some(buf); Ok(v) => {
src.split_to(v);
continue;
}
Err(e) => {
if !src.is_empty() {
buf.set_src(Some(src));
}
self.unset_buffers(buf);
return match e.code() {
ssl::ErrorCode::WANT_READ | ssl::ErrorCode::WANT_WRITE => {
buf.set_dst(
self.inner
.borrow_mut()
.get_mut()
.inner_write_buf
.take(),
);
Ok(())
}
_ => Err(map_to_ioerr(e)),
};
} }
return match e.code() {
ssl::ErrorCode::WANT_READ | ssl::ErrorCode::WANT_WRITE => Ok(()),
_ => Err(map_to_ioerr(e)),
};
} }
} }
} else {
Ok(())
} }
} }
} }
@ -283,42 +280,47 @@ impl Clone for SslAcceptor {
} }
impl<F: Filter> FilterFactory<F> for SslAcceptor { impl<F: Filter> FilterFactory<F> for SslAcceptor {
type Filter = SslFilter<F>; type Filter = SslFilter;
type Error = Box<dyn Error>; type Error = Box<dyn Error>;
type Future = BoxFuture<'static, Result<Io<Self::Filter>, Self::Error>>; type Future = BoxFuture<'static, Result<Io<Layer<Self::Filter, F>>, Self::Error>>;
fn create(self, st: Io<F>) -> Self::Future { fn create(self, io: Io<F>) -> Self::Future {
let timeout = self.timeout; let timeout = self.timeout;
let ctx_result = ssl::Ssl::new(self.acceptor.context()); let ctx_result = ssl::Ssl::new(self.acceptor.context());
Box::pin(async move { Box::pin(async move {
time::timeout(timeout, async { time::timeout(timeout, async {
let ssl = ctx_result.map_err(map_to_ioerr)?; let ssl = ctx_result.map_err(map_to_ioerr)?;
let pool = st.memory_pool(); let inner = IoInner {
let st = st.map_filter(|inner: F| { pool: io.memory_pool(),
let inner = IoInner { inner_read_buf: None,
pool, inner_write_buf: None,
inner, };
write_buf: None, let filter = SslFilter {
}; pool: io.memory_pool(),
let ssl_stream = ssl::SslStream::new(ssl, inner)?; handshake: Cell::new(true),
inner: RefCell::new(ssl::SslStream::new(ssl, inner)?),
Ok::<_, Box<dyn Error>>(SslFilter { };
pool, let io = io.add_filter(filter);
read_buf: Cell::new(None),
handshake: Cell::new(true),
inner: RefCell::new(ssl_stream),
})
})?;
poll_fn(|cx| { poll_fn(|cx| {
handle_result(st.filter().inner.borrow_mut().accept(), &st, cx) let result = io
.with_buf(|buf| {
let filter = io.filter();
filter.with_buffers(buf, || filter.inner.borrow_mut().accept())
})
.map_err(|err| {
let err: Box<dyn Error> =
io::Error::new(io::ErrorKind::Other, err).into();
err
})?;
handle_result(result, &io, cx)
}) })
.await?; .await?;
st.filter().handshake.set(false); io.filter().handshake.set(false);
Ok(st) Ok(io)
}) })
.await .await
.map_err(|_| { .map_err(|_| {
@ -341,35 +343,42 @@ impl SslConnector {
} }
impl<F: Filter> FilterFactory<F> for SslConnector { impl<F: Filter> FilterFactory<F> for SslConnector {
type Filter = SslFilter<F>; type Filter = SslFilter;
type Error = Box<dyn Error>; type Error = Box<dyn Error>;
type Future = BoxFuture<'static, Result<Io<Self::Filter>, Self::Error>>; type Future = BoxFuture<'static, Result<Io<Layer<Self::Filter, F>>, Self::Error>>;
fn create(self, st: Io<F>) -> Self::Future { fn create(self, io: Io<F>) -> Self::Future {
Box::pin(async move { Box::pin(async move {
let ssl = self.ssl; let inner = IoInner {
let pool = st.memory_pool(); pool: io.memory_pool(),
let st = st.map_filter(|inner: F| { inner_read_buf: None,
let inner = IoInner { inner_write_buf: None,
pool, };
inner, let filter = SslFilter {
write_buf: None, pool: io.memory_pool(),
}; handshake: Cell::new(true),
let ssl_stream = ssl::SslStream::new(ssl, inner)?; inner: RefCell::new(ssl::SslStream::new(self.ssl, inner)?),
};
let io = io.add_filter(filter);
Ok::<_, Box<dyn Error>>(SslFilter { poll_fn(|cx| {
pool, let result = io
read_buf: Cell::new(None), .with_buf(|buf| {
handshake: Cell::new(true), let filter = io.filter();
inner: RefCell::new(ssl_stream), filter.with_buffers(buf, || filter.inner.borrow_mut().connect())
}) })
})?; .map_err(|err| {
let err: Box<dyn Error> =
io::Error::new(io::ErrorKind::Other, err).into();
err
})?;
handle_result(result, &io, cx)
})
.await?;
poll_fn(|cx| handle_result(st.filter().inner.borrow_mut().connect(), &st, cx)) io.filter().handshake.set(false);
.await?; Ok(io)
Ok(st)
}) })
} }
} }

View file

@ -3,7 +3,7 @@ use std::{future::Future, io, marker::PhantomData, pin::Pin, sync::Arc};
use tls_rust::ServerConfig; use tls_rust::ServerConfig;
use ntex_io::{Filter, FilterFactory, Io}; use ntex_io::{Filter, FilterFactory, Io, Layer};
use ntex_service::{Service, ServiceFactory}; use ntex_service::{Service, ServiceFactory};
use ntex_util::{future::Ready, time::Millis}; use ntex_util::{future::Ready, time::Millis};
@ -52,7 +52,7 @@ impl<F> Clone for Acceptor<F> {
} }
impl<F: Filter, C: 'static> ServiceFactory<Io<F>, C> for Acceptor<F> { impl<F: Filter, C: 'static> ServiceFactory<Io<F>, C> for Acceptor<F> {
type Response = Io<TlsFilter<F>>; type Response = Io<Layer<TlsFilter, F>>;
type Error = io::Error; type Error = io::Error;
type Service = AcceptorService<F>; type Service = AcceptorService<F>;
@ -79,7 +79,7 @@ pub struct AcceptorService<F> {
} }
impl<F: Filter> Service<Io<F>> for AcceptorService<F> { impl<F: Filter> Service<Io<F>> for AcceptorService<F> {
type Response = Io<TlsFilter<F>>; type Response = Io<Layer<TlsFilter, F>>;
type Error = io::Error; type Error = io::Error;
type Future<'f> = AcceptorServiceFut<F>; type Future<'f> = AcceptorServiceFut<F>;
@ -113,7 +113,7 @@ pin_project_lite::pin_project! {
} }
impl<F: Filter> Future for AcceptorServiceFut<F> { impl<F: Filter> Future for AcceptorServiceFut<F> {
type Output = Result<Io<TlsFilter<F>>, io::Error>; type Output = Result<Io<Layer<TlsFilter, F>>, io::Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.project().fut.poll(cx) self.project().fut.poll(cx)

View file

@ -1,9 +1,9 @@
//! An implementation of SSL streams for ntex backed by OpenSSL //! An implementation of SSL streams for ntex backed by OpenSSL
use std::io::{self, Read as IoRead, Write as IoWrite}; use std::io::{self, Read as IoRead, Write as IoWrite};
use std::{any, cell::Cell, cell::RefCell, sync::Arc, task::Context, task::Poll}; use std::{any, cell::Cell, cell::RefCell, sync::Arc, task::Poll};
use ntex_bytes::{BufMut, BytesVec}; use ntex_bytes::BufMut;
use ntex_io::{types, Filter, Io, IoRef, ReadStatus, WriteStatus}; use ntex_io::{types, Filter, FilterLayer, Io, Layer, ReadBuf, WriteBuf};
use ntex_util::{future::poll_fn, ready}; use ntex_util::{future::poll_fn, ready};
use tls_rust::{ClientConfig, ClientConnection, ServerName}; use tls_rust::{ClientConfig, ClientConnection, ServerName};
@ -12,12 +12,12 @@ use crate::rustls::{IoInner, TlsFilter, Wrapper};
use super::{PeerCert, PeerCertChain}; use super::{PeerCert, PeerCertChain};
/// An implementation of SSL streams /// An implementation of SSL streams
pub struct TlsClientFilter<F> { pub struct TlsClientFilter {
inner: IoInner<F>, inner: IoInner,
session: RefCell<ClientConnection>, session: RefCell<ClientConnection>,
} }
impl<F: Filter> Filter for TlsClientFilter<F> { impl FilterLayer for TlsClientFilter {
fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> { fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> {
const H2: &[u8] = b"h2"; const H2: &[u8] = b"h2";
@ -52,71 +52,19 @@ impl<F: Filter> Filter for TlsClientFilter<F> {
None None
} }
} else { } else {
self.inner.filter.query(id) None
} }
} }
#[inline] fn process_read_buf(&self, buf: &mut ReadBuf<'_>) -> io::Result<usize> {
fn poll_shutdown(&self) -> Poll<io::Result<()>> {
self.inner.filter.poll_shutdown()
}
#[inline]
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> {
self.inner.filter.poll_read_ready(cx)
}
#[inline]
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> {
self.inner.filter.poll_write_ready(cx)
}
#[inline]
fn get_read_buf(&self) -> Option<BytesVec> {
self.inner.read_buf.take()
}
#[inline]
fn get_write_buf(&self) -> Option<BytesVec> {
self.inner.write_buf.take()
}
#[inline]
fn release_read_buf(&self, buf: BytesVec) {
self.inner.read_buf.set(Some(buf));
}
fn process_read_buf(&self, io: &IoRef, nbytes: usize) -> io::Result<(usize, usize)> {
let mut session = self.session.borrow_mut(); let mut session = self.session.borrow_mut();
// ask inner filter to process read buf
match self.inner.filter.process_read_buf(io, nbytes) {
Err(err) => io.want_shutdown(Some(err)),
Ok((_, 0)) => return Ok((0, 0)),
Ok(_) => (),
}
// get processed buffer // get processed buffer
let mut dst = if let Some(dst) = self.inner.read_buf.take() { let (src, dst) = buf.get_pair();
dst
} else {
self.inner.pool.get_read_buf()
};
let (hw, lw) = self.inner.pool.read_params().unpack();
let mut src = if let Some(src) = self.inner.filter.get_read_buf() {
src
} else {
return Ok((0, 0));
};
let mut new_bytes = usize::from(self.inner.handshake.get()); let mut new_bytes = usize::from(self.inner.handshake.get());
loop { loop {
// make sure we've got room // make sure we've got room
let remaining = dst.remaining_mut(); self.inner.pool.resize_read_buf(dst);
if remaining < lw {
dst.reserve(hw - remaining);
}
let mut cursor = io::Cursor::new(&src); let mut cursor = io::Cursor::new(&src);
let n = session.read_tls(&mut cursor)?; let n = session.read_tls(&mut cursor)?;
@ -138,73 +86,74 @@ impl<F: Filter> Filter for TlsClientFilter<F> {
} }
} }
let dst_len = dst.len(); Ok(new_bytes)
self.inner.read_buf.set(Some(dst));
self.inner.filter.release_read_buf(src);
Ok((dst_len, new_bytes))
} }
fn release_write_buf(&self, mut src: BytesVec) -> Result<(), io::Error> { fn process_write_buf(&self, buf: &mut WriteBuf<'_>) -> io::Result<()> {
let mut session = self.session.borrow_mut(); if let Some(mut src) = buf.take_src() {
let mut io = Wrapper(&self.inner); let mut session = self.session.borrow_mut();
let mut io = Wrapper(&self.inner, buf);
loop {
if !src.is_empty() {
let n = session.writer().write(&src)?;
src.split_to(n);
}
if session.wants_write() {
session.complete_io(&mut io)?;
} else {
break;
}
}
loop {
if !src.is_empty() { if !src.is_empty() {
let n = session.writer().write(&src)?; buf.set_src(Some(src));
src.split_to(n);
}
let n = session.write_tls(&mut io)?;
if n == 0 {
break;
} }
Ok(())
} else {
Ok(())
} }
if !src.is_empty() {
self.inner.write_buf.set(Some(src));
}
Ok(())
} }
} }
impl<F: Filter> TlsClientFilter<F> { impl TlsClientFilter {
pub(crate) async fn create( pub(crate) async fn create<F: Filter>(
io: Io<F>, io: Io<F>,
cfg: Arc<ClientConfig>, cfg: Arc<ClientConfig>,
domain: ServerName, domain: ServerName,
) -> Result<Io<TlsFilter<F>>, io::Error> { ) -> Result<Io<Layer<TlsFilter, F>>, io::Error> {
let pool = io.memory_pool(); let session = ClientConnection::new(cfg, domain)
let session = match ClientConnection::new(cfg, domain) { .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
Ok(session) => session, let filter = TlsFilter::new_client(TlsClientFilter {
Err(error) => return Err(io::Error::new(io::ErrorKind::Other, error)), inner: IoInner {
}; pool: io.memory_pool(),
let io = io.map_filter(|filter: F| {
let inner = IoInner {
pool,
filter,
read_buf: Cell::new(None),
write_buf: Cell::new(None),
handshake: Cell::new(true), handshake: Cell::new(true),
}; },
session: RefCell::new(session),
Ok::<_, io::Error>(TlsFilter::new_client(TlsClientFilter { });
inner, let io = io.add_filter(filter);
session: RefCell::new(session),
}))
})?;
let filter = io.filter(); let filter = io.filter();
loop { loop {
let (result, wants_read, handshaking) = { let (result, wants_read, handshaking) = io.with_buf(|buf| {
let mut session = filter.client().session.borrow_mut(); let mut session = filter.client().session.borrow_mut();
let mut wrp = Wrapper(&filter.client().inner); let mut wrp = Wrapper(&filter.client().inner, buf);
( let mut result = (
session.complete_io(&mut wrp), session.complete_io(&mut wrp),
session.wants_read(), session.wants_read(),
session.is_handshaking(), session.is_handshaking(),
) );
};
while session.wants_write() {
result.0 = session.complete_io(&mut wrp);
if result.0.is_err() {
break;
}
}
result
})?;
match result { match result {
Ok(_) => { Ok(_) => {
filter.client().inner.handshake.set(false); filter.client().inner.handshake.set(false);

View file

@ -1,11 +1,13 @@
#![allow(clippy::type_complexity)] #![allow(clippy::type_complexity)]
//! An implementation of SSL streams for ntex backed by OpenSSL //! An implementation of SSL streams for ntex backed by OpenSSL
use std::{any, cmp, future::Future, io, pin::Pin, task::Context, task::Poll}; use std::{any, cell::Cell, cmp, io, sync::Arc, task::Context, task::Poll};
use std::{cell::Cell, sync::Arc};
use ntex_bytes::{BytesVec, PoolRef}; use ntex_bytes::PoolRef;
use ntex_io::{Base, Filter, FilterFactory, Io, IoRef, ReadStatus, WriteStatus}; use ntex_io::{
use ntex_util::time::Millis; Filter, FilterFactory, FilterLayer, Io, Layer, ReadBuf, ReadStatus, WriteBuf,
WriteStatus,
};
use ntex_util::{future::BoxFuture, time::Millis};
use tls_rust::{Certificate, ClientConfig, ServerConfig, ServerName}; use tls_rust::{Certificate, ClientConfig, ServerConfig, ServerName};
mod accept; mod accept;
@ -25,33 +27,33 @@ pub struct PeerCert(pub Certificate);
pub struct PeerCertChain(pub Vec<Certificate>); pub struct PeerCertChain(pub Vec<Certificate>);
/// An implementation of SSL streams /// An implementation of SSL streams
pub struct TlsFilter<F = Base> { pub struct TlsFilter {
inner: InnerTlsFilter<F>, inner: InnerTlsFilter,
} }
enum InnerTlsFilter<F> { enum InnerTlsFilter {
Server(TlsServerFilter<F>), Server(TlsServerFilter),
Client(TlsClientFilter<F>), Client(TlsClientFilter),
} }
impl<F> TlsFilter<F> { impl TlsFilter {
fn new_server(server: TlsServerFilter<F>) -> Self { fn new_server(server: TlsServerFilter) -> Self {
TlsFilter { TlsFilter {
inner: InnerTlsFilter::Server(server), inner: InnerTlsFilter::Server(server),
} }
} }
fn new_client(client: TlsClientFilter<F>) -> Self { fn new_client(client: TlsClientFilter) -> Self {
TlsFilter { TlsFilter {
inner: InnerTlsFilter::Client(client), inner: InnerTlsFilter::Client(client),
} }
} }
fn server(&self) -> &TlsServerFilter<F> { fn server(&self) -> &TlsServerFilter {
match self.inner { match self.inner {
InnerTlsFilter::Server(ref server) => server, InnerTlsFilter::Server(ref server) => server,
_ => unreachable!(), _ => unreachable!(),
} }
} }
fn client(&self) -> &TlsClientFilter<F> { fn client(&self) -> &TlsClientFilter {
match self.inner { match self.inner {
InnerTlsFilter::Client(ref server) => server, InnerTlsFilter::Client(ref server) => server,
_ => unreachable!(), _ => unreachable!(),
@ -59,7 +61,7 @@ impl<F> TlsFilter<F> {
} }
} }
impl<F: Filter> Filter for TlsFilter<F> { impl FilterLayer for TlsFilter {
#[inline] #[inline]
fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> { fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> {
match self.inner { match self.inner {
@ -69,10 +71,10 @@ impl<F: Filter> Filter for TlsFilter<F> {
} }
#[inline] #[inline]
fn poll_shutdown(&self) -> Poll<io::Result<()>> { fn shutdown(&self, buf: &mut WriteBuf<'_>) -> io::Result<Poll<()>> {
match self.inner { match self.inner {
InnerTlsFilter::Server(ref f) => f.poll_shutdown(), InnerTlsFilter::Server(ref f) => f.shutdown(buf),
InnerTlsFilter::Client(ref f) => f.poll_shutdown(), InnerTlsFilter::Client(ref f) => f.shutdown(buf),
} }
} }
@ -93,42 +95,18 @@ impl<F: Filter> Filter for TlsFilter<F> {
} }
#[inline] #[inline]
fn get_read_buf(&self) -> Option<BytesVec> { fn process_read_buf(&self, buf: &mut ReadBuf<'_>) -> io::Result<usize> {
match self.inner { match self.inner {
InnerTlsFilter::Server(ref f) => f.get_read_buf(), InnerTlsFilter::Server(ref f) => f.process_read_buf(buf),
InnerTlsFilter::Client(ref f) => f.get_read_buf(), InnerTlsFilter::Client(ref f) => f.process_read_buf(buf),
} }
} }
#[inline] #[inline]
fn get_write_buf(&self) -> Option<BytesVec> { fn process_write_buf(&self, buf: &mut WriteBuf<'_>) -> io::Result<()> {
match self.inner { match self.inner {
InnerTlsFilter::Server(ref f) => f.get_write_buf(), InnerTlsFilter::Server(ref f) => f.process_write_buf(buf),
InnerTlsFilter::Client(ref f) => f.get_write_buf(), InnerTlsFilter::Client(ref f) => f.process_write_buf(buf),
}
}
#[inline]
fn release_read_buf(&self, buf: BytesVec) {
match self.inner {
InnerTlsFilter::Server(ref f) => f.release_read_buf(buf),
InnerTlsFilter::Client(ref f) => f.release_read_buf(buf),
}
}
#[inline]
fn process_read_buf(&self, io: &IoRef, nb: usize) -> io::Result<(usize, usize)> {
match self.inner {
InnerTlsFilter::Server(ref f) => f.process_read_buf(io, nb),
InnerTlsFilter::Client(ref f) => f.process_read_buf(io, nb),
}
}
#[inline]
fn release_write_buf(&self, src: BytesVec) -> Result<(), io::Error> {
match self.inner {
InnerTlsFilter::Server(ref f) => f.release_write_buf(src),
InnerTlsFilter::Client(ref f) => f.release_write_buf(src),
} }
} }
} }
@ -172,10 +150,10 @@ impl Clone for TlsAcceptor {
} }
impl<F: Filter> FilterFactory<F> for TlsAcceptor { impl<F: Filter> FilterFactory<F> for TlsAcceptor {
type Filter = TlsFilter<F>; type Filter = TlsFilter;
type Error = io::Error; type Error = io::Error;
type Future = Pin<Box<dyn Future<Output = Result<Io<Self::Filter>, io::Error>>>>; type Future = BoxFuture<'static, Result<Io<Layer<Self::Filter, F>>, io::Error>>;
fn create(self, st: Io<F>) -> Self::Future { fn create(self, st: Io<F>) -> Self::Future {
let cfg = self.cfg.clone(); let cfg = self.cfg.clone();
@ -227,10 +205,10 @@ impl Clone for TlsConnectorConfigured {
} }
impl<F: Filter> FilterFactory<F> for TlsConnectorConfigured { impl<F: Filter> FilterFactory<F> for TlsConnectorConfigured {
type Filter = TlsFilter<F>; type Filter = TlsFilter;
type Error = io::Error; type Error = io::Error;
type Future = Pin<Box<dyn Future<Output = Result<Io<Self::Filter>, io::Error>>>>; type Future = BoxFuture<'static, Result<Io<Layer<Self::Filter, F>>, io::Error>>;
fn create(self, st: Io<F>) -> Self::Future { fn create(self, st: Io<F>) -> Self::Future {
let cfg = self.cfg; let cfg = self.cfg;
@ -240,44 +218,31 @@ impl<F: Filter> FilterFactory<F> for TlsConnectorConfigured {
} }
} }
pub(crate) struct IoInner<F> { pub(crate) struct IoInner {
filter: F,
pool: PoolRef, pool: PoolRef,
read_buf: Cell<Option<BytesVec>>,
write_buf: Cell<Option<BytesVec>>,
handshake: Cell<bool>, handshake: Cell<bool>,
} }
pub(crate) struct Wrapper<'a, F>(&'a IoInner<F>); pub(crate) struct Wrapper<'a, 'b>(&'a IoInner, &'a mut WriteBuf<'b>);
impl<'a, F: Filter> io::Read for Wrapper<'a, F> { impl<'a, 'b> io::Read for Wrapper<'a, 'b> {
fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> { fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> {
if let Some(mut read_buf) = self.0.filter.get_read_buf() { self.1.with_read_buf(|buf| {
let read_buf = buf.get_src();
let len = cmp::min(read_buf.len(), dst.len()); let len = cmp::min(read_buf.len(), dst.len());
let result = if len > 0 { if len > 0 {
dst[..len].copy_from_slice(&read_buf.split_to(len)); dst[..len].copy_from_slice(&read_buf.split_to(len));
Ok(len) Ok(len)
} else { } else {
Err(io::Error::new(io::ErrorKind::WouldBlock, "")) Err(io::Error::new(io::ErrorKind::WouldBlock, ""))
}; }
self.0.filter.release_read_buf(read_buf); })
result
} else {
Err(io::Error::new(io::ErrorKind::WouldBlock, ""))
}
} }
} }
impl<'a, F: Filter> io::Write for Wrapper<'a, F> { impl<'a, 'b> io::Write for Wrapper<'a, 'b> {
fn write(&mut self, src: &[u8]) -> io::Result<usize> { fn write(&mut self, src: &[u8]) -> io::Result<usize> {
let mut buf = if let Some(mut buf) = self.0.filter.get_write_buf() { self.1.get_dst().extend_from_slice(src);
buf.reserve(src.len());
buf
} else {
BytesVec::with_capacity_in(src.len(), self.0.pool)
};
buf.extend_from_slice(src);
self.0.filter.release_write_buf(buf)?;
Ok(src.len()) Ok(src.len())
} }

View file

@ -1,9 +1,9 @@
//! An implementation of SSL streams for ntex backed by OpenSSL //! An implementation of SSL streams for ntex backed by OpenSSL
use std::io::{self, Read as IoRead, Write as IoWrite}; use std::io::{self, Read as IoRead, Write as IoWrite};
use std::{any, cell::Cell, cell::RefCell, sync::Arc, task::Context, task::Poll}; use std::{any, cell::Cell, cell::RefCell, sync::Arc, task::Poll};
use ntex_bytes::{BufMut, BytesVec}; use ntex_bytes::BufMut;
use ntex_io::{types, Filter, Io, IoRef, ReadStatus, WriteStatus}; use ntex_io::{types, Filter, FilterLayer, Io, Layer, ReadBuf, WriteBuf};
use ntex_util::{future::poll_fn, ready, time, time::Millis}; use ntex_util::{future::poll_fn, ready, time, time::Millis};
use tls_rust::{ServerConfig, ServerConnection}; use tls_rust::{ServerConfig, ServerConnection};
@ -13,12 +13,12 @@ use crate::Servername;
use super::{PeerCert, PeerCertChain}; use super::{PeerCert, PeerCertChain};
/// An implementation of SSL streams /// An implementation of SSL streams
pub struct TlsServerFilter<F> { pub struct TlsServerFilter {
inner: IoInner<F>, inner: IoInner,
session: RefCell<ServerConnection>, session: RefCell<ServerConnection>,
} }
impl<F: Filter> Filter for TlsServerFilter<F> { impl FilterLayer for TlsServerFilter {
fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> { fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> {
const H2: &[u8] = b"h2"; const H2: &[u8] = b"h2";
@ -59,71 +59,19 @@ impl<F: Filter> Filter for TlsServerFilter<F> {
None None
} }
} else { } else {
self.inner.filter.query(id) None
} }
} }
#[inline] fn process_read_buf(&self, buf: &mut ReadBuf<'_>) -> io::Result<usize> {
fn poll_shutdown(&self) -> Poll<io::Result<()>> {
self.inner.filter.poll_shutdown()
}
#[inline]
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> {
self.inner.filter.poll_read_ready(cx)
}
#[inline]
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> {
self.inner.filter.poll_write_ready(cx)
}
#[inline]
fn get_read_buf(&self) -> Option<BytesVec> {
self.inner.read_buf.take()
}
#[inline]
fn get_write_buf(&self) -> Option<BytesVec> {
self.inner.write_buf.take()
}
#[inline]
fn release_read_buf(&self, buf: BytesVec) {
self.inner.read_buf.set(Some(buf));
}
fn process_read_buf(&self, io: &IoRef, nbytes: usize) -> io::Result<(usize, usize)> {
let mut session = self.session.borrow_mut(); let mut session = self.session.borrow_mut();
// ask inner filter to process read buf
match self.inner.filter.process_read_buf(io, nbytes) {
Err(err) => io.want_shutdown(Some(err)),
Ok((_, 0)) => return Ok((0, 0)),
Ok(_) => (),
}
// get processed buffer // get processed buffer
let mut dst = if let Some(dst) = self.inner.read_buf.take() { let (src, dst) = buf.get_pair();
dst
} else {
self.inner.pool.get_read_buf()
};
let (hw, lw) = self.inner.pool.read_params().unpack();
let mut src = if let Some(src) = self.inner.filter.get_read_buf() {
src
} else {
return Ok((0, 0));
};
let mut new_bytes = usize::from(self.inner.handshake.get()); let mut new_bytes = usize::from(self.inner.handshake.get());
loop { loop {
// make sure we've got room // make sure we've got room
let remaining = dst.remaining_mut(); self.inner.pool.resize_read_buf(dst);
if remaining < lw {
dst.reserve(hw - remaining);
}
let mut cursor = io::Cursor::new(&src); let mut cursor = io::Cursor::new(&src);
let n = session.read_tls(&mut cursor)?; let n = session.read_tls(&mut cursor)?;
@ -145,73 +93,73 @@ impl<F: Filter> Filter for TlsServerFilter<F> {
} }
} }
let dst_len = dst.len(); Ok(new_bytes)
self.inner.read_buf.set(Some(dst));
self.inner.filter.release_read_buf(src);
Ok((dst_len, new_bytes))
} }
fn release_write_buf(&self, mut src: BytesVec) -> Result<(), io::Error> { fn process_write_buf(&self, buf: &mut WriteBuf<'_>) -> io::Result<()> {
let mut session = self.session.borrow_mut(); if let Some(mut src) = buf.take_src() {
let mut io = Wrapper(&self.inner); let mut session = self.session.borrow_mut();
let mut io = Wrapper(&self.inner, buf);
loop {
if !src.is_empty() {
let n = session.writer().write(&src)?;
src.split_to(n);
}
if session.wants_write() {
session.complete_io(&mut io)?;
} else {
break;
}
}
loop {
if !src.is_empty() { if !src.is_empty() {
let n = session.writer().write(&src)?; buf.set_src(Some(src));
src.split_to(n);
} }
let n = session.write_tls(&mut io)?;
if n == 0 {
break;
}
}
if !src.is_empty() {
self.inner.write_buf.set(Some(src));
} }
Ok(()) Ok(())
} }
} }
impl<F: Filter> TlsServerFilter<F> { impl TlsServerFilter {
pub(crate) async fn create( pub(crate) async fn create<F: Filter>(
io: Io<F>, io: Io<F>,
cfg: Arc<ServerConfig>, cfg: Arc<ServerConfig>,
timeout: Millis, timeout: Millis,
) -> Result<Io<TlsFilter<F>>, io::Error> { ) -> Result<Io<Layer<TlsFilter, F>>, io::Error> {
time::timeout(timeout, async { time::timeout(timeout, async {
let pool = io.memory_pool(); let session = ServerConnection::new(cfg)
let session = match ServerConnection::new(cfg) { .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
Ok(session) => session, let filter = TlsFilter::new_server(TlsServerFilter {
Err(error) => return Err(io::Error::new(io::ErrorKind::Other, error)), session: RefCell::new(session),
}; inner: IoInner {
let io = io.map_filter(|filter: F| { pool: io.memory_pool(),
let inner = IoInner {
pool,
filter,
read_buf: Cell::new(None),
write_buf: Cell::new(None),
handshake: Cell::new(true), handshake: Cell::new(true),
}; },
});
Ok::<_, io::Error>(TlsFilter::new_server(TlsServerFilter { let io = io.add_filter(filter);
inner,
session: RefCell::new(session),
}))
})?;
let filter = io.filter(); let filter = io.filter();
loop { loop {
let (result, wants_read, handshaking) = { let (result, wants_read, handshaking) = io.with_buf(|buf| {
let mut session = filter.server().session.borrow_mut(); let mut session = filter.server().session.borrow_mut();
let mut wrp = Wrapper(&filter.server().inner); let mut wrp = Wrapper(&filter.server().inner, buf);
( let mut result = (
session.complete_io(&mut wrp), session.complete_io(&mut wrp),
session.wants_read(), session.wants_read(),
session.is_handshaking(), session.is_handshaking(),
) );
};
while session.wants_write() {
result.0 = session.complete_io(&mut wrp);
if result.0.is_err() {
break;
}
}
result
})?;
match result { match result {
Ok(_) => { Ok(_) => {
filter.server().inner.handshake.set(false); filter.server().inner.handshake.set(false);

View file

@ -1,6 +1,6 @@
[package] [package]
name = "ntex-tokio" name = "ntex-tokio"
version = "0.2.0" version = "0.2.1"
authors = ["ntex contributors <team@ntex.rs>"] authors = ["ntex contributors <team@ntex.rs>"]
description = "tokio intergration for ntex framework" description = "tokio intergration for ntex framework"
keywords = ["network", "framework", "async", "futures"] keywords = ["network", "framework", "async", "futures"]
@ -16,8 +16,8 @@ name = "ntex_tokio"
path = "src/lib.rs" path = "src/lib.rs"
[dependencies] [dependencies]
ntex-bytes = "0.1.18" ntex-bytes = "0.1.19"
ntex-io = "0.2.0" ntex-io = "0.2.1"
ntex-util = "0.2.0" ntex-util = "0.2.0"
log = "0.4" log = "0.4"
pin-project-lite = "0.2" pin-project-lite = "0.2"

View file

@ -54,73 +54,42 @@ impl Future for ReadTask {
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.as_ref(); let this = self.as_ref();
loop { this.state.with_buf(|buf, hw, lw| {
match ready!(this.state.poll_ready(cx)) { match ready!(this.state.poll_ready(cx)) {
ReadStatus::Ready => { ReadStatus::Ready => {
let pool = this.state.memory_pool();
let mut io = this.io.borrow_mut();
let mut buf = self.state.get_read_buf();
let (hw, lw) = pool.read_params().unpack();
// read data from socket // read data from socket
let mut new_bytes = 0; let mut io = this.io.borrow_mut();
let mut close = false;
let mut pending = false;
loop { loop {
// make sure we've got room // make sure we've got room
let remaining = buf.remaining_mut(); let remaining = buf.remaining_mut();
if remaining < lw { if remaining < lw {
buf.reserve(hw - remaining); buf.reserve(hw - remaining);
} }
return match poll_read_buf(Pin::new(&mut *io), cx, buf) {
match poll_read_buf(Pin::new(&mut *io), cx, &mut buf) { Poll::Pending => Poll::Pending,
Poll::Pending => {
pending = true;
break;
}
Poll::Ready(Ok(n)) => { Poll::Ready(Ok(n)) => {
if n == 0 { if n == 0 {
log::trace!("tokio stream is disconnected"); log::trace!("tokio stream is disconnected");
close = true; Poll::Ready(Ok(()))
} else if buf.len() < hw {
continue;
} else { } else {
new_bytes += n; Poll::Pending
if new_bytes <= hw {
continue;
}
} }
break;
} }
Poll::Ready(Err(err)) => { Poll::Ready(Err(err)) => {
log::trace!("read task failed on io {:?}", err); log::trace!("read task failed on io {:?}", err);
drop(io); Poll::Ready(Err(err))
this.state.release_read_buf(buf, new_bytes);
this.state.close(Some(err));
return Poll::Ready(());
} }
} };
} }
drop(io);
if new_bytes == 0 && close {
this.state.close(None);
return Poll::Ready(());
}
this.state.release_read_buf(buf, new_bytes);
return if close {
this.state.close(None);
Poll::Ready(())
} else if pending {
Poll::Pending
} else {
continue;
};
} }
ReadStatus::Terminate => { ReadStatus::Terminate => {
log::trace!("read task is instructed to shutdown"); log::trace!("read task is instructed to shutdown");
return Poll::Ready(()); Poll::Ready(Ok(()))
} }
} }
} })
} }
} }
@ -269,14 +238,14 @@ impl Future for WriteTask {
if read_buf.filled().is_empty() => if read_buf.filled().is_empty() =>
{ {
this.state.close(None); this.state.close(None);
log::trace!("write task is stopped"); log::trace!("tokio write task is stopped");
return Poll::Ready(()); return Poll::Ready(());
} }
Poll::Pending => { Poll::Pending => {
*count += read_buf.filled().len() as u16; *count += read_buf.filled().len() as u16;
if *count > 4096 { if *count > 4096 {
log::trace!( log::trace!(
"write task is stopped, too much input" "tokio write task is stopped, too much input"
); );
this.state.close(None); this.state.close(None);
return Poll::Ready(()); return Poll::Ready(());
@ -344,7 +313,7 @@ pub(super) fn flush_io<T: AsyncRead + AsyncWrite + Unpin>(
} }
} }
} }
log::trace!("flushed {} bytes", written); // log::trace!("flushed {} bytes", written);
// remove written data // remove written data
let result = if written == len { let result = if written == len {
@ -501,18 +470,11 @@ mod unixstream {
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.as_ref(); let this = self.as_ref();
loop { this.state.with_buf(|buf, hw, lw| {
match ready!(this.state.poll_ready(cx)) { match ready!(this.state.poll_ready(cx)) {
ReadStatus::Ready => { ReadStatus::Ready => {
let pool = this.state.memory_pool();
let mut io = this.io.borrow_mut();
let mut buf = self.state.get_read_buf();
let (hw, lw) = pool.read_params().unpack();
// read data from socket // read data from socket
let mut new_bytes = 0; let mut io = this.io.borrow_mut();
let mut close = false;
let mut pending = false;
loop { loop {
// make sure we've got room // make sure we've got room
let remaining = buf.remaining_mut(); let remaining = buf.remaining_mut();
@ -520,54 +482,31 @@ mod unixstream {
buf.reserve(hw - remaining); buf.reserve(hw - remaining);
} }
match poll_read_buf(Pin::new(&mut *io), cx, &mut buf) { return match poll_read_buf(Pin::new(&mut *io), cx, buf) {
Poll::Pending => { Poll::Pending => Poll::Pending,
pending = true;
break;
}
Poll::Ready(Ok(n)) => { Poll::Ready(Ok(n)) => {
if n == 0 { if n == 0 {
log::trace!("unix stream is disconnected"); log::trace!("tokio unix stream is disconnected");
close = true; Poll::Ready(Ok(()))
} else if buf.len() < hw {
continue;
} else { } else {
new_bytes += n; Poll::Pending
if new_bytes <= hw {
continue;
}
} }
break;
} }
Poll::Ready(Err(err)) => { Poll::Ready(Err(err)) => {
log::trace!("read task failed on io {:?}", err); log::trace!("unix stream read task failed {:?}", err);
drop(io); Poll::Ready(Err(err))
this.state.release_read_buf(buf, new_bytes);
this.state.close(Some(err));
return Poll::Ready(());
} }
} };
} }
drop(io);
if new_bytes == 0 && close {
this.state.close(None);
return Poll::Ready(());
}
this.state.release_read_buf(buf, new_bytes);
return if close {
this.state.close(None);
Poll::Ready(())
} else if pending {
Poll::Pending
} else {
continue;
};
} }
ReadStatus::Terminate => { ReadStatus::Terminate => {
log::trace!("read task is instructed to shutdown"); log::trace!("read task is instructed to shutdown");
return Poll::Ready(()); Poll::Ready(Ok(()))
} }
} }
} })
} }
} }
@ -735,10 +674,6 @@ pub fn poll_read_buf<T: AsyncRead>(
cx: &mut Context<'_>, cx: &mut Context<'_>,
buf: &mut BytesVec, buf: &mut BytesVec,
) -> Poll<io::Result<usize>> { ) -> Poll<io::Result<usize>> {
if !buf.has_remaining_mut() {
return Poll::Ready(Ok(0));
}
let n = { let n = {
let dst = let dst =
unsafe { &mut *(buf.chunk_mut() as *mut _ as *mut [mem::MaybeUninit<u8>]) }; unsafe { &mut *(buf.chunk_mut() as *mut _ as *mut [mem::MaybeUninit<u8>]) };

View file

@ -1,6 +1,6 @@
[package] [package]
name = "ntex" name = "ntex"
version = "0.6.0" version = "0.6.1"
authors = ["ntex contributors <team@ntex.rs>"] authors = ["ntex contributors <team@ntex.rs>"]
description = "Framework for composable network services" description = "Framework for composable network services"
readme = "README.md" readme = "README.md"
@ -49,20 +49,20 @@ async-std = ["ntex-rt/async-std", "ntex-async-std", "ntex-connect/async-std"]
[dependencies] [dependencies]
ntex-codec = "0.6.2" ntex-codec = "0.6.2"
ntex-connect = "0.2.0" ntex-connect = "0.2.1"
ntex-http = "0.1.9" ntex-http = "0.1.9"
ntex-router = "0.5.1" ntex-router = "0.5.1"
ntex-service = "1.0.0" ntex-service = "1.0.0"
ntex-macros = "0.1.3" ntex-macros = "0.1.3"
ntex-util = "0.2.0" ntex-util = "0.2.0"
ntex-bytes = "0.1.18" ntex-bytes = "0.1.19"
ntex-h2 = "0.2.0" ntex-h2 = "0.2.0"
ntex-rt = "0.4.7" ntex-rt = "0.4.7"
ntex-io = "0.2.0" ntex-io = "0.2.1"
ntex-tls = "0.2.0" ntex-tls = "0.2.1"
ntex-tokio = { version = "0.2.0", optional = true } ntex-tokio = { version = "0.2.1", optional = true }
ntex-glommio = { version = "0.2.0", optional = true } ntex-glommio = { version = "0.2.1", optional = true }
ntex-async-std = { version = "0.2.0", optional = true } ntex-async-std = { version = "0.2.1", optional = true }
async-oneshot = "0.5.0" async-oneshot = "0.5.0"
async-channel = "1.8.0" async-channel = "1.8.0"
@ -88,7 +88,7 @@ percent-encoding = "2.1"
serde_json = "1.0" serde_json = "1.0"
serde_urlencoded = "0.7" serde_urlencoded = "0.7"
url-pkg = { version = "2.1", package = "url", optional = true } url-pkg = { version = "2.1", package = "url", optional = true }
coo-kie = { version = "0.16", package = "cookie", optional = true } coo-kie = { version = "0.17", package = "cookie", optional = true }
# openssl # openssl
tls-openssl = { version="0.10", package = "openssl", optional = true } tls-openssl = { version="0.10", package = "openssl", optional = true }

View file

@ -206,6 +206,7 @@ struct H2ClientInner {
streams: RefCell<HashMap<frame::StreamId, StreamInfo>>, streams: RefCell<HashMap<frame::StreamId, StreamInfo>>,
} }
#[derive(Debug)]
struct StreamInfo { struct StreamInfo {
tx: Option<oneshot::Sender<Result<(ResponseHead, Payload), SendRequestError>>>, tx: Option<oneshot::Sender<Result<(ResponseHead, Payload), SendRequestError>>>,
stream: Option<h2::Stream>, stream: Option<h2::Stream>,

View file

@ -10,7 +10,7 @@
//! //!
//! let response = client.get("http://www.rust-lang.org") // <- Create request builder //! let response = client.get("http://www.rust-lang.org") // <- Create request builder
//! .header("User-Agent", "ntex::web") //! .header("User-Agent", "ntex::web")
//! .send() // <- Send http request //! .send() // <- Send http request
//! .await; //! .await;
//! //!
//! println!("Response: {:?}", response); //! println!("Response: {:?}", response);

View file

@ -626,7 +626,6 @@ mod tests {
#[crate::rt_test] #[crate::rt_test]
async fn test_basics() { async fn test_basics() {
env_logger::init();
let store = Rc::new(RefCell::new(Vec::new())); let store = Rc::new(RefCell::new(Vec::new()));
let store2 = store.clone(); let store2 = store.clone();

View file

@ -146,7 +146,7 @@ const DATE_VALUE_DEFAULT: [u8; DATE_VALUE_LENGTH_HDR] = [
b'0', b'0', b'0', b'0', b'0', b'0', b'0', b'\r', b'\n', b'\r', b'\n', b'0', b'0', b'0', b'0', b'0', b'0', b'0', b'\r', b'\n', b'\r', b'\n',
]; ];
#[derive(Clone)] #[derive(Debug, Clone)]
pub struct DateService(Rc<DateServiceInner>); pub struct DateService(Rc<DateServiceInner>);
impl Default for DateService { impl Default for DateService {
@ -155,6 +155,7 @@ impl Default for DateService {
} }
} }
#[derive(Debug)]
struct DateServiceInner { struct DateServiceInner {
current: Cell<bool>, current: Cell<bool>,
current_time: Cell<time::Instant>, current_time: Cell<time::Instant>,

View file

@ -21,16 +21,19 @@ bitflags! {
} }
} }
#[derive(Debug)]
/// HTTP/1 Codec /// HTTP/1 Codec
pub struct ClientCodec { pub struct ClientCodec {
inner: ClientCodecInner, inner: ClientCodecInner,
} }
#[derive(Debug)]
/// HTTP/1 Payload Codec /// HTTP/1 Payload Codec
pub struct ClientPayloadCodec { pub struct ClientPayloadCodec {
inner: ClientCodecInner, inner: ClientCodecInner,
} }
#[derive(Debug)]
struct ClientCodecInner { struct ClientCodecInner {
timer: DateService, timer: DateService,
decoder: decoder::MessageDecoder<ResponseHead>, decoder: decoder::MessageDecoder<ResponseHead>,

View file

@ -14,6 +14,7 @@ use super::MAX_BUFFER_SIZE;
const MAX_HEADERS: usize = 96; const MAX_HEADERS: usize = 96;
#[derive(Debug)]
/// Incoming messagd decoder /// Incoming messagd decoder
pub(super) struct MessageDecoder<T: MessageType>(PhantomData<T>); pub(super) struct MessageDecoder<T: MessageType>(PhantomData<T>);

View file

@ -1015,8 +1015,8 @@ mod tests {
} }
#[crate::rt_test] #[crate::rt_test]
/// /// h1 dispatcher still processes all incoming requests /// h1 dispatcher still processes all incoming requests
/// /// but it does not write any data to socket /// but it does not write any data to socket
async fn test_write_disconnected() { async fn test_write_disconnected() {
let num = Arc::new(AtomicUsize::new(0)); let num = Arc::new(AtomicUsize::new(0));
let num2 = num.clone(); let num2 = num.clone();
@ -1039,6 +1039,7 @@ mod tests {
assert_eq!(num.load(Ordering::Relaxed), 1); assert_eq!(num.load(Ordering::Relaxed), 1);
} }
/// max http message size is 32k (no payload)
#[crate::rt_test] #[crate::rt_test]
async fn test_read_large_message() { async fn test_read_large_message() {
let (client, server) = Io::create(); let (client, server) = Io::create();

View file

@ -3,8 +3,7 @@ use std::{cell::RefCell, error::Error, fmt, marker, rc::Rc, task};
use crate::http::body::MessageBody; use crate::http::body::MessageBody;
use crate::http::config::{DispatcherConfig, OnRequest, ServiceConfig}; use crate::http::config::{DispatcherConfig, OnRequest, ServiceConfig};
use crate::http::error::{DispatchError, ResponseError}; use crate::http::error::{DispatchError, ResponseError};
use crate::http::request::Request; use crate::http::{request::Request, response::Response};
use crate::http::response::Response;
use crate::io::{types, Filter, Io}; use crate::io::{types, Filter, Io};
use crate::service::{IntoServiceFactory, Service, ServiceFactory}; use crate::service::{IntoServiceFactory, Service, ServiceFactory};
use crate::{time::Millis, util::BoxFuture}; use crate::{time::Millis, util::BoxFuture};
@ -56,9 +55,9 @@ mod openssl {
use tls_openssl::ssl::SslAcceptor; use tls_openssl::ssl::SslAcceptor;
use super::*; use super::*;
use crate::{server::SslError, service::pipeline_factory}; use crate::{io::Layer, server::SslError, service::pipeline_factory};
impl<F, S, B, X, U> H1Service<SslFilter<F>, S, B, X, U> impl<F, S, B, X, U> H1Service<Layer<SslFilter, F>, S, B, X, U>
where where
F: Filter, F: Filter,
S: ServiceFactory<Request> + 'static, S: ServiceFactory<Request> + 'static,
@ -69,7 +68,8 @@ mod openssl {
X: ServiceFactory<Request, Response = Request> + 'static, X: ServiceFactory<Request, Response = Request> + 'static,
X::Error: ResponseError, X::Error: ResponseError,
X::InitError: fmt::Debug, X::InitError: fmt::Debug,
U: ServiceFactory<(Request, Io<SslFilter<F>>, Codec), Response = ()> + 'static, U: ServiceFactory<(Request, Io<Layer<SslFilter, F>>, Codec), Response = ()>
+ 'static,
U::Error: fmt::Display + Error, U::Error: fmt::Display + Error,
U::InitError: fmt::Debug, U::InitError: fmt::Debug,
{ {
@ -102,9 +102,9 @@ mod rustls {
use tls_rustls::ServerConfig; use tls_rustls::ServerConfig;
use super::*; use super::*;
use crate::{server::SslError, service::pipeline_factory}; use crate::{io::Layer, server::SslError, service::pipeline_factory};
impl<F, S, B, X, U> H1Service<TlsFilter<F>, S, B, X, U> impl<F, S, B, X, U> H1Service<Layer<TlsFilter, F>, S, B, X, U>
where where
F: Filter, F: Filter,
S: ServiceFactory<Request> + 'static, S: ServiceFactory<Request> + 'static,
@ -115,7 +115,8 @@ mod rustls {
X: ServiceFactory<Request, Response = Request> + 'static, X: ServiceFactory<Request, Response = Request> + 'static,
X::Error: ResponseError, X::Error: ResponseError,
X::InitError: fmt::Debug, X::InitError: fmt::Debug,
U: ServiceFactory<(Request, Io<TlsFilter<F>>, Codec), Response = ()> + 'static, U: ServiceFactory<(Request, Io<Layer<TlsFilter, F>>, Codec), Response = ()>
+ 'static,
U::Error: fmt::Display + Error, U::Error: fmt::Display + Error,
U::InitError: fmt::Debug, U::InitError: fmt::Debug,
{ {

View file

@ -64,6 +64,7 @@ impl Stream for Payload {
} }
} }
#[derive(Debug)]
/// Sender part of the payload stream /// Sender part of the payload stream
pub struct PayloadSender { pub struct PayloadSender {
inner: Weak<RefCell<Inner>>, inner: Weak<RefCell<Inner>>,

View file

@ -49,13 +49,11 @@ mod openssl {
use ntex_tls::openssl::{Acceptor, SslFilter}; use ntex_tls::openssl::{Acceptor, SslFilter};
use tls_openssl::ssl::SslAcceptor; use tls_openssl::ssl::SslAcceptor;
use crate::io::Filter; use crate::{io::Layer, server::SslError, service::pipeline_factory};
use crate::server::SslError;
use crate::service::pipeline_factory;
use super::*; use super::*;
impl<F, S, B> H2Service<SslFilter<F>, S, B> impl<F, S, B> H2Service<Layer<SslFilter, F>, S, B>
where where
F: Filter, F: Filter,
S: ServiceFactory<Request> + 'static, S: ServiceFactory<Request> + 'static,
@ -90,9 +88,9 @@ mod rustls {
use tls_rustls::ServerConfig; use tls_rustls::ServerConfig;
use super::*; use super::*;
use crate::{server::SslError, service::pipeline_factory}; use crate::{io::Layer, server::SslError, service::pipeline_factory};
impl<F, S, B> H2Service<TlsFilter<F>, S, B> impl<F, S, B> H2Service<Layer<TlsFilter, F>, S, B>
where where
F: Filter, F: Filter,
S: ServiceFactory<Request> + 'static, S: ServiceFactory<Request> + 'static,
@ -394,7 +392,7 @@ where
loop { loop {
match poll_fn(|cx| body.poll_next_chunk(cx)).await { match poll_fn(|cx| body.poll_next_chunk(cx)).await {
None => { None => {
log::debug!("{:?} closing sending payload", msg.id()); log::debug!("{:?} closing payload stream", msg.id());
msg.stream().send_payload(Bytes::new(), true).await?; msg.stream().send_payload(Bytes::new(), true).await?;
break; break;
} }

View file

@ -146,10 +146,9 @@ mod openssl {
use tls_openssl::ssl::SslAcceptor; use tls_openssl::ssl::SslAcceptor;
use super::*; use super::*;
use crate::server::SslError; use crate::{io::Layer, server::SslError, service::pipeline_factory};
use crate::service::pipeline_factory;
impl<F, S, B, X, U> HttpService<SslFilter<F>, S, B, X, U> impl<F, S, B, X, U> HttpService<Layer<SslFilter, F>, S, B, X, U>
where where
F: Filter, F: Filter,
S: ServiceFactory<Request> + 'static, S: ServiceFactory<Request> + 'static,
@ -160,7 +159,8 @@ mod openssl {
X: ServiceFactory<Request, Response = Request> + 'static, X: ServiceFactory<Request, Response = Request> + 'static,
X::Error: ResponseError, X::Error: ResponseError,
X::InitError: fmt::Debug, X::InitError: fmt::Debug,
U: ServiceFactory<(Request, Io<SslFilter<F>>, h1::Codec), Response = ()> + 'static, U: ServiceFactory<(Request, Io<Layer<SslFilter, F>>, h1::Codec), Response = ()>
+ 'static,
U::Error: fmt::Display + error::Error, U::Error: fmt::Display + error::Error,
U::InitError: fmt::Debug, U::InitError: fmt::Debug,
{ {
@ -191,9 +191,9 @@ mod rustls {
use tls_rustls::ServerConfig; use tls_rustls::ServerConfig;
use super::*; use super::*;
use crate::{server::SslError, service::pipeline_factory}; use crate::{io::Layer, server::SslError, service::pipeline_factory};
impl<F, S, B, X, U> HttpService<TlsFilter<F>, S, B, X, U> impl<F, S, B, X, U> HttpService<Layer<TlsFilter, F>, S, B, X, U>
where where
F: Filter, F: Filter,
S: ServiceFactory<Request> + 'static, S: ServiceFactory<Request> + 'static,
@ -204,7 +204,8 @@ mod rustls {
X: ServiceFactory<Request, Response = Request> + 'static, X: ServiceFactory<Request, Response = Request> + 'static,
X::Error: ResponseError, X::Error: ResponseError,
X::InitError: fmt::Debug, X::InitError: fmt::Debug,
U: ServiceFactory<(Request, Io<TlsFilter<F>>, h1::Codec), Response = ()> + 'static, U: ServiceFactory<(Request, Io<Layer<TlsFilter, F>>, h1::Codec), Response = ()>
+ 'static,
U::Error: fmt::Display + error::Error, U::Error: fmt::Display + error::Error,
U::InitError: fmt::Debug, U::InitError: fmt::Debug,
{ {

View file

@ -4,8 +4,9 @@ use std::{convert::TryFrom, net, str::FromStr, sync::mpsc, thread};
#[cfg(feature = "cookie")] #[cfg(feature = "cookie")]
use coo_kie::{Cookie, CookieJar}; use coo_kie::{Cookie, CookieJar};
use crate::io::{Filter, Io};
use crate::ws::{error::WsClientError, WsClient, WsConnection}; use crate::ws::{error::WsClientError, WsClient, WsConnection};
use crate::{io::Filter, io::Io, rt::System, server::Server, service::ServiceFactory}; use crate::{rt::System, server::Server, service::ServiceFactory};
use crate::{time::Millis, time::Seconds, util::Bytes}; use crate::{time::Millis, time::Seconds, util::Bytes};
use super::client::{Client, ClientRequest, ClientResponse, Connector}; use super::client::{Client, ClientRequest, ClientResponse, Connector};
@ -349,7 +350,10 @@ impl TestServer {
/// Connect to a websocket server /// Connect to a websocket server
pub async fn wss( pub async fn wss(
&mut self, &mut self,
) -> Result<WsConnection<crate::connect::openssl::SslFilter>, WsClientError> { ) -> Result<
WsConnection<crate::io::Layer<crate::connect::openssl::SslFilter>>,
WsClientError,
> {
self.wss_at("/").await self.wss_at("/").await
} }
@ -358,7 +362,10 @@ impl TestServer {
pub async fn wss_at( pub async fn wss_at(
&mut self, &mut self,
path: &str, path: &str,
) -> Result<WsConnection<crate::connect::openssl::SslFilter>, WsClientError> { ) -> Result<
WsConnection<crate::io::Layer<crate::connect::openssl::SslFilter>>,
WsClientError,
> {
use tls_openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; use tls_openssl::ssl::{SslConnector, SslMethod, SslVerifyMode};
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();

View file

@ -15,13 +15,13 @@ use crate::connect::{Connect, ConnectError, Connector};
use crate::http::header::{self, HeaderMap, HeaderName, HeaderValue, AUTHORIZATION}; use crate::http::header::{self, HeaderMap, HeaderName, HeaderValue, AUTHORIZATION};
use crate::http::{body::BodySize, client::ClientResponse, error::HttpError, h1}; use crate::http::{body::BodySize, client::ClientResponse, error::HttpError, h1};
use crate::http::{ConnectionType, RequestHead, RequestHeadType, StatusCode, Uri}; use crate::http::{ConnectionType, RequestHead, RequestHeadType, StatusCode, Uri};
use crate::io::{Base, DispatchItem, Dispatcher, Filter, Io, Sealed}; use crate::io::{Base, DispatchItem, Dispatcher, Filter, Io, Layer, Sealed};
use crate::service::{apply_fn, into_service, IntoService, Service}; use crate::service::{apply_fn, into_service, IntoService, Service};
use crate::time::{timeout, Millis, Seconds}; use crate::time::{timeout, Millis, Seconds};
use crate::{channel::mpsc, rt, util::Ready, ws}; use crate::{channel::mpsc, rt, util::Ready, ws};
use super::error::{WsClientBuilderError, WsClientError, WsError}; use super::error::{WsClientBuilderError, WsClientError, WsError};
use super::transport::{WsTransport, WsTransportFactory}; use super::transport::WsTransport;
/// `WebSocket` client builder /// `WebSocket` client builder
pub struct WsClient<F, T> { pub struct WsClient<F, T> {
@ -527,7 +527,7 @@ where
pub fn openssl( pub fn openssl(
&mut self, &mut self,
connector: openssl::SslConnector, connector: openssl::SslConnector,
) -> WsClientBuilder<openssl::SslFilter, openssl::Connector<Uri>> { ) -> WsClientBuilder<Layer<openssl::SslFilter>, openssl::Connector<Uri>> {
self.connector(openssl::Connector::new(connector)) self.connector(openssl::Connector::new(connector))
} }
@ -536,7 +536,7 @@ where
pub fn rustls( pub fn rustls(
&mut self, &mut self,
config: std::sync::Arc<rustls::ClientConfig>, config: std::sync::Arc<rustls::ClientConfig>,
) -> WsClientBuilder<rustls::TlsFilter, rustls::Connector<Uri>> { ) -> WsClientBuilder<Layer<rustls::TlsFilter>, rustls::Connector<Uri>> {
self.connector(rustls::Connector::from(config)) self.connector(rustls::Connector::from(config))
} }
@ -787,12 +787,8 @@ impl<F: Filter> WsConnection<F> {
} }
/// Convert to ws stream to plain io stream /// Convert to ws stream to plain io stream
pub async fn into_transport(self) -> Io<WsTransport<F>> { pub fn into_transport(self) -> Io<Layer<WsTransport, F>> {
// WsTransportFactory is infallible WsTransport::create(self.io, self.codec)
self.io
.add_filter(WsTransportFactory::new(self.codec))
.await
.unwrap()
} }
} }

View file

@ -1,9 +1,9 @@
//! An implementation of WebSockets base bytes streams //! An implementation of WebSockets base bytes streams
use std::{any, cell::Cell, cmp, io, task::Context, task::Poll}; use std::{cell::Cell, cmp, io, task::Poll};
use crate::codec::{Decoder, Encoder}; use crate::codec::{Decoder, Encoder};
use crate::io::{Base, Filter, FilterFactory, Io, IoRef, ReadStatus, WriteStatus}; use crate::io::{Filter, FilterFactory, FilterLayer, Io, Layer, ReadBuf, WriteBuf};
use crate::util::{BufMut, BytesVec, PoolRef, Ready}; use crate::util::{BufMut, PoolRef, Ready};
use super::{CloseCode, CloseReason, Codec, Frame, Item, Message}; use super::{CloseCode, CloseReason, Codec, Frame, Item, Message};
@ -16,15 +16,24 @@ bitflags::bitflags! {
} }
/// An implementation of WebSockets streams /// An implementation of WebSockets streams
pub struct WsTransport<F = Base> { pub struct WsTransport {
inner: F,
pool: PoolRef, pool: PoolRef,
codec: Codec, codec: Codec,
flags: Cell<Flags>, flags: Cell<Flags>,
read_buf: Cell<Option<BytesVec>>,
} }
impl<F: Filter> WsTransport<F> { impl WsTransport {
/// Create websockets transport
pub fn create<F: Filter>(io: Io<F>, codec: Codec) -> Io<Layer<WsTransport, F>> {
let pool = io.memory_pool();
io.add_filter(WsTransport {
pool,
codec,
flags: Cell::new(Flags::empty()),
})
}
fn insert_flags(&self, flags: Flags) { fn insert_flags(&self, flags: Flags) {
let mut f = self.flags.get(); let mut f = self.flags.get();
f.insert(flags); f.insert(flags);
@ -47,21 +56,12 @@ impl<F: Filter> WsTransport<F> {
} }
} }
impl<F: Filter> Filter for WsTransport<F> { impl FilterLayer for WsTransport {
#[inline] #[inline]
fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> { fn shutdown(&self, buf: &mut WriteBuf<'_>) -> io::Result<Poll<()>> {
self.inner.query(id)
}
#[inline]
fn poll_shutdown(&self) -> Poll<io::Result<()>> {
let flags = self.flags.get(); let flags = self.flags.get();
if !flags.contains(Flags::CLOSED) { if !flags.contains(Flags::CLOSED) {
self.insert_flags(Flags::CLOSED); self.insert_flags(Flags::CLOSED);
let mut b = self
.inner
.get_write_buf()
.unwrap_or_else(|| self.pool.get_write_buf());
let code = if flags.contains(Flags::PROTO_ERR) { let code = if flags.contains(Flags::PROTO_ERR) {
CloseCode::Protocol CloseCode::Protocol
} else { } else {
@ -72,159 +72,100 @@ impl<F: Filter> Filter for WsTransport<F> {
code, code,
description: None, description: None,
})), })),
&mut b, buf.get_dst(),
); );
self.inner.release_write_buf(b)?;
} }
Ok(Poll::Ready(()))
self.inner.poll_shutdown()
} }
#[inline] fn process_read_buf(&self, buf: &mut ReadBuf<'_>) -> io::Result<usize> {
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> { if let Some(mut src) = buf.take_src() {
self.inner.poll_read_ready(cx) let mut dst = buf.take_dst();
} let dst_len = dst.len();
#[inline] loop {
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> { // make sure we've got room
self.inner.poll_write_ready(cx) self.pool.resize_read_buf(&mut dst);
}
#[inline] let frame = if let Some(frame) =
fn get_read_buf(&self) -> Option<BytesVec> { self.codec.decode_vec(&mut src).map_err(|e| {
self.read_buf.take() log::trace!("Failed to decode ws codec frames: {:?}", e);
} self.insert_flags(Flags::PROTO_ERR);
io::Error::new(io::ErrorKind::Other, e)
})? {
frame
} else {
break;
};
#[inline] match frame {
fn get_write_buf(&self) -> Option<BytesVec> { Frame::Binary(bin) => dst.extend_from_slice(&bin),
None Frame::Continuation(Item::FirstBinary(bin)) => {
} self.insert_flags(Flags::CONTINUATION);
dst.extend_from_slice(&bin);
#[inline] }
fn release_read_buf(&self, buf: BytesVec) { Frame::Continuation(Item::Continue(bin)) => {
self.read_buf.set(Some(buf)); self.continuation_must_start("Continuation frame is not started")?;
} dst.extend_from_slice(&bin);
}
fn process_read_buf(&self, io: &IoRef, nbytes: usize) -> io::Result<(usize, usize)> { Frame::Continuation(Item::Last(bin)) => {
// ask inner filter to process read buf self.continuation_must_start(
match self.inner.process_read_buf(io, nbytes) { "Continuation frame is not started, last frame is received",
Err(err) => io.want_shutdown(Some(err)), )?;
Ok((_, 0)) => return Ok((0, 0)), dst.extend_from_slice(&bin);
Ok(_) => (), self.remove_flags(Flags::CONTINUATION);
} }
Frame::Continuation(Item::FirstText(_)) => {
// get inner buffer self.insert_flags(Flags::PROTO_ERR);
let mut src = if let Some(src) = self.inner.get_read_buf() { return Err(io::Error::new(
src io::ErrorKind::Other,
} else { "WebSocket Text continuation frames are not supported",
return Ok((0, 0)); ));
}; }
Frame::Text(_) => {
// get processed buffer self.insert_flags(Flags::PROTO_ERR);
let mut dst = if let Some(dst) = self.read_buf.take() { return Err(io::Error::new(
dst io::ErrorKind::Other,
} else { "WebSockets Text frames are not supported",
self.pool.get_read_buf() ));
}; }
let dst_len = dst.len(); Frame::Ping(msg) => {
let (hw, lw) = self.pool.read_params().unpack(); let _ = buf.with_write_buf(|b| {
self.codec.encode_vec(Message::Pong(msg), b.get_dst())
loop { });
// make sure we've got room }
let remaining = dst.remaining_mut(); Frame::Pong(_) => (),
if remaining < lw { Frame::Close(_) => {
dst.reserve(hw - remaining); buf.want_shutdown();
break;
}
};
} }
let frame = if let Some(frame) = let nb = dst.len() - dst_len;
self.codec.decode_vec(&mut src).map_err(|e| { buf.set_dst(Some(dst));
log::trace!("Failed to decode ws codec frames: {:?}", e); buf.set_src(Some(src));
self.insert_flags(Flags::PROTO_ERR); Ok(nb)
io::Error::new(io::ErrorKind::Other, e)
})? {
frame
} else {
break;
};
match frame {
Frame::Binary(bin) => dst.extend_from_slice(&bin),
Frame::Continuation(Item::FirstBinary(bin)) => {
self.insert_flags(Flags::CONTINUATION);
dst.extend_from_slice(&bin);
}
Frame::Continuation(Item::Continue(bin)) => {
self.continuation_must_start("Continuation frame is not started")?;
dst.extend_from_slice(&bin);
}
Frame::Continuation(Item::Last(bin)) => {
self.continuation_must_start(
"Continuation frame is not started, last frame is received",
)?;
dst.extend_from_slice(&bin);
self.remove_flags(Flags::CONTINUATION);
}
Frame::Continuation(Item::FirstText(_)) => {
self.insert_flags(Flags::PROTO_ERR);
return Err(io::Error::new(
io::ErrorKind::Other,
"WebSocket Text continuation frames are not supported",
));
}
Frame::Text(_) => {
self.insert_flags(Flags::PROTO_ERR);
return Err(io::Error::new(
io::ErrorKind::Other,
"WebSockets Text frames are not supported",
));
}
Frame::Ping(msg) => {
let mut b = self
.inner
.get_write_buf()
.unwrap_or_else(|| self.pool.get_write_buf());
let _ = self.codec.encode_vec(Message::Pong(msg), &mut b);
self.inner.release_write_buf(b)?;
}
Frame::Pong(_) => (),
Frame::Close(_) => {
io.want_shutdown(None);
break;
}
};
}
let dlen = dst.len();
let nbytes = dlen - dst_len;
if src.is_empty() {
self.pool.release_read_buf(src);
} else { } else {
self.inner.release_read_buf(src); Ok(0)
} }
self.read_buf.set(Some(dst));
Ok((dlen, nbytes))
} }
fn release_write_buf(&self, src: BytesVec) -> Result<(), io::Error> { fn process_write_buf(&self, buf: &mut WriteBuf<'_>) -> io::Result<()> {
let mut buf = if let Some(buf) = self.inner.get_write_buf() { if let Some(src) = buf.take_src() {
buf let dst = buf.get_dst();
} else {
self.pool.get_write_buf()
};
// make sure we've got room // make sure we've got room
let (hw, lw) = self.pool.write_params().unpack(); let (hw, lw) = self.pool.write_params().unpack();
let remaining = buf.remaining_mut(); let remaining = dst.remaining_mut();
if remaining < lw { if remaining < lw {
buf.reserve(cmp::max(hw, buf.len() + 12) - remaining); dst.reserve(cmp::max(hw, dst.len() + 12) - remaining);
}
// Encoder ws::Codec do not fail
let _ = self.codec.encode_vec(Message::Binary(src.freeze()), dst);
} }
Ok(())
// Encoder ws::Codec do not fail
let _ = self
.codec
.encode_vec(Message::Binary(src.freeze()), &mut buf);
self.inner.release_write_buf(buf)
} }
} }
@ -241,22 +182,12 @@ impl WsTransportFactory {
} }
impl<F: Filter> FilterFactory<F> for WsTransportFactory { impl<F: Filter> FilterFactory<F> for WsTransportFactory {
type Filter = WsTransport<F>; type Filter = WsTransport;
type Error = io::Error; type Error = io::Error;
type Future = Ready<Io<Self::Filter>, Self::Error>; type Future = Ready<Io<Layer<Self::Filter, F>>, Self::Error>;
fn create(self, st: Io<F>) -> Self::Future { fn create(self, io: Io<F>) -> Self::Future {
let pool = st.memory_pool(); Ready::Ok(WsTransport::create(io, self.codec))
Ready::from(st.map_filter(|inner: F| {
Ok(WsTransport {
pool,
inner,
codec: self.codec,
flags: Cell::new(Flags::empty()),
read_buf: Cell::new(None),
})
}))
} }
} }

View file

@ -494,11 +494,8 @@ async fn test_ws_transport() {
) )
.unwrap(); .unwrap();
let io = io
.add_filter(ws::WsTransportFactory::new(ws::Codec::default()))
.await?;
// start websocket service // start websocket service
let io = ws::WsTransport::create(io, ws::Codec::default());
while let Some(item) = while let Some(item) =
io.recv(&BytesCodec).await.map_err(|e| e.into_inner())? io.recv(&BytesCodec).await.map_err(|e| e.into_inner())?
{ {

View file

@ -254,9 +254,7 @@ async fn test_transport() {
) )
.unwrap(); .unwrap();
let io = io let io = ws::WsTransport::create(io, ws::Codec::default());
.add_filter(ws::WsTransportFactory::new(ws::Codec::default()))
.await?;
// start websocket service // start websocket service
while let Some(item) = while let Some(item) =

View file

@ -95,7 +95,7 @@ async fn test_transport() {
}); });
// client service // client service
let io = srv.ws().await.unwrap().into_transport().await; let io = srv.ws().await.unwrap().into_transport();
io.send(Bytes::from_static(b"text"), &BytesCodec) io.send(Bytes::from_static(b"text"), &BytesCodec)
.await .await

View file

@ -120,6 +120,7 @@ async fn test_run() {
// stop // stop
let _ = srv.stop(false).await; let _ = srv.stop(false).await;
thread::sleep(time::Duration::from_millis(100));
assert!(net::TcpStream::connect(addr).is_err()); assert!(net::TcpStream::connect(addr).is_err());
thread::sleep(time::Duration::from_millis(100)); thread::sleep(time::Duration::from_millis(100));
@ -250,7 +251,6 @@ fn test_configure_async() {
#[cfg(feature = "tokio")] #[cfg(feature = "tokio")]
#[allow(unreachable_code)] #[allow(unreachable_code)]
fn test_panic_in_worker() { fn test_panic_in_worker() {
env_logger::init();
let counter = Arc::new(AtomicUsize::new(0)); let counter = Arc::new(AtomicUsize::new(0));
let counter2 = counter.clone(); let counter2 = counter.clone();