Extend io task contexts, for compio runtime compatibility

This commit is contained in:
Nikolay Kim 2024-08-28 00:21:18 +05:00
parent 487faa3379
commit 6c907d1f45
14 changed files with 334 additions and 72 deletions

View file

@ -1,5 +1,9 @@
# Changes
## [2.3.0] - 2024-08-28
* Extend io task contexts, for "compio" runtime compatibility
## [2.2.0] - 2024-08-12
* Allow to notify dispatcher from IoRef

View file

@ -1,6 +1,6 @@
[package]
name = "ntex-io"
version = "2.2.0"
version = "2.3.0"
authors = ["ntex contributors <team@ntex.rs>"]
description = "Utilities for encoding and decoding frames"
keywords = ["network", "framework", "async", "futures"]
@ -18,8 +18,8 @@ path = "src/lib.rs"
[dependencies]
ntex-codec = "0.6.2"
ntex-bytes = "0.1.24"
ntex-util = "2.2"
ntex-service = "3.0"
ntex-util = "2.3"
ntex-service = "3"
bitflags = "2"
log = "0.4"

View file

@ -140,6 +140,18 @@ impl Stack {
})
}
pub(crate) fn get_read_source(&self) -> Option<BytesVec> {
self.get_last_level().0.take()
}
pub(crate) fn set_read_source(&self, io: &IoRef, buf: BytesVec) {
if buf.is_empty() {
io.memory_pool().release_read_buf(buf);
} else {
self.get_last_level().0.set(Some(buf));
}
}
pub(crate) fn with_read_source<F, R>(&self, io: &IoRef, f: F) -> R
where
F: FnOnce(&mut BytesVec) -> R,
@ -210,6 +222,10 @@ impl Stack {
result
}
pub(crate) fn get_write_destination(&self) -> Option<BytesVec> {
self.get_last_level().1.take()
}
pub(crate) fn with_write_destination<F, R>(&self, io: &IoRef, f: F) -> R
where
F: FnOnce(&mut Option<BytesVec>) -> R,

View file

