diff --git a/ntex-async-std/Cargo.toml b/ntex-async-std/Cargo.toml index 48f361a8..c9a277b3 100644 --- a/ntex-async-std/Cargo.toml +++ b/ntex-async-std/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-async-std" -version = "0.2.0" +version = "0.2.1" authors = ["ntex contributors "] description = "async-std intergration for ntex framework" keywords = ["network", "framework", "async", "futures"] @@ -16,8 +16,8 @@ name = "ntex_async_std" path = "src/lib.rs" [dependencies] -ntex-bytes = "0.1.11" -ntex-io = "0.2.0" +ntex-bytes = "0.1.19" +ntex-io = "0.2.1" ntex-util = "0.2.0" async-oneshot = "0.5.0" log = "0.4" diff --git a/ntex-async-std/src/io.rs b/ntex-async-std/src/io.rs index 4f64e091..40edecc2 100644 --- a/ntex-async-std/src/io.rs +++ b/ntex-async-std/src/io.rs @@ -1,4 +1,4 @@ -use std::{any, future::Future, io, pin::Pin, task::Context, task::Poll}; +use std::{any, cell::RefCell, future::Future, io, pin::Pin, task::Context, task::Poll}; use async_std::io::{Read, Write}; use ntex_bytes::{Buf, BufMut, BytesVec}; @@ -30,35 +30,31 @@ impl Handle for TcpStream { /// Read io task struct ReadTask { - io: TcpStream, + io: RefCell, state: ReadContext, } impl ReadTask { /// Create new read io task fn new(io: TcpStream, state: ReadContext) -> Self { - Self { io, state } + Self { + state, + io: RefCell::new(io), + } } } impl Future for ReadTask { type Output = (); - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.as_mut(); + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.as_ref(); - loop { + this.state.with_buf(|buf, hw, lw| { match ready!(this.state.poll_ready(cx)) { ReadStatus::Ready => { - let pool = this.state.memory_pool(); - let mut buf = this.state.get_read_buf(); - let io = &mut this.io; - let (hw, lw) = pool.read_params().unpack(); - // read data from socket - let mut new_bytes = 0; - let mut close = false; - let mut pending = false; + let mut io = self.io.borrow_mut(); loop { // make sure we've got room let remaining = buf.remaining_mut(); @@ -66,52 +62,31 @@ impl Future for ReadTask { buf.reserve(hw - remaining); } - match poll_read_buf(Pin::new(&mut io.0), cx, &mut buf) { - Poll::Pending => { - pending = true; - break; - } + return match poll_read_buf(Pin::new(&mut io.0), cx, buf) { + Poll::Pending => Poll::Pending, Poll::Ready(Ok(n)) => { if n == 0 { log::trace!("async-std stream is disconnected"); - close = true; + Poll::Ready(Ok(())) + } else if buf.len() < hw { + continue; } else { - new_bytes += n; - if new_bytes <= hw { - continue; - } + Poll::Pending } - break; } Poll::Ready(Err(err)) => { - log::trace!("read task failed on io {:?}", err); - this.state.release_read_buf(buf, new_bytes); - this.state.close(Some(err)); - return Poll::Ready(()); + log::trace!("async-std read task failed on io {:?}", err); + Poll::Ready(Err(err)) } - } + }; } - - if new_bytes == 0 && close { - this.state.close(None); - return Poll::Ready(()); - } - this.state.release_read_buf(buf, new_bytes); - return if close { - this.state.close(None); - Poll::Ready(()) - } else if pending { - Poll::Pending - } else { - continue; - }; } ReadStatus::Terminate => { log::trace!("read task is instructed to shutdown"); - return Poll::Ready(()); + Poll::Ready(Ok(())) } } - } + }) } } @@ -358,10 +333,6 @@ pub fn poll_read_buf( cx: &mut Context<'_>, buf: &mut BytesVec, ) -> Poll> { - if !buf.has_remaining_mut() { - return Poll::Ready(Ok(0)); - } - let dst = unsafe { &mut *(buf.chunk_mut() as *mut _ as *mut [u8]) }; let n = ready!(io.poll_read(cx, dst))?; @@ -389,35 +360,31 @@ mod unixstream { /// Read io task struct ReadTask { - io: UnixStream, + io: RefCell, state: ReadContext, } impl ReadTask { /// Create new read io task fn new(io: UnixStream, state: ReadContext) -> Self { - Self { io, state } + Self { + state, + io: RefCell::new(io), + } } } impl Future for ReadTask { type Output = (); - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.as_mut(); + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.as_ref(); - loop { + this.state.with_buf(|buf, hw, lw| { match ready!(this.state.poll_ready(cx)) { ReadStatus::Ready => { - let pool = this.state.memory_pool(); - let mut buf = this.state.get_read_buf(); - let io = &mut this.io; - let (hw, lw) = pool.read_params().unpack(); - // read data from socket - let mut new_bytes = 0; - let mut close = false; - let mut pending = false; + let mut io = this.io.borrow_mut(); loop { // make sure we've got room let remaining = buf.remaining_mut(); @@ -425,52 +392,31 @@ mod unixstream { buf.reserve(hw - remaining); } - match poll_read_buf(Pin::new(&mut io.0), cx, &mut buf) { - Poll::Pending => { - pending = true; - break; - } + return match poll_read_buf(Pin::new(&mut io.0), cx, buf) { + Poll::Pending => Poll::Pending, Poll::Ready(Ok(n)) => { if n == 0 { log::trace!("async-std stream is disconnected"); - close = true; + Poll::Ready(Ok(())) + } else if buf.len() < hw { + continue; } else { - new_bytes += n; - if new_bytes <= hw { - continue; - } + Poll::Pending } - break; } Poll::Ready(Err(err)) => { log::trace!("read task failed on io {:?}", err); - this.state.release_read_buf(buf, new_bytes); - this.state.close(Some(err)); - return Poll::Ready(()); + Poll::Ready(Err(err)) } - } + }; } - - if new_bytes == 0 && close { - this.state.close(None); - return Poll::Ready(()); - } - this.state.release_read_buf(buf, new_bytes); - return if close { - this.state.close(None); - Poll::Ready(()) - } else if pending { - Poll::Pending - } else { - continue; - }; } ReadStatus::Terminate => { log::trace!("read task is instructed to shutdown"); - return Poll::Ready(()); + Poll::Ready(Ok(())) } } - } + }) } } diff --git a/ntex-bytes/CHANGELOG.md b/ntex-bytes/CHANGELOG.md index 5b4671b5..fbe058dd 100644 --- a/ntex-bytes/CHANGELOG.md +++ b/ntex-bytes/CHANGELOG.md @@ -1,5 +1,9 @@ # Changes +## [0.1.19] (2023-01-23) + +* Add PollRef::resize_read_buf() and PollRef::resize_write_buf() helpers + ## [0.1.18] (2022-12-13) * Add Bytes<&Bytes> for Bytes impl diff --git a/ntex-bytes/Cargo.toml b/ntex-bytes/Cargo.toml index 9ce5f2c9..75437f3d 100644 --- a/ntex-bytes/Cargo.toml +++ b/ntex-bytes/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-bytes" -version = "0.1.18" +version = "0.1.19" license = "MIT" authors = ["Nikolay Kim ", "Carl Lerche "] description = "Types and traits for working with bytes (bytes crate fork)" diff --git a/ntex-bytes/src/pool.rs b/ntex-bytes/src/pool.rs index eb83337b..eeb721bc 100644 --- a/ntex-bytes/src/pool.rs +++ b/ntex-bytes/src/pool.rs @@ -6,7 +6,7 @@ use std::{cell::Cell, cell::RefCell, fmt, future::Future, mem, pin::Pin, ptr, rc use futures_core::task::__internal::AtomicWaker; -use crate::{BytesMut, BytesVec}; +use crate::{BufMut, BytesMut, BytesVec}; pub struct Pool { idx: Cell, @@ -293,6 +293,17 @@ impl PoolRef { } } + #[doc(hidden)] + #[inline] + /// Resize read buffer + pub fn resize_read_buf(self, buf: &mut BytesVec) { + let (hw, lw) = self.0.write_wm.get().unpack(); + let remaining = buf.remaining_mut(); + if remaining < lw { + buf.reserve(hw - remaining); + } + } + #[doc(hidden)] #[inline] /// Release read buffer, buf must be allocated from this pool @@ -318,6 +329,17 @@ impl PoolRef { } } + #[doc(hidden)] + #[inline] + /// Resize write buffer + pub fn resize_write_buf(self, buf: &mut BytesVec) { + let (hw, lw) = self.0.write_wm.get().unpack(); + let remaining = buf.remaining_mut(); + if remaining < lw { + buf.reserve(hw - remaining); + } + } + #[doc(hidden)] #[inline] /// Release write buffer, buf must be allocated from this pool diff --git a/ntex-connect/CHANGES.md b/ntex-connect/CHANGES.md index 4ca6350e..fa39795c 100644 --- a/ntex-connect/CHANGES.md +++ b/ntex-connect/CHANGES.md @@ -1,5 +1,9 @@ # Changes +## [0.2.1] - 2023-01-23 + +* Use new Io object + ## [0.2.0] - 2023-01-04 * Release diff --git a/ntex-connect/Cargo.toml b/ntex-connect/Cargo.toml index 11c1cff0..8ca9f4ed 100644 --- a/ntex-connect/Cargo.toml +++ b/ntex-connect/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-connect" -version = "0.2.0" +version = "0.2.1" authors = ["ntex contributors "] description = "ntexwork connect utils for ntex framework" keywords = ["network", "framework", "async", "futures"] @@ -35,18 +35,18 @@ async-std = ["ntex-rt/async-std", "ntex-async-std"] [dependencies] ntex-service = "1.0.0" -ntex-bytes = "0.1.18" +ntex-bytes = "0.1.19" ntex-http = "0.1.8" -ntex-io = "0.2.0" +ntex-io = "0.2.1" ntex-rt = "0.4.7" -ntex-tls = "0.2.0" +ntex-tls = "0.2.1" ntex-util = "0.2.0" log = "0.4" thiserror = "1.0" -ntex-tokio = { version = "0.2.0", optional = true } -ntex-glommio = { version = "0.2.0", optional = true } -ntex-async-std = { version = "0.2.0", optional = true } +ntex-tokio = { version = "0.2.1", optional = true } +ntex-glommio = { version = "0.2.1", optional = true } +ntex-async-std = { version = "0.2.1", optional = true } # openssl tls-openssl = { version="0.10", package = "openssl", optional = true } diff --git a/ntex-connect/src/openssl.rs b/ntex-connect/src/openssl.rs index 5bd25f7a..e14e6ad0 100644 --- a/ntex-connect/src/openssl.rs +++ b/ntex-connect/src/openssl.rs @@ -4,7 +4,7 @@ pub use ntex_tls::openssl::SslFilter; pub use tls_openssl::ssl::{Error as SslError, HandshakeError, SslConnector, SslMethod}; use ntex_bytes::PoolId; -use ntex_io::{Base, Io}; +use ntex_io::{FilterFactory, Io, Layer}; use ntex_service::{Service, ServiceFactory}; use ntex_tls::openssl::SslConnector as IoSslConnector; use ntex_util::future::{BoxFuture, Ready}; @@ -39,7 +39,7 @@ impl Connector { impl Connector { /// Resolve and connect to remote host - pub async fn connect(&self, message: U) -> Result>, ConnectError> + pub async fn connect(&self, message: U) -> Result>, ConnectError> where Connect: From, { @@ -57,7 +57,7 @@ impl Connector { let ssl = config .into_ssl(&host) .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; - match io.add_filter(IoSslConnector::new(ssl)).await { + match IoSslConnector::new(ssl).create(io).await { Ok(io) => { trace!("SSL Handshake success: {:?}", host); Ok(io) @@ -82,7 +82,7 @@ impl Clone for Connector { } impl ServiceFactory, C> for Connector { - type Response = Io>; + type Response = Io>; type Error = ConnectError; type Service = Connector; type InitError = (); @@ -95,7 +95,7 @@ impl ServiceFactory, C> for Connector { } impl Service> for Connector { - type Response = Io>; + type Response = Io>; type Error = ConnectError; type Future<'f> = BoxFuture<'f, Result>; diff --git a/ntex-connect/src/rustls.rs b/ntex-connect/src/rustls.rs index e93fa62e..2c37beff 100644 --- a/ntex-connect/src/rustls.rs +++ b/ntex-connect/src/rustls.rs @@ -4,7 +4,7 @@ pub use ntex_tls::rustls::TlsFilter; pub use tls_rustls::{ClientConfig, ServerName}; use ntex_bytes::PoolId; -use ntex_io::{Base, Io}; +use ntex_io::{FilterFactory, Io, Layer}; use ntex_service::{Service, ServiceFactory}; use ntex_tls::rustls::TlsConnector; use ntex_util::future::{BoxFuture, Ready}; @@ -48,7 +48,7 @@ impl Connector { impl Connector { /// Resolve and connect to remote host - pub async fn connect(&self, message: U) -> Result>, ConnectError> + pub async fn connect(&self, message: U) -> Result>, ConnectError> where Connect: From, { @@ -64,7 +64,7 @@ impl Connector { .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("{}", e)))?; let connector = connector.server_name(host.clone()); - match io.add_filter(connector).await { + match connector.create(io).await { Ok(io) => { trace!("TLS Handshake success: {:?}", &host); Ok(io) @@ -87,7 +87,7 @@ impl Clone for Connector { } impl ServiceFactory, C> for Connector { - type Response = Io>; + type Response = Io>; type Error = ConnectError; type Service = Connector; type InitError = (); @@ -100,7 +100,7 @@ impl ServiceFactory, C> for Connector { } impl Service> for Connector { - type Response = Io>; + type Response = Io>; type Error = ConnectError; type Future<'f> = BoxFuture<'f, Result>; diff --git a/ntex-glommio/Cargo.toml b/ntex-glommio/Cargo.toml index b734bac2..b6f00c85 100644 --- a/ntex-glommio/Cargo.toml +++ b/ntex-glommio/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-glommio" -version = "0.2.0" +version = "0.2.1" authors = ["ntex contributors "] description = "glommio intergration for ntex framework" keywords = ["network", "framework", "async", "futures"] @@ -16,8 +16,8 @@ name = "ntex_glommio" path = "src/lib.rs" [dependencies] -ntex-bytes = "0.1.18" -ntex-io = "0.2.0" +ntex-bytes = "0.1.19" +ntex-io = "0.2.1" ntex-util = "0.2.0" async-oneshot = "0.5.0" futures-lite = "1.12" diff --git a/ntex-glommio/src/io.rs b/ntex-glommio/src/io.rs index 08510317..68e7cac3 100644 --- a/ntex-glommio/src/io.rs +++ b/ntex-glommio/src/io.rs @@ -57,17 +57,10 @@ impl Future for ReadTask { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.as_mut(); - loop { + this.state.with_buf(|buf, hw, lw| { match ready!(this.state.poll_ready(cx)) { ReadStatus::Ready => { - let pool = this.state.memory_pool(); - let mut buf = this.state.get_read_buf(); - let (hw, lw) = pool.read_params().unpack(); - // read data from socket - let mut new_bytes = 0; - let mut close = false; - let mut pending = false; loop { // make sure we've got room let remaining = buf.remaining_mut(); @@ -75,56 +68,35 @@ impl Future for ReadTask { buf.reserve(hw - remaining); } - match poll_read_buf( + return match poll_read_buf( Pin::new(&mut *this.io.0.borrow_mut()), cx, - &mut buf, + buf, ) { - Poll::Pending => { - pending = true; - break; - } + Poll::Pending => Poll::Pending, Poll::Ready(Ok(n)) => { if n == 0 { log::trace!("glommio stream is disconnected"); - close = true; + Poll::Ready(Ok(())) + } else if buf.len() < hw { + continue; } else { - new_bytes += n; - if new_bytes <= hw { - continue; - } + Poll::Pending } - break; } Poll::Ready(Err(err)) => { log::trace!("read task failed on io {:?}", err); - let _ = this.state.release_read_buf(buf, new_bytes); - this.state.close(Some(err)); - return Poll::Ready(()); + Poll::Ready(Err(err)) } - } + }; } - - if new_bytes == 0 && close { - this.state.close(None); - return Poll::Ready(()); - } - this.state.release_read_buf(buf, new_bytes); - return if close { - this.state.close(None); - Poll::Ready(()) - } else if pending { - Poll::Pending - } else { - continue; - }; } ReadStatus::Terminate => { log::trace!("read task is instructed to shutdown"); - return Poll::Ready(()); + Poll::Ready(Ok(())) } } - } + }) } } @@ -372,10 +344,6 @@ pub fn poll_read_buf( cx: &mut Context<'_>, buf: &mut BytesVec, ) -> Poll> { - if !buf.has_remaining_mut() { - return Poll::Ready(Ok(0)); - } - let dst = unsafe { &mut *(buf.chunk_mut() as *mut _ as *mut [u8]) }; let n = ready!(io.poll_read(cx, dst))?; @@ -407,17 +375,10 @@ impl Future for UnixReadTask { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.as_mut(); - loop { + this.state.with_buf(|buf, hw, lw| { match ready!(this.state.poll_ready(cx)) { ReadStatus::Ready => { - let pool = this.state.memory_pool(); - let mut buf = this.state.get_read_buf(); - let (hw, lw) = pool.read_params().unpack(); - // read data from socket - let mut new_bytes = 0; - let mut close = false; - let mut pending = false; loop { // make sure we've got room let remaining = buf.remaining_mut(); @@ -425,56 +386,35 @@ impl Future for UnixReadTask { buf.reserve(hw - remaining); } - match poll_read_buf( + return match poll_read_buf( Pin::new(&mut *this.io.0.borrow_mut()), cx, - &mut buf, + buf, ) { - Poll::Pending => { - pending = true; - break; - } + Poll::Pending => Poll::Pending, Poll::Ready(Ok(n)) => { if n == 0 { log::trace!("glommio stream is disconnected"); - close = true; + Poll::Ready(Ok(())) + } else if buf.len() < hw { + continue; } else { - new_bytes += n; - if new_bytes <= hw { - continue; - } + Poll::Pending } - break; } Poll::Ready(Err(err)) => { log::trace!("read task failed on io {:?}", err); - let _ = this.state.release_read_buf(buf, new_bytes); - this.state.close(Some(err)); - return Poll::Ready(()); + Poll::Ready(Err(err)) } - } + }; } - - if new_bytes == 0 && close { - this.state.close(None); - return Poll::Ready(()); - } - this.state.release_read_buf(buf, new_bytes); - return if close { - this.state.close(None); - Poll::Ready(()) - } else if pending { - Poll::Pending - } else { - continue; - }; } ReadStatus::Terminate => { log::trace!("read task is instructed to shutdown"); - return Poll::Ready(()); + Poll::Ready(Ok(())) } } - } + }) } } diff --git a/ntex-io/CHANGES.md b/ntex-io/CHANGES.md index 9c189c8f..0759e780 100644 --- a/ntex-io/CHANGES.md +++ b/ntex-io/CHANGES.md @@ -1,5 +1,9 @@ # Changes +## [0.2.1] - 2023-01-23 + +* Refactor Io and Filter types + ## [0.2.0] - 2023-01-04 * Release diff --git a/ntex-io/Cargo.toml b/ntex-io/Cargo.toml index b9236ebb..f03a853e 100644 --- a/ntex-io/Cargo.toml +++ b/ntex-io/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-io" -version = "0.2.0" +version = "0.2.1" authors = ["ntex contributors "] description = "Utilities for encoding and decoding frames" keywords = ["network", "framework", "async", "futures"] @@ -17,13 +17,14 @@ path = "src/lib.rs" [dependencies] ntex-codec = "0.6.2" -ntex-bytes = "0.1.18" +ntex-bytes = "0.1.19" ntex-util = "0.2.0" ntex-service = "1.0.0" bitflags = "1.3" log = "0.4" pin-project-lite = "0.2" +smallvec = "1" [dev-dependencies] rand = "0.8" diff --git a/ntex-io/src/buf.rs b/ntex-io/src/buf.rs new file mode 100644 index 00000000..e45f482b --- /dev/null +++ b/ntex-io/src/buf.rs @@ -0,0 +1,361 @@ +use ntex_bytes::{BytesVec, PoolRef}; +use smallvec::SmallVec; + +use crate::IoRef; + +#[derive(Debug)] +pub struct Stack { + pub(crate) buffers: SmallVec<[(Option, Option); 4]>, +} + +impl Stack { + pub(crate) fn new() -> Self { + let mut buffers = SmallVec::with_capacity(4); + buffers.push((None, None)); + + Self { buffers } + } + + pub(crate) fn add_layer(&mut self) { + self.buffers.insert(0, (None, None)); + } + + pub(crate) fn read_buf( + &mut self, + io: &IoRef, + idx: usize, + nbytes: usize, + f: F, + ) -> R + where + F: FnOnce(&mut ReadBuf<'_>) -> R, + { + let pos = idx + 1; + if self.buffers.len() > pos { + let (curr, next) = self.buffers.split_at_mut(pos); + let mut buf = ReadBuf { + io, + nbytes, + curr: &mut curr[idx], + next: &mut next[0], + }; + f(&mut buf) + } else { + let mut val1 = (self.buffers[idx].0.take(), None); + let mut val2 = (None, self.buffers[idx].1.take()); + + let mut buf = ReadBuf { + io, + nbytes, + curr: &mut val1, + next: &mut val2, + }; + let result = f(&mut buf); + + self.buffers[idx].0 = val1.0; + self.buffers[idx].1 = val2.1; + result + } + } + + pub(crate) fn write_buf(&mut self, io: &IoRef, idx: usize, f: F) -> R + where + F: FnOnce(&mut WriteBuf<'_>) -> R, + { + let pos = idx + 1; + if self.buffers.len() > pos { + let (curr, next) = self.buffers.split_at_mut(pos); + let mut buf = WriteBuf { + io, + curr: &mut curr[idx], + next: &mut next[0], + }; + f(&mut buf) + } else { + let mut val1 = (self.buffers[idx].0.take(), None); + let mut val2 = (None, self.buffers[idx].1.take()); + + let mut buf = WriteBuf { + io, + curr: &mut val1, + next: &mut val2, + }; + let result = f(&mut buf); + + self.buffers[idx].0 = val1.0; + self.buffers[idx].1 = val2.1; + result + } + } + + pub(crate) fn first_read_buf_size(&self) -> usize { + self.buffers[0].0.as_ref().map(|b| b.len()).unwrap_or(0) + } + + pub(crate) fn first_read_buf(&mut self) -> &mut Option { + &mut self.buffers[0].0 + } + + pub(crate) fn first_write_buf(&mut self, io: &IoRef) -> &mut BytesVec { + if self.buffers[0].1.is_none() { + self.buffers[0].1 = Some(io.memory_pool().get_write_buf()); + } + self.buffers[0].1.as_mut().unwrap() + } + + pub(crate) fn last_read_buf(&mut self) -> &mut Option { + let idx = self.buffers.len() - 1; + &mut self.buffers[idx].0 + } + + pub(crate) fn last_write_buf(&mut self) -> &mut Option { + let idx = self.buffers.len() - 1; + &mut self.buffers[idx].1 + } + + pub(crate) fn last_write_buf_size(&self) -> usize { + let idx = self.buffers.len() - 1; + self.buffers[idx].1.as_ref().map(|b| b.len()).unwrap_or(0) + } + + pub(crate) fn set_last_write_buf(&mut self, buf: BytesVec) { + let idx = self.buffers.len() - 1; + self.buffers[idx].1 = Some(buf); + } + + pub(crate) fn release(&mut self, pool: PoolRef) { + for buf in &mut self.buffers { + if let Some(buf) = buf.0.take() { + pool.release_read_buf(buf); + } + if let Some(buf) = buf.1.take() { + pool.release_write_buf(buf); + } + } + } + + pub(crate) fn set_memory_pool(&mut self, pool: PoolRef) { + for buf in &mut self.buffers { + if let Some(ref mut b) = buf.0 { + pool.move_vec_in(b); + } + if let Some(ref mut b) = buf.1 { + pool.move_vec_in(b); + } + } + } +} + +#[derive(Debug)] +pub struct ReadBuf<'a> { + pub(crate) io: &'a IoRef, + pub(crate) curr: &'a mut (Option, Option), + pub(crate) next: &'a mut (Option, Option), + pub(crate) nbytes: usize, +} + +impl<'a> ReadBuf<'a> { + #[inline] + /// Get number of newly added bytes + pub fn nbytes(&self) -> usize { + self.nbytes + } + + #[inline] + /// Initiate graceful io stream shutdown + pub fn want_shutdown(&self) { + self.io.want_shutdown() + } + + #[inline] + /// Get reference to source read buffer + pub fn get_src(&mut self) -> &mut BytesVec { + if self.next.0.is_none() { + self.next.0 = Some(self.io.memory_pool().get_read_buf()); + } + self.next.0.as_mut().unwrap() + } + + #[inline] + /// Take source read buffer + pub fn take_src(&mut self) -> Option { + self.next + .0 + .take() + .and_then(|b| if b.is_empty() { None } else { Some(b) }) + } + + #[inline] + /// Set source read buffer + pub fn set_src(&mut self, src: Option) { + if let Some(src) = src { + if src.is_empty() { + self.io.memory_pool().release_read_buf(src); + } else { + if let Some(b) = self.next.0.take() { + self.io.memory_pool().release_read_buf(b); + } + self.next.0 = Some(src); + } + } + } + + #[inline] + /// Get reference to destination read buffer + pub fn get_dst(&mut self) -> &mut BytesVec { + if self.curr.0.is_none() { + self.curr.0 = Some(self.io.memory_pool().get_read_buf()); + } + self.curr.0.as_mut().unwrap() + } + + #[inline] + /// Take destination read buffer + pub fn take_dst(&mut self) -> BytesVec { + self.curr + .0 + .take() + .unwrap_or_else(|| self.io.memory_pool().get_read_buf()) + } + + #[inline] + /// Set destination read buffer + pub fn set_dst(&mut self, dst: Option) { + if let Some(dst) = dst { + if dst.is_empty() { + self.io.memory_pool().release_read_buf(dst); + } else { + if let Some(b) = self.curr.0.take() { + self.io.memory_pool().release_read_buf(b); + } + self.curr.0 = Some(dst); + } + } + } + + #[inline] + /// Get reference to source and destination read buffers (src, dst) + pub fn get_pair(&mut self) -> (&mut BytesVec, &mut BytesVec) { + if self.next.0.is_none() { + self.next.0 = Some(self.io.memory_pool().get_read_buf()); + } + if self.curr.0.is_none() { + self.curr.0 = Some(self.io.memory_pool().get_read_buf()); + } + (self.next.0.as_mut().unwrap(), self.curr.0.as_mut().unwrap()) + } + + #[inline] + pub fn with_write_buf<'b, F, R>(&'b mut self, f: F) -> R + where + F: FnOnce(&mut WriteBuf<'b>) -> R, + { + let mut buf = WriteBuf { + io: self.io, + curr: self.curr, + next: self.next, + }; + f(&mut buf) + } +} + +#[derive(Debug)] +pub struct WriteBuf<'a> { + pub(crate) io: &'a IoRef, + pub(crate) curr: &'a mut (Option, Option), + pub(crate) next: &'a mut (Option, Option), +} + +impl<'a> WriteBuf<'a> { + #[inline] + /// Initiate graceful io stream shutdown + pub fn want_shutdown(&self) { + self.io.want_shutdown() + } + + #[inline] + /// Get reference to source write buffer + pub fn get_src(&mut self) -> &mut BytesVec { + if self.curr.1.is_none() { + self.curr.1 = Some(self.io.memory_pool().get_write_buf()); + } + self.curr.1.as_mut().unwrap() + } + + #[inline] + /// Take source write buffer + pub fn take_src(&mut self) -> Option { + self.curr + .1 + .take() + .and_then(|b| if b.is_empty() { None } else { Some(b) }) + } + + #[inline] + /// Set source write buffer + pub fn set_src(&mut self, src: Option) { + if let Some(b) = self.curr.1.take() { + self.io.memory_pool().release_read_buf(b); + } + self.curr.1 = src; + } + + #[inline] + /// Get reference to destination write buffer + pub fn get_dst(&mut self) -> &mut BytesVec { + if self.next.1.is_none() { + self.next.1 = Some(self.io.memory_pool().get_write_buf()); + } + self.next.1.as_mut().unwrap() + } + + #[inline] + /// Take destination write buffer + pub fn take_dst(&mut self) -> BytesVec { + self.next + .1 + .take() + .unwrap_or_else(|| self.io.memory_pool().get_write_buf()) + } + + #[inline] + /// Set destination write buffer + pub fn set_dst(&mut self, dst: Option) { + if let Some(dst) = dst { + if dst.is_empty() { + self.io.memory_pool().release_write_buf(dst); + } else { + if let Some(b) = self.next.1.take() { + self.io.memory_pool().release_write_buf(b); + } + self.next.1 = Some(dst); + } + } + } + + #[inline] + /// Get reference to source and destination buffers (src, dst) + pub fn get_pair(&mut self) -> (&mut BytesVec, &mut BytesVec) { + if self.curr.1.is_none() { + self.curr.1 = Some(self.io.memory_pool().get_write_buf()); + } + if self.next.1.is_none() { + self.next.1 = Some(self.io.memory_pool().get_write_buf()); + } + (self.curr.1.as_mut().unwrap(), self.next.1.as_mut().unwrap()) + } + + #[inline] + pub fn with_read_buf<'b, F, R>(&'b mut self, f: F) -> R + where + F: FnOnce(&mut ReadBuf<'b>) -> R, + { + let mut buf = ReadBuf { + io: self.io, + curr: self.curr, + next: self.next, + nbytes: 0, + }; + f(&mut buf) + } +} diff --git a/ntex-io/src/dispatcher.rs b/ntex-io/src/dispatcher.rs index 0024edb0..95e518c4 100644 --- a/ntex-io/src/dispatcher.rs +++ b/ntex-io/src/dispatcher.rs @@ -316,18 +316,15 @@ where } // shutdown service DispatcherState::Shutdown => { - let err = slf.error.take(); - return if this.inner.shared.service.poll_shutdown(cx).is_ready() { log::trace!("service shutdown is completed, stop"); - Poll::Ready(if let Some(err) = err { + Poll::Ready(if let Some(err) = slf.error.take() { Err(err) } else { Ok(()) }) } else { - slf.error.set(err); Poll::Pending }; } @@ -632,9 +629,7 @@ mod tests { // close read side client.close().await; - - // TODO! fix - // assert!(client.is_server_dropped()); + assert!(client.is_server_dropped()); // service must be checked for readiness only once assert_eq!(counter.get(), 1); diff --git a/ntex-io/src/filter.rs b/ntex-io/src/filter.rs index 6692a860..f4c613ec 100644 --- a/ntex-io/src/filter.rs +++ b/ntex-io/src/filter.rs @@ -1,8 +1,6 @@ use std::{any, io, task::Context, task::Poll}; -use ntex_bytes::BytesVec; - -use super::{io::Flags, Filter, IoRef, ReadStatus, WriteStatus}; +use super::{buf::Stack, io::Flags, FilterLayer, IoRef, ReadStatus, WriteStatus}; /// Default `Io` filter pub struct Base(IoRef); @@ -13,8 +11,54 @@ impl Base { } } +pub struct Layer(pub(crate) F, L); + +impl Layer { + pub(crate) fn new(f: F, l: L) -> Self { + Self(f, l) + } +} + +pub(crate) struct NullFilter; + +const NULL: NullFilter = NullFilter; + +impl NullFilter { + pub(super) fn get() -> &'static dyn Filter { + &NULL + } +} + +pub trait Filter: 'static { + fn query(&self, id: any::TypeId) -> Option>; + + fn process_read_buf( + &self, + io: &IoRef, + stack: &mut Stack, + idx: usize, + nbytes: usize, + ) -> io::Result; + + /// Process write buffer + fn process_write_buf( + &self, + io: &IoRef, + stack: &mut Stack, + idx: usize, + ) -> io::Result<()>; + + /// Gracefully shutdown filter + fn shutdown(&self, io: &IoRef, stack: &mut Stack, idx: usize) -> io::Result>; + + /// Check readiness for read operations + fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll; + + /// Check readiness for write operations + fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll; +} + impl Filter for Base { - #[inline] fn query(&self, id: any::TypeId) -> Option> { if let Some(hnd) = self.0 .0.handle.take() { let res = hnd.query(id); @@ -25,11 +69,6 @@ impl Filter for Base { } } - #[inline] - fn poll_shutdown(&self) -> Poll> { - Poll::Ready(Ok(())) - } - #[inline] fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll { let flags = self.0.flags(); @@ -67,51 +106,128 @@ impl Filter for Base { } #[inline] - fn get_read_buf(&self) -> Option { - self.0 .0.read_buf.take() + fn process_read_buf( + &self, + _: &IoRef, + _: &mut Stack, + _: usize, + nbytes: usize, + ) -> io::Result { + Ok(nbytes) } #[inline] - fn get_write_buf(&self) -> Option { - self.0 .0.write_buf.take() - } - - #[inline] - fn release_read_buf(&self, buf: BytesVec) { - self.0 .0.read_buf.set(Some(buf)); - } - - #[inline] - fn process_read_buf(&self, _: &IoRef, n: usize) -> io::Result<(usize, usize)> { - let buf = self.0 .0.read_buf.as_ptr(); - let ref_buf = unsafe { buf.as_ref().unwrap() }; - let total = ref_buf.as_ref().map(|v| v.len()).unwrap_or(0); - Ok((total, n)) - } - - #[inline] - fn release_write_buf(&self, buf: BytesVec) -> Result<(), io::Error> { - let pool = self.0.memory_pool(); - if buf.is_empty() { - pool.release_write_buf(buf); - } else { - if buf.len() >= pool.write_params_high() { + fn process_write_buf(&self, _: &IoRef, s: &mut Stack, _: usize) -> io::Result<()> { + if let Some(buf) = s.last_write_buf() { + if buf.len() >= self.0.memory_pool().write_params_high() { self.0 .0.insert_flags(Flags::WR_BACKPRESSURE); } - self.0 .0.write_buf.set(Some(buf)); - self.0 .0.write_task.wake(); } Ok(()) } + + #[inline] + fn shutdown(&self, _: &IoRef, _: &mut Stack, _: usize) -> io::Result> { + Ok(Poll::Ready(())) + } } -pub(crate) struct NullFilter; +impl Filter for Layer +where + F: FilterLayer, + L: Filter, +{ + #[inline] + fn query(&self, id: any::TypeId) -> Option> { + self.0.query(id).or_else(|| self.1.query(id)) + } -const NULL: NullFilter = NullFilter; + #[inline] + fn shutdown(&self, io: &IoRef, stack: &mut Stack, idx: usize) -> io::Result> { + let result1 = stack.write_buf(io, idx, |buf| self.0.shutdown(buf))?; + self.process_write_buf(io, stack, idx)?; -impl NullFilter { - pub(super) fn get() -> &'static dyn Filter { - &NULL + let result2 = if F::BUFFERS { + self.1.shutdown(io, stack, idx + 1)? + } else { + self.1.shutdown(io, stack, idx)? + }; + + if result1.is_pending() || result2.is_pending() { + Ok(Poll::Pending) + } else { + Ok(Poll::Ready(())) + } + } + + #[inline] + fn process_read_buf( + &self, + io: &IoRef, + stack: &mut Stack, + idx: usize, + nbytes: usize, + ) -> io::Result { + let nbytes = if F::BUFFERS { + self.1.process_read_buf(io, stack, idx + 1, nbytes)? + } else { + self.1.process_read_buf(io, stack, idx, nbytes)? + }; + stack.read_buf(io, idx, nbytes, |buf| self.0.process_read_buf(buf)) + } + + #[inline] + fn process_write_buf( + &self, + io: &IoRef, + stack: &mut Stack, + idx: usize, + ) -> io::Result<()> { + stack.write_buf(io, idx, |buf| self.0.process_write_buf(buf))?; + + if F::BUFFERS { + self.1.process_write_buf(io, stack, idx + 1) + } else { + self.1.process_write_buf(io, stack, idx) + } + } + + #[inline] + fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll { + let res1 = self.0.poll_read_ready(cx); + let res2 = self.1.poll_read_ready(cx); + + match res1 { + Poll::Pending => Poll::Pending, + Poll::Ready(ReadStatus::Ready) => res2, + Poll::Ready(ReadStatus::Terminate) => Poll::Ready(ReadStatus::Terminate), + } + } + + #[inline] + fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll { + let res1 = self.0.poll_write_ready(cx); + let res2 = self.1.poll_write_ready(cx); + + match res1 { + Poll::Pending => Poll::Pending, + Poll::Ready(WriteStatus::Ready) => res2, + Poll::Ready(WriteStatus::Terminate) => Poll::Ready(WriteStatus::Terminate), + Poll::Ready(WriteStatus::Shutdown(t)) => { + if res2 == Poll::Ready(WriteStatus::Terminate) { + Poll::Ready(WriteStatus::Terminate) + } else { + Poll::Ready(WriteStatus::Shutdown(t)) + } + } + Poll::Ready(WriteStatus::Timeout(t)) => match res2 { + Poll::Ready(WriteStatus::Terminate) => Poll::Ready(WriteStatus::Terminate), + Poll::Ready(WriteStatus::Shutdown(t)) => { + Poll::Ready(WriteStatus::Shutdown(t)) + } + _ => Poll::Ready(WriteStatus::Timeout(t)), + }, + } } } @@ -121,11 +237,6 @@ impl Filter for NullFilter { None } - #[inline] - fn poll_shutdown(&self) -> Poll> { - Poll::Ready(Ok(())) - } - #[inline] fn poll_read_ready(&self, _: &mut Context<'_>) -> Poll { Poll::Ready(ReadStatus::Terminate) @@ -137,25 +248,23 @@ impl Filter for NullFilter { } #[inline] - fn get_read_buf(&self) -> Option { - None + fn process_read_buf( + &self, + _: &IoRef, + _: &mut Stack, + _: usize, + _: usize, + ) -> io::Result { + Ok(0) } #[inline] - fn get_write_buf(&self) -> Option { - None - } - - #[inline] - fn release_read_buf(&self, _: BytesVec) {} - - #[inline] - fn process_read_buf(&self, _: &IoRef, _: usize) -> io::Result<(usize, usize)> { - Ok((0, 0)) - } - - #[inline] - fn release_write_buf(&self, _: BytesVec) -> Result<(), io::Error> { + fn process_write_buf(&self, _: &IoRef, _: &mut Stack, _: usize) -> io::Result<()> { Ok(()) } + + #[inline] + fn shutdown(&self, _: &IoRef, _: &mut Stack, _: usize) -> io::Result> { + Ok(Poll::Ready(())) + } } diff --git a/ntex-io/src/io.rs b/ntex-io/src/io.rs index d0d86a41..362134f6 100644 --- a/ntex-io/src/io.rs +++ b/ntex-io/src/io.rs @@ -1,4 +1,4 @@ -use std::cell::Cell; +use std::cell::{Cell, RefCell}; use std::task::{Context, Poll}; use std::{fmt, future::Future, hash, io, marker, mem, ops, pin::Pin, ptr, rc::Rc, time}; @@ -7,10 +7,11 @@ use ntex_codec::{Decoder, Encoder}; use ntex_util::time::{now, Millis}; use ntex_util::{future::poll_fn, future::Either, task::LocalWaker}; -use super::filter::{Base, NullFilter}; -use super::seal::Sealed; -use super::tasks::{ReadContext, WriteContext}; -use super::{Filter, FilterFactory, Handle, IoStatusUpdate, IoStream, RecvError}; +use crate::buf::Stack; +use crate::filter::{Base, Filter, Layer, NullFilter}; +use crate::seal::Sealed; +use crate::tasks::{ReadContext, WriteContext}; +use crate::{FilterLayer, Handle, IoStatusUpdate, IoStream, RecvError}; bitflags::bitflags! { pub struct Flags: u16 { @@ -59,8 +60,7 @@ pub(crate) struct IoState { pub(super) read_task: LocalWaker, pub(super) write_task: LocalWaker, pub(super) dispatch_task: LocalWaker, - pub(super) read_buf: Cell>, - pub(super) write_buf: Cell>, + pub(super) buffer: RefCell, pub(super) filter: Cell<&'static dyn Filter>, pub(super) handle: Cell>>, #[allow(clippy::box_collection)] @@ -104,7 +104,6 @@ impl IoState { } } - #[inline] pub(super) fn io_stopped(&self, err: Option) { if err.is_some() { self.error.set(err); @@ -119,9 +118,8 @@ impl IoState { ); } - #[inline] /// Gracefully shutdown read and write io tasks - pub(super) fn init_shutdown(&self, err: Option) { + pub(super) fn init_shutdown(&self, err: Option, io: &IoRef) { if err.is_some() { self.io_stopped(err); } else if !self @@ -131,28 +129,25 @@ impl IoState { { log::trace!("initiate io shutdown {:?}", self.flags.get()); self.insert_flags(Flags::IO_STOPPING_FILTERS); - self.shutdown_filters(); + self.shutdown_filters(io); } } - #[inline] - pub(super) fn shutdown_filters(&self) { + pub(super) fn shutdown_filters(&self, io: &IoRef) { if !self .flags .get() .intersects(Flags::IO_STOPPED | Flags::IO_STOPPING) { - match self.filter.get().poll_shutdown() { - Poll::Ready(Ok(())) => { + let mut buffer = self.buffer.borrow_mut(); + match self.filter.get().shutdown(io, &mut buffer, 0) { + Ok(Poll::Ready(())) => { self.read_task.wake(); self.write_task.wake(); self.dispatch_task.wake(); self.insert_flags(Flags::IO_STOPPING); } - Poll::Ready(Err(err)) => { - self.io_stopped(Some(err)); - } - Poll::Pending => { + Ok(Poll::Pending) => { let flags = self.flags.get(); // check read buffer, if buffer is not consumed it is unlikely // that filter will properly complete shutdown @@ -165,40 +160,37 @@ impl IoState { self.insert_flags(Flags::IO_STOPPING); } } - } + Err(err) => { + self.io_stopped(Some(err)); + } + }; + self.write_task.wake(); } } - #[inline] pub(super) fn with_read_buf(&self, release: bool, f: Fn) -> Ret where Fn: FnOnce(&mut Option) -> Ret, { - let filter = self.filter.get(); - let mut buf = filter.get_read_buf(); - let result = f(&mut buf); + // use top most buffer + let mut buffer = self.buffer.borrow_mut(); + let buf = buffer.first_read_buf(); + let result = f(buf); - if let Some(buf) = buf { - if release { - // release buffer - if buf.is_empty() { - self.pool.get().release_read_buf(buf); - return result; - } + // release buffer + if release && buf.as_ref().map(|b| b.is_empty()).unwrap_or(false) { + if let Some(b) = buf.take() { + self.pool.get().release_read_buf(b); } - filter.release_read_buf(buf); } result } - #[inline] pub(super) fn with_write_buf(&self, f: Fn) -> Ret where Fn: FnOnce(&mut Option) -> Ret, { - let buf = self.write_buf.as_ptr(); - let ref_buf = unsafe { buf.as_mut().unwrap() }; - f(ref_buf) + f(self.buffer.borrow_mut().last_write_buf()) } } @@ -221,12 +213,7 @@ impl hash::Hash for IoState { impl Drop for IoState { #[inline] fn drop(&mut self) { - if let Some(buf) = self.read_buf.take() { - self.pool.get().release_read_buf(buf); - } - if let Some(buf) = self.write_buf.take() { - self.pool.get().release_write_buf(buf); - } + self.buffer.borrow_mut().release(self.pool.get()); } } @@ -248,8 +235,7 @@ impl Io { dispatch_task: LocalWaker::new(), read_task: LocalWaker::new(), write_task: LocalWaker::new(), - read_buf: Cell::new(None), - write_buf: Cell::new(None), + buffer: RefCell::new(Stack::new()), filter: Cell::new(NullFilter::get()), handle: Cell::new(None), on_disconnect: Cell::new(None), @@ -277,14 +263,7 @@ impl Io { #[inline] /// Set memory pool pub fn set_memory_pool(&self, pool: PoolRef) { - if let Some(mut buf) = self.0 .0.read_buf.take() { - pool.move_vec_in(&mut buf); - self.0 .0.read_buf.set(Some(buf)); - } - if let Some(mut buf) = self.0 .0.write_buf.take() { - pool.move_vec_in(&mut buf); - self.0 .0.write_buf.set(Some(buf)); - } + self.0 .0.buffer.borrow_mut().set_memory_pool(pool); self.0 .0.pool.set(pool); } @@ -312,8 +291,7 @@ impl Io { dispatch_task: LocalWaker::new(), read_task: LocalWaker::new(), write_task: LocalWaker::new(), - read_buf: Cell::new(None), - write_buf: Cell::new(None), + buffer: RefCell::new(Stack::new()), filter: Cell::new(NullFilter::get()), handle: Cell::new(None), on_disconnect: Cell::new(None), @@ -353,57 +331,37 @@ impl Io { } impl Io { - #[inline] - /// Get referece to a filter - pub fn filter(&self) -> &F { - self.1.filter() - } - #[inline] /// Convert current io stream into sealed version pub fn seal(mut self) -> Io { - // get current filter - let filter = unsafe { - let filter = self.1.seal(); - let filter_ref: &'static dyn Filter = { - let filter: &dyn Filter = filter.0.as_ref(); - std::mem::transmute(filter) - }; - self.0 .0.filter.replace(filter_ref); - filter - }; - - Io(self.0.clone(), FilterItem::with_sealed(filter)) - } - - #[inline] - /// Create new filter and replace current one - pub fn add_filter(self, factory: T) -> T::Future - where - T: FilterFactory, - { - factory.create(self) + let (filter, filter_ref) = self.1.seal(); + self.0 .0.filter.replace(filter_ref); + Io(self.0.clone(), filter) } #[inline] /// Map current filter with new one - pub fn map_filter(mut self, map: U) -> Result, E> + pub fn add_filter(mut self, nf: U) -> Io> where - T: Filter, - U: FnOnce(F) -> Result, + U: FilterLayer, { - // replace current filter - let filter = unsafe { - let filter = Box::new(map(*(self.1.get_filter()))?); - let filter_ref: &'static dyn Filter = { - let filter: &dyn Filter = filter.as_ref(); - std::mem::transmute(filter) - }; - self.0 .0.filter.replace(filter_ref); - filter - }; + // add layer to buffers + if U::BUFFERS { + self.0 .0.buffer.borrow_mut().add_layer(); + } - Ok(Io(self.0.clone(), FilterItem::with_filter(filter))) + // replace current filter + let (filter, filter_ref) = self.1.add_filter(nf); + self.0 .0.filter.replace(filter_ref); + Io(self.0.clone(), filter) + } +} + +impl Io> { + #[inline] + /// Get referece to a filter + pub fn filter(&self) -> &F { + &self.1.filter().0 } } @@ -629,8 +587,10 @@ impl Io { } } else { if !flags.contains(Flags::IO_STOPPING_FILTERS) { - self.0 .0.init_shutdown(None); + self.0 .0.init_shutdown(None, &self.0); } + + self.0 .0.read_task.wake(); self.0 .0.dispatch_task.register(cx.waker()); Poll::Pending } @@ -759,20 +719,6 @@ impl FilterItem { slf } - fn with_sealed(f: Sealed) -> Self { - let mut slf = Self { - data: [0; SEALED_SIZE], - _t: marker::PhantomData, - }; - - unsafe { - let ptr = &mut slf.data as *mut _ as *mut Sealed; - ptr.write(f); - slf.data[KIND_IDX] |= KIND_SEALED; - } - slf - } - /// Get filter, panic if it is not filter fn filter(&self) -> &F { if self.data[KIND_IDX] & KIND_PTR != 0 { @@ -786,8 +732,8 @@ impl FilterItem { } } - /// Get filter, panic if it is not filter - fn get_filter(&mut self) -> Box { + /// Get filter, panic if it is not set + fn take_filter(&mut self) -> Box { if self.data[KIND_IDX] & KIND_PTR != 0 { self.data[KIND_IDX] &= KIND_UNMASK; let ptr = &mut self.data as *mut _ as *mut *mut F; @@ -801,7 +747,7 @@ impl FilterItem { } /// Get sealed, panic if it is already sealed - fn get_sealed(&mut self) -> Sealed { + fn take_sealed(&mut self) -> Sealed { if self.data[KIND_IDX] & KIND_SEALED != 0 { self.data[KIND_IDX] &= KIND_UNMASK; let ptr = &mut self.data as *mut _ as *mut Sealed; @@ -820,25 +766,54 @@ impl FilterItem { fn drop_filter(&mut self) { if self.data[KIND_IDX] & KIND_PTR != 0 { - self.get_filter(); + self.take_filter(); } else if self.data[KIND_IDX] & KIND_SEALED != 0 { - self.get_sealed(); + self.take_sealed(); } } } impl FilterItem { - fn seal(&mut self) -> Sealed { - if self.data[KIND_IDX] & KIND_PTR != 0 { - Sealed(Box::new(*self.get_filter())) + fn add_filter( + &mut self, + new: T, + ) -> (FilterItem>, &'static dyn Filter) { + let filter = Box::new(Layer::new(new, *self.take_filter())); + let filter_ref: &'static dyn Filter = { + let filter: &dyn Filter = filter.as_ref(); + unsafe { std::mem::transmute(filter) } + }; + (FilterItem::with_filter(filter), filter_ref) + } + + fn seal(&mut self) -> (FilterItem, &'static dyn Filter) { + let filter = if self.data[KIND_IDX] & KIND_PTR != 0 { + Sealed(Box::new(*self.take_filter())) } else if self.data[KIND_IDX] & KIND_SEALED != 0 { - self.get_sealed() + self.take_sealed() } else { panic!( "Wrong filter item {:?} expected: {:?}", self.data[KIND_IDX], KIND_PTR ); + }; + + let filter_ref: &'static dyn Filter = { + let filter: &dyn Filter = filter.0.as_ref(); + unsafe { std::mem::transmute(filter) } + }; + + let mut slf = FilterItem { + data: [0; SEALED_SIZE], + _t: marker::PhantomData, + }; + + unsafe { + let ptr = &mut slf.data as *mut _ as *mut Sealed; + ptr.write(filter); + slf.data[KIND_IDX] |= KIND_SEALED; } + (slf, filter_ref) } } diff --git a/ntex-io/src/ioref.rs b/ntex-io/src/ioref.rs index 12b548dc..bab38364 100644 --- a/ntex-io/src/ioref.rs +++ b/ntex-io/src/ioref.rs @@ -1,9 +1,9 @@ use std::{any, fmt, hash, io, time}; -use ntex_bytes::{BufMut, BytesVec, PoolRef}; +use ntex_bytes::{BytesVec, PoolRef}; use ntex_codec::{Decoder, Encoder}; -use super::{io::Flags, timer, types, Filter, IoRef, OnDisconnect}; +use super::{io::Flags, timer, types, Filter, IoRef, OnDisconnect, WriteBuf}; impl IoRef { #[inline] @@ -49,7 +49,7 @@ impl IoRef { /// Notify dispatcher and initiate io stream shutdown process. pub fn close(&self) { self.0.insert_flags(Flags::DSP_STOP); - self.0.init_shutdown(None); + self.0.init_shutdown(None, self); } #[inline] @@ -72,8 +72,16 @@ impl IoRef { #[inline] /// Gracefully shutdown io stream - pub fn want_shutdown(&self, err: Option) { - self.0.init_shutdown(err); + pub fn want_shutdown(&self) { + if !self + .0 + .flags + .get() + .intersects(Flags::IO_STOPPED | Flags::IO_STOPPING | Flags::IO_STOPPING_FILTERS) + { + log::trace!("initiate io shutdown {:?}", self.0.flags.get()); + self.0.insert_flags(Flags::IO_STOPPING_FILTERS); + } } #[inline] @@ -96,13 +104,8 @@ impl IoRef { if !flags.contains(Flags::IO_STOPPING) { self.with_write_buf(|buf| { - let (hw, lw) = self.memory_pool().write_params().unpack(); - // make sure we've got room - let remaining = buf.remaining_mut(); - if remaining < lw { - buf.reserve(hw - remaining); - } + self.memory_pool().resize_write_buf(buf); // encode item and wake write task codec.encode_vec(item, buf) @@ -151,21 +154,35 @@ impl IoRef { #[inline] /// Get mut access to write buffer - pub fn with_write_buf(&self, f: F) -> Result + pub fn with_write_buf(&self, f: F) -> io::Result where F: FnOnce(&mut BytesVec) -> R, { - let filter = self.0.filter.get(); - let mut buf = filter - .get_write_buf() - .unwrap_or_else(|| self.memory_pool().get_write_buf()); - let is_write_sleep = buf.is_empty(); + let mut buffer = self.0.buffer.borrow_mut(); + let is_write_sleep = buffer.last_write_buf_size() == 0; - let result = f(&mut buf); - if is_write_sleep { + let result = f(buffer.first_write_buf(self)); + self.0 + .filter + .get() + .process_write_buf(self, &mut buffer, 0)?; + + if is_write_sleep && buffer.last_write_buf_size() != 0 { self.0.write_task.wake(); } - filter.release_write_buf(buf)?; + Ok(result) + } + + #[inline] + /// Get mut access to write buffer + pub fn with_buf(&self, f: F) -> io::Result + where + F: FnOnce(&mut WriteBuf<'_>) -> R, + { + let mut b = self.0.buffer.borrow_mut(); + let result = b.write_buf(self, 0, f); + self.0.filter.get().process_write_buf(self, &mut b, 0)?; + self.0.write_task.wake(); Ok(result) } @@ -240,16 +257,15 @@ impl fmt::Debug for IoRef { #[cfg(test)] mod tests { use std::cell::{Cell, RefCell}; - use std::{future::Future, pin::Pin, rc::Rc, task::Context, task::Poll}; + use std::{future::Future, pin::Pin, rc::Rc, task::Poll}; use ntex_bytes::Bytes; use ntex_codec::BytesCodec; - use ntex_util::future::{lazy, poll_fn, Ready}; + use ntex_util::future::{lazy, poll_fn}; use ntex_util::time::{sleep, Millis}; use super::*; - use crate::testing::IoTest; - use crate::{Filter, FilterFactory, Io, ReadStatus, WriteStatus}; + use crate::{testing::IoTest, FilterLayer, Io, ReadBuf, WriteBuf}; const BIN: &[u8] = b"GET /test HTTP/1\r\n\r\n"; const TEXT: &str = "GET /test HTTP/1\r\n\r\n"; @@ -370,87 +386,28 @@ mod tests { assert_eq!(waiter.await, ()); } - struct Counter { + struct Counter { idx: usize, - inner: F, in_bytes: Rc>, out_bytes: Rc>, read_order: Rc>>, write_order: Rc>>, } - impl Filter for Counter { - fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll { - self.inner.poll_read_ready(cx) - } - fn get_read_buf(&self) -> Option { - self.inner.get_read_buf() - } + impl FilterLayer for Counter { + const BUFFERS: bool = false; - fn release_read_buf(&self, buf: BytesVec) { - self.inner.release_read_buf(buf) - } - - fn process_read_buf(&self, io: &IoRef, n: usize) -> io::Result<(usize, usize)> { - let result = self.inner.process_read_buf(io, n)?; + fn process_read_buf(&self, buf: &mut ReadBuf<'_>) -> io::Result { self.read_order.borrow_mut().push(self.idx); - self.in_bytes.set(self.in_bytes.get() + result.1); - Ok(result) + self.in_bytes.set(self.in_bytes.get() + buf.nbytes()); + Ok(buf.nbytes()) } - fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll { - self.inner.poll_write_ready(cx) - } - - fn get_write_buf(&self) -> Option { - if let Some(buf) = self.inner.get_write_buf() { - self.out_bytes.set(self.out_bytes.get() - buf.len()); - Some(buf) - } else { - None - } - } - - fn release_write_buf(&self, buf: BytesVec) -> Result<(), io::Error> { + fn process_write_buf(&self, buf: &mut WriteBuf<'_>) -> io::Result<()> { self.write_order.borrow_mut().push(self.idx); - self.out_bytes.set(self.out_bytes.get() + buf.len()); - self.inner.release_write_buf(buf) - } - } - - struct CounterFactory( - usize, - Rc>, - Rc>, - Rc>>, - Rc>>, - ); - - impl FilterFactory for CounterFactory { - type Filter = Counter; - - type Error = (); - type Future = Ready>, Self::Error>; - - fn create(self, io: Io) -> Self::Future { - let idx = self.0; - let in_bytes = self.1.clone(); - let out_bytes = self.2.clone(); - let read_order = self.3.clone(); - let write_order = self.4; - Ready::Ok( - io.map_filter(|inner| { - Ok::<_, ()>(Counter { - idx, - inner, - in_bytes, - out_bytes, - read_order, - write_order, - }) - }) - .unwrap(), - ) + self.out_bytes + .set(self.out_bytes.get() + buf.get_dst().len()); + Ok(()) } } @@ -460,24 +417,22 @@ mod tests { let out_bytes = Rc::new(Cell::new(0)); let read_order = Rc::new(RefCell::new(Vec::new())); let write_order = Rc::new(RefCell::new(Vec::new())); - let factory = CounterFactory( - 1, - in_bytes.clone(), - out_bytes.clone(), - read_order.clone(), - write_order.clone(), - ); let (client, server) = IoTest::create(); - let state = Io::new(server).add_filter(factory).await.unwrap(); + let io = Io::new(server).add_filter(Counter { + idx: 1, + in_bytes: in_bytes.clone(), + out_bytes: out_bytes.clone(), + read_order: read_order.clone(), + write_order: write_order.clone(), + }); client.remote_buffer_cap(1024); client.write(TEXT); - let msg = state.recv(&BytesCodec).await.unwrap().unwrap(); + let msg = io.recv(&BytesCodec).await.unwrap().unwrap(); assert_eq!(msg, Bytes::from_static(BIN)); - state - .send(Bytes::from_static(b"test"), &BytesCodec) + io.send(Bytes::from_static(b"test"), &BytesCodec) .await .unwrap(); let buf = client.read().await.unwrap(); @@ -496,24 +451,20 @@ mod tests { let (client, server) = IoTest::create(); let state = Io::new(server) - .add_filter(CounterFactory( - 1, - in_bytes.clone(), - out_bytes.clone(), - read_order.clone(), - write_order.clone(), - )) - .await - .unwrap() - .add_filter(CounterFactory( - 2, - in_bytes.clone(), - out_bytes.clone(), - read_order.clone(), - write_order.clone(), - )) - .await - .unwrap(); + .add_filter(Counter { + idx: 1, + in_bytes: in_bytes.clone(), + out_bytes: out_bytes.clone(), + read_order: read_order.clone(), + write_order: write_order.clone(), + }) + .add_filter(Counter { + idx: 2, + in_bytes: in_bytes.clone(), + out_bytes: out_bytes.clone(), + read_order: read_order.clone(), + write_order: write_order.clone(), + }); let state = state.seal(); client.remote_buffer_cap(1024); diff --git a/ntex-io/src/lib.rs b/ntex-io/src/lib.rs index c59dcacd..87a307b1 100644 --- a/ntex-io/src/lib.rs +++ b/ntex-io/src/lib.rs @@ -7,6 +7,7 @@ use std::{ pub mod testing; pub mod types; +mod buf; mod dispatcher; mod filter; mod framed; @@ -17,12 +18,12 @@ mod tasks; mod timer; mod utils; -use ntex_bytes::BytesVec; use ntex_codec::{Decoder, Encoder}; use ntex_util::time::Millis; +pub use self::buf::{ReadBuf, WriteBuf}; pub use self::dispatcher::Dispatcher; -pub use self::filter::Base; +pub use self::filter::{Base, Filter, Layer}; pub use self::framed::Framed; pub use self::io::{Io, IoRef, OnDisconnect}; pub use self::seal::{IoBoxed, Sealed}; @@ -49,44 +50,51 @@ pub enum WriteStatus { Terminate, } -pub trait Filter: 'static { - fn query(&self, _: TypeId) -> Option> { - None +#[allow(unused_variables)] +pub trait FilterLayer: 'static { + /// Create buffers for this filter + const BUFFERS: bool = true; + + #[inline] + /// Check readiness for read operations + fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll { + Poll::Ready(ReadStatus::Ready) } - fn get_read_buf(&self) -> Option; - - fn release_read_buf(&self, buf: BytesVec); + #[inline] + /// Check readiness for write operations + fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll { + Poll::Ready(WriteStatus::Ready) + } /// Process read buffer /// - /// Returns tuple (total bytes, new bytes) - fn process_read_buf(&self, io: &IoRef, n: usize) -> sio::Result<(usize, usize)>; + /// Inner filter must process buffer before current. + /// Returns number of new bytes. + fn process_read_buf(&self, buf: &mut ReadBuf<'_>) -> sio::Result; - fn get_write_buf(&self) -> Option; + /// Process write buffer + fn process_write_buf(&self, buf: &mut WriteBuf<'_>) -> sio::Result<()>; - fn release_write_buf(&self, buf: BytesVec) -> sio::Result<()>; - - /// Check readiness for read operations - fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll; - - /// Check readiness for write operations - fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll; + /// Query internal filter data + fn query(&self, id: TypeId) -> Option> { + None + } /// Gracefully shutdown filter - fn poll_shutdown(&self) -> Poll> { - Poll::Ready(Ok(())) + fn shutdown(&self, buf: &mut WriteBuf<'_>) -> sio::Result> { + Ok(Poll::Ready(())) } } /// Creates new `Filter` values. -pub trait FilterFactory: Sized { +pub trait FilterFactory: Sized { /// The `Filter` value created by this factory - type Filter: Filter; + type Filter: FilterLayer; /// Errors produced while building a filter. type Error: fmt::Debug; /// The future of the `FilterFactory` instance. - type Future: Future, Self::Error>>; + type Future: Future>, Self::Error>>; /// Create and return a new filter value asynchronously. fn create(self, st: Io) -> Self::Future; diff --git a/ntex-io/src/seal.rs b/ntex-io/src/seal.rs index 347cd608..018da256 100644 --- a/ntex-io/src/seal.rs +++ b/ntex-io/src/seal.rs @@ -1,6 +1,6 @@ use std::ops; -use crate::{Filter, Io}; +use crate::{filter::Filter, Io}; /// Sealed filter type pub struct Sealed(pub(crate) Box); diff --git a/ntex-io/src/tasks.rs b/ntex-io/src/tasks.rs index e04c3a2c..1a874096 100644 --- a/ntex-io/src/tasks.rs +++ b/ntex-io/src/tasks.rs @@ -12,62 +12,93 @@ impl ReadContext { Self(io.clone()) } - #[inline] - /// Return memory pool for this context - pub fn memory_pool(&self) -> PoolRef { - self.0.memory_pool() - } - #[inline] /// Check readiness for read operations pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll { self.0.filter().poll_read_ready(cx) } - #[inline] /// Get read buffer - pub fn get_read_buf(&self) -> BytesVec { - self.0 - .0 - .read_buf + pub fn with_buf(&self, f: F) -> Poll<()> + where + F: FnOnce(&mut BytesVec, usize, usize) -> Poll>, + { + let mut stack = self.0 .0.buffer.borrow_mut(); + let mut buf = stack + .last_read_buf() .take() - .unwrap_or_else(|| self.0.memory_pool().get_read_buf()) - } + .unwrap_or_else(|| self.0.memory_pool().get_read_buf()); - #[inline] - /// Release read buffer after io read operations - pub fn release_read_buf(&self, buf: BytesVec, nbytes: usize) { + let total = buf.len(); + let (hw, lw) = self.0.memory_pool().read_params().unpack(); + + // call provided callback + let result = f(&mut buf, hw, lw); + + // handle buffer changes if buf.is_empty() { self.0.memory_pool().release_read_buf(buf); } else { - self.0 .0.read_buf.set(Some(buf)); - let filter = self.0.filter(); - match filter.process_read_buf(&self.0, nbytes) { - Ok((total, nbytes)) => { - if nbytes > 0 { - if total > self.0.memory_pool().read_params().high as usize { + let total2 = buf.len(); + let nbytes = if total2 > total { total2 - total } else { 0 }; + *stack.last_read_buf() = Some(buf); + + if nbytes > 0 { + let buf_full = nbytes >= hw; + match self + .0 + .filter() + .process_read_buf(&self.0, &mut stack, 0, nbytes) + { + Ok(nbytes) => { + if nbytes > 0 { + if buf_full || stack.first_read_buf_size() >= hw { + log::trace!( + "io read buffer is too large {}, enable read back-pressure", + total2 + ); + self.0 + .0 + .insert_flags(Flags::RD_READY | Flags::RD_BUF_FULL); + } else { + self.0 .0.insert_flags(Flags::RD_READY); + } + self.0 .0.dispatch_task.wake(); log::trace!( - "buffer is too large {}, enable read back-pressure", - total + "new {} bytes available, wakeup dispatcher", + nbytes, ); - self.0 .0.insert_flags(Flags::RD_READY | Flags::RD_BUF_FULL); + } else if buf_full { + // read task is paused because of read back-pressure + self.0 .0.read_task.wake(); } + } + Err(err) => { self.0 .0.dispatch_task.wake(); self.0 .0.insert_flags(Flags::RD_READY); - log::trace!("new {} bytes available, wakeup dispatcher", nbytes); + self.0 .0.init_shutdown(Some(err), &self.0); } } - Err(err) => { - self.0 .0.dispatch_task.wake(); - self.0 .0.insert_flags(Flags::RD_READY); - self.0.want_shutdown(Some(err)); - } } } + let result = match result { + Poll::Ready(Ok(())) => { + self.0 .0.io_stopped(None); + Poll::Ready(()) + } + Poll::Ready(Err(e)) => { + self.0 .0.io_stopped(Some(e)); + Poll::Ready(()) + } + Poll::Pending => Poll::Pending, + }; + + drop(stack); if self.0.flags().contains(Flags::IO_STOPPING_FILTERS) { - self.0 .0.shutdown_filters(); + self.0 .0.shutdown_filters(&self.0); } + result } #[inline] @@ -100,7 +131,7 @@ impl WriteContext { #[inline] /// Get write buffer pub fn get_write_buf(&self) -> Option { - self.0 .0.write_buf.take() + self.0 .0.buffer.borrow_mut().last_write_buf().take() } #[inline] @@ -125,11 +156,11 @@ impl WriteContext { self.0.set_flags(flags); self.0 .0.dispatch_task.wake(); } - self.0 .0.write_buf.set(Some(buf)) + self.0 .0.buffer.borrow_mut().set_last_write_buf(buf); } if self.0.flags().contains(Flags::IO_STOPPING_FILTERS) { - self.0 .0.shutdown_filters(); + self.0 .0.shutdown_filters(&self.0); } Ok(()) diff --git a/ntex-io/src/testing.rs b/ntex-io/src/testing.rs index e903a554..901d19e4 100644 --- a/ntex-io/src/testing.rs +++ b/ntex-io/src/testing.rs @@ -344,6 +344,12 @@ impl Drop for IoTest { _ => (), } self.state.set(state); + + let guard = self.remote.lock().unwrap(); + let mut remote = guard.borrow_mut(); + remote.read = IoTestState::Close; + remote.waker.wake(); + log::trace!("drop remote socket"); } } @@ -388,58 +394,58 @@ impl Future for ReadTask { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.as_ref(); - match this.state.poll_ready(cx) { - Poll::Ready(ReadStatus::Terminate) => { - log::trace!("read task is instructed to terminate"); - Poll::Ready(()) - } - Poll::Ready(ReadStatus::Ready) => { - let io = &this.io; - let pool = this.state.memory_pool(); - let mut buf = self.state.get_read_buf(); - let (hw, lw) = pool.read_params().unpack(); + this.state.with_buf(|buf, hw, lw| { + match this.state.poll_ready(cx) { + Poll::Ready(ReadStatus::Terminate) => { + log::trace!("read task is instructed to terminate"); + Poll::Ready(Ok(())) + } + Poll::Ready(ReadStatus::Ready) => { + let io = &this.io; - // read data from socket - let mut new_bytes = 0; - loop { - // make sure we've got room - let remaining = buf.remaining_mut(); - if remaining < lw { - buf.reserve(hw - remaining); - } - - match io.poll_read_buf(cx, &mut buf) { - Poll::Pending => { - log::trace!("no more data in io stream, read: {:?}", new_bytes); - break; + // read data from socket + let mut new_bytes = 0; + loop { + // make sure we've got room + let remaining = buf.remaining_mut(); + if remaining < lw { + buf.reserve(hw - remaining); } - Poll::Ready(Ok(n)) => { - if n == 0 { - log::trace!("io stream is disconnected"); - this.state.release_read_buf(buf, new_bytes); - this.state.close(None); - return Poll::Ready(()); - } else { - new_bytes += n; - if buf.len() > hw { - break; + match io.poll_read_buf(cx, buf) { + Poll::Pending => { + log::trace!( + "no more data in io stream, read: {:?}", + new_bytes + ); + break; + } + Poll::Ready(Ok(n)) => { + if n == 0 { + log::trace!("io stream is disconnected"); + return Poll::Ready(Ok(())); + } else { + new_bytes += n; + if buf.len() >= hw { + log::trace!( + "high water mark pause reading, read: {:?}", + new_bytes + ); + break; + } } } - } - Poll::Ready(Err(err)) => { - log::trace!("read task failed on io {:?}", err); - this.state.release_read_buf(buf, new_bytes); - this.state.close(Some(err)); - return Poll::Ready(()); + Poll::Ready(Err(err)) => { + log::trace!("read task failed on io {:?}", err); + return Poll::Ready(Err(err)); + } } } - } - this.state.release_read_buf(buf, new_bytes); - Poll::Pending + Poll::Pending + } + Poll::Pending => Poll::Pending, } - Poll::Pending => Poll::Pending, - } + }) } } diff --git a/ntex-io/src/utils.rs b/ntex-io/src/utils.rs index 97f8cc06..bfe4aac6 100644 --- a/ntex-io/src/utils.rs +++ b/ntex-io/src/utils.rs @@ -3,7 +3,7 @@ use std::marker::PhantomData; use ntex_service::{fn_service, pipeline_factory, Service, ServiceFactory}; use ntex_util::future::Ready; -use crate::{Filter, FilterFactory, Io, IoBoxed}; +use crate::{Filter, FilterFactory, Io, IoBoxed, Layer}; /// Service that converts any Io stream to IoBoxed stream pub fn seal( @@ -30,7 +30,6 @@ where pub fn filter(filter: T) -> FilterServiceFactory where T: FilterFactory + Clone, - F: Filter, { FilterServiceFactory { filter, @@ -46,9 +45,8 @@ pub struct FilterServiceFactory { impl ServiceFactory> for FilterServiceFactory where T: FilterFactory + Clone, - F: Filter, { - type Response = Io; + type Response = Io>; type Error = T::Error; type Service = FilterService; type InitError = (); @@ -71,25 +69,28 @@ pub struct FilterService { impl Service> for FilterService where T: FilterFactory + Clone, - F: Filter, { - type Response = Io; + type Response = Io>; type Error = T::Error; - type Future<'f> = T::Future where T: 'f; + type Future<'f> = T::Future where T: 'f, F: 'f; #[inline] fn call(&self, req: Io) -> Self::Future<'_> { - req.add_filter(self.filter.clone()) + self.filter.clone().create(req) } } #[cfg(test)] mod tests { - use ntex_bytes::{Bytes, BytesVec}; + use std::io; + + use ntex_bytes::Bytes; use ntex_codec::BytesCodec; use super::*; - use crate::{filter::NullFilter, testing::IoTest}; + use crate::{ + buf::Stack, filter::NullFilter, testing::IoTest, FilterLayer, ReadBuf, WriteBuf, + }; #[ntex::test] async fn test_utils() { @@ -114,16 +115,28 @@ mod tests { assert_eq!(buf, b"RES".as_ref()); } - #[derive(Copy, Clone, Debug)] - struct NullFilterFactory; + pub(crate) struct TestFilter; - impl FilterFactory for NullFilterFactory { - type Filter = crate::filter::NullFilter; + impl FilterLayer for TestFilter { + fn process_read_buf(&self, buf: &mut ReadBuf<'_>) -> io::Result { + Ok(buf.nbytes()) + } + + fn process_write_buf(&self, _: &mut WriteBuf<'_>) -> io::Result<()> { + Ok(()) + } + } + + #[derive(Copy, Clone, Debug)] + struct TestFilterFactory; + + impl FilterFactory for TestFilterFactory { + type Filter = TestFilter; type Error = std::convert::Infallible; - type Future = Ready, Self::Error>; + type Future = Ready>, Self::Error>; fn create(self, st: Io) -> Self::Future { - st.map_filter(|_| Ok(NullFilter)).into() + Ready::Ok(st.add_filter(TestFilter).into()) } } @@ -131,7 +144,7 @@ mod tests { async fn test_utils_filter() { let (_, server) = IoTest::create(); let svc = pipeline_factory( - filter::<_, crate::filter::Base>(NullFilterFactory) + filter::<_, crate::filter::Base>(TestFilterFactory) .map_err(|_| ()) .map_init_err(|_| ()), ) @@ -147,8 +160,15 @@ mod tests { #[ntex::test] async fn test_null_filter() { + let (_, server) = IoTest::create(); + let io = Io::new(server); + let ioref = io.get_ref(); + let mut stack = Stack::new(); assert!(NullFilter.query(std::any::TypeId::of::<()>()).is_none()); - assert!(NullFilter.poll_shutdown().is_ready()); + assert!(NullFilter + .shutdown(&ioref, &mut stack, 0) + .unwrap() + .is_ready()); assert_eq!( ntex_util::future::poll_fn(|cx| NullFilter.poll_read_ready(cx)).await, crate::ReadStatus::Terminate @@ -157,16 +177,12 @@ mod tests { ntex_util::future::poll_fn(|cx| NullFilter.poll_write_ready(cx)).await, crate::WriteStatus::Terminate ); - assert_eq!(NullFilter.get_read_buf(), None); - assert_eq!(NullFilter.get_write_buf(), None); - assert!(NullFilter.release_write_buf(BytesVec::new()).is_ok()); - NullFilter.release_read_buf(BytesVec::new()); - - let (_, server) = IoTest::create(); - let io = Io::new(server); + assert!(NullFilter.process_write_buf(&ioref, &mut stack, 0).is_ok()); assert_eq!( - NullFilter.process_read_buf(&io.get_ref(), 10).unwrap(), - (0, 0) + NullFilter + .process_read_buf(&ioref, &mut stack, 0, 0) + .unwrap(), + (0) ) } } diff --git a/ntex-tls/CHANGES.md b/ntex-tls/CHANGES.md index 11f4d797..2c3ff422 100644 --- a/ntex-tls/CHANGES.md +++ b/ntex-tls/CHANGES.md @@ -1,5 +1,9 @@ # Changes +## [0.2.1] - 2023-01-23 + +* Update filter implementation + ## [0.2.0] - 2023-01-04 * Release diff --git a/ntex-tls/Cargo.toml b/ntex-tls/Cargo.toml index 7bd33ab1..933aaf61 100644 --- a/ntex-tls/Cargo.toml +++ b/ntex-tls/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-tls" -version = "0.2.0" +version = "0.2.1" authors = ["ntex contributors "] description = "An implementation of SSL streams for ntex backed by OpenSSL" keywords = ["network", "framework", "async", "futures"] @@ -25,15 +25,15 @@ openssl = ["tls_openssl"] rustls = ["tls_rust"] [dependencies] -ntex-bytes = "0.1.18" -ntex-io = "0.2.0" +ntex-bytes = "0.1.19" +ntex-io = "0.2.1" ntex-util = "0.2.0" ntex-service = "1.0.0" log = "0.4" pin-project-lite = "0.2" # openssl -tls_openssl = { version="0.10.42", package = "openssl", optional = true } +tls_openssl = { version="0.10", package = "openssl", optional = true } # rustls tls_rust = { version = "0.20", package = "rustls", optional = true } diff --git a/ntex-tls/examples/cert.pem b/ntex-tls/examples/cert.pem index bad428a9..159aacea 100644 --- a/ntex-tls/examples/cert.pem +++ b/ntex-tls/examples/cert.pem @@ -1,31 +1,31 @@ -----BEGIN CERTIFICATE----- -MIIFPjCCAyYCCQDWGwiaSniPcTANBgkqhkiG9w0BAQsFADBhMQswCQYDVQQGEwJV +MIIFPjCCAyYCCQDvLYiYD+jqeTANBgkqhkiG9w0BAQsFADBhMQswCQYDVQQGEwJV UzELMAkGA1UECAwCQ0ExCzAJBgNVBAcMAlNGMRAwDgYDVQQKDAdDb21wYW55MQww -CgYDVQQLDANPcmcxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTAeFw0yMTEyMTgx -NjMwNDlaFw0yMjEyMTgxNjMwNDlaMGExCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJD +CgYDVQQLDANPcmcxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTAeFw0xODAxMjUx +NzQ2MDFaFw0xOTAxMjUxNzQ2MDFaMGExCzAJBgNVBAYTAlVTMQswCQYDVQQIDAJD QTELMAkGA1UEBwwCU0YxEDAOBgNVBAoMB0NvbXBhbnkxDDAKBgNVBAsMA09yZzEY MBYGA1UEAwwPd3d3LmV4YW1wbGUuY29tMIICIjANBgkqhkiG9w0BAQEFAAOCAg8A -MIICCgKCAgEAryUL1k7npaMOck9OO+EjzeL0FoysOP5JrgRh+8BoPY7WPyL56oFP -aCYKp2YMucmvFh/VZSyupC75JJNIaW0fvcIe4Euzy2Ex0VukPxYteRicaWRsxSId -o5RNNHd7JOf3ZWIMqkxmDhPNnqSGHcnVs/14+I5IbJCoba+KNElmL9CrL3gQkqNY -Jf2FSIgou5j1OthEdnQpiRxSRLmJ7gXtvpFGgj4AnrHGsMAPHueeop6yOX6egFnw -2cwp98c/0tMOUsXnDU1MTGF11+4UVr043SruZKU7bvhMZRcf4NTR2MNin0b3DYJ+ -JbTn+HgPhhhx3mrsWRyCvfP23jzwnV/222o+U46i7tNYYrDN8vXIM17gtIvKrv2F -CLTJE6tsp0xAi6dT+J+AIVqkJntrsxqx2CuOYGOOkPPc4rSf64bwOR1mikdvZCnV -NwGEXcH3nBRFMlk5bByCW0kUy03QNakiUEF+PoFzLrCL+V+21Q6Fd7Jmw06BzVFV -2YtsqFcSo7HXW91XJTDVJCPnrMJOooKQ9Fbq4zbQM0Lv02LyJWyR+0PMBzy4FfkW -ZWz10g3w+CITL/MQ65fsBBc9hRHC3QBWetj3puqM8DlPwqPhgmCA5zo8AWx7CogR -V66ukkeBYXYFHwV5uDJTX91tbwYesOL43rlDT905aV0VbaAyDZflipMCAwEAATAN -BgkqhkiG9w0BAQsFAAOCAgEAWeq502+YKMHrk8YD4L2mzY/AHSEWf6XubMgkNRbh -s72+zJs2SrAzu+y+iv5La4JXOxrWEvZOUCKAK0sRG/+ESQxul5mbyPQLWFJgSqv5 -O2RmhQ65l+O6RjPZbHPNJMTLMMlkFrKctgGIg5ysKHWPEZZ7ZlS3maxon+X75/b5 -uI3BxBpJTWcg6zOxh0+zIxhesgEbRmaEz6qu3ZSktBeUQFpTElreCcbkntlFbr+9 -SiKkaO4l6qEwRDhA595/7/JRZo4R5U1MifU6JhTMOyXTsH3BV1aVeS81/9jGPHl8 -kgVxeKSpL/jDwuSJdr+dMxs/TJHV6fsnVewFFFmigLWThYGDnKmXqJQNyt8utRpe -6vvReWSSIece1EdBActy0rtjPaUJNTTdYk1UYo63OIbCguLWQD1XYN1qJg4KWJzB -PjS6KCOLmJvYrAxRMED4XeZ17+PxC3xr2IpAL+loRhZUuxXV4GhccGZ4z89OIdOU -x97x2BjjV5Nnnt6eBfF3vP5sOz31QpAS/8tzdlGD+6Xq2/i1ZKMPrwgs2dhTyah0 -kCBfdE88Zew/A79z55IsVNiYJ4MrD8WTFjcM2j8SgI7tg+M+X/unj+wnzYT0L0dg -BEfzPd7zWdDOPInlTV9zUj1WOsLHX9odOh/Jj5X0FV5vZtcyQ0sGJAhdgTaXDvXs -Ing= +MIICCgKCAgEA2WzIA2IpVR9Tb9EFhITlxuhE5rY2a3S6qzYNzQVgSFggxXEPn8k1 +sQEcer5BfAP986Sck3H0FvB4Bt/I8PwOtUCmhwcc8KtB5TcGPR4fjXnrpC+MIK5U +NLkwuyBDKziYzTdBj8kUFX1WxmvEHEgqToPOZfBgsS71cJAR/zOWraDLSRM54jXy +voLZN4Ti9rQagQrvTQ44Vz5ycDQy7UxtbUGh1CVv69vNVr7/SOOh/Nw5FNOZWLWr +odGyoec5wh9iqRZgRqiTUc6Lt7V2RWc2X2gjwST2UfI+U46Ip3oaQ7ZD4eAkoqND +xdniBZAykVG3c/99ux4BAESTF8fsNch6UticBxYMuTu+ouvP0psfI9wwwNliJDmA +CRUTB9AgRynbL1AzhqQoDfsb98IZfjfNOpwnwuLwpMAPhbgd5KNdZaIJ4Hb6/stI +yFElOExxd3TAxF2Gshd/lq1JcNHAZ1DSXV5MvOWT/NWgXwbIzUgQ8eIi+HuDYX2U +UuaB6R8tbd52H7rbUv6HrfinuSlKWqjSYLkiKHkwUpoMw8y9UycRSzs1E9nPwPTO +vRXb0mNCQeBCV9FvStNVXdCUTT8LGPv87xSD2pmt7LijlE6mHLG8McfcWkzA69un +CEHIFAFDimTuN7EBljc119xWFTcHMyoZAfFF+oTqwSbBGImruCxnaJECAwEAATAN +BgkqhkiG9w0BAQsFAAOCAgEApavsgsn7SpPHfhDSN5iZs1ILZQRewJg0Bty0xPfk +3tynSW6bNH3nSaKbpsdmxxomthNSQgD2heOq1By9YzeOoNR+7Pk3s4FkASnf3ToI +JNTUasBFFfaCG96s4Yvs8KiWS/k84yaWuU8c3Wb1jXs5Rv1qE1Uvuwat1DSGXSoD +JNluuIkCsC4kWkyq5pWCGQrabWPRTWsHwC3PTcwSRBaFgYLJaR72SloHB1ot02zL +d2age9dmFRFLLCBzP+D7RojBvL37qS/HR+rQ4SoQwiVc/JzaeqSe7ZbvEH9sZYEu +ALowJzgbwro7oZflwTWunSeSGDSltkqKjvWvZI61pwfHKDahUTmZ5h2y67FuGEaC +CIOUI8dSVSPKITxaq3JL4ze2e9/0Lt7hj19YK2uUmtMAW5Tirz4Yx5lyGH9U8Wur +y/X8VPxTc4A9TMlJgkyz0hqvhbPOT/zSWB10zXh0glKAsSBryAOEDxV1UygmSir7 +YV8Qaq+oyKUTMc1MFq5vZ07M51EPaietn85t8V2Y+k/8XYltRp32NxsypxAJuyxh +g/ko6RVTrWa1sMvz/F9LFqAdKiK5eM96lh9IU4xiLg4ob8aS/GRAA8oIFkZFhLrt +tOwjIUPmEPyHWFi8dLpNuQKYalLYhuwZftG/9xV+wqhKGZO9iPrpHSYBRTap8w2y +1QU= -----END CERTIFICATE----- diff --git a/ntex-tls/examples/key.pem b/ntex-tls/examples/key.pem index e958f1a1..aac387c6 100644 --- a/ntex-tls/examples/key.pem +++ b/ntex-tls/examples/key.pem @@ -1,51 +1,51 @@ -----BEGIN RSA PRIVATE KEY----- -MIIJKQIBAAKCAgEAryUL1k7npaMOck9OO+EjzeL0FoysOP5JrgRh+8BoPY7WPyL5 -6oFPaCYKp2YMucmvFh/VZSyupC75JJNIaW0fvcIe4Euzy2Ex0VukPxYteRicaWRs -xSIdo5RNNHd7JOf3ZWIMqkxmDhPNnqSGHcnVs/14+I5IbJCoba+KNElmL9CrL3gQ -kqNYJf2FSIgou5j1OthEdnQpiRxSRLmJ7gXtvpFGgj4AnrHGsMAPHueeop6yOX6e -gFnw2cwp98c/0tMOUsXnDU1MTGF11+4UVr043SruZKU7bvhMZRcf4NTR2MNin0b3 -DYJ+JbTn+HgPhhhx3mrsWRyCvfP23jzwnV/222o+U46i7tNYYrDN8vXIM17gtIvK -rv2FCLTJE6tsp0xAi6dT+J+AIVqkJntrsxqx2CuOYGOOkPPc4rSf64bwOR1mikdv -ZCnVNwGEXcH3nBRFMlk5bByCW0kUy03QNakiUEF+PoFzLrCL+V+21Q6Fd7Jmw06B -zVFV2YtsqFcSo7HXW91XJTDVJCPnrMJOooKQ9Fbq4zbQM0Lv02LyJWyR+0PMBzy4 -FfkWZWz10g3w+CITL/MQ65fsBBc9hRHC3QBWetj3puqM8DlPwqPhgmCA5zo8AWx7 -CogRV66ukkeBYXYFHwV5uDJTX91tbwYesOL43rlDT905aV0VbaAyDZflipMCAwEA -AQKCAgBoOnqt4a0XNE8PlcRv/A6Loskxdiuzixib133cDOe74nn7frwNY0C3MRRc -BG4ETlLErtMWb53KlS2tJ30LSGaATbqELmjj2oaEGa5H4NHU4+GJErtsIV5UD5hW -ZdhB4U2n5s60tdxx+jT+eNhbd9aWU3yfJkVRXlDtXW64qQmH4P1OtXvfWBfIG/Qq -cuUSpvchOrybZYumTdVjkqrTnHGcW+YC8hT6W79rRhB5issr6ZcUghafOWcMpeQ/ -0TJZK0K13ZIfp2WFeuZfRw6Rg/AIJllSScZxxo/oBPfym5P6FGRndxrkzkh19g+q -HQDYA0oYW7clXMMtebbrEIb8kLRdaIHDiwyFXmyywvuAAk0jHbA8snM2dyeJWSRr -WQjvQFccGF4z390ZGUCN0ZeESskndg12r4jYaL/aQ8dQZ1ivS69F8bmbQtQNU2Ej -hscTUzEMOnrBTxvRQTjI9nnrbsbklagKmJHXOc/fj13g6/FkcfmeTrjuB30LxJyH -j+xXAi8AGv/oZRk6s/txas5hXpcFXnQDRobVoJjV8kuomcDTt1j33H+05ACFyvHM -/2jxJ1f3xbFx3fqivL89+Z4r8RYxLoWLg7QuqQLdtRgThEKUG0t3lt59fUo+JVVJ -CgRbj/OM3n5udgiIeBAyMAMZjVPUKhvLIFpiUY2vKnYx/97L0QKCAQEA4QUt3dEh -P0L4eQEAzg/J8JuleH7io5VxoK5c2oulhCdQdRDF5HWSesPKJmCmgJRmIXi7zseo -Sbg7Hd2xt/QnaPhRnASXJOdn7ddtoZ1M6Zb0y+d6mmcG+mK6PshtMCQ5S3Lqhsuh -tYQbwawNlCFzwzCzwGb3aD9lBKQYts7KFrMT3Goexg3Qqv374XGn6Eg1LMhXWYbT -M5gcPOYnOT+RugeaTxMnJ6nr6E7kyrLIS+xASXKwXGxSUsQG9VWH7jDuzzARrPEU -aeyxWdbDkBn2vzW+wDpMPMqzoShZsRC9NnFfncXRZfUC5DJWGzwA/xZaR0ZNNng2 -OE7rILyAH/aZSQKCAQEAx0ICGi7y94vn5KWdaNVol3nPid4aMnk4LDcX5m0tiqUG -7LIqnFDOOjEdxTf13n7Cv4gotOQNluOypswSDZI4tI0xQ/dJ8PI+vwmA0oHSzf7U -ZPO2gzIOzububPllQsCrKHN++2SyyNlKyYFu/akmlu6yIN3EMRLqYKvZaIL5z9Lk -pTU7eS0AsXJyqD54zRLFkw6S9omQHJXrEzYAuZI+Ue/Arlgyq95mUMsHYRHgaTq4 -GDMDLHNyrdKUhW+ZiZ9dhX+aRghHgNiXDk/Eh2/RZrLhKdVk94dJQbfGu/aiSk71 -dXPEAaQ7o1MDwQgu4TsCVCzac/CeqvmcoMFyx3NA+wKCAQEAoLfLR8hsH7wcroiZ -45QBXzo8WLD//WjrDKIdLfdaE+bkn4iIX6HeKpMXGowjwGi9/aA3O/z85RKSHsXO -fp4DXAUofPAGaFRjtcwNwMYSPjEUzWKa/hciM8o6TkdnPWBSD+KXQgnFiVk/Xfge -hrPR9BMgAAdLJIlLBKKUCFXwn3/uaprdOgZ6CPd5ZU+BZvXUDRVW1lnnFc3KNXEJ -iOkvk5iEjYAXkkvadEWNQn2pdBjc3djtwEWaEwVyFt6tROJsX01tAoH6W6G0Fn+/ -lHgG9hFUGgZJl44L+MpSLZbQHkehzJWS92ilVQni2HbmG0wC1S+QTJxV1agAZpRc -SvgeCQKCAQB3PnVrnfUhV8Sq/MG63xv8qpUc+KHM2uZW75GKAIRkmGYQeH8vlNwV -zxb104t8X3fEj4Ns3Z2UUyey0iVrobn1sxlshyzk2NPcF5/UWoUBaiNJVuA+m1Jp -V6IP7SBAVnUXfCbd42Fq+T7cYG0/uF6zrJ1FNfIXPC6vM6ij9t3xFVBn3fd9iQUF -LGyZaul4MGe0neAtUh3APae0k3jTlUVeW5B/xaBtYmbwqs/7s2sNDmrlcIHRtDVI -+OCRCjxkM88P+VEl4AaKgRPFKM+ADdbPEvXUxzPpPjkE7yorimmM9rvGUkVWhiZ6 -k0+H0ZHckCfQoBcLk1AhGcg2HA7IdZzJAoIBAQDAicb6CWlNdaIcJfADKSNK4+BF -JFbH+lXYrTxVSTV+Ubdi0w8Kwk0bzf20EstJnaOCyLDCjcxafjbmuGBVbw7an0lt -Yxjx0fWXxMfvb9/3vKuJVUySA4iq/zfXKlokRaFoqbdRfod3PVGUsynCV7HmImf3 -RZA0WkcSwzbg2E2QNKQ3CPd3cHtPpBX8TwRCotg/R5yCR9lihVfkSyULikwBFvrm -2UKZm4pPESWSfMHBToJoAeO0g67itbwwpNhwvgUdyjaj8u46qyjN1FMx3mBiv7Yq -CIE+H0qNu0jmFhoqPrgxfFrGCi6eDPYjRS86536Nc4m8y24z2hie8JLK8QKQ +MIIJKAIBAAKCAgEA2WzIA2IpVR9Tb9EFhITlxuhE5rY2a3S6qzYNzQVgSFggxXEP +n8k1sQEcer5BfAP986Sck3H0FvB4Bt/I8PwOtUCmhwcc8KtB5TcGPR4fjXnrpC+M +IK5UNLkwuyBDKziYzTdBj8kUFX1WxmvEHEgqToPOZfBgsS71cJAR/zOWraDLSRM5 +4jXyvoLZN4Ti9rQagQrvTQ44Vz5ycDQy7UxtbUGh1CVv69vNVr7/SOOh/Nw5FNOZ +WLWrodGyoec5wh9iqRZgRqiTUc6Lt7V2RWc2X2gjwST2UfI+U46Ip3oaQ7ZD4eAk +oqNDxdniBZAykVG3c/99ux4BAESTF8fsNch6UticBxYMuTu+ouvP0psfI9wwwNli +JDmACRUTB9AgRynbL1AzhqQoDfsb98IZfjfNOpwnwuLwpMAPhbgd5KNdZaIJ4Hb6 +/stIyFElOExxd3TAxF2Gshd/lq1JcNHAZ1DSXV5MvOWT/NWgXwbIzUgQ8eIi+HuD +YX2UUuaB6R8tbd52H7rbUv6HrfinuSlKWqjSYLkiKHkwUpoMw8y9UycRSzs1E9nP +wPTOvRXb0mNCQeBCV9FvStNVXdCUTT8LGPv87xSD2pmt7LijlE6mHLG8McfcWkzA +69unCEHIFAFDimTuN7EBljc119xWFTcHMyoZAfFF+oTqwSbBGImruCxnaJECAwEA +AQKCAgAME3aoeXNCPxMrSri7u4Xnnk71YXl0Tm9vwvjRQlMusXZggP8VKN/KjP0/ +9AE/GhmoxqPLrLCZ9ZE1EIjgmZ9Xgde9+C8rTtfCG2RFUL7/5J2p6NonlocmxoJm +YkxYwjP6ce86RTjQWL3RF3s09u0inz9/efJk5O7M6bOWMQ9VZXDlBiRY5BYvbqUR +6FeSzD4MnMbdyMRoVBeXE88gTvZk8xhB6DJnLzYgc0tKiRoeKT0iYv5JZw25VyRM +ycLzfTrFmXCPfB1ylb483d9Ly4fBlM8nkx37PzEnAuukIawDxsPOb9yZC+hfvNJI +7NFiMN+3maEqG2iC00w4Lep4skHY7eHUEUMl+Wjr+koAy2YGLWAwHZQTm7iXn9Ab +L6adL53zyCKelRuEQOzbeosJAqS+5fpMK0ekXyoFIuskj7bWuIoCX7K/kg6q5IW+ +vC2FrlsrbQ79GztWLVmHFO1I4J9M5r666YS0qdh8c+2yyRl4FmSiHfGxb3eOKpxQ +b6uI97iZlkxPF9LYUCSc7wq0V2gGz+6LnGvTHlHrOfVXqw/5pLAKhXqxvnroDTwz +0Ay/xFF6ei/NSxBY5t8ztGCBm45wCU3l8pW0X6dXqwUipw5b4MRy1VFRu6rqlmbL +OPSCuLxqyqsigiEYsBgS/icvXz9DWmCQMPd2XM9YhsHvUq+R4QKCAQEA98EuMMXI +6UKIt1kK2t/3OeJRyDd4iv/fCMUAnuPjLBvFE4cXD/SbqCxcQYqb+pue3PYkiTIC +71rN8OQAc5yKhzmmnCE5N26br/0pG4pwEjIr6mt8kZHmemOCNEzvhhT83nfKmV0g +9lNtuGEQMiwmZrpUOF51JOMC39bzcVjYX2Cmvb7cFbIq3lR0zwM+aZpQ4P8LHCIu +bgHmwbdlkLyIULJcQmHIbo6nPFB3ZZE4mqmjwY+rA6Fh9rgBa8OFCfTtrgeYXrNb +IgZQ5U8GoYRPNC2ot0vpTinraboa/cgm6oG4M7FW1POCJTl+/ktHEnKuO5oroSga +/BSg7hCNFVaOhwKCAQEA4Kkys0HtwEbV5mY/NnvUD5KwfXX7BxoXc9lZ6seVoLEc +KjgPYxqYRVrC7dB2YDwwp3qcRTi/uBAgFNm3iYlDzI4xS5SeaudUWjglj7BSgXE2 +iOEa7EwcvVPluLaTgiWjlzUKeUCNNHWSeQOt+paBOT+IgwRVemGVpAgkqQzNh/nP +tl3p9aNtgzEm1qVlPclY/XUCtf3bcOR+z1f1b4jBdn0leu5OhnxkC+Htik+2fTXD +jt6JGrMkanN25YzsjnD3Sn+v6SO26H99wnYx5oMSdmb8SlWRrKtfJHnihphjG/YY +l1cyorV6M/asSgXNQfGJm4OuJi0I4/FL2wLUHnU+JwKCAQEAzh4WipcRthYXXcoj +gMKRkMOb3GFh1OpYqJgVExtudNTJmZxq8GhFU51MR27Eo7LycMwKy2UjEfTOnplh +Us2qZiPtW7k8O8S2m6yXlYUQBeNdq9IuuYDTaYD94vsazscJNSAeGodjE+uGvb1q +1wLqE87yoE7dUInYa1cOA3+xy2/CaNuviBFJHtzOrSb6tqqenQEyQf6h9/12+DTW +t5pSIiixHrzxHiFqOoCLRKGToQB+71rSINwTf0nITNpGBWmSj5VcC3VV3TG5/XxI +fPlxV2yhD5WFDPVNGBGvwPDSh4jSMZdZMSNBZCy4XWFNSKjGEWoK4DFYed3DoSt9 +5IG1YwKCAQA63ntHl64KJUWlkwNbboU583FF3uWBjee5VqoGKHhf3CkKMxhtGqnt ++oN7t5VdUEhbinhqdx1dyPPvIsHCS3K1pkjqii4cyzNCVNYa2dQ00Qq+QWZBpwwc +3GAkz8rFXsGIPMDa1vxpU6mnBjzPniKMcsZ9tmQDppCEpBGfLpio2eAA5IkK8eEf +cIDB3CM0Vo94EvI76CJZabaE9IJ+0HIJb2+jz9BJ00yQBIqvJIYoNy9gP5Xjpi+T +qV/tdMkD5jwWjHD3AYHLWKUGkNwwkAYFeqT/gX6jpWBP+ZRPOp011X3KInJFSpKU +DT5GQ1Dux7EMTCwVGtXqjO8Ym5wjwwsfAoIBAEcxlhIW1G6BiNfnWbNPWBdh3v/K +5Ln98Rcrz8UIbWyl7qNPjYb13C1KmifVG1Rym9vWMO3KuG5atK3Mz2yLVRtmWAVc +fxzR57zz9MZFDun66xo+Z1wN3fVxQB4CYpOEI4Lb9ioX4v85hm3D6RpFukNtRQEc +Gfr4scTjJX4jFWDp0h6ffMb8mY+quvZoJ0TJqV9L9Yj6Ksdvqez/bdSraev97bHQ +4gbQxaTZ6WjaD4HjpPQefMdWp97Metg0ZQSS8b8EzmNFgyJ3XcjirzwliKTAQtn6 +I2sd0NCIooelrKRD8EJoDUwxoOctY7R97wpZ7/wEHU45cBCbRV3H4JILS5c= -----END RSA PRIVATE KEY----- diff --git a/ntex-tls/examples/rustls-client.rs b/ntex-tls/examples/rustls-client.rs index d094140d..c9776a91 100644 --- a/ntex-tls/examples/rustls-client.rs +++ b/ntex-tls/examples/rustls-client.rs @@ -27,7 +27,8 @@ async fn main() -> io::Result<()> { // rustls connector let connector = connect::rustls::Connector::new(config.clone()); - let io = connector.connect("www.rust-lang.org:443").await.unwrap(); + //let io = connector.connect("www.rust-lang.org:443").await.unwrap(); + let io = connector.connect("127.0.0.1:8443").await.unwrap(); println!("Connected to tls server {:?}", io.query::().get()); let result = io .send(Bytes::from_static(b"GET /\r\n\r\n"), &codec::BytesCodec) diff --git a/ntex-tls/src/openssl/accept.rs b/ntex-tls/src/openssl/accept.rs index 1c384757..3e919fa6 100644 --- a/ntex-tls/src/openssl/accept.rs +++ b/ntex-tls/src/openssl/accept.rs @@ -1,7 +1,7 @@ use std::task::{Context, Poll}; use std::{error::Error, future::Future, marker::PhantomData, pin::Pin}; -use ntex_io::{Filter, FilterFactory, Io}; +use ntex_io::{Filter, FilterFactory, Io, Layer}; use ntex_service::{Service, ServiceFactory}; use ntex_util::{future::Ready, time::Millis}; use tls_openssl::ssl::SslAcceptor; @@ -53,7 +53,7 @@ impl Clone for Acceptor { } impl ServiceFactory, C> for Acceptor { - type Response = Io>; + type Response = Io>; type Error = Box; type Service = AcceptorService; type InitError = (); @@ -81,7 +81,7 @@ pub struct AcceptorService { } impl Service> for AcceptorService { - type Response = Io>; + type Response = Io>; type Error = Box; type Future<'f> = AcceptorServiceResponse; @@ -115,7 +115,7 @@ pin_project_lite::pin_project! { } impl Future for AcceptorServiceResponse { - type Output = Result>, Box>; + type Output = Result>, Box>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { self.project().fut.poll(cx) diff --git a/ntex-tls/src/openssl/mod.rs b/ntex-tls/src/openssl/mod.rs index f427716b..c78b76d0 100644 --- a/ntex-tls/src/openssl/mod.rs +++ b/ntex-tls/src/openssl/mod.rs @@ -1,10 +1,9 @@ -#![allow(clippy::type_complexity)] //! An implementation of SSL streams for ntex backed by OpenSSL use std::cell::{Cell, RefCell}; use std::{any, cmp, error::Error, io, task::Context, task::Poll}; use ntex_bytes::{BufMut, BytesVec, PoolRef}; -use ntex_io::{types, Base, Filter, FilterFactory, Io, IoRef, ReadStatus, WriteStatus}; +use ntex_io::{types, Filter, FilterFactory, FilterLayer, Io, Layer, ReadBuf, WriteBuf}; use ntex_util::{future::poll_fn, future::BoxFuture, ready, time, time::Millis}; use tls_openssl::ssl::{self, NameType, SslStream}; use tls_openssl::x509::X509; @@ -23,28 +22,27 @@ pub struct PeerCert(pub X509); pub struct PeerCertChain(pub Vec); /// An implementation of SSL streams -pub struct SslFilter { - inner: RefCell>>, +pub struct SslFilter { + inner: RefCell>, pool: PoolRef, handshake: Cell, - read_buf: Cell>, } -struct IoInner { - inner: F, +struct IoInner { + inner_read_buf: Option, + inner_write_buf: Option, pool: PoolRef, - write_buf: Option, } -impl io::Read for IoInner { +impl io::Read for IoInner { fn read(&mut self, dst: &mut [u8]) -> io::Result { - if let Some(mut buf) = self.inner.get_read_buf() { + if let Some(mut buf) = self.inner_read_buf.take() { if buf.is_empty() { Err(io::Error::from(io::ErrorKind::WouldBlock)) } else { let len = cmp::min(buf.len(), dst.len()); dst[..len].copy_from_slice(&buf.split_to(len)); - self.inner.release_read_buf(buf); + self.inner_read_buf = Some(buf); Ok(len) } } else { @@ -53,16 +51,16 @@ impl io::Read for IoInner { } } -impl io::Write for IoInner { +impl io::Write for IoInner { fn write(&mut self, src: &[u8]) -> io::Result { - let mut buf = if let Some(mut buf) = self.inner.get_write_buf() { + let mut buf = if let Some(mut buf) = self.inner_write_buf.take() { buf.reserve(src.len()); buf } else { BytesVec::with_capacity_in(src.len(), self.pool) }; buf.extend_from_slice(src); - self.inner.release_write_buf(buf)?; + self.inner_write_buf = Some(buf); Ok(src.len()) } @@ -71,7 +69,37 @@ impl io::Write for IoInner { } } -impl Filter for SslFilter { +impl SslFilter { + fn with_buffers(&self, buf: &mut WriteBuf<'_>, f: F) -> R + where + F: FnOnce() -> R, + { + self.inner.borrow_mut().get_mut().inner_write_buf = Some(buf.take_dst()); + self.inner.borrow_mut().get_mut().inner_read_buf = + buf.with_read_buf(|b| b.take_src()); + let result = f(); + buf.set_dst(self.inner.borrow_mut().get_mut().inner_write_buf.take()); + buf.with_read_buf(|b| { + b.set_src(self.inner.borrow_mut().get_mut().inner_read_buf.take()) + }); + result + } + + fn set_buffers(&self, buf: &mut WriteBuf<'_>) { + self.inner.borrow_mut().get_mut().inner_write_buf = Some(buf.take_dst()); + self.inner.borrow_mut().get_mut().inner_read_buf = + buf.with_read_buf(|b| b.take_src()); + } + + fn unset_buffers(&self, buf: &mut WriteBuf<'_>) { + buf.set_dst(self.inner.borrow_mut().get_mut().inner_write_buf.take()); + buf.with_read_buf(|b| { + b.set_src(self.inner.borrow_mut().get_mut().inner_read_buf.take()) + }); + } +} + +impl FilterLayer for SslFilter { fn query(&self, id: any::TypeId) -> Option> { const H2: &[u8] = b"h2"; @@ -116,86 +144,37 @@ impl Filter for SslFilter { None } } else { - self.inner.borrow().get_ref().inner.query(id) + None } } - fn poll_shutdown(&self) -> Poll> { - let ssl_result = self.inner.borrow_mut().shutdown(); + fn shutdown(&self, buf: &mut WriteBuf<'_>) -> io::Result> { + let ssl_result = self.with_buffers(buf, || self.inner.borrow_mut().shutdown()); + match ssl_result { - Ok(ssl::ShutdownResult::Sent) => Poll::Pending, - Ok(ssl::ShutdownResult::Received) => { - self.inner.borrow().get_ref().inner.poll_shutdown() - } - Err(ref e) if e.code() == ssl::ErrorCode::ZERO_RETURN => { - self.inner.borrow().get_ref().inner.poll_shutdown() - } + Ok(ssl::ShutdownResult::Sent) => Ok(Poll::Pending), + Ok(ssl::ShutdownResult::Received) => Ok(Poll::Ready(())), + Err(ref e) if e.code() == ssl::ErrorCode::ZERO_RETURN => Ok(Poll::Ready(())), Err(ref e) if e.code() == ssl::ErrorCode::WANT_READ || e.code() == ssl::ErrorCode::WANT_WRITE => { - Poll::Pending + Ok(Poll::Pending) } - Err(e) => Poll::Ready(Err(e + Err(e) => Err(e .into_io_error() - .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e)))), + .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e))), } } - #[inline] - fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll { - self.inner.borrow().get_ref().inner.poll_read_ready(cx) - } - - #[inline] - fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll { - self.inner.borrow().get_ref().inner.poll_write_ready(cx) - } - - #[inline] - fn get_read_buf(&self) -> Option { - self.read_buf.take() - } - - #[inline] - fn get_write_buf(&self) -> Option { - self.inner.borrow_mut().get_mut().write_buf.take() - } - - #[inline] - fn release_read_buf(&self, buf: BytesVec) { - self.read_buf.set(Some(buf)); - } - - fn process_read_buf(&self, io: &IoRef, nbytes: usize) -> io::Result<(usize, usize)> { - // ask inner filter to process read buf - match self - .inner - .borrow_mut() - .get_ref() - .inner - .process_read_buf(io, nbytes) - { - Err(err) => io.want_shutdown(Some(err)), - Ok((n, 0)) => return Ok((n, 0)), - Ok((_, _)) => (), - } - - // get processed buffer - let mut dst = if let Some(dst) = self.get_read_buf() { - dst - } else { - self.pool.get_read_buf() - }; - let (hw, lw) = self.pool.read_params().unpack(); + fn process_read_buf(&self, buf: &mut ReadBuf<'_>) -> io::Result { + buf.with_write_buf(|b| self.set_buffers(b)); + let dst = buf.get_dst(); let mut new_bytes = usize::from(self.handshake.get()); loop { // make sure we've got room - let remaining = dst.remaining_mut(); - if remaining < lw { - dst.reserve(hw - remaining); - } + self.pool.resize_read_buf(dst); let chunk: &mut [u8] = unsafe { std::mem::transmute(&mut *dst.chunk_mut()) }; let ssl_result = self.inner.borrow_mut().ssl_read(chunk); @@ -209,43 +188,61 @@ impl Filter for SslFilter { if e.code() == ssl::ErrorCode::WANT_READ || e.code() == ssl::ErrorCode::WANT_WRITE => { - Ok((dst.len(), new_bytes)) + Ok(new_bytes) } Err(ref e) if e.code() == ssl::ErrorCode::ZERO_RETURN => { - io.want_shutdown(None); - Ok((dst.len(), new_bytes)) + buf.want_shutdown(); + Ok(new_bytes) } Err(e) => { log::trace!("SSL Error: {:?}", e); Err(map_to_ioerr(e)) } }; - self.release_read_buf(dst); + + buf.with_write_buf(|b| self.unset_buffers(b)); return result; } } - fn release_write_buf(&self, mut buf: BytesVec) -> Result<(), io::Error> { - loop { - if buf.is_empty() { - return Ok(()); - } - let ssl_result = self.inner.borrow_mut().ssl_write(&buf); - match ssl_result { - Ok(v) => { - buf.split_to(v); - continue; + fn process_write_buf(&self, buf: &mut WriteBuf<'_>) -> io::Result<()> { + if let Some(mut src) = buf.take_src() { + self.set_buffers(buf); + + loop { + if src.is_empty() { + self.unset_buffers(buf); + return Ok(()); } - Err(e) => { - if !buf.is_empty() { - self.inner.borrow_mut().get_mut().write_buf = Some(buf); + let ssl_result = self.inner.borrow_mut().ssl_write(&src); + match ssl_result { + Ok(v) => { + src.split_to(v); + continue; + } + Err(e) => { + if !src.is_empty() { + buf.set_src(Some(src)); + } + self.unset_buffers(buf); + return match e.code() { + ssl::ErrorCode::WANT_READ | ssl::ErrorCode::WANT_WRITE => { + buf.set_dst( + self.inner + .borrow_mut() + .get_mut() + .inner_write_buf + .take(), + ); + Ok(()) + } + _ => Err(map_to_ioerr(e)), + }; } - return match e.code() { - ssl::ErrorCode::WANT_READ | ssl::ErrorCode::WANT_WRITE => Ok(()), - _ => Err(map_to_ioerr(e)), - }; } } + } else { + Ok(()) } } } @@ -283,42 +280,47 @@ impl Clone for SslAcceptor { } impl FilterFactory for SslAcceptor { - type Filter = SslFilter; + type Filter = SslFilter; type Error = Box; - type Future = BoxFuture<'static, Result, Self::Error>>; + type Future = BoxFuture<'static, Result>, Self::Error>>; - fn create(self, st: Io) -> Self::Future { + fn create(self, io: Io) -> Self::Future { let timeout = self.timeout; let ctx_result = ssl::Ssl::new(self.acceptor.context()); Box::pin(async move { time::timeout(timeout, async { let ssl = ctx_result.map_err(map_to_ioerr)?; - let pool = st.memory_pool(); - let st = st.map_filter(|inner: F| { - let inner = IoInner { - pool, - inner, - write_buf: None, - }; - let ssl_stream = ssl::SslStream::new(ssl, inner)?; - - Ok::<_, Box>(SslFilter { - pool, - read_buf: Cell::new(None), - handshake: Cell::new(true), - inner: RefCell::new(ssl_stream), - }) - })?; + let inner = IoInner { + pool: io.memory_pool(), + inner_read_buf: None, + inner_write_buf: None, + }; + let filter = SslFilter { + pool: io.memory_pool(), + handshake: Cell::new(true), + inner: RefCell::new(ssl::SslStream::new(ssl, inner)?), + }; + let io = io.add_filter(filter); poll_fn(|cx| { - handle_result(st.filter().inner.borrow_mut().accept(), &st, cx) + let result = io + .with_buf(|buf| { + let filter = io.filter(); + filter.with_buffers(buf, || filter.inner.borrow_mut().accept()) + }) + .map_err(|err| { + let err: Box = + io::Error::new(io::ErrorKind::Other, err).into(); + err + })?; + handle_result(result, &io, cx) }) .await?; - st.filter().handshake.set(false); - Ok(st) + io.filter().handshake.set(false); + Ok(io) }) .await .map_err(|_| { @@ -341,35 +343,42 @@ impl SslConnector { } impl FilterFactory for SslConnector { - type Filter = SslFilter; + type Filter = SslFilter; type Error = Box; - type Future = BoxFuture<'static, Result, Self::Error>>; + type Future = BoxFuture<'static, Result>, Self::Error>>; - fn create(self, st: Io) -> Self::Future { + fn create(self, io: Io) -> Self::Future { Box::pin(async move { - let ssl = self.ssl; - let pool = st.memory_pool(); - let st = st.map_filter(|inner: F| { - let inner = IoInner { - pool, - inner, - write_buf: None, - }; - let ssl_stream = ssl::SslStream::new(ssl, inner)?; + let inner = IoInner { + pool: io.memory_pool(), + inner_read_buf: None, + inner_write_buf: None, + }; + let filter = SslFilter { + pool: io.memory_pool(), + handshake: Cell::new(true), + inner: RefCell::new(ssl::SslStream::new(self.ssl, inner)?), + }; + let io = io.add_filter(filter); - Ok::<_, Box>(SslFilter { - pool, - read_buf: Cell::new(None), - handshake: Cell::new(true), - inner: RefCell::new(ssl_stream), - }) - })?; + poll_fn(|cx| { + let result = io + .with_buf(|buf| { + let filter = io.filter(); + filter.with_buffers(buf, || filter.inner.borrow_mut().connect()) + }) + .map_err(|err| { + let err: Box = + io::Error::new(io::ErrorKind::Other, err).into(); + err + })?; + handle_result(result, &io, cx) + }) + .await?; - poll_fn(|cx| handle_result(st.filter().inner.borrow_mut().connect(), &st, cx)) - .await?; - - Ok(st) + io.filter().handshake.set(false); + Ok(io) }) } } diff --git a/ntex-tls/src/rustls/accept.rs b/ntex-tls/src/rustls/accept.rs index 6c87e9d3..f1373096 100644 --- a/ntex-tls/src/rustls/accept.rs +++ b/ntex-tls/src/rustls/accept.rs @@ -3,7 +3,7 @@ use std::{future::Future, io, marker::PhantomData, pin::Pin, sync::Arc}; use tls_rust::ServerConfig; -use ntex_io::{Filter, FilterFactory, Io}; +use ntex_io::{Filter, FilterFactory, Io, Layer}; use ntex_service::{Service, ServiceFactory}; use ntex_util::{future::Ready, time::Millis}; @@ -52,7 +52,7 @@ impl Clone for Acceptor { } impl ServiceFactory, C> for Acceptor { - type Response = Io>; + type Response = Io>; type Error = io::Error; type Service = AcceptorService; @@ -79,7 +79,7 @@ pub struct AcceptorService { } impl Service> for AcceptorService { - type Response = Io>; + type Response = Io>; type Error = io::Error; type Future<'f> = AcceptorServiceFut; @@ -113,7 +113,7 @@ pin_project_lite::pin_project! { } impl Future for AcceptorServiceFut { - type Output = Result>, io::Error>; + type Output = Result>, io::Error>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { self.project().fut.poll(cx) diff --git a/ntex-tls/src/rustls/client.rs b/ntex-tls/src/rustls/client.rs index c4c02b9d..4a24d595 100644 --- a/ntex-tls/src/rustls/client.rs +++ b/ntex-tls/src/rustls/client.rs @@ -1,9 +1,9 @@ //! An implementation of SSL streams for ntex backed by OpenSSL use std::io::{self, Read as IoRead, Write as IoWrite}; -use std::{any, cell::Cell, cell::RefCell, sync::Arc, task::Context, task::Poll}; +use std::{any, cell::Cell, cell::RefCell, sync::Arc, task::Poll}; -use ntex_bytes::{BufMut, BytesVec}; -use ntex_io::{types, Filter, Io, IoRef, ReadStatus, WriteStatus}; +use ntex_bytes::BufMut; +use ntex_io::{types, Filter, FilterLayer, Io, Layer, ReadBuf, WriteBuf}; use ntex_util::{future::poll_fn, ready}; use tls_rust::{ClientConfig, ClientConnection, ServerName}; @@ -12,12 +12,12 @@ use crate::rustls::{IoInner, TlsFilter, Wrapper}; use super::{PeerCert, PeerCertChain}; /// An implementation of SSL streams -pub struct TlsClientFilter { - inner: IoInner, +pub struct TlsClientFilter { + inner: IoInner, session: RefCell, } -impl Filter for TlsClientFilter { +impl FilterLayer for TlsClientFilter { fn query(&self, id: any::TypeId) -> Option> { const H2: &[u8] = b"h2"; @@ -52,71 +52,19 @@ impl Filter for TlsClientFilter { None } } else { - self.inner.filter.query(id) + None } } - #[inline] - fn poll_shutdown(&self) -> Poll> { - self.inner.filter.poll_shutdown() - } - - #[inline] - fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll { - self.inner.filter.poll_read_ready(cx) - } - - #[inline] - fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll { - self.inner.filter.poll_write_ready(cx) - } - - #[inline] - fn get_read_buf(&self) -> Option { - self.inner.read_buf.take() - } - - #[inline] - fn get_write_buf(&self) -> Option { - self.inner.write_buf.take() - } - - #[inline] - fn release_read_buf(&self, buf: BytesVec) { - self.inner.read_buf.set(Some(buf)); - } - - fn process_read_buf(&self, io: &IoRef, nbytes: usize) -> io::Result<(usize, usize)> { + fn process_read_buf(&self, buf: &mut ReadBuf<'_>) -> io::Result { let mut session = self.session.borrow_mut(); - // ask inner filter to process read buf - match self.inner.filter.process_read_buf(io, nbytes) { - Err(err) => io.want_shutdown(Some(err)), - Ok((_, 0)) => return Ok((0, 0)), - Ok(_) => (), - } - // get processed buffer - let mut dst = if let Some(dst) = self.inner.read_buf.take() { - dst - } else { - self.inner.pool.get_read_buf() - }; - let (hw, lw) = self.inner.pool.read_params().unpack(); - - let mut src = if let Some(src) = self.inner.filter.get_read_buf() { - src - } else { - return Ok((0, 0)); - }; - + let (src, dst) = buf.get_pair(); let mut new_bytes = usize::from(self.inner.handshake.get()); loop { // make sure we've got room - let remaining = dst.remaining_mut(); - if remaining < lw { - dst.reserve(hw - remaining); - } + self.inner.pool.resize_read_buf(dst); let mut cursor = io::Cursor::new(&src); let n = session.read_tls(&mut cursor)?; @@ -138,73 +86,74 @@ impl Filter for TlsClientFilter { } } - let dst_len = dst.len(); - self.inner.read_buf.set(Some(dst)); - self.inner.filter.release_read_buf(src); - Ok((dst_len, new_bytes)) + Ok(new_bytes) } - fn release_write_buf(&self, mut src: BytesVec) -> Result<(), io::Error> { - let mut session = self.session.borrow_mut(); - let mut io = Wrapper(&self.inner); + fn process_write_buf(&self, buf: &mut WriteBuf<'_>) -> io::Result<()> { + if let Some(mut src) = buf.take_src() { + let mut session = self.session.borrow_mut(); + let mut io = Wrapper(&self.inner, buf); + + loop { + if !src.is_empty() { + let n = session.writer().write(&src)?; + src.split_to(n); + } + + if session.wants_write() { + session.complete_io(&mut io)?; + } else { + break; + } + } - loop { if !src.is_empty() { - let n = session.writer().write(&src)?; - src.split_to(n); - } - - let n = session.write_tls(&mut io)?; - if n == 0 { - break; + buf.set_src(Some(src)); } + Ok(()) + } else { + Ok(()) } - - if !src.is_empty() { - self.inner.write_buf.set(Some(src)); - } - - Ok(()) } } -impl TlsClientFilter { - pub(crate) async fn create( +impl TlsClientFilter { + pub(crate) async fn create( io: Io, cfg: Arc, domain: ServerName, - ) -> Result>, io::Error> { - let pool = io.memory_pool(); - let session = match ClientConnection::new(cfg, domain) { - Ok(session) => session, - Err(error) => return Err(io::Error::new(io::ErrorKind::Other, error)), - }; - let io = io.map_filter(|filter: F| { - let inner = IoInner { - pool, - filter, - read_buf: Cell::new(None), - write_buf: Cell::new(None), + ) -> Result>, io::Error> { + let session = ClientConnection::new(cfg, domain) + .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; + let filter = TlsFilter::new_client(TlsClientFilter { + inner: IoInner { + pool: io.memory_pool(), handshake: Cell::new(true), - }; - - Ok::<_, io::Error>(TlsFilter::new_client(TlsClientFilter { - inner, - session: RefCell::new(session), - })) - })?; + }, + session: RefCell::new(session), + }); + let io = io.add_filter(filter); let filter = io.filter(); loop { - let (result, wants_read, handshaking) = { + let (result, wants_read, handshaking) = io.with_buf(|buf| { let mut session = filter.client().session.borrow_mut(); - let mut wrp = Wrapper(&filter.client().inner); - ( + let mut wrp = Wrapper(&filter.client().inner, buf); + let mut result = ( session.complete_io(&mut wrp), session.wants_read(), session.is_handshaking(), - ) - }; + ); + + while session.wants_write() { + result.0 = session.complete_io(&mut wrp); + if result.0.is_err() { + break; + } + } + result + })?; + match result { Ok(_) => { filter.client().inner.handshake.set(false); diff --git a/ntex-tls/src/rustls/mod.rs b/ntex-tls/src/rustls/mod.rs index 11060f5c..24b470fd 100644 --- a/ntex-tls/src/rustls/mod.rs +++ b/ntex-tls/src/rustls/mod.rs @@ -1,11 +1,13 @@ #![allow(clippy::type_complexity)] //! An implementation of SSL streams for ntex backed by OpenSSL -use std::{any, cmp, future::Future, io, pin::Pin, task::Context, task::Poll}; -use std::{cell::Cell, sync::Arc}; +use std::{any, cell::Cell, cmp, io, sync::Arc, task::Context, task::Poll}; -use ntex_bytes::{BytesVec, PoolRef}; -use ntex_io::{Base, Filter, FilterFactory, Io, IoRef, ReadStatus, WriteStatus}; -use ntex_util::time::Millis; +use ntex_bytes::PoolRef; +use ntex_io::{ + Filter, FilterFactory, FilterLayer, Io, Layer, ReadBuf, ReadStatus, WriteBuf, + WriteStatus, +}; +use ntex_util::{future::BoxFuture, time::Millis}; use tls_rust::{Certificate, ClientConfig, ServerConfig, ServerName}; mod accept; @@ -25,33 +27,33 @@ pub struct PeerCert(pub Certificate); pub struct PeerCertChain(pub Vec); /// An implementation of SSL streams -pub struct TlsFilter { - inner: InnerTlsFilter, +pub struct TlsFilter { + inner: InnerTlsFilter, } -enum InnerTlsFilter { - Server(TlsServerFilter), - Client(TlsClientFilter), +enum InnerTlsFilter { + Server(TlsServerFilter), + Client(TlsClientFilter), } -impl TlsFilter { - fn new_server(server: TlsServerFilter) -> Self { +impl TlsFilter { + fn new_server(server: TlsServerFilter) -> Self { TlsFilter { inner: InnerTlsFilter::Server(server), } } - fn new_client(client: TlsClientFilter) -> Self { + fn new_client(client: TlsClientFilter) -> Self { TlsFilter { inner: InnerTlsFilter::Client(client), } } - fn server(&self) -> &TlsServerFilter { + fn server(&self) -> &TlsServerFilter { match self.inner { InnerTlsFilter::Server(ref server) => server, _ => unreachable!(), } } - fn client(&self) -> &TlsClientFilter { + fn client(&self) -> &TlsClientFilter { match self.inner { InnerTlsFilter::Client(ref server) => server, _ => unreachable!(), @@ -59,7 +61,7 @@ impl TlsFilter { } } -impl Filter for TlsFilter { +impl FilterLayer for TlsFilter { #[inline] fn query(&self, id: any::TypeId) -> Option> { match self.inner { @@ -69,10 +71,10 @@ impl Filter for TlsFilter { } #[inline] - fn poll_shutdown(&self) -> Poll> { + fn shutdown(&self, buf: &mut WriteBuf<'_>) -> io::Result> { match self.inner { - InnerTlsFilter::Server(ref f) => f.poll_shutdown(), - InnerTlsFilter::Client(ref f) => f.poll_shutdown(), + InnerTlsFilter::Server(ref f) => f.shutdown(buf), + InnerTlsFilter::Client(ref f) => f.shutdown(buf), } } @@ -93,42 +95,18 @@ impl Filter for TlsFilter { } #[inline] - fn get_read_buf(&self) -> Option { + fn process_read_buf(&self, buf: &mut ReadBuf<'_>) -> io::Result { match self.inner { - InnerTlsFilter::Server(ref f) => f.get_read_buf(), - InnerTlsFilter::Client(ref f) => f.get_read_buf(), + InnerTlsFilter::Server(ref f) => f.process_read_buf(buf), + InnerTlsFilter::Client(ref f) => f.process_read_buf(buf), } } #[inline] - fn get_write_buf(&self) -> Option { + fn process_write_buf(&self, buf: &mut WriteBuf<'_>) -> io::Result<()> { match self.inner { - InnerTlsFilter::Server(ref f) => f.get_write_buf(), - InnerTlsFilter::Client(ref f) => f.get_write_buf(), - } - } - - #[inline] - fn release_read_buf(&self, buf: BytesVec) { - match self.inner { - InnerTlsFilter::Server(ref f) => f.release_read_buf(buf), - InnerTlsFilter::Client(ref f) => f.release_read_buf(buf), - } - } - - #[inline] - fn process_read_buf(&self, io: &IoRef, nb: usize) -> io::Result<(usize, usize)> { - match self.inner { - InnerTlsFilter::Server(ref f) => f.process_read_buf(io, nb), - InnerTlsFilter::Client(ref f) => f.process_read_buf(io, nb), - } - } - - #[inline] - fn release_write_buf(&self, src: BytesVec) -> Result<(), io::Error> { - match self.inner { - InnerTlsFilter::Server(ref f) => f.release_write_buf(src), - InnerTlsFilter::Client(ref f) => f.release_write_buf(src), + InnerTlsFilter::Server(ref f) => f.process_write_buf(buf), + InnerTlsFilter::Client(ref f) => f.process_write_buf(buf), } } } @@ -172,10 +150,10 @@ impl Clone for TlsAcceptor { } impl FilterFactory for TlsAcceptor { - type Filter = TlsFilter; + type Filter = TlsFilter; type Error = io::Error; - type Future = Pin, io::Error>>>>; + type Future = BoxFuture<'static, Result>, io::Error>>; fn create(self, st: Io) -> Self::Future { let cfg = self.cfg.clone(); @@ -227,10 +205,10 @@ impl Clone for TlsConnectorConfigured { } impl FilterFactory for TlsConnectorConfigured { - type Filter = TlsFilter; + type Filter = TlsFilter; type Error = io::Error; - type Future = Pin, io::Error>>>>; + type Future = BoxFuture<'static, Result>, io::Error>>; fn create(self, st: Io) -> Self::Future { let cfg = self.cfg; @@ -240,44 +218,31 @@ impl FilterFactory for TlsConnectorConfigured { } } -pub(crate) struct IoInner { - filter: F, +pub(crate) struct IoInner { pool: PoolRef, - read_buf: Cell>, - write_buf: Cell>, handshake: Cell, } -pub(crate) struct Wrapper<'a, F>(&'a IoInner); +pub(crate) struct Wrapper<'a, 'b>(&'a IoInner, &'a mut WriteBuf<'b>); -impl<'a, F: Filter> io::Read for Wrapper<'a, F> { +impl<'a, 'b> io::Read for Wrapper<'a, 'b> { fn read(&mut self, dst: &mut [u8]) -> io::Result { - if let Some(mut read_buf) = self.0.filter.get_read_buf() { + self.1.with_read_buf(|buf| { + let read_buf = buf.get_src(); let len = cmp::min(read_buf.len(), dst.len()); - let result = if len > 0 { + if len > 0 { dst[..len].copy_from_slice(&read_buf.split_to(len)); Ok(len) } else { Err(io::Error::new(io::ErrorKind::WouldBlock, "")) - }; - self.0.filter.release_read_buf(read_buf); - result - } else { - Err(io::Error::new(io::ErrorKind::WouldBlock, "")) - } + } + }) } } -impl<'a, F: Filter> io::Write for Wrapper<'a, F> { +impl<'a, 'b> io::Write for Wrapper<'a, 'b> { fn write(&mut self, src: &[u8]) -> io::Result { - let mut buf = if let Some(mut buf) = self.0.filter.get_write_buf() { - buf.reserve(src.len()); - buf - } else { - BytesVec::with_capacity_in(src.len(), self.0.pool) - }; - buf.extend_from_slice(src); - self.0.filter.release_write_buf(buf)?; + self.1.get_dst().extend_from_slice(src); Ok(src.len()) } diff --git a/ntex-tls/src/rustls/server.rs b/ntex-tls/src/rustls/server.rs index 89926b15..14a83adb 100644 --- a/ntex-tls/src/rustls/server.rs +++ b/ntex-tls/src/rustls/server.rs @@ -1,9 +1,9 @@ //! An implementation of SSL streams for ntex backed by OpenSSL use std::io::{self, Read as IoRead, Write as IoWrite}; -use std::{any, cell::Cell, cell::RefCell, sync::Arc, task::Context, task::Poll}; +use std::{any, cell::Cell, cell::RefCell, sync::Arc, task::Poll}; -use ntex_bytes::{BufMut, BytesVec}; -use ntex_io::{types, Filter, Io, IoRef, ReadStatus, WriteStatus}; +use ntex_bytes::BufMut; +use ntex_io::{types, Filter, FilterLayer, Io, Layer, ReadBuf, WriteBuf}; use ntex_util::{future::poll_fn, ready, time, time::Millis}; use tls_rust::{ServerConfig, ServerConnection}; @@ -13,12 +13,12 @@ use crate::Servername; use super::{PeerCert, PeerCertChain}; /// An implementation of SSL streams -pub struct TlsServerFilter { - inner: IoInner, +pub struct TlsServerFilter { + inner: IoInner, session: RefCell, } -impl Filter for TlsServerFilter { +impl FilterLayer for TlsServerFilter { fn query(&self, id: any::TypeId) -> Option> { const H2: &[u8] = b"h2"; @@ -59,71 +59,19 @@ impl Filter for TlsServerFilter { None } } else { - self.inner.filter.query(id) + None } } - #[inline] - fn poll_shutdown(&self) -> Poll> { - self.inner.filter.poll_shutdown() - } - - #[inline] - fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll { - self.inner.filter.poll_read_ready(cx) - } - - #[inline] - fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll { - self.inner.filter.poll_write_ready(cx) - } - - #[inline] - fn get_read_buf(&self) -> Option { - self.inner.read_buf.take() - } - - #[inline] - fn get_write_buf(&self) -> Option { - self.inner.write_buf.take() - } - - #[inline] - fn release_read_buf(&self, buf: BytesVec) { - self.inner.read_buf.set(Some(buf)); - } - - fn process_read_buf(&self, io: &IoRef, nbytes: usize) -> io::Result<(usize, usize)> { + fn process_read_buf(&self, buf: &mut ReadBuf<'_>) -> io::Result { let mut session = self.session.borrow_mut(); - // ask inner filter to process read buf - match self.inner.filter.process_read_buf(io, nbytes) { - Err(err) => io.want_shutdown(Some(err)), - Ok((_, 0)) => return Ok((0, 0)), - Ok(_) => (), - } - // get processed buffer - let mut dst = if let Some(dst) = self.inner.read_buf.take() { - dst - } else { - self.inner.pool.get_read_buf() - }; - let (hw, lw) = self.inner.pool.read_params().unpack(); - - let mut src = if let Some(src) = self.inner.filter.get_read_buf() { - src - } else { - return Ok((0, 0)); - }; - + let (src, dst) = buf.get_pair(); let mut new_bytes = usize::from(self.inner.handshake.get()); loop { // make sure we've got room - let remaining = dst.remaining_mut(); - if remaining < lw { - dst.reserve(hw - remaining); - } + self.inner.pool.resize_read_buf(dst); let mut cursor = io::Cursor::new(&src); let n = session.read_tls(&mut cursor)?; @@ -145,73 +93,73 @@ impl Filter for TlsServerFilter { } } - let dst_len = dst.len(); - self.inner.read_buf.set(Some(dst)); - self.inner.filter.release_read_buf(src); - Ok((dst_len, new_bytes)) + Ok(new_bytes) } - fn release_write_buf(&self, mut src: BytesVec) -> Result<(), io::Error> { - let mut session = self.session.borrow_mut(); - let mut io = Wrapper(&self.inner); + fn process_write_buf(&self, buf: &mut WriteBuf<'_>) -> io::Result<()> { + if let Some(mut src) = buf.take_src() { + let mut session = self.session.borrow_mut(); + let mut io = Wrapper(&self.inner, buf); + + loop { + if !src.is_empty() { + let n = session.writer().write(&src)?; + src.split_to(n); + } + + if session.wants_write() { + session.complete_io(&mut io)?; + } else { + break; + } + } - loop { if !src.is_empty() { - let n = session.writer().write(&src)?; - src.split_to(n); + buf.set_src(Some(src)); } - - let n = session.write_tls(&mut io)?; - if n == 0 { - break; - } - } - - if !src.is_empty() { - self.inner.write_buf.set(Some(src)); } Ok(()) } } -impl TlsServerFilter { - pub(crate) async fn create( +impl TlsServerFilter { + pub(crate) async fn create( io: Io, cfg: Arc, timeout: Millis, - ) -> Result>, io::Error> { + ) -> Result>, io::Error> { time::timeout(timeout, async { - let pool = io.memory_pool(); - let session = match ServerConnection::new(cfg) { - Ok(session) => session, - Err(error) => return Err(io::Error::new(io::ErrorKind::Other, error)), - }; - let io = io.map_filter(|filter: F| { - let inner = IoInner { - pool, - filter, - read_buf: Cell::new(None), - write_buf: Cell::new(None), + let session = ServerConnection::new(cfg) + .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; + let filter = TlsFilter::new_server(TlsServerFilter { + session: RefCell::new(session), + inner: IoInner { + pool: io.memory_pool(), handshake: Cell::new(true), - }; - - Ok::<_, io::Error>(TlsFilter::new_server(TlsServerFilter { - inner, - session: RefCell::new(session), - })) - })?; + }, + }); + let io = io.add_filter(filter); let filter = io.filter(); loop { - let (result, wants_read, handshaking) = { + let (result, wants_read, handshaking) = io.with_buf(|buf| { let mut session = filter.server().session.borrow_mut(); - let mut wrp = Wrapper(&filter.server().inner); - ( + let mut wrp = Wrapper(&filter.server().inner, buf); + let mut result = ( session.complete_io(&mut wrp), session.wants_read(), session.is_handshaking(), - ) - }; + ); + + while session.wants_write() { + result.0 = session.complete_io(&mut wrp); + if result.0.is_err() { + break; + } + } + result + })?; + match result { Ok(_) => { filter.server().inner.handshake.set(false); diff --git a/ntex-tokio/Cargo.toml b/ntex-tokio/Cargo.toml index c143179c..7d1bc9e1 100644 --- a/ntex-tokio/Cargo.toml +++ b/ntex-tokio/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-tokio" -version = "0.2.0" +version = "0.2.1" authors = ["ntex contributors "] description = "tokio intergration for ntex framework" keywords = ["network", "framework", "async", "futures"] @@ -16,8 +16,8 @@ name = "ntex_tokio" path = "src/lib.rs" [dependencies] -ntex-bytes = "0.1.18" -ntex-io = "0.2.0" +ntex-bytes = "0.1.19" +ntex-io = "0.2.1" ntex-util = "0.2.0" log = "0.4" pin-project-lite = "0.2" diff --git a/ntex-tokio/src/io.rs b/ntex-tokio/src/io.rs index 54844d93..fcc02848 100644 --- a/ntex-tokio/src/io.rs +++ b/ntex-tokio/src/io.rs @@ -54,73 +54,42 @@ impl Future for ReadTask { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.as_ref(); - loop { + this.state.with_buf(|buf, hw, lw| { match ready!(this.state.poll_ready(cx)) { ReadStatus::Ready => { - let pool = this.state.memory_pool(); - let mut io = this.io.borrow_mut(); - let mut buf = self.state.get_read_buf(); - let (hw, lw) = pool.read_params().unpack(); - // read data from socket - let mut new_bytes = 0; - let mut close = false; - let mut pending = false; + let mut io = this.io.borrow_mut(); loop { // make sure we've got room let remaining = buf.remaining_mut(); if remaining < lw { buf.reserve(hw - remaining); } - - match poll_read_buf(Pin::new(&mut *io), cx, &mut buf) { - Poll::Pending => { - pending = true; - break; - } + return match poll_read_buf(Pin::new(&mut *io), cx, buf) { + Poll::Pending => Poll::Pending, Poll::Ready(Ok(n)) => { if n == 0 { log::trace!("tokio stream is disconnected"); - close = true; + Poll::Ready(Ok(())) + } else if buf.len() < hw { + continue; } else { - new_bytes += n; - if new_bytes <= hw { - continue; - } + Poll::Pending } - break; } Poll::Ready(Err(err)) => { log::trace!("read task failed on io {:?}", err); - drop(io); - this.state.release_read_buf(buf, new_bytes); - this.state.close(Some(err)); - return Poll::Ready(()); + Poll::Ready(Err(err)) } - } + }; } - - drop(io); - if new_bytes == 0 && close { - this.state.close(None); - return Poll::Ready(()); - } - this.state.release_read_buf(buf, new_bytes); - return if close { - this.state.close(None); - Poll::Ready(()) - } else if pending { - Poll::Pending - } else { - continue; - }; } ReadStatus::Terminate => { log::trace!("read task is instructed to shutdown"); - return Poll::Ready(()); + Poll::Ready(Ok(())) } } - } + }) } } @@ -269,14 +238,14 @@ impl Future for WriteTask { if read_buf.filled().is_empty() => { this.state.close(None); - log::trace!("write task is stopped"); + log::trace!("tokio write task is stopped"); return Poll::Ready(()); } Poll::Pending => { *count += read_buf.filled().len() as u16; if *count > 4096 { log::trace!( - "write task is stopped, too much input" + "tokio write task is stopped, too much input" ); this.state.close(None); return Poll::Ready(()); @@ -344,7 +313,7 @@ pub(super) fn flush_io( } } } - log::trace!("flushed {} bytes", written); + // log::trace!("flushed {} bytes", written); // remove written data let result = if written == len { @@ -501,18 +470,11 @@ mod unixstream { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.as_ref(); - loop { + this.state.with_buf(|buf, hw, lw| { match ready!(this.state.poll_ready(cx)) { ReadStatus::Ready => { - let pool = this.state.memory_pool(); - let mut io = this.io.borrow_mut(); - let mut buf = self.state.get_read_buf(); - let (hw, lw) = pool.read_params().unpack(); - // read data from socket - let mut new_bytes = 0; - let mut close = false; - let mut pending = false; + let mut io = this.io.borrow_mut(); loop { // make sure we've got room let remaining = buf.remaining_mut(); @@ -520,54 +482,31 @@ mod unixstream { buf.reserve(hw - remaining); } - match poll_read_buf(Pin::new(&mut *io), cx, &mut buf) { - Poll::Pending => { - pending = true; - break; - } + return match poll_read_buf(Pin::new(&mut *io), cx, buf) { + Poll::Pending => Poll::Pending, Poll::Ready(Ok(n)) => { if n == 0 { - log::trace!("unix stream is disconnected"); - close = true; + log::trace!("tokio unix stream is disconnected"); + Poll::Ready(Ok(())) + } else if buf.len() < hw { + continue; } else { - new_bytes += n; - if new_bytes <= hw { - continue; - } + Poll::Pending } - break; } Poll::Ready(Err(err)) => { - log::trace!("read task failed on io {:?}", err); - drop(io); - this.state.release_read_buf(buf, new_bytes); - this.state.close(Some(err)); - return Poll::Ready(()); + log::trace!("unix stream read task failed {:?}", err); + Poll::Ready(Err(err)) } - } + }; } - - drop(io); - if new_bytes == 0 && close { - this.state.close(None); - return Poll::Ready(()); - } - this.state.release_read_buf(buf, new_bytes); - return if close { - this.state.close(None); - Poll::Ready(()) - } else if pending { - Poll::Pending - } else { - continue; - }; } ReadStatus::Terminate => { log::trace!("read task is instructed to shutdown"); - return Poll::Ready(()); + Poll::Ready(Ok(())) } } - } + }) } } @@ -735,10 +674,6 @@ pub fn poll_read_buf( cx: &mut Context<'_>, buf: &mut BytesVec, ) -> Poll> { - if !buf.has_remaining_mut() { - return Poll::Ready(Ok(0)); - } - let n = { let dst = unsafe { &mut *(buf.chunk_mut() as *mut _ as *mut [mem::MaybeUninit]) }; diff --git a/ntex/Cargo.toml b/ntex/Cargo.toml index 8ce850b8..eba9d7cd 100644 --- a/ntex/Cargo.toml +++ b/ntex/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex" -version = "0.6.0" +version = "0.6.1" authors = ["ntex contributors "] description = "Framework for composable network services" readme = "README.md" @@ -49,20 +49,20 @@ async-std = ["ntex-rt/async-std", "ntex-async-std", "ntex-connect/async-std"] [dependencies] ntex-codec = "0.6.2" -ntex-connect = "0.2.0" +ntex-connect = "0.2.1" ntex-http = "0.1.9" ntex-router = "0.5.1" ntex-service = "1.0.0" ntex-macros = "0.1.3" ntex-util = "0.2.0" -ntex-bytes = "0.1.18" +ntex-bytes = "0.1.19" ntex-h2 = "0.2.0" ntex-rt = "0.4.7" -ntex-io = "0.2.0" -ntex-tls = "0.2.0" -ntex-tokio = { version = "0.2.0", optional = true } -ntex-glommio = { version = "0.2.0", optional = true } -ntex-async-std = { version = "0.2.0", optional = true } +ntex-io = "0.2.1" +ntex-tls = "0.2.1" +ntex-tokio = { version = "0.2.1", optional = true } +ntex-glommio = { version = "0.2.1", optional = true } +ntex-async-std = { version = "0.2.1", optional = true } async-oneshot = "0.5.0" async-channel = "1.8.0" @@ -88,7 +88,7 @@ percent-encoding = "2.1" serde_json = "1.0" serde_urlencoded = "0.7" url-pkg = { version = "2.1", package = "url", optional = true } -coo-kie = { version = "0.16", package = "cookie", optional = true } +coo-kie = { version = "0.17", package = "cookie", optional = true } # openssl tls-openssl = { version="0.10", package = "openssl", optional = true } diff --git a/ntex/src/http/client/h2proto.rs b/ntex/src/http/client/h2proto.rs index 28649bcd..9811c65b 100644 --- a/ntex/src/http/client/h2proto.rs +++ b/ntex/src/http/client/h2proto.rs @@ -206,6 +206,7 @@ struct H2ClientInner { streams: RefCell>, } +#[derive(Debug)] struct StreamInfo { tx: Option>>, stream: Option, diff --git a/ntex/src/http/client/mod.rs b/ntex/src/http/client/mod.rs index 7a4aa6dd..fa8b1024 100644 --- a/ntex/src/http/client/mod.rs +++ b/ntex/src/http/client/mod.rs @@ -10,7 +10,7 @@ //! //! let response = client.get("http://www.rust-lang.org") // <- Create request builder //! .header("User-Agent", "ntex::web") -//! .send() // <- Send http request +//! .send() // <- Send http request //! .await; //! //! println!("Response: {:?}", response); diff --git a/ntex/src/http/client/pool.rs b/ntex/src/http/client/pool.rs index d605f661..cfadf83d 100644 --- a/ntex/src/http/client/pool.rs +++ b/ntex/src/http/client/pool.rs @@ -626,7 +626,6 @@ mod tests { #[crate::rt_test] async fn test_basics() { - env_logger::init(); let store = Rc::new(RefCell::new(Vec::new())); let store2 = store.clone(); diff --git a/ntex/src/http/config.rs b/ntex/src/http/config.rs index d103151a..c29b3d2e 100644 --- a/ntex/src/http/config.rs +++ b/ntex/src/http/config.rs @@ -146,7 +146,7 @@ const DATE_VALUE_DEFAULT: [u8; DATE_VALUE_LENGTH_HDR] = [ b'0', b'0', b'0', b'0', b'0', b'0', b'0', b'\r', b'\n', b'\r', b'\n', ]; -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct DateService(Rc); impl Default for DateService { @@ -155,6 +155,7 @@ impl Default for DateService { } } +#[derive(Debug)] struct DateServiceInner { current: Cell, current_time: Cell, diff --git a/ntex/src/http/h1/client.rs b/ntex/src/http/h1/client.rs index 6b144b36..9c3e8907 100644 --- a/ntex/src/http/h1/client.rs +++ b/ntex/src/http/h1/client.rs @@ -21,16 +21,19 @@ bitflags! { } } +#[derive(Debug)] /// HTTP/1 Codec pub struct ClientCodec { inner: ClientCodecInner, } +#[derive(Debug)] /// HTTP/1 Payload Codec pub struct ClientPayloadCodec { inner: ClientCodecInner, } +#[derive(Debug)] struct ClientCodecInner { timer: DateService, decoder: decoder::MessageDecoder, diff --git a/ntex/src/http/h1/decoder.rs b/ntex/src/http/h1/decoder.rs index ea3b445e..54b28d97 100644 --- a/ntex/src/http/h1/decoder.rs +++ b/ntex/src/http/h1/decoder.rs @@ -14,6 +14,7 @@ use super::MAX_BUFFER_SIZE; const MAX_HEADERS: usize = 96; +#[derive(Debug)] /// Incoming messagd decoder pub(super) struct MessageDecoder(PhantomData); diff --git a/ntex/src/http/h1/dispatcher.rs b/ntex/src/http/h1/dispatcher.rs index e4f95548..b4dc77f9 100644 --- a/ntex/src/http/h1/dispatcher.rs +++ b/ntex/src/http/h1/dispatcher.rs @@ -1015,8 +1015,8 @@ mod tests { } #[crate::rt_test] - /// /// h1 dispatcher still processes all incoming requests - /// /// but it does not write any data to socket + /// h1 dispatcher still processes all incoming requests + /// but it does not write any data to socket async fn test_write_disconnected() { let num = Arc::new(AtomicUsize::new(0)); let num2 = num.clone(); @@ -1039,6 +1039,7 @@ mod tests { assert_eq!(num.load(Ordering::Relaxed), 1); } + /// max http message size is 32k (no payload) #[crate::rt_test] async fn test_read_large_message() { let (client, server) = Io::create(); diff --git a/ntex/src/http/h1/service.rs b/ntex/src/http/h1/service.rs index 8e556dcd..d37956cd 100644 --- a/ntex/src/http/h1/service.rs +++ b/ntex/src/http/h1/service.rs @@ -3,8 +3,7 @@ use std::{cell::RefCell, error::Error, fmt, marker, rc::Rc, task}; use crate::http::body::MessageBody; use crate::http::config::{DispatcherConfig, OnRequest, ServiceConfig}; use crate::http::error::{DispatchError, ResponseError}; -use crate::http::request::Request; -use crate::http::response::Response; +use crate::http::{request::Request, response::Response}; use crate::io::{types, Filter, Io}; use crate::service::{IntoServiceFactory, Service, ServiceFactory}; use crate::{time::Millis, util::BoxFuture}; @@ -56,9 +55,9 @@ mod openssl { use tls_openssl::ssl::SslAcceptor; use super::*; - use crate::{server::SslError, service::pipeline_factory}; + use crate::{io::Layer, server::SslError, service::pipeline_factory}; - impl H1Service, S, B, X, U> + impl H1Service, S, B, X, U> where F: Filter, S: ServiceFactory + 'static, @@ -69,7 +68,8 @@ mod openssl { X: ServiceFactory + 'static, X::Error: ResponseError, X::InitError: fmt::Debug, - U: ServiceFactory<(Request, Io>, Codec), Response = ()> + 'static, + U: ServiceFactory<(Request, Io>, Codec), Response = ()> + + 'static, U::Error: fmt::Display + Error, U::InitError: fmt::Debug, { @@ -102,9 +102,9 @@ mod rustls { use tls_rustls::ServerConfig; use super::*; - use crate::{server::SslError, service::pipeline_factory}; + use crate::{io::Layer, server::SslError, service::pipeline_factory}; - impl H1Service, S, B, X, U> + impl H1Service, S, B, X, U> where F: Filter, S: ServiceFactory + 'static, @@ -115,7 +115,8 @@ mod rustls { X: ServiceFactory + 'static, X::Error: ResponseError, X::InitError: fmt::Debug, - U: ServiceFactory<(Request, Io>, Codec), Response = ()> + 'static, + U: ServiceFactory<(Request, Io>, Codec), Response = ()> + + 'static, U::Error: fmt::Display + Error, U::InitError: fmt::Debug, { diff --git a/ntex/src/http/h2/payload.rs b/ntex/src/http/h2/payload.rs index 796c1968..a8b753ab 100644 --- a/ntex/src/http/h2/payload.rs +++ b/ntex/src/http/h2/payload.rs @@ -64,6 +64,7 @@ impl Stream for Payload { } } +#[derive(Debug)] /// Sender part of the payload stream pub struct PayloadSender { inner: Weak>, diff --git a/ntex/src/http/h2/service.rs b/ntex/src/http/h2/service.rs index 8ef9334d..d0fd5ce4 100644 --- a/ntex/src/http/h2/service.rs +++ b/ntex/src/http/h2/service.rs @@ -49,13 +49,11 @@ mod openssl { use ntex_tls::openssl::{Acceptor, SslFilter}; use tls_openssl::ssl::SslAcceptor; - use crate::io::Filter; - use crate::server::SslError; - use crate::service::pipeline_factory; + use crate::{io::Layer, server::SslError, service::pipeline_factory}; use super::*; - impl H2Service, S, B> + impl H2Service, S, B> where F: Filter, S: ServiceFactory + 'static, @@ -90,9 +88,9 @@ mod rustls { use tls_rustls::ServerConfig; use super::*; - use crate::{server::SslError, service::pipeline_factory}; + use crate::{io::Layer, server::SslError, service::pipeline_factory}; - impl H2Service, S, B> + impl H2Service, S, B> where F: Filter, S: ServiceFactory + 'static, @@ -394,7 +392,7 @@ where loop { match poll_fn(|cx| body.poll_next_chunk(cx)).await { None => { - log::debug!("{:?} closing sending payload", msg.id()); + log::debug!("{:?} closing payload stream", msg.id()); msg.stream().send_payload(Bytes::new(), true).await?; break; } diff --git a/ntex/src/http/service.rs b/ntex/src/http/service.rs index fb338840..b621200b 100644 --- a/ntex/src/http/service.rs +++ b/ntex/src/http/service.rs @@ -146,10 +146,9 @@ mod openssl { use tls_openssl::ssl::SslAcceptor; use super::*; - use crate::server::SslError; - use crate::service::pipeline_factory; + use crate::{io::Layer, server::SslError, service::pipeline_factory}; - impl HttpService, S, B, X, U> + impl HttpService, S, B, X, U> where F: Filter, S: ServiceFactory + 'static, @@ -160,7 +159,8 @@ mod openssl { X: ServiceFactory + 'static, X::Error: ResponseError, X::InitError: fmt::Debug, - U: ServiceFactory<(Request, Io>, h1::Codec), Response = ()> + 'static, + U: ServiceFactory<(Request, Io>, h1::Codec), Response = ()> + + 'static, U::Error: fmt::Display + error::Error, U::InitError: fmt::Debug, { @@ -191,9 +191,9 @@ mod rustls { use tls_rustls::ServerConfig; use super::*; - use crate::{server::SslError, service::pipeline_factory}; + use crate::{io::Layer, server::SslError, service::pipeline_factory}; - impl HttpService, S, B, X, U> + impl HttpService, S, B, X, U> where F: Filter, S: ServiceFactory + 'static, @@ -204,7 +204,8 @@ mod rustls { X: ServiceFactory + 'static, X::Error: ResponseError, X::InitError: fmt::Debug, - U: ServiceFactory<(Request, Io>, h1::Codec), Response = ()> + 'static, + U: ServiceFactory<(Request, Io>, h1::Codec), Response = ()> + + 'static, U::Error: fmt::Display + error::Error, U::InitError: fmt::Debug, { diff --git a/ntex/src/http/test.rs b/ntex/src/http/test.rs index f792a514..8ab18c19 100644 --- a/ntex/src/http/test.rs +++ b/ntex/src/http/test.rs @@ -4,8 +4,9 @@ use std::{convert::TryFrom, net, str::FromStr, sync::mpsc, thread}; #[cfg(feature = "cookie")] use coo_kie::{Cookie, CookieJar}; +use crate::io::{Filter, Io}; use crate::ws::{error::WsClientError, WsClient, WsConnection}; -use crate::{io::Filter, io::Io, rt::System, server::Server, service::ServiceFactory}; +use crate::{rt::System, server::Server, service::ServiceFactory}; use crate::{time::Millis, time::Seconds, util::Bytes}; use super::client::{Client, ClientRequest, ClientResponse, Connector}; @@ -349,7 +350,10 @@ impl TestServer { /// Connect to a websocket server pub async fn wss( &mut self, - ) -> Result, WsClientError> { + ) -> Result< + WsConnection>, + WsClientError, + > { self.wss_at("/").await } @@ -358,7 +362,10 @@ impl TestServer { pub async fn wss_at( &mut self, path: &str, - ) -> Result, WsClientError> { + ) -> Result< + WsConnection>, + WsClientError, + > { use tls_openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); diff --git a/ntex/src/ws/client.rs b/ntex/src/ws/client.rs index 867b0c4e..267152ed 100644 --- a/ntex/src/ws/client.rs +++ b/ntex/src/ws/client.rs @@ -15,13 +15,13 @@ use crate::connect::{Connect, ConnectError, Connector}; use crate::http::header::{self, HeaderMap, HeaderName, HeaderValue, AUTHORIZATION}; use crate::http::{body::BodySize, client::ClientResponse, error::HttpError, h1}; use crate::http::{ConnectionType, RequestHead, RequestHeadType, StatusCode, Uri}; -use crate::io::{Base, DispatchItem, Dispatcher, Filter, Io, Sealed}; +use crate::io::{Base, DispatchItem, Dispatcher, Filter, Io, Layer, Sealed}; use crate::service::{apply_fn, into_service, IntoService, Service}; use crate::time::{timeout, Millis, Seconds}; use crate::{channel::mpsc, rt, util::Ready, ws}; use super::error::{WsClientBuilderError, WsClientError, WsError}; -use super::transport::{WsTransport, WsTransportFactory}; +use super::transport::WsTransport; /// `WebSocket` client builder pub struct WsClient { @@ -527,7 +527,7 @@ where pub fn openssl( &mut self, connector: openssl::SslConnector, - ) -> WsClientBuilder> { + ) -> WsClientBuilder, openssl::Connector> { self.connector(openssl::Connector::new(connector)) } @@ -536,7 +536,7 @@ where pub fn rustls( &mut self, config: std::sync::Arc, - ) -> WsClientBuilder> { + ) -> WsClientBuilder, rustls::Connector> { self.connector(rustls::Connector::from(config)) } @@ -787,12 +787,8 @@ impl WsConnection { } /// Convert to ws stream to plain io stream - pub async fn into_transport(self) -> Io> { - // WsTransportFactory is infallible - self.io - .add_filter(WsTransportFactory::new(self.codec)) - .await - .unwrap() + pub fn into_transport(self) -> Io> { + WsTransport::create(self.io, self.codec) } } diff --git a/ntex/src/ws/transport.rs b/ntex/src/ws/transport.rs index 58913e83..b4deb49a 100644 --- a/ntex/src/ws/transport.rs +++ b/ntex/src/ws/transport.rs @@ -1,9 +1,9 @@ //! An implementation of WebSockets base bytes streams -use std::{any, cell::Cell, cmp, io, task::Context, task::Poll}; +use std::{cell::Cell, cmp, io, task::Poll}; use crate::codec::{Decoder, Encoder}; -use crate::io::{Base, Filter, FilterFactory, Io, IoRef, ReadStatus, WriteStatus}; -use crate::util::{BufMut, BytesVec, PoolRef, Ready}; +use crate::io::{Filter, FilterFactory, FilterLayer, Io, Layer, ReadBuf, WriteBuf}; +use crate::util::{BufMut, PoolRef, Ready}; use super::{CloseCode, CloseReason, Codec, Frame, Item, Message}; @@ -16,15 +16,24 @@ bitflags::bitflags! { } /// An implementation of WebSockets streams -pub struct WsTransport { - inner: F, +pub struct WsTransport { pool: PoolRef, codec: Codec, flags: Cell, - read_buf: Cell>, } -impl WsTransport { +impl WsTransport { + /// Create websockets transport + pub fn create(io: Io, codec: Codec) -> Io> { + let pool = io.memory_pool(); + + io.add_filter(WsTransport { + pool, + codec, + flags: Cell::new(Flags::empty()), + }) + } + fn insert_flags(&self, flags: Flags) { let mut f = self.flags.get(); f.insert(flags); @@ -47,21 +56,12 @@ impl WsTransport { } } -impl Filter for WsTransport { +impl FilterLayer for WsTransport { #[inline] - fn query(&self, id: any::TypeId) -> Option> { - self.inner.query(id) - } - - #[inline] - fn poll_shutdown(&self) -> Poll> { + fn shutdown(&self, buf: &mut WriteBuf<'_>) -> io::Result> { let flags = self.flags.get(); if !flags.contains(Flags::CLOSED) { self.insert_flags(Flags::CLOSED); - let mut b = self - .inner - .get_write_buf() - .unwrap_or_else(|| self.pool.get_write_buf()); let code = if flags.contains(Flags::PROTO_ERR) { CloseCode::Protocol } else { @@ -72,159 +72,100 @@ impl Filter for WsTransport { code, description: None, })), - &mut b, + buf.get_dst(), ); - self.inner.release_write_buf(b)?; } - - self.inner.poll_shutdown() + Ok(Poll::Ready(())) } - #[inline] - fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll { - self.inner.poll_read_ready(cx) - } + fn process_read_buf(&self, buf: &mut ReadBuf<'_>) -> io::Result { + if let Some(mut src) = buf.take_src() { + let mut dst = buf.take_dst(); + let dst_len = dst.len(); - #[inline] - fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll { - self.inner.poll_write_ready(cx) - } + loop { + // make sure we've got room + self.pool.resize_read_buf(&mut dst); - #[inline] - fn get_read_buf(&self) -> Option { - self.read_buf.take() - } + let frame = if let Some(frame) = + self.codec.decode_vec(&mut src).map_err(|e| { + log::trace!("Failed to decode ws codec frames: {:?}", e); + self.insert_flags(Flags::PROTO_ERR); + io::Error::new(io::ErrorKind::Other, e) + })? { + frame + } else { + break; + }; - #[inline] - fn get_write_buf(&self) -> Option { - None - } - - #[inline] - fn release_read_buf(&self, buf: BytesVec) { - self.read_buf.set(Some(buf)); - } - - fn process_read_buf(&self, io: &IoRef, nbytes: usize) -> io::Result<(usize, usize)> { - // ask inner filter to process read buf - match self.inner.process_read_buf(io, nbytes) { - Err(err) => io.want_shutdown(Some(err)), - Ok((_, 0)) => return Ok((0, 0)), - Ok(_) => (), - } - - // get inner buffer - let mut src = if let Some(src) = self.inner.get_read_buf() { - src - } else { - return Ok((0, 0)); - }; - - // get processed buffer - let mut dst = if let Some(dst) = self.read_buf.take() { - dst - } else { - self.pool.get_read_buf() - }; - let dst_len = dst.len(); - let (hw, lw) = self.pool.read_params().unpack(); - - loop { - // make sure we've got room - let remaining = dst.remaining_mut(); - if remaining < lw { - dst.reserve(hw - remaining); + match frame { + Frame::Binary(bin) => dst.extend_from_slice(&bin), + Frame::Continuation(Item::FirstBinary(bin)) => { + self.insert_flags(Flags::CONTINUATION); + dst.extend_from_slice(&bin); + } + Frame::Continuation(Item::Continue(bin)) => { + self.continuation_must_start("Continuation frame is not started")?; + dst.extend_from_slice(&bin); + } + Frame::Continuation(Item::Last(bin)) => { + self.continuation_must_start( + "Continuation frame is not started, last frame is received", + )?; + dst.extend_from_slice(&bin); + self.remove_flags(Flags::CONTINUATION); + } + Frame::Continuation(Item::FirstText(_)) => { + self.insert_flags(Flags::PROTO_ERR); + return Err(io::Error::new( + io::ErrorKind::Other, + "WebSocket Text continuation frames are not supported", + )); + } + Frame::Text(_) => { + self.insert_flags(Flags::PROTO_ERR); + return Err(io::Error::new( + io::ErrorKind::Other, + "WebSockets Text frames are not supported", + )); + } + Frame::Ping(msg) => { + let _ = buf.with_write_buf(|b| { + self.codec.encode_vec(Message::Pong(msg), b.get_dst()) + }); + } + Frame::Pong(_) => (), + Frame::Close(_) => { + buf.want_shutdown(); + break; + } + }; } - let frame = if let Some(frame) = - self.codec.decode_vec(&mut src).map_err(|e| { - log::trace!("Failed to decode ws codec frames: {:?}", e); - self.insert_flags(Flags::PROTO_ERR); - io::Error::new(io::ErrorKind::Other, e) - })? { - frame - } else { - break; - }; - - match frame { - Frame::Binary(bin) => dst.extend_from_slice(&bin), - Frame::Continuation(Item::FirstBinary(bin)) => { - self.insert_flags(Flags::CONTINUATION); - dst.extend_from_slice(&bin); - } - Frame::Continuation(Item::Continue(bin)) => { - self.continuation_must_start("Continuation frame is not started")?; - dst.extend_from_slice(&bin); - } - Frame::Continuation(Item::Last(bin)) => { - self.continuation_must_start( - "Continuation frame is not started, last frame is received", - )?; - dst.extend_from_slice(&bin); - self.remove_flags(Flags::CONTINUATION); - } - Frame::Continuation(Item::FirstText(_)) => { - self.insert_flags(Flags::PROTO_ERR); - return Err(io::Error::new( - io::ErrorKind::Other, - "WebSocket Text continuation frames are not supported", - )); - } - Frame::Text(_) => { - self.insert_flags(Flags::PROTO_ERR); - return Err(io::Error::new( - io::ErrorKind::Other, - "WebSockets Text frames are not supported", - )); - } - Frame::Ping(msg) => { - let mut b = self - .inner - .get_write_buf() - .unwrap_or_else(|| self.pool.get_write_buf()); - let _ = self.codec.encode_vec(Message::Pong(msg), &mut b); - self.inner.release_write_buf(b)?; - } - Frame::Pong(_) => (), - Frame::Close(_) => { - io.want_shutdown(None); - break; - } - }; - } - - let dlen = dst.len(); - let nbytes = dlen - dst_len; - - if src.is_empty() { - self.pool.release_read_buf(src); + let nb = dst.len() - dst_len; + buf.set_dst(Some(dst)); + buf.set_src(Some(src)); + Ok(nb) } else { - self.inner.release_read_buf(src); + Ok(0) } - self.read_buf.set(Some(dst)); - Ok((dlen, nbytes)) } - fn release_write_buf(&self, src: BytesVec) -> Result<(), io::Error> { - let mut buf = if let Some(buf) = self.inner.get_write_buf() { - buf - } else { - self.pool.get_write_buf() - }; + fn process_write_buf(&self, buf: &mut WriteBuf<'_>) -> io::Result<()> { + if let Some(src) = buf.take_src() { + let dst = buf.get_dst(); - // make sure we've got room - let (hw, lw) = self.pool.write_params().unpack(); - let remaining = buf.remaining_mut(); - if remaining < lw { - buf.reserve(cmp::max(hw, buf.len() + 12) - remaining); + // make sure we've got room + let (hw, lw) = self.pool.write_params().unpack(); + let remaining = dst.remaining_mut(); + if remaining < lw { + dst.reserve(cmp::max(hw, dst.len() + 12) - remaining); + } + + // Encoder ws::Codec do not fail + let _ = self.codec.encode_vec(Message::Binary(src.freeze()), dst); } - - // Encoder ws::Codec do not fail - let _ = self - .codec - .encode_vec(Message::Binary(src.freeze()), &mut buf); - self.inner.release_write_buf(buf) + Ok(()) } } @@ -241,22 +182,12 @@ impl WsTransportFactory { } impl FilterFactory for WsTransportFactory { - type Filter = WsTransport; + type Filter = WsTransport; type Error = io::Error; - type Future = Ready, Self::Error>; + type Future = Ready>, Self::Error>; - fn create(self, st: Io) -> Self::Future { - let pool = st.memory_pool(); - - Ready::from(st.map_filter(|inner: F| { - Ok(WsTransport { - pool, - inner, - codec: self.codec, - flags: Cell::new(Flags::empty()), - read_buf: Cell::new(None), - }) - })) + fn create(self, io: Io) -> Self::Future { + Ready::Ok(WsTransport::create(io, self.codec)) } } diff --git a/ntex/tests/http_openssl.rs b/ntex/tests/http_openssl.rs index 07dbdea1..4507a361 100644 --- a/ntex/tests/http_openssl.rs +++ b/ntex/tests/http_openssl.rs @@ -494,11 +494,8 @@ async fn test_ws_transport() { ) .unwrap(); - let io = io - .add_filter(ws::WsTransportFactory::new(ws::Codec::default())) - .await?; - // start websocket service + let io = ws::WsTransport::create(io, ws::Codec::default()); while let Some(item) = io.recv(&BytesCodec).await.map_err(|e| e.into_inner())? { diff --git a/ntex/tests/http_ws.rs b/ntex/tests/http_ws.rs index 9cb9ec61..871cbd8d 100644 --- a/ntex/tests/http_ws.rs +++ b/ntex/tests/http_ws.rs @@ -254,9 +254,7 @@ async fn test_transport() { ) .unwrap(); - let io = io - .add_filter(ws::WsTransportFactory::new(ws::Codec::default())) - .await?; + let io = ws::WsTransport::create(io, ws::Codec::default()); // start websocket service while let Some(item) = diff --git a/ntex/tests/http_ws_client.rs b/ntex/tests/http_ws_client.rs index 24b33d35..e8fc0e13 100644 --- a/ntex/tests/http_ws_client.rs +++ b/ntex/tests/http_ws_client.rs @@ -95,7 +95,7 @@ async fn test_transport() { }); // client service - let io = srv.ws().await.unwrap().into_transport().await; + let io = srv.ws().await.unwrap().into_transport(); io.send(Bytes::from_static(b"text"), &BytesCodec) .await diff --git a/ntex/tests/server.rs b/ntex/tests/server.rs index 27bd61ec..8971f0f5 100644 --- a/ntex/tests/server.rs +++ b/ntex/tests/server.rs @@ -120,6 +120,7 @@ async fn test_run() { // stop let _ = srv.stop(false).await; + thread::sleep(time::Duration::from_millis(100)); assert!(net::TcpStream::connect(addr).is_err()); thread::sleep(time::Duration::from_millis(100)); @@ -250,7 +251,6 @@ fn test_configure_async() { #[cfg(feature = "tokio")] #[allow(unreachable_code)] fn test_panic_in_worker() { - env_logger::init(); let counter = Arc::new(AtomicUsize::new(0)); let counter2 = counter.clone();