Cleanup Filter trait, removed closed,want_read,want_shutdown methods

This commit is contained in:
Nikolay Kim 2021-12-29 15:10:24 +06:00
parent c5d43eb12d
commit dc17d00ed9
21 changed files with 331 additions and 507 deletions

View file

@ -1,6 +1,8 @@
# Changes
## [0.1.0-b.10] - 2021-12-xx
## [0.1.0-b.10] - 2021-12-30
* Cleanup Filter trait, removed closed,want_read,want_shutdown methods
* Cleanup internal flags on io error

View file

@ -1,6 +1,6 @@
[package]
name = "ntex-io"
version = "0.1.0-b.9"
version = "0.1.0-b.10"
authors = ["ntex contributors <team@ntex.rs>"]
description = "Utilities for encoding and decoding frames"
keywords = ["network", "framework", "async", "futures"]

View file

@ -63,7 +63,6 @@ enum DispatcherState {
}
enum DispatcherError<S, U> {
KeepAlive,
Encoder(U),
Service(S),
}
@ -176,7 +175,7 @@ where
Err(err) => self.error.set(Some(DispatcherError::Service(err))),
Ok(None) => return,
}
io.wake_dispatcher();
io.wake();
}
}
@ -382,18 +381,12 @@ where
) -> Poll<PollService<U>> {
match srv.poll_ready(cx) {
Poll::Ready(Ok(_)) => {
// check keepalive timeout
self.check_keepalive();
// check for errors
Poll::Ready(if let Some(err) = self.shared.error.take() {
log::trace!("error occured, stopping dispatcher");
self.st.set(DispatcherState::Stop);
match err {
DispatcherError::KeepAlive => {
PollService::Item(DispatchItem::KeepAliveTimeout)
}
DispatcherError::Encoder(err) => {
PollService::Item(DispatchItem::EncoderError(err))
}
@ -431,18 +424,6 @@ where
self.ka_timeout.get().non_zero()
}
/// check keepalive timeout
fn check_keepalive(&self) {
if self.io.is_keepalive() {
log::trace!("keepalive timeout");
if let Some(err) = self.shared.error.take() {
self.shared.error.set(Some(err));
} else {
self.shared.error.set(Some(DispatcherError::KeepAlive));
}
}
}
/// update keep-alive timer
fn update_keepalive(&self) {
if self.ka_enabled() {
@ -790,6 +771,7 @@ mod tests {
#[ntex::test]
async fn test_keepalive() {
env_logger::init();
let (client, server) = IoTest::create();
client.remote_buffer_cap(1024);
client.write("GET /test HTTP/1\r\n\r\n");
@ -831,7 +813,7 @@ mod tests {
// write side must be closed, dispatcher should fail with keep-alive
let flags = state.flags();
assert!(flags.contains(Flags::IO_SHUTDOWN));
assert!(flags.contains(Flags::IO_STOPPING));
assert!(flags.contains(Flags::DSP_KEEPALIVE));
assert!(client.is_closed());
assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);

View file

@ -14,13 +14,6 @@ impl Base {
}
impl Filter for Base {
#[inline]
fn closed(&self, err: Option<io::Error>) {
self.0 .0.set_error(err);
self.0 .0.handle.take();
self.0 .0.dispatch_task.wake();
}
#[inline]
fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> {
if let Some(hnd) = self.0 .0.handle.take() {
@ -32,25 +25,8 @@ impl Filter for Base {
}
}
#[inline]
fn want_read(&self) {
let flags = self.0.flags();
if flags.intersects(Flags::RD_PAUSED | Flags::RD_BUF_FULL) {
self.0
.0
.remove_flags(Flags::RD_PAUSED | Flags::RD_BUF_FULL);
self.0 .0.read_task.wake();
}
}
#[inline]
fn want_shutdown(&self, err: Option<io::Error>) {
self.0 .0.init_shutdown(err);
}
#[inline]
fn poll_shutdown(&self) -> Poll<io::Result<()>> {
self.want_shutdown(None);
Poll::Ready(Ok(()))
}
@ -58,7 +34,7 @@ impl Filter for Base {
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> {
let flags = self.0.flags();
if flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
if flags.intersects(Flags::IO_STOPPING) {
Poll::Ready(ReadStatus::Terminate)
} else if flags.intersects(Flags::RD_PAUSED | Flags::RD_BUF_FULL) {
self.0 .0.read_task.register(cx.waker());
@ -73,13 +49,14 @@ impl Filter for Base {
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> {
let mut flags = self.0.flags();
if flags.contains(Flags::IO_ERR) {
if flags.contains(Flags::IO_STOPPED) {
Poll::Ready(WriteStatus::Terminate)
} else if flags.intersects(Flags::IO_SHUTDOWN) {
} else if flags.intersects(Flags::IO_STOPPING) {
Poll::Ready(WriteStatus::Shutdown(self.0 .0.disconnect_timeout.get()))
} else if flags.contains(Flags::IO_FILTERS) && !flags.contains(Flags::IO_FILTERS_TO)
} else if flags.contains(Flags::IO_STOPPING_FILTERS)
&& !flags.contains(Flags::IO_FILTERS_TIMEOUT)
{
flags.insert(Flags::IO_FILTERS_TO);
flags.insert(Flags::IO_FILTERS_TIMEOUT);
self.0.set_flags(flags);
self.0 .0.write_task.register(cx.waker());
Poll::Ready(WriteStatus::Timeout(self.0 .0.disconnect_timeout.get()))
@ -102,6 +79,7 @@ impl Filter for Base {
#[inline]
fn release_read_buf(
&self,
_: &IoRef,
buf: BytesMut,
dst: &mut Option<BytesMut>,
nbytes: usize,
@ -145,12 +123,6 @@ impl Filter for NullFilter {
None
}
fn closed(&self, _: Option<io::Error>) {}
fn want_read(&self) {}
fn want_shutdown(&self, _: Option<io::Error>) {}
fn poll_shutdown(&self) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
@ -173,6 +145,7 @@ impl Filter for NullFilter {
fn release_read_buf(
&self,
_: &IoRef,
_: BytesMut,
_: &mut Option<BytesMut>,
_: usize,

View file

@ -13,33 +13,33 @@ use super::{Filter, FilterFactory, Handle, IoStream, RecvError};
bitflags::bitflags! {
pub struct Flags: u16 {
/// io error occured
const IO_ERR = 0b0000_0000_0000_0001;
/// shuting down filters
const IO_FILTERS = 0b0000_0000_0000_0010;
/// shuting down filters timeout
const IO_FILTERS_TO = 0b0000_0000_0000_0100;
/// io is closed
const IO_STOPPED = 0b0000_0000_0000_0001;
/// shutdown io tasks
const IO_SHUTDOWN = 0b0000_0000_0000_1000;
const IO_STOPPING = 0b0000_0000_0000_0010;
/// shuting down filters
const IO_STOPPING_FILTERS = 0b0000_0000_0000_0100;
/// initiate filters shutdown timeout in write task
const IO_FILTERS_TIMEOUT = 0b0000_0000_0000_1000;
/// pause io read
const RD_PAUSED = 0b0000_0000_0010_0000;
const RD_PAUSED = 0b0000_0000_0001_0000;
/// new data is available
const RD_READY = 0b0000_0000_0100_0000;
const RD_READY = 0b0000_0000_0010_0000;
/// read buffer is full
const RD_BUF_FULL = 0b0000_0000_1000_0000;
const RD_BUF_FULL = 0b0000_0000_0100_0000;
/// wait write completion
const WR_WAIT = 0b0000_0001_0000_0000;
const WR_WAIT = 0b0000_0000_1000_0000;
/// write buffer is full
const WR_BACKPRESSURE = 0b0000_0010_0000_0000;
const WR_BACKPRESSURE = 0b0000_0001_0000_0000;
/// dispatcher is marked stopped
const DSP_STOP = 0b0001_0000_0000_0000;
const DSP_STOP = 0b0000_0010_0000_0000;
/// keep-alive timeout occured
const DSP_KEEPALIVE = 0b0010_0000_0000_0000;
const DSP_KEEPALIVE = 0b0000_0100_0000_0000;
/// dispatcher returned error
const DSP_ERR = 0b0100_0000_0000_0000;
const DSP_ERR = 0b0000_1000_0000_0000;
}
}
@ -104,15 +104,7 @@ impl IoState {
}
#[inline]
pub(super) fn is_io_open(&self) -> bool {
!self
.flags
.get()
.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN)
}
#[inline]
pub(super) fn set_error(&self, err: Option<io::Error>) {
pub(super) fn io_stopped(&self, err: Option<io::Error>) {
if err.is_some() {
self.error.set(err);
}
@ -120,52 +112,49 @@ impl IoState {
self.write_task.wake();
self.dispatch_task.wake();
self.notify_disconnect();
let mut flags = self.flags.get();
flags.insert(Flags::IO_ERR);
flags.remove(
Flags::DSP_KEEPALIVE
| Flags::RD_PAUSED
| Flags::RD_READY
| Flags::RD_BUF_FULL
| Flags::WR_WAIT
| Flags::WR_BACKPRESSURE,
self.handle.take();
self.insert_flags(
Flags::IO_STOPPED | Flags::IO_STOPPING | Flags::IO_STOPPING_FILTERS,
);
self.flags.set(flags);
}
#[inline]
/// Gracefully shutdown read and write io tasks
pub(super) fn init_shutdown(&self, err: Option<io::Error>) {
let flags = self.flags.get();
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN | Flags::IO_FILTERS) {
log::trace!("initiate io shutdown {:?} {:?}", flags, err);
self.insert_flags(Flags::IO_FILTERS);
if err.is_some() {
self.io_stopped(err);
} else if !self
.flags
.get()
.intersects(Flags::IO_STOPPED | Flags::IO_STOPPING | Flags::IO_STOPPING_FILTERS)
{
log::trace!("initiate io shutdown {:?}", self.flags.get());
self.insert_flags(Flags::IO_STOPPING_FILTERS);
self.read_task.wake();
self.write_task.wake();
if let Some(err) = err {
self.error.set(Some(err));
}
self.dispatch_task.wake();
}
}
#[inline]
pub(super) fn shutdown_filters(&self) {
let mut flags = self.flags.get();
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
if !self
.flags
.get()
.intersects(Flags::IO_STOPPED | Flags::IO_STOPPING)
{
match self.filter.get().poll_shutdown() {
Poll::Pending => return,
Poll::Ready(Ok(())) => {
flags.insert(Flags::IO_SHUTDOWN);
self.read_task.wake();
self.write_task.wake();
self.dispatch_task.wake();
self.insert_flags(Flags::IO_STOPPING);
}
Poll::Ready(Err(err)) => {
flags.insert(Flags::IO_ERR);
self.error.set(Some(err));
self.io_stopped(Some(err));
}
Poll::Pending => (),
}
self.flags.set(flags);
self.read_task.wake();
self.write_task.wake();
self.dispatch_task.wake();
}
}
@ -264,7 +253,7 @@ impl Io {
let io_ref = IoRef(inner);
// start io tasks
let hnd = io.start(ReadContext(io_ref.clone()), WriteContext(io_ref.clone()));
let hnd = io.start(ReadContext::new(&io_ref), WriteContext::new(&io_ref));
io_ref.0.handle.set(hnd);
Io(io_ref, FilterItem::Ptr(Box::into_raw(filter)))
@ -331,6 +320,11 @@ impl<F> Io<F> {
pub fn reset_keepalive(&self) {
self.0 .0.remove_flags(Flags::DSP_KEEPALIVE)
}
/// Get current io error
fn error(&self) -> Option<io::Error> {
self.0 .0.error.take()
}
}
impl Io<Sealed> {
@ -478,13 +472,13 @@ impl<F> Io<F> {
/// Wake write task and instruct to flush data.
///
/// This is async version of .poll_flush() method.
pub async fn flush(&self, full: bool) -> Result<(), io::Error> {
pub async fn flush(&self, full: bool) -> io::Result<()> {
poll_fn(|cx| self.poll_flush(cx, full)).await
}
#[inline]
/// Shut down io stream
pub async fn shutdown(&self) -> Result<(), io::Error> {
pub async fn shutdown(&self) -> io::Result<()> {
poll_fn(|cx| self.poll_shutdown(cx)).await
}
@ -503,16 +497,13 @@ impl<F> Io<F> {
/// `Poll::Ready(Ok(None))` if io stream is disconnected
/// `Some(Poll::Ready(Err(e)))` if an error is encountered.
pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<Option<()>>> {
if !self.0 .0.is_io_open() {
if let Some(err) = self.0 .0.error.take() {
Poll::Ready(Err(err))
} else {
Poll::Ready(Ok(None))
}
let mut flags = self.0 .0.flags.get();
if flags.contains(Flags::IO_STOPPED) {
Poll::Ready(self.error().map(Err).unwrap_or(Ok(None)))
} else {
self.0 .0.dispatch_task.register(cx.waker());
let mut flags = self.0 .0.flags.get();
let ready = flags.contains(Flags::RD_READY);
if flags.intersects(Flags::RD_BUF_FULL | Flags::RD_PAUSED) {
if flags.intersects(Flags::RD_BUF_FULL) {
@ -540,7 +531,6 @@ impl<F> Io<F> {
}
#[inline]
#[allow(clippy::type_complexity)]
/// Decode codec item from incoming bytes stream.
///
/// Wake read task and request to read more data if data is not enough for decoding.
@ -556,13 +546,10 @@ impl<F> Io<F> {
match self.decode(codec) {
Ok(Some(el)) => Poll::Ready(Ok(el)),
Ok(None) => {
if !self.0 .0.is_io_open() {
return Poll::Ready(Err(RecvError::PeerGone(
self.0 .0.error.take(),
)));
}
let flags = self.flags();
if flags.contains(Flags::DSP_STOP) {
if flags.contains(Flags::IO_STOPPED) {
Poll::Ready(Err(RecvError::PeerGone(self.error())))
} else if flags.contains(Flags::DSP_STOP) {
Poll::Ready(Err(RecvError::Stop))
} else if flags.contains(Flags::DSP_KEEPALIVE) {
Poll::Ready(Err(RecvError::KeepAlive))
@ -594,64 +581,51 @@ impl<F> Io<F> {
/// otherwise wake up when size of write buffer is lower than
/// buffer max size.
pub fn poll_flush(&self, cx: &mut Context<'_>, full: bool) -> Poll<io::Result<()>> {
// check io error
if !self.0 .0.is_io_open() {
self.0 .0.remove_flags(
Flags::DSP_KEEPALIVE
| Flags::RD_PAUSED
| Flags::RD_READY
| Flags::RD_BUF_FULL
| Flags::WR_WAIT
| Flags::WR_BACKPRESSURE,
);
return Poll::Ready(Err(self
let flags = self.flags();
if flags.contains(Flags::IO_STOPPED) {
Poll::Ready(self.error().map(Err).unwrap_or(Ok(())))
} else {
let len = self
.0
.0
.error
.take()
.unwrap_or_else(|| io::Error::new(io::ErrorKind::Other, "disconnected"))));
}
.with_write_buf(|buf| buf.as_ref().map(|b| b.len()).unwrap_or(0));
let len = self
.0
.0
.with_write_buf(|buf| buf.as_ref().map(|b| b.len()).unwrap_or(0));
if len > 0 {
if full {
self.0 .0.insert_flags(Flags::WR_WAIT);
self.0 .0.dispatch_task.register(cx.waker());
return Poll::Pending;
} else if len >= self.0.memory_pool().write_params_high() << 1 {
self.0 .0.insert_flags(Flags::WR_BACKPRESSURE);
self.0 .0.dispatch_task.register(cx.waker());
return Poll::Pending;
if len > 0 {
if full {
self.0 .0.insert_flags(Flags::WR_WAIT);
self.0 .0.dispatch_task.register(cx.waker());
return Poll::Pending;
} else if len >= self.0.memory_pool().write_params_high() << 1 {
self.0 .0.insert_flags(Flags::WR_BACKPRESSURE);
self.0 .0.dispatch_task.register(cx.waker());
return Poll::Pending;
}
}
self.0
.0
.remove_flags(Flags::WR_WAIT | Flags::WR_BACKPRESSURE);
Poll::Ready(Ok(()))
}
self.0
.0
.remove_flags(Flags::WR_WAIT | Flags::WR_BACKPRESSURE);
Poll::Ready(Ok(()))
}
#[inline]
/// Shut down io stream
pub fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
pub fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let flags = self.flags();
if flags.intersects(Flags::IO_ERR) {
Poll::Ready(Ok(()))
} else {
if !flags.contains(Flags::IO_FILTERS) {
self.0 .0.init_shutdown(None);
}
if let Some(err) = self.0 .0.error.take() {
if flags.intersects(Flags::IO_STOPPED) {
if let Some(err) = self.error() {
Poll::Ready(Err(err))
} else {
self.0 .0.dispatch_task.register(cx.waker());
Poll::Pending
Poll::Ready(Ok(()))
}
} else {
if !flags.contains(Flags::IO_STOPPING_FILTERS) {
self.0 .0.init_shutdown(None);
}
self.0 .0.dispatch_task.register(cx.waker());
Poll::Pending
}
}
}
@ -708,7 +682,7 @@ pub struct OnDisconnect {
impl OnDisconnect {
pub(super) fn new(inner: Rc<IoState>) -> Self {
Self::new_inner(inner.flags.get().contains(Flags::IO_ERR), inner)
Self::new_inner(inner.flags.get().contains(Flags::IO_STOPPED), inner)
}
fn new_inner(disconnected: bool, inner: Rc<IoState>) -> Self {

View file

@ -32,73 +32,10 @@ impl IoRef {
self.0.pool.get()
}
#[inline]
/// Check if io is still active
pub fn is_io_open(&self) -> bool {
self.0.is_io_open()
}
#[inline]
/// Check if keep-alive timeout occured
pub fn is_keepalive(&self) -> bool {
self.0.flags.get().contains(Flags::DSP_KEEPALIVE)
}
#[inline]
/// Check if io stream is closed
pub fn is_closed(&self) -> bool {
self.0.flags.get().intersects(
Flags::IO_ERR | Flags::IO_SHUTDOWN | Flags::IO_FILTERS | Flags::DSP_STOP,
)
}
#[inline]
/// Take io error if any occured
pub fn take_error(&self) -> Option<io::Error> {
self.0.error.take()
}
#[inline]
/// Wake dispatcher task
pub fn wake_dispatcher(&self) {
self.0.dispatch_task.wake();
}
#[inline]
/// Gracefully close connection
///
/// First stop dispatcher, then dispatcher stops io tasks
pub fn close(&self) {
self.0.insert_flags(Flags::DSP_STOP);
self.0.dispatch_task.wake();
}
#[inline]
/// Force close connection
///
/// Dispatcher does not wait for uncompleted responses, but flushes io buffers.
pub fn force_close(&self) {
log::trace!("force close io stream object");
self.0.insert_flags(Flags::DSP_STOP | Flags::IO_SHUTDOWN);
self.0.read_task.wake();
self.0.write_task.wake();
self.0.dispatch_task.wake();
}
#[inline]
/// Notify when io stream get disconnected
pub fn on_disconnect(&self) -> OnDisconnect {
OnDisconnect::new(self.0.clone())
}
#[inline]
/// Query specific data
pub fn query<T: 'static>(&self) -> types::QueryItem<T> {
if let Some(item) = self.filter().query(any::TypeId::of::<T>()) {
types::QueryItem::new(item)
} else {
types::QueryItem::empty()
}
self.0.flags.get().contains(Flags::IO_STOPPING)
}
#[inline]
@ -131,6 +68,55 @@ impl IoRef {
len >= self.memory_pool().read_params_high()
}
#[inline]
/// Wake dispatcher task
pub fn wake(&self) {
self.0.dispatch_task.wake();
}
#[inline]
/// Gracefully close connection
///
/// First stop dispatcher, then dispatcher stops io tasks
pub fn close(&self) {
self.0.insert_flags(Flags::DSP_STOP);
self.0.dispatch_task.wake();
}
#[inline]
/// Force close connection
///
/// Dispatcher does not wait for uncompleted responses, but flushes io buffers.
pub fn force_close(&self) {
log::trace!("force close io stream object");
self.0.insert_flags(Flags::DSP_STOP | Flags::IO_STOPPING);
self.0.read_task.wake();
self.0.write_task.wake();
self.0.dispatch_task.wake();
}
#[inline]
/// Gracefully shutdown io stream
pub fn want_shutdown(&self, err: Option<io::Error>) {
self.0.init_shutdown(err);
}
#[inline]
/// Notify when io stream get disconnected
pub fn on_disconnect(&self) -> OnDisconnect {
OnDisconnect::new(self.0.clone())
}
#[inline]
/// Query specific data
pub fn query<T: 'static>(&self) -> types::QueryItem<T> {
if let Some(item) = self.filter().query(any::TypeId::of::<T>()) {
types::QueryItem::new(item)
} else {
types::QueryItem::empty()
}
}
#[inline]
/// Get mut access to write buffer
pub fn with_write_buf<F, R>(&self, f: F) -> Result<R, io::Error>
@ -176,7 +162,7 @@ impl IoRef {
{
let flags = self.0.flags.get();
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
if !flags.contains(Flags::IO_STOPPING) {
self.with_write_buf(|buf| {
let (hw, lw) = self.memory_pool().write_params().unpack();
@ -191,7 +177,7 @@ impl IoRef {
})
.map_or_else(
|err| {
self.0.set_error(Some(err));
self.0.io_stopped(Some(err));
Ok(())
},
|item| item,
@ -223,7 +209,7 @@ impl IoRef {
pub fn write(&self, src: &[u8]) -> io::Result<()> {
let flags = self.0.flags.get();
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
if !flags.intersects(Flags::IO_STOPPING) {
self.with_write_buf(|buf| {
buf.extend_from_slice(src);
})
@ -283,7 +269,7 @@ mod tests {
client.read_error(io::Error::new(io::ErrorKind::Other, "err"));
let msg = state.recv(&BytesCodec).await;
assert!(msg.is_err());
assert!(state.flags().contains(Flags::IO_ERR));
assert!(state.flags().contains(Flags::IO_STOPPED));
let (client, server) = IoTest::create();
client.remote_buffer_cap(1024);
@ -293,7 +279,7 @@ mod tests {
let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
if let Poll::Ready(msg) = res {
assert!(msg.is_err());
assert!(state.flags().contains(Flags::IO_ERR));
assert!(state.flags().contains(Flags::IO_STOPPED));
assert!(state.flags().contains(Flags::DSP_STOP));
}
@ -310,14 +296,14 @@ mod tests {
client.write_error(io::Error::new(io::ErrorKind::Other, "err"));
let res = state.send(Bytes::from_static(b"test"), &BytesCodec).await;
assert!(res.is_err());
assert!(state.flags().contains(Flags::IO_ERR));
assert!(state.flags().contains(Flags::IO_STOPPED));
let (client, server) = IoTest::create();
client.remote_buffer_cap(1024);
let state = Io::new(server);
state.force_close();
assert!(state.flags().contains(Flags::DSP_STOP));
assert!(state.flags().contains(Flags::IO_SHUTDOWN));
assert!(state.flags().contains(Flags::IO_STOPPING));
}
#[ntex::test]
@ -389,10 +375,6 @@ mod tests {
Poll::Ready(Ok(()))
}
fn want_read(&self) {}
fn want_shutdown(&self, _: Option<io::Error>) {}
fn query(&self, _: std::any::TypeId) -> Option<Box<dyn std::any::Any>> {
None
}
@ -401,21 +383,18 @@ mod tests {
self.inner.poll_read_ready(cx)
}
fn closed(&self, err: Option<io::Error>) {
self.inner.closed(err)
}
fn get_read_buf(&self) -> Option<BytesMut> {
self.inner.get_read_buf()
}
fn release_read_buf(
&self,
io: &IoRef,
buf: BytesMut,
dst: &mut Option<BytesMut>,
new_bytes: usize,
) -> io::Result<usize> {
let result = self.inner.release_read_buf(buf, dst, new_bytes)?;
let result = self.inner.release_read_buf(io, buf, dst, new_bytes)?;
self.read_order.borrow_mut().push(self.idx);
self.in_bytes.set(self.in_bytes.get() + result);
Ok(result)

View file

@ -55,24 +55,13 @@ pub enum WriteStatus {
pub trait Filter: 'static {
fn query(&self, id: TypeId) -> Option<Box<dyn Any>>;
/// Filter needs incoming data from io stream
fn want_read(&self);
/// Filter wants gracefully shutdown io stream
fn want_shutdown(&self, err: Option<sio::Error>);
fn poll_shutdown(&self) -> Poll<sio::Result<()>>;
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus>;
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus>;
fn get_read_buf(&self) -> Option<BytesMut>;
fn get_write_buf(&self) -> Option<BytesMut>;
fn release_read_buf(
&self,
io: &IoRef,
src: BytesMut,
dst: &mut Option<BytesMut>,
nbytes: usize,
@ -80,7 +69,11 @@ pub trait Filter: 'static {
fn release_write_buf(&self, buf: BytesMut) -> sio::Result<()>;
fn closed(&self, err: Option<sio::Error>);
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus>;
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus>;
fn poll_shutdown(&self) -> Poll<sio::Result<()>>;
}
pub trait FilterFactory<F: Filter>: Sized {

View file

@ -4,9 +4,13 @@ use ntex_bytes::{BytesMut, PoolRef};
use super::{io::Flags, IoRef, ReadStatus, WriteStatus};
pub struct ReadContext(pub(super) IoRef);
pub struct ReadContext(IoRef);
impl ReadContext {
pub(crate) fn new(io: &IoRef) -> Self {
Self(io.clone())
}
#[inline]
pub fn memory_pool(&self) -> PoolRef {
self.0.memory_pool()
@ -17,11 +21,6 @@ impl ReadContext {
self.0.filter().poll_read_ready(cx)
}
#[inline]
pub fn close(&self, err: Option<io::Error>) {
self.0.filter().closed(err);
}
#[inline]
pub fn get_read_buf(&self) -> BytesMut {
self.0
@ -37,7 +36,7 @@ impl ReadContext {
} else {
let filter = self.0.filter();
let mut dst = self.0 .0.read_buf.take();
let result = filter.release_read_buf(buf, &mut dst, nbytes);
let result = filter.release_read_buf(&self.0, buf, &mut dst, nbytes);
let nbytes = result.as_ref().map(|i| *i).unwrap_or(0);
if let Some(dst) = dst {
@ -63,19 +62,28 @@ impl ReadContext {
if let Err(err) = result {
self.0 .0.dispatch_task.wake();
self.0 .0.insert_flags(Flags::RD_READY);
filter.want_shutdown(Some(err));
self.0.want_shutdown(Some(err));
}
}
if self.0.flags().contains(Flags::IO_FILTERS) {
if self.0.flags().contains(Flags::IO_STOPPING_FILTERS) {
self.0 .0.shutdown_filters();
}
}
#[inline]
pub fn close(&self, err: Option<io::Error>) {
self.0 .0.io_stopped(err);
}
}
pub struct WriteContext(pub(super) IoRef);
pub struct WriteContext(IoRef);
impl WriteContext {
pub(crate) fn new(io: &IoRef) -> Self {
Self(io.clone())
}
#[inline]
pub fn memory_pool(&self) -> PoolRef {
self.0.memory_pool()
@ -86,11 +94,6 @@ impl WriteContext {
self.0.filter().poll_write_ready(cx)
}
#[inline]
pub fn close(&self, err: Option<io::Error>) {
self.0.filter().closed(err)
}
#[inline]
pub fn get_write_buf(&self) -> Option<BytesMut> {
self.0 .0.write_buf.take()
@ -120,9 +123,15 @@ impl WriteContext {
self.0 .0.write_buf.set(Some(buf))
}
if flags.contains(Flags::IO_FILTERS) {
if self.0.flags().contains(Flags::IO_STOPPING_FILTERS) {
self.0 .0.shutdown_filters();
}
Ok(())
}
#[inline]
pub fn close(&self, err: Option<io::Error>) {
self.0 .0.io_stopped(err);
}
}

View file

@ -638,6 +638,7 @@ pub(super) fn flush_io(
Poll::Pending
}
} else {
let _ = state.release_write_buf(buf);
Poll::Ready(true)
}
}

View file

@ -359,6 +359,9 @@ pub(super) fn flush_io<T: AsyncRead + AsyncWrite + Unpin>(
Poll::Ready(false)
}
}
} else if let Err(e) = state.release_write_buf(buf) {
state.close(Some(e));
Poll::Ready(false)
} else {
Poll::Ready(true)
}