Remove ReadFilter/WriteFilter traits

This commit is contained in:
Nikolay Kim 2021-12-19 10:44:12 +06:00
parent 1af728eb01
commit 1ccb87ea51
6 changed files with 153 additions and 187 deletions

View file

@ -1,8 +1,8 @@
# Changes
## [0.1.0-b.1] - 2021-12-18
## [0.1.0-b.1] - 2021-12-19
* Modify filter's release_read/write_buf return type
* Remove ReadFilter/WriteFilter traits.
## [0.1.0-b.0] - 2021-12-18

View file

@ -527,28 +527,28 @@ mod tests {
use ntex_util::time::{sleep, Millis};
use crate::testing::IoTest;
use crate::{state::Flags, state::IoStateInner, Io, IoStream, WriteRef};
use crate::{state::Flags, Io, IoRef, IoStream, WriteRef};
use super::*;
pub(crate) struct State(Rc<IoStateInner>);
pub(crate) struct State(IoRef);
impl State {
fn flags(&self) -> Flags {
self.0.flags.get()
self.0.flags()
}
fn write(&'_ self) -> WriteRef<'_> {
WriteRef(self.0.as_ref())
WriteRef(&self.0)
}
fn close(&self) {
self.0.insert_flags(Flags::DSP_STOP);
self.0.dispatch_task.wake();
self.0 .0.insert_flags(Flags::DSP_STOP);
self.0 .0.dispatch_task.wake();
}
fn set_memory_pool(&self, pool: PoolRef) {
self.0.pool.set(pool);
self.0 .0.pool.set(pool);
}
}
@ -572,7 +572,7 @@ mod tests {
error: Cell::new(None),
inflight: Cell::new(0),
});
let inner = State(state.0 .0.clone());
let inner = State(state.get_ref());
let expire = ka_updated + Duration::from_millis(500);
timer.register(expire, expire, &state);
@ -872,7 +872,7 @@ mod tests {
.keepalive_timeout(Seconds(1))
.await;
});
state.0.disconnect_timeout.set(Millis::ONE_SEC);
state.0 .0.disconnect_timeout.set(Millis::ONE_SEC);
let buf = client.read().await.unwrap();
assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));

View file

