use ntex-io instead of framed

This commit is contained in:
Nikolay Kim 2021-12-15 14:09:36 +06:00
parent dafd339817
commit 3dbba47ab1
62 changed files with 1545 additions and 5639 deletions

View file

@ -1,6 +1,6 @@
//! Framed transport dispatcher
use std::{
cell::Cell, future::Future, io, pin::Pin, rc::Rc, task::Context, task::Poll, time,
cell::Cell, future::Future, pin::Pin, rc::Rc, task::Context, task::Poll, time,
};
use ntex_bytes::Pool;
@ -70,7 +70,6 @@ enum DispatcherError<S, U> {
KeepAlive,
Encoder(U),
Service(S),
Io(io::Error),
}
enum PollService<U: Encoder + Decoder> {
@ -157,7 +156,7 @@ where
///
/// By default disconnect timeout is set to 1 seconds.
pub fn disconnect_timeout(self, val: Seconds) -> Self {
self.inner.state.set_disconnect_timeout(val);
self.inner.state.set_disconnect_timeout(val.into());
self
}
}
@ -176,12 +175,7 @@ where
Ok(Some(val)) => match write.encode(val, &self.codec) {
Ok(true) => (),
Ok(false) => write.enable_backpressure(None),
Err(Either::Left(err)) => {
self.error.set(Some(DispatcherError::Encoder(err)))
}
Err(Either::Right(err)) => {
self.error.set(Some(DispatcherError::Io(err)))
}
Err(err) => self.error.set(Some(DispatcherError::Encoder(err))),
},
Err(err) => self.error.set(Some(DispatcherError::Service(err))),
Ok(None) => return,
@ -407,12 +401,7 @@ where
Ok(Some(item)) => match write.encode(item, &self.shared.codec) {
Ok(true) => (),
Ok(false) => write.enable_backpressure(None),
Err(Either::Left(err)) => {
self.shared.error.set(Some(DispatcherError::Encoder(err)))
}
Err(Either::Right(err)) => {
self.shared.error.set(Some(DispatcherError::Io(err)))
}
Err(err) => self.shared.error.set(Some(DispatcherError::Encoder(err))),
},
Err(err) => self.shared.error.set(Some(DispatcherError::Service(err))),
Ok(None) => (),
@ -443,9 +432,6 @@ where
DispatcherError::Encoder(err) => {
PollService::Item(DispatchItem::EncoderError(err))
}
DispatcherError::Io(err) => {
PollService::Item(DispatchItem::Disconnect(Some(err)))
}
DispatcherError::Service(err) => {
self.error.set(Some(err));
PollService::ServiceError

View file

@ -46,13 +46,7 @@ impl ReadFilter for DefaultFilter {
#[inline]
fn read_closed(&self, err: Option<io::Error>) {
if err.is_some() {
self.0.error.set(err);
}
self.0.write_task.wake();
self.0.dispatch_task.wake();
self.0.insert_flags(Flags::IO_ERR | Flags::DSP_STOP);
self.0.notify_disconnect();
self.0.set_error(err);
}
#[inline]
@ -109,13 +103,9 @@ impl WriteFilter for DefaultFilter {
#[inline]
fn write_closed(&self, err: Option<io::Error>) {
if err.is_some() {
self.0.error.set(err);
}
self.0.read_task.wake();
self.0.set_error(err);
self.0.insert_flags(Flags::IO_CLOSED);
self.0.dispatch_task.wake();
self.0.insert_flags(Flags::IO_ERR | Flags::DSP_STOP);
self.0.notify_disconnect();
}
#[inline]

View file

@ -1,4 +1,4 @@
use std::{fmt, future::Future, io, task::Context, task::Poll};
use std::{any::Any, any::TypeId, fmt, future::Future, io, task::Context, task::Poll};
pub mod testing;
@ -14,12 +14,12 @@ mod tokio_impl;
use ntex_bytes::BytesMut;
use ntex_codec::{Decoder, Encoder};
use ntex_util::time::Millis;
use ntex_util::{channel::oneshot::Receiver, future::Either, time::Millis};
pub use self::dispatcher::Dispatcher;
pub use self::filter::DefaultFilter;
pub use self::state::{Io, IoRef, ReadRef, WriteRef};
pub use self::tasks::{ReadState, WriteState};
pub use self::state::{Io, IoRef, OnDisconnect, ReadRef, WriteRef};
pub use self::tasks::{ReadContext, WriteContext};
pub use self::time::Timer;
pub use self::utils::{filter_factory, from_iostream, into_boxed, into_io};
@ -55,8 +55,15 @@ pub trait WriteFilter {
fn release_write_buf(&self, buf: BytesMut) -> Result<(), io::Error>;
}
pub trait Filter: ReadFilter + WriteFilter {
pub trait Filter: ReadFilter + WriteFilter + 'static {
fn shutdown(&self, st: &IoRef) -> Poll<Result<(), io::Error>>;
fn query(
&self,
id: TypeId,
) -> Either<Option<Box<dyn Any>>, Receiver<Option<Box<dyn Any>>>> {
Either::Left(None)
}
}
pub trait FilterFactory<F: Filter>: Sized {
@ -69,7 +76,7 @@ pub trait FilterFactory<F: Filter>: Sized {
}
pub trait IoStream {
fn start(self, _: ReadState, _: WriteState);
fn start(self, _: ReadContext, _: WriteContext);
}
/// Framed transport item

View file

@ -1,14 +1,14 @@
use std::cell::{Cell, RefCell};
use std::task::{Context, Poll};
use std::{future::Future, hash, io, mem, ops::Deref, pin::Pin, ptr, rc::Rc};
use std::{fmt, future::Future, hash, io, mem, ops::Deref, pin::Pin, ptr, rc::Rc};
use ntex_bytes::{BytesMut, PoolId, PoolRef};
use ntex_codec::{Decoder, Encoder};
use ntex_util::time::{Millis, Seconds};
use ntex_util::time::Millis;
use ntex_util::{future::poll_fn, future::Either, task::LocalWaker};
use super::filter::{DefaultFilter, NullFilter};
use super::tasks::{ReadState, WriteState};
use super::tasks::{ReadContext, WriteContext};
use super::{Filter, FilterFactory, IoStream};
bitflags::bitflags! {
@ -21,13 +21,15 @@ bitflags::bitflags! {
const IO_FILTERS_TO = 0b0000_0000_0000_0100;
/// shutdown io tasks
const IO_SHUTDOWN = 0b0000_0000_0000_1000;
/// io object is closed
const IO_CLOSED = 0b0000_0000_0001_0000;
/// pause io read
const RD_PAUSED = 0b0000_0000_0000_1000;
const RD_PAUSED = 0b0000_0000_0010_0000;
/// new data is available
const RD_READY = 0b0000_0000_0001_0000;
const RD_READY = 0b0000_0000_0100_0000;
/// read buffer is full
const RD_BUF_FULL = 0b0000_0000_0010_0000;
const RD_BUF_FULL = 0b0000_0000_1000_0000;
/// wait write completion
const WR_WAIT = 0b0000_0001_0000_0000;
@ -103,8 +105,22 @@ impl IoStateInner {
}
#[inline]
fn is_io_err(&self) -> bool {
self.flags.get().contains(Flags::IO_ERR)
fn is_io_open(&self) -> bool {
!self.flags.get().intersects(
Flags::IO_ERR | Flags::IO_SHUTDOWN | Flags::IO_SHUTDOWN | Flags::IO_CLOSED,
)
}
#[inline]
pub(super) fn set_error(&self, err: Option<io::Error>) {
if err.is_some() {
self.error.set(err);
}
self.read_task.wake();
self.write_task.wake();
self.dispatch_task.wake();
self.insert_flags(Flags::IO_ERR | Flags::DSP_STOP);
self.notify_disconnect();
}
#[inline]
@ -195,7 +211,7 @@ impl Io {
let io_ref = IoRef(inner);
// start io tasks
io.start(ReadState(io_ref.clone()), WriteState(io_ref.clone()));
io.start(ReadContext(io_ref.clone()), WriteContext(io_ref.clone()));
Io(io_ref, FilterItem::Ptr(Box::into_raw(filter)))
}
@ -218,8 +234,8 @@ impl<F> Io<F> {
#[inline]
/// Set io disconnect timeout in secs
pub fn set_disconnect_timeout(&self, timeout: Seconds) {
self.0 .0.disconnect_timeout.set(timeout.into());
pub fn set_disconnect_timeout(&self, timeout: Millis) {
self.0 .0.disconnect_timeout.set(timeout);
}
}
@ -242,31 +258,6 @@ impl<F> Io<F> {
pub fn register_dispatcher(&self, cx: &mut Context<'_>) {
self.0 .0.dispatch_task.register(cx.waker());
}
#[inline]
/// Mark dispatcher as stopped
pub fn dispatcher_stopped(&self) {
self.0 .0.insert_flags(Flags::DSP_STOP);
}
#[inline]
/// Gracefully shutdown read and write io tasks
pub fn init_shutdown(&self, cx: &mut Context<'_>) {
let flags = self.0 .0.flags.get();
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN | Flags::IO_FILTERS) {
log::trace!("initiate io shutdown {:?}", flags);
self.0 .0.insert_flags(Flags::IO_FILTERS);
if let Err(err) = self.0 .0.shutdown_filters(&self.0) {
self.0 .0.error.set(Some(err));
self.0 .0.insert_flags(Flags::IO_ERR);
}
self.0 .0.read_task.wake();
self.0 .0.write_task.wake();
self.0 .0.dispatch_task.register(cx.waker());
}
}
}
impl IoRef {
@ -284,9 +275,9 @@ impl IoRef {
}
#[inline]
/// Check if io error occured in read or write task
pub fn is_io_err(&self) -> bool {
self.0.is_io_err()
/// Check if io is still active
pub fn is_io_open(&self) -> bool {
self.0.is_io_open()
}
#[inline]
@ -304,10 +295,13 @@ impl IoRef {
#[inline]
/// Check if io stream is closed
pub fn is_closed(&self) -> bool {
self.0
.flags
.get()
.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN | Flags::DSP_STOP)
self.0.flags.get().intersects(
Flags::IO_ERR
| Flags::IO_SHUTDOWN
| Flags::IO_CLOSED
| Flags::IO_FILTERS
| Flags::DSP_STOP,
)
}
#[inline]
@ -316,6 +310,12 @@ impl IoRef {
self.0.error.take()
}
#[inline]
/// Mark dispatcher as stopped
pub fn stop_dispatcher(&self) {
self.0.insert_flags(Flags::DSP_STOP);
}
#[inline]
/// Reset keep-alive error
pub fn reset_keepalive(&self) {
@ -360,9 +360,15 @@ impl IoRef {
pub fn on_disconnect(&self) -> OnDisconnect {
OnDisconnect::new(self.0.clone(), self.0.flags.get().contains(Flags::IO_ERR))
}
#[inline]
/// Query specific data
pub fn query<T: 'static>(&self) -> Option<T> {
todo!()
}
}
impl<F> Io<F> {
impl IoRef {
#[inline]
/// Read incoming io stream and decode codec item.
pub async fn next<U>(
@ -375,18 +381,18 @@ impl<F> Io<F> {
let read = self.read();
loop {
let mut buf = self.0 .0.read_buf.take();
let mut buf = self.0.read_buf.take();
let item = if let Some(ref mut buf) = buf {
codec.decode(buf)
} else {
Ok(None)
};
self.0 .0.read_buf.set(buf);
self.0.read_buf.set(buf);
return match item {
Ok(Some(el)) => Ok(Some(el)),
Ok(None) => {
self.0 .0.remove_flags(Flags::RD_READY);
self.0.remove_flags(Flags::RD_READY);
if poll_fn(|cx| read.poll_ready(cx))
.await
.map_err(Either::Right)?
@ -411,53 +417,53 @@ impl<F> Io<F> {
where
U: Encoder,
{
let filter = self.0 .0.filter.get();
let filter = self.0.filter.get();
let mut buf = filter
.get_write_buf()
.unwrap_or_else(|| self.0 .0.pool.get().get_write_buf());
.unwrap_or_else(|| self.0.pool.get().get_write_buf());
let is_write_sleep = buf.is_empty();
codec.encode(item, &mut buf).map_err(Either::Left)?;
filter.release_write_buf(buf).map_err(Either::Right)?;
self.0 .0.insert_flags(Flags::WR_WAIT);
self.0.insert_flags(Flags::WR_WAIT);
if is_write_sleep {
self.0 .0.write_task.wake();
self.0.write_task.wake();
}
poll_fn(|cx| self.write().poll_flush(cx))
poll_fn(|cx| self.write().poll_flush(cx, true))
.await
.map_err(Either::Right)?;
Ok(())
}
#[inline]
/// Shuts down connection
pub async fn shutdown(&self) -> Result<(), io::Error> {
if self.flags().intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
Ok(())
} else {
poll_fn(|cx| {
let flags = self.flags();
if !flags.contains(Flags::IO_FILTERS) {
self.init_shutdown(cx);
}
/// Shut down connection
pub fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
let flags = self.flags();
if self.flags().intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
if let Some(err) = self.0 .0.error.take() {
Poll::Ready(Err(err))
} else {
Poll::Ready(Ok(()))
}
} else {
self.0 .0.insert_flags(Flags::IO_FILTERS);
self.0 .0.dispatch_task.register(cx.waker());
Poll::Pending
}
})
.await
if flags.intersects(Flags::IO_ERR | Flags::IO_CLOSED) {
Poll::Ready(Ok(()))
} else {
if !flags.contains(Flags::IO_FILTERS) {
self.init_shutdown(cx);
}
self.0.insert_flags(Flags::IO_FILTERS);
if let Some(err) = self.0.error.take() {
Poll::Ready(Err(err))
} else {
self.0.dispatch_task.register(cx.waker());
Poll::Pending
}
}
}
#[inline]
/// Shut down connection
pub async fn shutdown(&self) -> Result<(), io::Error> {
poll_fn(|cx| self.poll_shutdown(cx)).await
}
#[inline]
#[allow(clippy::type_complexity)]
pub fn poll_next<U>(
@ -468,38 +474,48 @@ impl<F> Io<F> {
where
U: Decoder,
{
if self
.read()
.poll_ready(cx)
.map_err(Either::Right)?
.is_ready()
{
let mut buf = self.0 .0.read_buf.take();
let item = if let Some(ref mut buf) = buf {
codec.decode(buf)
} else {
Ok(None)
};
self.0 .0.read_buf.set(buf);
let read = self.read();
match item {
Ok(Some(el)) => Poll::Ready(Ok(Some(el))),
Ok(None) => {
if let Poll::Ready(res) =
self.read().poll_ready(cx).map_err(Either::Right)?
{
if res.is_none() {
return Poll::Ready(Ok(None));
}
match read.decode(codec) {
Ok(Some(el)) => Poll::Ready(Ok(Some(el))),
Ok(None) => {
if let Poll::Ready(res) = read.poll_ready(cx).map_err(Either::Right)? {
if res.is_none() {
return Poll::Ready(Ok(None));
}
Poll::Pending
}
Err(err) => Poll::Ready(Err(Either::Left(err))),
Poll::Pending
}
} else {
Poll::Pending
Err(err) => Poll::Ready(Err(Either::Left(err))),
}
}
#[inline]
/// Gracefully shutdown read and write io tasks
pub(super) fn init_shutdown(&self, cx: &mut Context<'_>) {
let flags = self.0.flags.get();
if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN | Flags::IO_FILTERS) {
log::trace!("initiate io shutdown {:?}", flags);
self.0.insert_flags(Flags::IO_FILTERS);
if let Err(err) = self.0.shutdown_filters(self) {
self.0.error.set(Some(err));
self.0.insert_flags(Flags::IO_ERR);
}
self.0.read_task.wake();
self.0.write_task.wake();
self.0.dispatch_task.register(cx.waker());
}
}
}
impl fmt::Debug for IoRef {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("IoRef")
.field("open", &!self.is_closed())
.finish()
}
}
impl<F: Filter> Io<F> {
@ -576,7 +592,10 @@ impl<F: Filter> Io<F> {
impl<F> Drop for Io<F> {
fn drop(&mut self) {
log::trace!("stopping io stream");
log::trace!(
"io is dropped, force stopping io streams {:?}",
self.0.flags()
);
if let FilterItem::Ptr(p) = self.1 {
if p.is_null() {
return;
@ -635,7 +654,7 @@ impl<'a> WriteRef<'a> {
///
/// Write task must be waken up separately.
pub fn enable_backpressure(&self, cx: Option<&mut Context<'_>>) {
log::trace!("enable write back-pressure");
log::trace!("enable write back-pressure {:?}", cx.is_some());
self.0.insert_flags(Flags::WR_BACKPRESSURE);
if let Some(cx) = cx {
self.0.dispatch_task.register(cx.waker());
@ -669,7 +688,7 @@ impl<'a> WriteRef<'a> {
&self,
item: U::Item,
codec: &U,
) -> Result<bool, Either<<U as Encoder>::Error, io::Error>>
) -> Result<bool, <U as Encoder>::Error>
where
U: Encoder,
{
@ -690,28 +709,44 @@ impl<'a> WriteRef<'a> {
}
// encode item and wake write task
let result = codec
.encode(item, &mut buf)
.map(|_| {
if is_write_sleep {
self.0.write_task.wake();
}
buf.len() < hw
})
.map_err(Either::Left);
filter.release_write_buf(buf).map_err(Either::Right)?;
Ok(result?)
let result = codec.encode(item, &mut buf).map(|_| {
if is_write_sleep {
self.0.write_task.wake();
}
buf.len() < hw
});
if let Err(err) = filter.release_write_buf(buf) {
self.0.set_error(Some(err));
}
result
} else {
Ok(true)
}
}
#[inline]
/// Wake write task and instruct to write all data.
/// Wake write task and instruct to write data.
///
/// When write task is done wake dispatcher.
pub fn poll_flush(&self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.0.insert_flags(Flags::WR_WAIT);
/// If full is true then wake up dispatcher when all data is flushed
/// otherwise wake up when size of write buffer is lower than
/// buffer max size.
pub fn poll_flush(
&self,
cx: &mut Context<'_>,
full: bool,
) -> Poll<Result<(), io::Error>> {
// check io error
if !self.0.is_io_open() {
return Poll::Ready(Err(self.0.error.take().unwrap_or_else(|| {
io::Error::new(io::ErrorKind::Other, "disconnected")
})));
}
if full {
self.0.insert_flags(Flags::WR_WAIT);
} else {
self.0.insert_flags(Flags::WR_BACKPRESSURE);
}
if let Some(buf) = self.0.write_buf.take() {
if !buf.is_empty() {
@ -722,14 +757,16 @@ impl<'a> WriteRef<'a> {
}
}
if self.0.is_io_err() {
Poll::Ready(Err(self.0.error.take().unwrap_or_else(|| {
io::Error::new(io::ErrorKind::Other, "disconnected")
})))
} else {
self.0.dispatch_task.register(cx.waker());
Poll::Ready(Ok(()))
}
// self.0.dispatch_task.register(cx.waker());
Poll::Ready(Ok(()))
}
#[inline]
/// Wake write task and instruct to write data.
///
/// This is async version of .poll_flush() method.
pub async fn flush(&self, full: bool) -> Result<(), io::Error> {
poll_fn(|cx| self.poll_flush(cx, full)).await
}
}
@ -834,7 +871,7 @@ impl<'a> ReadRef<'a> {
let mut flags = self.0.flags.get();
let ready = flags.contains(Flags::RD_READY);
if self.0.is_io_err() {
if !self.0.is_io_open() {
if let Some(err) = self.0.error.take() {
Poll::Ready(Err(err))
} else {
@ -843,7 +880,6 @@ impl<'a> ReadRef<'a> {
} else if ready {
Poll::Ready(Ok(Some(())))
} else {
flags.remove(Flags::RD_READY);
if flags.contains(Flags::RD_BUF_FULL) {
log::trace!("read back-pressure is enabled, wake io task");
flags.remove(Flags::RD_BUF_FULL);
@ -939,7 +975,6 @@ mod tests {
#[ntex::test]
async fn utils() {
env_logger::init();
let (client, server) = IoTest::create();
client.remote_buffer_cap(1024);
client.write(TEXT);
@ -1041,7 +1076,7 @@ mod tests {
in_bytes: Rc<Cell<usize>>,
out_bytes: Rc<Cell<usize>>,
}
impl<F: ReadFilter + WriteFilter> Filter for Counter<F> {
impl<F: ReadFilter + WriteFilter + 'static> Filter for Counter<F> {
fn shutdown(&self, _: &IoRef) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}

View file

@ -4,9 +4,9 @@ use ntex_bytes::{BytesMut, PoolRef};
use super::{state::Flags, IoRef, WriteReadiness};
pub struct ReadState(pub(super) IoRef);
pub struct ReadContext(pub(super) IoRef);
impl ReadState {
impl ReadContext {
#[inline]
pub fn memory_pool(&self) -> PoolRef {
self.0 .0.pool.get()
@ -60,9 +60,9 @@ impl ReadState {
}
}
pub struct WriteState(pub(super) IoRef);
pub struct WriteContext(pub(super) IoRef);
impl WriteState {
impl WriteContext {
#[inline]
pub fn memory_pool(&self) -> PoolRef {
self.0 .0.pool.get()

View file

@ -1,11 +1,13 @@
use std::cell::{Cell, RefCell};
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};
use std::{cmp, fmt, io, mem};
use std::{cmp, fmt, future::Future, io, mem, pin::Pin, rc::Rc};
use ntex_bytes::{BufMut, BytesMut};
use ntex_bytes::{Buf, BufMut, BytesMut};
use ntex_util::future::poll_fn;
use ntex_util::time::{sleep, Millis};
use ntex_util::time::{sleep, Millis, Sleep};
use crate::{IoStream, ReadContext, WriteContext, WriteReadiness};
#[derive(Default)]
struct AtomicWaker(Arc<Mutex<RefCell<Option<Waker>>>>);
@ -441,138 +443,181 @@ mod tokio {
}
}
#[cfg(not(feature = "tokio"))]
mod non_tokio {
impl IoStream for IoTest {
fn start(self, read: ReadState, write: WriteState) {
let io = Rc::new(self);
impl IoStream for IoTest {
fn start(self, read: ReadContext, write: WriteContext) {
let io = Rc::new(self);
ntex_util::spawn(ReadTask {
io: io.clone(),
state: read,
});
ntex_util::spawn(WriteTask {
io,
state: write,
st: IoWriteState::Processing,
});
}
ntex_util::spawn(ReadTask {
io: io.clone(),
state: read,
});
ntex_util::spawn(WriteTask {
io,
state: write,
st: IoWriteState::Processing(None),
});
}
}
/// Read io task
struct ReadTask {
io: Rc<IoTest>,
state: ReadState,
}
/// Read io task
struct ReadTask {
io: Rc<IoTest>,
state: ReadContext,
}
impl Future for ReadTask {
type Output = ();
impl Future for ReadTask {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.as_ref();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.as_ref();
match this.state.poll_ready(cx) {
Poll::Ready(Err(())) => {
log::trace!("read task is instructed to terminate");
Poll::Ready(())
}
Poll::Ready(Ok(())) => {
let io = &this.io;
let pool = this.state.memory_pool();
let mut buf = self.state.get_read_buf();
let (hw, lw) = pool.read_params().unpack();
match this.state.poll_ready(cx) {
Poll::Ready(Err(())) => {
log::trace!("read task is instructed to terminate");
Poll::Ready(())
}
Poll::Ready(Ok(())) => {
let io = &this.io;
let pool = this.state.memory_pool();
let mut buf = self.state.get_read_buf();
let (hw, lw) = pool.read_params().unpack();
// read data from socket
let mut new_bytes = 0;
loop {
// make sure we've got room
let remaining = buf.remaining_mut();
if remaining < lw {
buf.reserve(hw - remaining);
}
match io.poll_read_buf(cx, &mut buf) {
Poll::Pending => {
log::trace!("no more data in io stream");
break;
}
Poll::Ready(Ok(n)) => {
if n == 0 {
log::trace!("io stream is disconnected");
this.state.release_read_buf(buf, new_bytes);
this.state.close(None);
return Poll::Ready(());
} else {
new_bytes += n;
if buf.len() > hw {
break;
}
}
}
Poll::Ready(Err(err)) => {
log::trace!("read task failed on io {:?}", err);
this.state.release_read_buf(buf, new_bytes);
this.state.close(Some(err));
return Poll::Ready(());
}
}
// read data from socket
let mut new_bytes = 0;
loop {
// make sure we've got room
let remaining = buf.remaining_mut();
if remaining < lw {
buf.reserve(hw - remaining);
}
this.state.release_read_buf(buf, new_bytes);
Poll::Pending
}
Poll::Pending => Poll::Pending,
}
}
}
#[derive(Debug)]
enum IoWriteState {
Processing,
Shutdown(Option<Sleep>, Shutdown),
}
#[derive(Debug)]
enum Shutdown {
None,
Flushed,
Stopping,
}
/// Write io task
struct WriteTask {
st: IoWriteState,
io: Rc<IoTest>,
state: WriteState,
}
impl Future for WriteTask {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.as_mut().get_mut();
match this.st {
IoWriteState::Processing => {
match this.state.poll_ready(cx) {
Poll::Ready(Ok(())) => {
// flush framed instance
match flush_io(&this.io, &this.state, cx) {
Poll::Pending | Poll::Ready(true) => Poll::Pending,
Poll::Ready(false) => Poll::Ready(()),
match io.poll_read_buf(cx, &mut buf) {
Poll::Pending => {
log::trace!("no more data in io stream");
break;
}
Poll::Ready(Ok(n)) => {
if n == 0 {
log::trace!("io stream is disconnected");
let _ = this.state.release_read_buf(buf, new_bytes);
this.state.close(None);
return Poll::Ready(());
} else {
new_bytes += n;
if buf.len() > hw {
break;
}
}
}
Poll::Ready(Err(WriteReadiness::Shutdown)) => {
log::trace!("write task is instructed to shutdown");
this.st = IoWriteState::Shutdown(
this.state.disconnect_timeout().map(sleep),
Shutdown::None,
);
self.poll(cx)
Poll::Ready(Err(err)) => {
log::trace!("read task failed on io {:?}", err);
let _ = this.state.release_read_buf(buf, new_bytes);
this.state.close(Some(err));
return Poll::Ready(());
}
Poll::Ready(Err(WriteReadiness::Terminate)) => {
log::trace!("write task is instructed to terminate");
}
}
let _ = this.state.release_read_buf(buf, new_bytes);
Poll::Pending
}
Poll::Pending => Poll::Pending,
}
}
}
#[derive(Debug)]
enum IoWriteState {
Processing(Option<Sleep>),
Shutdown(Option<Sleep>, Shutdown),
}
#[derive(Debug)]
enum Shutdown {
None,
Flushed,
Stopping,
}
/// Write io task
struct WriteTask {
st: IoWriteState,
io: Rc<IoTest>,
state: WriteContext,
}
impl Future for WriteTask {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.as_mut().get_mut();
match this.st {
IoWriteState::Processing(ref mut delay) => {
match this.state.poll_ready(cx) {
Poll::Ready(Ok(())) => {
// flush framed instance
match flush_io(&this.io, &this.state, cx) {
Poll::Pending | Poll::Ready(true) => Poll::Pending,
Poll::Ready(false) => Poll::Ready(()),
}
}
Poll::Ready(Err(WriteReadiness::Timeout(time))) => {
if delay.is_none() {
*delay = Some(sleep(time));
}
self.poll(cx)
}
Poll::Ready(Err(WriteReadiness::Shutdown(time))) => {
log::trace!("write task is instructed to shutdown");
let timeout = if let Some(delay) = delay.take() {
delay
} else {
sleep(time)
};
this.st = IoWriteState::Shutdown(Some(timeout), Shutdown::None);
self.poll(cx)
}
Poll::Ready(Err(WriteReadiness::Terminate)) => {
log::trace!("write task is instructed to terminate");
// shutdown WRITE side
this.io
.local
.lock()
.unwrap()
.borrow_mut()
.flags
.insert(Flags::CLOSED);
this.state.close(None);
Poll::Ready(())
}
Poll::Pending => Poll::Pending,
}
}
IoWriteState::Shutdown(ref mut delay, ref mut st) => {
// close WRITE side and wait for disconnect on read side.
// use disconnect timeout, otherwise it could hang forever.
loop {
match st {
Shutdown::None => {
// flush write buffer
match flush_io(&this.io, &this.state, cx) {
Poll::Ready(true) => {
*st = Shutdown::Flushed;
continue;
}
Poll::Ready(false) => {
log::trace!(
"write task is closed with err during flush"
);
return Poll::Ready(());
}
_ => (),
}
}
Shutdown::Flushed => {
// shutdown WRITE side
this.io
.local
@ -581,143 +626,102 @@ mod non_tokio {
.borrow_mut()
.flags
.insert(Flags::CLOSED);
this.state.close(None);
Poll::Ready(())
*st = Shutdown::Stopping;
continue;
}
Poll::Pending => Poll::Pending,
}
}
IoWriteState::Shutdown(ref mut delay, ref mut st) => {
// close WRITE side and wait for disconnect on read side.
// use disconnect timeout, otherwise it could hang forever.
loop {
match st {
Shutdown::None => {
// flush write buffer
match flush_io(&this.io, &this.state, cx) {
Poll::Ready(true) => {
*st = Shutdown::Flushed;
continue;
}
Poll::Ready(false) => {
log::trace!(
"write task is closed with err during flush"
);
Shutdown::Stopping => {
// read until 0 or err
let io = &this.io;
loop {
let mut buf = BytesMut::new();
match io.poll_read_buf(cx, &mut buf) {
Poll::Ready(Err(e)) => {
this.state.close(Some(e));
log::trace!("write task is stopped");
return Poll::Ready(());
}
Poll::Ready(Ok(n)) if n == 0 => {
this.state.close(None);
log::trace!("write task is stopped");
return Poll::Ready(());
}
Poll::Pending => break,
_ => (),
}
}
Shutdown::Flushed => {
// shutdown WRITE side
this.io
.local
.lock()
.unwrap()
.borrow_mut()
.flags
.insert(Flags::CLOSED);
*st = Shutdown::Stopping;
continue;
}
Shutdown::Stopping => {
// read until 0 or err
let io = &this.io;
loop {
let mut buf = BytesMut::new();
match io.poll_read_buf(cx, &mut buf) {
Poll::Ready(Err(e)) => {
this.state.close(Some(e));
log::trace!("write task is stopped");
return Poll::Ready(());
}
Poll::Ready(Ok(n)) if n == 0 => {
this.state.close(None);
log::trace!("write task is stopped");
return Poll::Ready(());
}
Poll::Pending => break,
_ => (),
}
}
}
}
// disconnect timeout
if let Some(ref delay) = delay {
if delay.poll_elapsed(cx).is_pending() {
return Poll::Pending;
}
}
log::trace!("write task is stopped after delay");
this.state.close(None);
return Poll::Ready(());
}
// disconnect timeout
if let Some(ref delay) = delay {
if delay.poll_elapsed(cx).is_pending() {
return Poll::Pending;
}
}
log::trace!("write task is stopped after delay");
this.state.close(None);
return Poll::Ready(());
}
}
}
}
}
/// Flush write buffer to underlying I/O stream.
pub(super) fn flush_io(
io: &IoTest,
state: &WriteState,
cx: &mut Context<'_>,
) -> Poll<bool> {
let mut buf = if let Some(buf) = state.get_write_buf() {
buf
} else {
return Poll::Ready(true);
};
let len = buf.len();
let pool = state.memory_pool();
/// Flush write buffer to underlying I/O stream.
pub(super) fn flush_io(
io: &IoTest,
state: &WriteContext,
cx: &mut Context<'_>,
) -> Poll<bool> {
let mut buf = if let Some(buf) = state.get_write_buf() {
buf
} else {
return Poll::Ready(true);
};
let len = buf.len();
if len != 0 {
log::trace!("flushing framed transport: {}", len);
if len != 0 {
log::trace!("flushing framed transport: {}", len);
let mut written = 0;
while written < len {
match io.poll_write_buf(cx, &buf[written..]) {
Poll::Ready(Ok(n)) => {
if n == 0 {
log::trace!(
"disconnected during flush, written {}",
written
);
pool.release_write_buf(buf);
state.close(Some(io::Error::new(
io::ErrorKind::WriteZero,
"failed to write frame to transport",
)));
return Poll::Ready(false);
} else {
written += n
}
}
Poll::Pending => break,
Poll::Ready(Err(e)) => {
log::trace!("error during flush: {}", e);
pool.release_write_buf(buf);
state.close(Some(e));
let mut written = 0;
while written < len {
match io.poll_write_buf(cx, &buf[written..]) {
Poll::Ready(Ok(n)) => {
if n == 0 {
log::trace!("disconnected during flush, written {}", written);
let _ = state.release_write_buf(buf);
state.close(Some(io::Error::new(
io::ErrorKind::WriteZero,
"failed to write frame to transport",
)));
return Poll::Ready(false);
} else {
written += n
}
}
Poll::Pending => break,
Poll::Ready(Err(e)) => {
log::trace!("error during flush: {}", e);
let _ = state.release_write_buf(buf);
state.close(Some(e));
return Poll::Ready(false);
}
}
log::trace!("flushed {} bytes", written);
// remove written data
if written == len {
buf.clear();
state.release_write_buf(buf);
Poll::Ready(true)
} else {
buf.advance(written);
state.release_write_buf(buf);
Poll::Pending
}
} else {
Poll::Ready(true)
}
log::trace!("flushed {} bytes", written);
// remove written data
if written == len {
buf.clear();
let _ = state.release_write_buf(buf);
Poll::Ready(true)
} else {
buf.advance(written);
let _ = state.release_write_buf(buf);
Poll::Pending
}
} else {
Poll::Ready(true)
}
}

View file

@ -3,15 +3,12 @@ use std::{cell::RefCell, future::Future, io, pin::Pin, rc::Rc};
use ntex_bytes::{Buf, BufMut};
use ntex_util::time::{sleep, Sleep};
use tok_io::{io::AsyncRead, io::AsyncWrite, io::ReadBuf};
use tok_io::{io::AsyncRead, io::AsyncWrite, io::ReadBuf, net::TcpStream};
use super::{IoStream, ReadState, WriteReadiness, WriteState};
use super::{IoStream, ReadContext, WriteContext, WriteReadiness};
impl<T> IoStream for T
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
fn start(self, read: ReadState, write: WriteState) {
impl IoStream for TcpStream {
fn start(self, read: ReadContext, write: WriteContext) {
let io = Rc::new(RefCell::new(self));
ntex_util::spawn(ReadTask::new(io.clone(), read));
@ -19,26 +16,29 @@ where
}
}
/// Read io task
struct ReadTask<T> {
io: Rc<RefCell<T>>,
state: ReadState,
#[cfg(unix)]
impl IoStream for tok_io::net::UnixStream {
fn start(self, _read: ReadContext, _write: WriteContext) {
let _io = Rc::new(RefCell::new(self));
todo!()
}
}
impl<T> ReadTask<T>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
/// Read io task
struct ReadTask {
io: Rc<RefCell<TcpStream>>,
state: ReadContext,
}
impl ReadTask {
/// Create new read io task
fn new(io: Rc<RefCell<T>>, state: ReadState) -> Self {
fn new(io: Rc<RefCell<TcpStream>>, state: ReadContext) -> Self {
Self { io, state }
}
}
impl<T> Future for ReadTask<T>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
impl Future for ReadTask {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
@ -119,18 +119,15 @@ enum Shutdown {
}
/// Write io task
struct WriteTask<T> {
struct WriteTask {
st: IoWriteState,
io: Rc<RefCell<T>>,
state: WriteState,
io: Rc<RefCell<TcpStream>>,
state: WriteContext,
}
impl<T> WriteTask<T>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
impl WriteTask {
/// Create new write io task
fn new(io: Rc<RefCell<T>>, state: WriteState) -> Self {
fn new(io: Rc<RefCell<TcpStream>>, state: WriteContext) -> Self {
Self {
io,
state,
@ -139,10 +136,7 @@ where
}
}
impl<T> Future for WriteTask<T>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
impl Future for WriteTask {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
@ -272,7 +266,7 @@ where
/// Flush write buffer to underlying I/O stream.
pub(super) fn flush_io<T: AsyncRead + AsyncWrite + Unpin>(
io: &mut T,
state: &WriteState,
state: &WriteContext,
cx: &mut Context<'_>,
) -> Poll<bool> {
let mut buf = if let Some(buf) = state.get_write_buf() {
@ -284,12 +278,14 @@ pub(super) fn flush_io<T: AsyncRead + AsyncWrite + Unpin>(
let pool = state.memory_pool();
if len != 0 {
// log::trace!("flushing framed transport: {:?}", buf);
//log::trace!("flushing framed transport: {:?}", buf);
let mut written = 0;
while written < len {
match Pin::new(&mut *io).poll_write(cx, &buf[written..]) {
Poll::Pending => break,
Poll::Pending => {
break;
}
Poll::Ready(Ok(n)) => {
if n == 0 {
log::trace!("Disconnected during flush, written {}", written);
@ -311,7 +307,7 @@ pub(super) fn flush_io<T: AsyncRead + AsyncWrite + Unpin>(
}
}
}
// log::trace!("flushed {} bytes", written);
//log::trace!("flushed {} bytes", written);
// remove written data
let result = if written == len {