Refactor async io support (#417)

This commit is contained in:
Nikolay Kim 2024-09-11 18:18:45 +05:00 committed by GitHub
parent db6d3a6e4c
commit 1d529fab3c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 872 additions and 2508 deletions

View file

@ -1,6 +1,6 @@
[package] [package]
name = "ntex-async-std" name = "ntex-async-std"
version = "0.5.0" version = "0.5.1"
authors = ["ntex contributors <team@ntex.rs>"] authors = ["ntex contributors <team@ntex.rs>"]
description = "async-std intergration for ntex framework" description = "async-std intergration for ntex framework"
keywords = ["network", "framework", "async", "futures"] keywords = ["network", "framework", "async", "futures"]
@ -17,7 +17,7 @@ path = "src/lib.rs"
[dependencies] [dependencies]
ntex-bytes = "0.1" ntex-bytes = "0.1"
ntex-io = "2.0" ntex-io = "2.5"
ntex-util = "2.0" ntex-util = "2.0"
log = "0.4" log = "0.4"
async-std = { version = "1", features = ["unstable"] } async-std = { version = "1", features = ["unstable"] }

View file

@ -1,18 +1,24 @@
use std::{any, cell::RefCell, future::Future, io, pin::Pin, task::Context, task::Poll}; use std::{
any, cell::RefCell, future::poll_fn, io, pin::Pin, task::ready, task::Context,
use async_std::io::{Read, Write}; task::Poll,
use ntex_bytes::{Buf, BufMut, BytesVec};
use ntex_io::{
types, Handle, IoStream, ReadContext, ReadStatus, WriteContext, WriteStatus,
}; };
use ntex_util::{ready, time::sleep, time::Sleep};
use async_std::io::{Read as ARead, Write as AWrite};
use ntex_bytes::{Buf, BufMut, BytesVec};
use ntex_io::{types, Handle, IoStream, ReadContext, WriteContext, WriteContextBuf};
use crate::TcpStream; use crate::TcpStream;
impl IoStream for TcpStream { impl IoStream for TcpStream {
fn start(self, read: ReadContext, write: WriteContext) -> Option<Box<dyn Handle>> { fn start(self, read: ReadContext, write: WriteContext) -> Option<Box<dyn Handle>> {
async_std::task::spawn_local(ReadTask::new(self.clone(), read)); let mut rio = Read(RefCell::new(self.clone()));
async_std::task::spawn_local(WriteTask::new(self.clone(), write)); async_std::task::spawn_local(async move {
read.handle(&mut rio).await;
});
let mut wio = Write(RefCell::new(self.clone()));
async_std::task::spawn_local(async move {
write.handle(&mut wio).await;
});
Some(Box::new(self)) Some(Box::new(self))
} }
} }
@ -29,296 +35,111 @@ impl Handle for TcpStream {
} }
/// Read io task /// Read io task
struct ReadTask { struct Read(RefCell<TcpStream>);
io: RefCell<TcpStream>,
state: ReadContext,
}
impl ReadTask { impl ntex_io::AsyncRead for Read {
/// Create new read io task async fn read(&mut self, mut buf: BytesVec) -> (BytesVec, io::Result<usize>) {
fn new(io: TcpStream, state: ReadContext) -> Self { // read data from socket
Self { let result = poll_fn(|cx| {
state, let mut io = self.0.borrow_mut();
io: RefCell::new(io), poll_read_buf(Pin::new(&mut io.0), cx, &mut buf)
} })
.await;
(buf, result)
} }
} }
impl Future for ReadTask { struct Write(RefCell<TcpStream>);
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { impl ntex_io::AsyncWrite for Write {
let this = self.as_ref(); #[inline]
async fn write(&mut self, buf: &mut WriteContextBuf) -> io::Result<()> {
match ready!(this.state.poll_ready(cx)) { poll_fn(|cx| {
ReadStatus::Ready => { if let Some(mut b) = buf.take() {
this.state.with_buf(|buf, hw, lw| { let result = flush_io(&mut self.0.borrow_mut().0, &mut b, cx);
// read data from socket buf.set(b);
let mut io = self.io.borrow_mut(); result
loop { } else {
// make sure we've got room Poll::Ready(Ok(()))
let remaining = buf.remaining_mut();
if remaining < lw {
buf.reserve(hw - remaining);
}
return match poll_read_buf(Pin::new(&mut io.0), cx, buf) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(n)) => {
if n == 0 {
log::trace!("async-std stream is disconnected");
Poll::Ready(Ok(()))
} else if buf.len() < hw {
continue;
} else {
Poll::Pending
}
}
Poll::Ready(Err(err)) => {
log::trace!("async-std read task failed on io {:?}", err);
Poll::Ready(Err(err))
}
};
}
})
} }
ReadStatus::Terminate => { })
log::trace!("read task is instructed to shutdown"); .await
Poll::Ready(())
}
}
} }
}
#[derive(Debug)] #[inline]
enum IoWriteState { async fn flush(&mut self) -> io::Result<()> {
Processing(Option<Sleep>), Ok(())
Shutdown(Sleep, Shutdown),
}
#[derive(Debug)]
enum Shutdown {
None,
Stopping(u16),
}
/// Write io task
struct WriteTask {
st: IoWriteState,
io: TcpStream,
state: WriteContext,
}
impl WriteTask {
/// Create new write io task
fn new(io: TcpStream, state: WriteContext) -> Self {
Self {
io,
state,
st: IoWriteState::Processing(None),
}
} }
}
impl Future for WriteTask { #[inline]
type Output = (); async fn shutdown(&mut self) -> io::Result<()> {
self.0.borrow().0.shutdown(std::net::Shutdown::Both)
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.as_mut().get_mut();
match this.st {
IoWriteState::Processing(ref mut delay) => {
match this.state.poll_ready(cx) {
Poll::Ready(WriteStatus::Ready) => {
if let Some(delay) = delay {
if delay.poll_elapsed(cx).is_ready() {
this.state.close(Some(io::Error::new(
io::ErrorKind::TimedOut,
"Operation timedout",
)));
return Poll::Ready(());
}
}
// flush io stream
let io = &mut this.io.0;
match ready!(this.state.with_buf(|buf| flush_io(io, buf, cx))) {
Ok(()) => Poll::Pending,
Err(e) => {
this.state.close(Some(e));
Poll::Ready(())
}
}
}
Poll::Ready(WriteStatus::Timeout(time)) => {
log::trace!("initiate timeout delay for {:?}", time);
if delay.is_none() {
*delay = Some(sleep(time));
}
self.poll(cx)
}
Poll::Ready(WriteStatus::Shutdown(time)) => {
log::trace!("write task is instructed to shutdown");
let timeout = if let Some(delay) = delay.take() {
delay
} else {
sleep(time)
};
this.st = IoWriteState::Shutdown(timeout, Shutdown::None);
self.poll(cx)
}
Poll::Ready(WriteStatus::Terminate) => {
log::trace!("write task is instructed to terminate");
let _ = Pin::new(&mut this.io.0).poll_close(cx);
this.state.close(None);
Poll::Ready(())
}
Poll::Pending => Poll::Pending,
}
}
IoWriteState::Shutdown(ref mut delay, ref mut st) => {
// close WRITE side and wait for disconnect on read side.
// use disconnect timeout, otherwise it could hang forever.
loop {
match st {
Shutdown::None => {
// flush write buffer
let io = &mut this.io.0;
match this.state.with_buf(|buf| flush_io(io, buf, cx)) {
Poll::Ready(Ok(())) => {
if let Err(e) =
this.io.0.shutdown(std::net::Shutdown::Write)
{
this.state.close(Some(e));
return Poll::Ready(());
}
*st = Shutdown::Stopping(0);
continue;
}
Poll::Ready(Err(err)) => {
log::trace!(
"write task is closed with err during flush, {:?}",
err
);
this.state.close(Some(err));
return Poll::Ready(());
}
Poll::Pending => (),
}
}
Shutdown::Stopping(ref mut count) => {
// read until 0 or err
let mut buf = [0u8; 512];
let io = &mut this.io;
loop {
match Pin::new(&mut io.0).poll_read(cx, &mut buf) {
Poll::Ready(Err(e)) => {
log::trace!("write task is stopped");
this.state.close(Some(e));
return Poll::Ready(());
}
Poll::Ready(Ok(0)) => {
log::trace!("async-std socket is disconnected");
this.state.close(None);
return Poll::Ready(());
}
Poll::Ready(Ok(n)) => {
*count += n as u16;
if *count > 4096 {
log::trace!(
"write task is stopped, too much input"
);
this.state.close(None);
return Poll::Ready(());
}
}
Poll::Pending => break,
}
}
}
}
// disconnect timeout
if delay.poll_elapsed(cx).is_pending() {
return Poll::Pending;
}
log::trace!("write task is stopped after delay");
this.state.close(None);
let _ = Pin::new(&mut this.io.0).poll_close(cx);
return Poll::Ready(());
}
}
}
} }
} }
/// Flush write buffer to underlying I/O stream. /// Flush write buffer to underlying I/O stream.
pub(super) fn flush_io<T: Read + Write + Unpin>( pub(super) fn flush_io<T: ARead + AWrite + Unpin>(
io: &mut T, io: &mut T,
buf: &mut Option<BytesVec>, buf: &mut BytesVec,
cx: &mut Context<'_>, cx: &mut Context<'_>,
) -> Poll<io::Result<()>> { ) -> Poll<io::Result<()>> {
if let Some(buf) = buf { let len = buf.len();
let len = buf.len();
if len != 0 { if len != 0 {
// log::trace!("flushing framed transport: {:?}", buf.len()); // log::trace!("flushing framed transport: {:?}", buf.len());
let mut written = 0; let mut written = 0;
let result = loop { let result = loop {
break match Pin::new(&mut *io).poll_write(cx, &buf[written..]) { break match Pin::new(&mut *io).poll_write(cx, &buf[written..]) {
Poll::Ready(Ok(n)) => { Poll::Ready(Ok(n)) => {
if n == 0 { if n == 0 {
log::trace!("Disconnected during flush, written {}", written); log::trace!("Disconnected during flush, written {}", written);
Poll::Ready(Err(io::Error::new( Poll::Ready(Err(io::Error::new(
io::ErrorKind::WriteZero, io::ErrorKind::WriteZero,
"failed to write frame to transport", "failed to write frame to transport",
))) )))
} else {
written += n;
if written == len {
buf.clear();
Poll::Ready(Ok(()))
} else { } else {
written += n; continue;
if written == len {
buf.clear();
Poll::Ready(Ok(()))
} else {
continue;
}
} }
} }
Poll::Pending => {
// remove written data
buf.advance(written);
Poll::Pending
}
Poll::Ready(Err(e)) => {
log::trace!("Error during flush: {}", e);
Poll::Ready(Err(e))
}
};
};
// log::trace!("flushed {} bytes", written);
// flush
return if written > 0 {
match Pin::new(&mut *io).poll_flush(cx) {
Poll::Ready(Ok(_)) => result,
Poll::Pending => Poll::Pending,
Poll::Ready(Err(e)) => {
log::trace!("error during flush: {}", e);
Poll::Ready(Err(e))
}
} }
} else { Poll::Pending => {
result // remove written data
buf.advance(written);
Poll::Pending
}
Poll::Ready(Err(e)) => {
log::trace!("Error during flush: {}", e);
Poll::Ready(Err(e))
}
}; };
};
// log::trace!("flushed {} bytes", written);
// flush
if written > 0 {
match Pin::new(&mut *io).poll_flush(cx) {
Poll::Ready(Ok(_)) => result,
Poll::Pending => Poll::Pending,
Poll::Ready(Err(e)) => {
log::trace!("error during flush: {}", e);
Poll::Ready(Err(e))
}
}
} else {
result
} }
} else {
Poll::Ready(Ok(()))
} }
Poll::Ready(Ok(()))
} }
pub fn poll_read_buf<T: Read>( pub fn poll_read_buf<T: ARead>(
io: Pin<&mut T>, io: Pin<&mut T>,
cx: &mut Context<'_>, cx: &mut Context<'_>,
buf: &mut BytesVec, buf: &mut BytesVec,
@ -342,226 +163,58 @@ mod unixstream {
impl IoStream for UnixStream { impl IoStream for UnixStream {
fn start(self, read: ReadContext, write: WriteContext) -> Option<Box<dyn Handle>> { fn start(self, read: ReadContext, write: WriteContext) -> Option<Box<dyn Handle>> {
async_std::task::spawn_local(ReadTask::new(self.clone(), read)); let mut rio = Read(RefCell::new(self.clone()));
async_std::task::spawn_local(WriteTask::new(self, write)); async_std::task::spawn_local(async move {
read.handle(&mut rio).await;
});
let mut wio = Write(RefCell::new(self));
async_std::task::spawn_local(async move {
write.handle(&mut wio).await;
});
None None
} }
} }
/// Read io task /// Read io task
struct ReadTask { struct Read(RefCell<UnixStream>);
io: RefCell<UnixStream>,
state: ReadContext,
}
impl ReadTask { impl ntex_io::AsyncRead for Read {
/// Create new read io task async fn read(&mut self, mut buf: BytesVec) -> (BytesVec, io::Result<usize>) {
fn new(io: UnixStream, state: ReadContext) -> Self { // read data from socket
Self { let result = poll_fn(|cx| {
state, let mut io = self.0.borrow_mut();
io: RefCell::new(io), poll_read_buf(Pin::new(&mut io.0), cx, &mut buf)
} })
.await;
(buf, result)
} }
} }
impl Future for ReadTask { struct Write(RefCell<UnixStream>);
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { impl ntex_io::AsyncWrite for Write {
let this = self.as_ref(); #[inline]
async fn write(&mut self, buf: &mut WriteContextBuf) -> io::Result<()> {
this.state.with_buf(|buf, hw, lw| { poll_fn(|cx| {
match ready!(this.state.poll_ready(cx)) { if let Some(mut b) = buf.take() {
ReadStatus::Ready => { let result = flush_io(&mut self.0.borrow_mut().0, &mut b, cx);
// read data from socket buf.set(b);
let mut io = this.io.borrow_mut(); result
loop { } else {
// make sure we've got room Poll::Ready(Ok(()))
let remaining = buf.remaining_mut();
if remaining < lw {
buf.reserve(hw - remaining);
}
return match poll_read_buf(Pin::new(&mut io.0), cx, buf) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(n)) => {
if n == 0 {
log::trace!("async-std stream is disconnected");
Poll::Ready(Ok(()))
} else if buf.len() < hw {
continue;
} else {
Poll::Pending
}
}
Poll::Ready(Err(err)) => {
log::trace!("read task failed on io {:?}", err);
Poll::Ready(Err(err))
}
};
}
}
ReadStatus::Terminate => {
log::trace!("read task is instructed to shutdown");
Poll::Ready(Ok(()))
}
} }
}) })
.await
} }
}
/// Write io task #[inline]
struct WriteTask { async fn flush(&mut self) -> io::Result<()> {
st: IoWriteState, Ok(())
io: UnixStream,
state: WriteContext,
}
impl WriteTask {
/// Create new write io task
fn new(io: UnixStream, state: WriteContext) -> Self {
Self {
io,
state,
st: IoWriteState::Processing(None),
}
} }
}
impl Future for WriteTask { #[inline]
type Output = (); async fn shutdown(&mut self) -> io::Result<()> {
self.0.borrow().0.shutdown(std::net::Shutdown::Both)
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.as_mut().get_mut();
match this.st {
IoWriteState::Processing(ref mut delay) => {
match this.state.poll_ready(cx) {
Poll::Ready(WriteStatus::Ready) => {
if let Some(delay) = delay {
if delay.poll_elapsed(cx).is_ready() {
this.state.close(Some(io::Error::new(
io::ErrorKind::TimedOut,
"Operation timedout",
)));
return Poll::Ready(());
}
}
// flush io stream
let io = &mut this.io.0;
match ready!(this.state.with_buf(|buf| flush_io(io, buf, cx))) {
Ok(()) => Poll::Pending,
Err(e) => {
this.state.close(Some(e));
Poll::Ready(())
}
}
}
Poll::Ready(WriteStatus::Timeout(time)) => {
log::trace!("initiate timeout delay for {:?}", time);
if delay.is_none() {
*delay = Some(sleep(time));
}
self.poll(cx)
}
Poll::Ready(WriteStatus::Shutdown(time)) => {
log::trace!("write task is instructed to shutdown");
let timeout = if let Some(delay) = delay.take() {
delay
} else {
sleep(time)
};
this.st = IoWriteState::Shutdown(timeout, Shutdown::None);
self.poll(cx)
}
Poll::Ready(WriteStatus::Terminate) => {
log::trace!("write task is instructed to terminate");
let _ = Pin::new(&mut this.io.0).poll_close(cx);
this.state.close(None);
Poll::Ready(())
}
Poll::Pending => Poll::Pending,
}
}
IoWriteState::Shutdown(ref mut delay, ref mut st) => {
// close WRITE side and wait for disconnect on read side.
// use disconnect timeout, otherwise it could hang forever.
loop {
match st {
Shutdown::None => {
// flush write buffer
let io = &mut this.io.0;
match this.state.with_buf(|buf| flush_io(io, buf, cx)) {
Poll::Ready(Ok(())) => {
if let Err(e) =
this.io.0.shutdown(std::net::Shutdown::Write)
{
this.state.close(Some(e));
return Poll::Ready(());
}
*st = Shutdown::Stopping(0);
continue;
}
Poll::Ready(Err(err)) => {
log::trace!(
"write task is closed with err during flush, {:?}",
err
);
this.state.close(Some(err));
return Poll::Ready(());
}
Poll::Pending => (),
}
}
Shutdown::Stopping(ref mut count) => {
// read until 0 or err
let mut buf = [0u8; 512];
let io = &mut this.io;
loop {
match Pin::new(&mut io.0).poll_read(cx, &mut buf) {
Poll::Ready(Err(e)) => {
log::trace!("write task is stopped");
this.state.close(Some(e));
return Poll::Ready(());
}
Poll::Ready(Ok(0)) => {
log::trace!(
"async-std unix socket is disconnected"
);
this.state.close(None);
return Poll::Ready(());
}
Poll::Ready(Ok(n)) => {
*count += n as u16;
if *count > 4096 {
log::trace!(
"write task is stopped, too much input"
);
this.state.close(None);
return Poll::Ready(());
}
}
Poll::Pending => break,
}
}
}
}
// disconnect timeout
if delay.poll_elapsed(cx).is_pending() {
return Poll::Pending;
}
log::trace!("write task is stopped after delay");
this.state.close(None);
let _ = Pin::new(&mut this.io.0).poll_close(cx);
return Poll::Ready(());
}
}
}
} }
} }
} }

