Cleanup Filter trait, removed closed,want_read,want_shutdown methods

This commit is contained in:
Nikolay Kim 2021-12-29 15:10:24 +06:00
parent c5d43eb12d
commit dc17d00ed9
21 changed files with 331 additions and 507 deletions

View file

@ -74,6 +74,11 @@ jobs:
continue-on-error: true continue-on-error: true
run: | run: |
cargo tarpaulin --out Xml --all --all-features cargo tarpaulin --out Xml --all --all-features
- name: Generate coverage report (async-std)
if: matrix.version == '1.56.0' && (github.ref == 'refs/heads/master' || github.event_name == 'pull_request')
continue-on-error: true
run: |
cd ntex cd ntex
cargo tarpaulin --out Xml --output-dir=.. --no-default-features --features="async-std,cookie,url,compress,openssl,rustls" --lib cargo tarpaulin --out Xml --output-dir=.. --no-default-features --features="async-std,cookie,url,compress,openssl,rustls" --lib

View file

@ -1,6 +1,8 @@
# Changes # Changes
## [0.1.0-b.10] - 2021-12-xx ## [0.1.0-b.10] - 2021-12-30
* Cleanup Filter trait, removed closed,want_read,want_shutdown methods
* Cleanup internal flags on io error * Cleanup internal flags on io error

View file

@ -1,6 +1,6 @@
[package] [package]
name = "ntex-io" name = "ntex-io"
version = "0.1.0-b.9" version = "0.1.0-b.10"
authors = ["ntex contributors <team@ntex.rs>"] authors = ["ntex contributors <team@ntex.rs>"]
description = "Utilities for encoding and decoding frames" description = "Utilities for encoding and decoding frames"
keywords = ["network", "framework", "async", "futures"] keywords = ["network", "framework", "async", "futures"]

View file