@ -1,14 +1,14 @@
use std::{any, io, rc::Rc, task::Context, task::Poll};
use std::{any, io, task::Context, task::Poll};
use ntex_bytes::BytesMut;
use super::state::{Flags, IoRef, IoStateInner};
use super::{Filter, ReadFilter, WriteFilter, WriteReadiness};
use super::state::{Flags, IoRef};
use super::{Filter, WriteReadiness};
pub struct DefaultFilter(Rc<IoStateInner>);
pub struct DefaultFilter(IoRef);
impl DefaultFilter {
pub(crate) fn new(inner: Rc<IoStateInner>) -> Self {
pub(crate) fn new(inner: IoRef) -> Self {
DefaultFilter(inner)
}
}
@ -16,64 +16,95 @@ impl DefaultFilter {
impl Filter for DefaultFilter {
#[inline]
fn shutdown(&self, _: &IoRef) -> Poll<Result<(), io::Error>> {
let mut flags = self.0.flags.get();
let mut flags = self.0.flags();
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
flags.insert(Flags::IO_SHUTDOWN);
self.0.flags.set(flags);
self.0.read_task.wake();
self.0.write_task.wake();
self.0.set_flags(flags);
self.0 .0.read_task.wake();
self.0 .0.write_task.wake();
}
Poll::Ready(Ok(()))
}
#[inline]
fn closed(&self, err: Option<io::Error>) {
self.0 .0.set_error(err);
self.0 .0.handle.take();
self.0 .0.insert_flags(Flags::IO_CLOSED);
self.0 .0.dispatch_task.wake();
}
#[inline]
fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> {
if let Some(hnd) = self.0.handle.take() {
if let Some(hnd) = self.0 .0.handle.take() {
let res = hnd.query(id);
self.0.handle.set(Some(hnd));
self.0 .0.handle.set(Some(hnd));
res
} else {
None
}
}
}
impl ReadFilter for DefaultFilter {
#[inline]
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), ()>> {
let flags = self.0.flags.get();
let flags = self.0.flags();
if flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
Poll::Ready(Err(()))
} else if flags.intersects(Flags::RD_PAUSED) {
self.0.read_task.register(cx.waker());
self.0 .0.read_task.register(cx.waker());
Poll::Pending
} else {
self.0.read_task.register(cx.waker());
self.0 .0.read_task.register(cx.waker());
Poll::Ready(Ok(()))
}
}
#[inline]
fn read_closed(&self, err: Option<io::Error>) {
self.0.set_error(err);
fn poll_write_ready(
&self,
cx: &mut Context<'_>,
) -> Poll<Result<(), WriteReadiness>> {
let mut flags = self.0.flags();
if flags.contains(Flags::IO_ERR) {
Poll::Ready(Err(WriteReadiness::Terminate))
} else if flags.intersects(Flags::IO_SHUTDOWN) {
Poll::Ready(Err(WriteReadiness::Shutdown(
self.0 .0.disconnect_timeout.get(),
)))
} else if flags.contains(Flags::IO_FILTERS)
&& !flags.contains(Flags::IO_FILTERS_TO)
{
flags.insert(Flags::IO_FILTERS_TO);
self.0.set_flags(flags);
self.0 .0.write_task.register(cx.waker());
Poll::Ready(Err(WriteReadiness::Timeout(
self.0 .0.disconnect_timeout.get(),
)))
} else {
self.0 .0.write_task.register(cx.waker());
Poll::Ready(Ok(()))
}
}
#[inline]
fn get_read_buf(&self) -> Option<BytesMut> {
self.0.read_buf.take()
self.0 .0.read_buf.take()
}
#[inline]
fn release_read_buf(
&self,
buf: BytesMut,
new_bytes: usize,
) -> Result<bool, io::Error> {
let mut flags = self.0.flags.get();
fn get_write_buf(&self) -> Option<BytesMut> {
self.0 .0.write_buf.take()
}
if new_bytes > 0 {
if buf.len() > self.0.pool.get().read_params().high as usize {
#[inline]
fn release_read_buf(&self, buf: BytesMut, nbytes: usize) -> Result<(), io::Error> {
let mut flags = self.0.flags();
if nbytes > 0 {
if buf.len() > self.0.memory_pool().read_params().high as usize {
log::trace!(
"buffer is too large {}, enable read back-pressure",
buf.len()
@ -82,66 +113,23 @@ impl ReadFilter for DefaultFilter {
} else {
flags.insert(Flags::RD_READY);
}
self.0.flags.set(flags);
self.0.set_flags(flags);
}
self.0.read_buf.set(Some(buf));
Ok(false)
}
}
impl WriteFilter for DefaultFilter {
#[inline]
fn poll_write_ready(
&self,
cx: &mut Context<'_>,
) -> Poll<Result<(), WriteReadiness>> {
let mut flags = self.0.flags.get();
if flags.contains(Flags::IO_ERR) {
Poll::Ready(Err(WriteReadiness::Terminate))
} else if flags.intersects(Flags::IO_SHUTDOWN) {
Poll::Ready(Err(WriteReadiness::Shutdown(
self.0.disconnect_timeout.get(),
)))
} else if flags.contains(Flags::IO_FILTERS)
&& !flags.contains(Flags::IO_FILTERS_TO)
{
flags.insert(Flags::IO_FILTERS_TO);
self.0.flags.set(flags);
self.0.write_task.register(cx.waker());
Poll::Ready(Err(WriteReadiness::Timeout(
self.0.disconnect_timeout.get(),
)))
} else {
self.0.write_task.register(cx.waker());
Poll::Ready(Ok(()))
}
self.0 .0.read_buf.set(Some(buf));
Ok(())
}
#[inline]
fn write_closed(&self, err: Option<io::Error>) {
self.0.set_error(err);
self.0.handle.take();
self.0.insert_flags(Flags::IO_CLOSED);
self.0.dispatch_task.wake();
}
#[inline]
fn get_write_buf(&self) -> Option<BytesMut> {
self.0.write_buf.take()
}
#[inline]
fn release_write_buf(&self, buf: BytesMut) -> Result<bool, io::Error> {
let pool = self.0.pool.get();
fn release_write_buf(&self, buf: BytesMut) -> Result<(), io::Error> {
let pool = self.0.memory_pool();
if buf.is_empty() {
pool.release_write_buf(buf);
} else {
self.0.write_buf.set(Some(buf));
self.0.write_task.wake();
self.0 .0.write_buf.set(Some(buf));
self.0 .0.write_task.wake();
}
Ok(false)
Ok(())
}
}
@ -160,39 +148,33 @@ impl Filter for NullFilter {
Poll::Ready(Ok(()))
}
fn closed(&self, _: Option<io::Error>) {}
fn query(&self, _: any::TypeId) -> Option<Box<dyn any::Any>> {
None
}
}
impl ReadFilter for NullFilter {
fn poll_read_ready(&self, _: &mut Context<'_>) -> Poll<Result<(), ()>> {
Poll::Ready(Err(()))
}
fn read_closed(&self, _: Option<io::Error>) {}
fn poll_write_ready(&self, _: &mut Context<'_>) -> Poll<Result<(), WriteReadiness>> {
Poll::Ready(Err(WriteReadiness::Terminate))
}
fn get_read_buf(&self) -> Option<BytesMut> {
None
}
fn release_read_buf(&self, _: BytesMut, _: usize) -> Result<bool, io::Error> {
Ok(true)
}
}
impl WriteFilter for NullFilter {
fn poll_write_ready(&self, _: &mut Context<'_>) -> Poll<Result<(), WriteReadiness>> {
Poll::Ready(Err(WriteReadiness::Terminate))
}
fn write_closed(&self, _: Option<io::Error>) {}
fn get_write_buf(&self) -> Option<BytesMut> {
None
}
fn release_write_buf(&self, _: BytesMut) -> Result<bool, io::Error> {
Ok(true)
fn release_read_buf(&self, _: BytesMut, _: usize) -> Result<(), io::Error> {
Ok(())
}
fn release_write_buf(&self, _: BytesMut) -> Result<(), io::Error> {
Ok(())
}
}

View file

@ -36,31 +36,25 @@ pub enum WriteReadiness {
Terminate,
}
pub trait ReadFilter {
pub trait Filter: 'static {
fn shutdown(&self, st: &IoRef) -> Poll<Result<(), io::Error>>;
fn closed(&self, err: Option<io::Error>);
fn query(&self, id: TypeId) -> Option<Box<dyn Any>>;
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), ()>>;
fn read_closed(&self, err: Option<io::Error>);
fn get_read_buf(&self) -> Option<BytesMut>;
fn release_read_buf(&self, buf: BytesMut, nbytes: usize) -> Result<bool, io::Error>;
}
pub trait WriteFilter {
fn poll_write_ready(&self, cx: &mut Context<'_>)
-> Poll<Result<(), WriteReadiness>>;
fn write_closed(&self, err: Option<io::Error>);
fn get_read_buf(&self) -> Option<BytesMut>;
fn get_write_buf(&self) -> Option<BytesMut>;
fn release_write_buf(&self, buf: BytesMut) -> Result<bool, io::Error>;
}
fn release_read_buf(&self, buf: BytesMut, nbytes: usize) -> Result<(), io::Error>;
pub trait Filter: ReadFilter + WriteFilter + 'static {
fn shutdown(&self, st: &IoRef) -> Poll<Result<(), io::Error>>;
fn query(&self, id: TypeId) -> Option<Box<dyn Any>>;
fn release_write_buf(&self, buf: BytesMut) -> Result<(), io::Error>;
}
pub trait FilterFactory<F: Filter>: Sized {

View file

@ -65,7 +65,7 @@ pub(crate) struct IoStateInner {
pub(super) dispatch_task: LocalWaker,
pub(super) read_buf: Cell<Option<BytesMut>>,
pub(super) write_buf: Cell<Option<BytesMut>>,
pub(super) filter: Cell<&'static dyn Filter>,
filter: Cell<&'static dyn Filter>,
pub(super) handle: Cell<Option<Box<dyn Handle>>>,
on_disconnect: RefCell<Vec<Option<LocalWaker>>>,
}
@ -225,7 +225,7 @@ impl Io {
on_disconnect: RefCell::new(Vec::new()),
});
let filter = Box::new(DefaultFilter::new(inner.clone()));
let filter = Box::new(DefaultFilter::new(IoRef(inner.clone())));
let filter_ref: &'static dyn Filter = unsafe {
let filter: &dyn Filter = filter.as_ref();
std::mem::transmute(filter)
@ -293,6 +293,18 @@ impl IoRef {
self.0.flags.get()
}
#[inline]
/// Set flags
pub(crate) fn set_flags(&self, flags: Flags) {
self.0.flags.set(flags)
}
#[inline]
/// Get memory pool
pub(crate) fn filter(&self) -> &dyn Filter {
self.0.filter.get()
}
#[inline]
/// Get memory pool
pub fn memory_pool(&self) -> PoolRef {
@ -389,7 +401,7 @@ impl IoRef {
#[inline]
/// Query specific data
pub fn query<T: 'static>(&self) -> types::QueryItem<T> {
if let Some(item) = self.0.filter.get().query(any::TypeId::of::<T>()) {
if let Some(item) = self.filter().query(any::TypeId::of::<T>()) {
types::QueryItem::new(item)
} else {
types::QueryItem::empty()
@ -420,10 +432,10 @@ impl IoRef {
where
U: Encoder,
{
let filter = self.0.filter.get();
let filter = self.filter();
let mut buf = filter
.get_write_buf()
.unwrap_or_else(|| self.0.pool.get().get_write_buf());
.unwrap_or_else(|| self.memory_pool().get_write_buf());
let is_write_sleep = buf.is_empty();
codec.encode(item, &mut buf).map_err(Either::Left)?;
@ -626,7 +638,7 @@ impl<'a> WriteRef<'a> {
/// Check if write buffer is full
pub fn is_full(&self) -> bool {
if let Some(buf) = self.0 .0.read_buf.take() {
let hw = self.0 .0.pool.get().write_params_high();
let hw = self.0.memory_pool().write_params_high();
let result = buf.len() >= hw;
self.0 .0.write_buf.set(Some(buf));
result
@ -659,19 +671,16 @@ impl<'a> WriteRef<'a> {
where
F: FnOnce(&mut BytesMut) -> R,
{
let filter = self.0 .0.filter.get();
let filter = self.0.filter();
let mut buf = filter
.get_write_buf()
.unwrap_or_else(|| self.0 .0.pool.get().get_write_buf());
.unwrap_or_else(|| self.0.memory_pool().get_write_buf());
if buf.is_empty() {
self.0 .0.write_task.wake();
}
let result = f(&mut buf);
let close = filter.release_write_buf(buf)?;
if close {
self.0 .0.init_shutdown(None, self.0);
}
filter.release_write_buf(buf)?;
Ok(result)
}
@ -690,12 +699,12 @@ impl<'a> WriteRef<'a> {
let flags = self.0 .0.flags.get();
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
let filter = self.0 .0.filter.get();
let filter = self.0.filter();
let mut buf = filter
.get_write_buf()
.unwrap_or_else(|| self.0 .0.pool.get().get_write_buf());
.unwrap_or_else(|| self.0.memory_pool().get_write_buf());
let is_write_sleep = buf.is_empty();
let (hw, lw) = self.0 .0.pool.get().write_params().unpack();
let (hw, lw) = self.0.memory_pool().write_params().unpack();
// make sure we've got room
let remaining = buf.capacity() - buf.len();
@ -710,15 +719,8 @@ impl<'a> WriteRef<'a> {
}
buf.len() < hw
});
match filter.release_write_buf(buf) {
Err(err) => {
self.0 .0.set_error(Some(err));
}
Ok(close) => {
if close {
self.0 .0.init_shutdown(None, self.0);
}
}
if let Err(err) = filter.release_write_buf(buf) {
self.0 .0.set_error(Some(err));
}
result
} else {
@ -734,15 +736,15 @@ impl<'a> WriteRef<'a> {
let flags = self.0 .0.flags.get();
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
let filter = self.0 .0.filter.get();
let filter = self.0.filter();
let mut buf = filter
.get_write_buf()
.unwrap_or_else(|| self.0 .0.pool.get().get_write_buf());
.unwrap_or_else(|| self.0.memory_pool().get_write_buf());
let is_write_sleep = buf.is_empty();
// write and wake write task
buf.extend_from_slice(src);
let result = buf.len() < self.0 .0.pool.get().write_params_high();
let result = buf.len() < self.0.memory_pool().write_params_high();
if is_write_sleep {
self.0 .0.write_task.wake();
}
@ -783,7 +785,7 @@ impl<'a> WriteRef<'a> {
self.0 .0.insert_flags(Flags::WR_WAIT);
self.0 .0.dispatch_task.register(cx.waker());
return Poll::Pending;
} else if len >= self.0 .0.pool.get().write_params_high() << 1 {
} else if len >= self.0.memory_pool().write_params_high() << 1 {
self.0 .0.insert_flags(Flags::WR_BACKPRESSURE);
self.0 .0.dispatch_task.register(cx.waker());
return Poll::Pending;
@ -828,7 +830,7 @@ impl<'a> ReadRef<'a> {
/// Check if read buffer is full
pub fn is_full(&self) -> bool {
if let Some(buf) = self.0 .0.read_buf.take() {
let result = buf.len() >= self.0 .0.pool.get().read_params_high();
let result = buf.len() >= self.0.memory_pool().read_params_high();
self.0 .0.read_buf.set(Some(buf));
result
} else {
@ -887,10 +889,10 @@ impl<'a> ReadRef<'a> {
.0
.read_buf
.take()
.unwrap_or_else(|| self.0 .0.pool.get().get_read_buf());
.unwrap_or_else(|| self.0.memory_pool().get_read_buf());
let res = f(&mut buf);
if buf.is_empty() {
self.0 .0.pool.get().release_read_buf(buf);
self.0.memory_pool().release_read_buf(buf);
} else {
self.0 .0.read_buf.set(Some(buf));
}
@ -1005,7 +1007,7 @@ mod tests {
use super::*;
use crate::testing::IoTest;
use crate::{Filter, FilterFactory, ReadFilter, WriteFilter, WriteReadiness};
use crate::{Filter, FilterFactory, WriteReadiness};
const BIN: &[u8] = b"GET /test HTTP/1\r\n\r\n";
const TEXT: &str = "GET /test HTTP/1\r\n\r\n";
@ -1114,7 +1116,7 @@ mod tests {
in_bytes: Rc<Cell<usize>>,
out_bytes: Rc<Cell<usize>>,
}
impl<F: ReadFilter + WriteFilter + 'static> Filter for Counter<F> {
impl<F: Filter> Filter for Counter<F> {
fn shutdown(&self, _: &IoRef) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}
@ -1122,15 +1124,13 @@ mod tests {
fn query(&self, _: std::any::TypeId) -> Option<Box<dyn std::any::Any>> {
None
}
}
impl<F: ReadFilter> ReadFilter for Counter<F> {
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), ()>> {
self.inner.poll_read_ready(cx)
}
fn read_closed(&self, err: Option<io::Error>) {
self.inner.read_closed(err)
fn closed(&self, err: Option<io::Error>) {
self.inner.closed(err)
}
fn get_read_buf(&self) -> Option<BytesMut> {
@ -1145,9 +1145,7 @@ mod tests {
self.in_bytes.set(self.in_bytes.get() + new_bytes);
self.inner.release_read_buf(buf, new_bytes)
}
}
impl<F: WriteFilter> WriteFilter for Counter<F> {
fn poll_write_ready(
&self,
cx: &mut Context<'_>,
@ -1155,10 +1153,6 @@ mod tests {
self.inner.poll_write_ready(cx)
}
fn write_closed(&self, err: Option<io::Error>) {
self.inner.write_closed(err)
}
fn get_write_buf(&self) -> Option<BytesMut> {
if let Some(buf) = self.inner.get_write_buf() {
self.out_bytes.set(self.out_bytes.get() - buf.len());
@ -1186,8 +1180,8 @@ mod tests {
let in_bytes = self.0.clone();
let out_bytes = self.1.clone();
Ready::Ok(
io.map_filter::<CounterFactory, _>(|inner| {
Ok(Counter {
io.map_filter(|inner| {
Ok::<_, ()>(Counter {
inner,
in_bytes,
out_bytes,

View file

@ -9,27 +9,25 @@ pub struct ReadContext(pub(super) IoRef);
impl ReadContext {
#[inline]
pub fn memory_pool(&self) -> PoolRef {
self.0 .0.pool.get()
self.0.memory_pool()
}
#[inline]
pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), ()>> {
self.0 .0.filter.get().poll_read_ready(cx)
self.0.filter().poll_read_ready(cx)
}
#[inline]
pub fn close(&self, err: Option<io::Error>) {
self.0 .0.filter.get().read_closed(err);
self.0.filter().closed(err);
}
#[inline]
pub fn get_read_buf(&self) -> BytesMut {
self.0
.0
.filter
.get()
.filter()
.get_read_buf()
.unwrap_or_else(|| self.0 .0.pool.get().get_read_buf())
.unwrap_or_else(|| self.0.memory_pool().get_read_buf())
}
#[inline]
@ -39,22 +37,20 @@ impl ReadContext {
new_bytes: usize,
) -> Result<(), io::Error> {
if buf.is_empty() {
self.0 .0.pool.get().release_read_buf(buf);
self.0.memory_pool().release_read_buf(buf);
Ok(())
} else {
let mut flags = self.0 .0.flags.get();
let mut flags = self.0.flags();
if new_bytes > 0 {
flags.insert(Flags::RD_READY);
self.0 .0.flags.set(flags);
self.0.set_flags(flags);
self.0 .0.dispatch_task.wake();
}
let close = self.0 .0.filter.get().release_read_buf(buf, new_bytes)?;
self.0.filter().release_read_buf(buf, new_bytes)?;
if flags.contains(Flags::IO_FILTERS) {
self.0 .0.shutdown_filters(&self.0)?;
} else if close {
self.0 .0.init_shutdown(None, &self.0);
}
Ok(())
}
@ -66,17 +62,17 @@ pub struct WriteContext(pub(super) IoRef);
impl WriteContext {
#[inline]
pub fn memory_pool(&self) -> PoolRef {
self.0 .0.pool.get()
self.0.memory_pool()
}
#[inline]
pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), WriteReadiness>> {
self.0 .0.filter.get().poll_write_ready(cx)
self.0.filter().poll_write_ready(cx)
}
#[inline]
pub fn close(&self, err: Option<io::Error>) {
self.0 .0.filter.get().write_closed(err)
self.0.filter().closed(err)
}
#[inline]
@ -86,14 +82,14 @@ impl WriteContext {
#[inline]
pub fn release_write_buf(&self, buf: BytesMut) -> Result<(), io::Error> {
let pool = self.0 .0.pool.get();
let mut flags = self.0 .0.flags.get();
let pool = self.0.memory_pool();
let mut flags = self.0.flags();
if buf.is_empty() {
pool.release_write_buf(buf);
if flags.intersects(Flags::WR_WAIT | Flags::WR_BACKPRESSURE) {
flags.remove(Flags::WR_WAIT | Flags::WR_BACKPRESSURE);
self.0 .0.flags.set(flags);
self.0.set_flags(flags);
self.0 .0.dispatch_task.wake();
}
} else {
@ -102,7 +98,7 @@ impl WriteContext {
&& buf.len() < pool.write_params_high() << 1
{
flags.remove(Flags::WR_BACKPRESSURE);
self.0 .0.flags.set(flags);
self.0.set_flags(flags);
self.0 .0.dispatch_task.wake();
}
self.0 .0.write_buf.set(Some(buf))