diff --git a/ntex-io/CHANGES.md b/ntex-io/CHANGES.md index 3c9c32cf..b8490723 100644 --- a/ntex-io/CHANGES.md +++ b/ntex-io/CHANGES.md @@ -1,5 +1,9 @@ # Changes +## [2.3.0] - 2024-08-28 + +* Extend io task contexts, for "compio" runtime compatibility + ## [2.2.0] - 2024-08-12 * Allow to notify dispatcher from IoRef diff --git a/ntex-io/Cargo.toml b/ntex-io/Cargo.toml index 47f52cb0..c66aca7f 100644 --- a/ntex-io/Cargo.toml +++ b/ntex-io/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-io" -version = "2.2.0" +version = "2.3.0" authors = ["ntex contributors "] description = "Utilities for encoding and decoding frames" keywords = ["network", "framework", "async", "futures"] @@ -18,8 +18,8 @@ path = "src/lib.rs" [dependencies] ntex-codec = "0.6.2" ntex-bytes = "0.1.24" -ntex-util = "2.2" -ntex-service = "3.0" +ntex-util = "2.3" +ntex-service = "3" bitflags = "2" log = "0.4" diff --git a/ntex-io/src/buf.rs b/ntex-io/src/buf.rs index 57709c17..478442ef 100644 --- a/ntex-io/src/buf.rs +++ b/ntex-io/src/buf.rs @@ -140,6 +140,18 @@ impl Stack { }) } + pub(crate) fn get_read_source(&self) -> Option { + self.get_last_level().0.take() + } + + pub(crate) fn set_read_source(&self, io: &IoRef, buf: BytesVec) { + if buf.is_empty() { + io.memory_pool().release_read_buf(buf); + } else { + self.get_last_level().0.set(Some(buf)); + } + } + pub(crate) fn with_read_source(&self, io: &IoRef, f: F) -> R where F: FnOnce(&mut BytesVec) -> R, @@ -210,6 +222,10 @@ impl Stack { result } + pub(crate) fn get_write_destination(&self) -> Option { + self.get_last_level().1.take() + } + pub(crate) fn with_write_destination(&self, io: &IoRef, f: F) -> R where F: FnOnce(&mut Option) -> R, diff --git a/ntex-io/src/filter.rs b/ntex-io/src/filter.rs index 4094f95a..ecb9383e 100644 --- a/ntex-io/src/filter.rs +++ b/ntex-io/src/filter.rs @@ -1,6 +1,6 @@ use std::{any, io, task::Context, task::Poll}; -use super::{buf::Stack, io::Flags, FilterLayer, IoRef, ReadStatus, WriteStatus}; +use crate::{buf::Stack, FilterLayer, Flags, IoRef, ReadStatus, WriteStatus}; #[derive(Debug)] /// Default `Io` filter @@ -80,9 +80,10 @@ impl Filter for Base { Poll::Ready(ReadStatus::Terminate) } else { self.0 .0.read_task.register(cx.waker()); + if flags.intersects(Flags::IO_STOPPING_FILTERS) { Poll::Ready(ReadStatus::Ready) - } else if flags.intersects(Flags::RD_PAUSED | Flags::RD_BUF_FULL) { + } else if flags.cannot_read() { Poll::Pending } else { Poll::Ready(ReadStatus::Ready) @@ -109,6 +110,9 @@ impl Filter for Base { Poll::Ready(WriteStatus::Timeout( self.0 .0.disconnect_timeout.get().into(), )) + } else if flags.intersects(Flags::WR_PAUSED) { + self.0 .0.write_task.register(cx.waker()); + Poll::Pending } else { self.0 .0.write_task.register(cx.waker()); Poll::Ready(WriteStatus::Ready) diff --git a/ntex-io/src/flags.rs b/ntex-io/src/flags.rs new file mode 100644 index 00000000..8e80f2cf --- /dev/null +++ b/ntex-io/src/flags.rs @@ -0,0 +1,58 @@ +bitflags::bitflags! { + #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)] + pub struct Flags: u16 { + /// io is closed + const IO_STOPPED = 0b0000_0000_0000_0001; + /// shutdown io tasks + const IO_STOPPING = 0b0000_0000_0000_0010; + /// shuting down filters + const IO_STOPPING_FILTERS = 0b0000_0000_0000_0100; + /// initiate filters shutdown timeout in write task + const IO_FILTERS_TIMEOUT = 0b0000_0000_0000_1000; + + /// pause io read + const RD_PAUSED = 0b0000_0000_0001_0000; + /// read any data and notify dispatcher + const RD_NOTIFY = 0b0000_0000_1000_0000; + + /// new data is available in read buffer + const BUF_R_READY = 0b0000_0000_0010_0000; + /// read buffer is full + const BUF_R_FULL = 0b0000_0000_0100_0000; + + /// wait while write task flushes buf + const BUF_W_MUST_FLUSH = 0b0000_0001_0000_0000; + + /// write buffer is full + const WR_BACKPRESSURE = 0b0000_0010_0000_0000; + /// write task paused + const WR_PAUSED = 0b0000_0100_0000_0000; + + /// dispatcher is marked stopped + const DSP_STOP = 0b0001_0000_0000_0000; + /// timeout occured + const DSP_TIMEOUT = 0b0010_0000_0000_0000; + } +} + +impl Flags { + pub(crate) fn is_waiting_for_write(&self) -> bool { + self.intersects(Flags::BUF_W_MUST_FLUSH | Flags::WR_BACKPRESSURE) + } + + pub(crate) fn waiting_for_write_is_done(&mut self) { + self.remove(Flags::BUF_W_MUST_FLUSH | Flags::WR_BACKPRESSURE); + } + + pub(crate) fn is_read_buf_ready(&self) -> bool { + self.contains(Flags::BUF_R_READY) + } + + pub(crate) fn cannot_read(self) -> bool { + self.intersects(Flags::RD_PAUSED | Flags::BUF_R_FULL) + } + + pub(crate) fn cleanup_read_flags(&mut self) { + self.remove(Flags::BUF_R_READY | Flags::BUF_R_FULL | Flags::RD_PAUSED); + } +} diff --git a/ntex-io/src/io.rs b/ntex-io/src/io.rs index 15d5ce0f..b3e9e91d 100644 --- a/ntex-io/src/io.rs +++ b/ntex-io/src/io.rs @@ -9,46 +9,12 @@ use ntex_util::{future::Either, task::LocalWaker, time::Seconds}; use crate::buf::Stack; use crate::filter::{Base, Filter, Layer, NullFilter}; +use crate::flags::Flags; use crate::seal::Sealed; use crate::tasks::{ReadContext, WriteContext}; use crate::timer::TimerHandle; use crate::{Decoded, FilterLayer, Handle, IoStatusUpdate, IoStream, RecvError}; -bitflags::bitflags! { - #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)] - pub struct Flags: u16 { - /// io is closed - const IO_STOPPED = 0b0000_0000_0000_0001; - /// shutdown io tasks - const IO_STOPPING = 0b0000_0000_0000_0010; - /// shuting down filters - const IO_STOPPING_FILTERS = 0b0000_0000_0000_0100; - /// initiate filters shutdown timeout in write task - const IO_FILTERS_TIMEOUT = 0b0000_0000_0000_1000; - - /// pause io read - const RD_PAUSED = 0b0000_0000_0001_0000; - /// new data is available - const RD_READY = 0b0000_0000_0010_0000; - /// read buffer is full - const RD_BUF_FULL = 0b0000_0000_0100_0000; - /// any new data is available - const RD_FORCE_READY = 0b0000_0000_1000_0000; - - /// wait write completion - const WR_WAIT = 0b0000_0001_0000_0000; - /// write buffer is full - const WR_BACKPRESSURE = 0b0000_0010_0000_0000; - /// write task paused - const WR_PAUSED = 0b0000_0100_0000_0000; - - /// dispatcher is marked stopped - const DSP_STOP = 0b0001_0000_0000_0000; - /// timeout occured - const DSP_TIMEOUT = 0b0010_0000_0000_0000; - } -} - /// Interface object to underlying io stream pub struct Io(UnsafeCell, marker::PhantomData); @@ -384,8 +350,14 @@ impl Io { #[doc(hidden)] #[inline] /// Wait until read becomes ready. + pub async fn read_notify(&self) -> io::Result> { + poll_fn(|cx| self.poll_read_notify(cx)).await + } + + #[doc(hidden)] + #[deprecated] pub async fn force_read_ready(&self) -> io::Result> { - poll_fn(|cx| self.poll_force_read_ready(cx)).await + poll_fn(|cx| self.poll_read_notify(cx)).await } #[inline] @@ -454,9 +426,9 @@ impl Io { } else { st.dispatch_task.register(cx.waker()); - let ready = flags.contains(Flags::RD_READY); - if flags.intersects(Flags::RD_BUF_FULL | Flags::RD_PAUSED) { - flags.remove(Flags::RD_READY | Flags::RD_BUF_FULL | Flags::RD_PAUSED); + let ready = flags.contains(Flags::BUF_R_READY); + if flags.cannot_read() { + flags.cleanup_read_flags(); st.read_task.wake(); st.flags.set(flags); if ready { @@ -465,7 +437,7 @@ impl Io { Poll::Pending } } else if ready { - flags.remove(Flags::RD_READY); + flags.remove(Flags::BUF_R_READY); st.flags.set(flags); Poll::Ready(Ok(Some(()))) } else { @@ -489,18 +461,15 @@ impl Io { /// `Poll::Ready(Ok(Some(()))))` if the io stream is ready for reading. /// `Poll::Ready(Ok(None))` if io stream is disconnected /// `Some(Poll::Ready(Err(e)))` if an error is encountered. - pub fn poll_force_read_ready( - &self, - cx: &mut Context<'_>, - ) -> Poll>> { + pub fn poll_read_notify(&self, cx: &mut Context<'_>) -> Poll>> { let ready = self.poll_read_ready(cx); if ready.is_pending() { let st = self.st(); - if st.remove_flags(Flags::RD_FORCE_READY) { + if st.remove_flags(Flags::RD_NOTIFY) { Poll::Ready(Ok(Some(()))) } else { - st.insert_flags(Flags::RD_FORCE_READY); + st.insert_flags(Flags::RD_NOTIFY); Poll::Pending } } else { @@ -508,6 +477,15 @@ impl Io { } } + #[doc(hidden)] + #[deprecated] + pub fn poll_force_read_ready( + &self, + cx: &mut Context<'_>, + ) -> Poll>> { + self.poll_read_notify(cx) + } + #[inline] /// Decode codec item from incoming bytes stream. /// @@ -597,7 +575,7 @@ impl Io { let len = st.buffer.write_destination_size(); if len > 0 { if full { - st.insert_flags(Flags::WR_WAIT); + st.insert_flags(Flags::BUF_W_MUST_FLUSH); st.dispatch_task.register(cx.waker()); return Poll::Pending; } else if len >= st.pool.get().write_params_high() << 1 { @@ -606,7 +584,7 @@ impl Io { return Poll::Pending; } } - st.remove_flags(Flags::WR_WAIT | Flags::WR_BACKPRESSURE); + st.remove_flags(Flags::BUF_W_MUST_FLUSH | Flags::WR_BACKPRESSURE); Poll::Ready(Ok(())) } } diff --git a/ntex-io/src/ioref.rs b/ntex-io/src/ioref.rs index 6ae3dead..168171e4 100644 --- a/ntex-io/src/ioref.rs +++ b/ntex-io/src/ioref.rs @@ -4,7 +4,7 @@ use ntex_bytes::{BytesVec, PoolRef}; use ntex_codec::{Decoder, Encoder}; use ntex_util::time::Seconds; -use super::{io::Flags, timer, types, Decoded, Filter, IoRef, OnDisconnect, WriteBuf}; +use crate::{timer, types, Decoded, Filter, Flags, IoRef, OnDisconnect, WriteBuf}; impl IoRef { #[inline] diff --git a/ntex-io/src/lib.rs b/ntex-io/src/lib.rs index c613a240..7c034ce7 100644 --- a/ntex-io/src/lib.rs +++ b/ntex-io/src/lib.rs @@ -11,6 +11,7 @@ pub mod types; mod buf; mod dispatcher; mod filter; +mod flags; mod framed; mod io; mod ioref; @@ -33,7 +34,7 @@ pub use self::timer::TimerHandle; pub use self::utils::{seal, Decoded}; #[doc(hidden)] -pub use self::io::Flags; +pub use self::flags::Flags; /// Status for read task #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] diff --git a/ntex-io/src/tasks.rs b/ntex-io/src/tasks.rs index fafc8ec8..3545635e 100644 --- a/ntex-io/src/tasks.rs +++ b/ntex-io/src/tasks.rs @@ -1,8 +1,9 @@ -use std::{io, task::Context, task::Poll}; +use std::{future::poll_fn, future::Future, io, task::Context, task::Poll}; -use ntex_bytes::{BytesVec, PoolRef}; +use ntex_bytes::{BufMut, BytesVec, PoolRef}; +use ntex_util::task; -use super::{io::Flags, IoRef, ReadStatus, WriteStatus}; +use crate::{Flags, IoRef, ReadStatus, WriteStatus}; #[derive(Debug)] /// Context for io read task @@ -19,6 +20,31 @@ impl ReadContext { self.0.tag() } + #[inline] + /// Check readiness for read operations + pub async fn ready(&self) -> ReadStatus { + poll_fn(|cx| self.0.filter().poll_read_ready(cx)).await + } + + #[inline] + /// Wait when io get closed or preparing for close + pub async fn wait_for_close(&self) { + poll_fn(|cx| { + let flags = self.0.flags(); + + if flags.intersects(Flags::IO_STOPPING | Flags::IO_STOPPED) { + Poll::Ready(()) + } else { + self.0 .0.read_task.register(cx.waker()); + if flags.contains(Flags::IO_STOPPING_FILTERS) { + shutdown_filters(&self.0); + } + Poll::Pending + } + }) + .await + } + #[inline] /// Check readiness for read operations pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll { @@ -56,9 +82,9 @@ impl ReadContext { self.0.tag(), total ); - inner.insert_flags(Flags::RD_READY | Flags::RD_BUF_FULL); + inner.insert_flags(Flags::BUF_R_READY | Flags::BUF_R_FULL); } else { - inner.insert_flags(Flags::RD_READY); + inner.insert_flags(Flags::BUF_R_READY); if nbytes >= hw { // read task is paused because of read back-pressure @@ -82,8 +108,8 @@ impl ReadContext { // otherwise read task would sleep forever inner.read_task.wake(); } - if inner.flags.get().contains(Flags::RD_FORCE_READY) { - // in case of "force read" we must wake up dispatch task + if inner.flags.get().contains(Flags::RD_NOTIFY) { + // in case of "notify" we must wake up dispatch task // if we read any data from source inner.dispatch_task.wake(); } @@ -101,7 +127,7 @@ impl ReadContext { .map_err(|err| { inner.dispatch_task.wake(); inner.io_stopped(Some(err)); - inner.insert_flags(Flags::RD_READY); + inner.insert_flags(Flags::BUF_R_READY); }); } @@ -122,6 +148,120 @@ impl ReadContext { } } } + + /// Get read buffer (async) + pub async fn with_buf_async(&self, f: F) -> Poll<()> + where + F: FnOnce(BytesVec) -> R, + R: Future)>, + { + let inner = &self.0 .0; + + // we already pushed new data to read buffer, + // we have to wait for dispatcher to read data from buffer + if inner.flags.get().is_read_buf_ready() { + task::yield_to().await; + } + + let mut buf = if inner.flags.get().is_read_buf_ready() { + // read buffer is still not read by dispatcher + // we cannot touch it + inner.pool.get().get_read_buf() + } else { + inner + .buffer + .get_read_source() + .unwrap_or_else(|| inner.pool.get().get_read_buf()) + }; + + // make sure we've got room + let remaining = buf.remaining_mut(); + let (hw, lw) = self.0.memory_pool().read_params().unpack(); + if remaining < lw { + buf.reserve(hw - remaining); + } + let total = buf.len(); + + // call provided callback + let (buf, result) = f(buf).await; + let total2 = buf.len(); + let nbytes = if total2 > total { total2 - total } else { 0 }; + let total = total2; + + if let Some(mut first_buf) = inner.buffer.get_read_source() { + first_buf.extend_from_slice(&buf); + inner.buffer.set_read_source(&self.0, first_buf); + } else { + inner.buffer.set_read_source(&self.0, buf); + } + + // handle buffer changes + if nbytes > 0 { + let filter = self.0.filter(); + let res = match filter.process_read_buf(&self.0, &inner.buffer, 0, nbytes) { + Ok(status) => { + if status.nbytes > 0 { + // check read back-pressure + if hw < inner.buffer.read_destination_size() { + log::trace!( + "{}: Io read buffer is too large {}, enable read back-pressure", + self.0.tag(), + total + ); + inner.insert_flags(Flags::BUF_R_READY | Flags::BUF_R_FULL); + } else { + inner.insert_flags(Flags::BUF_R_READY); + } + log::trace!( + "{}: New {} bytes available, wakeup dispatcher", + self.0.tag(), + nbytes + ); + // dest buffer has new data, wake up dispatcher + inner.dispatch_task.wake(); + } else if inner.flags.get().contains(Flags::RD_NOTIFY) { + // in case of "notify" we must wake up dispatch task + // if we read any data from source + inner.dispatch_task.wake(); + } + + // while reading, filter wrote some data + // in that case filters need to process write buffers + // and potentialy wake write task + if status.need_write { + filter.process_write_buf(&self.0, &inner.buffer, 0) + } else { + Ok(()) + } + } + Err(err) => Err(err), + }; + + if let Err(err) = res { + inner.dispatch_task.wake(); + inner.io_stopped(Some(err)); + inner.insert_flags(Flags::BUF_R_READY); + } + } + + match result { + Ok(n) => { + if n == 0 { + inner.io_stopped(None); + Poll::Ready(()) + } else { + if inner.flags.get().contains(Flags::IO_STOPPING_FILTERS) { + shutdown_filters(&self.0); + } + Poll::Pending + } + } + Err(e) => { + inner.io_stopped(Some(e)); + Poll::Ready(()) + } + } + } } #[derive(Debug)] @@ -145,13 +285,19 @@ impl WriteContext { self.0.memory_pool() } + #[inline] + /// Check readiness for write operations + pub async fn ready(&self) -> WriteStatus { + poll_fn(|cx| self.0.filter().poll_write_ready(cx)).await + } + #[inline] /// Check readiness for write operations pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll { self.0.filter().poll_write_ready(cx) } - /// Get read buffer + /// Get write buffer pub fn with_buf(&self, f: F) -> Poll> where F: FnOnce(&mut Option) -> Poll>, @@ -167,8 +313,8 @@ impl WriteContext { // if write buffer is smaller than high watermark value, turn off back-pressure let mut flags = inner.flags.get(); if len == 0 { - if flags.intersects(Flags::WR_WAIT | Flags::WR_BACKPRESSURE) { - flags.remove(Flags::WR_WAIT | Flags::WR_BACKPRESSURE); + if flags.is_waiting_for_write() { + flags.waiting_for_write_is_done(); inner.dispatch_task.wake(); } } else if flags.contains(Flags::WR_BACKPRESSURE) @@ -188,6 +334,57 @@ impl WriteContext { result } + /// Get write buffer (async) + pub async fn with_buf_async(&self, f: F) -> io::Result<()> + where + F: FnOnce(BytesVec) -> R, + R: Future>, + { + let inner = &self.0 .0; + + // running + let mut flags = inner.flags.get(); + if flags.contains(Flags::WR_PAUSED) { + flags.remove(Flags::WR_PAUSED); + inner.flags.set(flags); + } + + // buffer + let buf = inner.buffer.get_write_destination(); + + // call provided callback + let result = if let Some(buf) = buf { + if !buf.is_empty() { + f(buf).await + } else { + Ok(()) + } + } else { + Ok(()) + }; + + // if write buffer is smaller than high watermark value, turn off back-pressure + let mut flags = inner.flags.get(); + let len = inner.buffer.write_destination_size(); + + if len == 0 { + if flags.is_waiting_for_write() { + flags.waiting_for_write_is_done(); + inner.dispatch_task.wake(); + } + flags.insert(Flags::WR_PAUSED); + inner.flags.set(flags); + } else if flags.contains(Flags::WR_BACKPRESSURE) + && len < inner.pool.get().write_params_high() << 1 + { + flags.remove(Flags::WR_BACKPRESSURE); + inner.flags.set(flags); + inner.dispatch_task.wake(); + } + + result + } + #[inline] /// Indicate that write io task is stopped pub fn close(&self, err: Option) { @@ -210,7 +407,7 @@ fn shutdown_filters(io: &IoRef) { // check read buffer, if buffer is not consumed it is unlikely // that filter will properly complete shutdown if flags.contains(Flags::RD_PAUSED) - || flags.contains(Flags::RD_BUF_FULL | Flags::RD_READY) + || flags.contains(Flags::BUF_R_FULL | Flags::BUF_R_READY) { st.dispatch_task.wake(); st.insert_flags(Flags::IO_STOPPING); diff --git a/ntex-tls/CHANGES.md b/ntex-tls/CHANGES.md index 80dfd39f..bbae1b99 100644 --- a/ntex-tls/CHANGES.md +++ b/ntex-tls/CHANGES.md @@ -1,5 +1,9 @@ # Changes +## [2.1.0] - 2024-08-28 + +* Update io api usage + ## [2.0.1] - 2024-08-26 * Fix rustls client/server filters diff --git a/ntex-tls/Cargo.toml b/ntex-tls/Cargo.toml index 415b427f..412b02e9 100644 --- a/ntex-tls/Cargo.toml +++ b/ntex-tls/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-tls" -version = "2.0.1" +version = "2.1.0" authors = ["ntex contributors "] description = "An implementation of SSL streams for ntex backed by OpenSSL" keywords = ["network", "framework", "async", "futures"] @@ -26,7 +26,7 @@ rustls = ["tls_rust"] [dependencies] ntex-bytes = "0.1" -ntex-io = "2" +ntex-io = "2.3" ntex-util = "2" ntex-service = "3" ntex-net = "2" diff --git a/ntex-tls/src/openssl/mod.rs b/ntex-tls/src/openssl/mod.rs index 29d4d9fe..45ed1fcd 100644 --- a/ntex-tls/src/openssl/mod.rs +++ b/ntex-tls/src/openssl/mod.rs @@ -248,7 +248,7 @@ async fn handle_result( Ok(v) => Ok(Some(v)), Err(e) => match e.code() { ssl::ErrorCode::WANT_READ => { - let res = io.force_read_ready().await; + let res = io.read_notify().await; match res? { None => Err(io::Error::new(io::ErrorKind::Other, "disconnected")), _ => Ok(None), diff --git a/ntex-tls/src/rustls/client.rs b/ntex-tls/src/rustls/client.rs index 70a8c264..1ebe0669 100644 --- a/ntex-tls/src/rustls/client.rs +++ b/ntex-tls/src/rustls/client.rs @@ -164,7 +164,7 @@ impl TlsClientFilter { } poll_fn(|cx| { let read_ready = if wants_read { - match ready!(io.poll_force_read_ready(cx))? { + match ready!(io.poll_read_notify(cx))? { Some(_) => Ok(true), None => Err(io::Error::new( io::ErrorKind::Other, diff --git a/ntex-tls/src/rustls/server.rs b/ntex-tls/src/rustls/server.rs index 574183cd..2637b82f 100644 --- a/ntex-tls/src/rustls/server.rs +++ b/ntex-tls/src/rustls/server.rs @@ -173,7 +173,7 @@ impl TlsServerFilter { } poll_fn(|cx| { let read_ready = if wants_read { - match ready!(io.poll_force_read_ready(cx))? { + match ready!(io.poll_read_notify(cx))? { Some(_) => Ok(true), None => Err(io::Error::new( io::ErrorKind::Other,