diff --git a/ntex-io/CHANGES.md b/ntex-io/CHANGES.md index 0759e780..7275df30 100644 --- a/ntex-io/CHANGES.md +++ b/ntex-io/CHANGES.md @@ -1,5 +1,9 @@ # Changes +## [0.2.2] - 2023-01-24 + +* Process write buffer if filter wrote to write buffer during reading + ## [0.2.1] - 2023-01-23 * Refactor Io and Filter types diff --git a/ntex-io/Cargo.toml b/ntex-io/Cargo.toml index f03a853e..73c14813 100644 --- a/ntex-io/Cargo.toml +++ b/ntex-io/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-io" -version = "0.2.1" +version = "0.2.2" authors = ["ntex contributors "] description = "Utilities for encoding and decoding frames" keywords = ["network", "framework", "async", "futures"] @@ -30,4 +30,4 @@ smallvec = "1" rand = "0.8" env_logger = "0.10" -ntex = { version = "0.6.0", features = ["tokio"] } +ntex = { version = "0.6.1", features = ["tokio"] } diff --git a/ntex-io/src/buf.rs b/ntex-io/src/buf.rs index e45f482b..be1bb30e 100644 --- a/ntex-io/src/buf.rs +++ b/ntex-io/src/buf.rs @@ -38,6 +38,7 @@ impl Stack { nbytes, curr: &mut curr[idx], next: &mut next[0], + need_write: false, }; f(&mut buf) } else { @@ -49,6 +50,7 @@ impl Stack { nbytes, curr: &mut val1, next: &mut val2, + need_write: false, }; let result = f(&mut buf); @@ -69,6 +71,7 @@ impl Stack { io, curr: &mut curr[idx], next: &mut next[0], + need_write: false, }; f(&mut buf) } else { @@ -79,6 +82,7 @@ impl Stack { io, curr: &mut val1, next: &mut val2, + need_write: false, }; let result = f(&mut buf); @@ -152,6 +156,7 @@ pub struct ReadBuf<'a> { pub(crate) curr: &'a mut (Option, Option), pub(crate) next: &'a mut (Option, Option), pub(crate) nbytes: usize, + pub(crate) need_write: bool, } impl<'a> ReadBuf<'a> { @@ -254,8 +259,11 @@ impl<'a> ReadBuf<'a> { io: self.io, curr: self.curr, next: self.next, + need_write: self.need_write, }; - f(&mut buf) + let result = f(&mut buf); + self.need_write = buf.need_write; + result } } @@ -264,6 +272,7 @@ pub struct WriteBuf<'a> { pub(crate) io: &'a IoRef, pub(crate) curr: &'a mut (Option, Option), pub(crate) next: &'a mut (Option, Option), + pub(crate) need_write: bool, } impl<'a> WriteBuf<'a> { @@ -302,11 +311,17 @@ impl<'a> WriteBuf<'a> { #[inline] /// Get reference to destination write buffer - pub fn get_dst(&mut self) -> &mut BytesVec { + pub fn with_dst_buf(&mut self, f: F) -> R + where + F: FnOnce(&mut BytesVec) -> R, + { if self.next.1.is_none() { self.next.1 = Some(self.io.memory_pool().get_write_buf()); } - self.next.1.as_mut().unwrap() + let buf = self.next.1.as_mut().unwrap(); + let r = f(buf); + self.need_write |= !buf.is_empty(); + r } #[inline] @@ -328,23 +343,12 @@ impl<'a> WriteBuf<'a> { if let Some(b) = self.next.1.take() { self.io.memory_pool().release_write_buf(b); } + self.need_write |= !dst.is_empty(); self.next.1 = Some(dst); } } } - #[inline] - /// Get reference to source and destination buffers (src, dst) - pub fn get_pair(&mut self) -> (&mut BytesVec, &mut BytesVec) { - if self.curr.1.is_none() { - self.curr.1 = Some(self.io.memory_pool().get_write_buf()); - } - if self.next.1.is_none() { - self.next.1 = Some(self.io.memory_pool().get_write_buf()); - } - (self.curr.1.as_mut().unwrap(), self.next.1.as_mut().unwrap()) - } - #[inline] pub fn with_read_buf<'b, F, R>(&'b mut self, f: F) -> R where @@ -355,7 +359,10 @@ impl<'a> WriteBuf<'a> { curr: self.curr, next: self.next, nbytes: 0, + need_write: self.need_write, }; - f(&mut buf) + let result = f(&mut buf); + self.need_write = buf.need_write; + result } } diff --git a/ntex-io/src/filter.rs b/ntex-io/src/filter.rs index f4c613ec..608dfe30 100644 --- a/ntex-io/src/filter.rs +++ b/ntex-io/src/filter.rs @@ -29,6 +29,12 @@ impl NullFilter { } } +#[derive(Debug, Copy, Clone, Default, PartialEq, Eq)] +pub struct FilterReadStatus { + pub nbytes: usize, + pub need_write: bool, +} + pub trait Filter: 'static { fn query(&self, id: any::TypeId) -> Option>; @@ -38,7 +44,7 @@ pub trait Filter: 'static { stack: &mut Stack, idx: usize, nbytes: usize, - ) -> io::Result; + ) -> io::Result; /// Process write buffer fn process_write_buf( @@ -112,8 +118,11 @@ impl Filter for Base { _: &mut Stack, _: usize, nbytes: usize, - ) -> io::Result { - Ok(nbytes) + ) -> io::Result { + Ok(FilterReadStatus { + nbytes, + need_write: false, + }) } #[inline] @@ -167,13 +176,18 @@ where stack: &mut Stack, idx: usize, nbytes: usize, - ) -> io::Result { - let nbytes = if F::BUFFERS { + ) -> io::Result { + let status = if F::BUFFERS { self.1.process_read_buf(io, stack, idx + 1, nbytes)? } else { self.1.process_read_buf(io, stack, idx, nbytes)? }; - stack.read_buf(io, idx, nbytes, |buf| self.0.process_read_buf(buf)) + stack.read_buf(io, idx, status.nbytes, |buf| { + self.0.process_read_buf(buf).map(|nbytes| FilterReadStatus { + nbytes, + need_write: status.need_write || buf.need_write, + }) + }) } #[inline] @@ -254,8 +268,8 @@ impl Filter for NullFilter { _: &mut Stack, _: usize, _: usize, - ) -> io::Result { - Ok(0) + ) -> io::Result { + Ok(Default::default()) } #[inline] diff --git a/ntex-io/src/ioref.rs b/ntex-io/src/ioref.rs index bab38364..2fc994b9 100644 --- a/ntex-io/src/ioref.rs +++ b/ntex-io/src/ioref.rs @@ -406,7 +406,7 @@ mod tests { fn process_write_buf(&self, buf: &mut WriteBuf<'_>) -> io::Result<()> { self.write_order.borrow_mut().push(self.idx); self.out_bytes - .set(self.out_bytes.get() + buf.get_dst().len()); + .set(self.out_bytes.get() + buf.with_dst_buf(|b| b.len())); Ok(()) } } diff --git a/ntex-io/src/tasks.rs b/ntex-io/src/tasks.rs index 1a874096..28d8503b 100644 --- a/ntex-io/src/tasks.rs +++ b/ntex-io/src/tasks.rs @@ -24,6 +24,7 @@ impl ReadContext { F: FnOnce(&mut BytesVec, usize, usize) -> Poll>, { let mut stack = self.0 .0.buffer.borrow_mut(); + let is_write_sleep = stack.last_write_buf_size() == 0; let mut buf = stack .last_read_buf() .take() @@ -50,8 +51,8 @@ impl ReadContext { .filter() .process_read_buf(&self.0, &mut stack, 0, nbytes) { - Ok(nbytes) => { - if nbytes > 0 { + Ok(status) => { + if status.nbytes > 0 { if buf_full || stack.first_read_buf_size() >= hw { log::trace!( "io read buffer is too large {}, enable read back-pressure", @@ -70,8 +71,27 @@ impl ReadContext { ); } else if buf_full { // read task is paused because of read back-pressure + // but there is no new data in top most read buffer + // so we need to wake up read task to read more data + // otherwise read task would sleep forever self.0 .0.read_task.wake(); } + + // while reading, filter wrote some data + // in that case filters need to process write buffers + // and potentialy wake write task + if status.need_write { + if let Err(err) = + self.0.filter().process_write_buf(&self.0, &mut stack, 0) + { + self.0 .0.dispatch_task.wake(); + self.0 .0.insert_flags(Flags::RD_READY); + self.0 .0.init_shutdown(Some(err), &self.0); + } + if is_write_sleep && stack.last_write_buf_size() != 0 { + self.0 .0.write_task.wake(); + } + } } Err(err) => { self.0 .0.dispatch_task.wake(); diff --git a/ntex-io/src/utils.rs b/ntex-io/src/utils.rs index bfe4aac6..c196ee45 100644 --- a/ntex-io/src/utils.rs +++ b/ntex-io/src/utils.rs @@ -182,7 +182,7 @@ mod tests { NullFilter .process_read_buf(&ioref, &mut stack, 0, 0) .unwrap(), - (0) + Default::default() ) } } diff --git a/ntex-tls/CHANGES.md b/ntex-tls/CHANGES.md index 2c3ff422..e18b1896 100644 --- a/ntex-tls/CHANGES.md +++ b/ntex-tls/CHANGES.md @@ -1,5 +1,9 @@ # Changes +## [0.2.2] - 2023-01-24 + +* Update ntex-io to 0.2.2 + ## [0.2.1] - 2023-01-23 * Update filter implementation diff --git a/ntex-tls/Cargo.toml b/ntex-tls/Cargo.toml index 933aaf61..10cffd29 100644 --- a/ntex-tls/Cargo.toml +++ b/ntex-tls/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-tls" -version = "0.2.1" +version = "0.2.2" authors = ["ntex contributors "] description = "An implementation of SSL streams for ntex backed by OpenSSL" keywords = ["network", "framework", "async", "futures"] @@ -26,7 +26,7 @@ rustls = ["tls_rust"] [dependencies] ntex-bytes = "0.1.19" -ntex-io = "0.2.1" +ntex-io = "0.2.2" ntex-util = "0.2.0" ntex-service = "1.0.0" log = "0.4" @@ -39,7 +39,7 @@ tls_openssl = { version="0.10", package = "openssl", optional = true } tls_rust = { version = "0.20", package = "rustls", optional = true } [dev-dependencies] -ntex = { version = "0.6.0", features = ["openssl", "rustls", "tokio"] } +ntex = { version = "0.6.1", features = ["openssl", "rustls", "tokio"] } env_logger = "0.10" rustls-pemfile = { version = "1.0" } webpki-roots = { version = "0.22" } diff --git a/ntex-tls/src/openssl/mod.rs b/ntex-tls/src/openssl/mod.rs index c78b76d0..5022ddd1 100644 --- a/ntex-tls/src/openssl/mod.rs +++ b/ntex-tls/src/openssl/mod.rs @@ -29,20 +29,20 @@ pub struct SslFilter { } struct IoInner { - inner_read_buf: Option, - inner_write_buf: Option, + source: Option, + destination: Option, pool: PoolRef, } impl io::Read for IoInner { fn read(&mut self, dst: &mut [u8]) -> io::Result { - if let Some(mut buf) = self.inner_read_buf.take() { + if let Some(mut buf) = self.source.take() { if buf.is_empty() { Err(io::Error::from(io::ErrorKind::WouldBlock)) } else { let len = cmp::min(buf.len(), dst.len()); dst[..len].copy_from_slice(&buf.split_to(len)); - self.inner_read_buf = Some(buf); + self.source = Some(buf); Ok(len) } } else { @@ -53,14 +53,14 @@ impl io::Read for IoInner { impl io::Write for IoInner { fn write(&mut self, src: &[u8]) -> io::Result { - let mut buf = if let Some(mut buf) = self.inner_write_buf.take() { + let mut buf = if let Some(mut buf) = self.destination.take() { buf.reserve(src.len()); buf } else { BytesVec::with_capacity_in(src.len(), self.pool) }; buf.extend_from_slice(src); - self.inner_write_buf = Some(buf); + self.destination = Some(buf); Ok(src.len()) } @@ -74,28 +74,22 @@ impl SslFilter { where F: FnOnce() -> R, { - self.inner.borrow_mut().get_mut().inner_write_buf = Some(buf.take_dst()); - self.inner.borrow_mut().get_mut().inner_read_buf = - buf.with_read_buf(|b| b.take_src()); + self.inner.borrow_mut().get_mut().destination = Some(buf.take_dst()); + self.inner.borrow_mut().get_mut().source = buf.with_read_buf(|b| b.take_src()); let result = f(); - buf.set_dst(self.inner.borrow_mut().get_mut().inner_write_buf.take()); - buf.with_read_buf(|b| { - b.set_src(self.inner.borrow_mut().get_mut().inner_read_buf.take()) - }); + buf.set_dst(self.inner.borrow_mut().get_mut().destination.take()); + buf.with_read_buf(|b| b.set_src(self.inner.borrow_mut().get_mut().source.take())); result } fn set_buffers(&self, buf: &mut WriteBuf<'_>) { - self.inner.borrow_mut().get_mut().inner_write_buf = Some(buf.take_dst()); - self.inner.borrow_mut().get_mut().inner_read_buf = - buf.with_read_buf(|b| b.take_src()); + self.inner.borrow_mut().get_mut().destination = Some(buf.take_dst()); + self.inner.borrow_mut().get_mut().source = buf.with_read_buf(|b| b.take_src()); } fn unset_buffers(&self, buf: &mut WriteBuf<'_>) { - buf.set_dst(self.inner.borrow_mut().get_mut().inner_write_buf.take()); - buf.with_read_buf(|b| { - b.set_src(self.inner.borrow_mut().get_mut().inner_read_buf.take()) - }); + buf.set_dst(self.inner.borrow_mut().get_mut().destination.take()); + buf.with_read_buf(|b| b.set_src(self.inner.borrow_mut().get_mut().source.take())); } } @@ -171,7 +165,8 @@ impl FilterLayer for SslFilter { buf.with_write_buf(|b| self.set_buffers(b)); let dst = buf.get_dst(); - let mut new_bytes = usize::from(self.handshake.get()); + //let mut new_bytes = usize::from(self.handshake.get()); + let mut new_bytes = 1; loop { // make sure we've got room self.pool.resize_read_buf(dst); @@ -228,11 +223,7 @@ impl FilterLayer for SslFilter { return match e.code() { ssl::ErrorCode::WANT_READ | ssl::ErrorCode::WANT_WRITE => { buf.set_dst( - self.inner - .borrow_mut() - .get_mut() - .inner_write_buf - .take(), + self.inner.borrow_mut().get_mut().destination.take(), ); Ok(()) } @@ -294,8 +285,8 @@ impl FilterFactory for SslAcceptor { let ssl = ctx_result.map_err(map_to_ioerr)?; let inner = IoInner { pool: io.memory_pool(), - inner_read_buf: None, - inner_write_buf: None, + source: None, + destination: None, }; let filter = SslFilter { pool: io.memory_pool(), @@ -352,8 +343,8 @@ impl FilterFactory for SslConnector { Box::pin(async move { let inner = IoInner { pool: io.memory_pool(), - inner_read_buf: None, - inner_write_buf: None, + source: None, + destination: None, }; let filter = SslFilter { pool: io.memory_pool(), diff --git a/ntex-tls/src/rustls/mod.rs b/ntex-tls/src/rustls/mod.rs index 24b470fd..814c26c2 100644 --- a/ntex-tls/src/rustls/mod.rs +++ b/ntex-tls/src/rustls/mod.rs @@ -242,7 +242,7 @@ impl<'a, 'b> io::Read for Wrapper<'a, 'b> { impl<'a, 'b> io::Write for Wrapper<'a, 'b> { fn write(&mut self, src: &[u8]) -> io::Result { - self.1.get_dst().extend_from_slice(src); + self.1.with_dst_buf(|buf| buf.extend_from_slice(src)); Ok(src.len()) } diff --git a/ntex/CHANGES.md b/ntex/CHANGES.md index f0f33dc8..ba41c71f 100644 --- a/ntex/CHANGES.md +++ b/ntex/CHANGES.md @@ -1,5 +1,9 @@ # Changes +## [0.6.2] - 2023-01-24 + +* Update ntex-io, ntex-tls deps + ## [0.6.1] - 2023-01-23 * Refactor io subsystem diff --git a/ntex/Cargo.toml b/ntex/Cargo.toml index eba9d7cd..917c9617 100644 --- a/ntex/Cargo.toml +++ b/ntex/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex" -version = "0.6.1" +version = "0.6.2" authors = ["ntex contributors "] description = "Framework for composable network services" readme = "README.md" @@ -56,10 +56,10 @@ ntex-service = "1.0.0" ntex-macros = "0.1.3" ntex-util = "0.2.0" ntex-bytes = "0.1.19" -ntex-h2 = "0.2.0" +ntex-h2 = "0.2.1" ntex-rt = "0.4.7" -ntex-io = "0.2.1" -ntex-tls = "0.2.1" +ntex-io = "0.2.2" +ntex-tls = "0.2.2" ntex-tokio = { version = "0.2.1", optional = true } ntex-glommio = { version = "0.2.1", optional = true } ntex-async-std = { version = "0.2.1", optional = true } diff --git a/ntex/src/ws/transport.rs b/ntex/src/ws/transport.rs index b4deb49a..7938be7a 100644 --- a/ntex/src/ws/transport.rs +++ b/ntex/src/ws/transport.rs @@ -67,13 +67,15 @@ impl FilterLayer for WsTransport { } else { CloseCode::Normal }; - let _ = self.codec.encode_vec( - Message::Close(Some(CloseReason { - code, - description: None, - })), - buf.get_dst(), - ); + let _ = buf.with_dst_buf(|buf| { + self.codec.encode_vec( + Message::Close(Some(CloseReason { + code, + description: None, + })), + buf, + ) + }); } Ok(Poll::Ready(())) } @@ -131,7 +133,7 @@ impl FilterLayer for WsTransport { } Frame::Ping(msg) => { let _ = buf.with_write_buf(|b| { - self.codec.encode_vec(Message::Pong(msg), b.get_dst()) + b.with_dst_buf(|b| self.codec.encode_vec(Message::Pong(msg), b)) }); } Frame::Pong(_) => (), @@ -153,17 +155,17 @@ impl FilterLayer for WsTransport { fn process_write_buf(&self, buf: &mut WriteBuf<'_>) -> io::Result<()> { if let Some(src) = buf.take_src() { - let dst = buf.get_dst(); + buf.with_dst_buf(|dst| { + // make sure we've got room + let (hw, lw) = self.pool.write_params().unpack(); + let remaining = dst.remaining_mut(); + if remaining < lw { + dst.reserve(cmp::max(hw, dst.len() + 12) - remaining); + } - // make sure we've got room - let (hw, lw) = self.pool.write_params().unpack(); - let remaining = dst.remaining_mut(); - if remaining < lw { - dst.reserve(cmp::max(hw, dst.len() + 12) - remaining); - } - - // Encoder ws::Codec do not fail - let _ = self.codec.encode_vec(Message::Binary(src.freeze()), dst); + // Encoder ws::Codec do not fail + let _ = self.codec.encode_vec(Message::Binary(src.freeze()), dst); + }); } Ok(()) }