@ -1,6 +1,6 @@
use std::{any, io, task::Context, task::Poll};
use super::{buf::Stack, io::Flags, FilterLayer, IoRef, ReadStatus, WriteStatus};
use crate::{buf::Stack, FilterLayer, Flags, IoRef, ReadStatus, WriteStatus};
#[derive(Debug)]
/// Default `Io` filter
@ -80,9 +80,10 @@ impl Filter for Base {
Poll::Ready(ReadStatus::Terminate)
} else {
self.0 .0.read_task.register(cx.waker());
if flags.intersects(Flags::IO_STOPPING_FILTERS) {
Poll::Ready(ReadStatus::Ready)
} else if flags.intersects(Flags::RD_PAUSED | Flags::RD_BUF_FULL) {
} else if flags.cannot_read() {
Poll::Pending
} else {
Poll::Ready(ReadStatus::Ready)
@ -109,6 +110,9 @@ impl Filter for Base {
Poll::Ready(WriteStatus::Timeout(
self.0 .0.disconnect_timeout.get().into(),
))
} else if flags.intersects(Flags::WR_PAUSED) {
self.0 .0.write_task.register(cx.waker());
Poll::Pending
} else {
self.0 .0.write_task.register(cx.waker());
Poll::Ready(WriteStatus::Ready)

58
ntex-io/src/flags.rs Normal file
View file

@ -0,0 +1,58 @@
bitflags::bitflags! {
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub struct Flags: u16 {
/// io is closed
const IO_STOPPED = 0b0000_0000_0000_0001;
/// shutdown io tasks
const IO_STOPPING = 0b0000_0000_0000_0010;
/// shuting down filters
const IO_STOPPING_FILTERS = 0b0000_0000_0000_0100;
/// initiate filters shutdown timeout in write task
const IO_FILTERS_TIMEOUT = 0b0000_0000_0000_1000;
/// pause io read
const RD_PAUSED = 0b0000_0000_0001_0000;
/// read any data and notify dispatcher
const RD_NOTIFY = 0b0000_0000_1000_0000;
/// new data is available in read buffer
const BUF_R_READY = 0b0000_0000_0010_0000;
/// read buffer is full
const BUF_R_FULL = 0b0000_0000_0100_0000;
/// wait while write task flushes buf
const BUF_W_MUST_FLUSH = 0b0000_0001_0000_0000;
/// write buffer is full
const WR_BACKPRESSURE = 0b0000_0010_0000_0000;
/// write task paused
const WR_PAUSED = 0b0000_0100_0000_0000;
/// dispatcher is marked stopped
const DSP_STOP = 0b0001_0000_0000_0000;
/// timeout occured
const DSP_TIMEOUT = 0b0010_0000_0000_0000;
}
}
impl Flags {
pub(crate) fn is_waiting_for_write(&self) -> bool {
self.intersects(Flags::BUF_W_MUST_FLUSH | Flags::WR_BACKPRESSURE)
}
pub(crate) fn waiting_for_write_is_done(&mut self) {
self.remove(Flags::BUF_W_MUST_FLUSH | Flags::WR_BACKPRESSURE);
}
pub(crate) fn is_read_buf_ready(&self) -> bool {
self.contains(Flags::BUF_R_READY)
}
pub(crate) fn cannot_read(self) -> bool {
self.intersects(Flags::RD_PAUSED | Flags::BUF_R_FULL)
}
pub(crate) fn cleanup_read_flags(&mut self) {
self.remove(Flags::BUF_R_READY | Flags::BUF_R_FULL | Flags::RD_PAUSED);
}
}

View file

@ -9,46 +9,12 @@ use ntex_util::{future::Either, task::LocalWaker, time::Seconds};
use crate::buf::Stack;
use crate::filter::{Base, Filter, Layer, NullFilter};
use crate::flags::Flags;
use crate::seal::Sealed;
use crate::tasks::{ReadContext, WriteContext};
use crate::timer::TimerHandle;
use crate::{Decoded, FilterLayer, Handle, IoStatusUpdate, IoStream, RecvError};
bitflags::bitflags! {
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub struct Flags: u16 {
/// io is closed
const IO_STOPPED = 0b0000_0000_0000_0001;
/// shutdown io tasks
const IO_STOPPING = 0b0000_0000_0000_0010;
/// shuting down filters
const IO_STOPPING_FILTERS = 0b0000_0000_0000_0100;
/// initiate filters shutdown timeout in write task
const IO_FILTERS_TIMEOUT = 0b0000_0000_0000_1000;
/// pause io read
const RD_PAUSED = 0b0000_0000_0001_0000;
/// new data is available
const RD_READY = 0b0000_0000_0010_0000;
/// read buffer is full
const RD_BUF_FULL = 0b0000_0000_0100_0000;
/// any new data is available
const RD_FORCE_READY = 0b0000_0000_1000_0000;
/// wait write completion
const WR_WAIT = 0b0000_0001_0000_0000;
/// write buffer is full
const WR_BACKPRESSURE = 0b0000_0010_0000_0000;
/// write task paused
const WR_PAUSED = 0b0000_0100_0000_0000;
/// dispatcher is marked stopped
const DSP_STOP = 0b0001_0000_0000_0000;
/// timeout occured
const DSP_TIMEOUT = 0b0010_0000_0000_0000;
}
}
/// Interface object to underlying io stream
pub struct Io<F = Base>(UnsafeCell<IoRef>, marker::PhantomData<F>);
@ -384,8 +350,14 @@ impl<F> Io<F> {
#[doc(hidden)]
#[inline]
/// Wait until read becomes ready.
pub async fn read_notify(&self) -> io::Result<Option<()>> {
poll_fn(|cx| self.poll_read_notify(cx)).await
}
#[doc(hidden)]
#[deprecated]
pub async fn force_read_ready(&self) -> io::Result<Option<()>> {
poll_fn(|cx| self.poll_force_read_ready(cx)).await
poll_fn(|cx| self.poll_read_notify(cx)).await
}
#[inline]
@ -454,9 +426,9 @@ impl<F> Io<F> {
} else {
st.dispatch_task.register(cx.waker());
let ready = flags.contains(Flags::RD_READY);
if flags.intersects(Flags::RD_BUF_FULL | Flags::RD_PAUSED) {
flags.remove(Flags::RD_READY | Flags::RD_BUF_FULL | Flags::RD_PAUSED);
let ready = flags.contains(Flags::BUF_R_READY);
if flags.cannot_read() {
flags.cleanup_read_flags();
st.read_task.wake();
st.flags.set(flags);
if ready {
@ -465,7 +437,7 @@ impl<F> Io<F> {
Poll::Pending
}
} else if ready {
flags.remove(Flags::RD_READY);
flags.remove(Flags::BUF_R_READY);
st.flags.set(flags);
Poll::Ready(Ok(Some(())))
} else {
@ -489,18 +461,15 @@ impl<F> Io<F> {
/// `Poll::Ready(Ok(Some(()))))` if the io stream is ready for reading.
/// `Poll::Ready(Ok(None))` if io stream is disconnected
/// `Some(Poll::Ready(Err(e)))` if an error is encountered.
pub fn poll_force_read_ready(
&self,
cx: &mut Context<'_>,
) -> Poll<io::Result<Option<()>>> {
pub fn poll_read_notify(&self, cx: &mut Context<'_>) -> Poll<io::Result<Option<()>>> {
let ready = self.poll_read_ready(cx);
if ready.is_pending() {
let st = self.st();
if st.remove_flags(Flags::RD_FORCE_READY) {
if st.remove_flags(Flags::RD_NOTIFY) {
Poll::Ready(Ok(Some(())))
} else {
st.insert_flags(Flags::RD_FORCE_READY);
st.insert_flags(Flags::RD_NOTIFY);
Poll::Pending
}
} else {
@ -508,6 +477,15 @@ impl<F> Io<F> {
}
}
#[doc(hidden)]
#[deprecated]
pub fn poll_force_read_ready(
&self,
cx: &mut Context<'_>,
) -> Poll<io::Result<Option<()>>> {
self.poll_read_notify(cx)
}
#[inline]
/// Decode codec item from incoming bytes stream.
///
@ -597,7 +575,7 @@ impl<F> Io<F> {
let len = st.buffer.write_destination_size();
if len > 0 {
if full {
st.insert_flags(Flags::WR_WAIT);
st.insert_flags(Flags::BUF_W_MUST_FLUSH);
st.dispatch_task.register(cx.waker());
return Poll::Pending;
} else if len >= st.pool.get().write_params_high() << 1 {
@ -606,7 +584,7 @@ impl<F> Io<F> {
return Poll::Pending;
}
}
st.remove_flags(Flags::WR_WAIT | Flags::WR_BACKPRESSURE);
st.remove_flags(Flags::BUF_W_MUST_FLUSH | Flags::WR_BACKPRESSURE);
Poll::Ready(Ok(()))
}
}

View file

@ -4,7 +4,7 @@ use ntex_bytes::{BytesVec, PoolRef};
use ntex_codec::{Decoder, Encoder};
use ntex_util::time::Seconds;
use super::{io::Flags, timer, types, Decoded, Filter, IoRef, OnDisconnect, WriteBuf};
use crate::{timer, types, Decoded, Filter, Flags, IoRef, OnDisconnect, WriteBuf};
impl IoRef {
#[inline]

View file

@ -11,6 +11,7 @@ pub mod types;
mod buf;
mod dispatcher;
mod filter;
mod flags;
mod framed;
mod io;
mod ioref;
@ -33,7 +34,7 @@ pub use self::timer::TimerHandle;
pub use self::utils::{seal, Decoded};
#[doc(hidden)]
pub use self::io::Flags;
pub use self::flags::Flags;
/// Status for read task
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]

View file

@ -1,8 +1,9 @@
use std::{io, task::Context, task::Poll};
use std::{future::poll_fn, future::Future, io, task::Context, task::Poll};
use ntex_bytes::{BytesVec, PoolRef};
use ntex_bytes::{BufMut, BytesVec, PoolRef};
use ntex_util::task;
use super::{io::Flags, IoRef, ReadStatus, WriteStatus};
use crate::{Flags, IoRef, ReadStatus, WriteStatus};
#[derive(Debug)]
/// Context for io read task
@ -19,6 +20,31 @@ impl ReadContext {
self.0.tag()
}
#[inline]
/// Check readiness for read operations
pub async fn ready(&self) -> ReadStatus {
poll_fn(|cx| self.0.filter().poll_read_ready(cx)).await
}
#[inline]
/// Wait when io get closed or preparing for close
pub async fn wait_for_close(&self) {
poll_fn(|cx| {
let flags = self.0.flags();
if flags.intersects(Flags::IO_STOPPING | Flags::IO_STOPPED) {
Poll::Ready(())
} else {
self.0 .0.read_task.register(cx.waker());
if flags.contains(Flags::IO_STOPPING_FILTERS) {
shutdown_filters(&self.0);
}
Poll::Pending
}
})
.await
}
#[inline]
/// Check readiness for read operations
pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> {
@ -56,9 +82,9 @@ impl ReadContext {
self.0.tag(),
total
);
inner.insert_flags(Flags::RD_READY | Flags::RD_BUF_FULL);
inner.insert_flags(Flags::BUF_R_READY | Flags::BUF_R_FULL);
} else {
inner.insert_flags(Flags::RD_READY);
inner.insert_flags(Flags::BUF_R_READY);
if nbytes >= hw {
// read task is paused because of read back-pressure
@ -82,8 +108,8 @@ impl ReadContext {
// otherwise read task would sleep forever
inner.read_task.wake();
}
if inner.flags.get().contains(Flags::RD_FORCE_READY) {
// in case of "force read" we must wake up dispatch task
if inner.flags.get().contains(Flags::RD_NOTIFY) {
// in case of "notify" we must wake up dispatch task
// if we read any data from source
inner.dispatch_task.wake();
}
@ -101,7 +127,7 @@ impl ReadContext {
.map_err(|err| {
inner.dispatch_task.wake();
inner.io_stopped(Some(err));
inner.insert_flags(Flags::RD_READY);
inner.insert_flags(Flags::BUF_R_READY);
});
}
@ -122,6 +148,120 @@ impl ReadContext {
}
}
}
/// Get read buffer (async)
pub async fn with_buf_async<F, R>(&self, f: F) -> Poll<()>
where
F: FnOnce(BytesVec) -> R,
R: Future<Output = (BytesVec, io::Result<usize>)>,
{
let inner = &self.0 .0;
// we already pushed new data to read buffer,
// we have to wait for dispatcher to read data from buffer
if inner.flags.get().is_read_buf_ready() {
task::yield_to().await;
}
let mut buf = if inner.flags.get().is_read_buf_ready() {
// read buffer is still not read by dispatcher
// we cannot touch it
inner.pool.get().get_read_buf()
} else {
inner
.buffer
.get_read_source()
.unwrap_or_else(|| inner.pool.get().get_read_buf())
};
// make sure we've got room
let remaining = buf.remaining_mut();
let (hw, lw) = self.0.memory_pool().read_params().unpack();
if remaining < lw {
buf.reserve(hw - remaining);
}
let total = buf.len();
// call provided callback
let (buf, result) = f(buf).await;
let total2 = buf.len();
let nbytes = if total2 > total { total2 - total } else { 0 };
let total = total2;
if let Some(mut first_buf) = inner.buffer.get_read_source() {
first_buf.extend_from_slice(&buf);
inner.buffer.set_read_source(&self.0, first_buf);
} else {
inner.buffer.set_read_source(&self.0, buf);
}
// handle buffer changes
if nbytes > 0 {
let filter = self.0.filter();
let res = match filter.process_read_buf(&self.0, &inner.buffer, 0, nbytes) {
Ok(status) => {
if status.nbytes > 0 {
// check read back-pressure
if hw < inner.buffer.read_destination_size() {
log::trace!(
"{}: Io read buffer is too large {}, enable read back-pressure",
self.0.tag(),
total
);
inner.insert_flags(Flags::BUF_R_READY | Flags::BUF_R_FULL);
} else {
inner.insert_flags(Flags::BUF_R_READY);
}
log::trace!(
"{}: New {} bytes available, wakeup dispatcher",
self.0.tag(),
nbytes
);
// dest buffer has new data, wake up dispatcher
inner.dispatch_task.wake();
} else if inner.flags.get().contains(Flags::RD_NOTIFY) {
// in case of "notify" we must wake up dispatch task
// if we read any data from source
inner.dispatch_task.wake();
}
// while reading, filter wrote some data
// in that case filters need to process write buffers
// and potentialy wake write task
if status.need_write {
filter.process_write_buf(&self.0, &inner.buffer, 0)
} else {
Ok(())
}
}
Err(err) => Err(err),
};
if let Err(err) = res {
inner.dispatch_task.wake();
inner.io_stopped(Some(err));
inner.insert_flags(Flags::BUF_R_READY);
}
}
match result {
Ok(n) => {
if n == 0 {
inner.io_stopped(None);
Poll::Ready(())
} else {
if inner.flags.get().contains(Flags::IO_STOPPING_FILTERS) {
shutdown_filters(&self.0);
}
Poll::Pending
}
}
Err(e) => {
inner.io_stopped(Some(e));
Poll::Ready(())
}
}
}
}
#[derive(Debug)]
@ -145,13 +285,19 @@ impl WriteContext {
self.0.memory_pool()
}
#[inline]
/// Check readiness for write operations
pub async fn ready(&self) -> WriteStatus {
poll_fn(|cx| self.0.filter().poll_write_ready(cx)).await
}
#[inline]
/// Check readiness for write operations
pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> {
self.0.filter().poll_write_ready(cx)
}
/// Get read buffer
/// Get write buffer
pub fn with_buf<F>(&self, f: F) -> Poll<io::Result<()>>
where
F: FnOnce(&mut Option<BytesVec>) -> Poll<io::Result<()>>,
@ -167,8 +313,8 @@ impl WriteContext {
// if write buffer is smaller than high watermark value, turn off back-pressure
let mut flags = inner.flags.get();
if len == 0 {
if flags.intersects(Flags::WR_WAIT | Flags::WR_BACKPRESSURE) {
flags.remove(Flags::WR_WAIT | Flags::WR_BACKPRESSURE);
if flags.is_waiting_for_write() {
flags.waiting_for_write_is_done();
inner.dispatch_task.wake();
}
} else if flags.contains(Flags::WR_BACKPRESSURE)
@ -188,6 +334,57 @@ impl WriteContext {
result
}
/// Get write buffer (async)
pub async fn with_buf_async<F, R>(&self, f: F) -> io::Result<()>
where
F: FnOnce(BytesVec) -> R,
R: Future<Output = io::Result<()>>,
{
let inner = &self.0 .0;
// running
let mut flags = inner.flags.get();
if flags.contains(Flags::WR_PAUSED) {
flags.remove(Flags::WR_PAUSED);
inner.flags.set(flags);
}
// buffer
let buf = inner.buffer.get_write_destination();
// call provided callback
let result = if let Some(buf) = buf {
if !buf.is_empty() {
f(buf).await
} else {
Ok(())
}
} else {
Ok(())
};
// if write buffer is smaller than high watermark value, turn off back-pressure
let mut flags = inner.flags.get();
let len = inner.buffer.write_destination_size();
if len == 0 {
if flags.is_waiting_for_write() {
flags.waiting_for_write_is_done();
inner.dispatch_task.wake();
}
flags.insert(Flags::WR_PAUSED);
inner.flags.set(flags);
} else if flags.contains(Flags::WR_BACKPRESSURE)
&& len < inner.pool.get().write_params_high() << 1
{
flags.remove(Flags::WR_BACKPRESSURE);
inner.flags.set(flags);
inner.dispatch_task.wake();
}
result
}
#[inline]
/// Indicate that write io task is stopped
pub fn close(&self, err: Option<io::Error>) {
@ -210,7 +407,7 @@ fn shutdown_filters(io: &IoRef) {
// check read buffer, if buffer is not consumed it is unlikely
// that filter will properly complete shutdown
if flags.contains(Flags::RD_PAUSED)
|| flags.contains(Flags::RD_BUF_FULL | Flags::RD_READY)
|| flags.contains(Flags::BUF_R_FULL | Flags::BUF_R_READY)
{
st.dispatch_task.wake();
st.insert_flags(Flags::IO_STOPPING);

View file

@ -1,5 +1,9 @@
# Changes
## [2.1.0] - 2024-08-28
* Update io api usage
## [2.0.1] - 2024-08-26
* Fix rustls client/server filters

View file

@ -1,6 +1,6 @@
[package]
name = "ntex-tls"
version = "2.0.1"
version = "2.1.0"
authors = ["ntex contributors <team@ntex.rs>"]
description = "An implementation of SSL streams for ntex backed by OpenSSL"
keywords = ["network", "framework", "async", "futures"]
@ -26,7 +26,7 @@ rustls = ["tls_rust"]
[dependencies]
ntex-bytes = "0.1"
ntex-io = "2"
ntex-io = "2.3"
ntex-util = "2"
ntex-service = "3"
ntex-net = "2"

View file

@ -248,7 +248,7 @@ async fn handle_result<T, F>(
Ok(v) => Ok(Some(v)),
Err(e) => match e.code() {
ssl::ErrorCode::WANT_READ => {
let res = io.force_read_ready().await;
let res = io.read_notify().await;
match res? {
None => Err(io::Error::new(io::ErrorKind::Other, "disconnected")),
_ => Ok(None),

View file

@ -164,7 +164,7 @@ impl TlsClientFilter {
}
poll_fn(|cx| {
let read_ready = if wants_read {
match ready!(io.poll_force_read_ready(cx))? {
match ready!(io.poll_read_notify(cx))? {
Some(_) => Ok(true),
None => Err(io::Error::new(
io::ErrorKind::Other,

View file

@ -173,7 +173,7 @@ impl TlsServerFilter {
}
poll_fn(|cx| {
let read_ready = if wants_read {
match ready!(io.poll_force_read_ready(cx))? {
match ready!(io.poll_read_notify(cx))? {
Some(_) => Ok(true),
None => Err(io::Error::new(
io::ErrorKind::Other,