Fix read filters ordering

This commit is contained in:
Nikolay Kim 2021-12-25 21:21:34 +06:00
parent 531bafbae2
commit 079a1c9cbf
23 changed files with 642 additions and 325 deletions

View file

@ -1,5 +1,9 @@
# Changes
## [0.1.0-b.6] - 2021-12-xx
* Fix read filters ordering
## [0.1.0-b.5] - 2021-12-24
* Use new ntex-service traits

View file

@ -1,6 +1,6 @@
[package]
name = "ntex-io"
version = "0.1.0-b.5"
version = "0.1.0-b.6"
authors = ["ntex contributors <team@ntex.rs>"]
description = "Utilities for encoding and decoding frames"
keywords = ["network", "framework", "async", "futures"]

View file

@ -67,7 +67,7 @@ impl Filter for Base {
if flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
Poll::Ready(ReadStatus::Terminate)
} else if flags.intersects(Flags::RD_PAUSED) {
} else if flags.intersects(Flags::RD_PAUSED | Flags::RD_BUF_FULL) {
self.0 .0.read_task.register(cx.waker());
Poll::Pending
} else {
@ -98,7 +98,7 @@ impl Filter for Base {
#[inline]
fn get_read_buf(&self) -> Option<BytesMut> {
self.0 .0.read_buf.take()
None
}
#[inline]
@ -107,21 +107,18 @@ impl Filter for Base {
}
#[inline]
fn release_read_buf(&self, buf: BytesMut, nbytes: usize) -> Result<(), io::Error> {
if nbytes > 0 {
if buf.len() > self.0.memory_pool().read_params().high as usize {
log::trace!(
"buffer is too large {}, enable read back-pressure",
buf.len()
);
self.0 .0.insert_flags(Flags::RD_READY | Flags::RD_BUF_FULL);
} else {
self.0 .0.insert_flags(Flags::RD_READY);
}
fn release_read_buf(
&self,
buf: BytesMut,
dst: &mut Option<BytesMut>,
nbytes: usize,
) -> io::Result<usize> {
if let Some(ref mut dst) = dst {
dst.extend_from_slice(&buf)
} else {
*dst = Some(buf)
}
self.0 .0.read_buf.set(Some(buf));
Ok(())
Ok(nbytes)
}
#[inline]
@ -178,8 +175,13 @@ impl Filter for NullFilter {
None
}
fn release_read_buf(&self, _: BytesMut, _: usize) -> Result<(), io::Error> {
Ok(())
fn release_read_buf(
&self,
_: BytesMut,
_: &mut Option<BytesMut>,
_: usize,
) -> io::Result<usize> {
Ok(0)
}
fn release_write_buf(&self, _: BytesMut) -> Result<(), io::Error> {

View file

@ -279,7 +279,8 @@ impl fmt::Debug for IoRef {
#[cfg(test)]
mod tests {
use std::{cell::Cell, future::Future, pin::Pin, rc::Rc, task::Context, task::Poll};
use std::cell::{Cell, RefCell};
use std::{future::Future, pin::Pin, rc::Rc, task::Context, task::Poll};
use ntex_bytes::Bytes;
use ntex_codec::BytesCodec;
@ -393,9 +394,12 @@ mod tests {
}
struct Counter<F> {
idx: usize,
inner: F,
in_bytes: Rc<Cell<usize>>,
out_bytes: Rc<Cell<usize>>,
read_order: Rc<RefCell<Vec<usize>>>,
write_order: Rc<RefCell<Vec<usize>>>,
}
impl<F: Filter> Filter for Counter<F> {
fn poll_shutdown(&self) -> Poll<io::Result<()>> {
@ -425,10 +429,13 @@ mod tests {
fn release_read_buf(
&self,
buf: BytesMut,
dst: &mut Option<BytesMut>,
new_bytes: usize,
) -> Result<(), io::Error> {
self.in_bytes.set(self.in_bytes.get() + new_bytes);
self.inner.release_read_buf(buf, new_bytes)
) -> io::Result<usize> {
let result = self.inner.release_read_buf(buf, dst, new_bytes)?;
self.read_order.borrow_mut().push(self.idx);
self.in_bytes.set(self.in_bytes.get() + result);
Ok(result)
}
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> {
@ -445,12 +452,19 @@ mod tests {
}
fn release_write_buf(&self, buf: BytesMut) -> Result<(), io::Error> {
self.write_order.borrow_mut().push(self.idx);
self.out_bytes.set(self.out_bytes.get() + buf.len());
self.inner.release_write_buf(buf)
}
}
struct CounterFactory(Rc<Cell<usize>>, Rc<Cell<usize>>);
struct CounterFactory(
usize,
Rc<Cell<usize>>,
Rc<Cell<usize>>,
Rc<RefCell<Vec<usize>>>,
Rc<RefCell<Vec<usize>>>,
);
impl<F: Filter> FilterFactory<F> for CounterFactory {
type Filter = Counter<F>;
@ -459,14 +473,20 @@ mod tests {
type Future = Ready<Io<Counter<F>>, Self::Error>;
fn create(self, io: Io<F>) -> Self::Future {
let in_bytes = self.0.clone();
let out_bytes = self.1.clone();
let idx = self.0;
let in_bytes = self.1.clone();
let out_bytes = self.2.clone();
let read_order = self.3.clone();
let write_order = self.4.clone();
Ready::Ok(
io.map_filter(|inner| {
Ok::<_, ()>(Counter {
idx,
inner,
in_bytes,
out_bytes,
read_order,
write_order,
})
})
.unwrap(),
@ -478,7 +498,15 @@ mod tests {
async fn filter() {
let in_bytes = Rc::new(Cell::new(0));
let out_bytes = Rc::new(Cell::new(0));
let factory = CounterFactory(in_bytes.clone(), out_bytes.clone());
let read_order = Rc::new(RefCell::new(Vec::new()));
let write_order = Rc::new(RefCell::new(Vec::new()));
let factory = CounterFactory(
1,
in_bytes.clone(),
out_bytes.clone(),
read_order.clone(),
write_order.clone(),
);
let (client, server) = IoTest::create();
let state = Io::new(server).add_filter(factory).await.unwrap();
@ -503,13 +531,27 @@ mod tests {
async fn boxed_filter() {
let in_bytes = Rc::new(Cell::new(0));
let out_bytes = Rc::new(Cell::new(0));
let read_order = Rc::new(RefCell::new(Vec::new()));
let write_order = Rc::new(RefCell::new(Vec::new()));
let (client, server) = IoTest::create();
let state = Io::new(server)
.add_filter(CounterFactory(in_bytes.clone(), out_bytes.clone()))
.add_filter(CounterFactory(
1,
in_bytes.clone(),
out_bytes.clone(),
read_order.clone(),
write_order.clone(),
))
.await
.unwrap()
.add_filter(CounterFactory(in_bytes.clone(), out_bytes.clone()))
.add_filter(CounterFactory(
2,
in_bytes.clone(),
out_bytes.clone(),
read_order.clone(),
write_order.clone(),
))
.await
.unwrap();
let state = state.seal();
@ -533,5 +575,7 @@ mod tests {
assert_eq!(Rc::strong_count(&in_bytes), 3);
drop(state);
assert_eq!(Rc::strong_count(&in_bytes), 1);
assert_eq!(*read_order.borrow(), &[1, 2][..]);
assert_eq!(*write_order.borrow(), &[2, 1][..]);
}
}

View file

@ -1,6 +1,6 @@
use std::{
any::Any, any::TypeId, fmt, future::Future, io::Error as IoError, task::Context,
task::Poll,
any::Any, any::TypeId, fmt, future::Future, io as sio, io::Error as IoError,
task::Context, task::Poll,
};
pub mod testing;
@ -55,7 +55,7 @@ pub trait Filter: 'static {
/// Filter wants gracefully shutdown io stream
fn want_shutdown(&self);
fn poll_shutdown(&self) -> Poll<std::io::Result<()>>;
fn poll_shutdown(&self) -> Poll<sio::Result<()>>;
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus>;
@ -65,11 +65,16 @@ pub trait Filter: 'static {
fn get_write_buf(&self) -> Option<BytesMut>;
fn release_read_buf(&self, buf: BytesMut, nbytes: usize) -> std::io::Result<()>;
fn release_read_buf(
&self,
src: BytesMut,
dst: &mut Option<BytesMut>,
nbytes: usize,
) -> sio::Result<usize>;
fn release_write_buf(&self, buf: BytesMut) -> std::io::Result<()>;
fn release_write_buf(&self, buf: BytesMut) -> sio::Result<()>;
fn closed(&self, err: Option<std::io::Error>);
fn closed(&self, err: Option<sio::Error>);
}
pub trait FilterFactory<F: Filter>: Sized {

View file

@ -31,27 +31,35 @@ impl ReadContext {
}
#[inline]
pub fn release_read_buf(
&self,
buf: BytesMut,
new_bytes: usize,
) -> Result<(), io::Error> {
pub fn release_read_buf(&self, buf: BytesMut, nbytes: usize) -> Result<(), io::Error> {
if buf.is_empty() {
self.0.memory_pool().release_read_buf(buf);
Ok(())
} else {
let mut flags = self.0.flags();
if new_bytes > 0 {
flags.insert(Flags::RD_READY);
self.0.set_flags(flags);
let mut dst = self.0 .0.read_buf.take();
let nbytes = self.0.filter().release_read_buf(buf, &mut dst, nbytes)?;
if let Some(dst) = dst {
if self.0.flags().contains(Flags::IO_FILTERS) {
self.0 .0.shutdown_filters()?;
}
if nbytes > 0 {
if dst.len() > self.0.memory_pool().read_params().high as usize {
log::trace!(
"buffer is too large {}, enable read back-pressure",
dst.len()
);
self.0 .0.insert_flags(Flags::RD_READY | Flags::RD_BUF_FULL);
} else {
self.0 .0.insert_flags(Flags::RD_READY);
log::trace!("new {} bytes available, wakeup dispatcher", nbytes);
}
self.0 .0.dispatch_task.wake();
}
self.0 .0.read_buf.set(Some(dst));
} else if nbytes > 0 {
self.0 .0.dispatch_task.wake();
log::trace!("new {} bytes available, wakeup dispatcher", new_bytes);
}
self.0.filter().release_read_buf(buf, new_bytes)?;
if flags.contains(Flags::IO_FILTERS) {
self.0 .0.shutdown_filters()?;
self.0 .0.insert_flags(Flags::RD_READY);
}
Ok(())
}

View file

@ -51,62 +51,72 @@ impl Future for ReadTask {
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.as_ref();
match this.state.poll_ready(cx) {
Poll::Ready(ReadStatus::Ready) => {
let pool = this.state.memory_pool();
let mut io = this.io.borrow_mut();
let mut buf = self.state.get_read_buf();
let (hw, lw) = pool.read_params().unpack();
loop {
match ready!(this.state.poll_ready(cx)) {
ReadStatus::Ready => {
let pool = this.state.memory_pool();
let mut io = this.io.borrow_mut();
let mut buf = self.state.get_read_buf();
let (hw, lw) = pool.read_params().unpack();
// read data from socket
let mut new_bytes = 0;
loop {
// make sure we've got room
let remaining = buf.remaining_mut();
if remaining < lw {
buf.reserve(hw - remaining);
}
// read data from socket
let mut new_bytes = 0;
let mut close = false;
let mut pending = false;
loop {
// make sure we've got room
let remaining = buf.remaining_mut();
if remaining < lw {
buf.reserve(hw - remaining);
}
match poll_read_buf(Pin::new(&mut *io), cx, &mut buf) {
Poll::Pending => break,
Poll::Ready(Ok(n)) => {
if n == 0 {
log::trace!("io stream is disconnected");
if let Err(e) = this.state.release_read_buf(buf, new_bytes)
{
this.state.close(Some(e));
match poll_read_buf(Pin::new(&mut *io), cx, &mut buf) {
Poll::Pending => {
pending = true;
break;
}
Poll::Ready(Ok(n)) => {
if n == 0 {
log::trace!("tokio stream is disconnected");
close = true;
} else {
this.state.close(None);
new_bytes += n;
if new_bytes < hw {
continue;
}
}
break;
}
Poll::Ready(Err(err)) => {
log::trace!("read task failed on io {:?}", err);
let _ = this.state.release_read_buf(buf, new_bytes);
this.state.close(Some(err));
return Poll::Ready(());
} else {
new_bytes += n;
if buf.len() > hw {
break;
}
}
}
Poll::Ready(Err(err)) => {
log::trace!("read task failed on io {:?}", err);
let _ = this.state.release_read_buf(buf, new_bytes);
this.state.close(Some(err));
return Poll::Ready(());
}
}
}
if let Err(e) = this.state.release_read_buf(buf, new_bytes) {
this.state.close(Some(e));
Poll::Ready(())
} else {
Poll::Pending
if new_bytes == 0 && close {
this.state.close(None);
return Poll::Ready(());
}
return if let Err(e) = this.state.release_read_buf(buf, new_bytes) {
this.state.close(Some(e));
Poll::Ready(())
} else if close {
this.state.close(None);
Poll::Ready(())
} else if pending {
Poll::Pending
} else {
continue;
};
}
ReadStatus::Terminate => {
log::trace!("read task is instructed to shutdown");
return Poll::Ready(());
}
}
Poll::Ready(ReadStatus::Terminate) => {
log::trace!("read task is instructed to shutdown");
Poll::Ready(())
}
Poll::Pending => Poll::Pending,
}
}
}
@ -466,63 +476,72 @@ mod unixstream {
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.as_ref();
match this.state.poll_ready(cx) {
Poll::Ready(ReadStatus::Ready) => {
let pool = this.state.memory_pool();
let mut io = this.io.borrow_mut();
let mut buf = self.state.get_read_buf();
let (hw, lw) = pool.read_params().unpack();
loop {
match ready!(this.state.poll_ready(cx)) {
ReadStatus::Ready => {
let pool = this.state.memory_pool();
let mut io = this.io.borrow_mut();
let mut buf = self.state.get_read_buf();
let (hw, lw) = pool.read_params().unpack();
// read data from socket
let mut new_bytes = 0;
loop {
// make sure we've got room
let remaining = buf.remaining_mut();
if remaining < lw {
buf.reserve(hw - remaining);
}
// read data from socket
let mut new_bytes = 0;
let mut close = false;
let mut pending = false;
loop {
// make sure we've got room
let remaining = buf.remaining_mut();
if remaining < lw {
buf.reserve(hw - remaining);
}
match poll_read_buf(Pin::new(&mut *io), cx, &mut buf) {
Poll::Pending => break,
Poll::Ready(Ok(n)) => {
if n == 0 {
log::trace!("io stream is disconnected");
if let Err(e) =
this.state.release_read_buf(buf, new_bytes)
{
this.state.close(Some(e));
match poll_read_buf(Pin::new(&mut *io), cx, &mut buf) {
Poll::Pending => {
pending = true;
break;
}
Poll::Ready(Ok(n)) => {
if n == 0 {
log::trace!("unix stream is disconnected");
close = true;
} else {
this.state.close(None);
new_bytes += n;
if new_bytes < hw {
continue;
}
}
break;
}
Poll::Ready(Err(err)) => {
log::trace!("read task failed on io {:?}", err);
let _ = this.state.release_read_buf(buf, new_bytes);
this.state.close(Some(err));
return Poll::Ready(());
} else {
new_bytes += n;
if buf.len() > hw {
break;
}
}
}
Poll::Ready(Err(err)) => {
log::trace!("read task failed on io {:?}", err);
let _ = this.state.release_read_buf(buf, new_bytes);
this.state.close(Some(err));
return Poll::Ready(());
}
}
}
if let Err(e) = this.state.release_read_buf(buf, new_bytes) {
this.state.close(Some(e));
Poll::Ready(())
} else {
Poll::Pending
if new_bytes == 0 && close {
this.state.close(None);
return Poll::Ready(());
}
return if let Err(e) = this.state.release_read_buf(buf, new_bytes) {
this.state.close(Some(e));
Poll::Ready(())
} else if close {
this.state.close(None);
Poll::Ready(())
} else if pending {
Poll::Pending
} else {
continue;
};
}
ReadStatus::Terminate => {
log::trace!("read task is instructed to shutdown");
return Poll::Ready(());
}
}
Poll::Ready(ReadStatus::Terminate) => {
log::trace!("read task is instructed to shutdown");
Poll::Ready(())
}
Poll::Pending => Poll::Pending,
}
}
}