Refactor filter shutdown

This commit is contained in:
Nikolay Kim 2024-09-11 00:34:18 +05:00
parent a9407562b5
commit 568df1cbe9
5 changed files with 65 additions and 110 deletions

View file

@ -93,26 +93,16 @@ impl Filter for Base {
#[inline]
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> {
let mut flags = self.0.flags();
let flags = self.0.flags();
if flags.is_stopped() {
Poll::Ready(WriteStatus::Terminate)
} else {
self.0 .0.write_task.register(cx.waker());
if flags.intersects(Flags::IO_STOPPING) {
Poll::Ready(WriteStatus::Shutdown(
self.0 .0.disconnect_timeout.get().into(),
))
} else if flags.contains(Flags::IO_STOPPING_FILTERS)
&& !flags.contains(Flags::IO_FILTERS_TIMEOUT)
{
flags.insert(Flags::IO_FILTERS_TIMEOUT);
self.0.set_flags(flags);
Poll::Ready(WriteStatus::Timeout(
self.0 .0.disconnect_timeout.get().into(),
))
} else if flags.intersects(Flags::WR_PAUSED) {
if flags.contains(Flags::IO_STOPPING) {
Poll::Ready(WriteStatus::Shutdown)
} else if flags.contains(Flags::WR_PAUSED) {
Poll::Pending
} else {
Poll::Ready(WriteStatus::Ready)
@ -242,20 +232,13 @@ where
Poll::Pending => Poll::Pending,
Poll::Ready(WriteStatus::Ready) => res2,
Poll::Ready(WriteStatus::Terminate) => Poll::Ready(WriteStatus::Terminate),
Poll::Ready(WriteStatus::Shutdown(t)) => {
Poll::Ready(WriteStatus::Shutdown) => {
if res2 == Poll::Ready(WriteStatus::Terminate) {
Poll::Ready(WriteStatus::Terminate)
} else {
Poll::Ready(WriteStatus::Shutdown(t))
Poll::Ready(WriteStatus::Shutdown)
}
}
Poll::Ready(WriteStatus::Timeout(t)) => match res2 {
Poll::Ready(WriteStatus::Terminate) => Poll::Ready(WriteStatus::Terminate),
Poll::Ready(WriteStatus::Shutdown(t)) => {
Poll::Ready(WriteStatus::Shutdown(t))
}
_ => Poll::Ready(WriteStatus::Timeout(t)),
},
}
}
}

View file

@ -7,8 +7,6 @@ bitflags::bitflags! {
const IO_STOPPING = 0b0000_0000_0000_0010;
/// shuting down filters
const IO_STOPPING_FILTERS = 0b0000_0000_0000_0100;
/// initiate filters shutdown timeout in write task
const IO_FILTERS_TIMEOUT = 0b0000_0000_0000_1000;
/// pause io read
const RD_PAUSED = 0b0000_0000_0001_0000;

View file

@ -14,12 +14,6 @@ impl IoRef {
self.0.flags.get()
}
#[inline]
/// Set flags
pub(crate) fn set_flags(&self, flags: Flags) {
self.0.flags.set(flags)
}
#[inline]
/// Get current filter
pub(crate) fn filter(&self) -> &dyn Filter {

View file

@ -23,7 +23,6 @@ mod utils;
use ntex_bytes::BytesVec;
use ntex_codec::{Decoder, Encoder};
use ntex_util::time::Millis;
pub use self::buf::{ReadBuf, WriteBuf};
pub use self::dispatcher::{Dispatcher, DispatcherConfig};
@ -64,10 +63,8 @@ pub enum ReadStatus {
pub enum WriteStatus {
/// Write task is clear to proceed with write operation
Ready,
/// Initiate timeout for normal write operations, shutdown connection after timeout
Timeout(Millis),
/// Initiate graceful io shutdown operation with timeout
Shutdown(Millis),
/// Initiate graceful io shutdown operation
Shutdown,
/// Immediately terminate connection
Terminate,
}

View file

@ -1,17 +1,22 @@
use std::{future::poll_fn, io, task::Poll};
use std::{cell::Cell, fmt, future::poll_fn, io, task::Context, task::Poll};
use ntex_bytes::{BufMut, BytesVec};
use ntex_util::{future::select, future::Either, time::sleep};
use ntex_util::{future::lazy, future::select, future::Either, time::sleep, time::Sleep};
use crate::{AsyncRead, AsyncWrite, Flags, IoRef, ReadStatus, WriteStatus};
#[derive(Debug)]
/// Context for io read task
pub struct ReadContext(IoRef);
pub struct ReadContext(IoRef, Cell<Option<Sleep>>);
impl fmt::Debug for ReadContext {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ReadContext").field("io", &self.0).finish()
}
}
impl ReadContext {
pub(crate) fn new(io: &IoRef) -> Self {
Self(io.clone())
Self(io.clone(), Cell::new(None))
}
#[inline]
@ -30,7 +35,7 @@ impl ReadContext {
} else {
self.0 .0.read_task.register(cx.waker());
if flags.contains(Flags::IO_STOPPING_FILTERS) {
shutdown_filters(&self.0);
self.shutdown_filters(cx);
}
Poll::Pending
}
@ -149,7 +154,7 @@ impl ReadContext {
}
Ok(_) => {
if inner.flags.get().contains(Flags::IO_STOPPING_FILTERS) {
shutdown_filters(&self.0);
lazy(|cx| self.shutdown_filters(cx)).await;
}
}
Err(err) => {
@ -160,6 +165,48 @@ impl ReadContext {
}
}
}
fn shutdown_filters(&self, cx: &mut Context<'_>) {
let st = &self.0 .0;
let filter = self.0.filter();
match filter.shutdown(&self.0, &st.buffer, 0) {
Ok(Poll::Ready(())) => {
st.dispatch_task.wake();
st.insert_flags(Flags::IO_STOPPING);
}
Ok(Poll::Pending) => {
let flags = st.flags.get();
// check read buffer, if buffer is not consumed it is unlikely
// that filter will properly complete shutdown
if flags.contains(Flags::RD_PAUSED)
|| flags.contains(Flags::BUF_R_FULL | Flags::BUF_R_READY)
{
st.dispatch_task.wake();
st.insert_flags(Flags::IO_STOPPING);
} else {
// filter shutdown timeout
let timeout = self
.1
.take()
.unwrap_or_else(|| sleep(st.disconnect_timeout.get()));
if timeout.poll_elapsed(cx).is_ready() {
st.dispatch_task.wake();
st.insert_flags(Flags::IO_STOPPING);
} else {
self.1.set(Some(timeout));
}
}
}
Err(err) => {
st.io_stopped(Some(err));
}
}
if let Err(err) = filter.process_write_buf(&self.0, &st.buffer, 0) {
st.io_stopped(Some(err));
}
}
}
#[derive(Debug)]
@ -212,41 +259,13 @@ impl WriteContext {
where
T: AsyncWrite,
{
let inner = &self.0 .0;
let mut delay = None;
let mut buf = WriteContextBuf {
io: self.0.clone(),
buf: None,
};
loop {
// check readiness
let result = if let Some(ref mut sleep) = delay {
let result = match select(sleep, self.ready()).await {
Either::Left(_) => {
self.close(Some(io::Error::new(
io::ErrorKind::TimedOut,
"Operation timedout",
)));
return;
}
Either::Right(res) => res,
};
delay = None;
result
} else {
self.ready().await
};
// running
let mut flags = inner.flags.get();
if flags.contains(Flags::WR_PAUSED) {
flags.remove(Flags::WR_PAUSED);
inner.flags.set(flags);
}
// handle write
match result {
match self.ready().await {
WriteStatus::Ready => {
// write io stream
match select(io.write(&mut buf), self.when_stopped()).await {
@ -255,12 +274,7 @@ impl WriteContext {
Either::Right(_) => return,
}
}
WriteStatus::Timeout(time) => {
log::trace!("{}: Initiate timeout delay for {:?}", self.tag(), time);
delay = Some(sleep(time));
continue;
}
WriteStatus::Shutdown(time) => {
WriteStatus::Shutdown => {
log::trace!("{}: Write task is instructed to shutdown", self.tag());
let fut = async {
@ -270,7 +284,7 @@ impl WriteContext {
io.shutdown().await?;
Ok(())
};
match select(sleep(time), fut).await {
match select(sleep(self.0 .0.disconnect_timeout.get()), fut).await {
Either::Left(_) => self.close(None),
Either::Right(res) => self.close(res.err()),
}
@ -328,34 +342,3 @@ impl WriteContextBuf {
}
}
}
fn shutdown_filters(io: &IoRef) {
let st = &io.0;
let flags = st.flags.get();
if !flags.intersects(Flags::IO_STOPPED | Flags::IO_STOPPING) {
let filter = io.filter();
match filter.shutdown(io, &st.buffer, 0) {
Ok(Poll::Ready(())) => {
st.dispatch_task.wake();
st.insert_flags(Flags::IO_STOPPING);
}
Ok(Poll::Pending) => {
// check read buffer, if buffer is not consumed it is unlikely
// that filter will properly complete shutdown
if flags.contains(Flags::RD_PAUSED)
|| flags.contains(Flags::BUF_R_FULL | Flags::BUF_R_READY)
{
st.dispatch_task.wake();
st.insert_flags(Flags::IO_STOPPING);
}
}
Err(err) => {
st.io_stopped(Some(err));
}
}
if let Err(err) = filter.process_write_buf(io, &st.buffer, 0) {
st.io_stopped(Some(err));
}
}
}