cleanup Filter trait

This commit is contained in:
Nikolay Kim 2021-12-22 13:50:14 +06:00
parent fd97208a01
commit 8bbbfde22d
12 changed files with 223 additions and 167 deletions

View file

@ -3,7 +3,7 @@ use std::{any, io, task::Context, task::Poll};
use ntex_bytes::BytesMut;
use super::io::Flags;
use super::{Filter, IoRef, WriteReadiness};
use super::{Filter, IoRef, ReadStatus, WriteStatus};
pub struct Base(IoRef);
@ -14,19 +14,6 @@ impl Base {
}
impl Filter for Base {
#[inline]
fn shutdown(&self, _: &IoRef) -> Poll<Result<(), io::Error>> {
let mut flags = self.0.flags();
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
flags.insert(Flags::IO_SHUTDOWN);
self.0.set_flags(flags);
self.0 .0.read_task.wake();
self.0 .0.write_task.wake();
}
Poll::Ready(Ok(()))
}
#[inline]
fn closed(&self, err: Option<io::Error>) {
self.0 .0.set_error(err);
@ -47,45 +34,61 @@ impl Filter for Base {
}
#[inline]
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), ()>> {
fn want_read(&self) {
todo!()
}
#[inline]
fn want_shutdown(&self) {
todo!()
}
#[inline]
fn poll_shutdown(&self) -> Poll<io::Result<()>> {
let mut flags = self.0.flags();
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
flags.insert(Flags::IO_SHUTDOWN);
self.0.set_flags(flags);
self.0 .0.read_task.wake();
self.0 .0.write_task.wake();
}
Poll::Ready(Ok(()))
}
#[inline]
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> {
let flags = self.0.flags();
if flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
Poll::Ready(Err(()))
Poll::Ready(ReadStatus::Terminate)
} else if flags.intersects(Flags::RD_PAUSED) {
self.0 .0.read_task.register(cx.waker());
Poll::Pending
} else {
self.0 .0.read_task.register(cx.waker());
Poll::Ready(Ok(()))
Poll::Ready(ReadStatus::Ready)
}
}
#[inline]
fn poll_write_ready(
&self,
cx: &mut Context<'_>,
) -> Poll<Result<(), WriteReadiness>> {
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> {
let mut flags = self.0.flags();
if flags.contains(Flags::IO_ERR) {
Poll::Ready(Err(WriteReadiness::Terminate))
Poll::Ready(WriteStatus::Terminate)
} else if flags.intersects(Flags::IO_SHUTDOWN) {
Poll::Ready(Err(WriteReadiness::Shutdown(
self.0 .0.disconnect_timeout.get(),
)))
Poll::Ready(WriteStatus::Shutdown(self.0 .0.disconnect_timeout.get()))
} else if flags.contains(Flags::IO_FILTERS)
&& !flags.contains(Flags::IO_FILTERS_TO)
{
flags.insert(Flags::IO_FILTERS_TO);
self.0.set_flags(flags);
self.0 .0.write_task.register(cx.waker());
Poll::Ready(Err(WriteReadiness::Timeout(
self.0 .0.disconnect_timeout.get(),
)))
Poll::Ready(WriteStatus::Timeout(self.0 .0.disconnect_timeout.get()))
} else {
self.0 .0.write_task.register(cx.waker());
Poll::Ready(Ok(()))
Poll::Ready(WriteStatus::Ready)
}
}
@ -144,22 +147,26 @@ impl NullFilter {
}
impl Filter for NullFilter {
fn shutdown(&self, _: &IoRef) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}
fn closed(&self, _: Option<io::Error>) {}
fn query(&self, _: any::TypeId) -> Option<Box<dyn any::Any>> {
None
}
fn poll_read_ready(&self, _: &mut Context<'_>) -> Poll<Result<(), ()>> {
Poll::Ready(Err(()))
fn closed(&self, _: Option<io::Error>) {}
fn want_read(&self) {}
fn want_shutdown(&self) {}
fn poll_shutdown(&self) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_write_ready(&self, _: &mut Context<'_>) -> Poll<Result<(), WriteReadiness>> {
Poll::Ready(Err(WriteReadiness::Terminate))
fn poll_read_ready(&self, _: &mut Context<'_>) -> Poll<ReadStatus> {
Poll::Ready(ReadStatus::Terminate)
}
fn poll_write_ready(&self, _: &mut Context<'_>) -> Poll<WriteStatus> {
Poll::Ready(WriteStatus::Terminate)
}
fn get_read_buf(&self) -> Option<BytesMut> {

View file

@ -125,13 +125,13 @@ impl IoState {
#[inline]
/// Gracefully shutdown read and write io tasks
pub(super) fn init_shutdown(&self, cx: Option<&mut Context<'_>>, st: &IoRef) {
pub(super) fn init_shutdown(&self, cx: Option<&mut Context<'_>>) {
let flags = self.flags.get();
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN | Flags::IO_FILTERS) {
log::trace!("initiate io shutdown {:?}", flags);
self.insert_flags(Flags::IO_FILTERS);
if let Err(err) = self.shutdown_filters(st) {
if let Err(err) = self.shutdown_filters() {
self.error.set(Some(err));
}
@ -144,10 +144,10 @@ impl IoState {
}
#[inline]
pub(super) fn shutdown_filters(&self, st: &IoRef) -> Result<(), io::Error> {
pub(super) fn shutdown_filters(&self) -> Result<(), io::Error> {
let mut flags = self.flags.get();
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
let result = match self.filter.get().shutdown(st) {
let result = match self.filter.get().poll_shutdown() {
Poll::Pending => return Ok(()),
Poll::Ready(Ok(())) => {
flags.insert(Flags::IO_SHUTDOWN);
@ -619,7 +619,7 @@ impl<F> Io<F> {
Poll::Ready(Ok(()))
} else {
if !flags.contains(Flags::IO_FILTERS) {
self.0 .0.init_shutdown(Some(cx), self.as_ref());
self.0 .0.init_shutdown(Some(cx));
}
if let Some(err) = self.0 .0.error.take() {

View file

@ -292,7 +292,7 @@ mod tests {
use super::*;
use crate::testing::IoTest;
use crate::{Filter, FilterFactory, Io, WriteReadiness};
use crate::{Filter, FilterFactory, Io, ReadStatus, WriteStatus};
const BIN: &[u8] = b"GET /test HTTP/1\r\n\r\n";
const TEXT: &str = "GET /test HTTP/1\r\n\r\n";
@ -402,15 +402,19 @@ mod tests {
out_bytes: Rc<Cell<usize>>,
}
impl<F: Filter> Filter for Counter<F> {
fn shutdown(&self, _: &IoRef) -> Poll<Result<(), io::Error>> {
fn poll_shutdown(&self) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn want_read(&self) {}
fn want_shutdown(&self) {}
fn query(&self, _: std::any::TypeId) -> Option<Box<dyn std::any::Any>> {
None
}
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), ()>> {
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> {
self.inner.poll_read_ready(cx)
}
@ -431,10 +435,7 @@ mod tests {
self.inner.release_read_buf(buf, new_bytes)
}
fn poll_write_ready(
&self,
cx: &mut Context<'_>,
) -> Poll<Result<(), WriteReadiness>> {
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> {
self.inner.poll_write_ready(cx)
}

View file

@ -34,31 +34,43 @@ pub use self::utils::{filter_factory, into_boxed};
pub type IoBoxed = Io<Box<dyn Filter>>;
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub enum WriteReadiness {
pub enum ReadStatus {
Ready,
Terminate,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub enum WriteStatus {
Ready,
Timeout(Millis),
Shutdown(Millis),
Terminate,
}
pub trait Filter: 'static {
fn shutdown(&self, st: &IoRef) -> Poll<Result<(), IoError>>;
fn closed(&self, err: Option<IoError>);
fn query(&self, id: TypeId) -> Option<Box<dyn Any>>;
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), ()>>;
/// Filter needs incoming data from io stream
fn want_read(&self);
fn poll_write_ready(&self, cx: &mut Context<'_>)
-> Poll<Result<(), WriteReadiness>>;
/// Filter wants gracefully shutdown io stream
fn want_shutdown(&self);
fn poll_shutdown(&self) -> Poll<std::io::Result<()>>;
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus>;
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus>;
fn get_read_buf(&self) -> Option<BytesMut>;
fn get_write_buf(&self) -> Option<BytesMut>;
fn release_read_buf(&self, buf: BytesMut, nbytes: usize) -> Result<(), IoError>;
fn release_read_buf(&self, buf: BytesMut, nbytes: usize) -> std::io::Result<()>;
fn release_write_buf(&self, buf: BytesMut) -> Result<(), IoError>;
fn release_write_buf(&self, buf: BytesMut) -> std::io::Result<()>;
fn closed(&self, err: Option<std::io::Error>);
}
pub trait FilterFactory<F: Filter>: Sized {

View file

@ -2,7 +2,7 @@ use std::{io, task::Context, task::Poll};
use ntex_bytes::{BytesMut, PoolRef};
use super::{io::Flags, IoRef, WriteReadiness};
use super::{io::Flags, IoRef, ReadStatus, WriteStatus};
pub struct ReadContext(pub(super) IoRef);
@ -13,7 +13,7 @@ impl ReadContext {
}
#[inline]
pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), ()>> {
pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> {
self.0.filter().poll_read_ready(cx)
}
@ -51,7 +51,7 @@ impl ReadContext {
self.0.filter().release_read_buf(buf, new_bytes)?;
if flags.contains(Flags::IO_FILTERS) {
self.0 .0.shutdown_filters(&self.0)?;
self.0 .0.shutdown_filters()?;
}
Ok(())
}
@ -67,7 +67,7 @@ impl WriteContext {
}
#[inline]
pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), WriteReadiness>> {
pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> {
self.0.filter().poll_write_ready(cx)
}
@ -106,7 +106,7 @@ impl WriteContext {
}
if flags.contains(Flags::IO_FILTERS) {
self.0 .0.shutdown_filters(&self.0)?;
self.0 .0.shutdown_filters()?;
}
Ok(())
}

View file

@ -7,7 +7,9 @@ use ntex_bytes::{Buf, BufMut, BytesMut};
use ntex_util::future::poll_fn;
use ntex_util::time::{sleep, Millis, Sleep};
use crate::{types, Handle, IoStream, ReadContext, WriteContext, WriteReadiness};
use crate::{
types, Handle, IoStream, ReadContext, ReadStatus, WriteContext, WriteStatus,
};
#[derive(Default)]
struct AtomicWaker(Arc<Mutex<RefCell<Option<Waker>>>>);
@ -388,11 +390,11 @@ impl Future for ReadTask {
let this = self.as_ref();
match this.state.poll_ready(cx) {
Poll::Ready(Err(())) => {
Poll::Ready(ReadStatus::Terminate) => {
log::trace!("read task is instructed to terminate");
Poll::Ready(())
}
Poll::Ready(Ok(())) => {
Poll::Ready(ReadStatus::Ready) => {
let io = &this.io;
let pool = this.state.memory_pool();
let mut buf = self.state.get_read_buf();
@ -474,20 +476,20 @@ impl Future for WriteTask {
match this.st {
IoWriteState::Processing(ref mut delay) => {
match this.state.poll_ready(cx) {
Poll::Ready(Ok(())) => {
Poll::Ready(WriteStatus::Ready) => {
// flush framed instance
match flush_io(&this.io, &this.state, cx) {
Poll::Pending | Poll::Ready(true) => Poll::Pending,
Poll::Ready(false) => Poll::Ready(()),
}
}
Poll::Ready(Err(WriteReadiness::Timeout(time))) => {
Poll::Ready(WriteStatus::Timeout(time)) => {
if delay.is_none() {
*delay = Some(sleep(time));
}
self.poll(cx)
}
Poll::Ready(Err(WriteReadiness::Shutdown(time))) => {
Poll::Ready(WriteStatus::Shutdown(time)) => {
log::trace!("write task is instructed to shutdown");
let timeout = if let Some(delay) = delay.take() {
@ -499,7 +501,7 @@ impl Future for WriteTask {
this.st = IoWriteState::Shutdown(Some(timeout), Shutdown::None);
self.poll(cx)
}
Poll::Ready(Err(WriteReadiness::Terminate)) => {
Poll::Ready(WriteStatus::Terminate) => {
log::trace!("write task is instructed to terminate");
// shutdown WRITE side
this.io

View file

@ -7,8 +7,8 @@ use tok_io::io::{AsyncRead, AsyncWrite, ReadBuf};
use tok_io::net::TcpStream;
use crate::{
types, Filter, Handle, Io, IoBoxed, IoStream, ReadContext, WriteContext,
WriteReadiness,
types, Filter, Handle, Io, IoBoxed, IoStream, ReadContext, ReadStatus, WriteContext,
WriteStatus,
};
impl IoStream for TcpStream {
@ -52,11 +52,7 @@ impl Future for ReadTask {
let this = self.as_ref();
match this.state.poll_ready(cx) {
Poll::Ready(Err(())) => {
log::trace!("read task is instructed to shutdown");
Poll::Ready(())
}
Poll::Ready(Ok(())) => {
Poll::Ready(ReadStatus::Ready) => {
let pool = this.state.memory_pool();
let mut io = this.io.borrow_mut();
let mut buf = self.state.get_read_buf();
@ -107,6 +103,10 @@ impl Future for ReadTask {
Poll::Pending
}
}
Poll::Ready(ReadStatus::Terminate) => {
log::trace!("read task is instructed to shutdown");
Poll::Ready(())
}
Poll::Pending => Poll::Pending,
}
}
@ -152,7 +152,7 @@ impl Future for WriteTask {
match this.st {
IoWriteState::Processing(ref mut delay) => {
match this.state.poll_ready(cx) {
Poll::Ready(Ok(())) => {
Poll::Ready(WriteStatus::Ready) => {
if let Some(delay) = delay {
if delay.poll_elapsed(cx).is_ready() {
this.state.close(Some(io::Error::new(
@ -169,14 +169,14 @@ impl Future for WriteTask {
Poll::Ready(false) => Poll::Ready(()),
}
}
Poll::Ready(Err(WriteReadiness::Timeout(time))) => {
Poll::Ready(WriteStatus::Timeout(time)) => {
log::trace!("initiate timeout delay for {:?}", time);
if delay.is_none() {
*delay = Some(sleep(time));
}
self.poll(cx)
}
Poll::Ready(Err(WriteReadiness::Shutdown(time))) => {
Poll::Ready(WriteStatus::Shutdown(time)) => {
log::trace!("write task is instructed to shutdown");
let timeout = if let Some(delay) = delay.take() {
@ -188,7 +188,7 @@ impl Future for WriteTask {
this.st = IoWriteState::Shutdown(timeout, Shutdown::None);
self.poll(cx)
}
Poll::Ready(Err(WriteReadiness::Terminate)) => {
Poll::Ready(WriteStatus::Terminate) => {
log::trace!("write task is instructed to terminate");
let _ = Pin::new(&mut *this.io.borrow_mut()).poll_shutdown(cx);
@ -479,11 +479,7 @@ mod unixstream {
let this = self.as_ref();
match this.state.poll_ready(cx) {
Poll::Ready(Err(())) => {
log::trace!("read task is instructed to shutdown");
Poll::Ready(())
}
Poll::Ready(Ok(())) => {
Poll::Ready(ReadStatus::Ready) => {
let pool = this.state.memory_pool();
let mut io = this.io.borrow_mut();
let mut buf = self.state.get_read_buf();
@ -534,6 +530,10 @@ mod unixstream {
Poll::Pending
}
}
Poll::Ready(ReadStatus::Terminate) => {
log::trace!("read task is instructed to shutdown");
Poll::Ready(())
}
Poll::Pending => Poll::Pending,
}
}
@ -566,7 +566,7 @@ mod unixstream {
match this.st {
IoWriteState::Processing(ref mut delay) => {
match this.state.poll_ready(cx) {
Poll::Ready(Ok(())) => {
Poll::Ready(WriteStatus::Ready) => {
if let Some(delay) = delay {
if delay.poll_elapsed(cx).is_ready() {
this.state.close(Some(io::Error::new(
@ -583,13 +583,13 @@ mod unixstream {
Poll::Ready(false) => Poll::Ready(()),
}
}
Poll::Ready(Err(WriteReadiness::Timeout(time))) => {
Poll::Ready(WriteStatus::Timeout(time)) => {
if delay.is_none() {
*delay = Some(sleep(time));
}
self.poll(cx)
}
Poll::Ready(Err(WriteReadiness::Shutdown(time))) => {
Poll::Ready(WriteStatus::Shutdown(time)) => {
log::trace!("write task is instructed to shutdown");
let timeout = if let Some(delay) = delay.take() {
@ -601,7 +601,7 @@ mod unixstream {
this.st = IoWriteState::Shutdown(timeout, Shutdown::None);
self.poll(cx)
}
Poll::Ready(Err(WriteReadiness::Terminate)) => {
Poll::Ready(WriteStatus::Terminate) => {
log::trace!("write task is instructed to terminate");
let _ =