@ -63,7 +63,6 @@ enum DispatcherState {
} }
enum DispatcherError<S, U> { enum DispatcherError<S, U> {
KeepAlive,
Encoder(U), Encoder(U),
Service(S), Service(S),
} }
@ -176,7 +175,7 @@ where
Err(err) => self.error.set(Some(DispatcherError::Service(err))), Err(err) => self.error.set(Some(DispatcherError::Service(err))),
Ok(None) => return, Ok(None) => return,
} }
io.wake_dispatcher(); io.wake();
} }
} }
@ -382,18 +381,12 @@ where
) -> Poll<PollService<U>> { ) -> Poll<PollService<U>> {
match srv.poll_ready(cx) { match srv.poll_ready(cx) {
Poll::Ready(Ok(_)) => { Poll::Ready(Ok(_)) => {
// check keepalive timeout
self.check_keepalive();
// check for errors // check for errors
Poll::Ready(if let Some(err) = self.shared.error.take() { Poll::Ready(if let Some(err) = self.shared.error.take() {
log::trace!("error occured, stopping dispatcher"); log::trace!("error occured, stopping dispatcher");
self.st.set(DispatcherState::Stop); self.st.set(DispatcherState::Stop);
match err { match err {
DispatcherError::KeepAlive => {
PollService::Item(DispatchItem::KeepAliveTimeout)
}
DispatcherError::Encoder(err) => { DispatcherError::Encoder(err) => {
PollService::Item(DispatchItem::EncoderError(err)) PollService::Item(DispatchItem::EncoderError(err))
} }
@ -431,18 +424,6 @@ where
self.ka_timeout.get().non_zero() self.ka_timeout.get().non_zero()
} }
/// check keepalive timeout
fn check_keepalive(&self) {
if self.io.is_keepalive() {
log::trace!("keepalive timeout");
if let Some(err) = self.shared.error.take() {
self.shared.error.set(Some(err));
} else {
self.shared.error.set(Some(DispatcherError::KeepAlive));
}
}
}
/// update keep-alive timer /// update keep-alive timer
fn update_keepalive(&self) { fn update_keepalive(&self) {
if self.ka_enabled() { if self.ka_enabled() {
@ -790,6 +771,7 @@ mod tests {
#[ntex::test] #[ntex::test]
async fn test_keepalive() { async fn test_keepalive() {
env_logger::init();
let (client, server) = IoTest::create(); let (client, server) = IoTest::create();
client.remote_buffer_cap(1024); client.remote_buffer_cap(1024);
client.write("GET /test HTTP/1\r\n\r\n"); client.write("GET /test HTTP/1\r\n\r\n");
@ -831,7 +813,7 @@ mod tests {
// write side must be closed, dispatcher should fail with keep-alive // write side must be closed, dispatcher should fail with keep-alive
let flags = state.flags(); let flags = state.flags();
assert!(flags.contains(Flags::IO_SHUTDOWN)); assert!(flags.contains(Flags::IO_STOPPING));
assert!(flags.contains(Flags::DSP_KEEPALIVE)); assert!(flags.contains(Flags::DSP_KEEPALIVE));
assert!(client.is_closed()); assert!(client.is_closed());
assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]); assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);

View file

@ -14,13 +14,6 @@ impl Base {
} }
impl Filter for Base { impl Filter for Base {
#[inline]
fn closed(&self, err: Option<io::Error>) {
self.0 .0.set_error(err);
self.0 .0.handle.take();
self.0 .0.dispatch_task.wake();
}
#[inline] #[inline]
fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> { fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> {
if let Some(hnd) = self.0 .0.handle.take() { if let Some(hnd) = self.0 .0.handle.take() {
@ -32,25 +25,8 @@ impl Filter for Base {
} }
} }
#[inline]
fn want_read(&self) {
let flags = self.0.flags();
if flags.intersects(Flags::RD_PAUSED | Flags::RD_BUF_FULL) {
self.0
.0
.remove_flags(Flags::RD_PAUSED | Flags::RD_BUF_FULL);
self.0 .0.read_task.wake();
}
}
#[inline]
fn want_shutdown(&self, err: Option<io::Error>) {
self.0 .0.init_shutdown(err);
}
#[inline] #[inline]
fn poll_shutdown(&self) -> Poll<io::Result<()>> { fn poll_shutdown(&self) -> Poll<io::Result<()>> {
self.want_shutdown(None);
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
@ -58,7 +34,7 @@ impl Filter for Base {
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> { fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> {
let flags = self.0.flags(); let flags = self.0.flags();
if flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) { if flags.intersects(Flags::IO_STOPPING) {
Poll::Ready(ReadStatus::Terminate) Poll::Ready(ReadStatus::Terminate)
} else if flags.intersects(Flags::RD_PAUSED | Flags::RD_BUF_FULL) { } else if flags.intersects(Flags::RD_PAUSED | Flags::RD_BUF_FULL) {
self.0 .0.read_task.register(cx.waker()); self.0 .0.read_task.register(cx.waker());
@ -73,13 +49,14 @@ impl Filter for Base {
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> { fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> {
let mut flags = self.0.flags(); let mut flags = self.0.flags();
if flags.contains(Flags::IO_ERR) { if flags.contains(Flags::IO_STOPPED) {
Poll::Ready(WriteStatus::Terminate) Poll::Ready(WriteStatus::Terminate)
} else if flags.intersects(Flags::IO_SHUTDOWN) { } else if flags.intersects(Flags::IO_STOPPING) {
Poll::Ready(WriteStatus::Shutdown(self.0 .0.disconnect_timeout.get())) Poll::Ready(WriteStatus::Shutdown(self.0 .0.disconnect_timeout.get()))
} else if flags.contains(Flags::IO_FILTERS) && !flags.contains(Flags::IO_FILTERS_TO) } else if flags.contains(Flags::IO_STOPPING_FILTERS)
&& !flags.contains(Flags::IO_FILTERS_TIMEOUT)
{ {
flags.insert(Flags::IO_FILTERS_TO); flags.insert(Flags::IO_FILTERS_TIMEOUT);
self.0.set_flags(flags); self.0.set_flags(flags);
self.0 .0.write_task.register(cx.waker()); self.0 .0.write_task.register(cx.waker());
Poll::Ready(WriteStatus::Timeout(self.0 .0.disconnect_timeout.get())) Poll::Ready(WriteStatus::Timeout(self.0 .0.disconnect_timeout.get()))
@ -102,6 +79,7 @@ impl Filter for Base {
#[inline] #[inline]
fn release_read_buf( fn release_read_buf(
&self, &self,
_: &IoRef,
buf: BytesMut, buf: BytesMut,
dst: &mut Option<BytesMut>, dst: &mut Option<BytesMut>,
nbytes: usize, nbytes: usize,
@ -145,12 +123,6 @@ impl Filter for NullFilter {
None None
} }
fn closed(&self, _: Option<io::Error>) {}
fn want_read(&self) {}
fn want_shutdown(&self, _: Option<io::Error>) {}
fn poll_shutdown(&self) -> Poll<io::Result<()>> { fn poll_shutdown(&self) -> Poll<io::Result<()>> {
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
@ -173,6 +145,7 @@ impl Filter for NullFilter {
fn release_read_buf( fn release_read_buf(
&self, &self,
_: &IoRef,
_: BytesMut, _: BytesMut,
_: &mut Option<BytesMut>, _: &mut Option<BytesMut>,
_: usize, _: usize,

View file

@ -13,33 +13,33 @@ use super::{Filter, FilterFactory, Handle, IoStream, RecvError};
bitflags::bitflags! { bitflags::bitflags! {
pub struct Flags: u16 { pub struct Flags: u16 {
/// io error occured /// io is closed
const IO_ERR = 0b0000_0000_0000_0001; const IO_STOPPED = 0b0000_0000_0000_0001;
/// shuting down filters
const IO_FILTERS = 0b0000_0000_0000_0010;
/// shuting down filters timeout
const IO_FILTERS_TO = 0b0000_0000_0000_0100;
/// shutdown io tasks /// shutdown io tasks
const IO_SHUTDOWN = 0b0000_0000_0000_1000; const IO_STOPPING = 0b0000_0000_0000_0010;
/// shuting down filters
const IO_STOPPING_FILTERS = 0b0000_0000_0000_0100;
/// initiate filters shutdown timeout in write task
const IO_FILTERS_TIMEOUT = 0b0000_0000_0000_1000;
/// pause io read /// pause io read
const RD_PAUSED = 0b0000_0000_0010_0000; const RD_PAUSED = 0b0000_0000_0001_0000;
/// new data is available /// new data is available
const RD_READY = 0b0000_0000_0100_0000; const RD_READY = 0b0000_0000_0010_0000;
/// read buffer is full /// read buffer is full
const RD_BUF_FULL = 0b0000_0000_1000_0000; const RD_BUF_FULL = 0b0000_0000_0100_0000;
/// wait write completion /// wait write completion
const WR_WAIT = 0b0000_0001_0000_0000; const WR_WAIT = 0b0000_0000_1000_0000;
/// write buffer is full /// write buffer is full
const WR_BACKPRESSURE = 0b0000_0010_0000_0000; const WR_BACKPRESSURE = 0b0000_0001_0000_0000;
/// dispatcher is marked stopped /// dispatcher is marked stopped
const DSP_STOP = 0b0001_0000_0000_0000; const DSP_STOP = 0b0000_0010_0000_0000;
/// keep-alive timeout occured /// keep-alive timeout occured
const DSP_KEEPALIVE = 0b0010_0000_0000_0000; const DSP_KEEPALIVE = 0b0000_0100_0000_0000;
/// dispatcher returned error /// dispatcher returned error
const DSP_ERR = 0b0100_0000_0000_0000; const DSP_ERR = 0b0000_1000_0000_0000;
} }
} }
@ -104,15 +104,7 @@ impl IoState {
} }
#[inline] #[inline]
pub(super) fn is_io_open(&self) -> bool { pub(super) fn io_stopped(&self, err: Option<io::Error>) {
!self
.flags
.get()
.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN)
}
#[inline]
pub(super) fn set_error(&self, err: Option<io::Error>) {
if err.is_some() { if err.is_some() {
self.error.set(err); self.error.set(err);
} }
@ -120,52 +112,49 @@ impl IoState {
self.write_task.wake(); self.write_task.wake();
self.dispatch_task.wake(); self.dispatch_task.wake();
self.notify_disconnect(); self.notify_disconnect();
let mut flags = self.flags.get(); self.handle.take();
flags.insert(Flags::IO_ERR); self.insert_flags(
flags.remove( Flags::IO_STOPPED | Flags::IO_STOPPING | Flags::IO_STOPPING_FILTERS,
Flags::DSP_KEEPALIVE
| Flags::RD_PAUSED
| Flags::RD_READY
| Flags::RD_BUF_FULL
| Flags::WR_WAIT
| Flags::WR_BACKPRESSURE,
); );
self.flags.set(flags);
} }
#[inline] #[inline]
/// Gracefully shutdown read and write io tasks /// Gracefully shutdown read and write io tasks
pub(super) fn init_shutdown(&self, err: Option<io::Error>) { pub(super) fn init_shutdown(&self, err: Option<io::Error>) {
let flags = self.flags.get(); if err.is_some() {
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN | Flags::IO_FILTERS) { self.io_stopped(err);
log::trace!("initiate io shutdown {:?} {:?}", flags, err); } else if !self
self.insert_flags(Flags::IO_FILTERS); .flags
.get()
.intersects(Flags::IO_STOPPED | Flags::IO_STOPPING | Flags::IO_STOPPING_FILTERS)
{
log::trace!("initiate io shutdown {:?}", self.flags.get());
self.insert_flags(Flags::IO_STOPPING_FILTERS);
self.read_task.wake(); self.read_task.wake();
self.write_task.wake(); self.write_task.wake();
if let Some(err) = err { self.dispatch_task.wake();
self.error.set(Some(err));
}
} }
} }
#[inline] #[inline]
pub(super) fn shutdown_filters(&self) { pub(super) fn shutdown_filters(&self) {
let mut flags = self.flags.get(); if !self
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) { .flags
.get()
.intersects(Flags::IO_STOPPED | Flags::IO_STOPPING)
{
match self.filter.get().poll_shutdown() { match self.filter.get().poll_shutdown() {
Poll::Pending => return,
Poll::Ready(Ok(())) => { Poll::Ready(Ok(())) => {
flags.insert(Flags::IO_SHUTDOWN); self.read_task.wake();
self.write_task.wake();
self.dispatch_task.wake();
self.insert_flags(Flags::IO_STOPPING);
} }
Poll::Ready(Err(err)) => { Poll::Ready(Err(err)) => {
flags.insert(Flags::IO_ERR); self.io_stopped(Some(err));
self.error.set(Some(err));
} }
Poll::Pending => (),
} }
self.flags.set(flags);
self.read_task.wake();
self.write_task.wake();
self.dispatch_task.wake();
} }
} }
@ -264,7 +253,7 @@ impl Io {
let io_ref = IoRef(inner); let io_ref = IoRef(inner);
// start io tasks // start io tasks
let hnd = io.start(ReadContext(io_ref.clone()), WriteContext(io_ref.clone())); let hnd = io.start(ReadContext::new(&io_ref), WriteContext::new(&io_ref));
io_ref.0.handle.set(hnd); io_ref.0.handle.set(hnd);
Io(io_ref, FilterItem::Ptr(Box::into_raw(filter))) Io(io_ref, FilterItem::Ptr(Box::into_raw(filter)))
@ -331,6 +320,11 @@ impl<F> Io<F> {
pub fn reset_keepalive(&self) { pub fn reset_keepalive(&self) {
self.0 .0.remove_flags(Flags::DSP_KEEPALIVE) self.0 .0.remove_flags(Flags::DSP_KEEPALIVE)
} }
/// Get current io error
fn error(&self) -> Option<io::Error> {
self.0 .0.error.take()
}
} }
impl Io<Sealed> { impl Io<Sealed> {
@ -478,13 +472,13 @@ impl<F> Io<F> {
/// Wake write task and instruct to flush data. /// Wake write task and instruct to flush data.
/// ///
/// This is async version of .poll_flush() method. /// This is async version of .poll_flush() method.
pub async fn flush(&self, full: bool) -> Result<(), io::Error> { pub async fn flush(&self, full: bool) -> io::Result<()> {
poll_fn(|cx| self.poll_flush(cx, full)).await poll_fn(|cx| self.poll_flush(cx, full)).await
} }
#[inline] #[inline]
/// Shut down io stream /// Shut down io stream
pub async fn shutdown(&self) -> Result<(), io::Error> { pub async fn shutdown(&self) -> io::Result<()> {
poll_fn(|cx| self.poll_shutdown(cx)).await poll_fn(|cx| self.poll_shutdown(cx)).await
} }
@ -503,16 +497,13 @@ impl<F> Io<F> {
/// `Poll::Ready(Ok(None))` if io stream is disconnected /// `Poll::Ready(Ok(None))` if io stream is disconnected
/// `Some(Poll::Ready(Err(e)))` if an error is encountered. /// `Some(Poll::Ready(Err(e)))` if an error is encountered.
pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<Option<()>>> { pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<Option<()>>> {
if !self.0 .0.is_io_open() { let mut flags = self.0 .0.flags.get();
if let Some(err) = self.0 .0.error.take() {
Poll::Ready(Err(err)) if flags.contains(Flags::IO_STOPPED) {
} else { Poll::Ready(self.error().map(Err).unwrap_or(Ok(None)))
Poll::Ready(Ok(None))
}
} else { } else {
self.0 .0.dispatch_task.register(cx.waker()); self.0 .0.dispatch_task.register(cx.waker());
let mut flags = self.0 .0.flags.get();
let ready = flags.contains(Flags::RD_READY); let ready = flags.contains(Flags::RD_READY);
if flags.intersects(Flags::RD_BUF_FULL | Flags::RD_PAUSED) { if flags.intersects(Flags::RD_BUF_FULL | Flags::RD_PAUSED) {
if flags.intersects(Flags::RD_BUF_FULL) { if flags.intersects(Flags::RD_BUF_FULL) {
@ -540,7 +531,6 @@ impl<F> Io<F> {
} }
#[inline] #[inline]
#[allow(clippy::type_complexity)]
/// Decode codec item from incoming bytes stream. /// Decode codec item from incoming bytes stream.
/// ///
/// Wake read task and request to read more data if data is not enough for decoding. /// Wake read task and request to read more data if data is not enough for decoding.
@ -556,13 +546,10 @@ impl<F> Io<F> {
match self.decode(codec) { match self.decode(codec) {
Ok(Some(el)) => Poll::Ready(Ok(el)), Ok(Some(el)) => Poll::Ready(Ok(el)),
Ok(None) => { Ok(None) => {
if !self.0 .0.is_io_open() {
return Poll::Ready(Err(RecvError::PeerGone(
self.0 .0.error.take(),
)));
}
let flags = self.flags(); let flags = self.flags();
if flags.contains(Flags::DSP_STOP) { if flags.contains(Flags::IO_STOPPED) {
Poll::Ready(Err(RecvError::PeerGone(self.error())))
} else if flags.contains(Flags::DSP_STOP) {
Poll::Ready(Err(RecvError::Stop)) Poll::Ready(Err(RecvError::Stop))
} else if flags.contains(Flags::DSP_KEEPALIVE) { } else if flags.contains(Flags::DSP_KEEPALIVE) {
Poll::Ready(Err(RecvError::KeepAlive)) Poll::Ready(Err(RecvError::KeepAlive))
@ -594,64 +581,51 @@ impl<F> Io<F> {
/// otherwise wake up when size of write buffer is lower than /// otherwise wake up when size of write buffer is lower than
/// buffer max size. /// buffer max size.
pub fn poll_flush(&self, cx: &mut Context<'_>, full: bool) -> Poll<io::Result<()>> { pub fn poll_flush(&self, cx: &mut Context<'_>, full: bool) -> Poll<io::Result<()>> {
// check io error let flags = self.flags();
if !self.0 .0.is_io_open() {
self.0 .0.remove_flags( if flags.contains(Flags::IO_STOPPED) {
Flags::DSP_KEEPALIVE Poll::Ready(self.error().map(Err).unwrap_or(Ok(())))
| Flags::RD_PAUSED } else {
| Flags::RD_READY let len = self
| Flags::RD_BUF_FULL
| Flags::WR_WAIT
| Flags::WR_BACKPRESSURE,
);
return Poll::Ready(Err(self
.0 .0
.0 .0
.error .with_write_buf(|buf| buf.as_ref().map(|b| b.len()).unwrap_or(0));
.take()
.unwrap_or_else(|| io::Error::new(io::ErrorKind::Other, "disconnected"))));
}
let len = self if len > 0 {
.0 if full {
.0 self.0 .0.insert_flags(Flags::WR_WAIT);
.with_write_buf(|buf| buf.as_ref().map(|b| b.len()).unwrap_or(0)); self.0 .0.dispatch_task.register(cx.waker());
return Poll::Pending;
if len > 0 { } else if len >= self.0.memory_pool().write_params_high() << 1 {
if full { self.0 .0.insert_flags(Flags::WR_BACKPRESSURE);
self.0 .0.insert_flags(Flags::WR_WAIT); self.0 .0.dispatch_task.register(cx.waker());
self.0 .0.dispatch_task.register(cx.waker()); return Poll::Pending;
return Poll::Pending; }
} else if len >= self.0.memory_pool().write_params_high() << 1 {
self.0 .0.insert_flags(Flags::WR_BACKPRESSURE);
self.0 .0.dispatch_task.register(cx.waker());
return Poll::Pending;
} }
self.0
.0
.remove_flags(Flags::WR_WAIT | Flags::WR_BACKPRESSURE);
Poll::Ready(Ok(()))
} }
self.0
.0
.remove_flags(Flags::WR_WAIT | Flags::WR_BACKPRESSURE);
Poll::Ready(Ok(()))
} }
#[inline] #[inline]
/// Shut down io stream /// Shut down io stream
pub fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { pub fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let flags = self.flags(); let flags = self.flags();
if flags.intersects(Flags::IO_ERR) { if flags.intersects(Flags::IO_STOPPED) {
Poll::Ready(Ok(())) if let Some(err) = self.error() {
} else {
if !flags.contains(Flags::IO_FILTERS) {
self.0 .0.init_shutdown(None);
}
if let Some(err) = self.0 .0.error.take() {
Poll::Ready(Err(err)) Poll::Ready(Err(err))
} else { } else {
self.0 .0.dispatch_task.register(cx.waker()); Poll::Ready(Ok(()))
Poll::Pending
} }
} else {
if !flags.contains(Flags::IO_STOPPING_FILTERS) {
self.0 .0.init_shutdown(None);
}
self.0 .0.dispatch_task.register(cx.waker());
Poll::Pending
} }
} }
} }
@ -708,7 +682,7 @@ pub struct OnDisconnect {
impl OnDisconnect { impl OnDisconnect {
pub(super) fn new(inner: Rc<IoState>) -> Self { pub(super) fn new(inner: Rc<IoState>) -> Self {
Self::new_inner(inner.flags.get().contains(Flags::IO_ERR), inner) Self::new_inner(inner.flags.get().contains(Flags::IO_STOPPED), inner)
} }
fn new_inner(disconnected: bool, inner: Rc<IoState>) -> Self { fn new_inner(disconnected: bool, inner: Rc<IoState>) -> Self {

View file

@ -32,73 +32,10 @@ impl IoRef {
self.0.pool.get() self.0.pool.get()
} }
#[inline]
/// Check if io is still active
pub fn is_io_open(&self) -> bool {
self.0.is_io_open()
}
#[inline]
/// Check if keep-alive timeout occured
pub fn is_keepalive(&self) -> bool {
self.0.flags.get().contains(Flags::DSP_KEEPALIVE)
}
#[inline] #[inline]
/// Check if io stream is closed /// Check if io stream is closed
pub fn is_closed(&self) -> bool { pub fn is_closed(&self) -> bool {
self.0.flags.get().intersects( self.0.flags.get().contains(Flags::IO_STOPPING)
Flags::IO_ERR | Flags::IO_SHUTDOWN | Flags::IO_FILTERS | Flags::DSP_STOP,
)
}
#[inline]
/// Take io error if any occured
pub fn take_error(&self) -> Option<io::Error> {
self.0.error.take()
}
#[inline]
/// Wake dispatcher task
pub fn wake_dispatcher(&self) {
self.0.dispatch_task.wake();
}
#[inline]
/// Gracefully close connection
///
/// First stop dispatcher, then dispatcher stops io tasks
pub fn close(&self) {
self.0.insert_flags(Flags::DSP_STOP);
self.0.dispatch_task.wake();
}
#[inline]
/// Force close connection
///
/// Dispatcher does not wait for uncompleted responses, but flushes io buffers.
pub fn force_close(&self) {
log::trace!("force close io stream object");
self.0.insert_flags(Flags::DSP_STOP | Flags::IO_SHUTDOWN);
self.0.read_task.wake();
self.0.write_task.wake();
self.0.dispatch_task.wake();
}
#[inline]
/// Notify when io stream get disconnected
pub fn on_disconnect(&self) -> OnDisconnect {
OnDisconnect::new(self.0.clone())
}
#[inline]
/// Query specific data
pub fn query<T: 'static>(&self) -> types::QueryItem<T> {
if let Some(item) = self.filter().query(any::TypeId::of::<T>()) {
types::QueryItem::new(item)
} else {
types::QueryItem::empty()
}
} }
#[inline] #[inline]
@ -131,6 +68,55 @@ impl IoRef {
len >= self.memory_pool().read_params_high() len >= self.memory_pool().read_params_high()
} }
#[inline]
/// Wake dispatcher task
pub fn wake(&self) {
self.0.dispatch_task.wake();
}
#[inline]
/// Gracefully close connection
///
/// First stop dispatcher, then dispatcher stops io tasks
pub fn close(&self) {
self.0.insert_flags(Flags::DSP_STOP);
self.0.dispatch_task.wake();
}
#[inline]
/// Force close connection
///
/// Dispatcher does not wait for uncompleted responses, but flushes io buffers.
pub fn force_close(&self) {
log::trace!("force close io stream object");
self.0.insert_flags(Flags::DSP_STOP | Flags::IO_STOPPING);
self.0.read_task.wake();
self.0.write_task.wake();
self.0.dispatch_task.wake();
}
#[inline]
/// Gracefully shutdown io stream
pub fn want_shutdown(&self, err: Option<io::Error>) {
self.0.init_shutdown(err);
}
#[inline]
/// Notify when io stream get disconnected
pub fn on_disconnect(&self) -> OnDisconnect {
OnDisconnect::new(self.0.clone())
}
#[inline]
/// Query specific data
pub fn query<T: 'static>(&self) -> types::QueryItem<T> {
if let Some(item) = self.filter().query(any::TypeId::of::<T>()) {
types::QueryItem::new(item)
} else {
types::QueryItem::empty()
}
}
#[inline] #[inline]
/// Get mut access to write buffer /// Get mut access to write buffer
pub fn with_write_buf<F, R>(&self, f: F) -> Result<R, io::Error> pub fn with_write_buf<F, R>(&self, f: F) -> Result<R, io::Error>
@ -176,7 +162,7 @@ impl IoRef {
{ {
let flags = self.0.flags.get(); let flags = self.0.flags.get();
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) { if !flags.contains(Flags::IO_STOPPING) {
self.with_write_buf(|buf| { self.with_write_buf(|buf| {
let (hw, lw) = self.memory_pool().write_params().unpack(); let (hw, lw) = self.memory_pool().write_params().unpack();
@ -191,7 +177,7 @@ impl IoRef {
}) })
.map_or_else( .map_or_else(
|err| { |err| {
self.0.set_error(Some(err)); self.0.io_stopped(Some(err));
Ok(()) Ok(())
}, },
|item| item, |item| item,
@ -223,7 +209,7 @@ impl IoRef {
pub fn write(&self, src: &[u8]) -> io::Result<()> { pub fn write(&self, src: &[u8]) -> io::Result<()> {
let flags = self.0.flags.get(); let flags = self.0.flags.get();
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) { if !flags.intersects(Flags::IO_STOPPING) {
self.with_write_buf(|buf| { self.with_write_buf(|buf| {
buf.extend_from_slice(src); buf.extend_from_slice(src);
}) })
@ -283,7 +269,7 @@ mod tests {
client.read_error(io::Error::new(io::ErrorKind::Other, "err")); client.read_error(io::Error::new(io::ErrorKind::Other, "err"));
let msg = state.recv(&BytesCodec).await; let msg = state.recv(&BytesCodec).await;
assert!(msg.is_err()); assert!(msg.is_err());
assert!(state.flags().contains(Flags::IO_ERR)); assert!(state.flags().contains(Flags::IO_STOPPED));
let (client, server) = IoTest::create(); let (client, server) = IoTest::create();
client.remote_buffer_cap(1024); client.remote_buffer_cap(1024);
@ -293,7 +279,7 @@ mod tests {
let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await; let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
if let Poll::Ready(msg) = res { if let Poll::Ready(msg) = res {
assert!(msg.is_err()); assert!(msg.is_err());
assert!(state.flags().contains(Flags::IO_ERR)); assert!(state.flags().contains(Flags::IO_STOPPED));
assert!(state.flags().contains(Flags::DSP_STOP)); assert!(state.flags().contains(Flags::DSP_STOP));
} }
@ -310,14 +296,14 @@ mod tests {
client.write_error(io::Error::new(io::ErrorKind::Other, "err")); client.write_error(io::Error::new(io::ErrorKind::Other, "err"));
let res = state.send(Bytes::from_static(b"test"), &BytesCodec).await; let res = state.send(Bytes::from_static(b"test"), &BytesCodec).await;
assert!(res.is_err()); assert!(res.is_err());
assert!(state.flags().contains(Flags::IO_ERR)); assert!(state.flags().contains(Flags::IO_STOPPED));
let (client, server) = IoTest::create(); let (client, server) = IoTest::create();
client.remote_buffer_cap(1024); client.remote_buffer_cap(1024);
let state = Io::new(server); let state = Io::new(server);
state.force_close(); state.force_close();
assert!(state.flags().contains(Flags::DSP_STOP)); assert!(state.flags().contains(Flags::DSP_STOP));
assert!(state.flags().contains(Flags::IO_SHUTDOWN)); assert!(state.flags().contains(Flags::IO_STOPPING));
} }
#[ntex::test] #[ntex::test]
@ -389,10 +375,6 @@ mod tests {
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
fn want_read(&self) {}
fn want_shutdown(&self, _: Option<io::Error>) {}
fn query(&self, _: std::any::TypeId) -> Option<Box<dyn std::any::Any>> { fn query(&self, _: std::any::TypeId) -> Option<Box<dyn std::any::Any>> {
None None
} }
@ -401,21 +383,18 @@ mod tests {
self.inner.poll_read_ready(cx) self.inner.poll_read_ready(cx)
} }
fn closed(&self, err: Option<io::Error>) {
self.inner.closed(err)
}
fn get_read_buf(&self) -> Option<BytesMut> { fn get_read_buf(&self) -> Option<BytesMut> {
self.inner.get_read_buf() self.inner.get_read_buf()
} }
fn release_read_buf( fn release_read_buf(
&self, &self,
io: &IoRef,
buf: BytesMut, buf: BytesMut,
dst: &mut Option<BytesMut>, dst: &mut Option<BytesMut>,
new_bytes: usize, new_bytes: usize,
) -> io::Result<usize> { ) -> io::Result<usize> {
let result = self.inner.release_read_buf(buf, dst, new_bytes)?; let result = self.inner.release_read_buf(io, buf, dst, new_bytes)?;
self.read_order.borrow_mut().push(self.idx); self.read_order.borrow_mut().push(self.idx);
self.in_bytes.set(self.in_bytes.get() + result); self.in_bytes.set(self.in_bytes.get() + result);
Ok(result) Ok(result)

View file

@ -55,24 +55,13 @@ pub enum WriteStatus {
pub trait Filter: 'static { pub trait Filter: 'static {
fn query(&self, id: TypeId) -> Option<Box<dyn Any>>; fn query(&self, id: TypeId) -> Option<Box<dyn Any>>;
/// Filter needs incoming data from io stream
fn want_read(&self);
/// Filter wants gracefully shutdown io stream
fn want_shutdown(&self, err: Option<sio::Error>);
fn poll_shutdown(&self) -> Poll<sio::Result<()>>;
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus>;
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus>;
fn get_read_buf(&self) -> Option<BytesMut>; fn get_read_buf(&self) -> Option<BytesMut>;
fn get_write_buf(&self) -> Option<BytesMut>; fn get_write_buf(&self) -> Option<BytesMut>;
fn release_read_buf( fn release_read_buf(
&self, &self,
io: &IoRef,
src: BytesMut, src: BytesMut,
dst: &mut Option<BytesMut>, dst: &mut Option<BytesMut>,
nbytes: usize, nbytes: usize,
@ -80,7 +69,11 @@ pub trait Filter: 'static {
fn release_write_buf(&self, buf: BytesMut) -> sio::Result<()>; fn release_write_buf(&self, buf: BytesMut) -> sio::Result<()>;
fn closed(&self, err: Option<sio::Error>); fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus>;
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus>;
fn poll_shutdown(&self) -> Poll<sio::Result<()>>;
} }
pub trait FilterFactory<F: Filter>: Sized { pub trait FilterFactory<F: Filter>: Sized {

View file

@ -4,9 +4,13 @@ use ntex_bytes::{BytesMut, PoolRef};
use super::{io::Flags, IoRef, ReadStatus, WriteStatus}; use super::{io::Flags, IoRef, ReadStatus, WriteStatus};
pub struct ReadContext(pub(super) IoRef); pub struct ReadContext(IoRef);
impl ReadContext { impl ReadContext {
pub(crate) fn new(io: &IoRef) -> Self {
Self(io.clone())
}
#[inline] #[inline]
pub fn memory_pool(&self) -> PoolRef { pub fn memory_pool(&self) -> PoolRef {
self.0.memory_pool() self.0.memory_pool()
@ -17,11 +21,6 @@ impl ReadContext {
self.0.filter().poll_read_ready(cx) self.0.filter().poll_read_ready(cx)
} }
#[inline]
pub fn close(&self, err: Option<io::Error>) {
self.0.filter().closed(err);
}
#[inline] #[inline]
pub fn get_read_buf(&self) -> BytesMut { pub fn get_read_buf(&self) -> BytesMut {
self.0 self.0
@ -37,7 +36,7 @@ impl ReadContext {
} else { } else {
let filter = self.0.filter(); let filter = self.0.filter();
let mut dst = self.0 .0.read_buf.take(); let mut dst = self.0 .0.read_buf.take();
let result = filter.release_read_buf(buf, &mut dst, nbytes); let result = filter.release_read_buf(&self.0, buf, &mut dst, nbytes);
let nbytes = result.as_ref().map(|i| *i).unwrap_or(0); let nbytes = result.as_ref().map(|i| *i).unwrap_or(0);
if let Some(dst) = dst { if let Some(dst) = dst {
@ -63,19 +62,28 @@ impl ReadContext {
if let Err(err) = result { if let Err(err) = result {
self.0 .0.dispatch_task.wake(); self.0 .0.dispatch_task.wake();
self.0 .0.insert_flags(Flags::RD_READY); self.0 .0.insert_flags(Flags::RD_READY);
filter.want_shutdown(Some(err)); self.0.want_shutdown(Some(err));
} }
} }
if self.0.flags().contains(Flags::IO_FILTERS) { if self.0.flags().contains(Flags::IO_STOPPING_FILTERS) {
self.0 .0.shutdown_filters(); self.0 .0.shutdown_filters();
} }
} }
#[inline]
pub fn close(&self, err: Option<io::Error>) {
self.0 .0.io_stopped(err);
}
} }
pub struct WriteContext(pub(super) IoRef); pub struct WriteContext(IoRef);
impl WriteContext { impl WriteContext {
pub(crate) fn new(io: &IoRef) -> Self {
Self(io.clone())
}
#[inline] #[inline]
pub fn memory_pool(&self) -> PoolRef { pub fn memory_pool(&self) -> PoolRef {
self.0.memory_pool() self.0.memory_pool()
@ -86,11 +94,6 @@ impl WriteContext {
self.0.filter().poll_write_ready(cx) self.0.filter().poll_write_ready(cx)
} }
#[inline]
pub fn close(&self, err: Option<io::Error>) {
self.0.filter().closed(err)
}
#[inline] #[inline]
pub fn get_write_buf(&self) -> Option<BytesMut> { pub fn get_write_buf(&self) -> Option<BytesMut> {
self.0 .0.write_buf.take() self.0 .0.write_buf.take()
@ -120,9 +123,15 @@ impl WriteContext {
self.0 .0.write_buf.set(Some(buf)) self.0 .0.write_buf.set(Some(buf))
} }
if flags.contains(Flags::IO_FILTERS) { if self.0.flags().contains(Flags::IO_STOPPING_FILTERS) {
self.0 .0.shutdown_filters(); self.0 .0.shutdown_filters();
} }
Ok(()) Ok(())
} }
#[inline]
pub fn close(&self, err: Option<io::Error>) {
self.0 .0.io_stopped(err);
}
} }

View file

@ -638,6 +638,7 @@ pub(super) fn flush_io(
Poll::Pending Poll::Pending
} }
} else { } else {
let _ = state.release_write_buf(buf);
Poll::Ready(true) Poll::Ready(true)
} }
} }

View file

@ -359,6 +359,9 @@ pub(super) fn flush_io<T: AsyncRead + AsyncWrite + Unpin>(
Poll::Ready(false) Poll::Ready(false)
} }
} }
} else if let Err(e) = state.release_write_buf(buf) {
state.close(Some(e));
Poll::Ready(false)
} else { } else {
Poll::Ready(true) Poll::Ready(true)
} }

View file

@ -1,6 +1,6 @@
[package] [package]
name = "ntex-tls" name = "ntex-tls"
version = "0.1.0-b.6" version = "0.1.0-b.7"
authors = ["ntex contributors <team@ntex.rs>"] authors = ["ntex contributors <team@ntex.rs>"]
description = "An implementation of SSL streams for ntex backed by OpenSSL" description = "An implementation of SSL streams for ntex backed by OpenSSL"
keywords = ["network", "framework", "async", "futures"] keywords = ["network", "framework", "async", "futures"]
@ -26,7 +26,7 @@ rustls = ["tls_rust"]
[dependencies] [dependencies]
ntex-bytes = "0.1.8" ntex-bytes = "0.1.8"
ntex-io = "0.1.0-b.8" ntex-io = "0.1.0-b.10"
ntex-util = "0.1.5" ntex-util = "0.1.5"
ntex-service = "0.3.0-b.0" ntex-service = "0.3.0-b.0"
pin-project-lite = "0.2" pin-project-lite = "0.2"

View file

@ -6,7 +6,7 @@ use std::{
}; };
use ntex_bytes::{BufMut, BytesMut, PoolRef}; use ntex_bytes::{BufMut, BytesMut, PoolRef};
use ntex_io::{Base, Filter, FilterFactory, Io, ReadStatus, WriteStatus}; use ntex_io::{Base, Filter, FilterFactory, Io, IoRef, ReadStatus, WriteStatus};
use ntex_util::{future::poll_fn, ready, time, time::Millis}; use ntex_util::{future::poll_fn, ready, time, time::Millis};
use tls_openssl::ssl::{self, SslStream}; use tls_openssl::ssl::{self, SslStream};
use tls_openssl::x509::X509; use tls_openssl::x509::X509;
@ -137,21 +137,6 @@ impl<F: Filter> Filter for SslFilter<F> {
self.inner.borrow().get_ref().inner.poll_write_ready(cx) self.inner.borrow().get_ref().inner.poll_write_ready(cx)
} }
#[inline]
fn closed(&self, err: Option<io::Error>) {
self.inner.borrow().get_ref().inner.closed(err)
}
#[inline]
fn want_read(&self) {
self.inner.borrow().get_ref().inner.want_read()
}
#[inline]
fn want_shutdown(&self, err: Option<io::Error>) {
self.inner.borrow().get_ref().inner.want_shutdown(err)
}
#[inline] #[inline]
fn get_read_buf(&self) -> Option<BytesMut> { fn get_read_buf(&self) -> Option<BytesMut> {
if let Some(buf) = self.inner.borrow_mut().get_mut().read_buf.take() { if let Some(buf) = self.inner.borrow_mut().get_mut().read_buf.take() {
@ -174,6 +159,7 @@ impl<F: Filter> Filter for SslFilter<F> {
fn release_read_buf( fn release_read_buf(
&self, &self,
io: &IoRef,
src: BytesMut, src: BytesMut,
dst: &mut Option<BytesMut>, dst: &mut Option<BytesMut>,
nbytes: usize, nbytes: usize,
@ -185,9 +171,9 @@ impl<F: Filter> Filter for SslFilter<F> {
let result = inner let result = inner
.get_ref() .get_ref()
.inner .inner
.release_read_buf(src, &mut dst, nbytes); .release_read_buf(io, src, &mut dst, nbytes);
if let Err(err) = result { if let Err(err) = result {
self.want_shutdown(Some(err)); io.want_shutdown(Some(err));
} }
if dst.is_some() { if dst.is_some() {
inner.get_mut().read_buf = dst; inner.get_mut().read_buf = dst;
@ -233,7 +219,7 @@ impl<F: Filter> Filter for SslFilter<F> {
Ok(new_bytes) Ok(new_bytes)
} }
Err(ref e) if e.code() == ssl::ErrorCode::ZERO_RETURN => { Err(ref e) if e.code() == ssl::ErrorCode::ZERO_RETURN => {
self.want_shutdown(None); io.want_shutdown(None);
Ok(new_bytes) Ok(new_bytes)
} }
Err(e) => Err(map_to_ioerr(e)), Err(e) => Err(map_to_ioerr(e)),
@ -258,10 +244,6 @@ impl<F: Filter> Filter for SslFilter<F> {
} }
return match e.code() { return match e.code() {
ssl::ErrorCode::WANT_READ | ssl::ErrorCode::WANT_WRITE => Ok(()), ssl::ErrorCode::WANT_READ | ssl::ErrorCode::WANT_WRITE => Ok(()),
ssl::ErrorCode::ZERO_RETURN => {
self.want_shutdown(None);
Ok(())
}
_ => Err(map_to_ioerr(e)), _ => Err(map_to_ioerr(e)),
}; };
} }

View file

@ -3,7 +3,7 @@ use std::io::{self, Read as IoRead, Write as IoWrite};
use std::{any, cell::RefCell, cmp, sync::Arc, task::Context, task::Poll}; use std::{any, cell::RefCell, cmp, sync::Arc, task::Context, task::Poll};
use ntex_bytes::{BufMut, BytesMut, PoolRef}; use ntex_bytes::{BufMut, BytesMut, PoolRef};
use ntex_io::{Filter, Io, ReadStatus, WriteStatus}; use ntex_io::{Filter, Io, IoRef, ReadStatus, WriteStatus};
use ntex_util::{future::poll_fn, ready}; use ntex_util::{future::poll_fn, ready};
use tls_rust::{ClientConfig, ClientConnection, ServerName}; use tls_rust::{ClientConfig, ClientConnection, ServerName};
@ -46,16 +46,6 @@ impl<F: Filter> Filter for TlsClientFilter<F> {
} }
} }
#[inline]
fn want_read(&self) {
self.inner.borrow().inner.want_read()
}
#[inline]
fn want_shutdown(&self, err: Option<io::Error>) {
self.inner.borrow().inner.want_shutdown(err)
}
#[inline] #[inline]
fn poll_shutdown(&self) -> Poll<io::Result<()>> { fn poll_shutdown(&self) -> Poll<io::Result<()>> {
self.inner.borrow().inner.poll_shutdown() self.inner.borrow().inner.poll_shutdown()
@ -71,11 +61,6 @@ impl<F: Filter> Filter for TlsClientFilter<F> {
self.inner.borrow().inner.poll_write_ready(cx) self.inner.borrow().inner.poll_write_ready(cx)
} }
#[inline]
fn closed(&self, err: Option<io::Error>) {
self.inner.borrow().inner.closed(err)
}
#[inline] #[inline]
fn get_read_buf(&self) -> Option<BytesMut> { fn get_read_buf(&self) -> Option<BytesMut> {
if let Some(buf) = self.inner.borrow_mut().read_buf.take() { if let Some(buf) = self.inner.borrow_mut().read_buf.take() {
@ -98,6 +83,7 @@ impl<F: Filter> Filter for TlsClientFilter<F> {
fn release_read_buf( fn release_read_buf(
&self, &self,
io: &IoRef,
src: BytesMut, src: BytesMut,
dst: &mut Option<BytesMut>, dst: &mut Option<BytesMut>,
nbytes: usize, nbytes: usize,
@ -111,8 +97,8 @@ impl<F: Filter> Filter for TlsClientFilter<F> {
} else { } else {
let mut src = { let mut src = {
let mut dst = None; let mut dst = None;
if let Err(err) = inner.inner.release_read_buf(src, &mut dst, nbytes) { if let Err(err) = inner.inner.release_read_buf(io, src, &mut dst, nbytes) {
self.want_shutdown(Some(err)); io.want_shutdown(Some(err));
} }
if let Some(dst) = dst { if let Some(dst) = dst {

View file

@ -4,7 +4,7 @@ use std::sync::Arc;
use std::{any, future::Future, io, pin::Pin, task::Context, task::Poll}; use std::{any, future::Future, io, pin::Pin, task::Context, task::Poll};
use ntex_bytes::BytesMut; use ntex_bytes::BytesMut;
use ntex_io::{Base, Filter, FilterFactory, Io, ReadStatus, WriteStatus}; use ntex_io::{Base, Filter, FilterFactory, Io, IoRef, ReadStatus, WriteStatus};
use ntex_util::time::Millis; use ntex_util::time::Millis;
use tls_rust::{ClientConfig, ServerConfig, ServerName}; use tls_rust::{ClientConfig, ServerConfig, ServerName};
@ -60,30 +60,6 @@ impl<F: Filter> Filter for TlsFilter<F> {
} }
} }
#[inline]
fn closed(&self, err: Option<io::Error>) {
match self.inner {
InnerTlsFilter::Server(ref f) => f.closed(err),
InnerTlsFilter::Client(ref f) => f.closed(err),
}
}
#[inline]
fn want_read(&self) {
match self.inner {
InnerTlsFilter::Server(ref f) => f.want_read(),
InnerTlsFilter::Client(ref f) => f.want_read(),
}
}
#[inline]
fn want_shutdown(&self, err: Option<io::Error>) {
match self.inner {
InnerTlsFilter::Server(ref f) => f.want_shutdown(err),
InnerTlsFilter::Client(ref f) => f.want_shutdown(err),
}
}
#[inline] #[inline]
fn poll_shutdown(&self) -> Poll<io::Result<()>> { fn poll_shutdown(&self) -> Poll<io::Result<()>> {
match self.inner { match self.inner {
@ -127,13 +103,14 @@ impl<F: Filter> Filter for TlsFilter<F> {
#[inline] #[inline]
fn release_read_buf( fn release_read_buf(
&self, &self,
io: &IoRef,
src: BytesMut, src: BytesMut,
dst: &mut Option<BytesMut>, dst: &mut Option<BytesMut>,
nb: usize, nb: usize,
) -> io::Result<usize> { ) -> io::Result<usize> {
match self.inner { match self.inner {
InnerTlsFilter::Server(ref f) => f.release_read_buf(src, dst, nb), InnerTlsFilter::Server(ref f) => f.release_read_buf(io, src, dst, nb),
InnerTlsFilter::Client(ref f) => f.release_read_buf(src, dst, nb), InnerTlsFilter::Client(ref f) => f.release_read_buf(io, src, dst, nb),
} }
} }

View file

@ -4,7 +4,7 @@ use std::sync::Arc;
use std::{any, cell::RefCell, cmp, task::Context, task::Poll}; use std::{any, cell::RefCell, cmp, task::Context, task::Poll};
use ntex_bytes::{BufMut, BytesMut, PoolRef}; use ntex_bytes::{BufMut, BytesMut, PoolRef};
use ntex_io::{Filter, Io, ReadStatus, WriteStatus}; use ntex_io::{Filter, Io, IoRef, ReadStatus, WriteStatus};
use ntex_util::{future::poll_fn, ready, time, time::Millis}; use ntex_util::{future::poll_fn, ready, time, time::Millis};
use tls_rust::{ServerConfig, ServerConnection}; use tls_rust::{ServerConfig, ServerConnection};
@ -46,21 +46,6 @@ impl<F: Filter> Filter for TlsServerFilter<F> {
} }
} }
#[inline]
fn closed(&self, err: Option<io::Error>) {
self.inner.borrow().inner.closed(err)
}
#[inline]
fn want_read(&self) {
self.inner.borrow().inner.want_read()
}
#[inline]
fn want_shutdown(&self, err: Option<io::Error>) {
self.inner.borrow().inner.want_shutdown(err)
}
#[inline] #[inline]
fn poll_shutdown(&self) -> Poll<io::Result<()>> { fn poll_shutdown(&self) -> Poll<io::Result<()>> {
self.inner.borrow().inner.poll_shutdown() self.inner.borrow().inner.poll_shutdown()
@ -98,6 +83,7 @@ impl<F: Filter> Filter for TlsServerFilter<F> {
fn release_read_buf( fn release_read_buf(
&self, &self,
io: &IoRef,
src: BytesMut, src: BytesMut,
dst: &mut Option<BytesMut>, dst: &mut Option<BytesMut>,
nbytes: usize, nbytes: usize,
@ -111,8 +97,8 @@ impl<F: Filter> Filter for TlsServerFilter<F> {
} else { } else {
let mut src = { let mut src = {
let mut dst = None; let mut dst = None;
if let Err(e) = inner.inner.release_read_buf(src, &mut dst, nbytes) { if let Err(e) = inner.inner.release_read_buf(io, src, &mut dst, nbytes) {
self.want_shutdown(Some(e)); io.want_shutdown(Some(e));
} }
if let Some(dst) = dst { if let Some(dst) = dst {

View file

@ -1,5 +1,9 @@
# Changes # Changes
## [0.5.0-b.10] - 2021-12-30
* Update ntex-io to 0.1.0-b.10
## [0.5.0-b.6] - 2021-12-29 ## [0.5.0-b.6] - 2021-12-29
* Add `async-std` support * Add `async-std` support

View file

@ -1,6 +1,6 @@
[package] [package]
name = "ntex" name = "ntex"
version = "0.5.0-b.6" version = "0.5.0-b.7"
authors = ["ntex contributors <team@ntex.rs>"] authors = ["ntex contributors <team@ntex.rs>"]
description = "Framework for composable network services" description = "Framework for composable network services"
readme = "README.md" readme = "README.md"
@ -51,9 +51,9 @@ ntex-service = "0.3.0-b.0"
ntex-macros = "0.1.3" ntex-macros = "0.1.3"
ntex-util = "0.1.5" ntex-util = "0.1.5"
ntex-bytes = "0.1.8" ntex-bytes = "0.1.8"
ntex-tls = "0.1.0-b.6" ntex-tls = "0.1.0-b.7"
ntex-rt = "0.4.0-b.3" ntex-rt = "0.4.0-b.3"
ntex-io = { version = "0.1.0-b.9", features = ["tokio-traits"] } ntex-io = { version = "0.1.0-b.10", features = ["tokio-traits"] }
base64 = "0.13" base64 = "0.13"
bitflags = "1.3" bitflags = "1.3"

View file

@ -163,8 +163,8 @@ pub enum DispatchError {
Upgrade(Box<dyn std::error::Error>), Upgrade(Box<dyn std::error::Error>),
/// Peer is disconnected, error indicates that peer is disconnected because of it /// Peer is disconnected, error indicates that peer is disconnected because of it
#[display(fmt = "Disconnect: {:?}", _0)] #[display(fmt = "Disconnected: {:?}", _0)]
Disconnect(Option<io::Error>), PeerGone(Option<io::Error>),
/// Http request parse error. /// Http request parse error.
#[display(fmt = "Parse error: {}", _0)] #[display(fmt = "Parse error: {}", _0)]
@ -212,7 +212,7 @@ impl std::error::Error for DispatchError {}
impl From<io::Error> for DispatchError { impl From<io::Error> for DispatchError {
fn from(err: io::Error) -> Self { fn from(err: io::Error) -> Self {
DispatchError::Disconnect(Some(err)) DispatchError::PeerGone(Some(err))
} }
} }

View file

@ -312,8 +312,7 @@ where
{ {
log::trace!("peer is gone with {:?}", err); log::trace!("peer is gone with {:?}", err);
*this.st = State::Stop; *this.st = State::Stop;
this.inner.error = this.inner.error = Some(DispatchError::PeerGone(Some(err)));
Some(DispatchError::Disconnect(Some(err)));
} }
} }
Err(RecvError::Decoder(err)) => { Err(RecvError::Decoder(err)) => {
@ -326,7 +325,7 @@ where
Err(RecvError::PeerGone(err)) => { Err(RecvError::PeerGone(err)) => {
log::trace!("peer is gone with {:?}", err); log::trace!("peer is gone with {:?}", err);
*this.st = State::Stop; *this.st = State::Stop;
this.inner.error = Some(DispatchError::Disconnect(err)); this.inner.error = Some(DispatchError::PeerGone(err));
} }
Err(RecvError::Stop) => { Err(RecvError::Stop) => {
log::trace!("dispatcher is instructed to stop"); log::trace!("dispatcher is instructed to stop");
@ -350,16 +349,16 @@ where
// consume request's payload // consume request's payload
State::ReadPayload => { State::ReadPayload => {
if let Err(e) = ready!(this.inner.poll_request_payload(cx)) { if let Err(e) = ready!(this.inner.poll_request_payload(cx)) {
set_error!(this, e); *this.st = State::Stop;
this.inner.error = Some(e);
} else { } else {
*this.st = this.inner.switch_to_read_request(); *this.st = this.inner.switch_to_read_request();
} }
} }
// send response body // send response body
State::SendPayload { ref mut body } => { State::SendPayload { ref mut body } => {
if !this.inner.state.is_io_open() { if this.inner.io().is_closed() {
let e = this.inner.state.take_error().into(); *this.st = State::Stop;
set_error!(this, e);
} else { } else {
if let Poll::Ready(Err(err)) = this.inner.poll_request_payload(cx) { if let Poll::Ready(Err(err)) = this.inner.poll_request_payload(cx) {
this.inner.error = Some(err); this.inner.error = Some(err);
@ -394,29 +393,18 @@ where
State::Stop => { State::Stop => {
this.inner.unregister_keepalive(); this.inner.unregister_keepalive();
if this return if let Err(e) =
.inner ready!(this.inner.io.as_ref().unwrap().poll_shutdown(cx))
.io
.as_ref()
.unwrap()
.poll_shutdown(cx)?
.is_ready()
{ {
// get io error // get io error
if this.inner.error.is_none() { if let Some(e) = this.inner.error.take() {
this.inner.error = Some(DispatchError::Disconnect( Poll::Ready(Err(e))
this.inner.state.take_error(),
));
}
return Poll::Ready(if let Some(err) = this.inner.error.take() {
Err(err)
} else { } else {
Ok(()) Poll::Ready(Err(DispatchError::PeerGone(Some(e))))
}); }
} else { } else {
return Poll::Pending; Poll::Ready(Ok(()))
} };
} }
} }
} }
@ -494,7 +482,9 @@ where
// we dont need to process responses if socket is disconnected // we dont need to process responses if socket is disconnected
// but we still want to handle requests with app service // but we still want to handle requests with app service
// so we skip response processing for droppped connection // so we skip response processing for droppped connection
if self.state.is_io_open() { if self.state.is_closed() {
State::Stop
} else {
let result = self let result = self
.io() .io()
.encode(Message::Item((msg, body.size())), &self.codec) .encode(Message::Item((msg, body.size())), &self.codec)
@ -523,8 +513,6 @@ where
_ => State::SendPayload { body }, _ => State::SendPayload { body },
} }
} }
} else {
State::Stop
} }
} }
@ -571,95 +559,88 @@ where
cx: &mut Context<'_>, cx: &mut Context<'_>,
) -> Poll<Result<(), DispatchError>> { ) -> Poll<Result<(), DispatchError>> {
// check if payload data is required // check if payload data is required
if let Some(ref mut payload) = self.payload { let payload = if let Some(ref mut payload) = self.payload {
match payload.1.poll_data_required(cx) { payload
PayloadStatus::Read => { } else {
let io = self.io.as_ref().unwrap(); return Poll::Ready(Ok(()));
};
match payload.1.poll_data_required(cx) {
PayloadStatus::Read => {
let io = self.io.as_ref().unwrap();
// read request payload // read request payload
let mut updated = false; let mut updated = false;
loop { loop {
let res = io.poll_recv(&payload.0, cx); let res = io.poll_recv(&payload.0, cx);
match res { match res {
Poll::Ready(Ok(PayloadItem::Chunk(chunk))) => { Poll::Ready(Ok(PayloadItem::Chunk(chunk))) => {
updated = true; updated = true;
payload.1.feed_data(chunk); payload.1.feed_data(chunk);
}
Poll::Ready(Ok(PayloadItem::Eof)) => {
updated = true;
payload.1.feed_eof();
self.payload = None;
break;
}
Poll::Ready(Err(err)) => {
let err = match err {
RecvError::WriteBackpressure => {
if io.poll_flush(cx, false)?.is_pending() {
break;
} else {
continue;
}
}
RecvError::KeepAlive => {
payload
.1
.set_error(PayloadError::EncodingCorrupted);
self.payload = None;
io::Error::new(io::ErrorKind::Other, "Keep-alive")
.into()
}
RecvError::Stop => {
payload
.1
.set_error(PayloadError::EncodingCorrupted);
self.payload = None;
io::Error::new(
io::ErrorKind::Other,
"Dispatcher stopped",
)
.into()
}
RecvError::PeerGone(err) => {
payload
.1
.set_error(PayloadError::EncodingCorrupted);
self.payload = None;
if let Some(err) = err {
DispatchError::Disconnect(Some(err))
} else {
ParseError::Incomplete.into()
}
}
RecvError::Decoder(e) => {
payload
.1
.set_error(PayloadError::EncodingCorrupted);
self.payload = None;
DispatchError::Parse(e)
}
};
return Poll::Ready(Err(err));
}
Poll::Pending => break,
} }
} Poll::Ready(Ok(PayloadItem::Eof)) => {
if updated { updated = true;
Poll::Ready(Ok(())) payload.1.feed_eof();
} else { self.payload = None;
Poll::Pending break;
}
Poll::Ready(Err(err)) => {
let err = match err {
RecvError::WriteBackpressure => {
if io.poll_flush(cx, false)?.is_pending() {
break;
} else {
continue;
}
}
RecvError::KeepAlive => {
payload.1.set_error(PayloadError::EncodingCorrupted);
self.payload = None;
io::Error::new(io::ErrorKind::Other, "Keep-alive")
.into()
}
RecvError::Stop => {
payload.1.set_error(PayloadError::EncodingCorrupted);
self.payload = None;
io::Error::new(
io::ErrorKind::Other,
"Dispatcher stopped",
)
.into()
}
RecvError::PeerGone(err) => {
payload.1.set_error(PayloadError::EncodingCorrupted);
self.payload = None;
if let Some(err) = err {
DispatchError::PeerGone(Some(err))
} else {
ParseError::Incomplete.into()
}
}
RecvError::Decoder(e) => {
payload.1.set_error(PayloadError::EncodingCorrupted);
self.payload = None;
DispatchError::Parse(e)
}
};
return Poll::Ready(Err(err));
}
Poll::Pending => break,
} }
} }
PayloadStatus::Pause => Poll::Pending, if updated {
PayloadStatus::Dropped => { Poll::Ready(Ok(()))
// service call is not interested in payload } else {
// wait until future completes and then close Poll::Pending
// connection
self.payload = None;
Poll::Ready(Err(DispatchError::PayloadIsNotConsumed))
} }
} }
} else { PayloadStatus::Pause => Poll::Pending,
Poll::Ready(Ok(())) PayloadStatus::Dropped => {
// service call is not interested in payload
// wait until future completes and then close
// connection
self.payload = None;
Poll::Ready(Err(DispatchError::PayloadIsNotConsumed))
}
} }
} }
} }
@ -803,7 +784,7 @@ mod tests {
client.close().await; client.close().await;
assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_ready()); assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_ready());
assert!(!h1.inner.state.is_io_open()); assert!(h1.inner.state.is_closed());
} }
#[crate::rt_test] #[crate::rt_test]
@ -947,6 +928,7 @@ mod tests {
let mut decoder = ClientCodec::default(); let mut decoder = ClientCodec::default();
// generate large http message
let data = rand::thread_rng() let data = rand::thread_rng()
.sample_iter(&rand::distributions::Alphanumeric) .sample_iter(&rand::distributions::Alphanumeric)
.take(70_000) .take(70_000)
@ -960,7 +942,7 @@ mod tests {
sleep(Millis(50)).await; sleep(Millis(50)).await;
// required because io shutdown is async oper // required because io shutdown is async oper
let _ = lazy(|cx| Pin::new(&mut h1).poll(cx)).await; let _ = lazy(|cx| Pin::new(&mut h1).poll(cx)).await;
sleep(Millis(50)).await; sleep(Millis(550)).await;
assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_ready()); assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_ready());
assert!(h1.inner.state.is_closed()); assert!(h1.inner.state.is_closed());
@ -1130,7 +1112,7 @@ mod tests {
assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_ready()); assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_ready());
sleep(Millis(50)).await; sleep(Millis(50)).await;
assert!(!h1.inner.state.is_io_open()); assert!(h1.inner.state.is_closed());
let buf = client.local_buffer(|buf| buf.split().freeze()); let buf = client.local_buffer(|buf| buf.split().freeze());
assert_eq!(&buf[..28], b"HTTP/1.1 500 Internal Server"); assert_eq!(&buf[..28], b"HTTP/1.1 500 Internal Server");
assert_eq!(&buf[buf.len() - 5..], b"error"); assert_eq!(&buf[buf.len() - 5..], b"error");

View file

@ -2,7 +2,7 @@
use std::{any, cell::Cell, io, task::Context, task::Poll}; use std::{any, cell::Cell, io, task::Context, task::Poll};
use crate::codec::{Decoder, Encoder}; use crate::codec::{Decoder, Encoder};
use crate::io::{Filter, FilterFactory, Io, ReadStatus, WriteStatus}; use crate::io::{Filter, FilterFactory, Io, IoRef, ReadStatus, WriteStatus};
use crate::util::{BufMut, BytesMut, PoolRef, Ready}; use crate::util::{BufMut, BytesMut, PoolRef, Ready};
use super::{Codec, Frame, Item, Message}; use super::{Codec, Frame, Item, Message};
@ -59,16 +59,6 @@ impl<F: Filter> Filter for WsTransport<F> {
self.inner.query(id) self.inner.query(id)
} }
#[inline]
fn want_read(&self) {
self.inner.want_read()
}
#[inline]
fn want_shutdown(&self, err: Option<io::Error>) {
self.inner.want_shutdown(err)
}
#[inline] #[inline]
fn poll_shutdown(&self) -> Poll<io::Result<()>> { fn poll_shutdown(&self) -> Poll<io::Result<()>> {
self.inner.poll_shutdown() self.inner.poll_shutdown()
@ -84,11 +74,6 @@ impl<F: Filter> Filter for WsTransport<F> {
self.inner.poll_write_ready(cx) self.inner.poll_write_ready(cx)
} }
#[inline]
fn closed(&self, err: Option<io::Error>) {
self.inner.closed(err)
}
#[inline] #[inline]
fn get_read_buf(&self) -> Option<BytesMut> { fn get_read_buf(&self) -> Option<BytesMut> {
self.inner.get_read_buf().or_else(|| self.read_buf.take()) self.inner.get_read_buf().or_else(|| self.read_buf.take())
@ -101,14 +86,15 @@ impl<F: Filter> Filter for WsTransport<F> {
fn release_read_buf( fn release_read_buf(
&self, &self,
io: &IoRef,
src: BytesMut, src: BytesMut,
dst: &mut Option<BytesMut>, dst: &mut Option<BytesMut>,
nbytes: usize, nbytes: usize,
) -> io::Result<usize> { ) -> io::Result<usize> {
let mut src = { let mut src = {
let mut dst = None; let mut dst = None;
if let Err(err) = self.inner.release_read_buf(src, &mut dst, nbytes) { if let Err(err) = self.inner.release_read_buf(io, src, &mut dst, nbytes) {
self.want_shutdown(Some(err)); io.want_shutdown(Some(err));
} }
if let Some(dst) = dst { if let Some(dst) = dst {
dst dst