mirror of
https://github.com/ntex-rs/ntex.git
synced 2025-04-04 05:17:39 +03:00
Add openssl filter (#69)
* add ntex-openssl * cleanup io api * add filter shutdown
This commit is contained in:
parent
841ad736d4
commit
dafd339817
19 changed files with 1178 additions and 303 deletions
|
@ -37,3 +37,4 @@ tok-io = { version = "1", package = "tokio", default-features = false, features
|
|||
ntex = "0.4.13"
|
||||
futures = "0.3.13"
|
||||
rand = "0.8"
|
||||
env_logger = "0.9"
|
|
@ -1,6 +1,6 @@
|
|||
//! Framed transport dispatcher
|
||||
use std::{
|
||||
cell::Cell, future::Future, pin::Pin, rc::Rc, task::Context, task::Poll, time,
|
||||
cell::Cell, future::Future, io, pin::Pin, rc::Rc, task::Context, task::Poll, time,
|
||||
};
|
||||
|
||||
use ntex_bytes::Pool;
|
||||
|
@ -70,6 +70,7 @@ enum DispatcherError<S, U> {
|
|||
KeepAlive,
|
||||
Encoder(U),
|
||||
Service(S),
|
||||
Io(io::Error),
|
||||
}
|
||||
|
||||
enum PollService<U: Encoder + Decoder> {
|
||||
|
@ -171,10 +172,19 @@ where
|
|||
{
|
||||
fn handle_result(&self, item: Result<S::Response, S::Error>, write: WriteRef<'_>) {
|
||||
self.inflight.set(self.inflight.get() - 1);
|
||||
match write.encode_result(item, &self.codec) {
|
||||
Ok(true) => (),
|
||||
Ok(false) => write.enable_backpressure(None),
|
||||
Err(err) => self.error.set(Some(err.into())),
|
||||
match item {
|
||||
Ok(Some(val)) => match write.encode(val, &self.codec) {
|
||||
Ok(true) => (),
|
||||
Ok(false) => write.enable_backpressure(None),
|
||||
Err(Either::Left(err)) => {
|
||||
self.error.set(Some(DispatcherError::Encoder(err)))
|
||||
}
|
||||
Err(Either::Right(err)) => {
|
||||
self.error.set(Some(DispatcherError::Io(err)))
|
||||
}
|
||||
},
|
||||
Err(err) => self.error.set(Some(DispatcherError::Service(err))),
|
||||
Ok(None) => return,
|
||||
}
|
||||
write.wake_dispatcher();
|
||||
}
|
||||
|
@ -217,7 +227,10 @@ where
|
|||
match slf.st.get() {
|
||||
DispatcherState::Processing => {
|
||||
let result = match slf.poll_service(this.service, cx, read) {
|
||||
Poll::Pending => return Poll::Pending,
|
||||
Poll::Pending => {
|
||||
let _ = read.poll_ready(cx);
|
||||
return Poll::Pending;
|
||||
}
|
||||
Poll::Ready(result) => result,
|
||||
};
|
||||
|
||||
|
@ -237,8 +250,20 @@ where
|
|||
}
|
||||
Ok(None) => {
|
||||
log::trace!("not enough data to decode next frame, register dispatch task");
|
||||
read.wake(cx);
|
||||
return Poll::Pending;
|
||||
// service is ready, wake io read task
|
||||
match read.poll_ready(cx) {
|
||||
Poll::Pending
|
||||
| Poll::Ready(Ok(Some(()))) => {
|
||||
read.resume();
|
||||
return Poll::Pending;
|
||||
}
|
||||
Poll::Ready(Ok(None)) => {
|
||||
DispatchItem::Disconnect(None)
|
||||
}
|
||||
Poll::Ready(Err(err)) => {
|
||||
DispatchItem::Disconnect(Some(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
slf.st.set(DispatcherState::Stop);
|
||||
|
@ -248,8 +273,18 @@ where
|
|||
}
|
||||
} else {
|
||||
// no new events
|
||||
state.register_dispatcher(cx);
|
||||
return Poll::Pending;
|
||||
match read.poll_ready(cx) {
|
||||
Poll::Pending | Poll::Ready(Ok(Some(()))) => {
|
||||
read.resume();
|
||||
return Poll::Pending;
|
||||
}
|
||||
Poll::Ready(Ok(None)) => {
|
||||
DispatchItem::Disconnect(None)
|
||||
}
|
||||
Poll::Ready(Err(err)) => {
|
||||
DispatchItem::Disconnect(Some(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
PollService::Item(item) => item,
|
||||
|
@ -318,7 +353,7 @@ where
|
|||
|
||||
if slf.shared.inflight.get() == 0 {
|
||||
slf.st.set(DispatcherState::Shutdown);
|
||||
state.shutdown(cx);
|
||||
state.init_shutdown(cx);
|
||||
} else {
|
||||
state.register_dispatcher(cx);
|
||||
return Poll::Pending;
|
||||
|
@ -368,15 +403,19 @@ where
|
|||
item: Result<Option<<U as Encoder>::Item>, S::Error>,
|
||||
write: WriteRef<'_>,
|
||||
) {
|
||||
match write.encode_result(item, &self.shared.codec) {
|
||||
Ok(true) => (),
|
||||
Ok(false) => write.enable_backpressure(None),
|
||||
Err(Either::Left(err)) => {
|
||||
self.error.set(Some(err));
|
||||
}
|
||||
Err(Either::Right(err)) => {
|
||||
self.shared.error.set(Some(DispatcherError::Encoder(err)))
|
||||
}
|
||||
match item {
|
||||
Ok(Some(item)) => match write.encode(item, &self.shared.codec) {
|
||||
Ok(true) => (),
|
||||
Ok(false) => write.enable_backpressure(None),
|
||||
Err(Either::Left(err)) => {
|
||||
self.shared.error.set(Some(DispatcherError::Encoder(err)))
|
||||
}
|
||||
Err(Either::Right(err)) => {
|
||||
self.shared.error.set(Some(DispatcherError::Io(err)))
|
||||
}
|
||||
},
|
||||
Err(err) => self.shared.error.set(Some(DispatcherError::Service(err))),
|
||||
Ok(None) => (),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -388,9 +427,6 @@ where
|
|||
) -> Poll<PollService<U>> {
|
||||
match srv.poll_ready(cx) {
|
||||
Poll::Ready(Ok(_)) => {
|
||||
// service is ready, wake io read task
|
||||
read.resume();
|
||||
|
||||
// check keepalive timeout
|
||||
self.check_keepalive();
|
||||
|
||||
|
@ -407,6 +443,9 @@ where
|
|||
DispatcherError::Encoder(err) => {
|
||||
PollService::Item(DispatchItem::EncoderError(err))
|
||||
}
|
||||
DispatcherError::Io(err) => {
|
||||
PollService::Item(DispatchItem::Disconnect(Some(err)))
|
||||
}
|
||||
DispatcherError::Service(err) => {
|
||||
self.error.set(Some(err));
|
||||
PollService::ServiceError
|
||||
|
@ -425,7 +464,7 @@ where
|
|||
|
||||
// get io error
|
||||
if let Some(err) = self.state.take_error() {
|
||||
PollService::Item(DispatchItem::IoError(err))
|
||||
PollService::Item(DispatchItem::Disconnect(Some(err)))
|
||||
} else {
|
||||
PollService::ServiceError
|
||||
}
|
||||
|
@ -803,15 +842,15 @@ mod tests {
|
|||
|
||||
// response message
|
||||
assert!(!state.write().is_ready());
|
||||
assert_eq!(state.write().with_buf(|buf| buf.len()), 65536);
|
||||
assert_eq!(state.write().with_buf(|buf| buf.len()).unwrap(), 65536);
|
||||
|
||||
client.remote_buffer_cap(10240);
|
||||
sleep(Millis(50)).await;
|
||||
assert_eq!(state.write().with_buf(|buf| buf.len()), 55296);
|
||||
assert_eq!(state.write().with_buf(|buf| buf.len()).unwrap(), 55296);
|
||||
|
||||
client.remote_buffer_cap(45056);
|
||||
sleep(Millis(50)).await;
|
||||
assert_eq!(state.write().with_buf(|buf| buf.len()), 10240);
|
||||
assert_eq!(state.write().with_buf(|buf| buf.len()).unwrap(), 10240);
|
||||
|
||||
// backpressure disabled
|
||||
assert!(state.write().is_ready());
|
||||
|
@ -821,7 +860,6 @@ mod tests {
|
|||
#[ntex::test]
|
||||
async fn test_keepalive() {
|
||||
let (client, server) = IoTest::create();
|
||||
// do not allow to write to socket
|
||||
client.remote_buffer_cap(1024);
|
||||
client.write("GET /test HTTP/1\r\n\r\n");
|
||||
|
||||
|
@ -854,8 +892,7 @@ mod tests {
|
|||
.keepalive_timeout(Seconds(1))
|
||||
.await;
|
||||
});
|
||||
|
||||
state.0.disconnect_timeout.set(Seconds(1));
|
||||
state.0.disconnect_timeout.set(Millis::ONE_SEC);
|
||||
|
||||
let buf = client.read().await.unwrap();
|
||||
assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
|
||||
|
|
|
@ -2,7 +2,7 @@ use std::{io, rc::Rc, task::Context, task::Poll};
|
|||
|
||||
use ntex_bytes::BytesMut;
|
||||
|
||||
use super::state::{Flags, IoStateInner};
|
||||
use super::state::{Flags, IoRef, IoStateInner};
|
||||
use super::{Filter, ReadFilter, WriteFilter, WriteReadiness};
|
||||
|
||||
pub struct DefaultFilter(Rc<IoStateInner>);
|
||||
|
@ -13,7 +13,20 @@ impl DefaultFilter {
|
|||
}
|
||||
}
|
||||
|
||||
impl Filter for DefaultFilter {}
|
||||
impl Filter for DefaultFilter {
|
||||
#[inline]
|
||||
fn shutdown(&self, _: &IoRef) -> Poll<Result<(), io::Error>> {
|
||||
let mut flags = self.0.flags.get();
|
||||
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
|
||||
flags.insert(Flags::IO_SHUTDOWN);
|
||||
self.0.flags.set(flags);
|
||||
self.0.read_task.wake();
|
||||
self.0.write_task.wake();
|
||||
}
|
||||
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
impl ReadFilter for DefaultFilter {
|
||||
#[inline]
|
||||
|
@ -48,20 +61,20 @@ impl ReadFilter for DefaultFilter {
|
|||
}
|
||||
|
||||
#[inline]
|
||||
fn release_read_buf(&self, buf: BytesMut, new_bytes: usize) {
|
||||
if new_bytes > 0 {
|
||||
if buf.len() > self.0.pool.get().read_params().high as usize {
|
||||
log::trace!(
|
||||
"buffer is too large {}, enable read back-pressure",
|
||||
buf.len()
|
||||
);
|
||||
self.0.insert_flags(Flags::RD_READY | Flags::RD_BUF_FULL);
|
||||
} else {
|
||||
self.0.insert_flags(Flags::RD_READY);
|
||||
}
|
||||
self.0.dispatch_task.wake();
|
||||
fn release_read_buf(
|
||||
&self,
|
||||
buf: BytesMut,
|
||||
new_bytes: usize,
|
||||
) -> Result<(), io::Error> {
|
||||
if new_bytes > 0 && buf.len() > self.0.pool.get().read_params().high as usize {
|
||||
log::trace!(
|
||||
"buffer is too large {}, enable read back-pressure",
|
||||
buf.len()
|
||||
);
|
||||
self.0.insert_flags(Flags::RD_BUF_FULL);
|
||||
}
|
||||
self.0.read_buf.set(Some(buf));
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -71,12 +84,23 @@ impl WriteFilter for DefaultFilter {
|
|||
&self,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), WriteReadiness>> {
|
||||
let flags = self.0.flags.get();
|
||||
let mut flags = self.0.flags.get();
|
||||
|
||||
if flags.contains(Flags::IO_ERR) {
|
||||
Poll::Ready(Err(WriteReadiness::Terminate))
|
||||
} else if flags.intersects(Flags::IO_SHUTDOWN) {
|
||||
Poll::Ready(Err(WriteReadiness::Shutdown))
|
||||
Poll::Ready(Err(WriteReadiness::Shutdown(
|
||||
self.0.disconnect_timeout.get(),
|
||||
)))
|
||||
} else if flags.contains(Flags::IO_FILTERS)
|
||||
&& !flags.contains(Flags::IO_FILTERS_TO)
|
||||
{
|
||||
flags.insert(Flags::IO_FILTERS_TO);
|
||||
self.0.flags.set(flags);
|
||||
self.0.write_task.register(cx.waker());
|
||||
Poll::Ready(Err(WriteReadiness::Timeout(
|
||||
self.0.disconnect_timeout.get(),
|
||||
)))
|
||||
} else {
|
||||
self.0.write_task.register(cx.waker());
|
||||
Poll::Ready(Ok(()))
|
||||
|
@ -100,13 +124,15 @@ impl WriteFilter for DefaultFilter {
|
|||
}
|
||||
|
||||
#[inline]
|
||||
fn release_write_buf(&self, buf: BytesMut) {
|
||||
fn release_write_buf(&self, buf: BytesMut) -> Result<(), io::Error> {
|
||||
let pool = self.0.pool.get();
|
||||
if buf.is_empty() {
|
||||
pool.release_write_buf(buf);
|
||||
} else {
|
||||
self.0.write_buf.set(Some(buf));
|
||||
self.0.write_task.wake();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -120,7 +146,11 @@ impl NullFilter {
|
|||
}
|
||||
}
|
||||
|
||||
impl Filter for NullFilter {}
|
||||
impl Filter for NullFilter {
|
||||
fn shutdown(&self, _: &IoRef) -> Poll<Result<(), io::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
impl ReadFilter for NullFilter {
|
||||
fn poll_read_ready(&self, _: &mut Context<'_>) -> Poll<Result<(), ()>> {
|
||||
|
@ -133,7 +163,9 @@ impl ReadFilter for NullFilter {
|
|||
None
|
||||
}
|
||||
|
||||
fn release_read_buf(&self, _: BytesMut, _: usize) {}
|
||||
fn release_read_buf(&self, _: BytesMut, _: usize) -> Result<(), io::Error> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl WriteFilter for NullFilter {
|
||||
|
@ -147,5 +179,7 @@ impl WriteFilter for NullFilter {
|
|||
None
|
||||
}
|
||||
|
||||
fn release_write_buf(&self, _: BytesMut) {}
|
||||
fn release_write_buf(&self, _: BytesMut) -> Result<(), io::Error> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,19 +14,22 @@ mod tokio_impl;
|
|||
|
||||
use ntex_bytes::BytesMut;
|
||||
use ntex_codec::{Decoder, Encoder};
|
||||
use ntex_util::time::Millis;
|
||||
|
||||
pub use self::dispatcher::Dispatcher;
|
||||
pub use self::filter::DefaultFilter;
|
||||
pub use self::state::{Io, IoRef, ReadRef, WriteRef};
|
||||
pub use self::tasks::{ReadState, WriteState};
|
||||
pub use self::time::Timer;
|
||||
|
||||
pub use self::utils::{from_iostream, into_boxed};
|
||||
pub use self::utils::{filter_factory, from_iostream, into_boxed, into_io};
|
||||
|
||||
pub type IoBoxed = Io<Box<dyn Filter>>;
|
||||
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
|
||||
pub enum WriteReadiness {
|
||||
Shutdown,
|
||||
Timeout(Millis),
|
||||
Shutdown(Millis),
|
||||
Terminate,
|
||||
}
|
||||
|
||||
|
@ -37,7 +40,8 @@ pub trait ReadFilter {
|
|||
|
||||
fn get_read_buf(&self) -> Option<BytesMut>;
|
||||
|
||||
fn release_read_buf(&self, buf: BytesMut, new_bytes: usize);
|
||||
fn release_read_buf(&self, buf: BytesMut, new_bytes: usize)
|
||||
-> Result<(), io::Error>;
|
||||
}
|
||||
|
||||
pub trait WriteFilter {
|
||||
|
@ -48,10 +52,12 @@ pub trait WriteFilter {
|
|||
|
||||
fn get_write_buf(&self) -> Option<BytesMut>;
|
||||
|
||||
fn release_write_buf(&self, buf: BytesMut);
|
||||
fn release_write_buf(&self, buf: BytesMut) -> Result<(), io::Error>;
|
||||
}
|
||||
|
||||
pub trait Filter: ReadFilter + WriteFilter {}
|
||||
pub trait Filter: ReadFilter + WriteFilter {
|
||||
fn shutdown(&self, st: &IoRef) -> Poll<Result<(), io::Error>>;
|
||||
}
|
||||
|
||||
pub trait FilterFactory<F: Filter>: Sized {
|
||||
type Filter: Filter;
|
||||
|
@ -59,7 +65,7 @@ pub trait FilterFactory<F: Filter>: Sized {
|
|||
type Error: fmt::Debug;
|
||||
type Future: Future<Output = Result<Io<Self::Filter>, Self::Error>>;
|
||||
|
||||
fn create(&self, st: Io<F>) -> Self::Future;
|
||||
fn create(self, st: Io<F>) -> Self::Future;
|
||||
}
|
||||
|
||||
pub trait IoStream {
|
||||
|
@ -79,8 +85,8 @@ pub enum DispatchItem<U: Encoder + Decoder> {
|
|||
DecoderError(<U as Decoder>::Error),
|
||||
/// Encoder parse error
|
||||
EncoderError(<U as Encoder>::Error),
|
||||
/// Unexpected io error
|
||||
IoError(io::Error),
|
||||
/// Socket is disconnected
|
||||
Disconnect(Option<io::Error>),
|
||||
}
|
||||
|
||||
impl<U> fmt::Debug for DispatchItem<U>
|
||||
|
@ -108,8 +114,8 @@ where
|
|||
DispatchItem::DecoderError(ref e) => {
|
||||
write!(fmt, "DispatchItem::DecoderError({:?})", e)
|
||||
}
|
||||
DispatchItem::IoError(ref e) => {
|
||||
write!(fmt, "DispatchItem::IoError({:?})", e)
|
||||
DispatchItem::Disconnect(ref e) => {
|
||||
write!(fmt, "DispatchItem::Disconnect({:?})", e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -128,8 +134,8 @@ mod tests {
|
|||
assert!(format!("{:?}", err).contains("DispatchItem::Encoder"));
|
||||
let err = T::DecoderError(io::Error::new(io::ErrorKind::Other, "err"));
|
||||
assert!(format!("{:?}", err).contains("DispatchItem::Decoder"));
|
||||
let err = T::IoError(io::Error::new(io::ErrorKind::Other, "err"));
|
||||
assert!(format!("{:?}", err).contains("DispatchItem::IoError"));
|
||||
let err = T::Disconnect(Some(io::Error::new(io::ErrorKind::Other, "err")));
|
||||
assert!(format!("{:?}", err).contains("DispatchItem::Disconnect"));
|
||||
|
||||
assert!(format!("{:?}", T::WBackPressureEnabled)
|
||||
.contains("DispatchItem::WBackPressureEnabled"));
|
||||
|
|
|
@ -4,7 +4,8 @@ use std::{future::Future, hash, io, mem, ops::Deref, pin::Pin, ptr, rc::Rc};
|
|||
|
||||
use ntex_bytes::{BytesMut, PoolId, PoolRef};
|
||||
use ntex_codec::{Decoder, Encoder};
|
||||
use ntex_util::{future::poll_fn, future::Either, task::LocalWaker, time::Seconds};
|
||||
use ntex_util::time::{Millis, Seconds};
|
||||
use ntex_util::{future::poll_fn, future::Either, task::LocalWaker};
|
||||
|
||||
use super::filter::{DefaultFilter, NullFilter};
|
||||
use super::tasks::{ReadState, WriteState};
|
||||
|
@ -14,8 +15,12 @@ bitflags::bitflags! {
|
|||
pub struct Flags: u16 {
|
||||
/// io error occured
|
||||
const IO_ERR = 0b0000_0000_0000_0001;
|
||||
/// shuting down filters
|
||||
const IO_FILTERS = 0b0000_0000_0000_0010;
|
||||
/// shuting down filters timeout
|
||||
const IO_FILTERS_TO = 0b0000_0000_0000_0100;
|
||||
/// shutdown io tasks
|
||||
const IO_SHUTDOWN = 0b0000_0000_0000_0100;
|
||||
const IO_SHUTDOWN = 0b0000_0000_0000_1000;
|
||||
|
||||
/// pause io read
|
||||
const RD_PAUSED = 0b0000_0000_0000_1000;
|
||||
|
@ -51,7 +56,7 @@ pub struct IoRef(pub(super) Rc<IoStateInner>);
|
|||
pub(crate) struct IoStateInner {
|
||||
pub(super) flags: Cell<Flags>,
|
||||
pub(super) pool: Cell<PoolRef>,
|
||||
pub(super) disconnect_timeout: Cell<Seconds>,
|
||||
pub(super) disconnect_timeout: Cell<Millis>,
|
||||
pub(super) error: Cell<Option<io::Error>>,
|
||||
pub(super) read_task: LocalWaker,
|
||||
pub(super) write_task: LocalWaker,
|
||||
|
@ -77,6 +82,16 @@ impl IoStateInner {
|
|||
self.flags.set(flags);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(super) fn notify_keepalive(&self) {
|
||||
let mut flags = self.flags.get();
|
||||
if !flags.contains(Flags::DSP_KEEPALIVE) {
|
||||
flags.insert(Flags::DSP_KEEPALIVE);
|
||||
self.flags.set(flags);
|
||||
self.dispatch_task.wake();
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(super) fn notify_disconnect(&self) {
|
||||
let mut on_disconnect = self.on_disconnect.borrow_mut();
|
||||
|
@ -86,6 +101,36 @@ impl IoStateInner {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn is_io_err(&self) -> bool {
|
||||
self.flags.get().contains(Flags::IO_ERR)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(super) fn shutdown_filters(&self, st: &IoRef) -> Result<(), io::Error> {
|
||||
let mut flags = self.flags.get();
|
||||
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
|
||||
let result = match self.filter.get().shutdown(st) {
|
||||
Poll::Pending => return Ok(()),
|
||||
Poll::Ready(Ok(())) => {
|
||||
flags.insert(Flags::IO_SHUTDOWN);
|
||||
Ok(())
|
||||
}
|
||||
Poll::Ready(Err(err)) => {
|
||||
flags.insert(Flags::IO_ERR);
|
||||
self.dispatch_task.wake();
|
||||
Err(err)
|
||||
}
|
||||
};
|
||||
self.flags.set(flags);
|
||||
self.read_task.wake();
|
||||
self.write_task.wake();
|
||||
result
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for IoStateInner {}
|
||||
|
@ -130,7 +175,7 @@ impl Io {
|
|||
pool: Cell::new(pool),
|
||||
flags: Cell::new(Flags::empty()),
|
||||
error: Cell::new(None),
|
||||
disconnect_timeout: Cell::new(Seconds(1)),
|
||||
disconnect_timeout: Cell::new(Millis::ONE_SEC),
|
||||
dispatch_task: LocalWaker::new(),
|
||||
read_task: LocalWaker::new(),
|
||||
write_task: LocalWaker::new(),
|
||||
|
@ -147,10 +192,12 @@ impl Io {
|
|||
};
|
||||
inner.filter.replace(filter_ref);
|
||||
|
||||
// start io tasks
|
||||
io.start(ReadState(inner.clone()), WriteState(inner.clone()));
|
||||
let io_ref = IoRef(inner);
|
||||
|
||||
Io(IoRef(inner), FilterItem::Ptr(Box::into_raw(filter)))
|
||||
// start io tasks
|
||||
io.start(ReadState(io_ref.clone()), WriteState(io_ref.clone()));
|
||||
|
||||
Io(io_ref, FilterItem::Ptr(Box::into_raw(filter)))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -172,7 +219,7 @@ impl<F> Io<F> {
|
|||
#[inline]
|
||||
/// Set io disconnect timeout in secs
|
||||
pub fn set_disconnect_timeout(&self, timeout: Seconds) {
|
||||
self.0 .0.disconnect_timeout.set(timeout);
|
||||
self.0 .0.disconnect_timeout.set(timeout.into());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -201,6 +248,25 @@ impl<F> Io<F> {
|
|||
pub fn dispatcher_stopped(&self) {
|
||||
self.0 .0.insert_flags(Flags::DSP_STOP);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Gracefully shutdown read and write io tasks
|
||||
pub fn init_shutdown(&self, cx: &mut Context<'_>) {
|
||||
let flags = self.0 .0.flags.get();
|
||||
|
||||
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN | Flags::IO_FILTERS) {
|
||||
log::trace!("initiate io shutdown {:?}", flags);
|
||||
self.0 .0.insert_flags(Flags::IO_FILTERS);
|
||||
if let Err(err) = self.0 .0.shutdown_filters(&self.0) {
|
||||
self.0 .0.error.set(Some(err));
|
||||
self.0 .0.insert_flags(Flags::IO_ERR);
|
||||
}
|
||||
|
||||
self.0 .0.read_task.wake();
|
||||
self.0 .0.write_task.wake();
|
||||
self.0 .0.dispatch_task.register(cx.waker());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl IoRef {
|
||||
|
@ -220,7 +286,7 @@ impl IoRef {
|
|||
#[inline]
|
||||
/// Check if io error occured in read or write task
|
||||
pub fn is_io_err(&self) -> bool {
|
||||
self.0.flags.get().contains(Flags::IO_ERR)
|
||||
self.0.is_io_err()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
|
@ -244,20 +310,6 @@ impl IoRef {
|
|||
.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN | Flags::DSP_STOP)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Gracefully shutdown read and write io tasks
|
||||
pub fn shutdown(&self, cx: &mut Context<'_>) {
|
||||
let flags = self.0.flags.get();
|
||||
|
||||
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
|
||||
log::trace!("initiate io shutdown {:?}", flags);
|
||||
self.0.insert_flags(Flags::IO_SHUTDOWN);
|
||||
self.0.read_task.wake();
|
||||
self.0.write_task.wake();
|
||||
self.0.dispatch_task.register(cx.waker());
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Take io error if any occured
|
||||
pub fn take_error(&self) -> Option<io::Error> {
|
||||
|
@ -331,32 +383,21 @@ impl<F> Io<F> {
|
|||
};
|
||||
self.0 .0.read_buf.set(buf);
|
||||
|
||||
let result = match item {
|
||||
return match item {
|
||||
Ok(Some(el)) => Ok(Some(el)),
|
||||
Ok(None) => {
|
||||
self.0 .0.remove_flags(Flags::RD_READY);
|
||||
poll_fn(|cx| {
|
||||
if read.is_ready() {
|
||||
Poll::Ready(())
|
||||
} else {
|
||||
read.wake(cx);
|
||||
Poll::Pending
|
||||
}
|
||||
})
|
||||
.await;
|
||||
if self.is_io_err() {
|
||||
if let Some(err) = self.take_error() {
|
||||
Err(Either::Right(err))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
} else {
|
||||
continue;
|
||||
if poll_fn(|cx| read.poll_ready(cx))
|
||||
.await
|
||||
.map_err(Either::Right)?
|
||||
.is_none()
|
||||
{
|
||||
return Ok(None);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
Err(err) => Err(Either::Left(err)),
|
||||
};
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -364,8 +405,8 @@ impl<F> Io<F> {
|
|||
/// Encode item, send to a peer
|
||||
pub async fn send<U>(
|
||||
&self,
|
||||
codec: &U,
|
||||
item: U::Item,
|
||||
codec: &U,
|
||||
) -> Result<(), Either<U::Error, io::Error>>
|
||||
where
|
||||
U: Encoder,
|
||||
|
@ -374,31 +415,46 @@ impl<F> Io<F> {
|
|||
let mut buf = filter
|
||||
.get_write_buf()
|
||||
.unwrap_or_else(|| self.0 .0.pool.get().get_write_buf());
|
||||
|
||||
let is_write_sleep = buf.is_empty();
|
||||
codec.encode(item, &mut buf).map_err(Either::Left)?;
|
||||
filter.release_write_buf(buf);
|
||||
filter.release_write_buf(buf).map_err(Either::Right)?;
|
||||
self.0 .0.insert_flags(Flags::WR_WAIT);
|
||||
if is_write_sleep {
|
||||
self.0 .0.write_task.wake();
|
||||
}
|
||||
|
||||
poll_fn(|cx| {
|
||||
if !self.0 .0.flags.get().contains(Flags::WR_WAIT) || self.is_io_err() {
|
||||
Poll::Ready(())
|
||||
} else {
|
||||
self.register_dispatcher(cx);
|
||||
Poll::Pending
|
||||
}
|
||||
})
|
||||
.await;
|
||||
poll_fn(|cx| self.write().poll_flush(cx))
|
||||
.await
|
||||
.map_err(Either::Right)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
if self.is_io_err() {
|
||||
let err = self.0 .0.error.take().unwrap_or_else(|| {
|
||||
io::Error::new(io::ErrorKind::Other, "Internal error")
|
||||
});
|
||||
Err(Either::Right(err))
|
||||
} else {
|
||||
#[inline]
|
||||
/// Shuts down connection
|
||||
pub async fn shutdown(&self) -> Result<(), io::Error> {
|
||||
if self.flags().intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
|
||||
Ok(())
|
||||
} else {
|
||||
poll_fn(|cx| {
|
||||
let flags = self.flags();
|
||||
if !flags.contains(Flags::IO_FILTERS) {
|
||||
self.init_shutdown(cx);
|
||||
}
|
||||
|
||||
if self.flags().intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
|
||||
if let Some(err) = self.0 .0.error.take() {
|
||||
Poll::Ready(Err(err))
|
||||
} else {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
} else {
|
||||
self.0 .0.insert_flags(Flags::IO_FILTERS);
|
||||
self.0 .0.dispatch_task.register(cx.waker());
|
||||
Poll::Pending
|
||||
}
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -412,26 +468,52 @@ impl<F> Io<F> {
|
|||
where
|
||||
U: Decoder,
|
||||
{
|
||||
let mut buf = self.0 .0.read_buf.take();
|
||||
let item = if let Some(ref mut buf) = buf {
|
||||
codec.decode(buf)
|
||||
} else {
|
||||
Ok(None)
|
||||
};
|
||||
self.0 .0.read_buf.set(buf);
|
||||
if self
|
||||
.read()
|
||||
.poll_ready(cx)
|
||||
.map_err(Either::Right)?
|
||||
.is_ready()
|
||||
{
|
||||
let mut buf = self.0 .0.read_buf.take();
|
||||
let item = if let Some(ref mut buf) = buf {
|
||||
codec.decode(buf)
|
||||
} else {
|
||||
Ok(None)
|
||||
};
|
||||
self.0 .0.read_buf.set(buf);
|
||||
|
||||
match item {
|
||||
Ok(Some(el)) => Poll::Ready(Ok(Some(el))),
|
||||
Ok(None) => {
|
||||
self.read().wake(cx);
|
||||
Poll::Pending
|
||||
match item {
|
||||
Ok(Some(el)) => Poll::Ready(Ok(Some(el))),
|
||||
Ok(None) => {
|
||||
if let Poll::Ready(res) =
|
||||
self.read().poll_ready(cx).map_err(Either::Right)?
|
||||
{
|
||||
if res.is_none() {
|
||||
return Poll::Ready(Ok(None));
|
||||
}
|
||||
}
|
||||
Poll::Pending
|
||||
}
|
||||
Err(err) => Poll::Ready(Err(Either::Left(err))),
|
||||
}
|
||||
Err(err) => Poll::Ready(Err(Either::Left(err))),
|
||||
} else {
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Filter> Io<F> {
|
||||
#[inline]
|
||||
/// Get referece to filter
|
||||
pub fn filter(&self) -> &F {
|
||||
if let FilterItem::Ptr(p) = self.1 {
|
||||
if let Some(r) = unsafe { p.as_ref() } {
|
||||
return r;
|
||||
}
|
||||
}
|
||||
panic!()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn into_boxed(mut self) -> crate::IoBoxed
|
||||
where
|
||||
|
@ -457,18 +539,18 @@ impl<F: Filter> Io<F> {
|
|||
}
|
||||
|
||||
#[inline]
|
||||
pub async fn add_filter<T>(self, factory: &T) -> Result<Io<T::Filter>, T::Error>
|
||||
pub fn add_filter<T>(self, factory: T) -> T::Future
|
||||
where
|
||||
T: FilterFactory<F>,
|
||||
{
|
||||
factory.create(self).await
|
||||
factory.create(self)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn map_filter<T, U>(mut self, map: T) -> Io<U>
|
||||
pub fn map_filter<T, U>(mut self, map: U) -> Result<Io<T::Filter>, T::Error>
|
||||
where
|
||||
T: FnOnce(F) -> U,
|
||||
U: Filter,
|
||||
T: FilterFactory<F>,
|
||||
U: FnOnce(F) -> Result<T::Filter, T::Error>,
|
||||
{
|
||||
// replace current filter
|
||||
let filter = unsafe {
|
||||
|
@ -477,7 +559,7 @@ impl<F: Filter> Io<F> {
|
|||
FilterItem::Boxed(_) => panic!(),
|
||||
FilterItem::Ptr(p) => {
|
||||
assert!(!p.is_null());
|
||||
Box::new(map(*Box::from_raw(p)))
|
||||
Box::new(map(*Box::from_raw(p))?)
|
||||
}
|
||||
};
|
||||
let filter_ref: &'static dyn Filter = {
|
||||
|
@ -488,7 +570,7 @@ impl<F: Filter> Io<F> {
|
|||
filter
|
||||
};
|
||||
|
||||
Io(self.0.clone(), FilterItem::Ptr(Box::into_raw(filter)))
|
||||
Ok(Io(self.0.clone(), FilterItem::Ptr(Box::into_raw(filter))))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -562,7 +644,7 @@ impl<'a> WriteRef<'a> {
|
|||
|
||||
#[inline]
|
||||
/// Get mut access to write buffer
|
||||
pub fn with_buf<F, R>(&self, f: F) -> R
|
||||
pub fn with_buf<F, R>(&self, f: F) -> Result<R, io::Error>
|
||||
where
|
||||
F: FnOnce(&mut BytesMut) -> R,
|
||||
{
|
||||
|
@ -575,8 +657,8 @@ impl<'a> WriteRef<'a> {
|
|||
}
|
||||
|
||||
let result = f(&mut buf);
|
||||
filter.release_write_buf(buf);
|
||||
result
|
||||
filter.release_write_buf(buf)?;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
|
@ -587,7 +669,7 @@ impl<'a> WriteRef<'a> {
|
|||
&self,
|
||||
item: U::Item,
|
||||
codec: &U,
|
||||
) -> Result<bool, <U as Encoder>::Error>
|
||||
) -> Result<bool, Either<<U as Encoder>::Error, io::Error>>
|
||||
where
|
||||
U: Encoder,
|
||||
{
|
||||
|
@ -608,70 +690,45 @@ impl<'a> WriteRef<'a> {
|
|||
}
|
||||
|
||||
// encode item and wake write task
|
||||
let result = codec.encode(item, &mut buf).map(|_| {
|
||||
if is_write_sleep {
|
||||
self.0.write_task.wake();
|
||||
}
|
||||
buf.len() < hw
|
||||
});
|
||||
filter.release_write_buf(buf);
|
||||
result
|
||||
let result = codec
|
||||
.encode(item, &mut buf)
|
||||
.map(|_| {
|
||||
if is_write_sleep {
|
||||
self.0.write_task.wake();
|
||||
}
|
||||
buf.len() < hw
|
||||
})
|
||||
.map_err(Either::Left);
|
||||
filter.release_write_buf(buf).map_err(Either::Right)?;
|
||||
Ok(result?)
|
||||
} else {
|
||||
Ok(true)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Write item to a buf and wake up io task
|
||||
pub fn encode_result<U, E>(
|
||||
&self,
|
||||
item: Result<Option<U::Item>, E>,
|
||||
codec: &U,
|
||||
) -> Result<bool, Either<E, U::Error>>
|
||||
where
|
||||
U: Encoder,
|
||||
{
|
||||
let flags = self.0.flags.get();
|
||||
/// Wake write task and instruct to write all data.
|
||||
///
|
||||
/// When write task is done wake dispatcher.
|
||||
pub fn poll_flush(&self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
self.0.insert_flags(Flags::WR_WAIT);
|
||||
|
||||
if !flags.intersects(Flags::IO_ERR | Flags::DSP_ERR) {
|
||||
match item {
|
||||
Ok(Some(item)) => {
|
||||
let filter = self.0.filter.get();
|
||||
let mut buf = filter
|
||||
.get_write_buf()
|
||||
.unwrap_or_else(|| self.0.pool.get().get_write_buf());
|
||||
let is_write_sleep = buf.is_empty();
|
||||
let (hw, lw) = self.0.pool.get().write_params().unpack();
|
||||
|
||||
// make sure we've got room
|
||||
let remaining = buf.capacity() - buf.len();
|
||||
if remaining < lw {
|
||||
buf.reserve(hw - remaining);
|
||||
}
|
||||
|
||||
// encode item
|
||||
if let Err(err) = codec.encode(item, &mut buf) {
|
||||
log::trace!("Encoder error: {:?}", err);
|
||||
filter.release_write_buf(buf);
|
||||
self.0.insert_flags(Flags::DSP_STOP | Flags::DSP_ERR);
|
||||
self.0.dispatch_task.wake();
|
||||
return Err(Either::Right(err));
|
||||
} else if is_write_sleep {
|
||||
self.0.write_task.wake();
|
||||
}
|
||||
let result = Ok(buf.len() < hw);
|
||||
filter.release_write_buf(buf);
|
||||
result
|
||||
}
|
||||
Err(err) => {
|
||||
self.0.insert_flags(Flags::DSP_STOP | Flags::DSP_ERR);
|
||||
self.0.dispatch_task.wake();
|
||||
Err(Either::Left(err))
|
||||
}
|
||||
_ => Ok(true),
|
||||
if let Some(buf) = self.0.write_buf.take() {
|
||||
if !buf.is_empty() {
|
||||
self.0.write_buf.set(Some(buf));
|
||||
self.0.write_task.wake();
|
||||
self.0.dispatch_task.register(cx.waker());
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
|
||||
if self.0.is_io_err() {
|
||||
Poll::Ready(Err(self.0.error.take().unwrap_or_else(|| {
|
||||
io::Error::new(io::ErrorKind::Other, "disconnected")
|
||||
})))
|
||||
} else {
|
||||
Ok(true)
|
||||
self.0.dispatch_task.register(cx.waker());
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -720,28 +777,6 @@ impl<'a> ReadRef<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Wake read task and instruct to read more data
|
||||
///
|
||||
/// Only wakes if back-pressure is enabled on read task
|
||||
/// otherwise read is already awake.
|
||||
pub fn wake(&self, cx: &mut Context<'_>) {
|
||||
let mut flags = self.0.flags.get();
|
||||
flags.remove(Flags::RD_READY);
|
||||
if flags.contains(Flags::RD_BUF_FULL) {
|
||||
log::trace!("read back-pressure is enabled, wake io task");
|
||||
flags.remove(Flags::RD_BUF_FULL);
|
||||
self.0.read_task.wake();
|
||||
}
|
||||
if flags.contains(Flags::RD_PAUSED) {
|
||||
log::trace!("read is paused, wake io task");
|
||||
flags.remove(Flags::RD_PAUSED);
|
||||
self.0.read_task.wake();
|
||||
}
|
||||
self.0.flags.set(flags);
|
||||
self.0.dispatch_task.register(cx.waker());
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Attempts to decode a frame from the read buffer.
|
||||
pub fn decode<U>(
|
||||
|
@ -753,7 +788,11 @@ impl<'a> ReadRef<'a> {
|
|||
{
|
||||
let mut buf = self.0.read_buf.take();
|
||||
let result = if let Some(ref mut buf) = buf {
|
||||
codec.decode(buf)
|
||||
let result = codec.decode(buf);
|
||||
if result.as_ref().map(|v| v.is_none()).unwrap_or(false) {
|
||||
self.0.remove_flags(Flags::RD_READY);
|
||||
}
|
||||
result
|
||||
} else {
|
||||
self.0.remove_flags(Flags::RD_READY);
|
||||
Ok(None)
|
||||
|
@ -775,12 +814,46 @@ impl<'a> ReadRef<'a> {
|
|||
.unwrap_or_else(|| self.0.pool.get().get_read_buf());
|
||||
let res = f(&mut buf);
|
||||
if buf.is_empty() {
|
||||
self.0.remove_flags(Flags::RD_READY);
|
||||
self.0.pool.get().release_read_buf(buf);
|
||||
} else {
|
||||
self.0.read_buf.set(Some(buf));
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Wake read task and instruct to read more data
|
||||
///
|
||||
/// Only wakes if back-pressure is enabled on read task
|
||||
/// otherwise read is already awake.
|
||||
pub fn poll_ready(
|
||||
&self,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<Option<()>, io::Error>> {
|
||||
let mut flags = self.0.flags.get();
|
||||
let ready = flags.contains(Flags::RD_READY);
|
||||
|
||||
if self.0.is_io_err() {
|
||||
if let Some(err) = self.0.error.take() {
|
||||
Poll::Ready(Err(err))
|
||||
} else {
|
||||
Poll::Ready(Ok(None))
|
||||
}
|
||||
} else if ready {
|
||||
Poll::Ready(Ok(Some(())))
|
||||
} else {
|
||||
flags.remove(Flags::RD_READY);
|
||||
if flags.contains(Flags::RD_BUF_FULL) {
|
||||
log::trace!("read back-pressure is enabled, wake io task");
|
||||
flags.remove(Flags::RD_BUF_FULL);
|
||||
self.0.read_task.wake();
|
||||
}
|
||||
self.0.flags.set(flags);
|
||||
self.0.dispatch_task.register(cx.waker());
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// OnDisconnect future resolves when socket get disconnected
|
||||
|
@ -866,6 +939,7 @@ mod tests {
|
|||
|
||||
#[ntex::test]
|
||||
async fn utils() {
|
||||
env_logger::init();
|
||||
let (client, server) = IoTest::create();
|
||||
client.remote_buffer_cap(1024);
|
||||
client.write(TEXT);
|
||||
|
@ -907,14 +981,14 @@ mod tests {
|
|||
client.remote_buffer_cap(1024);
|
||||
let state = Io::new(server);
|
||||
state
|
||||
.send(&BytesCodec, Bytes::from_static(b"test"))
|
||||
.send(Bytes::from_static(b"test"), &BytesCodec)
|
||||
.await
|
||||
.unwrap();
|
||||
let buf = client.read().await.unwrap();
|
||||
assert_eq!(buf, Bytes::from_static(b"test"));
|
||||
|
||||
client.write_error(io::Error::new(io::ErrorKind::Other, "err"));
|
||||
let res = state.send(&BytesCodec, Bytes::from_static(b"test")).await;
|
||||
let res = state.send(Bytes::from_static(b"test"), &BytesCodec).await;
|
||||
assert!(res.is_err());
|
||||
assert!(state.flags().contains(Flags::IO_ERR));
|
||||
assert!(state.flags().contains(Flags::DSP_STOP));
|
||||
|
@ -967,7 +1041,11 @@ mod tests {
|
|||
in_bytes: Rc<Cell<usize>>,
|
||||
out_bytes: Rc<Cell<usize>>,
|
||||
}
|
||||
impl<F: ReadFilter + WriteFilter> Filter for Counter<F> {}
|
||||
impl<F: ReadFilter + WriteFilter> Filter for Counter<F> {
|
||||
fn shutdown(&self, _: &IoRef) -> Poll<Result<(), io::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: ReadFilter> ReadFilter for Counter<F> {
|
||||
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), ()>> {
|
||||
|
@ -982,9 +1060,13 @@ mod tests {
|
|||
self.inner.get_read_buf()
|
||||
}
|
||||
|
||||
fn release_read_buf(&self, buf: BytesMut, new_bytes: usize) {
|
||||
fn release_read_buf(
|
||||
&self,
|
||||
buf: BytesMut,
|
||||
new_bytes: usize,
|
||||
) -> Result<(), io::Error> {
|
||||
self.in_bytes.set(self.in_bytes.get() + new_bytes);
|
||||
self.inner.release_read_buf(buf, new_bytes);
|
||||
self.inner.release_read_buf(buf, new_bytes)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1009,9 +1091,9 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
fn release_write_buf(&self, buf: BytesMut) {
|
||||
fn release_write_buf(&self, buf: BytesMut) -> Result<(), io::Error> {
|
||||
self.out_bytes.set(self.out_bytes.get() + buf.len());
|
||||
self.inner.release_write_buf(buf);
|
||||
self.inner.release_write_buf(buf)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1023,14 +1105,19 @@ mod tests {
|
|||
type Error = ();
|
||||
type Future = Ready<Io<Counter<F>>, Self::Error>;
|
||||
|
||||
fn create(&self, st: Io<F>) -> Self::Future {
|
||||
fn create(self, io: Io<F>) -> Self::Future {
|
||||
let in_bytes = self.0.clone();
|
||||
let out_bytes = self.1.clone();
|
||||
Ready::Ok(st.map_filter(|inner| Counter {
|
||||
inner,
|
||||
in_bytes,
|
||||
out_bytes,
|
||||
}))
|
||||
Ready::Ok(
|
||||
io.map_filter::<CounterFactory, _>(|inner| {
|
||||
Ok(Counter {
|
||||
inner,
|
||||
in_bytes,
|
||||
out_bytes,
|
||||
})
|
||||
})
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1041,7 +1128,7 @@ mod tests {
|
|||
let factory = CounterFactory(in_bytes.clone(), out_bytes.clone());
|
||||
|
||||
let (client, server) = IoTest::create();
|
||||
let state = Io::new(server).add_filter(&factory).await.unwrap();
|
||||
let state = Io::new(server).add_filter(factory).await.unwrap();
|
||||
|
||||
client.remote_buffer_cap(1024);
|
||||
client.write(TEXT);
|
||||
|
@ -1049,7 +1136,7 @@ mod tests {
|
|||
assert_eq!(msg, Bytes::from_static(BIN));
|
||||
|
||||
state
|
||||
.send(&BytesCodec, Bytes::from_static(b"test"))
|
||||
.send(Bytes::from_static(b"test"), &BytesCodec)
|
||||
.await
|
||||
.unwrap();
|
||||
let buf = client.read().await.unwrap();
|
||||
|
@ -1066,10 +1153,10 @@ mod tests {
|
|||
|
||||
let (client, server) = IoTest::create();
|
||||
let state = Io::new(server)
|
||||
.add_filter(&CounterFactory(in_bytes.clone(), out_bytes.clone()))
|
||||
.add_filter(CounterFactory(in_bytes.clone(), out_bytes.clone()))
|
||||
.await
|
||||
.unwrap()
|
||||
.add_filter(&CounterFactory(in_bytes.clone(), out_bytes.clone()))
|
||||
.add_filter(CounterFactory(in_bytes.clone(), out_bytes.clone()))
|
||||
.await
|
||||
.unwrap();
|
||||
let state = state.into_boxed();
|
||||
|
@ -1080,7 +1167,7 @@ mod tests {
|
|||
assert_eq!(msg, Bytes::from_static(BIN));
|
||||
|
||||
state
|
||||
.send(&BytesCodec, Bytes::from_static(b"test"))
|
||||
.send(Bytes::from_static(b"test"), &BytesCodec)
|
||||
.await
|
||||
.unwrap();
|
||||
let buf = client.read().await.unwrap();
|
||||
|
|
|
@ -1,98 +1,115 @@
|
|||
use std::{io, rc::Rc, task::Context, task::Poll};
|
||||
use std::{io, task::Context, task::Poll};
|
||||
|
||||
use ntex_bytes::{BytesMut, PoolRef};
|
||||
use ntex_util::time::Seconds;
|
||||
|
||||
use super::{state::Flags, state::IoStateInner, WriteReadiness};
|
||||
use super::{state::Flags, IoRef, WriteReadiness};
|
||||
|
||||
pub struct ReadState(pub(super) Rc<IoStateInner>);
|
||||
pub struct ReadState(pub(super) IoRef);
|
||||
|
||||
impl ReadState {
|
||||
#[inline]
|
||||
pub fn memory_pool(&self) -> PoolRef {
|
||||
self.0.pool.get()
|
||||
self.0 .0.pool.get()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), ()>> {
|
||||
self.0.filter.get().poll_read_ready(cx)
|
||||
self.0 .0.filter.get().poll_read_ready(cx)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn close(&self, err: Option<io::Error>) {
|
||||
self.0.filter.get().read_closed(err);
|
||||
self.0 .0.filter.get().read_closed(err);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn get_read_buf(&self) -> BytesMut {
|
||||
self.0
|
||||
.0
|
||||
.filter
|
||||
.get()
|
||||
.get_read_buf()
|
||||
.unwrap_or_else(|| self.0.pool.get().get_read_buf())
|
||||
.unwrap_or_else(|| self.0 .0.pool.get().get_read_buf())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn release_read_buf(&self, buf: BytesMut, new_bytes: usize) {
|
||||
pub fn release_read_buf(
|
||||
&self,
|
||||
buf: BytesMut,
|
||||
new_bytes: usize,
|
||||
) -> Result<(), io::Error> {
|
||||
if buf.is_empty() {
|
||||
self.0.pool.get().release_read_buf(buf);
|
||||
self.0 .0.pool.get().release_read_buf(buf);
|
||||
Ok(())
|
||||
} else {
|
||||
self.0.filter.get().release_read_buf(buf, new_bytes);
|
||||
let mut flags = self.0 .0.flags.get();
|
||||
|
||||
// notify dispatcher
|
||||
if new_bytes > 0 {
|
||||
flags.insert(Flags::RD_READY);
|
||||
self.0 .0.flags.set(flags);
|
||||
self.0 .0.dispatch_task.wake();
|
||||
}
|
||||
self.0 .0.filter.get().release_read_buf(buf, new_bytes)?;
|
||||
|
||||
if flags.contains(Flags::IO_FILTERS) {
|
||||
self.0 .0.shutdown_filters(&self.0)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct WriteState(pub(super) Rc<IoStateInner>);
|
||||
pub struct WriteState(pub(super) IoRef);
|
||||
|
||||
impl WriteState {
|
||||
#[inline]
|
||||
pub fn memory_pool(&self) -> PoolRef {
|
||||
self.0.pool.get()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn disconnect_timeout(&self) -> Seconds {
|
||||
self.0.disconnect_timeout.get()
|
||||
self.0 .0.pool.get()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), WriteReadiness>> {
|
||||
self.0.filter.get().poll_write_ready(cx)
|
||||
self.0 .0.filter.get().poll_write_ready(cx)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn close(&self, err: Option<io::Error>) {
|
||||
self.0.filter.get().write_closed(err)
|
||||
self.0 .0.filter.get().write_closed(err)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn get_write_buf(&self) -> Option<BytesMut> {
|
||||
self.0.write_buf.take()
|
||||
self.0 .0.write_buf.take()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn release_write_buf(&self, buf: BytesMut) {
|
||||
let pool = self.0.pool.get();
|
||||
pub fn release_write_buf(&self, buf: BytesMut) -> Result<(), io::Error> {
|
||||
let pool = self.0 .0.pool.get();
|
||||
let mut flags = self.0 .0.flags.get();
|
||||
|
||||
if buf.is_empty() {
|
||||
pool.release_write_buf(buf);
|
||||
|
||||
let mut flags = self.0.flags.get();
|
||||
if flags.intersects(Flags::WR_WAIT | Flags::WR_BACKPRESSURE) {
|
||||
flags.remove(Flags::WR_WAIT | Flags::WR_BACKPRESSURE);
|
||||
self.0.flags.set(flags);
|
||||
self.0.dispatch_task.wake();
|
||||
self.0 .0.flags.set(flags);
|
||||
self.0 .0.dispatch_task.wake();
|
||||
}
|
||||
} else {
|
||||
// if write buffer is smaller than high watermark value, turn off back-pressure
|
||||
if buf.len() < pool.write_params_high() << 1 {
|
||||
let mut flags = self.0.flags.get();
|
||||
if flags.contains(Flags::WR_BACKPRESSURE) {
|
||||
flags.remove(Flags::WR_BACKPRESSURE);
|
||||
self.0.flags.set(flags);
|
||||
self.0.dispatch_task.wake();
|
||||
}
|
||||
if buf.len() < pool.write_params_high() << 1
|
||||
&& flags.contains(Flags::WR_BACKPRESSURE)
|
||||
{
|
||||
flags.remove(Flags::WR_BACKPRESSURE);
|
||||
self.0 .0.flags.set(flags);
|
||||
self.0 .0.dispatch_task.wake();
|
||||
}
|
||||
self.0.write_buf.set(Some(buf))
|
||||
self.0 .0.write_buf.set(Some(buf))
|
||||
}
|
||||
|
||||
if flags.contains(Flags::IO_FILTERS) {
|
||||
self.0 .0.shutdown_filters(&self.0)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@ use std::{
|
|||
use ntex_util::spawn;
|
||||
use ntex_util::time::{now, sleep, Millis};
|
||||
|
||||
use super::state::{Flags, IoRef, IoStateInner};
|
||||
use super::state::{IoRef, IoStateInner};
|
||||
|
||||
pub struct Timer(Rc<RefCell<Inner>>);
|
||||
|
||||
|
@ -79,8 +79,7 @@ impl Timer {
|
|||
let key = *key;
|
||||
if key <= now_time {
|
||||
for st in i.notifications.remove(&key).unwrap() {
|
||||
st.dispatch_task.wake();
|
||||
st.insert_flags(Flags::DSP_KEEPALIVE);
|
||||
st.notify_keepalive();
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
|
|
|
@ -69,8 +69,13 @@ where
|
|||
Poll::Ready(Ok(n)) => {
|
||||
if n == 0 {
|
||||
log::trace!("io stream is disconnected");
|
||||
this.state.release_read_buf(buf, new_bytes);
|
||||
this.state.close(None);
|
||||
if let Err(e) =
|
||||
this.state.release_read_buf(buf, new_bytes)
|
||||
{
|
||||
this.state.close(Some(e));
|
||||
} else {
|
||||
this.state.close(None);
|
||||
}
|
||||
return Poll::Ready(());
|
||||
} else {
|
||||
new_bytes += n;
|
||||
|
@ -81,15 +86,19 @@ where
|
|||
}
|
||||
Poll::Ready(Err(err)) => {
|
||||
log::trace!("read task failed on io {:?}", err);
|
||||
this.state.release_read_buf(buf, new_bytes);
|
||||
let _ = this.state.release_read_buf(buf, new_bytes);
|
||||
this.state.close(Some(err));
|
||||
return Poll::Ready(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
this.state.release_read_buf(buf, new_bytes);
|
||||
Poll::Pending
|
||||
if let Err(e) = this.state.release_read_buf(buf, new_bytes) {
|
||||
this.state.close(Some(e));
|
||||
Poll::Ready(())
|
||||
} else {
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
|
@ -98,8 +107,8 @@ where
|
|||
|
||||
#[derive(Debug)]
|
||||
enum IoWriteState {
|
||||
Processing,
|
||||
Shutdown(Option<Sleep>, Shutdown),
|
||||
Processing(Option<Sleep>),
|
||||
Shutdown(Sleep, Shutdown),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
@ -125,7 +134,7 @@ where
|
|||
Self {
|
||||
io,
|
||||
state,
|
||||
st: IoWriteState::Processing,
|
||||
st: IoWriteState::Processing(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -140,22 +149,41 @@ where
|
|||
let mut this = self.as_mut().get_mut();
|
||||
|
||||
match this.st {
|
||||
IoWriteState::Processing => {
|
||||
IoWriteState::Processing(ref mut delay) => {
|
||||
match this.state.poll_ready(cx) {
|
||||
Poll::Ready(Ok(())) => {
|
||||
if let Some(delay) = delay {
|
||||
if delay.poll_elapsed(cx).is_ready() {
|
||||
this.state.close(Some(io::Error::new(
|
||||
io::ErrorKind::TimedOut,
|
||||
"Operation timedout",
|
||||
)));
|
||||
return Poll::Ready(());
|
||||
}
|
||||
}
|
||||
|
||||
// flush framed instance
|
||||
match flush_io(&mut *this.io.borrow_mut(), &this.state, cx) {
|
||||
Poll::Pending | Poll::Ready(true) => Poll::Pending,
|
||||
Poll::Ready(false) => Poll::Ready(()),
|
||||
}
|
||||
}
|
||||
Poll::Ready(Err(WriteReadiness::Shutdown)) => {
|
||||
Poll::Ready(Err(WriteReadiness::Timeout(time))) => {
|
||||
if delay.is_none() {
|
||||
*delay = Some(sleep(time));
|
||||
}
|
||||
self.poll(cx)
|
||||
}
|
||||
Poll::Ready(Err(WriteReadiness::Shutdown(time))) => {
|
||||
log::trace!("write task is instructed to shutdown");
|
||||
|
||||
this.st = IoWriteState::Shutdown(
|
||||
this.state.disconnect_timeout().map(sleep),
|
||||
Shutdown::None,
|
||||
);
|
||||
let timeout = if let Some(delay) = delay.take() {
|
||||
delay
|
||||
} else {
|
||||
sleep(time)
|
||||
};
|
||||
|
||||
this.st = IoWriteState::Shutdown(timeout, Shutdown::None);
|
||||
self.poll(cx)
|
||||
}
|
||||
Poll::Ready(Err(WriteReadiness::Terminate)) => {
|
||||
|
@ -229,10 +257,8 @@ where
|
|||
}
|
||||
|
||||
// disconnect timeout
|
||||
if let Some(ref delay) = delay {
|
||||
if delay.poll_elapsed(cx).is_pending() {
|
||||
return Poll::Pending;
|
||||
}
|
||||
if delay.poll_elapsed(cx).is_pending() {
|
||||
return Poll::Pending;
|
||||
}
|
||||
log::trace!("write task is stopped after delay");
|
||||
this.state.close(None);
|
||||
|
@ -290,11 +316,17 @@ pub(super) fn flush_io<T: AsyncRead + AsyncWrite + Unpin>(
|
|||
// remove written data
|
||||
let result = if written == len {
|
||||
buf.clear();
|
||||
state.release_write_buf(buf);
|
||||
if let Err(e) = state.release_write_buf(buf) {
|
||||
state.close(Some(e));
|
||||
return Poll::Ready(false);
|
||||
}
|
||||
Poll::Ready(true)
|
||||
} else {
|
||||
buf.advance(written);
|
||||
state.release_write_buf(buf);
|
||||
if let Err(e) = state.release_write_buf(buf) {
|
||||
state.close(Some(e));
|
||||
return Poll::Ready(false);
|
||||
}
|
||||
Poll::Pending
|
||||
};
|
||||
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
use ntex_service::{fn_factory_with_config, into_service, Service, ServiceFactory};
|
||||
use std::{io, marker::PhantomData, task::Context, task::Poll};
|
||||
|
||||
use super::{Filter, Io, IoBoxed, IoStream};
|
||||
use ntex_service::{fn_factory_with_config, into_service, Service, ServiceFactory};
|
||||
use ntex_util::future::Ready;
|
||||
|
||||
use super::{Filter, FilterFactory, Io, IoBoxed, IoStream};
|
||||
|
||||
/// Service that converts any Io<F> stream to IoBoxed stream
|
||||
pub fn into_boxed<F, S>(
|
||||
|
@ -47,3 +50,81 @@ where
|
|||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Service that converts IoStream stream to Io stream
|
||||
pub fn into_io<I>() -> impl ServiceFactory<
|
||||
Config = (),
|
||||
Request = I,
|
||||
Response = Io,
|
||||
Error = io::Error,
|
||||
InitError = (),
|
||||
>
|
||||
where
|
||||
I: IoStream,
|
||||
{
|
||||
fn_factory_with_config(move |_: ()| {
|
||||
Ready::Ok(into_service(move |io| Ready::Ok(Io::new(io))))
|
||||
})
|
||||
}
|
||||
|
||||
/// Create filter factory service
|
||||
pub fn filter_factory<T, F>(filter: T) -> FilterServiceFactory<T, F>
|
||||
where
|
||||
T: FilterFactory<F> + Clone,
|
||||
F: Filter,
|
||||
{
|
||||
FilterServiceFactory {
|
||||
filter,
|
||||
_t: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FilterServiceFactory<T, F> {
|
||||
filter: T,
|
||||
_t: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<T, F> ServiceFactory for FilterServiceFactory<T, F>
|
||||
where
|
||||
T: FilterFactory<F> + Clone,
|
||||
F: Filter,
|
||||
{
|
||||
type Config = ();
|
||||
type Request = Io<F>;
|
||||
type Response = Io<T::Filter>;
|
||||
type Error = T::Error;
|
||||
type Service = FilterService<T, F>;
|
||||
type InitError = ();
|
||||
type Future = Ready<Self::Service, Self::InitError>;
|
||||
|
||||
fn new_service(&self, _: ()) -> Self::Future {
|
||||
Ready::Ok(FilterService {
|
||||
filter: self.filter.clone(),
|
||||
_t: PhantomData,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FilterService<T, F> {
|
||||
filter: T,
|
||||
_t: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<T, F> Service for FilterService<T, F>
|
||||
where
|
||||
T: FilterFactory<F> + Clone,
|
||||
F: Filter,
|
||||
{
|
||||
type Request = Io<F>;
|
||||
type Response = Io<T::Filter>;
|
||||
type Error = T::Error;
|
||||
type Future = T::Future;
|
||||
|
||||
fn poll_ready(&self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&self, req: Io<F>) -> Self::Future {
|
||||
req.add_filter(self.filter.clone())
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue