Refactor framed io (#67)

* refactor framed io

* allow to add filters

* move io code to separate module

* add into_boxed()

* remove uneeded IO_STOP state

* simplify on_disconnect storage

* cleanup io state

* add io connector
This commit is contained in:
Nikolay Kim 2021-12-13 17:19:43 +06:00 committed by GitHub
parent 2188d92725
commit 841ad736d4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
30 changed files with 3839 additions and 408 deletions

View file

@ -8,7 +8,7 @@ jobs:
fail-fast: false
matrix:
version:
- 1.53.0 # MSRV
- 1.56.0 # MSRV
- stable
- nightly
@ -43,7 +43,7 @@ jobs:
key: ${{ matrix.version }}-x86_64-unknown-linux-gnu-cargo-index-trimmed-${{ hashFiles('**/Cargo.lock') }}
- name: Cache cargo tarpaulin
if: matrix.version == '1.53.0' && (github.ref == 'refs/heads/master' || github.event_name == 'pull_request')
if: matrix.version == '1.56.0' && (github.ref == 'refs/heads/master' || github.event_name == 'pull_request')
uses: actions/cache@v1
with:
path: ~/.cargo/bin
@ -57,19 +57,19 @@ jobs:
args: --all --all-features --no-fail-fast -- --nocapture
- name: Install tarpaulin
if: matrix.version == '1.53.0' && (github.ref == 'refs/heads/master' || github.event_name == 'pull_request')
if: matrix.version == '1.56.0' && (github.ref == 'refs/heads/master' || github.event_name == 'pull_request')
continue-on-error: true
run: |
cargo install cargo-tarpaulin
- name: Generate coverage report
if: matrix.version == '1.53.0' && (github.ref == 'refs/heads/master' || github.event_name == 'pull_request')
if: matrix.version == '1.56.0' && (github.ref == 'refs/heads/master' || github.event_name == 'pull_request')
continue-on-error: true
run: |
cargo tarpaulin --out Xml --all --all-features
- name: Upload to Codecov
if: matrix.version == '1.53.0' && (github.ref == 'refs/heads/master' || github.event_name == 'pull_request')
if: matrix.version == '1.56.0' && (github.ref == 'refs/heads/master' || github.event_name == 'pull_request')
continue-on-error: true
uses: codecov/codecov-action@v1
with:

View file

@ -3,6 +3,7 @@ members = [
"ntex",
"ntex-bytes",
"ntex-codec",
"ntex-io",
"ntex-router",
"ntex-rt",
"ntex-service",
@ -14,6 +15,7 @@ members = [
ntex = { path = "ntex" }
ntex-bytes = { path = "ntex-bytes" }
ntex-codec = { path = "ntex-codec" }
ntex-io = { path = "ntex-io" }
ntex-router = { path = "ntex-router" }
ntex-rt = { path = "ntex-rt" }
ntex-service = { path = "ntex-service" }

View file

@ -1,6 +1,6 @@
//! Provides abstractions for working with bytes.
//!
//! This is fork of bytes crate (https://github.com/tokio-rs/bytes)
//! This is fork of [bytes crate](https://github.com/tokio-rs/bytes)
//!
//! The `ntex-bytes` crate provides an efficient byte buffer structure
//! ([`Bytes`](struct.Bytes.html)) and traits for working with buffer

View file

@ -23,5 +23,5 @@ log = "0.4"
tokio = { version = "1", default-features = false }
[dev-dependencies]
ntex = "0.3.13"
ntex = "0.4.13"
futures = "0.3.13"

View file

@ -1,8 +1,8 @@
//! Utilities for encoding and decoding frames.
//!
//! Contains adapters to go from streams of bytes, [`AsyncRead`] and
//! [`AsyncWrite`], to framed streams implementing [`Sink`] and [`Stream`].
//! Framed streams are also known as [transports].
//! [`AsyncWrite`], to framed streams implementing `Sink` and `Stream`.
//! Framed streams are also known as `transports`.
//!
//! [`AsyncRead`]: #
//! [`AsyncWrite`]: #

39
ntex-io/Cargo.toml Normal file
View file

@ -0,0 +1,39 @@
[package]
name = "ntex-io"
version = "0.1.0"
authors = ["ntex contributors <team@ntex.rs>"]
description = "Utilities for encoding and decoding frames"
keywords = ["network", "framework", "async", "futures"]
homepage = "https://ntex.rs"
repository = "https://github.com/ntex-rs/ntex.git"
documentation = "https://docs.rs/ntex-io/"
categories = ["network-programming", "asynchronous"]
license = "MIT"
edition = "2018"
[lib]
name = "ntex_io"
path = "src/lib.rs"
[features]
default = ["tokio"]
# tokio support
tokio = ["tok-io"]
[dependencies]
bitflags = "1.3"
fxhash = "0.2.1"
ntex-codec = "0.5.1"
ntex-bytes = "0.1.7"
ntex-util = "0.1.2"
ntex-service = "0.2.1"
log = "0.4"
pin-project-lite = "0.2"
tok-io = { version = "1", package = "tokio", default-features = false, features = ["net"], optional = true }
[dev-dependencies]
ntex = "0.4.13"
futures = "0.3.13"
rand = "0.8"

1
ntex-io/LICENSE Symbolic link
View file

@ -0,0 +1 @@
../LICENSE

904
ntex-io/src/dispatcher.rs Normal file
View file

@ -0,0 +1,904 @@
//! Framed transport dispatcher
use std::{
cell::Cell, future::Future, pin::Pin, rc::Rc, task::Context, task::Poll, time,
};
use ntex_bytes::Pool;
use ntex_codec::{Decoder, Encoder};
use ntex_service::{IntoService, Service};
use ntex_util::time::{now, Seconds};
use ntex_util::{future::Either, spawn};
use super::{DispatchItem, IoBoxed, ReadRef, Timer, WriteRef};
type Response<U> = <U as Encoder>::Item;
pin_project_lite::pin_project! {
/// Framed dispatcher - is a future that reads frames from bytes stream
/// and pass then to the service.
pub struct Dispatcher<S, U>
where
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>>,
S::Error: 'static,
S::Future: 'static,
U: Encoder,
U: Decoder,
<U as Encoder>::Item: 'static,
{
service: S,
inner: DispatcherInner<S, U>,
#[pin]
fut: Option<S::Future>,
}
}
struct DispatcherInner<S, U>
where
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>>,
U: Encoder + Decoder,
{
st: Cell<DispatcherState>,
state: IoBoxed,
timer: Timer,
ka_timeout: Seconds,
ka_updated: Cell<time::Instant>,
error: Cell<Option<S::Error>>,
ready_err: Cell<bool>,
shared: Rc<DispatcherShared<S, U>>,
pool: Pool,
}
struct DispatcherShared<S, U>
where
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>>,
U: Encoder + Decoder,
{
codec: U,
error: Cell<Option<DispatcherError<S::Error, <U as Encoder>::Error>>>,
inflight: Cell<usize>,
}
#[derive(Copy, Clone, Debug)]
enum DispatcherState {
Processing,
Backpressure,
Stop,
Shutdown,
}
enum DispatcherError<S, U> {
KeepAlive,
Encoder(U),
Service(S),
}
enum PollService<U: Encoder + Decoder> {
Item(DispatchItem<U>),
ServiceError,
Ready,
}
impl<S, U> From<Either<S, U>> for DispatcherError<S, U> {
fn from(err: Either<S, U>) -> Self {
match err {
Either::Left(err) => DispatcherError::Service(err),
Either::Right(err) => DispatcherError::Encoder(err),
}
}
}
impl<S, U> Dispatcher<S, U>
where
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>> + 'static,
U: Decoder + Encoder + 'static,
<U as Encoder>::Item: 'static,
{
/// Construct new `Dispatcher` instance.
pub fn new<F: IntoService<S>>(
state: IoBoxed,
codec: U,
service: F,
timer: Timer,
) -> Self {
let updated = now();
let ka_timeout = Seconds(30);
// register keepalive timer
let expire = updated + time::Duration::from(ka_timeout);
timer.register(expire, expire, &state);
Dispatcher {
service: service.into_service(),
fut: None,
inner: DispatcherInner {
pool: state.memory_pool().pool(),
ka_updated: Cell::new(updated),
error: Cell::new(None),
ready_err: Cell::new(false),
st: Cell::new(DispatcherState::Processing),
shared: Rc::new(DispatcherShared {
codec,
error: Cell::new(None),
inflight: Cell::new(0),
}),
state,
timer,
ka_timeout,
},
}
}
/// Set keep-alive timeout.
///
/// To disable timeout set value to 0.
///
/// By default keep-alive timeout is set to 30 seconds.
pub fn keepalive_timeout(mut self, timeout: Seconds) -> Self {
// register keepalive timer
let prev = self.inner.ka_updated.get() + time::Duration::from(self.inner.ka());
if timeout.is_zero() {
self.inner.timer.unregister(prev, &self.inner.state);
} else {
let expire = self.inner.ka_updated.get() + time::Duration::from(timeout);
self.inner.timer.register(expire, prev, &self.inner.state);
}
self.inner.ka_timeout = timeout;
self
}
/// Set connection disconnect timeout in seconds.
///
/// Defines a timeout for disconnect connection. If a disconnect procedure does not complete
/// within this time, the connection get dropped.
///
/// To disable timeout set value to 0.
///
/// By default disconnect timeout is set to 1 seconds.
pub fn disconnect_timeout(self, val: Seconds) -> Self {
self.inner.state.set_disconnect_timeout(val);
self
}
}
impl<S, U> DispatcherShared<S, U>
where
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>>,
S::Error: 'static,
S::Future: 'static,
U: Encoder + Decoder,
<U as Encoder>::Item: 'static,
{
fn handle_result(&self, item: Result<S::Response, S::Error>, write: WriteRef<'_>) {
self.inflight.set(self.inflight.get() - 1);
match write.encode_result(item, &self.codec) {
Ok(true) => (),
Ok(false) => write.enable_backpressure(None),
Err(err) => self.error.set(Some(err.into())),
}
write.wake_dispatcher();
}
}
impl<S, U> Future for Dispatcher<S, U>
where
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>> + 'static,
U: Decoder + Encoder + 'static,
<U as Encoder>::Item: 'static,
{
type Output = Result<(), S::Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.as_mut().project();
let slf = &this.inner;
let state = &slf.state;
let read = state.read();
let write = state.write();
// handle service response future
if let Some(fut) = this.fut.as_mut().as_pin_mut() {
match fut.poll(cx) {
Poll::Pending => (),
Poll::Ready(item) => {
this.fut.set(None);
slf.shared.inflight.set(slf.shared.inflight.get() - 1);
slf.handle_result(item, write);
}
}
}
// handle memory pool pressure
if slf.pool.poll_ready(cx).is_pending() {
read.pause(cx);
return Poll::Pending;
}
loop {
match slf.st.get() {
DispatcherState::Processing => {
let result = match slf.poll_service(this.service, cx, read) {
Poll::Pending => return Poll::Pending,
Poll::Ready(result) => result,
};
let item = match result {
PollService::Ready => {
if !write.is_ready() {
// instruct write task to notify dispatcher when data is flushed
write.enable_backpressure(Some(cx));
slf.st.set(DispatcherState::Backpressure);
DispatchItem::WBackPressureEnabled
} else if read.is_ready() {
// decode incoming bytes if buffer is ready
match read.decode(&slf.shared.codec) {
Ok(Some(el)) => {
slf.update_keepalive();
DispatchItem::Item(el)
}
Ok(None) => {
log::trace!("not enough data to decode next frame, register dispatch task");
read.wake(cx);
return Poll::Pending;
}
Err(err) => {
slf.st.set(DispatcherState::Stop);
slf.unregister_keepalive();
DispatchItem::DecoderError(err)
}
}
} else {
// no new events
state.register_dispatcher(cx);
return Poll::Pending;
}
}
PollService::Item(item) => item,
PollService::ServiceError => continue,
};
// call service
if this.fut.is_none() {
// optimize first service call
this.fut.set(Some(this.service.call(item)));
match this.fut.as_mut().as_pin_mut().unwrap().poll(cx) {
Poll::Ready(res) => {
this.fut.set(None);
slf.handle_result(res, write);
}
Poll::Pending => {
slf.shared.inflight.set(slf.shared.inflight.get() + 1)
}
}
} else {
slf.spawn_service_call(this.service.call(item));
}
}
// handle write back-pressure
DispatcherState::Backpressure => {
let result = match slf.poll_service(this.service, cx, read) {
Poll::Ready(result) => result,
Poll::Pending => return Poll::Pending,
};
let item = match result {
PollService::Ready => {
if write.is_ready() {
slf.st.set(DispatcherState::Processing);
DispatchItem::WBackPressureDisabled
} else {
return Poll::Pending;
}
}
PollService::Item(item) => item,
PollService::ServiceError => continue,
};
// call service
if this.fut.is_none() {
// optimize first service call
this.fut.set(Some(this.service.call(item)));
match this.fut.as_mut().as_pin_mut().unwrap().poll(cx) {
Poll::Ready(res) => {
this.fut.set(None);
slf.handle_result(res, write);
}
Poll::Pending => {
slf.shared.inflight.set(slf.shared.inflight.get() + 1)
}
}
} else {
slf.spawn_service_call(this.service.call(item));
}
}
// drain service responses
DispatcherState::Stop => {
// service may relay on poll_ready for response results
if !this.inner.ready_err.get() {
let _ = this.service.poll_ready(cx);
}
if slf.shared.inflight.get() == 0 {
slf.st.set(DispatcherState::Shutdown);
state.shutdown(cx);
} else {
state.register_dispatcher(cx);
return Poll::Pending;
}
}
// shutdown service
DispatcherState::Shutdown => {
let err = slf.error.take();
return if this.service.poll_shutdown(cx, err.is_some()).is_ready() {
log::trace!("service shutdown is completed, stop");
Poll::Ready(if let Some(err) = err {
Err(err)
} else {
Ok(())
})
} else {
slf.error.set(err);
Poll::Pending
};
}
}
}
}
}
impl<S, U> DispatcherInner<S, U>
where
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>> + 'static,
U: Decoder + Encoder + 'static,
{
/// spawn service call
fn spawn_service_call(&self, fut: S::Future) {
self.shared.inflight.set(self.shared.inflight.get() + 1);
let st = self.state.get_ref();
let shared = self.shared.clone();
spawn(async move {
let item = fut.await;
shared.handle_result(item, st.write());
});
}
fn handle_result(
&self,
item: Result<Option<<U as Encoder>::Item>, S::Error>,
write: WriteRef<'_>,
) {
match write.encode_result(item, &self.shared.codec) {
Ok(true) => (),
Ok(false) => write.enable_backpressure(None),
Err(Either::Left(err)) => {
self.error.set(Some(err));
}
Err(Either::Right(err)) => {
self.shared.error.set(Some(DispatcherError::Encoder(err)))
}
}
}
fn poll_service(
&self,
srv: &S,
cx: &mut Context<'_>,
read: ReadRef<'_>,
) -> Poll<PollService<U>> {
match srv.poll_ready(cx) {
Poll::Ready(Ok(_)) => {
// service is ready, wake io read task
read.resume();
// check keepalive timeout
self.check_keepalive();
// check for errors
Poll::Ready(if let Some(err) = self.shared.error.take() {
log::trace!("error occured, stopping dispatcher");
self.unregister_keepalive();
self.st.set(DispatcherState::Stop);
match err {
DispatcherError::KeepAlive => {
PollService::Item(DispatchItem::KeepAliveTimeout)
}
DispatcherError::Encoder(err) => {
PollService::Item(DispatchItem::EncoderError(err))
}
DispatcherError::Service(err) => {
self.error.set(Some(err));
PollService::ServiceError
}
}
} else if self.state.is_dispatcher_stopped() {
log::trace!("dispatcher is instructed to stop");
self.unregister_keepalive();
// process unhandled data
if let Ok(Some(el)) = read.decode(&self.shared.codec) {
PollService::Item(DispatchItem::Item(el))
} else {
self.st.set(DispatcherState::Stop);
// get io error
if let Some(err) = self.state.take_error() {
PollService::Item(DispatchItem::IoError(err))
} else {
PollService::ServiceError
}
}
} else {
PollService::Ready
})
}
// pause io read task
Poll::Pending => {
log::trace!("service is not ready, register dispatch task");
read.pause(cx);
Poll::Pending
}
// handle service readiness error
Poll::Ready(Err(err)) => {
log::trace!("service readiness check failed, stopping");
self.st.set(DispatcherState::Stop);
self.error.set(Some(err));
self.unregister_keepalive();
self.ready_err.set(true);
Poll::Ready(PollService::ServiceError)
}
}
}
fn ka(&self) -> Seconds {
self.ka_timeout
}
fn ka_enabled(&self) -> bool {
self.ka_timeout.non_zero()
}
/// check keepalive timeout
fn check_keepalive(&self) {
if self.state.is_keepalive() {
log::trace!("keepalive timeout");
if let Some(err) = self.shared.error.take() {
self.shared.error.set(Some(err));
} else {
self.shared.error.set(Some(DispatcherError::KeepAlive));
}
}
}
/// update keep-alive timer
fn update_keepalive(&self) {
if self.ka_enabled() {
let updated = now();
if updated != self.ka_updated.get() {
let ka = time::Duration::from(self.ka());
self.timer.register(
updated + ka,
self.ka_updated.get() + ka,
&self.state,
);
self.ka_updated.set(updated);
}
}
}
/// unregister keep-alive timer
fn unregister_keepalive(&self) {
if self.ka_enabled() {
self.timer.unregister(
self.ka_updated.get() + time::Duration::from(self.ka()),
&self.state,
);
}
}
}
#[cfg(test)]
mod tests {
use rand::Rng;
use std::sync::{atomic::AtomicBool, atomic::Ordering::Relaxed, Arc, Mutex};
use std::{cell::RefCell, time::Duration};
use ntex_bytes::{Bytes, PoolId, PoolRef};
use ntex_codec::BytesCodec;
use ntex_util::future::Ready;
use ntex_util::time::{sleep, Millis};
use crate::testing::IoTest;
use crate::{state::Flags, state::IoStateInner, Io, IoStream, WriteRef};
use super::*;
pub(crate) struct State(Rc<IoStateInner>);
impl State {
fn flags(&self) -> Flags {
self.0.flags.get()
}
fn write(&'_ self) -> WriteRef<'_> {
WriteRef(self.0.as_ref())
}
fn close(&self) {
self.0.insert_flags(Flags::DSP_STOP);
self.0.dispatch_task.wake();
}
fn set_memory_pool(&self, pool: PoolRef) {
self.0.pool.set(pool);
}
}
impl<S, U> Dispatcher<S, U>
where
S: Service<Request = DispatchItem<U>, Response = Option<Response<U>>>,
S::Error: 'static,
S::Future: 'static,
U: Decoder + Encoder + 'static,
<U as Encoder>::Item: 'static,
{
/// Construct new `Dispatcher` instance
pub(crate) fn debug<T: IoStream, F: IntoService<S>>(
io: T,
codec: U,
service: F,
) -> (Self, State) {
let state = Io::new(io);
let timer = Timer::default();
let ka_timeout = Seconds(1);
let ka_updated = now();
let shared = Rc::new(DispatcherShared {
codec: codec,
error: Cell::new(None),
inflight: Cell::new(0),
});
let inner = State(state.0 .0.clone());
let expire = ka_updated + Duration::from_millis(500);
timer.register(expire, expire, &state);
(
Dispatcher {
service: service.into_service(),
fut: None,
inner: DispatcherInner {
ka_updated: Cell::new(ka_updated),
error: Cell::new(None),
ready_err: Cell::new(false),
st: Cell::new(DispatcherState::Processing),
pool: state.memory_pool().pool(),
state: state.into_boxed(),
shared,
timer,
ka_timeout,
},
},
inner,
)
}
}
#[ntex::test]
async fn test_basic() {
let (client, server) = IoTest::create();
client.remote_buffer_cap(1024);
client.write("GET /test HTTP/1\r\n\r\n");
let (disp, _) = Dispatcher::debug(
server,
BytesCodec,
ntex_service::fn_service(|msg: DispatchItem<BytesCodec>| async move {
sleep(Millis(50)).await;
if let DispatchItem::Item(msg) = msg {
Ok::<_, ()>(Some(msg.freeze()))
} else {
panic!()
}
}),
);
spawn(async move {
let _ = disp.await;
});
sleep(Millis(25)).await;
let buf = client.read().await.unwrap();
assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
client.write("GET /test HTTP/1\r\n\r\n");
let buf = client.read().await.unwrap();
assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
client.close().await;
assert!(client.is_server_dropped());
}
#[ntex::test]
async fn test_sink() {
let (client, server) = IoTest::create();
client.remote_buffer_cap(1024);
client.write("GET /test HTTP/1\r\n\r\n");
let (disp, st) = Dispatcher::debug(
server,
BytesCodec,
ntex_service::fn_service(|msg: DispatchItem<BytesCodec>| async move {
if let DispatchItem::Item(msg) = msg {
Ok::<_, ()>(Some(msg.freeze()))
} else {
panic!()
}
}),
);
spawn(async move {
let _ = disp.disconnect_timeout(Seconds(1)).await;
});
let buf = client.read().await.unwrap();
assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
assert!(st
.write()
.encode(Bytes::from_static(b"test"), &mut BytesCodec)
.is_ok());
let buf = client.read().await.unwrap();
assert_eq!(buf, Bytes::from_static(b"test"));
st.close();
sleep(Millis(1100)).await;
assert!(client.is_server_dropped());
}
#[ntex::test]
async fn test_err_in_service() {
let (client, server) = IoTest::create();
client.remote_buffer_cap(0);
client.write("GET /test HTTP/1\r\n\r\n");
let (disp, state) = Dispatcher::debug(
server,
BytesCodec,
ntex_service::fn_service(|_: DispatchItem<BytesCodec>| async move {
Err::<Option<Bytes>, _>(())
}),
);
state
.write()
.encode(
Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"),
&mut BytesCodec,
)
.unwrap();
spawn(async move {
let _ = disp.await;
});
// buffer should be flushed
client.remote_buffer_cap(1024);
let buf = client.read().await.unwrap();
assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
// write side must be closed, dispatcher waiting for read side to close
assert!(client.is_closed());
// close read side
client.close().await;
assert!(client.is_server_dropped());
}
#[ntex::test]
async fn test_err_in_service_ready() {
let (client, server) = IoTest::create();
client.remote_buffer_cap(0);
client.write("GET /test HTTP/1\r\n\r\n");
let counter = Rc::new(Cell::new(0));
struct Srv(Rc<Cell<usize>>);
impl Service for Srv {
type Request = DispatchItem<BytesCodec>;
type Response = Option<Response<BytesCodec>>;
type Error = ();
type Future = Ready<Option<Response<BytesCodec>>, ()>;
fn poll_ready(&self, _: &mut Context<'_>) -> Poll<Result<(), ()>> {
self.0.set(self.0.get() + 1);
Poll::Ready(Err(()))
}
fn call(&self, _: DispatchItem<BytesCodec>) -> Self::Future {
Ready::Ok(None)
}
}
let (disp, state) = Dispatcher::debug(server, BytesCodec, Srv(counter.clone()));
state
.write()
.encode(
Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"),
&mut BytesCodec,
)
.unwrap();
spawn(async move {
let _ = disp.await;
});
// buffer should be flushed
client.remote_buffer_cap(1024);
let buf = client.read().await.unwrap();
assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
// write side must be closed, dispatcher waiting for read side to close
assert!(client.is_closed());
// close read side
client.close().await;
assert!(client.is_server_dropped());
// service must be checked for readiness only once
assert_eq!(counter.get(), 1);
}
#[ntex::test]
async fn test_write_backpressure() {
let (client, server) = IoTest::create();
// do not allow to write to socket
client.remote_buffer_cap(0);
client.write("GET /test HTTP/1\r\n\r\n");
let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
let data2 = data.clone();
let (disp, state) = Dispatcher::debug(
server,
BytesCodec,
ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
let data = data2.clone();
async move {
match msg {
DispatchItem::Item(_) => {
data.lock().unwrap().borrow_mut().push(0);
let bytes = rand::thread_rng()
.sample_iter(&rand::distributions::Alphanumeric)
.take(65_536)
.map(char::from)
.collect::<String>();
return Ok::<_, ()>(Some(Bytes::from(bytes)));
}
DispatchItem::WBackPressureEnabled => {
data.lock().unwrap().borrow_mut().push(1);
}
DispatchItem::WBackPressureDisabled => {
data.lock().unwrap().borrow_mut().push(2);
}
_ => (),
}
Ok(None)
}
}),
);
let pool = PoolId::P10.pool_ref();
pool.set_read_params(8 * 1024, 1024);
pool.set_write_params(16 * 1024, 1024);
state.set_memory_pool(pool);
spawn(async move {
let _ = disp.await;
});
let buf = client.read_any();
assert_eq!(buf, Bytes::from_static(b""));
client.write("GET /test HTTP/1\r\n\r\n");
sleep(Millis(25)).await;
// buf must be consumed
assert_eq!(client.remote_buffer(|buf| buf.len()), 0);
// response message
assert!(!state.write().is_ready());
assert_eq!(state.write().with_buf(|buf| buf.len()), 65536);
client.remote_buffer_cap(10240);
sleep(Millis(50)).await;
assert_eq!(state.write().with_buf(|buf| buf.len()), 55296);
client.remote_buffer_cap(45056);
sleep(Millis(50)).await;
assert_eq!(state.write().with_buf(|buf| buf.len()), 10240);
// backpressure disabled
assert!(state.write().is_ready());
assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1, 2]);
}
#[ntex::test]
async fn test_keepalive() {
let (client, server) = IoTest::create();
// do not allow to write to socket
client.remote_buffer_cap(1024);
client.write("GET /test HTTP/1\r\n\r\n");
let data = Arc::new(Mutex::new(RefCell::new(Vec::new())));
let data2 = data.clone();
let (disp, state) = Dispatcher::debug(
server,
BytesCodec,
ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
let data = data2.clone();
async move {
match msg {
DispatchItem::Item(bytes) => {
data.lock().unwrap().borrow_mut().push(0);
return Ok::<_, ()>(Some(bytes.freeze()));
}
DispatchItem::KeepAliveTimeout => {
data.lock().unwrap().borrow_mut().push(1);
}
_ => (),
}
Ok(None)
}
}),
);
spawn(async move {
let _ = disp
.keepalive_timeout(Seconds::ZERO)
.keepalive_timeout(Seconds(1))
.await;
});
state.0.disconnect_timeout.set(Seconds(1));
let buf = client.read().await.unwrap();
assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));
sleep(Millis(3500)).await;
// write side must be closed, dispatcher should fail with keep-alive
let flags = state.flags();
assert!(flags.contains(Flags::IO_SHUTDOWN));
assert!(flags.contains(Flags::DSP_KEEPALIVE));
assert!(client.is_closed());
assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]);
}
#[ntex::test]
async fn test_unhandled_data() {
let handled = Arc::new(AtomicBool::new(false));
let handled2 = handled.clone();
let (client, server) = IoTest::create();
client.remote_buffer_cap(1024);
client.write("GET /test HTTP/1\r\n\r\n");
let (disp, _) = Dispatcher::debug(
server,
BytesCodec,
ntex_service::fn_service(move |msg: DispatchItem<BytesCodec>| {
handled2.store(true, Relaxed);
async move {
sleep(Millis(50)).await;
if let DispatchItem::Item(msg) = msg {
Ok::<_, ()>(Some(msg.freeze()))
} else {
panic!()
}
}
}),
);
client.close().await;
spawn(async move {
let _ = disp.await;
});
sleep(Millis(50)).await;
assert!(handled.load(Relaxed));
}
}

151
ntex-io/src/filter.rs Normal file
View file

@ -0,0 +1,151 @@
use std::{io, rc::Rc, task::Context, task::Poll};
use ntex_bytes::BytesMut;
use super::state::{Flags, IoStateInner};
use super::{Filter, ReadFilter, WriteFilter, WriteReadiness};
pub struct DefaultFilter(Rc<IoStateInner>);
impl DefaultFilter {
pub(crate) fn new(inner: Rc<IoStateInner>) -> Self {
DefaultFilter(inner)
}
}
impl Filter for DefaultFilter {}
impl ReadFilter for DefaultFilter {
#[inline]
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), ()>> {
let flags = self.0.flags.get();
if flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) {
Poll::Ready(Err(()))
} else if flags.intersects(Flags::RD_PAUSED) {
self.0.read_task.register(cx.waker());
Poll::Pending
} else {
self.0.read_task.register(cx.waker());
Poll::Ready(Ok(()))
}
}
#[inline]
fn read_closed(&self, err: Option<io::Error>) {
if err.is_some() {
self.0.error.set(err);
}
self.0.write_task.wake();
self.0.dispatch_task.wake();
self.0.insert_flags(Flags::IO_ERR | Flags::DSP_STOP);
self.0.notify_disconnect();
}
#[inline]
fn get_read_buf(&self) -> Option<BytesMut> {
self.0.read_buf.take()
}
#[inline]
fn release_read_buf(&self, buf: BytesMut, new_bytes: usize) {
if new_bytes > 0 {
if buf.len() > self.0.pool.get().read_params().high as usize {
log::trace!(
"buffer is too large {}, enable read back-pressure",
buf.len()
);
self.0.insert_flags(Flags::RD_READY | Flags::RD_BUF_FULL);
} else {
self.0.insert_flags(Flags::RD_READY);
}
self.0.dispatch_task.wake();
}
self.0.read_buf.set(Some(buf));
}
}
impl WriteFilter for DefaultFilter {
#[inline]
fn poll_write_ready(
&self,
cx: &mut Context<'_>,
) -> Poll<Result<(), WriteReadiness>> {
let flags = self.0.flags.get();
if flags.contains(Flags::IO_ERR) {
Poll::Ready(Err(WriteReadiness::Terminate))
} else if flags.intersects(Flags::IO_SHUTDOWN) {
Poll::Ready(Err(WriteReadiness::Shutdown))
} else {
self.0.write_task.register(cx.waker());
Poll::Ready(Ok(()))
}
}
#[inline]
fn write_closed(&self, err: Option<io::Error>) {
if err.is_some() {
self.0.error.set(err);
}
self.0.read_task.wake();
self.0.dispatch_task.wake();
self.0.insert_flags(Flags::IO_ERR | Flags::DSP_STOP);
self.0.notify_disconnect();
}
#[inline]
fn get_write_buf(&self) -> Option<BytesMut> {
self.0.write_buf.take()
}
#[inline]
fn release_write_buf(&self, buf: BytesMut) {
let pool = self.0.pool.get();
if buf.is_empty() {
pool.release_write_buf(buf);
} else {
self.0.write_buf.set(Some(buf));
}
}
}
pub(crate) struct NullFilter;
const NULL: NullFilter = NullFilter;
impl NullFilter {
pub(super) fn get() -> &'static dyn Filter {
&NULL
}
}
impl Filter for NullFilter {}
impl ReadFilter for NullFilter {
fn poll_read_ready(&self, _: &mut Context<'_>) -> Poll<Result<(), ()>> {
Poll::Ready(Err(()))
}
fn read_closed(&self, _: Option<io::Error>) {}
fn get_read_buf(&self) -> Option<BytesMut> {
None
}
fn release_read_buf(&self, _: BytesMut, _: usize) {}
}
impl WriteFilter for NullFilter {
fn poll_write_ready(&self, _: &mut Context<'_>) -> Poll<Result<(), WriteReadiness>> {
Poll::Ready(Err(WriteReadiness::Terminate))
}
fn write_closed(&self, _: Option<io::Error>) {}
fn get_write_buf(&self) -> Option<BytesMut> {
None
}
fn release_write_buf(&self, _: BytesMut) {}
}

141
ntex-io/src/lib.rs Normal file
View file

@ -0,0 +1,141 @@
use std::{fmt, future::Future, io, task::Context, task::Poll};
pub mod testing;
mod dispatcher;
mod filter;
mod state;
mod tasks;
mod time;
mod utils;
#[cfg(feature = "tokio")]
mod tokio_impl;
use ntex_bytes::BytesMut;
use ntex_codec::{Decoder, Encoder};
pub use self::dispatcher::Dispatcher;
pub use self::state::{Io, IoRef, ReadRef, WriteRef};
pub use self::tasks::{ReadState, WriteState};
pub use self::time::Timer;
pub use self::utils::{from_iostream, into_boxed};
pub type IoBoxed = Io<Box<dyn Filter>>;
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub enum WriteReadiness {
Shutdown,
Terminate,
}
pub trait ReadFilter {
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), ()>>;
fn read_closed(&self, err: Option<io::Error>);
fn get_read_buf(&self) -> Option<BytesMut>;
fn release_read_buf(&self, buf: BytesMut, new_bytes: usize);
}
pub trait WriteFilter {
fn poll_write_ready(&self, cx: &mut Context<'_>)
-> Poll<Result<(), WriteReadiness>>;
fn write_closed(&self, err: Option<io::Error>);
fn get_write_buf(&self) -> Option<BytesMut>;
fn release_write_buf(&self, buf: BytesMut);
}
pub trait Filter: ReadFilter + WriteFilter {}
pub trait FilterFactory<F: Filter>: Sized {
type Filter: Filter;
type Error: fmt::Debug;
type Future: Future<Output = Result<Io<Self::Filter>, Self::Error>>;
fn create(&self, st: Io<F>) -> Self::Future;
}
pub trait IoStream {
fn start(self, _: ReadState, _: WriteState);
}
/// Framed transport item
pub enum DispatchItem<U: Encoder + Decoder> {
Item(<U as Decoder>::Item),
/// Write back-pressure enabled
WBackPressureEnabled,
/// Write back-pressure disabled
WBackPressureDisabled,
/// Keep alive timeout
KeepAliveTimeout,
/// Decoder parse error
DecoderError(<U as Decoder>::Error),
/// Encoder parse error
EncoderError(<U as Encoder>::Error),
/// Unexpected io error
IoError(io::Error),
}
impl<U> fmt::Debug for DispatchItem<U>
where
U: Encoder + Decoder,
<U as Decoder>::Item: fmt::Debug,
{
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
DispatchItem::Item(ref item) => {
write!(fmt, "DispatchItem::Item({:?})", item)
}
DispatchItem::WBackPressureEnabled => {
write!(fmt, "DispatchItem::WBackPressureEnabled")
}
DispatchItem::WBackPressureDisabled => {
write!(fmt, "DispatchItem::WBackPressureDisabled")
}
DispatchItem::KeepAliveTimeout => {
write!(fmt, "DispatchItem::KeepAliveTimeout")
}
DispatchItem::EncoderError(ref e) => {
write!(fmt, "DispatchItem::EncoderError({:?})", e)
}
DispatchItem::DecoderError(ref e) => {
write!(fmt, "DispatchItem::DecoderError({:?})", e)
}
DispatchItem::IoError(ref e) => {
write!(fmt, "DispatchItem::IoError({:?})", e)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use ntex_codec::BytesCodec;
#[test]
fn test_fmt() {
type T = DispatchItem<BytesCodec>;
let err = T::EncoderError(io::Error::new(io::ErrorKind::Other, "err"));
assert!(format!("{:?}", err).contains("DispatchItem::Encoder"));
let err = T::DecoderError(io::Error::new(io::ErrorKind::Other, "err"));
assert!(format!("{:?}", err).contains("DispatchItem::Decoder"));
let err = T::IoError(io::Error::new(io::ErrorKind::Other, "err"));
assert!(format!("{:?}", err).contains("DispatchItem::IoError"));
assert!(format!("{:?}", T::WBackPressureEnabled)
.contains("DispatchItem::WBackPressureEnabled"));
assert!(format!("{:?}", T::WBackPressureDisabled)
.contains("DispatchItem::WBackPressureDisabled"));
assert!(format!("{:?}", T::KeepAliveTimeout)
.contains("DispatchItem::KeepAliveTimeout"));
}
}

1097
ntex-io/src/state.rs Normal file

File diff suppressed because it is too large Load diff

98
ntex-io/src/tasks.rs Normal file
View file

@ -0,0 +1,98 @@
use std::{io, rc::Rc, task::Context, task::Poll};
use ntex_bytes::{BytesMut, PoolRef};
use ntex_util::time::Seconds;
use super::{state::Flags, state::IoStateInner, WriteReadiness};
pub struct ReadState(pub(super) Rc<IoStateInner>);
impl ReadState {
#[inline]
pub fn memory_pool(&self) -> PoolRef {
self.0.pool.get()
}
#[inline]
pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), ()>> {
self.0.filter.get().poll_read_ready(cx)
}
#[inline]
pub fn close(&self, err: Option<io::Error>) {
self.0.filter.get().read_closed(err);
}
#[inline]
pub fn get_read_buf(&self) -> BytesMut {
self.0
.filter
.get()
.get_read_buf()
.unwrap_or_else(|| self.0.pool.get().get_read_buf())
}
#[inline]
pub fn release_read_buf(&self, buf: BytesMut, new_bytes: usize) {
if buf.is_empty() {
self.0.pool.get().release_read_buf(buf);
} else {
self.0.filter.get().release_read_buf(buf, new_bytes);
}
}
}
pub struct WriteState(pub(super) Rc<IoStateInner>);
impl WriteState {
#[inline]
pub fn memory_pool(&self) -> PoolRef {
self.0.pool.get()
}
#[inline]
pub fn disconnect_timeout(&self) -> Seconds {
self.0.disconnect_timeout.get()
}
#[inline]
pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), WriteReadiness>> {
self.0.filter.get().poll_write_ready(cx)
}
#[inline]
pub fn close(&self, err: Option<io::Error>) {
self.0.filter.get().write_closed(err)
}
#[inline]
pub fn get_write_buf(&self) -> Option<BytesMut> {
self.0.write_buf.take()
}
#[inline]
pub fn release_write_buf(&self, buf: BytesMut) {
let pool = self.0.pool.get();
if buf.is_empty() {
pool.release_write_buf(buf);
let mut flags = self.0.flags.get();
if flags.intersects(Flags::WR_WAIT | Flags::WR_BACKPRESSURE) {
flags.remove(Flags::WR_WAIT | Flags::WR_BACKPRESSURE);
self.0.flags.set(flags);
self.0.dispatch_task.wake();
}
} else {
// if write buffer is smaller than high watermark value, turn off back-pressure
if buf.len() < pool.write_params_high() << 1 {
let mut flags = self.0.flags.get();
if flags.contains(Flags::WR_BACKPRESSURE) {
flags.remove(Flags::WR_BACKPRESSURE);
self.0.flags.set(flags);
self.0.dispatch_task.wake();
}
}
self.0.write_buf.set(Some(buf))
}
}
}

746
ntex-io/src/testing.rs Normal file
View file

@ -0,0 +1,746 @@
use std::cell::{Cell, RefCell};
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};
use std::{cmp, fmt, io, mem};
use ntex_bytes::{BufMut, BytesMut};
use ntex_util::future::poll_fn;
use ntex_util::time::{sleep, Millis};
#[derive(Default)]
struct AtomicWaker(Arc<Mutex<RefCell<Option<Waker>>>>);
impl AtomicWaker {
fn wake(&self) {
if let Some(waker) = self.0.lock().unwrap().borrow_mut().take() {
waker.wake()
}
}
}
impl fmt::Debug for AtomicWaker {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "AtomicWaker")
}
}
/// Async io stream
#[derive(Debug)]
pub struct IoTest {
tp: Type,
state: Arc<Cell<State>>,
local: Arc<Mutex<RefCell<Channel>>>,
remote: Arc<Mutex<RefCell<Channel>>>,
}
bitflags::bitflags! {
struct Flags: u8 {
const FLUSHED = 0b0000_0001;
const CLOSED = 0b0000_0010;
}
}
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
enum Type {
Client,
Server,
ClientClone,
ServerClone,
}
#[derive(Copy, Clone, Default, Debug)]
struct State {
client_dropped: bool,
server_dropped: bool,
}
#[derive(Default, Debug)]
struct Channel {
buf: BytesMut,
buf_cap: usize,
flags: Flags,
waker: AtomicWaker,
read: IoState,
write: IoState,
}
impl Channel {
fn is_closed(&self) -> bool {
self.flags.contains(Flags::CLOSED)
}
}
impl Default for Flags {
fn default() -> Self {
Flags::empty()
}
}
#[derive(Debug)]
enum IoState {
Ok,
Pending,
Close,
Err(io::Error),
}
impl Default for IoState {
fn default() -> Self {
IoState::Ok
}
}
impl IoTest {
/// Create a two interconnected streams
pub fn create() -> (IoTest, IoTest) {
let local = Arc::new(Mutex::new(RefCell::new(Channel::default())));
let remote = Arc::new(Mutex::new(RefCell::new(Channel::default())));
let state = Arc::new(Cell::new(State::default()));
(
IoTest {
tp: Type::Client,
local: local.clone(),
remote: remote.clone(),
state: state.clone(),
},
IoTest {
state,
tp: Type::Server,
local: remote,
remote: local,
},
)
}
pub fn is_client_dropped(&self) -> bool {
self.state.get().client_dropped
}
pub fn is_server_dropped(&self) -> bool {
self.state.get().server_dropped
}
/// Check if channel is closed from remoote side
pub fn is_closed(&self) -> bool {
self.remote.lock().unwrap().borrow().is_closed()
}
/// Set read to Pending state
pub fn read_pending(&self) {
self.remote.lock().unwrap().borrow_mut().read = IoState::Pending;
}
/// Set read to error
pub fn read_error(&self, err: io::Error) {
self.remote.lock().unwrap().borrow_mut().read = IoState::Err(err);
}
/// Set write error on remote side
pub fn write_error(&self, err: io::Error) {
self.local.lock().unwrap().borrow_mut().write = IoState::Err(err);
}
/// Access read buffer.
pub fn local_buffer<F, R>(&self, f: F) -> R
where
F: FnOnce(&mut BytesMut) -> R,
{
let guard = self.local.lock().unwrap();
let mut ch = guard.borrow_mut();
f(&mut ch.buf)
}
/// Access remote buffer.
pub fn remote_buffer<F, R>(&self, f: F) -> R
where
F: FnOnce(&mut BytesMut) -> R,
{
let guard = self.remote.lock().unwrap();
let mut ch = guard.borrow_mut();
f(&mut ch.buf)
}
/// Closed remote side.
pub async fn close(&self) {
{
let guard = self.remote.lock().unwrap();
let mut remote = guard.borrow_mut();
remote.read = IoState::Close;
remote.waker.wake();
log::trace!("close remote socket");
}
sleep(Millis(35)).await;
}
/// Add extra data to the remote buffer and notify reader
pub fn write<T: AsRef<[u8]>>(&self, data: T) {
let guard = self.remote.lock().unwrap();
let mut write = guard.borrow_mut();
write.buf.extend_from_slice(data.as_ref());
write.waker.wake();
}
/// Read any available data
pub fn remote_buffer_cap(&self, cap: usize) {
// change cap
self.local.lock().unwrap().borrow_mut().buf_cap = cap;
// wake remote
self.remote.lock().unwrap().borrow().waker.wake();
}
/// Read any available data
pub fn read_any(&self) -> BytesMut {
self.local.lock().unwrap().borrow_mut().buf.split()
}
/// Read data, if data is not available wait for it
pub async fn read(&self) -> Result<BytesMut, io::Error> {
if self.local.lock().unwrap().borrow().buf.is_empty() {
poll_fn(|cx| {
let guard = self.local.lock().unwrap();
let read = guard.borrow_mut();
if read.buf.is_empty() {
let closed = match self.tp {
Type::Client | Type::ClientClone => {
self.is_server_dropped() || read.is_closed()
}
Type::Server | Type::ServerClone => self.is_client_dropped(),
};
if closed {
Poll::Ready(())
} else {
*read.waker.0.lock().unwrap().borrow_mut() =
Some(cx.waker().clone());
drop(read);
drop(guard);
Poll::Pending
}
} else {
Poll::Ready(())
}
})
.await;
}
Ok(self.local.lock().unwrap().borrow_mut().buf.split())
}
pub fn poll_read_buf(
&self,
cx: &mut Context<'_>,
buf: &mut BytesMut,
) -> Poll<io::Result<usize>> {
let guard = self.local.lock().unwrap();
let mut ch = guard.borrow_mut();
*ch.waker.0.lock().unwrap().borrow_mut() = Some(cx.waker().clone());
if !ch.buf.is_empty() {
let size = std::cmp::min(ch.buf.len(), buf.remaining_mut());
let b = ch.buf.split_to(size);
buf.put_slice(&b);
return Poll::Ready(Ok(size));
}
match mem::take(&mut ch.read) {
IoState::Ok => Poll::Pending,
IoState::Close => {
ch.read = IoState::Close;
Poll::Ready(Ok(0))
}
IoState::Pending => Poll::Pending,
IoState::Err(e) => Poll::Ready(Err(e)),
}
}
pub fn poll_write_buf(
&self,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let guard = self.remote.lock().unwrap();
let mut ch = guard.borrow_mut();
match mem::take(&mut ch.write) {
IoState::Ok => {
let cap = cmp::min(buf.len(), ch.buf_cap);
if cap > 0 {
ch.buf.extend(&buf[..cap]);
ch.buf_cap -= cap;
ch.flags.remove(Flags::FLUSHED);
ch.waker.wake();
Poll::Ready(Ok(cap))
} else {
*self
.local
.lock()
.unwrap()
.borrow_mut()
.waker
.0
.lock()
.unwrap()
.borrow_mut() = Some(cx.waker().clone());
Poll::Pending
}
}
IoState::Close => Poll::Ready(Ok(0)),
IoState::Pending => {
*self
.local
.lock()
.unwrap()
.borrow_mut()
.waker
.0
.lock()
.unwrap()
.borrow_mut() = Some(cx.waker().clone());
Poll::Pending
}
IoState::Err(e) => Poll::Ready(Err(e)),
}
}
}
impl Clone for IoTest {
fn clone(&self) -> Self {
let tp = match self.tp {
Type::Server => Type::ServerClone,
Type::Client => Type::ClientClone,
val => val,
};
IoTest {
tp,
local: self.local.clone(),
remote: self.remote.clone(),
state: self.state.clone(),
}
}
}
impl Drop for IoTest {
fn drop(&mut self) {
let mut state = self.state.get();
match self.tp {
Type::Server => state.server_dropped = true,
Type::Client => state.client_dropped = true,
_ => (),
}
self.state.set(state);
}
}
#[cfg(feature = "tokio")]
mod tokio {
use std::task::{Context, Poll};
use std::{cmp, io, mem, pin::Pin};
use tok_io::io::{AsyncRead, AsyncWrite, ReadBuf};
use super::{Flags, IoState, IoTest};
impl AsyncRead for IoTest {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let guard = self.local.lock().unwrap();
let mut ch = guard.borrow_mut();
*ch.waker.0.lock().unwrap().borrow_mut() = Some(cx.waker().clone());
if !ch.buf.is_empty() {
let size = std::cmp::min(ch.buf.len(), buf.remaining());
let b = ch.buf.split_to(size);
buf.put_slice(&b);
return Poll::Ready(Ok(()));
}
match mem::take(&mut ch.read) {
IoState::Ok => Poll::Pending,
IoState::Close => {
ch.read = IoState::Close;
Poll::Ready(Ok(()))
}
IoState::Pending => Poll::Pending,
IoState::Err(e) => Poll::Ready(Err(e)),
}
}
}
impl AsyncWrite for IoTest {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let guard = self.remote.lock().unwrap();
let mut ch = guard.borrow_mut();
match mem::take(&mut ch.write) {
IoState::Ok => {
let cap = cmp::min(buf.len(), ch.buf_cap);
if cap > 0 {
ch.buf.extend(&buf[..cap]);
ch.buf_cap -= cap;
ch.flags.remove(Flags::FLUSHED);
ch.waker.wake();
Poll::Ready(Ok(cap))
} else {
*self
.local
.lock()
.unwrap()
.borrow_mut()
.waker
.0
.lock()
.unwrap()
.borrow_mut() = Some(cx.waker().clone());
Poll::Pending
}
}
IoState::Close => Poll::Ready(Ok(0)),
IoState::Pending => {
*self
.local
.lock()
.unwrap()
.borrow_mut()
.waker
.0
.lock()
.unwrap()
.borrow_mut() = Some(cx.waker().clone());
Poll::Pending
}
IoState::Err(e) => Poll::Ready(Err(e)),
}
}
fn poll_flush(
self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(
self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<io::Result<()>> {
self.local
.lock()
.unwrap()
.borrow_mut()
.flags
.insert(Flags::CLOSED);
Poll::Ready(Ok(()))
}
}
}
#[cfg(not(feature = "tokio"))]
mod non_tokio {
impl IoStream for IoTest {
fn start(self, read: ReadState, write: WriteState) {
let io = Rc::new(self);
ntex_util::spawn(ReadTask {
io: io.clone(),
state: read,
});
ntex_util::spawn(WriteTask {
io,
state: write,
st: IoWriteState::Processing,
});
}
}
/// Read io task
struct ReadTask {
io: Rc<IoTest>,
state: ReadState,
}
impl Future for ReadTask {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.as_ref();
match this.state.poll_ready(cx) {
Poll::Ready(Err(())) => {
log::trace!("read task is instructed to terminate");
Poll::Ready(())
}
Poll::Ready(Ok(())) => {
let io = &this.io;
let pool = this.state.memory_pool();
let mut buf = self.state.get_read_buf();
let (hw, lw) = pool.read_params().unpack();
// read data from socket
let mut new_bytes = 0;
loop {
// make sure we've got room
let remaining = buf.remaining_mut();
if remaining < lw {
buf.reserve(hw - remaining);
}
match io.poll_read_buf(cx, &mut buf) {
Poll::Pending => {
log::trace!("no more data in io stream");
break;
}
Poll::Ready(Ok(n)) => {
if n == 0 {
log::trace!("io stream is disconnected");
this.state.release_read_buf(buf, new_bytes);
this.state.close(None);
return Poll::Ready(());
} else {
new_bytes += n;
if buf.len() > hw {
break;
}
}
}
Poll::Ready(Err(err)) => {
log::trace!("read task failed on io {:?}", err);
this.state.release_read_buf(buf, new_bytes);
this.state.close(Some(err));
return Poll::Ready(());
}
}
}
this.state.release_read_buf(buf, new_bytes);
Poll::Pending
}
Poll::Pending => Poll::Pending,
}
}
}
#[derive(Debug)]
enum IoWriteState {
Processing,
Shutdown(Option<Sleep>, Shutdown),
}
#[derive(Debug)]
enum Shutdown {
None,
Flushed,
Stopping,
}
/// Write io task
struct WriteTask {
st: IoWriteState,
io: Rc<IoTest>,
state: WriteState,
}
impl Future for WriteTask {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.as_mut().get_mut();
match this.st {
IoWriteState::Processing => {
match this.state.poll_ready(cx) {
Poll::Ready(Ok(())) => {
// flush framed instance
match flush_io(&this.io, &this.state, cx) {
Poll::Pending | Poll::Ready(true) => Poll::Pending,
Poll::Ready(false) => Poll::Ready(()),
}
}
Poll::Ready(Err(WriteReadiness::Shutdown)) => {
log::trace!("write task is instructed to shutdown");
this.st = IoWriteState::Shutdown(
this.state.disconnect_timeout().map(sleep),
Shutdown::None,
);
self.poll(cx)
}
Poll::Ready(Err(WriteReadiness::Terminate)) => {
log::trace!("write task is instructed to terminate");
// shutdown WRITE side
this.io
.local
.lock()
.unwrap()
.borrow_mut()
.flags
.insert(Flags::CLOSED);
this.state.close(None);
Poll::Ready(())
}
Poll::Pending => Poll::Pending,
}
}
IoWriteState::Shutdown(ref mut delay, ref mut st) => {
// close WRITE side and wait for disconnect on read side.
// use disconnect timeout, otherwise it could hang forever.
loop {
match st {
Shutdown::None => {
// flush write buffer
match flush_io(&this.io, &this.state, cx) {
Poll::Ready(true) => {
*st = Shutdown::Flushed;
continue;
}
Poll::Ready(false) => {
log::trace!(
"write task is closed with err during flush"
);
return Poll::Ready(());
}
_ => (),
}
}
Shutdown::Flushed => {
// shutdown WRITE side
this.io
.local
.lock()
.unwrap()
.borrow_mut()
.flags
.insert(Flags::CLOSED);
*st = Shutdown::Stopping;
continue;
}
Shutdown::Stopping => {
// read until 0 or err
let io = &this.io;
loop {
let mut buf = BytesMut::new();
match io.poll_read_buf(cx, &mut buf) {
Poll::Ready(Err(e)) => {
this.state.close(Some(e));
log::trace!("write task is stopped");
return Poll::Ready(());
}
Poll::Ready(Ok(n)) if n == 0 => {
this.state.close(None);
log::trace!("write task is stopped");
return Poll::Ready(());
}
Poll::Pending => break,
_ => (),
}
}
}
}
// disconnect timeout
if let Some(ref delay) = delay {
if delay.poll_elapsed(cx).is_pending() {
return Poll::Pending;
}
}
log::trace!("write task is stopped after delay");
this.state.close(None);
return Poll::Ready(());
}
}
}
}
}
/// Flush write buffer to underlying I/O stream.
pub(super) fn flush_io(
io: &IoTest,
state: &WriteState,
cx: &mut Context<'_>,
) -> Poll<bool> {
let mut buf = if let Some(buf) = state.get_write_buf() {
buf
} else {
return Poll::Ready(true);
};
let len = buf.len();
let pool = state.memory_pool();
if len != 0 {
log::trace!("flushing framed transport: {}", len);
let mut written = 0;
while written < len {
match io.poll_write_buf(cx, &buf[written..]) {
Poll::Ready(Ok(n)) => {
if n == 0 {
log::trace!(
"disconnected during flush, written {}",
written
);
pool.release_write_buf(buf);
state.close(Some(io::Error::new(
io::ErrorKind::WriteZero,
"failed to write frame to transport",
)));
return Poll::Ready(false);
} else {
written += n
}
}
Poll::Pending => break,
Poll::Ready(Err(e)) => {
log::trace!("error during flush: {}", e);
pool.release_write_buf(buf);
state.close(Some(e));
return Poll::Ready(false);
}
}
}
log::trace!("flushed {} bytes", written);
// remove written data
if written == len {
buf.clear();
state.release_write_buf(buf);
Poll::Ready(true)
} else {
buf.advance(written);
state.release_write_buf(buf);
Poll::Pending
}
} else {
Poll::Ready(true)
}
}
}
#[cfg(test)]
#[allow(clippy::redundant_clone)]
mod tests {
use super::*;
#[ntex::test]
async fn basic() {
let (client, server) = IoTest::create();
assert_eq!(client.tp, Type::Client);
assert_eq!(client.clone().tp, Type::ClientClone);
assert_eq!(server.tp, Type::Server);
assert_eq!(server.clone().tp, Type::ServerClone);
assert!(!server.is_client_dropped());
drop(client);
assert!(server.is_client_dropped());
let server2 = server.clone();
assert!(!server2.is_server_dropped());
drop(server);
assert!(server2.is_server_dropped());
}
}

104
ntex-io/src/time.rs Normal file
View file

@ -0,0 +1,104 @@
use std::{
cell::RefCell, collections::BTreeMap, collections::HashSet, rc::Rc, time::Instant,
};
use ntex_util::spawn;
use ntex_util::time::{now, sleep, Millis};
use super::state::{Flags, IoRef, IoStateInner};
pub struct Timer(Rc<RefCell<Inner>>);
struct Inner {
running: bool,
resolution: Millis,
notifications: BTreeMap<Instant, HashSet<Rc<IoStateInner>, fxhash::FxBuildHasher>>,
}
impl Inner {
fn new(resolution: Millis) -> Self {
Inner {
resolution,
running: false,
notifications: BTreeMap::default(),
}
}
fn unregister(&mut self, expire: Instant, io: &IoRef) {
if let Some(states) = self.notifications.get_mut(&expire) {
states.remove(&io.0);
if states.is_empty() {
self.notifications.remove(&expire);
}
}
}
}
impl Clone for Timer {
fn clone(&self) -> Self {
Timer(self.0.clone())
}
}
impl Default for Timer {
fn default() -> Self {
Timer::new(Millis::ONE_SEC)
}
}
impl Timer {
/// Create new timer with resolution in milliseconds
pub fn new(resolution: Millis) -> Timer {
Timer(Rc::new(RefCell::new(Inner::new(resolution))))
}
pub fn register(&self, expire: Instant, previous: Instant, io: &IoRef) {
let mut inner = self.0.borrow_mut();
inner.unregister(previous, io);
inner
.notifications
.entry(expire)
.or_insert_with(HashSet::default)
.insert(io.0.clone());
if !inner.running {
inner.running = true;
let interval = inner.resolution;
let inner = self.0.clone();
spawn(async move {
loop {
sleep(interval).await;
{
let mut i = inner.borrow_mut();
let now_time = now();
// notify io dispatcher
while let Some(key) = i.notifications.keys().next() {
let key = *key;
if key <= now_time {
for st in i.notifications.remove(&key).unwrap() {
st.dispatch_task.wake();
st.insert_flags(Flags::DSP_KEEPALIVE);
}
} else {
break;
}
}
// new tick
if i.notifications.is_empty() {
i.running = false;
break;
}
}
}
});
}
}
pub fn unregister(&self, expire: Instant, io: &IoRef) {
self.0.borrow_mut().unregister(expire, io);
}
}

314
ntex-io/src/tokio_impl.rs Normal file
View file

@ -0,0 +1,314 @@
use std::task::{Context, Poll};
use std::{cell::RefCell, future::Future, io, pin::Pin, rc::Rc};
use ntex_bytes::{Buf, BufMut};
use ntex_util::time::{sleep, Sleep};
use tok_io::{io::AsyncRead, io::AsyncWrite, io::ReadBuf};
use super::{IoStream, ReadState, WriteReadiness, WriteState};
impl<T> IoStream for T
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
fn start(self, read: ReadState, write: WriteState) {
let io = Rc::new(RefCell::new(self));
ntex_util::spawn(ReadTask::new(io.clone(), read));
ntex_util::spawn(WriteTask::new(io, write));
}
}
/// Read io task
struct ReadTask<T> {
io: Rc<RefCell<T>>,
state: ReadState,
}
impl<T> ReadTask<T>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
/// Create new read io task
fn new(io: Rc<RefCell<T>>, state: ReadState) -> Self {
Self { io, state }
}
}
impl<T> Future for ReadTask<T>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.as_ref();
match this.state.poll_ready(cx) {
Poll::Ready(Err(())) => {
log::trace!("read task is instructed to shutdown");
Poll::Ready(())
}
Poll::Ready(Ok(())) => {
let pool = this.state.memory_pool();
let mut io = this.io.borrow_mut();
let mut buf = self.state.get_read_buf();
let (hw, lw) = pool.read_params().unpack();
// read data from socket
let mut new_bytes = 0;
loop {
// make sure we've got room
let remaining = buf.remaining_mut();
if remaining < lw {
buf.reserve(hw - remaining);
}
match ntex_codec::poll_read_buf(Pin::new(&mut *io), cx, &mut buf) {
Poll::Pending => break,
Poll::Ready(Ok(n)) => {
if n == 0 {
log::trace!("io stream is disconnected");
this.state.release_read_buf(buf, new_bytes);
this.state.close(None);
return Poll::Ready(());
} else {
new_bytes += n;
if buf.len() > hw {
break;
}
}
}
Poll::Ready(Err(err)) => {
log::trace!("read task failed on io {:?}", err);
this.state.release_read_buf(buf, new_bytes);
this.state.close(Some(err));
return Poll::Ready(());
}
}
}
this.state.release_read_buf(buf, new_bytes);
Poll::Pending
}
Poll::Pending => Poll::Pending,
}
}
}
#[derive(Debug)]
enum IoWriteState {
Processing,
Shutdown(Option<Sleep>, Shutdown),
}
#[derive(Debug)]
enum Shutdown {
None,
Flushed,
Stopping,
}
/// Write io task
struct WriteTask<T> {
st: IoWriteState,
io: Rc<RefCell<T>>,
state: WriteState,
}
impl<T> WriteTask<T>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
/// Create new write io task
fn new(io: Rc<RefCell<T>>, state: WriteState) -> Self {
Self {
io,
state,
st: IoWriteState::Processing,
}
}
}
impl<T> Future for WriteTask<T>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.as_mut().get_mut();
match this.st {
IoWriteState::Processing => {
match this.state.poll_ready(cx) {
Poll::Ready(Ok(())) => {
// flush framed instance
match flush_io(&mut *this.io.borrow_mut(), &this.state, cx) {
Poll::Pending | Poll::Ready(true) => Poll::Pending,
Poll::Ready(false) => Poll::Ready(()),
}
}
Poll::Ready(Err(WriteReadiness::Shutdown)) => {
log::trace!("write task is instructed to shutdown");
this.st = IoWriteState::Shutdown(
this.state.disconnect_timeout().map(sleep),
Shutdown::None,
);
self.poll(cx)
}
Poll::Ready(Err(WriteReadiness::Terminate)) => {
log::trace!("write task is instructed to terminate");
let _ = Pin::new(&mut *this.io.borrow_mut()).poll_shutdown(cx);
this.state.close(None);
Poll::Ready(())
}
Poll::Pending => Poll::Pending,
}
}
IoWriteState::Shutdown(ref mut delay, ref mut st) => {
// close WRITE side and wait for disconnect on read side.
// use disconnect timeout, otherwise it could hang forever.
loop {
match st {
Shutdown::None => {
// flush write buffer
match flush_io(&mut *this.io.borrow_mut(), &this.state, cx) {
Poll::Ready(true) => {
*st = Shutdown::Flushed;
continue;
}
Poll::Ready(false) => {
log::trace!(
"write task is closed with err during flush"
);
return Poll::Ready(());
}
_ => (),
}
}
Shutdown::Flushed => {
// shutdown WRITE side
match Pin::new(&mut *this.io.borrow_mut()).poll_shutdown(cx)
{
Poll::Ready(Ok(_)) => {
*st = Shutdown::Stopping;
continue;
}
Poll::Ready(Err(e)) => {
log::trace!(
"write task is closed with err during shutdown"
);
this.state.close(Some(e));
return Poll::Ready(());
}
_ => (),
}
}
Shutdown::Stopping => {
// read until 0 or err
let mut buf = [0u8; 512];
let mut io = this.io.borrow_mut();
loop {
let mut read_buf = ReadBuf::new(&mut buf);
match Pin::new(&mut *io).poll_read(cx, &mut read_buf) {
Poll::Ready(Err(_)) | Poll::Ready(Ok(_))
if read_buf.filled().is_empty() =>
{
this.state.close(None);
log::trace!("write task is stopped");
return Poll::Ready(());
}
Poll::Pending => break,
_ => (),
}
}
}
}
// disconnect timeout
if let Some(ref delay) = delay {
if delay.poll_elapsed(cx).is_pending() {
return Poll::Pending;
}
}
log::trace!("write task is stopped after delay");
this.state.close(None);
return Poll::Ready(());
}
}
}
}
}
/// Flush write buffer to underlying I/O stream.
pub(super) fn flush_io<T: AsyncRead + AsyncWrite + Unpin>(
io: &mut T,
state: &WriteState,
cx: &mut Context<'_>,
) -> Poll<bool> {
let mut buf = if let Some(buf) = state.get_write_buf() {
buf
} else {
return Poll::Ready(true);
};
let len = buf.len();
let pool = state.memory_pool();
if len != 0 {
// log::trace!("flushing framed transport: {:?}", buf);
let mut written = 0;
while written < len {
match Pin::new(&mut *io).poll_write(cx, &buf[written..]) {
Poll::Pending => break,
Poll::Ready(Ok(n)) => {
if n == 0 {
log::trace!("Disconnected during flush, written {}", written);
pool.release_write_buf(buf);
state.close(Some(io::Error::new(
io::ErrorKind::WriteZero,
"failed to write frame to transport",
)));
return Poll::Ready(false);
} else {
written += n
}
}
Poll::Ready(Err(e)) => {
log::trace!("Error during flush: {}", e);
pool.release_write_buf(buf);
state.close(Some(e));
return Poll::Ready(false);
}
}
}
// log::trace!("flushed {} bytes", written);
// remove written data
let result = if written == len {
buf.clear();
state.release_write_buf(buf);
Poll::Ready(true)
} else {
buf.advance(written);
state.release_write_buf(buf);
Poll::Pending
};
// flush
match Pin::new(&mut *io).poll_flush(cx) {
Poll::Ready(Ok(_)) => result,
Poll::Pending => Poll::Pending,
Poll::Ready(Err(e)) => {
log::trace!("error during flush: {}", e);
state.close(Some(e));
Poll::Ready(false)
}
}
} else {
Poll::Ready(true)
}
}

49
ntex-io/src/utils.rs Normal file
View file

@ -0,0 +1,49 @@
use ntex_service::{fn_factory_with_config, into_service, Service, ServiceFactory};
use super::{Filter, Io, IoBoxed, IoStream};
/// Service that converts any Io<F> stream to IoBoxed stream
pub fn into_boxed<F, S>(
srv: S,
) -> impl ServiceFactory<
Config = S::Config,
Request = Io<F>,
Response = S::Response,
Error = S::Error,
InitError = S::InitError,
>
where
F: Filter + 'static,
S: ServiceFactory<Request = IoBoxed>,
{
fn_factory_with_config(move |cfg: S::Config| {
let fut = srv.new_service(cfg);
async move {
let srv = fut.await?;
Ok(into_service(move |io: Io<F>| srv.call(io.into_boxed())))
}
})
}
/// Service that converts IoStream stream to IoBoxed stream
pub fn from_iostream<S, I>(
srv: S,
) -> impl ServiceFactory<
Config = S::Config,
Request = I,
Response = S::Response,
Error = S::Error,
InitError = S::InitError,
>
where
I: IoStream,
S: ServiceFactory<Request = IoBoxed>,
{
fn_factory_with_config(move |cfg: S::Config| {
let fut = srv.new_service(cfg);
async move {
let srv = fut.await?;
Ok(into_service(move |io| srv.call(Io::new(io).into_boxed())))
}
})
}

View file

@ -16,5 +16,5 @@ syn = { version = "^1", features = ["full", "parsing"] }
proc-macro2 = "^1"
[dev-dependencies]
ntex = "0.3.1"
ntex = "0.4.10"
futures = "0.3.13"

View file

@ -75,7 +75,7 @@ impl Builder {
let (stop_tx, stop) = channel();
let (sys_sender, sys_receiver) = unbounded_channel();
let system = System::construct(
let _system = System::construct(
sys_sender,
Arbiter::new_system(local),
self.stop_on_panic,
@ -87,10 +87,7 @@ impl Builder {
// start the system arbiter
let _ = local.spawn_local(arb);
AsyncSystemRunner {
stop,
_system: system,
}
AsyncSystemRunner { stop, _system }
}
fn create_runtime<F>(self, f: F) -> SystemRunner
@ -108,7 +105,7 @@ impl Builder {
});
// system arbiter
let system = System::construct(
let _system = System::construct(
sys_sender,
Arbiter::new_system(rt.local()),
self.stop_on_panic,
@ -119,11 +116,7 @@ impl Builder {
// init system arbiter and run configuration method
rt.block_on(lazy(move |_| f()));
SystemRunner {
rt,
stop,
_system: system,
}
SystemRunner { rt, stop, _system }
}
}

View file

@ -32,9 +32,10 @@ pub fn sleep(duration: Duration) -> Sleep {
}
}
#[doc(hidden)]
/// Creates new [`Interval`] that yields with interval of `period` with the
/// first tick completing at `start`. The default [`MissedTickBehavior`] is
/// [`Burst`](MissedTickBehavior::Burst), but this can be configured
/// first tick completing at `start`. The default `MissedTickBehavior` is
/// `Burst`, but this can be configured
/// by calling [`set_missed_tick_behavior`](Interval::set_missed_tick_behavior).
#[inline]
pub fn interval_at(start: Instant, period: Duration) -> Interval {

View file

@ -24,8 +24,6 @@ futures-core = { version = "0.3.18", default-features = false, features = ["allo
futures-sink = { version = "0.3.18", default-features = false, features = ["alloc"] }
pin-project-lite = "0.2.6"
backtrace = "*"
[dev-dependencies]
ntex = "0.4.10"
ntex-rt = "0.3.2"

View file

@ -5,8 +5,8 @@ pub mod condition;
pub mod oneshot;
pub mod pool;
/// Error returned from a [`Receiver`](Receiver) when the corresponding
/// [`Sender`](Sender) is dropped.
/// Error returned from a `Receiver` when the corresponding
/// `Sender` is dropped.
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub struct Canceled;

View file

@ -658,17 +658,19 @@ mod tests {
fut2.await;
let elapsed = Instant::now() - time;
#[cfg(not(target_os = "macos"))]
assert!(
elapsed > Duration::from_millis(200) && elapsed < Duration::from_millis(500),
elapsed > Duration::from_millis(200) && elapsed < Duration::from_millis(300),
"elapsed: {:?}",
elapsed
);
fut1.await;
let elapsed = Instant::now() - time;
#[cfg(not(target_os = "macos"))]
assert!(
elapsed > Duration::from_millis(1000)
&& elapsed < Duration::from_millis(3000), // osx
&& elapsed < Duration::from_millis(1200), // osx
"elapsed: {:?}",
elapsed
);
@ -676,8 +678,11 @@ mod tests {
let time = Instant::now();
sleep(Millis(25)).await;
let elapsed = Instant::now() - time;
#[cfg(not(target_os = "macos"))]
assert!(
elapsed > Duration::from_millis(20) && elapsed < Duration::from_millis(50)
elapsed > Duration::from_millis(20) && elapsed < Duration::from_millis(50),
"elapsed: {:?}",
elapsed
);
}
}

View file

@ -50,6 +50,7 @@ ntex-service = "0.2.1"
ntex-macros = "0.1.3"
ntex-util = "0.1.2"
ntex-bytes = "0.1.7"
ntex-io = { version = "0.1", features = ["tokio"] }
base64 = "0.13"
bitflags = "1.3"

148
ntex/src/connect/io.rs Normal file
View file

@ -0,0 +1,148 @@
use std::task::{Context, Poll};
use std::{future::Future, pin::Pin};
use crate::io::Io;
use crate::service::{Service, ServiceFactory};
use crate::util::{PoolId, PoolRef, Ready};
use super::service::ConnectServiceResponse;
use super::{Address, Connect, ConnectError, Connector};
pub struct IoConnector<T> {
inner: Connector<T>,
pool: PoolRef,
}
impl<T> IoConnector<T> {
/// Construct new connect service with custom dns resolver
pub fn new() -> Self {
IoConnector {
inner: Connector::new(),
pool: PoolId::P0.pool_ref(),
}
}
/// Set memory pool.
///
/// Use specified memory pool for memory allocations. By default P0
/// memory pool is used.
pub fn memory_pool(mut self, id: PoolId) -> Self {
self.pool = id.pool_ref();
self
}
}
impl<T: Address> IoConnector<T> {
/// Resolve and connect to remote host
pub fn connect<U>(&self, message: U) -> IoConnectServiceResponse<T>
where
Connect<T>: From<U>,
{
IoConnectServiceResponse {
inner: self.inner.call(message.into()),
pool: self.pool,
}
}
}
impl<T> Default for IoConnector<T> {
fn default() -> Self {
IoConnector::new()
}
}
impl<T> Clone for IoConnector<T> {
fn clone(&self) -> Self {
IoConnector {
inner: self.inner.clone(),
pool: self.pool,
}
}
}
impl<T: Address> ServiceFactory for IoConnector<T> {
type Request = Connect<T>;
type Response = Io;
type Error = ConnectError;
type Config = ();
type Service = IoConnector<T>;
type InitError = ();
type Future = Ready<Self::Service, Self::InitError>;
#[inline]
fn new_service(&self, _: ()) -> Self::Future {
Ready::Ok(self.clone())
}
}
impl<T: Address> Service for IoConnector<T> {
type Request = Connect<T>;
type Response = Io;
type Error = ConnectError;
type Future = IoConnectServiceResponse<T>;
#[inline]
fn poll_ready(&self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
#[inline]
fn call(&self, req: Connect<T>) -> Self::Future {
self.connect(req)
}
}
#[doc(hidden)]
pub struct IoConnectServiceResponse<T: Address> {
inner: ConnectServiceResponse<T>,
pool: PoolRef,
}
impl<T: Address> Future for IoConnectServiceResponse<T> {
type Output = Result<Io, ConnectError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match Pin::new(&mut self.inner).poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(stream)) => {
Poll::Ready(Ok(Io::with_memory_pool(stream, self.pool)))
}
Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[crate::rt_test]
async fn test_connect() {
let server = crate::server::test_server(|| {
crate::service::fn_service(|_| async { Ok::<_, ()>(()) })
});
let srv = IoConnector::default();
let result = srv.connect("").await;
assert!(result.is_err());
let result = srv.connect("localhost:99999").await;
assert!(result.is_err());
let srv = IoConnector::default();
let result = srv.connect(format!("{}", server.addr())).await;
assert!(result.is_ok());
let msg = Connect::new(format!("{}", server.addr())).set_addrs(vec![
format!("127.0.0.1:{}", server.addr().port() - 1)
.parse()
.unwrap(),
server.addr(),
]);
let result = crate::connect::connect(msg).await;
assert!(result.is_ok());
let msg = Connect::new(server.addr());
let result = crate::connect::connect(msg).await;
assert!(result.is_ok());
}
}

View file

@ -2,6 +2,7 @@
use std::future::Future;
mod error;
mod io;
mod message;
mod resolve;
mod service;
@ -18,6 +19,7 @@ pub mod rustls;
use crate::rt::net::TcpStream;
pub use self::error::ConnectError;
pub use self::io::IoConnector;
pub use self::message::{Address, Connect};
pub use self::resolve::Resolver;
pub use self::service::Connector;

View file

@ -35,9 +35,7 @@ impl<T: Address> Connector<T> {
impl<T> Default for Connector<T> {
fn default() -> Self {
Connector {
resolver: Resolver::default(),
}
Connector::new()
}
}

View file

@ -39,7 +39,6 @@ pub mod framed;
#[cfg(feature = "http-framework")]
pub mod http;
pub mod server;
pub mod testing;
pub mod util;
#[cfg(feature = "http-framework")]
pub mod web;
@ -78,3 +77,13 @@ pub mod time {
//! Utilities for tracking time.
pub use ntex_util::time::*;
}
pub mod io {
//! IO streaming utilities.
pub use ntex_io::*;
}
pub mod testing {
//! IO testing utilities.
pub use ntex_io::testing::IoTest as Io;
}

View file

@ -1,373 +0,0 @@
use std::cell::{Cell, RefCell};
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};
use std::{cmp, fmt, io, mem, pin::Pin};
use crate::codec::{AsyncRead, AsyncWrite, ReadBuf};
use crate::time::{sleep, Millis};
use crate::util::{poll_fn, BytesMut};
#[derive(Default)]
struct AtomicWaker(Arc<Mutex<RefCell<Option<Waker>>>>);
impl AtomicWaker {
fn wake(&self) {
if let Some(waker) = self.0.lock().unwrap().borrow_mut().take() {
waker.wake()
}
}
}
impl fmt::Debug for AtomicWaker {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "AtomicWaker")
}
}
/// Async io stream
#[derive(Debug)]
pub struct Io {
tp: Type,
state: Arc<Cell<State>>,
local: Arc<Mutex<RefCell<Channel>>>,
remote: Arc<Mutex<RefCell<Channel>>>,
}
bitflags::bitflags! {
struct Flags: u8 {
const FLUSHED = 0b0000_0001;
const CLOSED = 0b0000_0010;
}
}
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
enum Type {
Client,
Server,
ClientClone,
ServerClone,
}
#[derive(Copy, Clone, Default, Debug)]
struct State {
client_dropped: bool,
server_dropped: bool,
}
#[derive(Default, Debug)]
struct Channel {
buf: BytesMut,
buf_cap: usize,
flags: Flags,
waker: AtomicWaker,
read: IoState,
write: IoState,
}
impl Channel {
fn is_closed(&self) -> bool {
self.flags.contains(Flags::CLOSED)
}
}
impl Default for Flags {
fn default() -> Self {
Flags::empty()
}
}
#[derive(Debug)]
enum IoState {
Ok,
Pending,
Close,
Err(io::Error),
}
impl Default for IoState {
fn default() -> Self {
IoState::Ok
}
}
impl Io {
/// Create a two interconnected streams
pub fn create() -> (Io, Io) {
let local = Arc::new(Mutex::new(RefCell::new(Channel::default())));
let remote = Arc::new(Mutex::new(RefCell::new(Channel::default())));
let state = Arc::new(Cell::new(State::default()));
(
Io {
tp: Type::Client,
local: local.clone(),
remote: remote.clone(),
state: state.clone(),
},
Io {
state,
tp: Type::Server,
local: remote,
remote: local,
},
)
}
pub fn is_client_dropped(&self) -> bool {
self.state.get().client_dropped
}
pub fn is_server_dropped(&self) -> bool {
self.state.get().server_dropped
}
/// Check if channel is closed from remoote side
pub fn is_closed(&self) -> bool {
self.remote.lock().unwrap().borrow().is_closed()
}
/// Set read to Pending state
pub fn read_pending(&self) {
self.remote.lock().unwrap().borrow_mut().read = IoState::Pending;
}
/// Set read to error
pub fn read_error(&self, err: io::Error) {
self.remote.lock().unwrap().borrow_mut().read = IoState::Err(err);
}
/// Set write error on remote side
pub fn write_error(&self, err: io::Error) {
self.local.lock().unwrap().borrow_mut().write = IoState::Err(err);
}
/// Access read buffer.
pub fn local_buffer<F, R>(&self, f: F) -> R
where
F: FnOnce(&mut BytesMut) -> R,
{
let guard = self.local.lock().unwrap();
let mut ch = guard.borrow_mut();
f(&mut ch.buf)
}
/// Access remote buffer.
pub fn remote_buffer<F, R>(&self, f: F) -> R
where
F: FnOnce(&mut BytesMut) -> R,
{
let guard = self.remote.lock().unwrap();
let mut ch = guard.borrow_mut();
f(&mut ch.buf)
}
/// Closed remote side.
pub async fn close(&self) {
{
let guard = self.remote.lock().unwrap();
let mut remote = guard.borrow_mut();
remote.read = IoState::Close;
remote.waker.wake();
}
sleep(Millis(35)).await;
}
/// Add extra data to the remote buffer and notify reader
pub fn write<T: AsRef<[u8]>>(&self, data: T) {
let guard = self.remote.lock().unwrap();
let mut write = guard.borrow_mut();
write.buf.extend_from_slice(data.as_ref());
write.waker.wake();
}
/// Read any available data
pub fn remote_buffer_cap(&self, cap: usize) {
// change cap
self.local.lock().unwrap().borrow_mut().buf_cap = cap;
// wake remote
self.remote.lock().unwrap().borrow().waker.wake();
}
/// Read any available data
pub fn read_any(&self) -> BytesMut {
self.local.lock().unwrap().borrow_mut().buf.split()
}
/// Read data, if data is not available wait for it
pub async fn read(&self) -> Result<BytesMut, io::Error> {
if self.local.lock().unwrap().borrow().buf.is_empty() {
poll_fn(|cx| {
let guard = self.local.lock().unwrap();
let read = guard.borrow_mut();
if read.buf.is_empty() {
let closed = match self.tp {
Type::Client | Type::ClientClone => {
self.is_server_dropped() || read.is_closed()
}
Type::Server | Type::ServerClone => self.is_client_dropped(),
};
if closed {
Poll::Ready(())
} else {
*read.waker.0.lock().unwrap().borrow_mut() =
Some(cx.waker().clone());
drop(read);
drop(guard);
Poll::Pending
}
} else {
Poll::Ready(())
}
})
.await;
}
Ok(self.local.lock().unwrap().borrow_mut().buf.split())
}
}
impl Clone for Io {
fn clone(&self) -> Self {
let tp = match self.tp {
Type::Server => Type::ServerClone,
Type::Client => Type::ClientClone,
val => val,
};
Io {
tp,
local: self.local.clone(),
remote: self.remote.clone(),
state: self.state.clone(),
}
}
}
impl Drop for Io {
fn drop(&mut self) {
let mut state = self.state.get();
match self.tp {
Type::Server => state.server_dropped = true,
Type::Client => state.client_dropped = true,
_ => (),
}
self.state.set(state);
}
}
impl AsyncRead for Io {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let guard = self.local.lock().unwrap();
let mut ch = guard.borrow_mut();
*ch.waker.0.lock().unwrap().borrow_mut() = Some(cx.waker().clone());
if !ch.buf.is_empty() {
let size = std::cmp::min(ch.buf.len(), buf.remaining());
let b = ch.buf.split_to(size);
buf.put_slice(&b);
return Poll::Ready(Ok(()));
}
match mem::take(&mut ch.read) {
IoState::Ok => Poll::Pending,
IoState::Close => {
ch.read = IoState::Close;
Poll::Ready(Ok(()))
}
IoState::Pending => Poll::Pending,
IoState::Err(e) => Poll::Ready(Err(e)),
}
}
}
impl AsyncWrite for Io {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let guard = self.remote.lock().unwrap();
let mut ch = guard.borrow_mut();
match mem::take(&mut ch.write) {
IoState::Ok => {
let cap = cmp::min(buf.len(), ch.buf_cap);
if cap > 0 {
ch.buf.extend(&buf[..cap]);
ch.buf_cap -= cap;
ch.flags.remove(Flags::FLUSHED);
ch.waker.wake();
Poll::Ready(Ok(cap))
} else {
*self
.local
.lock()
.unwrap()
.borrow_mut()
.waker
.0
.lock()
.unwrap()
.borrow_mut() = Some(cx.waker().clone());
Poll::Pending
}
}
IoState::Close => Poll::Ready(Ok(0)),
IoState::Pending => {
*self
.local
.lock()
.unwrap()
.borrow_mut()
.waker
.0
.lock()
.unwrap()
.borrow_mut() = Some(cx.waker().clone());
Poll::Pending
}
IoState::Err(e) => Poll::Ready(Err(e)),
}
}
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
self.local
.lock()
.unwrap()
.borrow_mut()
.flags
.insert(Flags::CLOSED);
Poll::Ready(Ok(()))
}
}
#[cfg(test)]
#[allow(clippy::redundant_clone)]
mod tests {
use super::*;
#[crate::rt_test]
async fn basic() {
let (client, server) = Io::create();
assert_eq!(client.tp, Type::Client);
assert_eq!(client.clone().tp, Type::ClientClone);
assert_eq!(server.tp, Type::Server);
assert_eq!(server.clone().tp, Type::ServerClone);
assert!(!server.is_client_dropped());
drop(client);
assert!(server.is_client_dropped());
let server2 = server.clone();
assert!(!server2.is_server_dropped());
drop(server);
assert!(server2.is_server_dropped());
}
}

View file

@ -10,6 +10,7 @@ use crate::web::httprequest::HttpRequest;
#[derive(Clone, Debug)]
pub struct ResourceMap {
#[allow(dead_code)]
root: ResourceDef,
parent: RefCell<Option<Rc<ResourceMap>>>,
named: HashMap<String, ResourceDef>,

View file

@ -129,6 +129,7 @@ fn test_start() {
}
#[test]
#[allow(deprecated)]
fn test_configure() {
let addr1 = TestServer::unused_addr();
let addr2 = TestServer::unused_addr();
@ -179,6 +180,7 @@ fn test_configure() {
}
#[test]
#[allow(deprecated)]
fn test_configure_async() {
let addr1 = TestServer::unused_addr();
let addr2 = TestServer::unused_addr();
@ -255,7 +257,7 @@ fn test_on_worker_start() {
.bind("addr2", addr2)
.unwrap()
.listen("addr3", lst)
.apply_async(move |rt| {
.on_worker_start(move |rt| {
let num = num.clone();
async move {
rt.service("addr1", fn_service(|_| ok::<_, ()>(())));