refactor io api and backpressure api

This commit is contained in:
Nikolay Kim 2021-12-21 14:17:29 +06:00
parent 1c4c842515
commit 6c68a59e99
12 changed files with 290 additions and 250 deletions

View file

@ -2,8 +2,14 @@
## [0.1.0-b.3] - 2021-12-xx
* Add .poll_write_backpressure()
* Rename .poll_read_next() to .poll_recv()
* Rename .poll_write_ready() to .poll_flush()
* Rename .next() to .recv()
* Rename .write_ready() to .flush()
## [0.1.0-b.2] - 2021-12-20

View file

@ -4,8 +4,8 @@ use std::{cell::Cell, future, pin::Pin, rc::Rc, task::Context, task::Poll, time}
use ntex_bytes::Pool;
use ntex_codec::{Decoder, Encoder};
use ntex_service::{IntoService, Service};
use ntex_util::future::Either;
use ntex_util::time::{now, Seconds};
use ntex_util::{future::Either, ready};
use super::{rt::spawn, DispatchItem, IoBoxed, IoRef, Timer};
@ -203,50 +203,46 @@ where
// handle memory pool pressure
if slf.pool.poll_ready(cx).is_pending() {
io.pause(cx);
io.pause();
return Poll::Pending;
}
loop {
match slf.st.get() {
DispatcherState::Processing => {
let result = if let Poll::Ready(result) =
slf.poll_service(this.service, cx, io)
{
result
} else {
return Poll::Pending;
};
let item = match result {
let item = match ready!(slf.poll_service(this.service, cx, io)) {
PollService::Ready => {
if !io.is_write_ready() {
// instruct write task to notify dispatcher when data is flushed
io.enable_write_backpressure(cx);
slf.st.set(DispatcherState::Backpressure);
DispatchItem::WBackPressureEnabled
} else {
// decode incoming bytes if buffer is ready
match io.poll_read_next(&slf.shared.codec, cx) {
Poll::Ready(Some(Ok(el))) => {
slf.update_keepalive();
DispatchItem::Item(el)
}
Poll::Ready(Some(Err(Either::Left(err)))) => {
slf.st.set(DispatcherState::Stop);
slf.unregister_keepalive();
DispatchItem::DecoderError(err)
}
Poll::Ready(Some(Err(Either::Right(err)))) => {
slf.st.set(DispatcherState::Stop);
slf.unregister_keepalive();
DispatchItem::Disconnect(Some(err))
}
Poll::Ready(None) => DispatchItem::Disconnect(None),
Poll::Pending => {
log::trace!("not enough data to decode next frame, register dispatch task");
io.resume();
return Poll::Pending;
match io.poll_write_backpressure(cx) {
Poll::Pending => {
// instruct write task to notify dispatcher when data is flushed
slf.st.set(DispatcherState::Backpressure);
DispatchItem::WBackPressureEnabled
}
Poll::Ready(()) => {
// decode incoming bytes if buffer is ready
match io.poll_recv(&slf.shared.codec, cx) {
Poll::Ready(Some(Ok(el))) => {
slf.update_keepalive();
DispatchItem::Item(el)
}
Poll::Ready(Some(Err(Either::Left(err)))) => {
slf.st.set(DispatcherState::Stop);
slf.unregister_keepalive();
DispatchItem::DecoderError(err)
}
Poll::Ready(Some(Err(Either::Right(err)))) => {
slf.st.set(DispatcherState::Stop);
slf.unregister_keepalive();
DispatchItem::Disconnect(Some(err))
}
Poll::Ready(None) => {
DispatchItem::Disconnect(None)
}
Poll::Pending => {
log::trace!("not enough data to decode next frame, register dispatch task");
io.resume();
return Poll::Pending;
}
}
}
}
@ -274,13 +270,10 @@ where
}
// handle write back-pressure
DispatcherState::Backpressure => {
let result = match slf.poll_service(this.service, cx, io) {
Poll::Ready(result) => result,
Poll::Pending => return Poll::Pending,
};
let result = ready!(slf.poll_service(this.service, cx, io));
let item = match result {
PollService::Ready => {
if io.is_write_ready() {
if slf.io.poll_write_backpressure(cx).is_ready() {
slf.st.set(DispatcherState::Processing);
DispatchItem::WBackPressureDisabled
} else {
@ -308,7 +301,7 @@ where
slf.spawn_service_call(this.service.call(item));
}
}
// drain service responses
// drain service responses and shutdown io
DispatcherState::Stop => {
// service may relay on poll_ready for response results
if !this.inner.ready_err.get() {
@ -434,7 +427,7 @@ where
// pause io read task
Poll::Pending => {
log::trace!("service is not ready, register dispatch task");
io.pause(cx);
io.pause();
Poll::Pending
}
// handle service readiness error

View file

@ -167,6 +167,37 @@ impl IoState {
Ok(())
}
}
#[inline]
pub(super) fn with_read_buf<Fn, Ret>(&self, release: bool, f: Fn) -> Ret
where
Fn: FnOnce(&mut Option<BytesMut>) -> Ret,
{
let buf = self.read_buf.as_ptr();
let ref_buf = unsafe { buf.as_mut().unwrap() };
let result = f(ref_buf);
// release buffer
if release {
if let Some(ref buf) = ref_buf {
if buf.is_empty() {
let buf = mem::take(ref_buf).unwrap();
self.pool.get().release_read_buf(buf);
}
}
}
result
}
#[inline]
pub(super) fn with_write_buf<Fn, Ret>(&self, f: Fn) -> Ret
where
Fn: FnOnce(&mut Option<BytesMut>) -> Ret,
{
let buf = self.write_buf.as_ptr();
let ref_buf = unsafe { buf.as_mut().unwrap() };
f(ref_buf)
}
}
impl Eq for IoState {}
@ -376,14 +407,29 @@ impl<F: Filter> Io<F> {
impl<F> Io<F> {
#[inline]
/// Read incoming io stream and decode codec item.
pub async fn next<U>(
pub async fn recv<U>(
&self,
codec: &U,
) -> Option<Result<U::Item, Either<U::Error, io::Error>>>
where
U: Decoder,
{
poll_fn(|cx| self.poll_read_next(codec, cx)).await
poll_fn(|cx| self.poll_recv(codec, cx)).await
}
#[inline]
/// Pause read task
pub fn pause(&self) {
self.0 .0.insert_flags(Flags::RD_PAUSED);
}
#[inline]
/// Wake read io ask if it is paused
pub fn resume(&self) {
if self.flags().contains(Flags::RD_PAUSED) {
self.0 .0.remove_flags(Flags::RD_PAUSED);
self.0 .0.read_task.wake();
}
}
#[inline]
@ -400,13 +446,8 @@ impl<F> Io<F> {
let mut buf = filter
.get_write_buf()
.unwrap_or_else(|| self.memory_pool().get_write_buf());
let is_write_sleep = buf.is_empty();
codec.encode(item, &mut buf).map_err(Either::Left)?;
filter.release_write_buf(buf).map_err(Either::Right)?;
if is_write_sleep {
self.0 .0.write_task.wake();
}
poll_fn(|cx| self.poll_flush(cx, true))
.await
@ -422,67 +463,11 @@ impl<F> Io<F> {
poll_fn(|cx| self.poll_flush(cx, full)).await
}
#[doc(hidden)]
#[deprecated]
#[inline]
pub async fn write_ready(&self, full: bool) -> Result<(), io::Error> {
poll_fn(|cx| self.poll_flush(cx, full)).await
}
#[inline]
/// Shut down connection
pub async fn shutdown(&self) -> Result<(), io::Error> {
poll_fn(|cx| self.poll_shutdown(cx)).await
}
}
impl<F> Io<F> {
#[inline]
/// Wake write task and instruct to flush data.
///
/// If `full` is true then wake up dispatcher when all data is flushed
/// otherwise wake up when size of write buffer is lower than
/// buffer max size.
pub fn poll_flush(&self, cx: &mut Context<'_>, full: bool) -> Poll<io::Result<()>> {
// check io error
if !self.0 .0.is_io_open() {
return Poll::Ready(Err(self.0 .0.error.take().unwrap_or_else(|| {
io::Error::new(io::ErrorKind::Other, "disconnected")
})));
}
if let Some(buf) = self.0 .0.write_buf.take() {
let len = buf.len();
if len != 0 {
self.0 .0.write_buf.set(Some(buf));
if full {
self.0 .0.insert_flags(Flags::WR_WAIT);
self.0 .0.dispatch_task.register(cx.waker());
return Poll::Pending;
} else if len >= self.0.memory_pool().write_params_high() << 1 {
self.0 .0.insert_flags(Flags::WR_BACKPRESSURE);
self.0 .0.dispatch_task.register(cx.waker());
return Poll::Pending;
} else {
self.0 .0.remove_flags(Flags::WR_BACKPRESSURE);
}
}
}
Poll::Ready(Ok(()))
}
#[doc(hidden)]
#[deprecated]
#[inline]
pub fn poll_write_ready(
&self,
cx: &mut Context<'_>,
full: bool,
) -> Poll<io::Result<()>> {
self.poll_flush(cx, full)
}
#[inline]
/// Wake read task and instruct to read more data
@ -525,7 +510,10 @@ impl<F> Io<F> {
#[inline]
#[allow(clippy::type_complexity)]
pub fn poll_read_next<U>(
/// Decode codec item from incoming bytes stream.
///
/// Wake read task and request to read more data if data is not enough for decoding.
pub fn poll_recv<U>(
&self,
codec: &U,
cx: &mut Context<'_>,
@ -544,6 +532,69 @@ impl<F> Io<F> {
}
}
#[inline]
/// Wake write task and instruct to flush data.
///
/// If `full` is true then wake up dispatcher when all data is flushed
/// otherwise wake up when size of write buffer is lower than
/// buffer max size.
pub fn poll_flush(&self, cx: &mut Context<'_>, full: bool) -> Poll<io::Result<()>> {
// check io error
if !self.0 .0.is_io_open() {
return Poll::Ready(Err(self.0 .0.error.take().unwrap_or_else(|| {
io::Error::new(io::ErrorKind::Other, "disconnected")
})));
}
if let Some(buf) = self.0 .0.write_buf.take() {
let len = buf.len();
if len != 0 {
self.0 .0.write_buf.set(Some(buf));
if full {
self.0 .0.insert_flags(Flags::WR_WAIT);
self.0 .0.dispatch_task.register(cx.waker());
return Poll::Pending;
} else if len >= self.0.memory_pool().write_params_high() << 1 {
self.0 .0.insert_flags(Flags::WR_BACKPRESSURE);
self.0 .0.dispatch_task.register(cx.waker());
return Poll::Pending;
} else {
self.0 .0.remove_flags(Flags::WR_BACKPRESSURE);
}
}
}
Poll::Ready(Ok(()))
}
#[inline]
/// Wait until write task flushes data to io stream
///
/// Write task must be waken up separately.
pub fn poll_write_backpressure(&self, cx: &mut Context<'_>) -> Poll<()> {
if !self.is_io_open() {
Poll::Ready(())
} else if self.flags().contains(Flags::WR_BACKPRESSURE) {
self.0 .0.dispatch_task.register(cx.waker());
Poll::Pending
} else {
let len = self
.0
.0
.with_write_buf(|buf| buf.as_ref().map(|b| b.len()).unwrap_or(0));
let hw = self.memory_pool().write_params_high();
if len >= hw {
log::trace!("enable write back-pressure");
self.0 .0.insert_flags(Flags::WR_BACKPRESSURE);
self.0 .0.dispatch_task.register(cx.waker());
Poll::Pending
} else {
Poll::Ready(())
}
}
}
#[inline]
/// Shut down connection
pub fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
@ -565,30 +616,55 @@ impl<F> Io<F> {
}
}
#[doc(hidden)]
#[deprecated]
#[inline]
/// Pause read task
pub fn pause(&self, cx: &mut Context<'_>) {
self.0 .0.insert_flags(Flags::RD_PAUSED);
self.0 .0.dispatch_task.register(cx.waker());
pub async fn next<U>(
&self,
codec: &U,
) -> Option<Result<U::Item, Either<U::Error, io::Error>>>
where
U: Decoder,
{
self.recv(codec).await
}
#[doc(hidden)]
#[deprecated]
#[inline]
/// Wake read io task if it is paused
pub fn resume(&self) -> bool {
let flags = self.0 .0.flags.get();
if flags.contains(Flags::RD_PAUSED) {
self.0 .0.remove_flags(Flags::RD_PAUSED);
self.0 .0.read_task.wake();
true
} else {
false
}
pub async fn write_ready(&self, full: bool) -> Result<(), io::Error> {
poll_fn(|cx| self.poll_flush(cx, full)).await
}
#[doc(hidden)]
#[deprecated]
#[inline]
pub fn poll_write_ready(
&self,
cx: &mut Context<'_>,
full: bool,
) -> Poll<io::Result<()>> {
self.poll_flush(cx, full)
}
#[doc(hidden)]
#[deprecated]
#[inline]
#[allow(clippy::type_complexity)]
pub fn poll_read_next<U>(
&self,
codec: &U,
cx: &mut Context<'_>,
) -> Poll<Option<Result<U::Item, Either<U::Error, io::Error>>>>
where
U: Decoder,
{
self.poll_recv(codec, cx)
}
#[doc(hidden)]
#[deprecated]
#[inline]
/// Wait until write task flushes data to io stream
///
/// Write task must be waken up separately.
pub fn enable_write_backpressure(&self, cx: &mut Context<'_>) {
log::trace!("enable write back-pressure for dispatcher");
self.0 .0.insert_flags(Flags::WR_BACKPRESSURE);

View file

@ -126,26 +126,19 @@ impl IoRef {
#[inline]
/// Check if write buffer is full
pub fn is_write_buf_full(&self) -> bool {
if let Some(buf) = self.0.read_buf.take() {
let hw = self.memory_pool().write_params_high();
let result = buf.len() >= hw;
self.0.write_buf.set(Some(buf));
result
} else {
false
}
let len = self
.0
.with_write_buf(|buf| buf.as_ref().map(|b| b.len()).unwrap_or(0));
len >= self.memory_pool().write_params_high()
}
#[inline]
/// Check if read buffer is full
pub fn is_read_buf_full(&self) -> bool {
if let Some(buf) = self.0.read_buf.take() {
let result = buf.len() >= self.memory_pool().read_params_high();
self.0.read_buf.set(Some(buf));
result
} else {
false
}
let len = self
.0
.with_read_buf(false, |buf| buf.as_ref().map(|b| b.len()).unwrap_or(0));
len >= self.memory_pool().read_params_high()
}
#[inline]
@ -167,9 +160,6 @@ impl IoRef {
let mut buf = filter
.get_write_buf()
.unwrap_or_else(|| self.memory_pool().get_write_buf());
if buf.is_empty() {
self.0.write_task.wake();
}
let result = f(&mut buf);
filter.release_write_buf(buf)?;
@ -182,18 +172,13 @@ impl IoRef {
where
F: FnOnce(&mut BytesMut) -> R,
{
let mut buf = self
.0
.read_buf
.take()
.unwrap_or_else(|| self.memory_pool().get_read_buf());
let res = f(&mut buf);
if buf.is_empty() {
self.memory_pool().release_read_buf(buf);
} else {
self.0.read_buf.set(Some(buf));
}
res
self.0.with_read_buf(true, |buf| {
// set buf
if buf.is_none() {
*buf = Some(self.memory_pool().get_read_buf());
}
f(buf.as_mut().unwrap())
})
}
#[inline]
@ -252,12 +237,9 @@ impl IoRef {
where
U: Decoder,
{
if let Some(mut buf) = self.0.read_buf.take() {
let result = codec.decode(&mut buf);
self.0.read_buf.set(Some(buf));
return result;
}
Ok(None)
self.0.with_read_buf(false, |buf| {
buf.as_mut().map(|b| codec.decode(b)).unwrap_or(Ok(None))
})
}
#[inline]
@ -325,20 +307,20 @@ mod tests {
assert!(!state.is_read_buf_full());
assert!(!state.is_write_buf_full());
let msg = state.next(&BytesCodec).await.unwrap().unwrap();
let msg = state.recv(&BytesCodec).await.unwrap().unwrap();
assert_eq!(msg, Bytes::from_static(BIN));
let res = poll_fn(|cx| Poll::Ready(state.poll_read_next(&BytesCodec, cx))).await;
let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
assert!(res.is_pending());
client.write(TEXT);
sleep(Millis(50)).await;
let res = poll_fn(|cx| Poll::Ready(state.poll_read_next(&BytesCodec, cx))).await;
let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
if let Poll::Ready(msg) = res {
assert_eq!(msg.unwrap().unwrap(), Bytes::from_static(BIN));
}
client.read_error(io::Error::new(io::ErrorKind::Other, "err"));
let msg = state.next(&BytesCodec).await;
let msg = state.recv(&BytesCodec).await;
assert!(msg.unwrap().is_err());
assert!(state.flags().contains(Flags::IO_ERR));
assert!(state.flags().contains(Flags::DSP_STOP));
@ -348,7 +330,7 @@ mod tests {
let state = Io::new(server);
client.read_error(io::Error::new(io::ErrorKind::Other, "err"));
let res = poll_fn(|cx| Poll::Ready(state.poll_read_next(&BytesCodec, cx))).await;
let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
if let Poll::Ready(msg) = res {
assert!(msg.unwrap().is_err());
assert!(state.flags().contains(Flags::IO_ERR));
@ -506,7 +488,7 @@ mod tests {
client.remote_buffer_cap(1024);
client.write(TEXT);
let msg = state.next(&BytesCodec).await.unwrap().unwrap();
let msg = state.recv(&BytesCodec).await.unwrap().unwrap();
assert_eq!(msg, Bytes::from_static(BIN));
state
@ -537,7 +519,7 @@ mod tests {
client.remote_buffer_cap(1024);
client.write(TEXT);
let msg = state.next(&BytesCodec).await.unwrap().unwrap();
let msg = state.recv(&BytesCodec).await.unwrap().unwrap();
assert_eq!(msg, Bytes::from_static(BIN));
state