View file

@ -1,5 +1,9 @@
# Changes # Changes
## [0.1.2] - 2024-09-11
* Use new io api
## [0.1.1] - 2024-09-05 ## [0.1.1] - 2024-09-05
* Tune write task * Tune write task

View file

@ -1,6 +1,6 @@
[package] [package]
name = "ntex-compio" name = "ntex-compio"
version = "0.1.1" version = "0.1.2"
authors = ["ntex contributors <team@ntex.rs>"] authors = ["ntex contributors <team@ntex.rs>"]
description = "compio runtime intergration for ntex framework" description = "compio runtime intergration for ntex framework"
keywords = ["network", "framework", "async", "futures"] keywords = ["network", "framework", "async", "futures"]
@ -18,7 +18,7 @@ path = "src/lib.rs"
[dependencies] [dependencies]
ntex-bytes = "0.1" ntex-bytes = "0.1"
ntex-io = "2.3" ntex-io = "2.5"
ntex-util = "2" ntex-util = "2"
log = "0.4" log = "0.4"
compio-net = "0.4.1" compio-net = "0.4.1"

View file

@ -4,17 +4,13 @@ use compio::buf::{BufResult, IoBuf, IoBufMut, SetBufInit};
use compio::io::{AsyncRead, AsyncWrite}; use compio::io::{AsyncRead, AsyncWrite};
use compio::net::TcpStream; use compio::net::TcpStream;
use ntex_bytes::{Buf, BufMut, BytesVec}; use ntex_bytes::{Buf, BufMut, BytesVec};
use ntex_io::{ use ntex_io::{types, Handle, IoStream, ReadContext, WriteContext, WriteContextBuf};
types, Handle, IoStream, ReadContext, ReadStatus, WriteContext, WriteStatus,
};
use ntex_util::{future::select, future::Either, time::sleep};
impl IoStream for crate::TcpStream { impl IoStream for crate::TcpStream {
fn start(self, read: ReadContext, write: WriteContext) -> Option<Box<dyn Handle>> { fn start(self, read: ReadContext, write: WriteContext) -> Option<Box<dyn Handle>> {
let mut io = self.0.clone(); let io = self.0.clone();
compio::runtime::spawn(async move { compio::runtime::spawn(async move {
run(&mut io, &read, write).await; run(io.clone(), &read, write).await;
match io.close().await { match io.close().await {
Ok(_) => log::debug!("{} Stream is closed", read.tag()), Ok(_) => log::debug!("{} Stream is closed", read.tag()),
Err(e) => log::error!("{} Stream is closed, {:?}", read.tag(), e), Err(e) => log::error!("{} Stream is closed, {:?}", read.tag(), e),
@ -29,11 +25,9 @@ impl IoStream for crate::TcpStream {
#[cfg(unix)] #[cfg(unix)]
impl IoStream for crate::UnixStream { impl IoStream for crate::UnixStream {
fn start(self, read: ReadContext, write: WriteContext) -> Option<Box<dyn Handle>> { fn start(self, read: ReadContext, write: WriteContext) -> Option<Box<dyn Handle>> {
let mut io = self.0;
compio::runtime::spawn(async move { compio::runtime::spawn(async move {
run(&mut io, &read, write).await; run(self.0.clone(), &read, write).await;
match self.0.close().await {
match io.close().await {
Ok(_) => log::debug!("{} Unix stream is closed", read.tag()), Ok(_) => log::debug!("{} Unix stream is closed", read.tag()),
Err(e) => log::error!("{} Unix stream is closed, {:?}", read.tag(), e), Err(e) => log::error!("{} Unix stream is closed, {:?}", read.tag(), e),
} }
@ -89,17 +83,18 @@ impl SetBufInit for CompioBuf {
} }
async fn run<T: AsyncRead + AsyncWrite + Clone + 'static>( async fn run<T: AsyncRead + AsyncWrite + Clone + 'static>(
io: &mut T, io: T,
read: &ReadContext, read: &ReadContext,
write: WriteContext, write: WriteContext,
) { ) {
let mut wr_io = io.clone(); let mut wr_io = WriteIo(io.clone());
let wr_task = compio::runtime::spawn(async move { let wr_task = compio::runtime::spawn(async move {
write_task(&mut wr_io, &write).await; write.handle(&mut wr_io).await;
log::debug!("{} Write task is stopped", write.tag()); log::debug!("{} Write task is stopped", write.tag());
}); });
let mut io = ReadIo(io);
read_task(io, read).await; read.handle(&mut io).await;
log::debug!("{} Read task is stopped", read.tag()); log::debug!("{} Read task is stopped", read.tag());
if !wr_task.is_finished() { if !wr_task.is_finished() {
@ -107,142 +102,63 @@ async fn run<T: AsyncRead + AsyncWrite + Clone + 'static>(
} }
} }
/// Read io task struct ReadIo<T>(T);
async fn read_task<T: AsyncRead>(io: &mut T, state: &ReadContext) {
loop {
match state.ready().await {
ReadStatus::Ready => {
let result = state
.with_buf_async(|buf| async {
let BufResult(result, buf) =
match select(io.read(CompioBuf(buf)), state.wait_for_close())
.await
{
Either::Left(res) => res,
Either::Right(_) => return (Default::default(), Ok(1)),
};
match result { impl<T> ntex_io::AsyncRead for ReadIo<T>
Ok(n) => { where
if n == 0 { T: AsyncRead,
log::trace!( {
"{}: Tcp stream is disconnected", #[inline]
state.tag() async fn read(&mut self, buf: BytesVec) -> (BytesVec, io::Result<usize>) {
); let BufResult(result, buf) = self.0.read(CompioBuf(buf)).await;
} (buf.0, result)
(buf.0, Ok(n))
}
Err(err) => {
log::trace!(
"{}: Read task failed on io {:?}",
state.tag(),
err
);
(buf.0, Err(err))
}
}
})
.await;
if result.is_ready() {
break;
}
}
ReadStatus::Terminate => {
log::trace!("{}: Read task is instructed to shutdown", state.tag());
break;
}
}
} }
} }
/// Write io task struct WriteIo<T>(T);
async fn write_task<T: AsyncWrite>(mut io: T, state: &WriteContext) {
let mut delay = None;
loop { impl<T> ntex_io::AsyncWrite for WriteIo<T>
let result = if let Some(ref mut sleep) = delay { where
let result = match select(sleep, state.ready()).await { T: AsyncWrite,
Either::Left(_) => { {
state.close(Some(io::Error::new( #[inline]
io::ErrorKind::TimedOut, async fn write(&mut self, wbuf: &mut WriteContextBuf) -> io::Result<()> {
"Operation timedout", if let Some(b) = wbuf.take() {
))); let mut buf = CompioBuf(b);
return;
}
Either::Right(res) => res,
};
delay = None;
result
} else {
state.ready().await
};
match result {
WriteStatus::Ready => {
// write io stream
match write(&mut io, state).await {
Ok(()) => continue,
Err(e) => {
state.close(Some(e));
}
}
}
WriteStatus::Timeout(time) => {
log::trace!("{}: Initiate timeout delay for {:?}", state.tag(), time);
delay = Some(sleep(time));
continue;
}
WriteStatus::Shutdown(time) => {
log::trace!("{}: Write task is instructed to shutdown", state.tag());
let fut = async {
write(&mut io, state).await?;
io.flush().await?;
io.shutdown().await?;
Ok(())
};
match select(sleep(time), fut).await {
Either::Left(_) => state.close(None),
Either::Right(res) => state.close(res.err()),
}
}
WriteStatus::Terminate => {
log::trace!("{}: Write task is instructed to terminate", state.tag());
state.close(io.shutdown().await.err());
}
}
break;
}
}
// write to io stream
async fn write<T: AsyncWrite>(io: &mut T, state: &WriteContext) -> io::Result<()> {
state
.with_buf_async(|buf| async {
let mut buf = CompioBuf(buf);
loop { loop {
let BufResult(result, buf1) = io.write(buf).await; let BufResult(result, buf1) = self.0.write(buf).await;
buf = buf1; buf = buf1;
return match result { let result = match result {
Ok(0) => Err(io::Error::new( Ok(0) => Err(io::Error::new(
io::ErrorKind::WriteZero, io::ErrorKind::WriteZero,
"failed to write frame to transport", "failed to write frame to transport",
)), )),
Ok(size) => { Ok(size) => {
if buf.0.len() == size { buf.0.advance(size);
// return io.flush().await; if buf.0.is_empty() {
state.memory_pool().release_write_buf(buf.0);
Ok(()) Ok(())
} else { } else {
buf.0.advance(size);
continue; continue;
} }
} }
Err(e) => Err(e), Err(e) => Err(e),
}; };
wbuf.set(buf.0);
return result;
} }
}) } else {
.await Ok(())
}
}
#[inline]
async fn flush(&mut self) -> io::Result<()> {
self.0.flush().await
}
#[inline]
async fn shutdown(&mut self) -> io::Result<()> {
self.0.shutdown().await
}
} }

View file

@ -1,6 +1,6 @@
[package] [package]
name = "ntex-glommio" name = "ntex-glommio"
version = "0.5.0" version = "0.5.1"
authors = ["ntex contributors <team@ntex.rs>"] authors = ["ntex contributors <team@ntex.rs>"]
description = "glommio intergration for ntex framework" description = "glommio intergration for ntex framework"
keywords = ["network", "framework", "async", "futures"] keywords = ["network", "framework", "async", "futures"]
@ -17,7 +17,7 @@ path = "src/lib.rs"
[dependencies] [dependencies]
ntex-bytes = "0.1" ntex-bytes = "0.1"
ntex-io = "2.0" ntex-io = "2.5"
ntex-util = "2.0" ntex-util = "2.0"
futures-lite = "2.2" futures-lite = "2.2"
log = "0.4" log = "0.4"

View file

@ -1,28 +1,30 @@
use std::task::{Context, Poll}; use std::{any, future::poll_fn, io, pin::Pin, task::ready, task::Context, task::Poll};
use std::{any, future::Future, io, pin::Pin};
use futures_lite::future::FutureExt;
use futures_lite::io::{AsyncRead, AsyncWrite}; use futures_lite::io::{AsyncRead, AsyncWrite};
use ntex_bytes::{Buf, BufMut, BytesVec}; use ntex_bytes::{Buf, BufMut, BytesVec};
use ntex_io::{ use ntex_io::{types, Handle, IoStream, ReadContext, WriteContext, WriteContextBuf};
types, Handle, IoStream, ReadContext, ReadStatus, WriteContext, WriteStatus,
};
use ntex_util::{ready, time::sleep, time::Sleep};
use crate::net_impl::{TcpStream, UnixStream}; use crate::net_impl::{TcpStream, UnixStream};
impl IoStream for TcpStream { impl IoStream for TcpStream {
fn start(self, read: ReadContext, write: WriteContext) -> Option<Box<dyn Handle>> { fn start(self, read: ReadContext, write: WriteContext) -> Option<Box<dyn Handle>> {
glommio::spawn_local(ReadTask::new(self.clone(), read)).detach(); let mut rio = Read(self.clone());
glommio::spawn_local(WriteTask::new(self.clone(), write)).detach(); glommio::spawn_local(async move { read.handle(&mut rio).await }).detach();
let mut wio = Write(self.clone());
glommio::spawn_local(async move { write.handle(&mut wio).await }).detach();
Some(Box::new(self)) Some(Box::new(self))
} }
} }
impl IoStream for UnixStream { impl IoStream for UnixStream {
fn start(self, read: ReadContext, write: WriteContext) -> Option<Box<dyn Handle>> { fn start(self, read: ReadContext, write: WriteContext) -> Option<Box<dyn Handle>> {
glommio::spawn_local(UnixReadTask::new(self.clone(), read)).detach(); let mut rio = UnixRead(self.clone());
glommio::spawn_local(UnixWriteTask::new(self, write)).detach(); glommio::spawn_local(async move {
read.handle(&mut rio).await;
})
.detach();
let mut wio = UnixWrite(self);
glommio::spawn_local(async move { write.handle(&mut wio).await }).detach();
None None
} }
} }
@ -39,306 +41,150 @@ impl Handle for TcpStream {
} }
/// Read io task /// Read io task
struct ReadTask { struct Read(TcpStream);
io: TcpStream,
state: ReadContext,
}
impl ReadTask { impl ntex_io::AsyncRead for Read {
/// Create new read io task async fn read(&mut self, mut buf: BytesVec) -> (BytesVec, io::Result<usize>) {
fn new(io: TcpStream, state: ReadContext) -> Self { // read data from socket
Self { io, state } let result = poll_fn(|cx| {
let mut io = self.0 .0.borrow_mut();
poll_read_buf(Pin::new(&mut *io), cx, &mut buf)
})
.await;
(buf, result)
} }
} }
impl Future for ReadTask { struct Write(TcpStream);
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { impl ntex_io::AsyncWrite for Write {
let this = self.as_mut(); #[inline]
async fn write(&mut self, buf: &mut WriteContextBuf) -> io::Result<()> {
this.state.with_buf(|buf, hw, lw| { poll_fn(|cx| {
match ready!(this.state.poll_ready(cx)) { if let Some(mut b) = buf.take() {
ReadStatus::Ready => { let result = flush_io(&mut *self.0 .0.borrow_mut(), &mut b, cx);
// read data from socket buf.set(b);
loop { result
// make sure we've got room } else {
let remaining = buf.remaining_mut(); Poll::Ready(Ok(()))
if remaining < lw {
buf.reserve(hw - remaining);
}
return match poll_read_buf(
Pin::new(&mut *this.io.0.borrow_mut()),
cx,
buf,
) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(n)) => {
if n == 0 {
log::trace!("glommio stream is disconnected");
Poll::Ready(Ok(()))
} else if buf.len() < hw {
continue;
} else {
Poll::Pending
}
}
Poll::Ready(Err(err)) => {
log::trace!("read task failed on io {:?}", err);
Poll::Ready(Err(err))
}
};
}
}
ReadStatus::Terminate => {
log::trace!("read task is instructed to shutdown");
Poll::Ready(Ok(()))
}
} }
}) })
.await
}
#[inline]
async fn flush(&mut self) -> io::Result<()> {
Ok(())
}
#[inline]
async fn shutdown(&mut self) -> io::Result<()> {
poll_fn(|cx| Pin::new(&mut *self.0 .0.borrow_mut()).poll_close(cx)).await
} }
} }
enum IoWriteState { struct UnixRead(UnixStream);
Processing(Option<Sleep>),
Shutdown(Sleep, Shutdown),
}
enum Shutdown { impl ntex_io::AsyncRead for UnixRead {
Flush, async fn read(&mut self, mut buf: BytesVec) -> (BytesVec, io::Result<usize>) {
Close(Pin<Box<dyn Future<Output = glommio::Result<(), ()>>>>), // read data from socket
Stopping(u16), let result = poll_fn(|cx| {
} let mut io = self.0 .0.borrow_mut();
poll_read_buf(Pin::new(&mut *io), cx, &mut buf)
/// Write io task })
struct WriteTask { .await;
st: IoWriteState, (buf, result)
io: TcpStream,
state: WriteContext,
}
impl WriteTask {
/// Create new write io task
fn new(io: TcpStream, state: WriteContext) -> Self {
Self {
io,
state,
st: IoWriteState::Processing(None),
}
} }
} }
impl Future for WriteTask { struct UnixWrite(UnixStream);
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { impl ntex_io::AsyncWrite for UnixWrite {
let this = self.as_mut().get_mut(); #[inline]
async fn write(&mut self, buf: &mut WriteContextBuf) -> io::Result<()> {
match this.st { poll_fn(|cx| {
IoWriteState::Processing(ref mut delay) => { if let Some(mut b) = buf.take() {
match this.state.poll_ready(cx) { let result = flush_io(&mut *self.0 .0.borrow_mut(), &mut b, cx);
Poll::Ready(WriteStatus::Ready) => { buf.set(b);
if let Some(delay) = delay { result
if delay.poll_elapsed(cx).is_ready() { } else {
this.state.close(Some(io::Error::new( Poll::Ready(Ok(()))
io::ErrorKind::TimedOut,
"Operation timedout",
)));
return Poll::Ready(());
}
}
// flush io stream
match ready!(this.state.with_buf(|buf| flush_io(
&mut *this.io.0.borrow_mut(),
buf,
cx
))) {
Ok(()) => Poll::Pending,
Err(e) => {
this.state.close(Some(e));
Poll::Ready(())
}
}
}
Poll::Ready(WriteStatus::Timeout(time)) => {
log::trace!("initiate timeout delay for {:?}", time);
if delay.is_none() {
*delay = Some(sleep(time));
}
self.poll(cx)
}
Poll::Ready(WriteStatus::Shutdown(time)) => {
log::trace!("write task is instructed to shutdown");
let timeout = if let Some(delay) = delay.take() {
delay
} else {
sleep(time)
};
this.st = IoWriteState::Shutdown(timeout, Shutdown::Flush);
self.poll(cx)
}
Poll::Ready(WriteStatus::Terminate) => {
log::trace!("write task is instructed to terminate");
let _ = Pin::new(&mut *this.io.0.borrow_mut()).poll_close(cx);
this.state.close(None);
Poll::Ready(())
}
Poll::Pending => Poll::Pending,
}
} }
IoWriteState::Shutdown(ref mut delay, ref mut st) => { })
// close WRITE side and wait for disconnect on read side. .await
// use disconnect timeout, otherwise it could hang forever. }
loop {
match st {
Shutdown::Flush => {
// flush write buffer
let mut io = this.io.0.borrow_mut();
match this.state.with_buf(|buf| flush_io(&mut *io, buf, cx)) {
Poll::Ready(Ok(())) => {
let io = this.io.clone();
#[allow(clippy::await_holding_refcell_ref)]
let fut = Box::pin(async move {
io.0.borrow()
.shutdown(std::net::Shutdown::Write)
.await
});
*st = Shutdown::Close(fut);
continue;
}
Poll::Ready(Err(err)) => {
log::trace!(
"write task is closed with err during flush, {:?}",
err
);
this.state.close(Some(err));
return Poll::Ready(());
}
Poll::Pending => (),
}
}
Shutdown::Close(ref mut fut) => {
if ready!(fut.poll(cx)).is_err() {
this.state.close(None);
return Poll::Ready(());
}
*st = Shutdown::Stopping(0);
continue;
}
Shutdown::Stopping(ref mut count) => {
// read until 0 or err
let mut buf = [0u8; 512];
let io = &mut this.io;
loop {
match Pin::new(&mut *io.0.borrow_mut())
.poll_read(cx, &mut buf)
{
Poll::Ready(Err(e)) => {
log::trace!("write task is stopped");
this.state.close(Some(e));
return Poll::Ready(());
}
Poll::Ready(Ok(0)) => {
log::trace!("glommio socket is disconnected");
this.state.close(None);
return Poll::Ready(());
}
Poll::Ready(Ok(n)) => {
*count += n as u16;
if *count > 4096 {
log::trace!(
"write task is stopped, too much input"
);
this.state.close(None);
return Poll::Ready(());
}
}
Poll::Pending => break,
}
}
}
}
// disconnect timeout #[inline]
if delay.poll_elapsed(cx).is_pending() { async fn flush(&mut self) -> io::Result<()> {
return Poll::Pending; Ok(())
} }
log::trace!("write task is stopped after delay");
this.state.close(None); #[inline]
let _ = Pin::new(&mut *this.io.0.borrow_mut()).poll_close(cx); async fn shutdown(&mut self) -> io::Result<()> {
return Poll::Ready(()); poll_fn(|cx| Pin::new(&mut *self.0 .0.borrow_mut()).poll_close(cx)).await
}
}
}
} }
} }
/// Flush write buffer to underlying I/O stream. /// Flush write buffer to underlying I/O stream.
pub(super) fn flush_io<T: AsyncRead + AsyncWrite + Unpin>( pub(super) fn flush_io<T: AsyncRead + AsyncWrite + Unpin>(
io: &mut T, io: &mut T,
buf: &mut Option<BytesVec>, buf: &mut BytesVec,
cx: &mut Context<'_>, cx: &mut Context<'_>,
) -> Poll<io::Result<()>> { ) -> Poll<io::Result<()>> {
if let Some(buf) = buf { let len = buf.len();
let len = buf.len();
if len != 0 { if len != 0 {
// log::trace!("flushing framed transport: {:?}", buf.len()); // log::trace!("flushing framed transport: {:?}", buf.len());
let mut written = 0; let mut written = 0;
let result = loop { let result = loop {
break match Pin::new(&mut *io).poll_write(cx, &buf[written..]) { break match Pin::new(&mut *io).poll_write(cx, &buf[written..]) {
Poll::Ready(Ok(n)) => { Poll::Ready(Ok(n)) => {
if n == 0 { if n == 0 {
log::trace!("Disconnected during flush, written {}", written); log::trace!("Disconnected during flush, written {}", written);
Poll::Ready(Err(io::Error::new( Poll::Ready(Err(io::Error::new(
io::ErrorKind::WriteZero, io::ErrorKind::WriteZero,
"failed to write frame to transport", "failed to write frame to transport",
))) )))
} else {
written += n;
if written == len {
buf.clear();
Poll::Ready(Ok(()))
} else { } else {
written += n; continue;
if written == len {
buf.clear();
Poll::Ready(Ok(()))
} else {
continue;
}
} }
} }
Poll::Pending => {
// remove written data
buf.advance(written);
Poll::Pending
}
Poll::Ready(Err(e)) => {
log::trace!("Error during flush: {}", e);
Poll::Ready(Err(e))
}
};
};
log::trace!("flushed {} bytes", written);
// flush
return if written > 0 {
match Pin::new(&mut *io).poll_flush(cx) {
Poll::Ready(Ok(_)) => result,
Poll::Pending => Poll::Pending,
Poll::Ready(Err(e)) => {
log::trace!("error during flush: {}", e);
Poll::Ready(Err(e))
}
} }
} else { Poll::Pending => {
result // remove written data
buf.advance(written);
Poll::Pending
}
Poll::Ready(Err(e)) => {
log::trace!("Error during flush: {}", e);
Poll::Ready(Err(e))
}
}; };
};
// log::trace!("flushed {} bytes", written);
// flush
if written > 0 {
match Pin::new(&mut *io).poll_flush(cx) {
Poll::Ready(Ok(_)) => result,
Poll::Pending => Poll::Pending,
Poll::Ready(Err(e)) => {
log::trace!("error during flush: {}", e);
Poll::Ready(Err(e))
}
}
} else {
result
} }
} else {
Poll::Ready(Ok(()))
} }
Poll::Ready(Ok(()))
} }
pub fn poll_read_buf<T: AsyncRead>( pub fn poll_read_buf<T: AsyncRead>(
@ -357,232 +203,3 @@ pub fn poll_read_buf<T: AsyncRead>(
Poll::Ready(Ok(n)) Poll::Ready(Ok(n))
} }
/// Read io task
struct UnixReadTask {
io: UnixStream,
state: ReadContext,
}
impl UnixReadTask {
/// Create new read io task
fn new(io: UnixStream, state: ReadContext) -> Self {
Self { io, state }
}
}
impl Future for UnixReadTask {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.as_mut();
this.state.with_buf(|buf, hw, lw| {
match ready!(this.state.poll_ready(cx)) {
ReadStatus::Ready => {
// read data from socket
loop {
// make sure we've got room
let remaining = buf.remaining_mut();
if remaining < lw {
buf.reserve(hw - remaining);
}
return match poll_read_buf(
Pin::new(&mut *this.io.0.borrow_mut()),
cx,
buf,
) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(n)) => {
if n == 0 {
log::trace!("glommio stream is disconnected");
Poll::Ready(Ok(()))
} else if buf.len() < hw {
continue;
} else {
Poll::Pending
}
}
Poll::Ready(Err(err)) => {
log::trace!("read task failed on io {:?}", err);
Poll::Ready(Err(err))
}
};
}
}
ReadStatus::Terminate => {
log::trace!("read task is instructed to shutdown");
Poll::Ready(Ok(()))
}
}
})
}
}
/// Write io task
struct UnixWriteTask {
st: IoWriteState,
io: UnixStream,
state: WriteContext,
}
impl UnixWriteTask {
/// Create new write io task
fn new(io: UnixStream, state: WriteContext) -> Self {
Self {
io,
state,
st: IoWriteState::Processing(None),
}
}
}
impl Future for UnixWriteTask {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.as_mut().get_mut();
match this.st {
IoWriteState::Processing(ref mut delay) => {
match this.state.poll_ready(cx) {
Poll::Ready(WriteStatus::Ready) => {
if let Some(delay) = delay {
if delay.poll_elapsed(cx).is_ready() {
this.state.close(Some(io::Error::new(
io::ErrorKind::TimedOut,
"Operation timedout",
)));
return Poll::Ready(());
}
}
// flush io stream
match ready!(this.state.with_buf(|buf| flush_io(
&mut *this.io.0.borrow_mut(),
buf,
cx
))) {
Ok(()) => Poll::Pending,
Err(e) => {
this.state.close(Some(e));
Poll::Ready(())
}
}
}
Poll::Ready(WriteStatus::Timeout(time)) => {
log::trace!("initiate timeout delay for {:?}", time);
if delay.is_none() {
*delay = Some(sleep(time));
}
self.poll(cx)
}
Poll::Ready(WriteStatus::Shutdown(time)) => {
log::trace!("write task is instructed to shutdown");
let timeout = if let Some(delay) = delay.take() {
delay
} else {
sleep(time)
};
this.st = IoWriteState::Shutdown(timeout, Shutdown::Flush);
self.poll(cx)
}
Poll::Ready(WriteStatus::Terminate) => {
log::trace!("write task is instructed to terminate");
let _ = Pin::new(&mut *this.io.0.borrow_mut()).poll_close(cx);
this.state.close(None);
Poll::Ready(())
}
Poll::Pending => Poll::Pending,
}
}
IoWriteState::Shutdown(ref mut delay, ref mut st) => {
// close WRITE side and wait for disconnect on read side.
// use disconnect timeout, otherwise it could hang forever.
loop {
match st {
Shutdown::Flush => {
// flush write buffer
let mut io = this.io.0.borrow_mut();
match this.state.with_buf(|buf| flush_io(&mut *io, buf, cx)) {
Poll::Ready(Ok(())) => {
let io = this.io.clone();
#[allow(clippy::await_holding_refcell_ref)]
let fut = Box::pin(async move {
io.0.borrow()
.shutdown(std::net::Shutdown::Write)
.await
});
*st = Shutdown::Close(fut);
continue;
}
Poll::Ready(Err(err)) => {
log::trace!(
"write task is closed with err during flush, {:?}",
err
);
this.state.close(Some(err));
return Poll::Ready(());
}
Poll::Pending => (),
}
}
Shutdown::Close(ref mut fut) => {
if ready!(fut.poll(cx)).is_err() {
this.state.close(None);
return Poll::Ready(());
}
*st = Shutdown::Stopping(0);
continue;
}
Shutdown::Stopping(ref mut count) => {
// read until 0 or err
let mut buf = [0u8; 512];
let io = &mut this.io;
loop {
match Pin::new(&mut *io.0.borrow_mut())
.poll_read(cx, &mut buf)
{
Poll::Ready(Err(e)) => {
log::trace!("write task is stopped");
this.state.close(Some(e));
return Poll::Ready(());
}
Poll::Ready(Ok(0)) => {
log::trace!("glommio unix socket is disconnected");
this.state.close(None);
return Poll::Ready(());
}
Poll::Ready(Ok(n)) => {
*count += n as u16;
if *count > 4096 {
log::trace!(
"write task is stopped, too much input"
);
this.state.close(None);
return Poll::Ready(());
}
}
Poll::Pending => break,
}
}
}
}
// disconnect timeout
if delay.poll_elapsed(cx).is_pending() {
return Poll::Pending;
}
log::trace!("write task is stopped after delay");
this.state.close(None);
let _ = Pin::new(&mut *this.io.0.borrow_mut()).poll_close(cx);
return Poll::Ready(());
}
}
}
}
}

View file

@ -1,5 +1,9 @@
# Changes # Changes
## [2.5.0] - 2024-09-10
* Refactor async io support
## [2.3.1] - 2024-09-05 ## [2.3.1] - 2024-09-05
* Tune async io tasks support * Tune async io tasks support

View file

@ -1,6 +1,6 @@
[package] [package]
name = "ntex-io" name = "ntex-io"
version = "2.4.0" version = "2.5.0"
authors = ["ntex contributors <team@ntex.rs>"] authors = ["ntex contributors <team@ntex.rs>"]
description = "Utilities for encoding and decoding frames" description = "Utilities for encoding and decoding frames"
keywords = ["network", "framework", "async", "futures"] keywords = ["network", "framework", "async", "futures"]

View file

@ -152,27 +152,6 @@ impl Stack {
} }
} }
pub(crate) fn with_read_source<F, R>(&self, io: &IoRef, f: F) -> R
where
F: FnOnce(&mut BytesVec) -> R,
{
let item = self.get_last_level();
let mut rb = item.0.take();
if rb.is_none() {
rb = Some(io.memory_pool().get_read_buf());
}
let result = f(rb.as_mut().unwrap());
if let Some(b) = rb {
if b.is_empty() {
io.memory_pool().release_read_buf(b);
} else {
item.0.set(Some(b));
}
}
result
}
pub(crate) fn with_read_destination<F, R>(&self, io: &IoRef, f: F) -> R pub(crate) fn with_read_destination<F, R>(&self, io: &IoRef, f: F) -> R
where where
F: FnOnce(&mut BytesVec) -> R, F: FnOnce(&mut BytesVec) -> R,
@ -226,6 +205,17 @@ impl Stack {
self.get_last_level().1.take() self.get_last_level().1.take()
} }
pub(crate) fn set_write_destination(&self, buf: BytesVec) -> Option<BytesVec> {
let b = self.get_last_level().1.take();
if b.is_some() {
self.get_last_level().1.set(b);
Some(buf)
} else {
self.get_last_level().1.set(Some(buf));
None
}
}
pub(crate) fn with_write_destination<F, R>(&self, io: &IoRef, f: F) -> R pub(crate) fn with_write_destination<F, R>(&self, io: &IoRef, f: F) -> R
where where
F: FnOnce(&mut Option<BytesVec>) -> R, F: FnOnce(&mut Option<BytesVec>) -> R,

View file

@ -93,26 +93,16 @@ impl Filter for Base {
#[inline] #[inline]
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> { fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> {
let mut flags = self.0.flags(); let flags = self.0.flags();
if flags.contains(Flags::IO_STOPPED) { if flags.is_stopped() {
Poll::Ready(WriteStatus::Terminate) Poll::Ready(WriteStatus::Terminate)
} else { } else {
self.0 .0.write_task.register(cx.waker()); self.0 .0.write_task.register(cx.waker());
if flags.intersects(Flags::IO_STOPPING) { if flags.contains(Flags::IO_STOPPING) {
Poll::Ready(WriteStatus::Shutdown( Poll::Ready(WriteStatus::Shutdown)
self.0 .0.disconnect_timeout.get().into(), } else if flags.contains(Flags::WR_PAUSED) {
))
} else if flags.contains(Flags::IO_STOPPING_FILTERS)
&& !flags.contains(Flags::IO_FILTERS_TIMEOUT)
{
flags.insert(Flags::IO_FILTERS_TIMEOUT);
self.0.set_flags(flags);
Poll::Ready(WriteStatus::Timeout(
self.0 .0.disconnect_timeout.get().into(),
))
} else if flags.intersects(Flags::WR_PAUSED) {
Poll::Pending Poll::Pending
} else { } else {
Poll::Ready(WriteStatus::Ready) Poll::Ready(WriteStatus::Ready)
@ -242,20 +232,13 @@ where
Poll::Pending => Poll::Pending, Poll::Pending => Poll::Pending,
Poll::Ready(WriteStatus::Ready) => res2, Poll::Ready(WriteStatus::Ready) => res2,
Poll::Ready(WriteStatus::Terminate) => Poll::Ready(WriteStatus::Terminate), Poll::Ready(WriteStatus::Terminate) => Poll::Ready(WriteStatus::Terminate),
Poll::Ready(WriteStatus::Shutdown(t)) => { Poll::Ready(WriteStatus::Shutdown) => {
if res2 == Poll::Ready(WriteStatus::Terminate) { if res2 == Poll::Ready(WriteStatus::Terminate) {
Poll::Ready(WriteStatus::Terminate) Poll::Ready(WriteStatus::Terminate)
} else { } else {
Poll::Ready(WriteStatus::Shutdown(t)) Poll::Ready(WriteStatus::Shutdown)
} }
} }
Poll::Ready(WriteStatus::Timeout(t)) => match res2 {
Poll::Ready(WriteStatus::Terminate) => Poll::Ready(WriteStatus::Terminate),
Poll::Ready(WriteStatus::Shutdown(t)) => {
Poll::Ready(WriteStatus::Shutdown(t))
}
_ => Poll::Ready(WriteStatus::Timeout(t)),
},
} }
} }
} }

View file

@ -7,8 +7,6 @@ bitflags::bitflags! {
const IO_STOPPING = 0b0000_0000_0000_0010; const IO_STOPPING = 0b0000_0000_0000_0010;
/// shuting down filters /// shuting down filters
const IO_STOPPING_FILTERS = 0b0000_0000_0000_0100; const IO_STOPPING_FILTERS = 0b0000_0000_0000_0100;
/// initiate filters shutdown timeout in write task
const IO_FILTERS_TIMEOUT = 0b0000_0000_0000_1000;
/// pause io read /// pause io read
const RD_PAUSED = 0b0000_0000_0001_0000; const RD_PAUSED = 0b0000_0000_0001_0000;
@ -36,6 +34,10 @@ bitflags::bitflags! {
} }
impl Flags { impl Flags {
pub(crate) fn is_stopped(&self) -> bool {
self.intersects(Flags::IO_STOPPED)
}
pub(crate) fn is_waiting_for_write(&self) -> bool { pub(crate) fn is_waiting_for_write(&self) -> bool {
self.intersects(Flags::BUF_W_MUST_FLUSH | Flags::BUF_W_BACKPRESSURE) self.intersects(Flags::BUF_W_MUST_FLUSH | Flags::BUF_W_BACKPRESSURE)
} }

View file

@ -165,7 +165,7 @@ impl Io {
let inner = Rc::new(IoState { let inner = Rc::new(IoState {
filter: FilterPtr::null(), filter: FilterPtr::null(),
pool: Cell::new(pool), pool: Cell::new(pool),
flags: Cell::new(Flags::empty()), flags: Cell::new(Flags::WR_PAUSED),
error: Cell::new(None), error: Cell::new(None),
dispatch_task: LocalWaker::new(), dispatch_task: LocalWaker::new(),
read_task: LocalWaker::new(), read_task: LocalWaker::new(),
@ -421,7 +421,7 @@ impl<F> Io<F> {
let st = self.st(); let st = self.st();
let mut flags = st.flags.get(); let mut flags = st.flags.get();
if flags.contains(Flags::IO_STOPPED) { if flags.is_stopped() {
Poll::Ready(self.error().map(Err).unwrap_or(Ok(None))) Poll::Ready(self.error().map(Err).unwrap_or(Ok(None)))
} else { } else {
st.dispatch_task.register(cx.waker()); st.dispatch_task.register(cx.waker());
@ -531,7 +531,7 @@ impl<F> Io<F> {
} else { } else {
let st = self.st(); let st = self.st();
let flags = st.flags.get(); let flags = st.flags.get();
if flags.contains(Flags::IO_STOPPED) { if flags.is_stopped() {
Err(RecvError::PeerGone(self.error())) Err(RecvError::PeerGone(self.error()))
} else if flags.contains(Flags::DSP_STOP) { } else if flags.contains(Flags::DSP_STOP) {
st.remove_flags(Flags::DSP_STOP); st.remove_flags(Flags::DSP_STOP);
@ -568,7 +568,7 @@ impl<F> Io<F> {
pub fn poll_flush(&self, cx: &mut Context<'_>, full: bool) -> Poll<io::Result<()>> { pub fn poll_flush(&self, cx: &mut Context<'_>, full: bool) -> Poll<io::Result<()>> {
let flags = self.flags(); let flags = self.flags();
if flags.contains(Flags::IO_STOPPED) { if flags.is_stopped() {
Poll::Ready(self.error().map(Err).unwrap_or(Ok(()))) Poll::Ready(self.error().map(Err).unwrap_or(Ok(())))
} else { } else {
let st = self.st(); let st = self.st();
@ -595,7 +595,7 @@ impl<F> Io<F> {
let st = self.st(); let st = self.st();
let flags = st.flags.get(); let flags = st.flags.get();
if flags.intersects(Flags::IO_STOPPED) { if flags.is_stopped() {
if let Some(err) = self.error() { if let Some(err) = self.error() {
Poll::Ready(Err(err)) Poll::Ready(Err(err))
} else { } else {
@ -700,7 +700,7 @@ impl<F> Drop for Io<F> {
if st.filter.is_set() { if st.filter.is_set() {
// filter is unsafe and must be dropped explicitly, // filter is unsafe and must be dropped explicitly,
// and wont be dropped without special attention // and wont be dropped without special attention
if !st.flags.get().contains(Flags::IO_STOPPED) { if !st.flags.get().is_stopped() {
log::trace!( log::trace!(
"{}: Io is dropped, force stopping io streams {:?}", "{}: Io is dropped, force stopping io streams {:?}",
st.tag.get(), st.tag.get(),
@ -884,7 +884,7 @@ pub struct OnDisconnect {
impl OnDisconnect { impl OnDisconnect {
pub(super) fn new(inner: Rc<IoState>) -> Self { pub(super) fn new(inner: Rc<IoState>) -> Self {
Self::new_inner(inner.flags.get().contains(Flags::IO_STOPPED), inner) Self::new_inner(inner.flags.get().is_stopped(), inner)
} }
fn new_inner(disconnected: bool, inner: Rc<IoState>) -> Self { fn new_inner(disconnected: bool, inner: Rc<IoState>) -> Self {
@ -909,7 +909,7 @@ impl OnDisconnect {
#[inline] #[inline]
/// Check if connection is disconnected /// Check if connection is disconnected
pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<()> { pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<()> {
if self.token == usize::MAX || self.inner.flags.get().contains(Flags::IO_STOPPED) { if self.token == usize::MAX || self.inner.flags.get().is_stopped() {
Poll::Ready(()) Poll::Ready(())
} else if let Some(on_disconnect) = self.inner.on_disconnect.take() { } else if let Some(on_disconnect) = self.inner.on_disconnect.take() {
on_disconnect[self.token].register(cx.waker()); on_disconnect[self.token].register(cx.waker());

View file

@ -14,12 +14,6 @@ impl IoRef {
self.0.flags.get() self.0.flags.get()
} }
#[inline]
/// Set flags
pub(crate) fn set_flags(&self, flags: Flags) {
self.0.flags.set(flags)
}
#[inline] #[inline]
/// Get current filter /// Get current filter
pub(crate) fn filter(&self) -> &dyn Filter { pub(crate) fn filter(&self) -> &dyn Filter {
@ -41,10 +35,6 @@ impl IoRef {
.intersects(Flags::IO_STOPPING | Flags::IO_STOPPED) .intersects(Flags::IO_STOPPING | Flags::IO_STOPPED)
} }
pub(crate) fn is_io_closed(&self) -> bool {
self.0.flags.get().intersects(Flags::IO_STOPPED)
}
#[inline] #[inline]
/// Check if write back-pressure is enabled /// Check if write back-pressure is enabled
pub fn is_wr_backpressure(&self) -> bool { pub fn is_wr_backpressure(&self) -> bool {

View file

@ -1,5 +1,6 @@
//! Utilities for abstructing io streams //! Utilities for abstructing io streams
#![deny(rust_2018_idioms, unreachable_pub, missing_debug_implementations)] #![deny(rust_2018_idioms, unreachable_pub, missing_debug_implementations)]
#![allow(async_fn_in_trait)]
use std::{ use std::{
any::Any, any::TypeId, fmt, io as sio, io::Error as IoError, task::Context, task::Poll, any::Any, any::TypeId, fmt, io as sio, io::Error as IoError, task::Context, task::Poll,
@ -20,8 +21,8 @@ mod tasks;
mod timer; mod timer;
mod utils; mod utils;
use ntex_bytes::BytesVec;
use ntex_codec::{Decoder, Encoder}; use ntex_codec::{Decoder, Encoder};
use ntex_util::time::Millis;
pub use self::buf::{ReadBuf, WriteBuf}; pub use self::buf::{ReadBuf, WriteBuf};
pub use self::dispatcher::{Dispatcher, DispatcherConfig}; pub use self::dispatcher::{Dispatcher, DispatcherConfig};
@ -29,13 +30,27 @@ pub use self::filter::{Base, Filter, Layer};
pub use self::framed::Framed; pub use self::framed::Framed;
pub use self::io::{Io, IoRef, OnDisconnect}; pub use self::io::{Io, IoRef, OnDisconnect};
pub use self::seal::{IoBoxed, Sealed}; pub use self::seal::{IoBoxed, Sealed};
pub use self::tasks::{ReadContext, WriteContext}; pub use self::tasks::{ReadContext, WriteContext, WriteContextBuf};
pub use self::timer::TimerHandle; pub use self::timer::TimerHandle;
pub use self::utils::{seal, Decoded}; pub use self::utils::{seal, Decoded};
#[doc(hidden)] #[doc(hidden)]
pub use self::flags::Flags; pub use self::flags::Flags;
#[doc(hidden)]
pub trait AsyncRead {
async fn read(&mut self, buf: BytesVec) -> (BytesVec, sio::Result<usize>);
}
#[doc(hidden)]
pub trait AsyncWrite {
async fn write(&mut self, buf: &mut WriteContextBuf) -> sio::Result<()>;
async fn flush(&mut self) -> sio::Result<()>;
async fn shutdown(&mut self) -> sio::Result<()>;
}
/// Status for read task /// Status for read task
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub enum ReadStatus { pub enum ReadStatus {
@ -48,10 +63,8 @@ pub enum ReadStatus {
pub enum WriteStatus { pub enum WriteStatus {
/// Write task is clear to proceed with write operation /// Write task is clear to proceed with write operation
Ready, Ready,
/// Initiate timeout for normal write operations, shutdown connection after timeout /// Initiate graceful io shutdown operation
Timeout(Millis), Shutdown,
/// Initiate graceful io shutdown operation with timeout
Shutdown(Millis),
/// Immediately terminate connection /// Immediately terminate connection
Terminate, Terminate,
} }

View file

@ -1,16 +1,22 @@
use std::{future::poll_fn, future::Future, io, task::Context, task::Poll}; use std::{cell::Cell, fmt, future::poll_fn, io, task::Context, task::Poll};
use ntex_bytes::{BufMut, BytesVec, PoolRef}; use ntex_bytes::{BufMut, BytesVec};
use ntex_util::{future::lazy, future::select, future::Either, time::sleep, time::Sleep};
use crate::{Flags, IoRef, ReadStatus, WriteStatus}; use crate::{AsyncRead, AsyncWrite, Flags, IoRef, ReadStatus, WriteStatus};
#[derive(Debug)]
/// Context for io read task /// Context for io read task
pub struct ReadContext(IoRef); pub struct ReadContext(IoRef, Cell<Option<Sleep>>);
impl fmt::Debug for ReadContext {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ReadContext").field("io", &self.0).finish()
}
}
impl ReadContext { impl ReadContext {
pub(crate) fn new(io: &IoRef) -> Self { pub(crate) fn new(io: &IoRef) -> Self {
Self(io.clone()) Self(io.clone(), Cell::new(None))
} }
#[inline] #[inline]
@ -19,15 +25,8 @@ impl ReadContext {
self.0.tag() self.0.tag()
} }
#[inline]
/// Check readiness for read operations
pub async fn ready(&self) -> ReadStatus {
poll_fn(|cx| self.0.filter().poll_read_ready(cx)).await
}
#[inline]
/// Wait when io get closed or preparing for close /// Wait when io get closed or preparing for close
pub async fn wait_for_close(&self) { async fn wait_for_close(&self) {
poll_fn(|cx| { poll_fn(|cx| {
let flags = self.0.flags(); let flags = self.0.flags();
@ -36,7 +35,7 @@ impl ReadContext {
} else { } else {
self.0 .0.read_task.register(cx.waker()); self.0 .0.read_task.register(cx.waker());
if flags.contains(Flags::IO_STOPPING_FILTERS) { if flags.contains(Flags::IO_STOPPING_FILTERS) {
shutdown_filters(&self.0); self.shutdown_filters(cx);
} }
Poll::Pending Poll::Pending
} }
@ -44,222 +43,169 @@ impl ReadContext {
.await .await
} }
#[inline] /// Handle read io operations
/// Check readiness for read operations pub async fn handle<T>(&self, io: &mut T)
pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<ReadStatus> {
self.0.filter().poll_read_ready(cx)
}
/// Get read buffer
pub fn with_buf<F>(&self, f: F) -> Poll<()>
where where
F: FnOnce(&mut BytesVec, usize, usize) -> Poll<io::Result<()>>, T: AsyncRead,
{ {
let inner = &self.0 .0; let inner = &self.0 .0;
let (hw, lw) = self.0.memory_pool().read_params().unpack();
let (result, nbytes, total) = inner.buffer.with_read_source(&self.0, |buf| { loop {
let result = poll_fn(|cx| self.0.filter().poll_read_ready(cx)).await;
if result == ReadStatus::Terminate {
log::trace!("{}: Read task is instructed to shutdown", self.tag());
break;
}
let mut buf = if inner.flags.get().is_read_buf_ready() {
// read buffer is still not read by dispatcher
// we cannot touch it
inner.pool.get().get_read_buf()
} else {
inner
.buffer
.get_read_source()
.unwrap_or_else(|| inner.pool.get().get_read_buf())
};
// make sure we've got room
let (hw, lw) = self.0.memory_pool().read_params().unpack();
let remaining = buf.remaining_mut();
if remaining <= lw {
buf.reserve(hw - remaining);
}
let total = buf.len(); let total = buf.len();
// call provided callback // call provided callback
let result = f(buf, hw, lw); let (buf, result) = match select(io.read(buf), self.wait_for_close()).await {
Either::Left(res) => res,
Either::Right(_) => {
log::trace!("{}: Read io is closed, stop read task", self.tag());
break;
}
};
// handle incoming data
let total2 = buf.len(); let total2 = buf.len();
let nbytes = if total2 > total { total2 - total } else { 0 }; let nbytes = if total2 > total { total2 - total } else { 0 };
(result, nbytes, total2) let total = total2;
});
// handle buffer changes if let Some(mut first_buf) = inner.buffer.get_read_source() {
if nbytes > 0 { first_buf.extend_from_slice(&buf);
let filter = self.0.filter(); inner.buffer.set_read_source(&self.0, first_buf);
let _ = filter } else {
.process_read_buf(&self.0, &inner.buffer, 0, nbytes) inner.buffer.set_read_source(&self.0, buf);
.and_then(|status| { }
if status.nbytes > 0 {
// dest buffer has new data, wake up dispatcher // handle buffer changes
if inner.buffer.read_destination_size() >= hw { if nbytes > 0 {
log::trace!( let filter = self.0.filter();
let res = match filter.process_read_buf(&self.0, &inner.buffer, 0, nbytes) {
Ok(status) => {
if status.nbytes > 0 {
// check read back-pressure
if hw < inner.buffer.read_destination_size() {
log::trace!(
"{}: Io read buffer is too large {}, enable read back-pressure", "{}: Io read buffer is too large {}, enable read back-pressure",
self.0.tag(), self.0.tag(),
total total
); );
inner.insert_flags(Flags::BUF_R_READY | Flags::BUF_R_FULL); inner.insert_flags(Flags::BUF_R_READY | Flags::BUF_R_FULL);
} else { } else {
inner.insert_flags(Flags::BUF_R_READY); inner.insert_flags(Flags::BUF_R_READY);
if nbytes >= hw {
// read task is paused because of read back-pressure
// but there is no new data in top most read buffer
// so we need to wake up read task to read more data
// otherwise read task would sleep forever
inner.read_task.wake();
} }
} log::trace!(
log::trace!( "{}: New {} bytes available, wakeup dispatcher",
"{}: New {} bytes available, wakeup dispatcher", self.0.tag(),
self.0.tag(), nbytes
nbytes );
); // dest buffer has new data, wake up dispatcher
inner.dispatch_task.wake(); inner.dispatch_task.wake();
} else { } else if inner.flags.get().contains(Flags::RD_NOTIFY) {
if nbytes >= hw {
// read task is paused because of read back-pressure
// but there is no new data in top most read buffer
// so we need to wake up read task to read more data
// otherwise read task would sleep forever
inner.read_task.wake();
}
if inner.flags.get().contains(Flags::RD_NOTIFY) {
// in case of "notify" we must wake up dispatch task // in case of "notify" we must wake up dispatch task
// if we read any data from source // if we read any data from source
inner.dispatch_task.wake(); inner.dispatch_task.wake();
} }
}
// while reading, filter wrote some data // while reading, filter wrote some data
// in that case filters need to process write buffers // in that case filters need to process write buffers
// and potentialy wake write task // and potentialy wake write task
if status.need_write { if status.need_write {
filter.process_write_buf(&self.0, &inner.buffer, 0) filter.process_write_buf(&self.0, &inner.buffer, 0)
} else { } else {
Ok(()) Ok(())
}
} }
}) Err(err) => Err(err),
.map_err(|err| { };
if let Err(err) = res {
inner.dispatch_task.wake(); inner.dispatch_task.wake();
inner.io_stopped(Some(err)); inner.io_stopped(Some(err));
inner.insert_flags(Flags::BUF_R_READY); inner.insert_flags(Flags::BUF_R_READY);
});
}
match result {
Poll::Ready(Ok(())) => {
inner.io_stopped(None);
Poll::Ready(())
}
Poll::Ready(Err(e)) => {
inner.io_stopped(Some(e));
Poll::Ready(())
}
Poll::Pending => {
if inner.flags.get().contains(Flags::IO_STOPPING_FILTERS) {
shutdown_filters(&self.0);
} }
Poll::Pending }
match result {
Ok(0) => {
log::trace!("{}: Tcp stream is disconnected", self.tag());
inner.io_stopped(None);
break;
}
Ok(_) => {
if inner.flags.get().contains(Flags::IO_STOPPING_FILTERS) {
lazy(|cx| self.shutdown_filters(cx)).await;
}
}
Err(err) => {
log::trace!("{}: Read task failed on io {:?}", self.tag(), err);
inner.io_stopped(Some(err));
break;
}
} }
} }
} }
/// Get read buffer (async) fn shutdown_filters(&self, cx: &mut Context<'_>) {
pub async fn with_buf_async<F, R>(&self, f: F) -> Poll<()> let st = &self.0 .0;
where let filter = self.0.filter();
F: FnOnce(BytesVec) -> R,
R: Future<Output = (BytesVec, io::Result<usize>)>,
{
let inner = &self.0 .0;
// // we already pushed new data to read buffer, match filter.shutdown(&self.0, &st.buffer, 0) {
// // we have to wait for dispatcher to read data from buffer Ok(Poll::Ready(())) => {
// if inner.flags.get().is_read_buf_ready() { st.dispatch_task.wake();
// ntex_util::task::yield_to().await; st.insert_flags(Flags::IO_STOPPING);
// }
let mut buf = if inner.flags.get().is_read_buf_ready() {
// read buffer is still not read by dispatcher
// we cannot touch it
inner.pool.get().get_read_buf()
} else {
inner
.buffer
.get_read_source()
.unwrap_or_else(|| inner.pool.get().get_read_buf())
};
// make sure we've got room
let (hw, lw) = self.0.memory_pool().read_params().unpack();
let remaining = buf.remaining_mut();
if remaining <= lw {
buf.reserve(hw - remaining);
}
let total = buf.len();
// call provided callback
let (buf, result) = f(buf).await;
let total2 = buf.len();
let nbytes = if total2 > total { total2 - total } else { 0 };
let total = total2;
if let Some(mut first_buf) = inner.buffer.get_read_source() {
first_buf.extend_from_slice(&buf);
inner.buffer.set_read_source(&self.0, first_buf);
} else {
inner.buffer.set_read_source(&self.0, buf);
}
// handle buffer changes
if nbytes > 0 {
let filter = self.0.filter();
let res = match filter.process_read_buf(&self.0, &inner.buffer, 0, nbytes) {
Ok(status) => {
if status.nbytes > 0 {
// check read back-pressure
if hw < inner.buffer.read_destination_size() {
log::trace!(
"{}: Io read buffer is too large {}, enable read back-pressure",
self.0.tag(),
total
);
inner.insert_flags(Flags::BUF_R_READY | Flags::BUF_R_FULL);
} else {
inner.insert_flags(Flags::BUF_R_READY);
}
log::trace!(
"{}: New {} bytes available, wakeup dispatcher",
self.0.tag(),
nbytes
);
// dest buffer has new data, wake up dispatcher
inner.dispatch_task.wake();
} else if inner.flags.get().contains(Flags::RD_NOTIFY) {
// in case of "notify" we must wake up dispatch task
// if we read any data from source
inner.dispatch_task.wake();
}
// while reading, filter wrote some data
// in that case filters need to process write buffers
// and potentialy wake write task
if status.need_write {
filter.process_write_buf(&self.0, &inner.buffer, 0)
} else {
Ok(())
}
}
Err(err) => Err(err),
};
if let Err(err) = res {
inner.dispatch_task.wake();
inner.io_stopped(Some(err));
inner.insert_flags(Flags::BUF_R_READY);
} }
} Ok(Poll::Pending) => {
let flags = st.flags.get();
match result { // check read buffer, if buffer is not consumed it is unlikely
Ok(n) => { // that filter will properly complete shutdown
if n == 0 { if flags.contains(Flags::RD_PAUSED)
inner.io_stopped(None); || flags.contains(Flags::BUF_R_FULL | Flags::BUF_R_READY)
Poll::Ready(()) {
st.dispatch_task.wake();
st.insert_flags(Flags::IO_STOPPING);
} else { } else {
if inner.flags.get().contains(Flags::IO_STOPPING_FILTERS) { // filter shutdown timeout
shutdown_filters(&self.0); let timeout = self
.1
.take()
.unwrap_or_else(|| sleep(st.disconnect_timeout.get()));
if timeout.poll_elapsed(cx).is_ready() {
st.dispatch_task.wake();
st.insert_flags(Flags::IO_STOPPING);
} else {
self.1.set(Some(timeout));
} }
Poll::Pending
} }
} }
Err(e) => { Err(err) => {
inner.io_stopped(Some(e)); st.io_stopped(Some(err));
Poll::Ready(())
} }
} }
if let Err(err) = filter.process_write_buf(&self.0, &st.buffer, 0) {
st.io_stopped(Some(err));
}
} }
} }
@ -267,6 +213,13 @@ impl ReadContext {
/// Context for io write task /// Context for io write task
pub struct WriteContext(IoRef); pub struct WriteContext(IoRef);
#[derive(Debug)]
/// Context buf for io write task
pub struct WriteContextBuf {
io: IoRef,
buf: Option<BytesVec>,
}
impl WriteContext { impl WriteContext {
pub(crate) fn new(io: &IoRef) -> Self { pub(crate) fn new(io: &IoRef) -> Self {
Self(io.clone()) Self(io.clone())
@ -278,104 +231,92 @@ impl WriteContext {
self.0.tag() self.0.tag()
} }
#[inline]
/// Return memory pool for this context
pub fn memory_pool(&self) -> PoolRef {
self.0.memory_pool()
}
#[inline]
/// Check readiness for write operations /// Check readiness for write operations
pub async fn ready(&self) -> WriteStatus { async fn ready(&self) -> WriteStatus {
poll_fn(|cx| self.0.filter().poll_write_ready(cx)).await poll_fn(|cx| self.0.filter().poll_write_ready(cx)).await
} }
#[inline] /// Indicate that write io task is stopped
/// Check readiness for write operations fn close(&self, err: Option<io::Error>) {
pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<WriteStatus> { self.0 .0.io_stopped(err);
self.0.filter().poll_write_ready(cx)
} }
#[inline]
/// Check if io is closed /// Check if io is closed
pub fn poll_close(&self, cx: &mut Context<'_>) -> Poll<()> { async fn when_stopped(&self) {
if self.0.is_io_closed() { poll_fn(|cx| {
Poll::Ready(()) if self.0.flags().is_stopped() {
} else { Poll::Ready(())
self.0 .0.write_task.register(cx.waker());
Poll::Pending
}
}
/// Get write buffer
pub fn with_buf<F>(&self, f: F) -> Poll<io::Result<()>>
where
F: FnOnce(&mut Option<BytesVec>) -> Poll<io::Result<()>>,
{
let inner = &self.0 .0;
// call provided callback
let (result, len) = inner.buffer.with_write_destination(&self.0, |buf| {
let result = f(buf);
(result, buf.as_ref().map(|b| b.len()).unwrap_or(0))
});
// if write buffer is smaller than high watermark value, turn off back-pressure
let mut flags = inner.flags.get();
if len == 0 {
if flags.is_waiting_for_write() {
flags.waiting_for_write_is_done();
inner.dispatch_task.wake();
}
} else if flags.contains(Flags::BUF_W_BACKPRESSURE)
&& len < inner.pool.get().write_params_high() << 1
{
flags.remove(Flags::BUF_W_BACKPRESSURE);
inner.dispatch_task.wake();
}
match result {
Poll::Pending => flags.remove(Flags::WR_PAUSED),
Poll::Ready(Ok(())) => flags.insert(Flags::WR_PAUSED),
Poll::Ready(Err(_)) => {}
}
inner.flags.set(flags);
result
}
/// Get write buffer (async)
pub async fn with_buf_async<F, R>(&self, f: F) -> io::Result<()>
where
F: FnOnce(BytesVec) -> R,
R: Future<Output = io::Result<()>>,
{
let inner = &self.0 .0;
// running
let mut flags = inner.flags.get();
if flags.contains(Flags::WR_PAUSED) {
flags.remove(Flags::WR_PAUSED);
inner.flags.set(flags);
}
// buffer
let buf = inner.buffer.get_write_destination();
// call provided callback
let result = if let Some(buf) = buf {
if !buf.is_empty() {
f(buf).await
} else { } else {
Ok(()) self.0 .0.write_task.register(cx.waker());
Poll::Pending
} }
} else { })
Ok(()) .await
}
/// Handle write io operations
pub async fn handle<T>(&self, io: &mut T)
where
T: AsyncWrite,
{
let mut buf = WriteContextBuf {
io: self.0.clone(),
buf: None,
}; };
loop {
match self.ready().await {
WriteStatus::Ready => {
// write io stream
match select(io.write(&mut buf), self.when_stopped()).await {
Either::Left(Ok(_)) => continue,
Either::Left(Err(e)) => self.close(Some(e)),
Either::Right(_) => return,
}
}
WriteStatus::Shutdown => {
log::trace!("{}: Write task is instructed to shutdown", self.tag());
let fut = async {
// write io stream
io.write(&mut buf).await?;
io.flush().await?;
io.shutdown().await?;
Ok(())
};
match select(sleep(self.0 .0.disconnect_timeout.get()), fut).await {
Either::Left(_) => self.close(None),
Either::Right(res) => self.close(res.err()),
}
}
WriteStatus::Terminate => {
log::trace!("{}: Write task is instructed to terminate", self.tag());
self.close(io.shutdown().await.err());
}
}
return;
}
}
}
impl WriteContextBuf {
pub fn set(&mut self, mut buf: BytesVec) {
if buf.is_empty() {
self.io.memory_pool().release_write_buf(buf);
} else if let Some(b) = self.buf.take() {
buf.extend_from_slice(&b);
self.io.memory_pool().release_write_buf(b);
self.buf = Some(buf);
} else if let Some(b) = self.io.0.buffer.set_write_destination(buf) {
// write buffer is already set
self.buf = Some(b);
}
// if write buffer is smaller than high watermark value, turn off back-pressure // if write buffer is smaller than high watermark value, turn off back-pressure
let inner = &self.io.0;
let len = self.buf.as_ref().map(|b| b.len()).unwrap_or_default()
+ inner.buffer.write_destination_size();
let mut flags = inner.flags.get(); let mut flags = inner.flags.get();
let len = inner.buffer.write_destination_size();
if len == 0 { if len == 0 {
if flags.is_waiting_for_write() { if flags.is_waiting_for_write() {
@ -391,44 +332,13 @@ impl WriteContext {
inner.flags.set(flags); inner.flags.set(flags);
inner.dispatch_task.wake(); inner.dispatch_task.wake();
} }
result
} }
#[inline] pub fn take(&mut self) -> Option<BytesVec> {
/// Indicate that write io task is stopped if let Some(buf) = self.buf.take() {
pub fn close(&self, err: Option<io::Error>) { Some(buf)
self.0 .0.io_stopped(err); } else {
} self.io.0.buffer.get_write_destination()
}
fn shutdown_filters(io: &IoRef) {
let st = &io.0;
let flags = st.flags.get();
if !flags.intersects(Flags::IO_STOPPED | Flags::IO_STOPPING) {
let filter = io.filter();
match filter.shutdown(io, &st.buffer, 0) {
Ok(Poll::Ready(())) => {
st.dispatch_task.wake();
st.insert_flags(Flags::IO_STOPPING);
}
Ok(Poll::Pending) => {
// check read buffer, if buffer is not consumed it is unlikely
// that filter will properly complete shutdown
if flags.contains(Flags::RD_PAUSED)
|| flags.contains(Flags::BUF_R_FULL | Flags::BUF_R_READY)
{
st.dispatch_task.wake();
st.insert_flags(Flags::IO_STOPPING);
}
}
Err(err) => {
st.io_stopped(Some(err));
}
}
if let Err(err) = filter.process_write_buf(io, &st.buffer, 0) {
st.io_stopped(Some(err));
} }
} }
} }

View file

@ -1,14 +1,13 @@
//! utilities and helpers for testing //! utilities and helpers for testing
#![allow(clippy::let_underscore_future)] #![allow(clippy::let_underscore_future)]
use std::future::{poll_fn, Future};
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use std::task::{ready, Context, Poll, Waker}; use std::task::{Context, Poll, Waker};
use std::{any, cell::RefCell, cmp, fmt, io, mem, net, pin::Pin, rc::Rc}; use std::{any, cell::RefCell, cmp, fmt, future::poll_fn, io, mem, net, rc::Rc};
use ntex_bytes::{Buf, BufMut, Bytes, BytesVec}; use ntex_bytes::{Buf, BufMut, Bytes, BytesVec};
use ntex_util::time::{sleep, Millis, Sleep}; use ntex_util::time::{sleep, Millis};
use crate::{types, Handle, IoStream, ReadContext, ReadStatus, WriteContext, WriteStatus}; use crate::{types, Handle, IoStream, ReadContext, WriteContext, WriteContextBuf};
#[derive(Default)] #[derive(Default)]
struct AtomicWaker(Arc<Mutex<RefCell<Option<Waker>>>>); struct AtomicWaker(Arc<Mutex<RefCell<Option<Waker>>>>);
@ -356,14 +355,14 @@ impl IoStream for IoTest {
fn start(self, read: ReadContext, write: WriteContext) -> Option<Box<dyn Handle>> { fn start(self, read: ReadContext, write: WriteContext) -> Option<Box<dyn Handle>> {
let io = Rc::new(self); let io = Rc::new(self);
let _ = ntex_util::spawn(ReadTask { let mut rio = Read(io.clone());
io: io.clone(), let _ = ntex_util::spawn(async move {
state: read, read.handle(&mut rio).await;
}); });
let _ = ntex_util::spawn(WriteTask {
io: io.clone(), let mut wio = Write(io.clone());
state: write, let _ = ntex_util::spawn(async move {
st: IoWriteState::Processing(None), write.handle(&mut wio).await;
}); });
Some(Box::new(io)) Some(Box::new(io))
@ -382,271 +381,97 @@ impl Handle for Rc<IoTest> {
} }
/// Read io task /// Read io task
struct ReadTask { struct Read(Rc<IoTest>);
io: Rc<IoTest>,
state: ReadContext,
}
impl Future for ReadTask { impl crate::AsyncRead for Read {
type Output = (); async fn read(&mut self, mut buf: BytesVec) -> (BytesVec, io::Result<usize>) {
// read data from socket
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { let result = poll_fn(|cx| self.0.poll_read_buf(cx, &mut buf)).await;
let this = self.as_ref(); (buf, result)
this.state.with_buf(|buf, hw, lw| {
match this.state.poll_ready(cx) {
Poll::Ready(ReadStatus::Terminate) => {
log::trace!("read task is instructed to terminate");
Poll::Ready(Ok(()))
}
Poll::Ready(ReadStatus::Ready) => {
let io = &this.io;
// read data from socket
let mut new_bytes = 0;
loop {
// make sure we've got room
let remaining = buf.remaining_mut();
if remaining < lw {
buf.reserve(hw - remaining);
}
match io.poll_read_buf(cx, buf) {
Poll::Pending => {
log::trace!(
"no more data in io stream, read: {:?}",
new_bytes
);
break;
}
Poll::Ready(Ok(n)) => {
if n == 0 {
log::trace!("io stream is disconnected");
return Poll::Ready(Ok(()));
} else {
new_bytes += n;
if buf.len() >= hw {
log::trace!(
"high water mark pause reading, read: {:?}",
new_bytes
);
break;
}
}
}
Poll::Ready(Err(err)) => {
log::trace!("read task failed on io {:?}", err);
return Poll::Ready(Err(err));
}
}
}
Poll::Pending
}
Poll::Pending => Poll::Pending,
}
})
} }
} }
#[derive(Debug)] /// Write
enum IoWriteState { struct Write(Rc<IoTest>);
Processing(Option<Sleep>),
Shutdown(Option<Sleep>, Shutdown),
}
#[derive(Debug)] impl crate::AsyncWrite for Write {
enum Shutdown { async fn write(&mut self, buf: &mut WriteContextBuf) -> io::Result<()> {
None, poll_fn(|cx| {
Flushed, if let Some(mut b) = buf.take() {
Stopping, let result = write_io(&self.0, &mut b, cx);
} buf.set(b);
result
/// Write io task } else {
struct WriteTask { Poll::Ready(Ok(()))
st: IoWriteState,
io: Rc<IoTest>,
state: WriteContext,
}
impl Future for WriteTask {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.as_mut().get_mut();
match this.st {
IoWriteState::Processing(ref mut delay) => {
match this.state.poll_ready(cx) {
Poll::Ready(WriteStatus::Ready) => {
// flush framed instance
match ready!(flush_io(&this.io, &this.state, cx)) {
Ok(()) => Poll::Pending,
Err(e) => {
this.state.close(Some(e));
Poll::Ready(())
}
}
}
Poll::Ready(WriteStatus::Timeout(time)) => {
if delay.is_none() {
*delay = Some(sleep(time));
}
self.poll(cx)
}
Poll::Ready(WriteStatus::Shutdown(time)) => {
log::trace!("write task is instructed to shutdown");
let timeout = if let Some(delay) = delay.take() {
delay
} else {
sleep(time)
};
this.st = IoWriteState::Shutdown(Some(timeout), Shutdown::None);
self.poll(cx)
}
Poll::Ready(WriteStatus::Terminate) => {
log::trace!("write task is instructed to terminate");
// shutdown WRITE side
this.io
.local
.lock()
.unwrap()
.borrow_mut()
.flags
.insert(IoTestFlags::CLOSED);
this.state.close(None);
Poll::Ready(())
}
Poll::Pending => Poll::Pending,
}
} }
IoWriteState::Shutdown(ref mut delay, ref mut st) => { })
// close WRITE side and wait for disconnect on read side. .await
// use disconnect timeout, otherwise it could hang forever. }
loop {
match st {
Shutdown::None => {
// flush write buffer
match flush_io(&this.io, &this.state, cx) {
Poll::Ready(Ok(())) => {
*st = Shutdown::Flushed;
continue;
}
Poll::Ready(Err(err)) => {
log::trace!(
"write task is closed with err during flush {:?}",
err
);
this.state.close(Some(err));
return Poll::Ready(());
}
Poll::Pending => (),
}
}
Shutdown::Flushed => {
// shutdown WRITE side
this.io
.local
.lock()
.unwrap()
.borrow_mut()
.flags
.insert(IoTestFlags::CLOSED);
*st = Shutdown::Stopping;
continue;
}
Shutdown::Stopping => {
// read until 0 or err
let io = &this.io;
loop {
let mut buf = BytesVec::new();
match io.poll_read_buf(cx, &mut buf) {
Poll::Ready(Err(e)) => {
this.state.close(Some(e));
log::trace!("write task is stopped");
return Poll::Ready(());
}
Poll::Ready(Ok(0)) => {
this.state.close(None);
log::trace!("write task is stopped");
return Poll::Ready(());
}
Poll::Pending => break,
_ => (),
}
}
}
}
// disconnect timeout async fn flush(&mut self) -> io::Result<()> {
if let Some(ref delay) = delay { Ok(())
if delay.poll_elapsed(cx).is_pending() { }
return Poll::Pending;
} async fn shutdown(&mut self) -> io::Result<()> {
} // shutdown WRITE side
log::trace!("write task is stopped after delay"); self.0
this.state.close(None); .local
return Poll::Ready(()); .lock()
} .unwrap()
} .borrow_mut()
} .flags
.insert(IoTestFlags::CLOSED);
Ok(())
} }
} }
/// Flush write buffer to underlying I/O stream. /// Flush write buffer to underlying I/O stream.
pub(super) fn flush_io( pub(super) fn write_io(
io: &IoTest, io: &IoTest,
state: &WriteContext, buf: &mut BytesVec,
cx: &mut Context<'_>, cx: &mut Context<'_>,
) -> Poll<io::Result<()>> { ) -> Poll<io::Result<()>> {
state.with_buf(|buf| { let len = buf.len();
if let Some(buf) = buf {
let len = buf.len();
if len != 0 { if len != 0 {
log::trace!("flushing framed transport: {}", len); log::trace!("flushing framed transport: {}", len);
let mut written = 0; let mut written = 0;
let result = loop { let result = loop {
break match io.poll_write_buf(cx, &buf[written..]) { break match io.poll_write_buf(cx, &buf[written..]) {
Poll::Ready(Ok(n)) => { Poll::Ready(Ok(n)) => {
if n == 0 { if n == 0 {
log::trace!( log::trace!("disconnected during flush, written {}", written);
"disconnected during flush, written {}", Poll::Ready(Err(io::Error::new(
written io::ErrorKind::WriteZero,
); "failed to write frame to transport",
Poll::Ready(Err(io::Error::new( )))
io::ErrorKind::WriteZero, } else {
"failed to write frame to transport", written += n;
))) if written == len {
} else { buf.clear();
written += n; Poll::Ready(Ok(()))
if written == len { } else {
buf.clear(); continue;
Poll::Ready(Ok(()))
} else {
continue;
}
}
} }
Poll::Pending => { }
// remove written data }
buf.advance(written); Poll::Pending => {
Poll::Pending // remove written data
} buf.advance(written);
Poll::Ready(Err(e)) => { Poll::Pending
log::trace!("error during flush: {}", e); }
Poll::Ready(Err(e)) Poll::Ready(Err(e)) => {
} log::trace!("error during flush: {}", e);
}; Poll::Ready(Err(e))
}; }
log::trace!("flushed {} bytes", written); };
return result; };
} log::trace!("flushed {} bytes", written);
} result
} else {
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
}) }
} }
#[cfg(test)] #[cfg(test)]

View file

@ -1,6 +1,6 @@
[package] [package]
name = "ntex-net" name = "ntex-net"
version = "2.1.0" version = "2.2.0"
authors = ["ntex contributors <team@ntex.rs>"] authors = ["ntex contributors <team@ntex.rs>"]
description = "ntexwork utils for ntex framework" description = "ntexwork utils for ntex framework"
keywords = ["network", "framework", "async", "futures"] keywords = ["network", "framework", "async", "futures"]
@ -34,14 +34,14 @@ async-std = ["ntex-rt/async-std", "ntex-async-std"]
ntex-service = "3" ntex-service = "3"
ntex-bytes = "0.1" ntex-bytes = "0.1"
ntex-http = "0.1" ntex-http = "0.1"
ntex-io = "2.4" ntex-io = "2.5"
ntex-rt = "0.4.14" ntex-rt = "0.4.14"
ntex-util = "2" ntex-util = "2"
ntex-tokio = { version = "0.5.1", optional = true } ntex-tokio = { version = "0.5.2", optional = true }
ntex-compio = { version = "0.1", optional = true } ntex-compio = { version = "0.1.2", optional = true }
ntex-glommio = { version = "0.5", optional = true } ntex-glommio = { version = "0.5.1", optional = true }
ntex-async-std = { version = "0.5", optional = true } ntex-async-std = { version = "0.5.1", optional = true }
log = "0.4" log = "0.4"
thiserror = "1" thiserror = "1"

View file

@ -1,5 +1,9 @@
# Changes # Changes
## [0.5.2] - 2024-09-11
* Use new io api
## [0.5.1] - 2024-09-06 ## [0.5.1] - 2024-09-06
* Stop write task if io is closed * Stop write task if io is closed

View file

@ -1,6 +1,6 @@
[package] [package]
name = "ntex-tokio" name = "ntex-tokio"
version = "0.5.1" version = "0.5.2"
authors = ["ntex contributors <team@ntex.rs>"] authors = ["ntex contributors <team@ntex.rs>"]
description = "tokio intergration for ntex framework" description = "tokio intergration for ntex framework"
keywords = ["network", "framework", "async", "futures"] keywords = ["network", "framework", "async", "futures"]
@ -17,7 +17,7 @@ path = "src/lib.rs"
[dependencies] [dependencies]
ntex-bytes = "0.1" ntex-bytes = "0.1"
ntex-io = "2.4" ntex-io = "2.5"
ntex-util = "2" ntex-util = "2"
log = "0.4" log = "0.4"
tokio = { version = "1", default-features = false, features = ["rt", "net", "sync", "signal"] } tokio = { version = "1", default-features = false, features = ["rt", "net", "sync", "signal"] }

View file

@ -1,12 +1,12 @@
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use std::{any, cell::RefCell, cmp, future::Future, io, mem, pin::Pin, rc::Rc, rc::Weak}; use std::{any, cell::RefCell, cmp, future::poll_fn, io, mem, pin::Pin, rc::Rc, rc::Weak};
use ntex_bytes::{Buf, BufMut, BytesVec}; use ntex_bytes::{Buf, BufMut, BytesVec};
use ntex_io::{ use ntex_io::{
types, Filter, Handle, Io, IoBoxed, IoStream, ReadContext, ReadStatus, WriteContext, types, Filter, Handle, Io, IoBoxed, IoStream, ReadContext, WriteContext,
WriteStatus, WriteContextBuf,
}; };
use ntex_util::{ready, time::sleep, time::Millis, time::Sleep}; use ntex_util::{ready, time::Millis};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::TcpStream; use tokio::net::TcpStream;
@ -14,8 +14,14 @@ impl IoStream for crate::TcpStream {
fn start(self, read: ReadContext, write: WriteContext) -> Option<Box<dyn Handle>> { fn start(self, read: ReadContext, write: WriteContext) -> Option<Box<dyn Handle>> {
let io = Rc::new(RefCell::new(self.0)); let io = Rc::new(RefCell::new(self.0));
tokio::task::spawn_local(ReadTask::new(io.clone(), read)); let mut rio = Read(io.clone());
tokio::task::spawn_local(WriteTask::new(io.clone(), write)); tokio::task::spawn_local(async move {
read.handle(&mut rio).await;
});
let mut wio = Write(io.clone());
tokio::task::spawn_local(async move {
write.handle(&mut wio).await;
});
Some(Box::new(HandleWrapper(io))) Some(Box::new(HandleWrapper(io)))
} }
} }
@ -36,345 +42,149 @@ impl Handle for HandleWrapper {
} }
/// Read io task /// Read io task
struct ReadTask { struct Read(Rc<RefCell<TcpStream>>);
io: Rc<RefCell<TcpStream>>,
state: ReadContext,
}
impl ReadTask { impl ntex_io::AsyncRead for Read {
/// Create new read io task #[inline]
fn new(io: Rc<RefCell<TcpStream>>, state: ReadContext) -> Self { async fn read(&mut self, mut buf: BytesVec) -> (BytesVec, io::Result<usize>) {
Self { io, state } // read data from socket
} let result = poll_fn(|cx| {
} let mut n = 0;
let mut io = self.0.borrow_mut();
impl Future for ReadTask { loop {
type Output = (); return match poll_read_buf(Pin::new(&mut *io), cx, &mut buf)? {
Poll::Pending => {
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { if n > 0 {
let this = self.as_ref(); Poll::Ready(Ok(n))
match ready!(this.state.poll_ready(cx)) {
ReadStatus::Ready => {
this.state.with_buf(|buf, hw, lw| {
// read data from socket
let mut io = this.io.borrow_mut();
loop {
// make sure we've got room
let remaining = buf.remaining_mut();
if remaining < lw {
buf.reserve(hw - remaining);
}
return match poll_read_buf(Pin::new(&mut *io), cx, buf) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(n)) => {
if n == 0 {
log::trace!(
"{}: Tcp stream is disconnected",
this.state.tag()
);
Poll::Ready(Ok(()))
} else if buf.len() < hw {
continue;
} else {
Poll::Pending
}
}
Poll::Ready(Err(err)) => {
log::trace!(
"{}: Read task failed on io {:?}",
this.state.tag(),
err
);
Poll::Ready(Err(err))
}
};
}
})
}
ReadStatus::Terminate => {
log::trace!("{}: Read task is instructed to shutdown", this.state.tag());
Poll::Ready(())
}
}
}
}
#[derive(Debug)]
enum IoWriteState {
Processing(Option<Sleep>),
Shutdown(Sleep, Shutdown),
}
#[derive(Debug)]
enum Shutdown {
None,
Flushed,
Stopping(u16),
}
/// Write io task
struct WriteTask {
st: IoWriteState,
io: Rc<RefCell<TcpStream>>,
state: WriteContext,
}
impl WriteTask {
/// Create new write io task
fn new(io: Rc<RefCell<TcpStream>>, state: WriteContext) -> Self {
Self {
io,
state,
st: IoWriteState::Processing(None),
}
}
}
impl Future for WriteTask {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.as_mut().get_mut();
if this.state.poll_close(cx).is_ready() {
return Poll::Ready(());
}
match this.st {
IoWriteState::Processing(ref mut delay) => {
match ready!(this.state.poll_ready(cx)) {
WriteStatus::Ready => {
if let Some(delay) = delay {
if delay.poll_elapsed(cx).is_ready() {
this.state.close(Some(io::Error::new(
io::ErrorKind::TimedOut,
"Operation timedout",
)));
return Poll::Ready(());
}
}
// flush io stream
match ready!(this.state.with_buf(|buf| flush_io(
&mut *this.io.borrow_mut(),
buf,
cx,
&this.state
))) {
Ok(()) => Poll::Pending,
Err(e) => {
this.state.close(Some(e));
Poll::Ready(())
}
}
}
WriteStatus::Timeout(time) => {
log::trace!(
"{}: Initiate timeout delay for {:?}",
this.state.tag(),
time
);
if delay.is_none() {
*delay = Some(sleep(time));
}
self.poll(cx)
}
WriteStatus::Shutdown(time) => {
log::trace!(
"{}: Write task is instructed to shutdown",
this.state.tag()
);
let timeout = if let Some(delay) = delay.take() {
delay
} else { } else {
sleep(time) Poll::Pending
};
this.st = IoWriteState::Shutdown(timeout, Shutdown::None);
self.poll(cx)
}
WriteStatus::Terminate => {
log::trace!(
"{}: Write task is instructed to terminate",
this.state.tag()
);
if !matches!(
this.io.borrow().linger(),
Ok(Some(std::time::Duration::ZERO))
) {
// call shutdown to prevent flushing data on terminated Io. when
// linger is set to zero, closing will reset the connection, so
// shutdown is not neccessary.
let _ = Pin::new(&mut *this.io.borrow_mut()).poll_shutdown(cx);
} }
this.state.close(None);
Poll::Ready(())
} }
} Poll::Ready(size) => {
n += size;
if n > 0 && buf.remaining_mut() > 0 {
continue;
}
Poll::Ready(Ok(n))
}
};
} }
IoWriteState::Shutdown(ref mut delay, ref mut st) => { })
// close WRITE side and wait for disconnect on read side. .await;
// use disconnect timeout, otherwise it could hang forever.
loop {
if this.state.poll_close(cx).is_ready() {
return Poll::Ready(());
}
match st {
Shutdown::None => {
// flush write buffer
let mut io = this.io.borrow_mut();
match this
.state
.with_buf(|buf| flush_io(&mut *io, buf, cx, &this.state))
{
Poll::Ready(Ok(())) => {
*st = Shutdown::Flushed;
continue;
}
Poll::Ready(Err(err)) => {
log::trace!(
"{}: Write task is closed with err during flush, {:?}", this.state.tag(),
err
);
this.state.close(Some(err));
return Poll::Ready(());
}
Poll::Pending => (),
}
}
Shutdown::Flushed => {
// shutdown WRITE side
match Pin::new(&mut *this.io.borrow_mut()).poll_shutdown(cx) {
Poll::Ready(Ok(_)) => {
*st = Shutdown::Stopping(0);
continue;
}
Poll::Ready(Err(e)) => {
log::trace!(
"{}: Write task is closed with err during shutdown",
this.state.tag()
);
this.state.close(Some(e));
return Poll::Ready(());
}
_ => (),
}
}
Shutdown::Stopping(ref mut count) => {
// read until 0 or err
let mut buf = [0u8; 512];
loop {
let mut read_buf = ReadBuf::new(&mut buf);
match Pin::new(&mut *this.io.borrow_mut())
.poll_read(cx, &mut read_buf)
{
Poll::Ready(Err(_)) | Poll::Ready(Ok(_))
if read_buf.filled().is_empty() =>
{
this.state.close(None);
log::trace!(
"{}: Tokio write task is stopped",
this.state.tag()
);
return Poll::Ready(());
}
Poll::Pending => {
*count += read_buf.filled().len() as u16;
if *count > 4096 {
log::trace!("{}: Tokio write task is stopped, too much input", this.state.tag());
this.state.close(None);
return Poll::Ready(());
}
break;
}
_ => (),
}
}
}
}
// disconnect timeout (buf, result)
if delay.poll_elapsed(cx).is_pending() {
return Poll::Pending;
}
log::trace!("{}: Write task is stopped after delay", this.state.tag());
this.state.close(None);
return Poll::Ready(());
}
}
}
} }
} }
struct Write(Rc<RefCell<TcpStream>>);
impl ntex_io::AsyncWrite for Write {
#[inline]
async fn write(&mut self, buf: &mut WriteContextBuf) -> io::Result<()> {
poll_fn(|cx| {
if let Some(mut b) = buf.take() {
let result = flush_io(&mut *self.0.borrow_mut(), &mut b, cx);
buf.set(b);
result
} else {
Poll::Ready(Ok(()))
}
})
.await
}
#[inline]
async fn flush(&mut self) -> io::Result<()> {
Ok(())
}
#[inline]
async fn shutdown(&mut self) -> io::Result<()> {
poll_fn(|cx| Pin::new(&mut *self.0.borrow_mut()).poll_shutdown(cx)).await
}
}
pub fn poll_read_buf<T: AsyncRead>(
io: Pin<&mut T>,
cx: &mut Context<'_>,
buf: &mut BytesVec,
) -> Poll<io::Result<usize>> {
let n = {
let dst =
unsafe { &mut *(buf.chunk_mut() as *mut _ as *mut [mem::MaybeUninit<u8>]) };
let mut buf = ReadBuf::uninit(dst);
let ptr = buf.filled().as_ptr();
if io.poll_read(cx, &mut buf)?.is_pending() {
return Poll::Pending;
}
// Ensure the pointer does not change from under us
assert_eq!(ptr, buf.filled().as_ptr());
buf.filled().len()
};
// Safety: This is guaranteed to be the number of initialized (and read)
// bytes due to the invariants provided by `ReadBuf::filled`.
unsafe {
buf.advance_mut(n);
}
Poll::Ready(Ok(n))
}
/// Flush write buffer to underlying I/O stream. /// Flush write buffer to underlying I/O stream.
pub(super) fn flush_io<T: AsyncRead + AsyncWrite + Unpin>( pub(super) fn flush_io<T: AsyncRead + AsyncWrite + Unpin>(
io: &mut T, io: &mut T,
buf: &mut Option<BytesVec>, buf: &mut BytesVec,
cx: &mut Context<'_>, cx: &mut Context<'_>,
st: &WriteContext,
) -> Poll<io::Result<()>> { ) -> Poll<io::Result<()>> {
if let Some(buf) = buf { let len = buf.len();
let len = buf.len();
if len != 0 { if len != 0 {
// log::trace!("{}: Flushing framed transport: {:?}", st.tag(), buf.len()); // log::trace!("{}: Flushing framed transport: {:?}", st.tag(), buf.len());
let mut written = 0; let mut written = 0;
let result = loop { let result = loop {
break match Pin::new(&mut *io).poll_write(cx, &buf[written..]) { break match Pin::new(&mut *io).poll_write(cx, &buf[written..]) {
Poll::Ready(Ok(n)) => { Poll::Ready(Ok(n)) => {
if n == 0 { if n == 0 {
log::trace!( Poll::Ready(Err(io::Error::new(
"{}: Disconnected during flush, written {}", io::ErrorKind::WriteZero,
st.tag(), "failed to write frame to transport",
written )))
); } else {
Poll::Ready(Err(io::Error::new( written += n;
io::ErrorKind::WriteZero, if written == len {
"failed to write frame to transport", buf.clear();
))) Poll::Ready(Ok(()))
} else { } else {
written += n; continue;
if written == len {
buf.clear();
Poll::Ready(Ok(()))
} else {
continue;
}
} }
} }
Poll::Pending => {
// remove written data
buf.advance(written);
Poll::Pending
}
Poll::Ready(Err(e)) => {
log::trace!("{}: Error during flush: {}", st.tag(), e);
Poll::Ready(Err(e))
}
};
};
// log::trace!("{}: flushed {} bytes", st.tag(), written);
// flush
return if written > 0 {
match Pin::new(&mut *io).poll_flush(cx) {
Poll::Ready(Ok(_)) => result,
Poll::Pending => Poll::Pending,
Poll::Ready(Err(e)) => {
log::trace!("{}: Error during flush: {}", st.tag(), e);
Poll::Ready(Err(e))
}
} }
} else { Poll::Pending => {
result // remove written data
buf.advance(written);
Poll::Pending
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
}; };
};
// log::trace!("{}: flushed {} bytes", st.tag(), written);
// flush
if written > 0 {
match Pin::new(&mut *io).poll_flush(cx) {
Poll::Ready(Ok(_)) => result,
Poll::Pending => Poll::Pending,
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
}
} else {
result
} }
} else {
Poll::Ready(Ok(()))
} }
Poll::Ready(Ok(()))
} }
pub struct TokioIoBoxed(IoBoxed); pub struct TokioIoBoxed(IoBoxed);
@ -472,294 +282,77 @@ mod unixstream {
fn start(self, read: ReadContext, write: WriteContext) -> Option<Box<dyn Handle>> { fn start(self, read: ReadContext, write: WriteContext) -> Option<Box<dyn Handle>> {
let io = Rc::new(RefCell::new(self.0)); let io = Rc::new(RefCell::new(self.0));
tokio::task::spawn_local(ReadTask::new(io.clone(), read)); let mut rio = Read(io.clone());
tokio::task::spawn_local(WriteTask::new(io, write)); tokio::task::spawn_local(async move {
read.handle(&mut rio).await;
});
let mut wio = Write(io.clone());
tokio::task::spawn_local(async move {
write.handle(&mut wio).await;
});
None None
} }
} }
/// Read io task struct Read(Rc<RefCell<UnixStream>>);
struct ReadTask {
io: Rc<RefCell<UnixStream>>,
state: ReadContext,
}
impl ReadTask { impl ntex_io::AsyncRead for Read {
/// Create new read io task #[inline]
fn new(io: Rc<RefCell<UnixStream>>, state: ReadContext) -> Self { async fn read(&mut self, mut buf: BytesVec) -> (BytesVec, io::Result<usize>) {
Self { io, state } // read data from socket
} let result = poll_fn(|cx| {
} let mut n = 0;
let mut io = self.0.borrow_mut();
impl Future for ReadTask { loop {
type Output = (); return match poll_read_buf(Pin::new(&mut *io), cx, &mut buf)? {
Poll::Pending => {
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { if n > 0 {
let this = self.as_ref(); Poll::Ready(Ok(n))
} else {
this.state.with_buf(|buf, hw, lw| { Poll::Pending
match ready!(this.state.poll_ready(cx)) {
ReadStatus::Ready => {
// read data from socket
let mut io = this.io.borrow_mut();
loop {
// make sure we've got room
let remaining = buf.remaining_mut();
if remaining < lw {
buf.reserve(hw - remaining);
} }
return match poll_read_buf(Pin::new(&mut *io), cx, buf) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(n)) => {
if n == 0 {
log::trace!(
"{}: Tokio unix stream is disconnected",
this.state.tag()
);
Poll::Ready(Ok(()))
} else if buf.len() < hw {
continue;
} else {
Poll::Pending
}
}
Poll::Ready(Err(err)) => {
log::trace!(
"{}: Unix stream read task failed {:?}",
this.state.tag(),
err
);
Poll::Ready(Err(err))
}
};
} }
} Poll::Ready(size) => {
ReadStatus::Terminate => { n += size;
log::trace!( if n > 0 && buf.remaining_mut() > 0 {
"{}: Read task is instructed to shutdown", continue;
this.state.tag() }
); Poll::Ready(Ok(n))
Poll::Ready(Ok(())) }
} };
} }
}) })
.await;
(buf, result)
} }
} }
/// Write io task struct Write(Rc<RefCell<UnixStream>>);
struct WriteTask {
st: IoWriteState,
io: Rc<RefCell<UnixStream>>,
state: WriteContext,
}
impl WriteTask { impl ntex_io::AsyncWrite for Write {
/// Create new write io task #[inline]
fn new(io: Rc<RefCell<UnixStream>>, state: WriteContext) -> Self { async fn write(&mut self, buf: &mut WriteContextBuf) -> io::Result<()> {
Self { poll_fn(|cx| {
io, if let Some(mut b) = buf.take() {
state, let result = flush_io(&mut *self.0.borrow_mut(), &mut b, cx);
st: IoWriteState::Processing(None), buf.set(b);
} result
} else {
Poll::Ready(Ok(()))
}
})
.await
} }
}
impl Future for WriteTask { #[inline]
type Output = (); async fn flush(&mut self) -> io::Result<()> {
Ok(())
}
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { #[inline]
let this = self.as_mut().get_mut(); async fn shutdown(&mut self) -> io::Result<()> {
poll_fn(|cx| Pin::new(&mut *self.0.borrow_mut()).poll_shutdown(cx)).await
if this.state.poll_close(cx).is_ready() {
return Poll::Ready(());
}
match this.st {
IoWriteState::Processing(ref mut delay) => {
match this.state.poll_ready(cx) {
Poll::Ready(WriteStatus::Ready) => {
if let Some(delay) = delay {
if delay.poll_elapsed(cx).is_ready() {
this.state.close(Some(io::Error::new(
io::ErrorKind::TimedOut,
"Operation timedout",
)));
return Poll::Ready(());
}
}
// flush io stream
match ready!(this.state.with_buf(|buf| flush_io(
&mut *this.io.borrow_mut(),
buf,
cx,
&this.state
))) {
Ok(()) => Poll::Pending,
Err(e) => {
this.state.close(Some(e));
Poll::Ready(())
}
}
}
Poll::Ready(WriteStatus::Timeout(time)) => {
if delay.is_none() {
*delay = Some(sleep(time));
}
self.poll(cx)
}
Poll::Ready(WriteStatus::Shutdown(time)) => {
log::trace!(
"{}: Write task is instructed to shutdown",
this.state.tag()
);
let timeout = if let Some(delay) = delay.take() {
delay
} else {
sleep(time)
};
this.st = IoWriteState::Shutdown(timeout, Shutdown::None);
self.poll(cx)
}
Poll::Ready(WriteStatus::Terminate) => {
log::trace!(
"{}: Write task is instructed to terminate",
this.state.tag()
);
let _ = Pin::new(&mut *this.io.borrow_mut()).poll_shutdown(cx);
this.state.close(None);
Poll::Ready(())
}
Poll::Pending => Poll::Pending,
}
}
IoWriteState::Shutdown(ref mut delay, ref mut st) => {
// close WRITE side and wait for disconnect on read side.
// use disconnect timeout, otherwise it could hang forever.
loop {
if this.state.poll_close(cx).is_ready() {
return Poll::Ready(());
}
match st {
Shutdown::None => {
// flush write buffer
let mut io = this.io.borrow_mut();
match this.state.with_buf(|buf| {
flush_io(&mut *io, buf, cx, &this.state)
}) {
Poll::Ready(Ok(())) => {
*st = Shutdown::Flushed;
continue;
}
Poll::Ready(Err(err)) => {
log::trace!(
"{}: Write task is closed with err during flush, {:?}", this.state.tag(),
err
);
this.state.close(Some(err));
return Poll::Ready(());
}
Poll::Pending => (),
}
}
Shutdown::Flushed => {
// shutdown WRITE side
match Pin::new(&mut *this.io.borrow_mut()).poll_shutdown(cx)
{
Poll::Ready(Ok(_)) => {
*st = Shutdown::Stopping(0);
continue;
}
Poll::Ready(Err(e)) => {
log::trace!(
"{}: Write task is closed with err during shutdown", this.state.tag()
);
this.state.close(Some(e));
return Poll::Ready(());
}
_ => (),
}
}
Shutdown::Stopping(ref mut count) => {
// read until 0 or err
let mut buf = [0u8; 512];
loop {
let mut read_buf = ReadBuf::new(&mut buf);
match Pin::new(&mut *this.io.borrow_mut())
.poll_read(cx, &mut read_buf)
{
Poll::Ready(Err(_)) | Poll::Ready(Ok(_))
if read_buf.filled().is_empty() =>
{
this.state.close(None);
log::trace!(
"{}: Write task is stopped",
this.state.tag()
);
return Poll::Ready(());
}
Poll::Pending => {
*count += read_buf.filled().len() as u16;
if *count > 4096 {
log::trace!(
"{}: Write task is stopped, too much input", this.state.tag()
);
this.state.close(None);
return Poll::Ready(());
}
break;
}
_ => (),
}
}
}
}
// disconnect timeout
if delay.poll_elapsed(cx).is_pending() {
return Poll::Pending;
}
log::trace!(
"{}: Write task is stopped after delay",
this.state.tag()
);
this.state.close(None);
return Poll::Ready(());
}
}
}
} }
} }
} }
pub fn poll_read_buf<T: AsyncRead>(
io: Pin<&mut T>,
cx: &mut Context<'_>,
buf: &mut BytesVec,
) -> Poll<io::Result<usize>> {
let n = {
let dst =
unsafe { &mut *(buf.chunk_mut() as *mut _ as *mut [mem::MaybeUninit<u8>]) };
let mut buf = ReadBuf::uninit(dst);
let ptr = buf.filled().as_ptr();
if io.poll_read(cx, &mut buf)?.is_pending() {
return Poll::Pending;
}
// Ensure the pointer does not change from under us
assert_eq!(ptr, buf.filled().as_ptr());
buf.filled().len()
};
// Safety: This is guaranteed to be the number of initialized (and read)
// bytes due to the invariants provided by `ReadBuf::filled`.
unsafe {
buf.advance_mut(n);
}
Poll::Ready(Ok(n))
}

View file

@ -4,10 +4,8 @@ use ntex_bytes::PoolRef;
use ntex_io::Io; use ntex_io::Io;
mod io; mod io;
mod signals;
pub use self::io::{SocketOptions, TokioIoBoxed}; pub use self::io::{SocketOptions, TokioIoBoxed};
pub use self::signals::{signal, Signal};
struct TcpStream(tokio::net::TcpStream); struct TcpStream(tokio::net::TcpStream);

View file

@ -1,138 +0,0 @@
use std::{
cell::RefCell, future::Future, mem, pin::Pin, rc::Rc, task::Context, task::Poll,
};
use tokio::sync::oneshot;
use tokio::task::spawn_local;
thread_local! {
static SRUN: RefCell<bool> = const { RefCell::new(false) };
static SHANDLERS: Rc<RefCell<Vec<oneshot::Sender<Signal>>>> = Default::default();
}
/// Different types of process signals
#[derive(PartialEq, Eq, Clone, Copy, Debug)]
pub enum Signal {
/// SIGHUP
Hup,
/// SIGINT
Int,
/// SIGTERM
Term,
/// SIGQUIT
Quit,
}
/// Register signal handler.
///
/// Signals are handled by oneshots, you have to re-register
/// after each signal.
pub fn signal() -> Option<oneshot::Receiver<Signal>> {
if !SRUN.with(|v| *v.borrow()) {
spawn_local(Signals::new());
}
SHANDLERS.with(|handlers| {
let (tx, rx) = oneshot::channel();
handlers.borrow_mut().push(tx);
Some(rx)
})
}
struct Signals {
#[cfg(not(unix))]
signal: Pin<Box<dyn Future<Output = std::io::Result<()>>>>,
#[cfg(unix)]
signals: Vec<(
Signal,
tokio::signal::unix::Signal,
tokio::signal::unix::SignalKind,
)>,
}
impl Signals {
fn new() -> Signals {
SRUN.with(|h| *h.borrow_mut() = true);
#[cfg(not(unix))]
{
Signals {
signal: Box::pin(tokio::signal::ctrl_c()),
}
}
#[cfg(unix)]
{
use tokio::signal::unix;
let sig_map = [
(unix::SignalKind::interrupt(), Signal::Int),
(unix::SignalKind::hangup(), Signal::Hup),
(unix::SignalKind::terminate(), Signal::Term),
(unix::SignalKind::quit(), Signal::Quit),
];
let mut signals = Vec::new();
for (kind, sig) in sig_map.iter() {
match unix::signal(*kind) {
Ok(stream) => signals.push((*sig, stream, *kind)),
Err(e) => log::error!(
"Cannot initialize stream handler for {:?} err: {}",
sig,
e
),
}
}
Signals { signals }
}
}
}
impl Drop for Signals {
fn drop(&mut self) {
SRUN.with(|h| *h.borrow_mut() = false);
}
}
impl Future for Signals {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
#[cfg(not(unix))]
{
if self.signal.as_mut().poll(cx).is_ready() {
let handlers = SHANDLERS.with(|h| mem::take(&mut *h.borrow_mut()));
for sender in handlers {
let _ = sender.send(Signal::Int);
}
}
Poll::Pending
}
#[cfg(unix)]
{
for (sig, stream, kind) in self.signals.iter_mut() {
loop {
if Pin::new(&mut *stream).poll_recv(cx).is_ready() {
let handlers = SHANDLERS.with(|h| mem::take(&mut *h.borrow_mut()));
for sender in handlers {
let _ = sender.send(*sig);
}
match tokio::signal::unix::signal(*kind) {
Ok(s) => {
*stream = s;
continue;
}
Err(e) => log::error!(
"Cannot initialize stream handler for {:?} err: {}",
sig,
e
),
}
}
break;
}
}
Poll::Pending
}
}
}

View file

@ -71,7 +71,7 @@ ntex-bytes = "0.1.27"
ntex-server = "2.3" ntex-server = "2.3"
ntex-h2 = "1.1" ntex-h2 = "1.1"
ntex-rt = "0.4.15" ntex-rt = "0.4.15"
ntex-io = "2.4" ntex-io = "2.5"
ntex-net = "2.1" ntex-net = "2.1"
ntex-tls = "2.1" ntex-tls = "2.1"