diff --git a/.github/workflows/linux.yml b/.github/workflows/linux.yml index 5132bd29..1b3ed78d 100644 --- a/.github/workflows/linux.yml +++ b/.github/workflows/linux.yml @@ -8,7 +8,7 @@ jobs: fail-fast: false matrix: version: - - 1.53.0 # MSRV + - 1.56.0 # MSRV - stable - nightly @@ -43,7 +43,7 @@ jobs: key: ${{ matrix.version }}-x86_64-unknown-linux-gnu-cargo-index-trimmed-${{ hashFiles('**/Cargo.lock') }} - name: Cache cargo tarpaulin - if: matrix.version == '1.53.0' && (github.ref == 'refs/heads/master' || github.event_name == 'pull_request') + if: matrix.version == '1.56.0' && (github.ref == 'refs/heads/master' || github.event_name == 'pull_request') uses: actions/cache@v1 with: path: ~/.cargo/bin @@ -57,19 +57,19 @@ jobs: args: --all --all-features --no-fail-fast -- --nocapture - name: Install tarpaulin - if: matrix.version == '1.53.0' && (github.ref == 'refs/heads/master' || github.event_name == 'pull_request') + if: matrix.version == '1.56.0' && (github.ref == 'refs/heads/master' || github.event_name == 'pull_request') continue-on-error: true run: | cargo install cargo-tarpaulin - name: Generate coverage report - if: matrix.version == '1.53.0' && (github.ref == 'refs/heads/master' || github.event_name == 'pull_request') + if: matrix.version == '1.56.0' && (github.ref == 'refs/heads/master' || github.event_name == 'pull_request') continue-on-error: true run: | cargo tarpaulin --out Xml --all --all-features - name: Upload to Codecov - if: matrix.version == '1.53.0' && (github.ref == 'refs/heads/master' || github.event_name == 'pull_request') + if: matrix.version == '1.56.0' && (github.ref == 'refs/heads/master' || github.event_name == 'pull_request') continue-on-error: true uses: codecov/codecov-action@v1 with: diff --git a/Cargo.toml b/Cargo.toml index 308cf7f8..a90cbbcf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,6 +3,7 @@ members = [ "ntex", "ntex-bytes", "ntex-codec", + "ntex-io", "ntex-router", "ntex-rt", "ntex-service", @@ -14,6 +15,7 @@ members = [ ntex = { path = "ntex" } ntex-bytes = { path = "ntex-bytes" } ntex-codec = { path = "ntex-codec" } +ntex-io = { path = "ntex-io" } ntex-router = { path = "ntex-router" } ntex-rt = { path = "ntex-rt" } ntex-service = { path = "ntex-service" } diff --git a/ntex-bytes/src/lib.rs b/ntex-bytes/src/lib.rs index 877b4700..90f83316 100644 --- a/ntex-bytes/src/lib.rs +++ b/ntex-bytes/src/lib.rs @@ -1,6 +1,6 @@ //! Provides abstractions for working with bytes. //! -//! This is fork of bytes crate (https://github.com/tokio-rs/bytes) +//! This is fork of [bytes crate](https://github.com/tokio-rs/bytes) //! //! The `ntex-bytes` crate provides an efficient byte buffer structure //! ([`Bytes`](struct.Bytes.html)) and traits for working with buffer diff --git a/ntex-codec/Cargo.toml b/ntex-codec/Cargo.toml index 5810d779..84d70d91 100644 --- a/ntex-codec/Cargo.toml +++ b/ntex-codec/Cargo.toml @@ -23,5 +23,5 @@ log = "0.4" tokio = { version = "1", default-features = false } [dev-dependencies] -ntex = "0.3.13" +ntex = "0.4.13" futures = "0.3.13" diff --git a/ntex-codec/src/lib.rs b/ntex-codec/src/lib.rs index fc9344b6..779bb7e2 100644 --- a/ntex-codec/src/lib.rs +++ b/ntex-codec/src/lib.rs @@ -1,8 +1,8 @@ //! Utilities for encoding and decoding frames. //! //! Contains adapters to go from streams of bytes, [`AsyncRead`] and -//! [`AsyncWrite`], to framed streams implementing [`Sink`] and [`Stream`]. -//! Framed streams are also known as [transports]. +//! [`AsyncWrite`], to framed streams implementing `Sink` and `Stream`. +//! Framed streams are also known as `transports`. //! //! [`AsyncRead`]: # //! [`AsyncWrite`]: # diff --git a/ntex-io/Cargo.toml b/ntex-io/Cargo.toml new file mode 100644 index 00000000..51d7ec33 --- /dev/null +++ b/ntex-io/Cargo.toml @@ -0,0 +1,39 @@ +[package] +name = "ntex-io" +version = "0.1.0" +authors = ["ntex contributors "] +description = "Utilities for encoding and decoding frames" +keywords = ["network", "framework", "async", "futures"] +homepage = "https://ntex.rs" +repository = "https://github.com/ntex-rs/ntex.git" +documentation = "https://docs.rs/ntex-io/" +categories = ["network-programming", "asynchronous"] +license = "MIT" +edition = "2018" + +[lib] +name = "ntex_io" +path = "src/lib.rs" + +[features] +default = ["tokio"] + +# tokio support +tokio = ["tok-io"] + +[dependencies] +bitflags = "1.3" +fxhash = "0.2.1" +ntex-codec = "0.5.1" +ntex-bytes = "0.1.7" +ntex-util = "0.1.2" +ntex-service = "0.2.1" +log = "0.4" +pin-project-lite = "0.2" + +tok-io = { version = "1", package = "tokio", default-features = false, features = ["net"], optional = true } + +[dev-dependencies] +ntex = "0.4.13" +futures = "0.3.13" +rand = "0.8" diff --git a/ntex-io/LICENSE b/ntex-io/LICENSE new file mode 120000 index 00000000..ea5b6064 --- /dev/null +++ b/ntex-io/LICENSE @@ -0,0 +1 @@ +../LICENSE \ No newline at end of file diff --git a/ntex-io/src/dispatcher.rs b/ntex-io/src/dispatcher.rs new file mode 100644 index 00000000..31f3b9e2 --- /dev/null +++ b/ntex-io/src/dispatcher.rs @@ -0,0 +1,904 @@ +//! Framed transport dispatcher +use std::{ + cell::Cell, future::Future, pin::Pin, rc::Rc, task::Context, task::Poll, time, +}; + +use ntex_bytes::Pool; +use ntex_codec::{Decoder, Encoder}; +use ntex_service::{IntoService, Service}; +use ntex_util::time::{now, Seconds}; +use ntex_util::{future::Either, spawn}; + +use super::{DispatchItem, IoBoxed, ReadRef, Timer, WriteRef}; + +type Response = ::Item; + +pin_project_lite::pin_project! { + /// Framed dispatcher - is a future that reads frames from bytes stream + /// and pass then to the service. + pub struct Dispatcher + where + S: Service, Response = Option>>, + S::Error: 'static, + S::Future: 'static, + U: Encoder, + U: Decoder, + ::Item: 'static, + { + service: S, + inner: DispatcherInner, + #[pin] + fut: Option, + } +} + +struct DispatcherInner +where + S: Service, Response = Option>>, + U: Encoder + Decoder, +{ + st: Cell, + state: IoBoxed, + timer: Timer, + ka_timeout: Seconds, + ka_updated: Cell, + error: Cell>, + ready_err: Cell, + shared: Rc>, + pool: Pool, +} + +struct DispatcherShared +where + S: Service, Response = Option>>, + U: Encoder + Decoder, +{ + codec: U, + error: Cell::Error>>>, + inflight: Cell, +} + +#[derive(Copy, Clone, Debug)] +enum DispatcherState { + Processing, + Backpressure, + Stop, + Shutdown, +} + +enum DispatcherError { + KeepAlive, + Encoder(U), + Service(S), +} + +enum PollService { + Item(DispatchItem), + ServiceError, + Ready, +} + +impl From> for DispatcherError { + fn from(err: Either) -> Self { + match err { + Either::Left(err) => DispatcherError::Service(err), + Either::Right(err) => DispatcherError::Encoder(err), + } + } +} + +impl Dispatcher +where + S: Service, Response = Option>> + 'static, + U: Decoder + Encoder + 'static, + ::Item: 'static, +{ + /// Construct new `Dispatcher` instance. + pub fn new>( + state: IoBoxed, + codec: U, + service: F, + timer: Timer, + ) -> Self { + let updated = now(); + let ka_timeout = Seconds(30); + + // register keepalive timer + let expire = updated + time::Duration::from(ka_timeout); + timer.register(expire, expire, &state); + + Dispatcher { + service: service.into_service(), + fut: None, + inner: DispatcherInner { + pool: state.memory_pool().pool(), + ka_updated: Cell::new(updated), + error: Cell::new(None), + ready_err: Cell::new(false), + st: Cell::new(DispatcherState::Processing), + shared: Rc::new(DispatcherShared { + codec, + error: Cell::new(None), + inflight: Cell::new(0), + }), + state, + timer, + ka_timeout, + }, + } + } + + /// Set keep-alive timeout. + /// + /// To disable timeout set value to 0. + /// + /// By default keep-alive timeout is set to 30 seconds. + pub fn keepalive_timeout(mut self, timeout: Seconds) -> Self { + // register keepalive timer + let prev = self.inner.ka_updated.get() + time::Duration::from(self.inner.ka()); + if timeout.is_zero() { + self.inner.timer.unregister(prev, &self.inner.state); + } else { + let expire = self.inner.ka_updated.get() + time::Duration::from(timeout); + self.inner.timer.register(expire, prev, &self.inner.state); + } + self.inner.ka_timeout = timeout; + + self + } + + /// Set connection disconnect timeout in seconds. + /// + /// Defines a timeout for disconnect connection. If a disconnect procedure does not complete + /// within this time, the connection get dropped. + /// + /// To disable timeout set value to 0. + /// + /// By default disconnect timeout is set to 1 seconds. + pub fn disconnect_timeout(self, val: Seconds) -> Self { + self.inner.state.set_disconnect_timeout(val); + self + } +} + +impl DispatcherShared +where + S: Service, Response = Option>>, + S::Error: 'static, + S::Future: 'static, + U: Encoder + Decoder, + ::Item: 'static, +{ + fn handle_result(&self, item: Result, write: WriteRef<'_>) { + self.inflight.set(self.inflight.get() - 1); + match write.encode_result(item, &self.codec) { + Ok(true) => (), + Ok(false) => write.enable_backpressure(None), + Err(err) => self.error.set(Some(err.into())), + } + write.wake_dispatcher(); + } +} + +impl Future for Dispatcher +where + S: Service, Response = Option>> + 'static, + U: Decoder + Encoder + 'static, + ::Item: 'static, +{ + type Output = Result<(), S::Error>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.as_mut().project(); + let slf = &this.inner; + let state = &slf.state; + let read = state.read(); + let write = state.write(); + + // handle service response future + if let Some(fut) = this.fut.as_mut().as_pin_mut() { + match fut.poll(cx) { + Poll::Pending => (), + Poll::Ready(item) => { + this.fut.set(None); + slf.shared.inflight.set(slf.shared.inflight.get() - 1); + slf.handle_result(item, write); + } + } + } + + // handle memory pool pressure + if slf.pool.poll_ready(cx).is_pending() { + read.pause(cx); + return Poll::Pending; + } + + loop { + match slf.st.get() { + DispatcherState::Processing => { + let result = match slf.poll_service(this.service, cx, read) { + Poll::Pending => return Poll::Pending, + Poll::Ready(result) => result, + }; + + let item = match result { + PollService::Ready => { + if !write.is_ready() { + // instruct write task to notify dispatcher when data is flushed + write.enable_backpressure(Some(cx)); + slf.st.set(DispatcherState::Backpressure); + DispatchItem::WBackPressureEnabled + } else if read.is_ready() { + // decode incoming bytes if buffer is ready + match read.decode(&slf.shared.codec) { + Ok(Some(el)) => { + slf.update_keepalive(); + DispatchItem::Item(el) + } + Ok(None) => { + log::trace!("not enough data to decode next frame, register dispatch task"); + read.wake(cx); + return Poll::Pending; + } + Err(err) => { + slf.st.set(DispatcherState::Stop); + slf.unregister_keepalive(); + DispatchItem::DecoderError(err) + } + } + } else { + // no new events + state.register_dispatcher(cx); + return Poll::Pending; + } + } + PollService::Item(item) => item, + PollService::ServiceError => continue, + }; + + // call service + if this.fut.is_none() { + // optimize first service call + this.fut.set(Some(this.service.call(item))); + match this.fut.as_mut().as_pin_mut().unwrap().poll(cx) { + Poll::Ready(res) => { + this.fut.set(None); + slf.handle_result(res, write); + } + Poll::Pending => { + slf.shared.inflight.set(slf.shared.inflight.get() + 1) + } + } + } else { + slf.spawn_service_call(this.service.call(item)); + } + } + // handle write back-pressure + DispatcherState::Backpressure => { + let result = match slf.poll_service(this.service, cx, read) { + Poll::Ready(result) => result, + Poll::Pending => return Poll::Pending, + }; + let item = match result { + PollService::Ready => { + if write.is_ready() { + slf.st.set(DispatcherState::Processing); + DispatchItem::WBackPressureDisabled + } else { + return Poll::Pending; + } + } + PollService::Item(item) => item, + PollService::ServiceError => continue, + }; + + // call service + if this.fut.is_none() { + // optimize first service call + this.fut.set(Some(this.service.call(item))); + match this.fut.as_mut().as_pin_mut().unwrap().poll(cx) { + Poll::Ready(res) => { + this.fut.set(None); + slf.handle_result(res, write); + } + Poll::Pending => { + slf.shared.inflight.set(slf.shared.inflight.get() + 1) + } + } + } else { + slf.spawn_service_call(this.service.call(item)); + } + } + // drain service responses + DispatcherState::Stop => { + // service may relay on poll_ready for response results + if !this.inner.ready_err.get() { + let _ = this.service.poll_ready(cx); + } + + if slf.shared.inflight.get() == 0 { + slf.st.set(DispatcherState::Shutdown); + state.shutdown(cx); + } else { + state.register_dispatcher(cx); + return Poll::Pending; + } + } + // shutdown service + DispatcherState::Shutdown => { + let err = slf.error.take(); + + return if this.service.poll_shutdown(cx, err.is_some()).is_ready() { + log::trace!("service shutdown is completed, stop"); + + Poll::Ready(if let Some(err) = err { + Err(err) + } else { + Ok(()) + }) + } else { + slf.error.set(err); + Poll::Pending + }; + } + } + } + } +} + +impl DispatcherInner +where + S: Service, Response = Option>> + 'static, + U: Decoder + Encoder + 'static, +{ + /// spawn service call + fn spawn_service_call(&self, fut: S::Future) { + self.shared.inflight.set(self.shared.inflight.get() + 1); + + let st = self.state.get_ref(); + let shared = self.shared.clone(); + spawn(async move { + let item = fut.await; + shared.handle_result(item, st.write()); + }); + } + + fn handle_result( + &self, + item: Result::Item>, S::Error>, + write: WriteRef<'_>, + ) { + match write.encode_result(item, &self.shared.codec) { + Ok(true) => (), + Ok(false) => write.enable_backpressure(None), + Err(Either::Left(err)) => { + self.error.set(Some(err)); + } + Err(Either::Right(err)) => { + self.shared.error.set(Some(DispatcherError::Encoder(err))) + } + } + } + + fn poll_service( + &self, + srv: &S, + cx: &mut Context<'_>, + read: ReadRef<'_>, + ) -> Poll> { + match srv.poll_ready(cx) { + Poll::Ready(Ok(_)) => { + // service is ready, wake io read task + read.resume(); + + // check keepalive timeout + self.check_keepalive(); + + // check for errors + Poll::Ready(if let Some(err) = self.shared.error.take() { + log::trace!("error occured, stopping dispatcher"); + self.unregister_keepalive(); + self.st.set(DispatcherState::Stop); + + match err { + DispatcherError::KeepAlive => { + PollService::Item(DispatchItem::KeepAliveTimeout) + } + DispatcherError::Encoder(err) => { + PollService::Item(DispatchItem::EncoderError(err)) + } + DispatcherError::Service(err) => { + self.error.set(Some(err)); + PollService::ServiceError + } + } + } else if self.state.is_dispatcher_stopped() { + log::trace!("dispatcher is instructed to stop"); + + self.unregister_keepalive(); + + // process unhandled data + if let Ok(Some(el)) = read.decode(&self.shared.codec) { + PollService::Item(DispatchItem::Item(el)) + } else { + self.st.set(DispatcherState::Stop); + + // get io error + if let Some(err) = self.state.take_error() { + PollService::Item(DispatchItem::IoError(err)) + } else { + PollService::ServiceError + } + } + } else { + PollService::Ready + }) + } + // pause io read task + Poll::Pending => { + log::trace!("service is not ready, register dispatch task"); + read.pause(cx); + Poll::Pending + } + // handle service readiness error + Poll::Ready(Err(err)) => { + log::trace!("service readiness check failed, stopping"); + self.st.set(DispatcherState::Stop); + self.error.set(Some(err)); + self.unregister_keepalive(); + self.ready_err.set(true); + Poll::Ready(PollService::ServiceError) + } + } + } + + fn ka(&self) -> Seconds { + self.ka_timeout + } + + fn ka_enabled(&self) -> bool { + self.ka_timeout.non_zero() + } + + /// check keepalive timeout + fn check_keepalive(&self) { + if self.state.is_keepalive() { + log::trace!("keepalive timeout"); + if let Some(err) = self.shared.error.take() { + self.shared.error.set(Some(err)); + } else { + self.shared.error.set(Some(DispatcherError::KeepAlive)); + } + } + } + + /// update keep-alive timer + fn update_keepalive(&self) { + if self.ka_enabled() { + let updated = now(); + if updated != self.ka_updated.get() { + let ka = time::Duration::from(self.ka()); + self.timer.register( + updated + ka, + self.ka_updated.get() + ka, + &self.state, + ); + self.ka_updated.set(updated); + } + } + } + + /// unregister keep-alive timer + fn unregister_keepalive(&self) { + if self.ka_enabled() { + self.timer.unregister( + self.ka_updated.get() + time::Duration::from(self.ka()), + &self.state, + ); + } + } +} + +#[cfg(test)] +mod tests { + use rand::Rng; + use std::sync::{atomic::AtomicBool, atomic::Ordering::Relaxed, Arc, Mutex}; + use std::{cell::RefCell, time::Duration}; + + use ntex_bytes::{Bytes, PoolId, PoolRef}; + use ntex_codec::BytesCodec; + use ntex_util::future::Ready; + use ntex_util::time::{sleep, Millis}; + + use crate::testing::IoTest; + use crate::{state::Flags, state::IoStateInner, Io, IoStream, WriteRef}; + + use super::*; + + pub(crate) struct State(Rc); + + impl State { + fn flags(&self) -> Flags { + self.0.flags.get() + } + + fn write(&'_ self) -> WriteRef<'_> { + WriteRef(self.0.as_ref()) + } + + fn close(&self) { + self.0.insert_flags(Flags::DSP_STOP); + self.0.dispatch_task.wake(); + } + + fn set_memory_pool(&self, pool: PoolRef) { + self.0.pool.set(pool); + } + } + + impl Dispatcher + where + S: Service, Response = Option>>, + S::Error: 'static, + S::Future: 'static, + U: Decoder + Encoder + 'static, + ::Item: 'static, + { + /// Construct new `Dispatcher` instance + pub(crate) fn debug>( + io: T, + codec: U, + service: F, + ) -> (Self, State) { + let state = Io::new(io); + let timer = Timer::default(); + let ka_timeout = Seconds(1); + let ka_updated = now(); + let shared = Rc::new(DispatcherShared { + codec: codec, + error: Cell::new(None), + inflight: Cell::new(0), + }); + let inner = State(state.0 .0.clone()); + + let expire = ka_updated + Duration::from_millis(500); + timer.register(expire, expire, &state); + + ( + Dispatcher { + service: service.into_service(), + fut: None, + inner: DispatcherInner { + ka_updated: Cell::new(ka_updated), + error: Cell::new(None), + ready_err: Cell::new(false), + st: Cell::new(DispatcherState::Processing), + pool: state.memory_pool().pool(), + state: state.into_boxed(), + shared, + timer, + ka_timeout, + }, + }, + inner, + ) + } + } + + #[ntex::test] + async fn test_basic() { + let (client, server) = IoTest::create(); + client.remote_buffer_cap(1024); + client.write("GET /test HTTP/1\r\n\r\n"); + + let (disp, _) = Dispatcher::debug( + server, + BytesCodec, + ntex_service::fn_service(|msg: DispatchItem| async move { + sleep(Millis(50)).await; + if let DispatchItem::Item(msg) = msg { + Ok::<_, ()>(Some(msg.freeze())) + } else { + panic!() + } + }), + ); + spawn(async move { + let _ = disp.await; + }); + + sleep(Millis(25)).await; + let buf = client.read().await.unwrap(); + assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n")); + + client.write("GET /test HTTP/1\r\n\r\n"); + let buf = client.read().await.unwrap(); + assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n")); + + client.close().await; + assert!(client.is_server_dropped()); + } + + #[ntex::test] + async fn test_sink() { + let (client, server) = IoTest::create(); + client.remote_buffer_cap(1024); + client.write("GET /test HTTP/1\r\n\r\n"); + + let (disp, st) = Dispatcher::debug( + server, + BytesCodec, + ntex_service::fn_service(|msg: DispatchItem| async move { + if let DispatchItem::Item(msg) = msg { + Ok::<_, ()>(Some(msg.freeze())) + } else { + panic!() + } + }), + ); + spawn(async move { + let _ = disp.disconnect_timeout(Seconds(1)).await; + }); + + let buf = client.read().await.unwrap(); + assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n")); + + assert!(st + .write() + .encode(Bytes::from_static(b"test"), &mut BytesCodec) + .is_ok()); + let buf = client.read().await.unwrap(); + assert_eq!(buf, Bytes::from_static(b"test")); + + st.close(); + sleep(Millis(1100)).await; + assert!(client.is_server_dropped()); + } + + #[ntex::test] + async fn test_err_in_service() { + let (client, server) = IoTest::create(); + client.remote_buffer_cap(0); + client.write("GET /test HTTP/1\r\n\r\n"); + + let (disp, state) = Dispatcher::debug( + server, + BytesCodec, + ntex_service::fn_service(|_: DispatchItem| async move { + Err::, _>(()) + }), + ); + state + .write() + .encode( + Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), + &mut BytesCodec, + ) + .unwrap(); + spawn(async move { + let _ = disp.await; + }); + + // buffer should be flushed + client.remote_buffer_cap(1024); + let buf = client.read().await.unwrap(); + assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n")); + + // write side must be closed, dispatcher waiting for read side to close + assert!(client.is_closed()); + + // close read side + client.close().await; + assert!(client.is_server_dropped()); + } + + #[ntex::test] + async fn test_err_in_service_ready() { + let (client, server) = IoTest::create(); + client.remote_buffer_cap(0); + client.write("GET /test HTTP/1\r\n\r\n"); + + let counter = Rc::new(Cell::new(0)); + + struct Srv(Rc>); + + impl Service for Srv { + type Request = DispatchItem; + type Response = Option>; + type Error = (); + type Future = Ready>, ()>; + + fn poll_ready(&self, _: &mut Context<'_>) -> Poll> { + self.0.set(self.0.get() + 1); + Poll::Ready(Err(())) + } + + fn call(&self, _: DispatchItem) -> Self::Future { + Ready::Ok(None) + } + } + + let (disp, state) = Dispatcher::debug(server, BytesCodec, Srv(counter.clone())); + state + .write() + .encode( + Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"), + &mut BytesCodec, + ) + .unwrap(); + spawn(async move { + let _ = disp.await; + }); + + // buffer should be flushed + client.remote_buffer_cap(1024); + let buf = client.read().await.unwrap(); + assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n")); + + // write side must be closed, dispatcher waiting for read side to close + assert!(client.is_closed()); + + // close read side + client.close().await; + assert!(client.is_server_dropped()); + + // service must be checked for readiness only once + assert_eq!(counter.get(), 1); + } + + #[ntex::test] + async fn test_write_backpressure() { + let (client, server) = IoTest::create(); + // do not allow to write to socket + client.remote_buffer_cap(0); + client.write("GET /test HTTP/1\r\n\r\n"); + + let data = Arc::new(Mutex::new(RefCell::new(Vec::new()))); + let data2 = data.clone(); + + let (disp, state) = Dispatcher::debug( + server, + BytesCodec, + ntex_service::fn_service(move |msg: DispatchItem| { + let data = data2.clone(); + async move { + match msg { + DispatchItem::Item(_) => { + data.lock().unwrap().borrow_mut().push(0); + let bytes = rand::thread_rng() + .sample_iter(&rand::distributions::Alphanumeric) + .take(65_536) + .map(char::from) + .collect::(); + return Ok::<_, ()>(Some(Bytes::from(bytes))); + } + DispatchItem::WBackPressureEnabled => { + data.lock().unwrap().borrow_mut().push(1); + } + DispatchItem::WBackPressureDisabled => { + data.lock().unwrap().borrow_mut().push(2); + } + _ => (), + } + Ok(None) + } + }), + ); + let pool = PoolId::P10.pool_ref(); + pool.set_read_params(8 * 1024, 1024); + pool.set_write_params(16 * 1024, 1024); + state.set_memory_pool(pool); + + spawn(async move { + let _ = disp.await; + }); + + let buf = client.read_any(); + assert_eq!(buf, Bytes::from_static(b"")); + client.write("GET /test HTTP/1\r\n\r\n"); + sleep(Millis(25)).await; + + // buf must be consumed + assert_eq!(client.remote_buffer(|buf| buf.len()), 0); + + // response message + assert!(!state.write().is_ready()); + assert_eq!(state.write().with_buf(|buf| buf.len()), 65536); + + client.remote_buffer_cap(10240); + sleep(Millis(50)).await; + assert_eq!(state.write().with_buf(|buf| buf.len()), 55296); + + client.remote_buffer_cap(45056); + sleep(Millis(50)).await; + assert_eq!(state.write().with_buf(|buf| buf.len()), 10240); + + // backpressure disabled + assert!(state.write().is_ready()); + assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1, 2]); + } + + #[ntex::test] + async fn test_keepalive() { + let (client, server) = IoTest::create(); + // do not allow to write to socket + client.remote_buffer_cap(1024); + client.write("GET /test HTTP/1\r\n\r\n"); + + let data = Arc::new(Mutex::new(RefCell::new(Vec::new()))); + let data2 = data.clone(); + + let (disp, state) = Dispatcher::debug( + server, + BytesCodec, + ntex_service::fn_service(move |msg: DispatchItem| { + let data = data2.clone(); + async move { + match msg { + DispatchItem::Item(bytes) => { + data.lock().unwrap().borrow_mut().push(0); + return Ok::<_, ()>(Some(bytes.freeze())); + } + DispatchItem::KeepAliveTimeout => { + data.lock().unwrap().borrow_mut().push(1); + } + _ => (), + } + Ok(None) + } + }), + ); + spawn(async move { + let _ = disp + .keepalive_timeout(Seconds::ZERO) + .keepalive_timeout(Seconds(1)) + .await; + }); + + state.0.disconnect_timeout.set(Seconds(1)); + + let buf = client.read().await.unwrap(); + assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n")); + sleep(Millis(3500)).await; + + // write side must be closed, dispatcher should fail with keep-alive + let flags = state.flags(); + assert!(flags.contains(Flags::IO_SHUTDOWN)); + assert!(flags.contains(Flags::DSP_KEEPALIVE)); + assert!(client.is_closed()); + assert_eq!(&data.lock().unwrap().borrow()[..], &[0, 1]); + } + + #[ntex::test] + async fn test_unhandled_data() { + let handled = Arc::new(AtomicBool::new(false)); + let handled2 = handled.clone(); + + let (client, server) = IoTest::create(); + client.remote_buffer_cap(1024); + client.write("GET /test HTTP/1\r\n\r\n"); + + let (disp, _) = Dispatcher::debug( + server, + BytesCodec, + ntex_service::fn_service(move |msg: DispatchItem| { + handled2.store(true, Relaxed); + async move { + sleep(Millis(50)).await; + if let DispatchItem::Item(msg) = msg { + Ok::<_, ()>(Some(msg.freeze())) + } else { + panic!() + } + } + }), + ); + client.close().await; + spawn(async move { + let _ = disp.await; + }); + sleep(Millis(50)).await; + + assert!(handled.load(Relaxed)); + } +} diff --git a/ntex-io/src/filter.rs b/ntex-io/src/filter.rs new file mode 100644 index 00000000..a17b38d9 --- /dev/null +++ b/ntex-io/src/filter.rs @@ -0,0 +1,151 @@ +use std::{io, rc::Rc, task::Context, task::Poll}; + +use ntex_bytes::BytesMut; + +use super::state::{Flags, IoStateInner}; +use super::{Filter, ReadFilter, WriteFilter, WriteReadiness}; + +pub struct DefaultFilter(Rc); + +impl DefaultFilter { + pub(crate) fn new(inner: Rc) -> Self { + DefaultFilter(inner) + } +} + +impl Filter for DefaultFilter {} + +impl ReadFilter for DefaultFilter { + #[inline] + fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll> { + let flags = self.0.flags.get(); + + if flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) { + Poll::Ready(Err(())) + } else if flags.intersects(Flags::RD_PAUSED) { + self.0.read_task.register(cx.waker()); + Poll::Pending + } else { + self.0.read_task.register(cx.waker()); + Poll::Ready(Ok(())) + } + } + + #[inline] + fn read_closed(&self, err: Option) { + if err.is_some() { + self.0.error.set(err); + } + self.0.write_task.wake(); + self.0.dispatch_task.wake(); + self.0.insert_flags(Flags::IO_ERR | Flags::DSP_STOP); + self.0.notify_disconnect(); + } + + #[inline] + fn get_read_buf(&self) -> Option { + self.0.read_buf.take() + } + + #[inline] + fn release_read_buf(&self, buf: BytesMut, new_bytes: usize) { + if new_bytes > 0 { + if buf.len() > self.0.pool.get().read_params().high as usize { + log::trace!( + "buffer is too large {}, enable read back-pressure", + buf.len() + ); + self.0.insert_flags(Flags::RD_READY | Flags::RD_BUF_FULL); + } else { + self.0.insert_flags(Flags::RD_READY); + } + self.0.dispatch_task.wake(); + } + self.0.read_buf.set(Some(buf)); + } +} + +impl WriteFilter for DefaultFilter { + #[inline] + fn poll_write_ready( + &self, + cx: &mut Context<'_>, + ) -> Poll> { + let flags = self.0.flags.get(); + + if flags.contains(Flags::IO_ERR) { + Poll::Ready(Err(WriteReadiness::Terminate)) + } else if flags.intersects(Flags::IO_SHUTDOWN) { + Poll::Ready(Err(WriteReadiness::Shutdown)) + } else { + self.0.write_task.register(cx.waker()); + Poll::Ready(Ok(())) + } + } + + #[inline] + fn write_closed(&self, err: Option) { + if err.is_some() { + self.0.error.set(err); + } + self.0.read_task.wake(); + self.0.dispatch_task.wake(); + self.0.insert_flags(Flags::IO_ERR | Flags::DSP_STOP); + self.0.notify_disconnect(); + } + + #[inline] + fn get_write_buf(&self) -> Option { + self.0.write_buf.take() + } + + #[inline] + fn release_write_buf(&self, buf: BytesMut) { + let pool = self.0.pool.get(); + if buf.is_empty() { + pool.release_write_buf(buf); + } else { + self.0.write_buf.set(Some(buf)); + } + } +} + +pub(crate) struct NullFilter; + +const NULL: NullFilter = NullFilter; + +impl NullFilter { + pub(super) fn get() -> &'static dyn Filter { + &NULL + } +} + +impl Filter for NullFilter {} + +impl ReadFilter for NullFilter { + fn poll_read_ready(&self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Err(())) + } + + fn read_closed(&self, _: Option) {} + + fn get_read_buf(&self) -> Option { + None + } + + fn release_read_buf(&self, _: BytesMut, _: usize) {} +} + +impl WriteFilter for NullFilter { + fn poll_write_ready(&self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Err(WriteReadiness::Terminate)) + } + + fn write_closed(&self, _: Option) {} + + fn get_write_buf(&self) -> Option { + None + } + + fn release_write_buf(&self, _: BytesMut) {} +} diff --git a/ntex-io/src/lib.rs b/ntex-io/src/lib.rs new file mode 100644 index 00000000..99d337d2 --- /dev/null +++ b/ntex-io/src/lib.rs @@ -0,0 +1,141 @@ +use std::{fmt, future::Future, io, task::Context, task::Poll}; + +pub mod testing; + +mod dispatcher; +mod filter; +mod state; +mod tasks; +mod time; +mod utils; + +#[cfg(feature = "tokio")] +mod tokio_impl; + +use ntex_bytes::BytesMut; +use ntex_codec::{Decoder, Encoder}; + +pub use self::dispatcher::Dispatcher; +pub use self::state::{Io, IoRef, ReadRef, WriteRef}; +pub use self::tasks::{ReadState, WriteState}; +pub use self::time::Timer; + +pub use self::utils::{from_iostream, into_boxed}; + +pub type IoBoxed = Io>; + +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub enum WriteReadiness { + Shutdown, + Terminate, +} + +pub trait ReadFilter { + fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll>; + + fn read_closed(&self, err: Option); + + fn get_read_buf(&self) -> Option; + + fn release_read_buf(&self, buf: BytesMut, new_bytes: usize); +} + +pub trait WriteFilter { + fn poll_write_ready(&self, cx: &mut Context<'_>) + -> Poll>; + + fn write_closed(&self, err: Option); + + fn get_write_buf(&self) -> Option; + + fn release_write_buf(&self, buf: BytesMut); +} + +pub trait Filter: ReadFilter + WriteFilter {} + +pub trait FilterFactory: Sized { + type Filter: Filter; + + type Error: fmt::Debug; + type Future: Future, Self::Error>>; + + fn create(&self, st: Io) -> Self::Future; +} + +pub trait IoStream { + fn start(self, _: ReadState, _: WriteState); +} + +/// Framed transport item +pub enum DispatchItem { + Item(::Item), + /// Write back-pressure enabled + WBackPressureEnabled, + /// Write back-pressure disabled + WBackPressureDisabled, + /// Keep alive timeout + KeepAliveTimeout, + /// Decoder parse error + DecoderError(::Error), + /// Encoder parse error + EncoderError(::Error), + /// Unexpected io error + IoError(io::Error), +} + +impl fmt::Debug for DispatchItem +where + U: Encoder + Decoder, + ::Item: fmt::Debug, +{ + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + DispatchItem::Item(ref item) => { + write!(fmt, "DispatchItem::Item({:?})", item) + } + DispatchItem::WBackPressureEnabled => { + write!(fmt, "DispatchItem::WBackPressureEnabled") + } + DispatchItem::WBackPressureDisabled => { + write!(fmt, "DispatchItem::WBackPressureDisabled") + } + DispatchItem::KeepAliveTimeout => { + write!(fmt, "DispatchItem::KeepAliveTimeout") + } + DispatchItem::EncoderError(ref e) => { + write!(fmt, "DispatchItem::EncoderError({:?})", e) + } + DispatchItem::DecoderError(ref e) => { + write!(fmt, "DispatchItem::DecoderError({:?})", e) + } + DispatchItem::IoError(ref e) => { + write!(fmt, "DispatchItem::IoError({:?})", e) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use ntex_codec::BytesCodec; + + #[test] + fn test_fmt() { + type T = DispatchItem; + + let err = T::EncoderError(io::Error::new(io::ErrorKind::Other, "err")); + assert!(format!("{:?}", err).contains("DispatchItem::Encoder")); + let err = T::DecoderError(io::Error::new(io::ErrorKind::Other, "err")); + assert!(format!("{:?}", err).contains("DispatchItem::Decoder")); + let err = T::IoError(io::Error::new(io::ErrorKind::Other, "err")); + assert!(format!("{:?}", err).contains("DispatchItem::IoError")); + + assert!(format!("{:?}", T::WBackPressureEnabled) + .contains("DispatchItem::WBackPressureEnabled")); + assert!(format!("{:?}", T::WBackPressureDisabled) + .contains("DispatchItem::WBackPressureDisabled")); + assert!(format!("{:?}", T::KeepAliveTimeout) + .contains("DispatchItem::KeepAliveTimeout")); + } +} diff --git a/ntex-io/src/state.rs b/ntex-io/src/state.rs new file mode 100644 index 00000000..251cf7a5 --- /dev/null +++ b/ntex-io/src/state.rs @@ -0,0 +1,1097 @@ +use std::cell::{Cell, RefCell}; +use std::task::{Context, Poll}; +use std::{future::Future, hash, io, mem, ops::Deref, pin::Pin, ptr, rc::Rc}; + +use ntex_bytes::{BytesMut, PoolId, PoolRef}; +use ntex_codec::{Decoder, Encoder}; +use ntex_util::{future::poll_fn, future::Either, task::LocalWaker, time::Seconds}; + +use super::filter::{DefaultFilter, NullFilter}; +use super::tasks::{ReadState, WriteState}; +use super::{Filter, FilterFactory, IoStream}; + +bitflags::bitflags! { + pub struct Flags: u16 { + /// io error occured + const IO_ERR = 0b0000_0000_0000_0001; + /// shutdown io tasks + const IO_SHUTDOWN = 0b0000_0000_0000_0100; + + /// pause io read + const RD_PAUSED = 0b0000_0000_0000_1000; + /// new data is available + const RD_READY = 0b0000_0000_0001_0000; + /// read buffer is full + const RD_BUF_FULL = 0b0000_0000_0010_0000; + + /// wait write completion + const WR_WAIT = 0b0000_0001_0000_0000; + /// write buffer is full + const WR_BACKPRESSURE = 0b0000_0010_0000_0000; + + /// dispatcher is marked stopped + const DSP_STOP = 0b0001_0000_0000_0000; + /// keep-alive timeout occured + const DSP_KEEPALIVE = 0b0010_0000_0000_0000; + /// dispatcher returned error + const DSP_ERR = 0b0100_0000_0000_0000; + } +} + +enum FilterItem { + Boxed(Box), + Ptr(*mut F), +} + +pub struct Io(pub(super) IoRef, FilterItem); + +#[derive(Clone)] +pub struct IoRef(pub(super) Rc); + +pub(crate) struct IoStateInner { + pub(super) flags: Cell, + pub(super) pool: Cell, + pub(super) disconnect_timeout: Cell, + pub(super) error: Cell>, + pub(super) read_task: LocalWaker, + pub(super) write_task: LocalWaker, + pub(super) dispatch_task: LocalWaker, + pub(super) read_buf: Cell>, + pub(super) write_buf: Cell>, + pub(super) filter: Cell<&'static dyn Filter>, + on_disconnect: RefCell>>, +} + +impl IoStateInner { + #[inline] + pub(super) fn insert_flags(&self, f: Flags) { + let mut flags = self.flags.get(); + flags.insert(f); + self.flags.set(flags); + } + + #[inline] + pub(super) fn remove_flags(&self, f: Flags) { + let mut flags = self.flags.get(); + flags.remove(f); + self.flags.set(flags); + } + + #[inline] + pub(super) fn notify_disconnect(&self) { + let mut on_disconnect = self.on_disconnect.borrow_mut(); + for item in &mut *on_disconnect { + if let Some(waker) = item.take() { + waker.wake(); + } + } + } +} + +impl Eq for IoStateInner {} + +impl PartialEq for IoStateInner { + #[inline] + fn eq(&self, other: &Self) -> bool { + ptr::eq(self, other) + } +} + +impl hash::Hash for IoStateInner { + #[inline] + fn hash(&self, state: &mut H) { + (self as *const _ as usize).hash(state); + } +} + +impl Drop for IoStateInner { + #[inline] + fn drop(&mut self) { + if let Some(buf) = self.read_buf.take() { + self.pool.get().release_read_buf(buf); + } + if let Some(buf) = self.write_buf.take() { + self.pool.get().release_write_buf(buf); + } + } +} + +impl Io { + #[inline] + /// Create `State` instance + pub fn new(io: I) -> Self { + Self::with_memory_pool(io, PoolId::DEFAULT.pool_ref()) + } + + #[inline] + /// Create `State` instance with specific memory pool. + pub fn with_memory_pool(io: I, pool: PoolRef) -> Self { + let inner = Rc::new(IoStateInner { + pool: Cell::new(pool), + flags: Cell::new(Flags::empty()), + error: Cell::new(None), + disconnect_timeout: Cell::new(Seconds(1)), + dispatch_task: LocalWaker::new(), + read_task: LocalWaker::new(), + write_task: LocalWaker::new(), + read_buf: Cell::new(None), + write_buf: Cell::new(None), + filter: Cell::new(NullFilter::get()), + on_disconnect: RefCell::new(Vec::new()), + }); + + let filter = Box::new(DefaultFilter::new(inner.clone())); + let filter_ref: &'static dyn Filter = unsafe { + let filter: &dyn Filter = filter.as_ref(); + std::mem::transmute(filter) + }; + inner.filter.replace(filter_ref); + + // start io tasks + io.start(ReadState(inner.clone()), WriteState(inner.clone())); + + Io(IoRef(inner), FilterItem::Ptr(Box::into_raw(filter))) + } +} + +impl Io { + #[inline] + /// Set memory pool + pub fn set_memory_pool(&self, pool: PoolRef) { + if let Some(mut buf) = self.0 .0.read_buf.take() { + pool.move_in(&mut buf); + self.0 .0.read_buf.set(Some(buf)); + } + if let Some(mut buf) = self.0 .0.write_buf.take() { + pool.move_in(&mut buf); + self.0 .0.write_buf.set(Some(buf)); + } + self.0 .0.pool.set(pool); + } + + #[inline] + /// Set io disconnect timeout in secs + pub fn set_disconnect_timeout(&self, timeout: Seconds) { + self.0 .0.disconnect_timeout.set(timeout); + } +} + +impl Io { + #[inline] + #[allow(clippy::should_implement_trait)] + /// Get IoRef reference + pub fn as_ref(&self) -> &IoRef { + &self.0 + } + + #[inline] + /// Get instance of IoRef + pub fn get_ref(&self) -> IoRef { + self.0.clone() + } + + #[inline] + /// Register dispatcher task + pub fn register_dispatcher(&self, cx: &mut Context<'_>) { + self.0 .0.dispatch_task.register(cx.waker()); + } + + #[inline] + /// Mark dispatcher as stopped + pub fn dispatcher_stopped(&self) { + self.0 .0.insert_flags(Flags::DSP_STOP); + } +} + +impl IoRef { + #[inline] + #[doc(hidden)] + /// Get current state flags + pub fn flags(&self) -> Flags { + self.0.flags.get() + } + + #[inline] + /// Get memory pool + pub fn memory_pool(&self) -> PoolRef { + self.0.pool.get() + } + + #[inline] + /// Check if io error occured in read or write task + pub fn is_io_err(&self) -> bool { + self.0.flags.get().contains(Flags::IO_ERR) + } + + #[inline] + /// Check if keep-alive timeout occured + pub fn is_keepalive(&self) -> bool { + self.0.flags.get().contains(Flags::DSP_KEEPALIVE) + } + + #[inline] + /// Check if dispatcher marked stopped + pub fn is_dispatcher_stopped(&self) -> bool { + self.0.flags.get().contains(Flags::DSP_STOP) + } + + #[inline] + /// Check if io stream is closed + pub fn is_closed(&self) -> bool { + self.0 + .flags + .get() + .intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN | Flags::DSP_STOP) + } + + #[inline] + /// Gracefully shutdown read and write io tasks + pub fn shutdown(&self, cx: &mut Context<'_>) { + let flags = self.0.flags.get(); + + if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) { + log::trace!("initiate io shutdown {:?}", flags); + self.0.insert_flags(Flags::IO_SHUTDOWN); + self.0.read_task.wake(); + self.0.write_task.wake(); + self.0.dispatch_task.register(cx.waker()); + } + } + + #[inline] + /// Take io error if any occured + pub fn take_error(&self) -> Option { + self.0.error.take() + } + + #[inline] + /// Reset keep-alive error + pub fn reset_keepalive(&self) { + self.0.remove_flags(Flags::DSP_KEEPALIVE) + } + + #[inline] + /// Get api for read task + pub fn read(&'_ self) -> ReadRef<'_> { + ReadRef(self.0.as_ref()) + } + + #[inline] + /// Get api for write task + pub fn write(&'_ self) -> WriteRef<'_> { + WriteRef(self.0.as_ref()) + } + + #[inline] + /// Gracefully close connection + /// + /// First stop dispatcher, then dispatcher stops io tasks + pub fn close(&self) { + self.0.insert_flags(Flags::DSP_STOP); + self.0.dispatch_task.wake(); + } + + #[inline] + /// Force close connection + /// + /// Dispatcher does not wait for uncompleted responses, but flushes io buffers. + pub fn force_close(&self) { + log::trace!("force close framed object"); + self.0.insert_flags(Flags::DSP_STOP | Flags::IO_SHUTDOWN); + self.0.read_task.wake(); + self.0.write_task.wake(); + self.0.dispatch_task.wake(); + } + + #[inline] + /// Notify when io stream get disconnected + pub fn on_disconnect(&self) -> OnDisconnect { + OnDisconnect::new(self.0.clone(), self.0.flags.get().contains(Flags::IO_ERR)) + } +} + +impl Io { + #[inline] + /// Read incoming io stream and decode codec item. + pub async fn next( + &self, + codec: &U, + ) -> Result, Either> + where + U: Decoder, + { + let read = self.read(); + + loop { + let mut buf = self.0 .0.read_buf.take(); + let item = if let Some(ref mut buf) = buf { + codec.decode(buf) + } else { + Ok(None) + }; + self.0 .0.read_buf.set(buf); + + let result = match item { + Ok(Some(el)) => Ok(Some(el)), + Ok(None) => { + self.0 .0.remove_flags(Flags::RD_READY); + poll_fn(|cx| { + if read.is_ready() { + Poll::Ready(()) + } else { + read.wake(cx); + Poll::Pending + } + }) + .await; + if self.is_io_err() { + if let Some(err) = self.take_error() { + Err(Either::Right(err)) + } else { + Ok(None) + } + } else { + continue; + } + } + Err(err) => Err(Either::Left(err)), + }; + return result; + } + } + + #[inline] + /// Encode item, send to a peer + pub async fn send( + &self, + codec: &U, + item: U::Item, + ) -> Result<(), Either> + where + U: Encoder, + { + let filter = self.0 .0.filter.get(); + let mut buf = filter + .get_write_buf() + .unwrap_or_else(|| self.0 .0.pool.get().get_write_buf()); + let is_write_sleep = buf.is_empty(); + codec.encode(item, &mut buf).map_err(Either::Left)?; + filter.release_write_buf(buf); + self.0 .0.insert_flags(Flags::WR_WAIT); + if is_write_sleep { + self.0 .0.write_task.wake(); + } + + poll_fn(|cx| { + if !self.0 .0.flags.get().contains(Flags::WR_WAIT) || self.is_io_err() { + Poll::Ready(()) + } else { + self.register_dispatcher(cx); + Poll::Pending + } + }) + .await; + + if self.is_io_err() { + let err = self.0 .0.error.take().unwrap_or_else(|| { + io::Error::new(io::ErrorKind::Other, "Internal error") + }); + Err(Either::Right(err)) + } else { + Ok(()) + } + } + + #[inline] + #[allow(clippy::type_complexity)] + pub fn poll_next( + &self, + codec: &U, + cx: &mut Context<'_>, + ) -> Poll, Either>> + where + U: Decoder, + { + let mut buf = self.0 .0.read_buf.take(); + let item = if let Some(ref mut buf) = buf { + codec.decode(buf) + } else { + Ok(None) + }; + self.0 .0.read_buf.set(buf); + + match item { + Ok(Some(el)) => Poll::Ready(Ok(Some(el))), + Ok(None) => { + self.read().wake(cx); + Poll::Pending + } + Err(err) => Poll::Ready(Err(Either::Left(err))), + } + } +} + +impl Io { + #[inline] + pub fn into_boxed(mut self) -> crate::IoBoxed + where + F: 'static, + { + // get current filter + let filter = unsafe { + let item = mem::replace(&mut self.1, FilterItem::Ptr(std::ptr::null_mut())); + let filter: Box = match item { + FilterItem::Boxed(b) => b, + FilterItem::Ptr(p) => Box::new(*Box::from_raw(p)), + }; + + let filter_ref: &'static dyn Filter = { + let filter: &dyn Filter = filter.as_ref(); + std::mem::transmute(filter) + }; + self.0 .0.filter.replace(filter_ref); + filter + }; + + Io(self.0.clone(), FilterItem::Boxed(filter)) + } + + #[inline] + pub async fn add_filter(self, factory: &T) -> Result, T::Error> + where + T: FilterFactory, + { + factory.create(self).await + } + + #[inline] + pub fn map_filter(mut self, map: T) -> Io + where + T: FnOnce(F) -> U, + U: Filter, + { + // replace current filter + let filter = unsafe { + let item = mem::replace(&mut self.1, FilterItem::Ptr(std::ptr::null_mut())); + let filter = match item { + FilterItem::Boxed(_) => panic!(), + FilterItem::Ptr(p) => { + assert!(!p.is_null()); + Box::new(map(*Box::from_raw(p))) + } + }; + let filter_ref: &'static dyn Filter = { + let filter: &dyn Filter = filter.as_ref(); + std::mem::transmute(filter) + }; + self.0 .0.filter.replace(filter_ref); + filter + }; + + Io(self.0.clone(), FilterItem::Ptr(Box::into_raw(filter))) + } +} + +impl Drop for Io { + fn drop(&mut self) { + log::trace!("stopping io stream"); + if let FilterItem::Ptr(p) = self.1 { + if p.is_null() { + return; + } + self.force_close(); + self.0 .0.filter.set(NullFilter::get()); + let _ = mem::replace(&mut self.1, FilterItem::Ptr(std::ptr::null_mut())); + unsafe { Box::from_raw(p) }; + } else { + self.force_close(); + self.0 .0.filter.set(NullFilter::get()); + } + } +} + +impl Deref for Io { + type Target = IoRef; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[derive(Copy, Clone)] +pub struct WriteRef<'a>(pub(super) &'a IoStateInner); + +impl<'a> WriteRef<'a> { + #[inline] + /// Check if write task is ready + pub fn is_ready(&self) -> bool { + !self.0.flags.get().contains(Flags::WR_BACKPRESSURE) + } + + #[inline] + /// Check if write buffer is full + pub fn is_full(&self) -> bool { + if let Some(buf) = self.0.read_buf.take() { + let hw = self.0.pool.get().write_params_high(); + let result = buf.len() >= hw; + self.0.write_buf.set(Some(buf)); + result + } else { + false + } + } + + #[inline] + /// Wake dispatcher task + pub fn wake_dispatcher(&self) { + self.0.dispatch_task.wake(); + } + + #[inline] + /// Wait until write task flushes data to io stream + /// + /// Write task must be waken up separately. + pub fn enable_backpressure(&self, cx: Option<&mut Context<'_>>) { + log::trace!("enable write back-pressure"); + self.0.insert_flags(Flags::WR_BACKPRESSURE); + if let Some(cx) = cx { + self.0.dispatch_task.register(cx.waker()); + } + } + + #[inline] + /// Get mut access to write buffer + pub fn with_buf(&self, f: F) -> R + where + F: FnOnce(&mut BytesMut) -> R, + { + let filter = self.0.filter.get(); + let mut buf = filter + .get_write_buf() + .unwrap_or_else(|| self.0.pool.get().get_write_buf()); + if buf.is_empty() { + self.0.write_task.wake(); + } + + let result = f(&mut buf); + filter.release_write_buf(buf); + result + } + + #[inline] + /// Write item to a buffer and wake up write task + /// + /// Returns write buffer state, false is returned if write buffer if full. + pub fn encode( + &self, + item: U::Item, + codec: &U, + ) -> Result::Error> + where + U: Encoder, + { + let flags = self.0.flags.get(); + + if !flags.intersects(Flags::IO_ERR | Flags::IO_SHUTDOWN) { + let filter = self.0.filter.get(); + let mut buf = filter + .get_write_buf() + .unwrap_or_else(|| self.0.pool.get().get_write_buf()); + let is_write_sleep = buf.is_empty(); + let (hw, lw) = self.0.pool.get().write_params().unpack(); + + // make sure we've got room + let remaining = buf.capacity() - buf.len(); + if remaining < lw { + buf.reserve(hw - remaining); + } + + // encode item and wake write task + let result = codec.encode(item, &mut buf).map(|_| { + if is_write_sleep { + self.0.write_task.wake(); + } + buf.len() < hw + }); + filter.release_write_buf(buf); + result + } else { + Ok(true) + } + } + + #[inline] + /// Write item to a buf and wake up io task + pub fn encode_result( + &self, + item: Result, E>, + codec: &U, + ) -> Result> + where + U: Encoder, + { + let flags = self.0.flags.get(); + + if !flags.intersects(Flags::IO_ERR | Flags::DSP_ERR) { + match item { + Ok(Some(item)) => { + let filter = self.0.filter.get(); + let mut buf = filter + .get_write_buf() + .unwrap_or_else(|| self.0.pool.get().get_write_buf()); + let is_write_sleep = buf.is_empty(); + let (hw, lw) = self.0.pool.get().write_params().unpack(); + + // make sure we've got room + let remaining = buf.capacity() - buf.len(); + if remaining < lw { + buf.reserve(hw - remaining); + } + + // encode item + if let Err(err) = codec.encode(item, &mut buf) { + log::trace!("Encoder error: {:?}", err); + filter.release_write_buf(buf); + self.0.insert_flags(Flags::DSP_STOP | Flags::DSP_ERR); + self.0.dispatch_task.wake(); + return Err(Either::Right(err)); + } else if is_write_sleep { + self.0.write_task.wake(); + } + let result = Ok(buf.len() < hw); + filter.release_write_buf(buf); + result + } + Err(err) => { + self.0.insert_flags(Flags::DSP_STOP | Flags::DSP_ERR); + self.0.dispatch_task.wake(); + Err(Either::Left(err)) + } + _ => Ok(true), + } + } else { + Ok(true) + } + } +} + +#[derive(Copy, Clone)] +pub struct ReadRef<'a>(&'a IoStateInner); + +impl<'a> ReadRef<'a> { + #[inline] + /// Check if read buffer has new data + pub fn is_ready(&self) -> bool { + self.0.flags.get().contains(Flags::RD_READY) + } + + #[inline] + /// Check if read buffer is full + pub fn is_full(&self) -> bool { + if let Some(buf) = self.0.read_buf.take() { + let result = buf.len() >= self.0.pool.get().read_params_high(); + self.0.read_buf.set(Some(buf)); + result + } else { + false + } + } + + #[inline] + /// Pause read task + /// + /// Also register dispatch task + pub fn pause(&self, cx: &mut Context<'_>) { + self.0.insert_flags(Flags::RD_PAUSED); + self.0.dispatch_task.register(cx.waker()); + } + + #[inline] + /// Wake read io task if it is paused + pub fn resume(&self) -> bool { + let flags = self.0.flags.get(); + if flags.contains(Flags::RD_PAUSED) { + self.0.remove_flags(Flags::RD_PAUSED); + self.0.read_task.wake(); + true + } else { + false + } + } + + #[inline] + /// Wake read task and instruct to read more data + /// + /// Only wakes if back-pressure is enabled on read task + /// otherwise read is already awake. + pub fn wake(&self, cx: &mut Context<'_>) { + let mut flags = self.0.flags.get(); + flags.remove(Flags::RD_READY); + if flags.contains(Flags::RD_BUF_FULL) { + log::trace!("read back-pressure is enabled, wake io task"); + flags.remove(Flags::RD_BUF_FULL); + self.0.read_task.wake(); + } + if flags.contains(Flags::RD_PAUSED) { + log::trace!("read is paused, wake io task"); + flags.remove(Flags::RD_PAUSED); + self.0.read_task.wake(); + } + self.0.flags.set(flags); + self.0.dispatch_task.register(cx.waker()); + } + + #[inline] + /// Attempts to decode a frame from the read buffer. + pub fn decode( + &self, + codec: &U, + ) -> Result::Item>, ::Error> + where + U: Decoder, + { + let mut buf = self.0.read_buf.take(); + let result = if let Some(ref mut buf) = buf { + codec.decode(buf) + } else { + self.0.remove_flags(Flags::RD_READY); + Ok(None) + }; + self.0.read_buf.set(buf); + result + } + + #[inline] + /// Get mut access to read buffer + pub fn with_buf(&self, f: F) -> R + where + F: FnOnce(&mut BytesMut) -> R, + { + let mut buf = self + .0 + .read_buf + .take() + .unwrap_or_else(|| self.0.pool.get().get_read_buf()); + let res = f(&mut buf); + if buf.is_empty() { + self.0.pool.get().release_read_buf(buf); + } else { + self.0.read_buf.set(Some(buf)); + } + res + } +} + +/// OnDisconnect future resolves when socket get disconnected +#[must_use = "OnDisconnect do nothing unless polled"] +pub struct OnDisconnect { + token: usize, + inner: Rc, +} + +impl OnDisconnect { + fn new(inner: Rc, disconnected: bool) -> Self { + let token = if disconnected { + usize::MAX + } else { + let mut on_disconnect = inner.on_disconnect.borrow_mut(); + let token = on_disconnect.len(); + on_disconnect.push(Some(LocalWaker::default())); + drop(on_disconnect); + token + }; + Self { token, inner } + } + + #[inline] + /// Check if connection is disconnected + pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<()> { + if self.token == usize::MAX { + Poll::Ready(()) + } else { + let on_disconnect = self.inner.on_disconnect.borrow(); + if on_disconnect[self.token].is_some() { + on_disconnect[self.token] + .as_ref() + .unwrap() + .register(cx.waker()); + Poll::Pending + } else { + Poll::Ready(()) + } + } + } +} + +impl Clone for OnDisconnect { + fn clone(&self) -> Self { + if self.token == usize::MAX { + OnDisconnect::new(self.inner.clone(), true) + } else { + OnDisconnect::new(self.inner.clone(), false) + } + } +} + +impl Future for OnDisconnect { + type Output = (); + + #[inline] + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.poll_ready(cx) + } +} + +impl Drop for OnDisconnect { + fn drop(&mut self) { + if self.token != usize::MAX { + self.inner.on_disconnect.borrow_mut()[self.token].take(); + } + } +} + +#[cfg(test)] +mod tests { + use ntex_bytes::Bytes; + use ntex_codec::BytesCodec; + use ntex_util::future::{lazy, Ready}; + + use super::*; + use crate::testing::IoTest; + use crate::{Filter, FilterFactory, ReadFilter, WriteFilter, WriteReadiness}; + + const BIN: &[u8] = b"GET /test HTTP/1\r\n\r\n"; + const TEXT: &str = "GET /test HTTP/1\r\n\r\n"; + + #[ntex::test] + async fn utils() { + let (client, server) = IoTest::create(); + client.remote_buffer_cap(1024); + client.write(TEXT); + + let state = Io::new(server); + assert!(!state.read().is_full()); + assert!(!state.write().is_full()); + + let msg = state.next(&BytesCodec).await.unwrap().unwrap(); + assert_eq!(msg, Bytes::from_static(BIN)); + + let res = poll_fn(|cx| Poll::Ready(state.poll_next(&BytesCodec, cx))).await; + assert!(res.is_pending()); + client.write(TEXT); + let res = poll_fn(|cx| Poll::Ready(state.poll_next(&BytesCodec, cx))).await; + if let Poll::Ready(msg) = res { + assert_eq!(msg.unwrap().unwrap(), Bytes::from_static(BIN)); + } + + client.read_error(io::Error::new(io::ErrorKind::Other, "err")); + let msg = state.next(&BytesCodec).await; + assert!(msg.is_err()); + assert!(state.flags().contains(Flags::IO_ERR)); + assert!(state.flags().contains(Flags::DSP_STOP)); + + let (client, server) = IoTest::create(); + client.remote_buffer_cap(1024); + let state = Io::new(server); + + client.read_error(io::Error::new(io::ErrorKind::Other, "err")); + let res = poll_fn(|cx| Poll::Ready(state.poll_next(&BytesCodec, cx))).await; + if let Poll::Ready(msg) = res { + assert!(msg.is_err()); + assert!(state.flags().contains(Flags::IO_ERR)); + assert!(state.flags().contains(Flags::DSP_STOP)); + } + + let (client, server) = IoTest::create(); + client.remote_buffer_cap(1024); + let state = Io::new(server); + state + .send(&BytesCodec, Bytes::from_static(b"test")) + .await + .unwrap(); + let buf = client.read().await.unwrap(); + assert_eq!(buf, Bytes::from_static(b"test")); + + client.write_error(io::Error::new(io::ErrorKind::Other, "err")); + let res = state.send(&BytesCodec, Bytes::from_static(b"test")).await; + assert!(res.is_err()); + assert!(state.flags().contains(Flags::IO_ERR)); + assert!(state.flags().contains(Flags::DSP_STOP)); + + let (client, server) = IoTest::create(); + client.remote_buffer_cap(1024); + let state = Io::new(server); + state.force_close(); + assert!(state.flags().contains(Flags::DSP_STOP)); + assert!(state.flags().contains(Flags::IO_SHUTDOWN)); + } + + #[ntex::test] + async fn on_disconnect() { + let (client, server) = IoTest::create(); + let state = Io::new(server); + let mut waiter = state.on_disconnect(); + assert_eq!( + lazy(|cx| Pin::new(&mut waiter).poll(cx)).await, + Poll::Pending + ); + let mut waiter2 = waiter.clone(); + assert_eq!( + lazy(|cx| Pin::new(&mut waiter2).poll(cx)).await, + Poll::Pending + ); + client.close().await; + assert_eq!(waiter.await, ()); + assert_eq!(waiter2.await, ()); + + let mut waiter = state.on_disconnect(); + assert_eq!( + lazy(|cx| Pin::new(&mut waiter).poll(cx)).await, + Poll::Ready(()) + ); + + let (client, server) = IoTest::create(); + let state = Io::new(server); + let mut waiter = state.on_disconnect(); + assert_eq!( + lazy(|cx| Pin::new(&mut waiter).poll(cx)).await, + Poll::Pending + ); + client.read_error(io::Error::new(io::ErrorKind::Other, "err")); + assert_eq!(waiter.await, ()); + } + + struct Counter { + inner: F, + in_bytes: Rc>, + out_bytes: Rc>, + } + impl Filter for Counter {} + + impl ReadFilter for Counter { + fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_read_ready(cx) + } + + fn read_closed(&self, err: Option) { + self.inner.read_closed(err) + } + + fn get_read_buf(&self) -> Option { + self.inner.get_read_buf() + } + + fn release_read_buf(&self, buf: BytesMut, new_bytes: usize) { + self.in_bytes.set(self.in_bytes.get() + new_bytes); + self.inner.release_read_buf(buf, new_bytes); + } + } + + impl WriteFilter for Counter { + fn poll_write_ready( + &self, + cx: &mut Context<'_>, + ) -> Poll> { + self.inner.poll_write_ready(cx) + } + + fn write_closed(&self, err: Option) { + self.inner.write_closed(err) + } + + fn get_write_buf(&self) -> Option { + if let Some(buf) = self.inner.get_write_buf() { + self.out_bytes.set(self.out_bytes.get() - buf.len()); + Some(buf) + } else { + None + } + } + + fn release_write_buf(&self, buf: BytesMut) { + self.out_bytes.set(self.out_bytes.get() + buf.len()); + self.inner.release_write_buf(buf); + } + } + + struct CounterFactory(Rc>, Rc>); + + impl FilterFactory for CounterFactory { + type Filter = Counter; + + type Error = (); + type Future = Ready>, Self::Error>; + + fn create(&self, st: Io) -> Self::Future { + let in_bytes = self.0.clone(); + let out_bytes = self.1.clone(); + Ready::Ok(st.map_filter(|inner| Counter { + inner, + in_bytes, + out_bytes, + })) + } + } + + #[ntex::test] + async fn filter() { + let in_bytes = Rc::new(Cell::new(0)); + let out_bytes = Rc::new(Cell::new(0)); + let factory = CounterFactory(in_bytes.clone(), out_bytes.clone()); + + let (client, server) = IoTest::create(); + let state = Io::new(server).add_filter(&factory).await.unwrap(); + + client.remote_buffer_cap(1024); + client.write(TEXT); + let msg = state.next(&BytesCodec).await.unwrap().unwrap(); + assert_eq!(msg, Bytes::from_static(BIN)); + + state + .send(&BytesCodec, Bytes::from_static(b"test")) + .await + .unwrap(); + let buf = client.read().await.unwrap(); + assert_eq!(buf, Bytes::from_static(b"test")); + + assert_eq!(in_bytes.get(), BIN.len()); + assert_eq!(out_bytes.get(), 4); + } + + #[ntex::test] + async fn boxed_filter() { + let in_bytes = Rc::new(Cell::new(0)); + let out_bytes = Rc::new(Cell::new(0)); + + let (client, server) = IoTest::create(); + let state = Io::new(server) + .add_filter(&CounterFactory(in_bytes.clone(), out_bytes.clone())) + .await + .unwrap() + .add_filter(&CounterFactory(in_bytes.clone(), out_bytes.clone())) + .await + .unwrap(); + let state = state.into_boxed(); + + client.remote_buffer_cap(1024); + client.write(TEXT); + let msg = state.next(&BytesCodec).await.unwrap().unwrap(); + assert_eq!(msg, Bytes::from_static(BIN)); + + state + .send(&BytesCodec, Bytes::from_static(b"test")) + .await + .unwrap(); + let buf = client.read().await.unwrap(); + assert_eq!(buf, Bytes::from_static(b"test")); + + assert_eq!(in_bytes.get(), BIN.len() * 2); + assert_eq!(out_bytes.get(), 8); + + // refs + assert_eq!(Rc::strong_count(&in_bytes), 3); + drop(state); + assert_eq!(Rc::strong_count(&in_bytes), 1); + } +} diff --git a/ntex-io/src/tasks.rs b/ntex-io/src/tasks.rs new file mode 100644 index 00000000..8795ab8d --- /dev/null +++ b/ntex-io/src/tasks.rs @@ -0,0 +1,98 @@ +use std::{io, rc::Rc, task::Context, task::Poll}; + +use ntex_bytes::{BytesMut, PoolRef}; +use ntex_util::time::Seconds; + +use super::{state::Flags, state::IoStateInner, WriteReadiness}; + +pub struct ReadState(pub(super) Rc); + +impl ReadState { + #[inline] + pub fn memory_pool(&self) -> PoolRef { + self.0.pool.get() + } + + #[inline] + pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { + self.0.filter.get().poll_read_ready(cx) + } + + #[inline] + pub fn close(&self, err: Option) { + self.0.filter.get().read_closed(err); + } + + #[inline] + pub fn get_read_buf(&self) -> BytesMut { + self.0 + .filter + .get() + .get_read_buf() + .unwrap_or_else(|| self.0.pool.get().get_read_buf()) + } + + #[inline] + pub fn release_read_buf(&self, buf: BytesMut, new_bytes: usize) { + if buf.is_empty() { + self.0.pool.get().release_read_buf(buf); + } else { + self.0.filter.get().release_read_buf(buf, new_bytes); + } + } +} + +pub struct WriteState(pub(super) Rc); + +impl WriteState { + #[inline] + pub fn memory_pool(&self) -> PoolRef { + self.0.pool.get() + } + + #[inline] + pub fn disconnect_timeout(&self) -> Seconds { + self.0.disconnect_timeout.get() + } + + #[inline] + pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { + self.0.filter.get().poll_write_ready(cx) + } + + #[inline] + pub fn close(&self, err: Option) { + self.0.filter.get().write_closed(err) + } + + #[inline] + pub fn get_write_buf(&self) -> Option { + self.0.write_buf.take() + } + + #[inline] + pub fn release_write_buf(&self, buf: BytesMut) { + let pool = self.0.pool.get(); + if buf.is_empty() { + pool.release_write_buf(buf); + + let mut flags = self.0.flags.get(); + if flags.intersects(Flags::WR_WAIT | Flags::WR_BACKPRESSURE) { + flags.remove(Flags::WR_WAIT | Flags::WR_BACKPRESSURE); + self.0.flags.set(flags); + self.0.dispatch_task.wake(); + } + } else { + // if write buffer is smaller than high watermark value, turn off back-pressure + if buf.len() < pool.write_params_high() << 1 { + let mut flags = self.0.flags.get(); + if flags.contains(Flags::WR_BACKPRESSURE) { + flags.remove(Flags::WR_BACKPRESSURE); + self.0.flags.set(flags); + self.0.dispatch_task.wake(); + } + } + self.0.write_buf.set(Some(buf)) + } + } +} diff --git a/ntex-io/src/testing.rs b/ntex-io/src/testing.rs new file mode 100644 index 00000000..173fe516 --- /dev/null +++ b/ntex-io/src/testing.rs @@ -0,0 +1,746 @@ +use std::cell::{Cell, RefCell}; +use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll, Waker}; +use std::{cmp, fmt, io, mem}; + +use ntex_bytes::{BufMut, BytesMut}; +use ntex_util::future::poll_fn; +use ntex_util::time::{sleep, Millis}; + +#[derive(Default)] +struct AtomicWaker(Arc>>>); + +impl AtomicWaker { + fn wake(&self) { + if let Some(waker) = self.0.lock().unwrap().borrow_mut().take() { + waker.wake() + } + } +} + +impl fmt::Debug for AtomicWaker { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "AtomicWaker") + } +} + +/// Async io stream +#[derive(Debug)] +pub struct IoTest { + tp: Type, + state: Arc>, + local: Arc>>, + remote: Arc>>, +} + +bitflags::bitflags! { + struct Flags: u8 { + const FLUSHED = 0b0000_0001; + const CLOSED = 0b0000_0010; + } +} + +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +enum Type { + Client, + Server, + ClientClone, + ServerClone, +} + +#[derive(Copy, Clone, Default, Debug)] +struct State { + client_dropped: bool, + server_dropped: bool, +} + +#[derive(Default, Debug)] +struct Channel { + buf: BytesMut, + buf_cap: usize, + flags: Flags, + waker: AtomicWaker, + read: IoState, + write: IoState, +} + +impl Channel { + fn is_closed(&self) -> bool { + self.flags.contains(Flags::CLOSED) + } +} + +impl Default for Flags { + fn default() -> Self { + Flags::empty() + } +} + +#[derive(Debug)] +enum IoState { + Ok, + Pending, + Close, + Err(io::Error), +} + +impl Default for IoState { + fn default() -> Self { + IoState::Ok + } +} + +impl IoTest { + /// Create a two interconnected streams + pub fn create() -> (IoTest, IoTest) { + let local = Arc::new(Mutex::new(RefCell::new(Channel::default()))); + let remote = Arc::new(Mutex::new(RefCell::new(Channel::default()))); + let state = Arc::new(Cell::new(State::default())); + + ( + IoTest { + tp: Type::Client, + local: local.clone(), + remote: remote.clone(), + state: state.clone(), + }, + IoTest { + state, + tp: Type::Server, + local: remote, + remote: local, + }, + ) + } + + pub fn is_client_dropped(&self) -> bool { + self.state.get().client_dropped + } + + pub fn is_server_dropped(&self) -> bool { + self.state.get().server_dropped + } + + /// Check if channel is closed from remoote side + pub fn is_closed(&self) -> bool { + self.remote.lock().unwrap().borrow().is_closed() + } + + /// Set read to Pending state + pub fn read_pending(&self) { + self.remote.lock().unwrap().borrow_mut().read = IoState::Pending; + } + + /// Set read to error + pub fn read_error(&self, err: io::Error) { + self.remote.lock().unwrap().borrow_mut().read = IoState::Err(err); + } + + /// Set write error on remote side + pub fn write_error(&self, err: io::Error) { + self.local.lock().unwrap().borrow_mut().write = IoState::Err(err); + } + + /// Access read buffer. + pub fn local_buffer(&self, f: F) -> R + where + F: FnOnce(&mut BytesMut) -> R, + { + let guard = self.local.lock().unwrap(); + let mut ch = guard.borrow_mut(); + f(&mut ch.buf) + } + + /// Access remote buffer. + pub fn remote_buffer(&self, f: F) -> R + where + F: FnOnce(&mut BytesMut) -> R, + { + let guard = self.remote.lock().unwrap(); + let mut ch = guard.borrow_mut(); + f(&mut ch.buf) + } + + /// Closed remote side. + pub async fn close(&self) { + { + let guard = self.remote.lock().unwrap(); + let mut remote = guard.borrow_mut(); + remote.read = IoState::Close; + remote.waker.wake(); + log::trace!("close remote socket"); + } + sleep(Millis(35)).await; + } + + /// Add extra data to the remote buffer and notify reader + pub fn write>(&self, data: T) { + let guard = self.remote.lock().unwrap(); + let mut write = guard.borrow_mut(); + write.buf.extend_from_slice(data.as_ref()); + write.waker.wake(); + } + + /// Read any available data + pub fn remote_buffer_cap(&self, cap: usize) { + // change cap + self.local.lock().unwrap().borrow_mut().buf_cap = cap; + // wake remote + self.remote.lock().unwrap().borrow().waker.wake(); + } + + /// Read any available data + pub fn read_any(&self) -> BytesMut { + self.local.lock().unwrap().borrow_mut().buf.split() + } + + /// Read data, if data is not available wait for it + pub async fn read(&self) -> Result { + if self.local.lock().unwrap().borrow().buf.is_empty() { + poll_fn(|cx| { + let guard = self.local.lock().unwrap(); + let read = guard.borrow_mut(); + if read.buf.is_empty() { + let closed = match self.tp { + Type::Client | Type::ClientClone => { + self.is_server_dropped() || read.is_closed() + } + Type::Server | Type::ServerClone => self.is_client_dropped(), + }; + if closed { + Poll::Ready(()) + } else { + *read.waker.0.lock().unwrap().borrow_mut() = + Some(cx.waker().clone()); + drop(read); + drop(guard); + Poll::Pending + } + } else { + Poll::Ready(()) + } + }) + .await; + } + Ok(self.local.lock().unwrap().borrow_mut().buf.split()) + } + + pub fn poll_read_buf( + &self, + cx: &mut Context<'_>, + buf: &mut BytesMut, + ) -> Poll> { + let guard = self.local.lock().unwrap(); + let mut ch = guard.borrow_mut(); + *ch.waker.0.lock().unwrap().borrow_mut() = Some(cx.waker().clone()); + + if !ch.buf.is_empty() { + let size = std::cmp::min(ch.buf.len(), buf.remaining_mut()); + let b = ch.buf.split_to(size); + buf.put_slice(&b); + return Poll::Ready(Ok(size)); + } + + match mem::take(&mut ch.read) { + IoState::Ok => Poll::Pending, + IoState::Close => { + ch.read = IoState::Close; + Poll::Ready(Ok(0)) + } + IoState::Pending => Poll::Pending, + IoState::Err(e) => Poll::Ready(Err(e)), + } + } + + pub fn poll_write_buf( + &self, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let guard = self.remote.lock().unwrap(); + let mut ch = guard.borrow_mut(); + + match mem::take(&mut ch.write) { + IoState::Ok => { + let cap = cmp::min(buf.len(), ch.buf_cap); + if cap > 0 { + ch.buf.extend(&buf[..cap]); + ch.buf_cap -= cap; + ch.flags.remove(Flags::FLUSHED); + ch.waker.wake(); + Poll::Ready(Ok(cap)) + } else { + *self + .local + .lock() + .unwrap() + .borrow_mut() + .waker + .0 + .lock() + .unwrap() + .borrow_mut() = Some(cx.waker().clone()); + Poll::Pending + } + } + IoState::Close => Poll::Ready(Ok(0)), + IoState::Pending => { + *self + .local + .lock() + .unwrap() + .borrow_mut() + .waker + .0 + .lock() + .unwrap() + .borrow_mut() = Some(cx.waker().clone()); + Poll::Pending + } + IoState::Err(e) => Poll::Ready(Err(e)), + } + } +} + +impl Clone for IoTest { + fn clone(&self) -> Self { + let tp = match self.tp { + Type::Server => Type::ServerClone, + Type::Client => Type::ClientClone, + val => val, + }; + + IoTest { + tp, + local: self.local.clone(), + remote: self.remote.clone(), + state: self.state.clone(), + } + } +} + +impl Drop for IoTest { + fn drop(&mut self) { + let mut state = self.state.get(); + match self.tp { + Type::Server => state.server_dropped = true, + Type::Client => state.client_dropped = true, + _ => (), + } + self.state.set(state); + } +} + +#[cfg(feature = "tokio")] +mod tokio { + use std::task::{Context, Poll}; + use std::{cmp, io, mem, pin::Pin}; + + use tok_io::io::{AsyncRead, AsyncWrite, ReadBuf}; + + use super::{Flags, IoState, IoTest}; + + impl AsyncRead for IoTest { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let guard = self.local.lock().unwrap(); + let mut ch = guard.borrow_mut(); + *ch.waker.0.lock().unwrap().borrow_mut() = Some(cx.waker().clone()); + + if !ch.buf.is_empty() { + let size = std::cmp::min(ch.buf.len(), buf.remaining()); + let b = ch.buf.split_to(size); + buf.put_slice(&b); + return Poll::Ready(Ok(())); + } + + match mem::take(&mut ch.read) { + IoState::Ok => Poll::Pending, + IoState::Close => { + ch.read = IoState::Close; + Poll::Ready(Ok(())) + } + IoState::Pending => Poll::Pending, + IoState::Err(e) => Poll::Ready(Err(e)), + } + } + } + + impl AsyncWrite for IoTest { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let guard = self.remote.lock().unwrap(); + let mut ch = guard.borrow_mut(); + + match mem::take(&mut ch.write) { + IoState::Ok => { + let cap = cmp::min(buf.len(), ch.buf_cap); + if cap > 0 { + ch.buf.extend(&buf[..cap]); + ch.buf_cap -= cap; + ch.flags.remove(Flags::FLUSHED); + ch.waker.wake(); + Poll::Ready(Ok(cap)) + } else { + *self + .local + .lock() + .unwrap() + .borrow_mut() + .waker + .0 + .lock() + .unwrap() + .borrow_mut() = Some(cx.waker().clone()); + Poll::Pending + } + } + IoState::Close => Poll::Ready(Ok(0)), + IoState::Pending => { + *self + .local + .lock() + .unwrap() + .borrow_mut() + .waker + .0 + .lock() + .unwrap() + .borrow_mut() = Some(cx.waker().clone()); + Poll::Pending + } + IoState::Err(e) => Poll::Ready(Err(e)), + } + } + + fn poll_flush( + self: Pin<&mut Self>, + _: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _: &mut Context<'_>, + ) -> Poll> { + self.local + .lock() + .unwrap() + .borrow_mut() + .flags + .insert(Flags::CLOSED); + Poll::Ready(Ok(())) + } + } +} + +#[cfg(not(feature = "tokio"))] +mod non_tokio { + impl IoStream for IoTest { + fn start(self, read: ReadState, write: WriteState) { + let io = Rc::new(self); + + ntex_util::spawn(ReadTask { + io: io.clone(), + state: read, + }); + ntex_util::spawn(WriteTask { + io, + state: write, + st: IoWriteState::Processing, + }); + } + } + + /// Read io task + struct ReadTask { + io: Rc, + state: ReadState, + } + + impl Future for ReadTask { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.as_ref(); + + match this.state.poll_ready(cx) { + Poll::Ready(Err(())) => { + log::trace!("read task is instructed to terminate"); + Poll::Ready(()) + } + Poll::Ready(Ok(())) => { + let io = &this.io; + let pool = this.state.memory_pool(); + let mut buf = self.state.get_read_buf(); + let (hw, lw) = pool.read_params().unpack(); + + // read data from socket + let mut new_bytes = 0; + loop { + // make sure we've got room + let remaining = buf.remaining_mut(); + if remaining < lw { + buf.reserve(hw - remaining); + } + + match io.poll_read_buf(cx, &mut buf) { + Poll::Pending => { + log::trace!("no more data in io stream"); + break; + } + Poll::Ready(Ok(n)) => { + if n == 0 { + log::trace!("io stream is disconnected"); + this.state.release_read_buf(buf, new_bytes); + this.state.close(None); + return Poll::Ready(()); + } else { + new_bytes += n; + if buf.len() > hw { + break; + } + } + } + Poll::Ready(Err(err)) => { + log::trace!("read task failed on io {:?}", err); + this.state.release_read_buf(buf, new_bytes); + this.state.close(Some(err)); + return Poll::Ready(()); + } + } + } + + this.state.release_read_buf(buf, new_bytes); + Poll::Pending + } + Poll::Pending => Poll::Pending, + } + } + } + + #[derive(Debug)] + enum IoWriteState { + Processing, + Shutdown(Option, Shutdown), + } + + #[derive(Debug)] + enum Shutdown { + None, + Flushed, + Stopping, + } + + /// Write io task + struct WriteTask { + st: IoWriteState, + io: Rc, + state: WriteState, + } + + impl Future for WriteTask { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.as_mut().get_mut(); + + match this.st { + IoWriteState::Processing => { + match this.state.poll_ready(cx) { + Poll::Ready(Ok(())) => { + // flush framed instance + match flush_io(&this.io, &this.state, cx) { + Poll::Pending | Poll::Ready(true) => Poll::Pending, + Poll::Ready(false) => Poll::Ready(()), + } + } + Poll::Ready(Err(WriteReadiness::Shutdown)) => { + log::trace!("write task is instructed to shutdown"); + + this.st = IoWriteState::Shutdown( + this.state.disconnect_timeout().map(sleep), + Shutdown::None, + ); + self.poll(cx) + } + Poll::Ready(Err(WriteReadiness::Terminate)) => { + log::trace!("write task is instructed to terminate"); + // shutdown WRITE side + this.io + .local + .lock() + .unwrap() + .borrow_mut() + .flags + .insert(Flags::CLOSED); + this.state.close(None); + Poll::Ready(()) + } + Poll::Pending => Poll::Pending, + } + } + IoWriteState::Shutdown(ref mut delay, ref mut st) => { + // close WRITE side and wait for disconnect on read side. + // use disconnect timeout, otherwise it could hang forever. + loop { + match st { + Shutdown::None => { + // flush write buffer + match flush_io(&this.io, &this.state, cx) { + Poll::Ready(true) => { + *st = Shutdown::Flushed; + continue; + } + Poll::Ready(false) => { + log::trace!( + "write task is closed with err during flush" + ); + return Poll::Ready(()); + } + _ => (), + } + } + Shutdown::Flushed => { + // shutdown WRITE side + this.io + .local + .lock() + .unwrap() + .borrow_mut() + .flags + .insert(Flags::CLOSED); + *st = Shutdown::Stopping; + continue; + } + Shutdown::Stopping => { + // read until 0 or err + let io = &this.io; + loop { + let mut buf = BytesMut::new(); + match io.poll_read_buf(cx, &mut buf) { + Poll::Ready(Err(e)) => { + this.state.close(Some(e)); + log::trace!("write task is stopped"); + return Poll::Ready(()); + } + Poll::Ready(Ok(n)) if n == 0 => { + this.state.close(None); + log::trace!("write task is stopped"); + return Poll::Ready(()); + } + Poll::Pending => break, + _ => (), + } + } + } + } + + // disconnect timeout + if let Some(ref delay) = delay { + if delay.poll_elapsed(cx).is_pending() { + return Poll::Pending; + } + } + log::trace!("write task is stopped after delay"); + this.state.close(None); + return Poll::Ready(()); + } + } + } + } + } + + /// Flush write buffer to underlying I/O stream. + pub(super) fn flush_io( + io: &IoTest, + state: &WriteState, + cx: &mut Context<'_>, + ) -> Poll { + let mut buf = if let Some(buf) = state.get_write_buf() { + buf + } else { + return Poll::Ready(true); + }; + let len = buf.len(); + let pool = state.memory_pool(); + + if len != 0 { + log::trace!("flushing framed transport: {}", len); + + let mut written = 0; + while written < len { + match io.poll_write_buf(cx, &buf[written..]) { + Poll::Ready(Ok(n)) => { + if n == 0 { + log::trace!( + "disconnected during flush, written {}", + written + ); + pool.release_write_buf(buf); + state.close(Some(io::Error::new( + io::ErrorKind::WriteZero, + "failed to write frame to transport", + ))); + return Poll::Ready(false); + } else { + written += n + } + } + Poll::Pending => break, + Poll::Ready(Err(e)) => { + log::trace!("error during flush: {}", e); + pool.release_write_buf(buf); + state.close(Some(e)); + return Poll::Ready(false); + } + } + } + log::trace!("flushed {} bytes", written); + + // remove written data + if written == len { + buf.clear(); + state.release_write_buf(buf); + Poll::Ready(true) + } else { + buf.advance(written); + state.release_write_buf(buf); + Poll::Pending + } + } else { + Poll::Ready(true) + } + } +} + +#[cfg(test)] +#[allow(clippy::redundant_clone)] +mod tests { + use super::*; + + #[ntex::test] + async fn basic() { + let (client, server) = IoTest::create(); + assert_eq!(client.tp, Type::Client); + assert_eq!(client.clone().tp, Type::ClientClone); + assert_eq!(server.tp, Type::Server); + assert_eq!(server.clone().tp, Type::ServerClone); + + assert!(!server.is_client_dropped()); + drop(client); + assert!(server.is_client_dropped()); + + let server2 = server.clone(); + assert!(!server2.is_server_dropped()); + drop(server); + assert!(server2.is_server_dropped()); + } +} diff --git a/ntex-io/src/time.rs b/ntex-io/src/time.rs new file mode 100644 index 00000000..6509c52b --- /dev/null +++ b/ntex-io/src/time.rs @@ -0,0 +1,104 @@ +use std::{ + cell::RefCell, collections::BTreeMap, collections::HashSet, rc::Rc, time::Instant, +}; + +use ntex_util::spawn; +use ntex_util::time::{now, sleep, Millis}; + +use super::state::{Flags, IoRef, IoStateInner}; + +pub struct Timer(Rc>); + +struct Inner { + running: bool, + resolution: Millis, + notifications: BTreeMap, fxhash::FxBuildHasher>>, +} + +impl Inner { + fn new(resolution: Millis) -> Self { + Inner { + resolution, + running: false, + notifications: BTreeMap::default(), + } + } + + fn unregister(&mut self, expire: Instant, io: &IoRef) { + if let Some(states) = self.notifications.get_mut(&expire) { + states.remove(&io.0); + if states.is_empty() { + self.notifications.remove(&expire); + } + } + } +} + +impl Clone for Timer { + fn clone(&self) -> Self { + Timer(self.0.clone()) + } +} + +impl Default for Timer { + fn default() -> Self { + Timer::new(Millis::ONE_SEC) + } +} + +impl Timer { + /// Create new timer with resolution in milliseconds + pub fn new(resolution: Millis) -> Timer { + Timer(Rc::new(RefCell::new(Inner::new(resolution)))) + } + + pub fn register(&self, expire: Instant, previous: Instant, io: &IoRef) { + let mut inner = self.0.borrow_mut(); + + inner.unregister(previous, io); + inner + .notifications + .entry(expire) + .or_insert_with(HashSet::default) + .insert(io.0.clone()); + + if !inner.running { + inner.running = true; + let interval = inner.resolution; + let inner = self.0.clone(); + + spawn(async move { + loop { + sleep(interval).await; + { + let mut i = inner.borrow_mut(); + let now_time = now(); + + // notify io dispatcher + while let Some(key) = i.notifications.keys().next() { + let key = *key; + if key <= now_time { + for st in i.notifications.remove(&key).unwrap() { + st.dispatch_task.wake(); + st.insert_flags(Flags::DSP_KEEPALIVE); + } + } else { + break; + } + } + + // new tick + if i.notifications.is_empty() { + i.running = false; + break; + } + } + } + }); + } + } + + pub fn unregister(&self, expire: Instant, io: &IoRef) { + self.0.borrow_mut().unregister(expire, io); + } +} diff --git a/ntex-io/src/tokio_impl.rs b/ntex-io/src/tokio_impl.rs new file mode 100644 index 00000000..884e1a96 --- /dev/null +++ b/ntex-io/src/tokio_impl.rs @@ -0,0 +1,314 @@ +use std::task::{Context, Poll}; +use std::{cell::RefCell, future::Future, io, pin::Pin, rc::Rc}; + +use ntex_bytes::{Buf, BufMut}; +use ntex_util::time::{sleep, Sleep}; +use tok_io::{io::AsyncRead, io::AsyncWrite, io::ReadBuf}; + +use super::{IoStream, ReadState, WriteReadiness, WriteState}; + +impl IoStream for T +where + T: AsyncRead + AsyncWrite + Unpin + 'static, +{ + fn start(self, read: ReadState, write: WriteState) { + let io = Rc::new(RefCell::new(self)); + + ntex_util::spawn(ReadTask::new(io.clone(), read)); + ntex_util::spawn(WriteTask::new(io, write)); + } +} + +/// Read io task +struct ReadTask { + io: Rc>, + state: ReadState, +} + +impl ReadTask +where + T: AsyncRead + AsyncWrite + Unpin + 'static, +{ + /// Create new read io task + fn new(io: Rc>, state: ReadState) -> Self { + Self { io, state } + } +} + +impl Future for ReadTask +where + T: AsyncRead + AsyncWrite + Unpin + 'static, +{ + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.as_ref(); + + match this.state.poll_ready(cx) { + Poll::Ready(Err(())) => { + log::trace!("read task is instructed to shutdown"); + Poll::Ready(()) + } + Poll::Ready(Ok(())) => { + let pool = this.state.memory_pool(); + let mut io = this.io.borrow_mut(); + let mut buf = self.state.get_read_buf(); + let (hw, lw) = pool.read_params().unpack(); + + // read data from socket + let mut new_bytes = 0; + loop { + // make sure we've got room + let remaining = buf.remaining_mut(); + if remaining < lw { + buf.reserve(hw - remaining); + } + + match ntex_codec::poll_read_buf(Pin::new(&mut *io), cx, &mut buf) { + Poll::Pending => break, + Poll::Ready(Ok(n)) => { + if n == 0 { + log::trace!("io stream is disconnected"); + this.state.release_read_buf(buf, new_bytes); + this.state.close(None); + return Poll::Ready(()); + } else { + new_bytes += n; + if buf.len() > hw { + break; + } + } + } + Poll::Ready(Err(err)) => { + log::trace!("read task failed on io {:?}", err); + this.state.release_read_buf(buf, new_bytes); + this.state.close(Some(err)); + return Poll::Ready(()); + } + } + } + + this.state.release_read_buf(buf, new_bytes); + Poll::Pending + } + Poll::Pending => Poll::Pending, + } + } +} + +#[derive(Debug)] +enum IoWriteState { + Processing, + Shutdown(Option, Shutdown), +} + +#[derive(Debug)] +enum Shutdown { + None, + Flushed, + Stopping, +} + +/// Write io task +struct WriteTask { + st: IoWriteState, + io: Rc>, + state: WriteState, +} + +impl WriteTask +where + T: AsyncRead + AsyncWrite + Unpin + 'static, +{ + /// Create new write io task + fn new(io: Rc>, state: WriteState) -> Self { + Self { + io, + state, + st: IoWriteState::Processing, + } + } +} + +impl Future for WriteTask +where + T: AsyncRead + AsyncWrite + Unpin + 'static, +{ + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.as_mut().get_mut(); + + match this.st { + IoWriteState::Processing => { + match this.state.poll_ready(cx) { + Poll::Ready(Ok(())) => { + // flush framed instance + match flush_io(&mut *this.io.borrow_mut(), &this.state, cx) { + Poll::Pending | Poll::Ready(true) => Poll::Pending, + Poll::Ready(false) => Poll::Ready(()), + } + } + Poll::Ready(Err(WriteReadiness::Shutdown)) => { + log::trace!("write task is instructed to shutdown"); + + this.st = IoWriteState::Shutdown( + this.state.disconnect_timeout().map(sleep), + Shutdown::None, + ); + self.poll(cx) + } + Poll::Ready(Err(WriteReadiness::Terminate)) => { + log::trace!("write task is instructed to terminate"); + + let _ = Pin::new(&mut *this.io.borrow_mut()).poll_shutdown(cx); + this.state.close(None); + Poll::Ready(()) + } + Poll::Pending => Poll::Pending, + } + } + IoWriteState::Shutdown(ref mut delay, ref mut st) => { + // close WRITE side and wait for disconnect on read side. + // use disconnect timeout, otherwise it could hang forever. + loop { + match st { + Shutdown::None => { + // flush write buffer + match flush_io(&mut *this.io.borrow_mut(), &this.state, cx) { + Poll::Ready(true) => { + *st = Shutdown::Flushed; + continue; + } + Poll::Ready(false) => { + log::trace!( + "write task is closed with err during flush" + ); + return Poll::Ready(()); + } + _ => (), + } + } + Shutdown::Flushed => { + // shutdown WRITE side + match Pin::new(&mut *this.io.borrow_mut()).poll_shutdown(cx) + { + Poll::Ready(Ok(_)) => { + *st = Shutdown::Stopping; + continue; + } + Poll::Ready(Err(e)) => { + log::trace!( + "write task is closed with err during shutdown" + ); + this.state.close(Some(e)); + return Poll::Ready(()); + } + _ => (), + } + } + Shutdown::Stopping => { + // read until 0 or err + let mut buf = [0u8; 512]; + let mut io = this.io.borrow_mut(); + loop { + let mut read_buf = ReadBuf::new(&mut buf); + match Pin::new(&mut *io).poll_read(cx, &mut read_buf) { + Poll::Ready(Err(_)) | Poll::Ready(Ok(_)) + if read_buf.filled().is_empty() => + { + this.state.close(None); + log::trace!("write task is stopped"); + return Poll::Ready(()); + } + Poll::Pending => break, + _ => (), + } + } + } + } + + // disconnect timeout + if let Some(ref delay) = delay { + if delay.poll_elapsed(cx).is_pending() { + return Poll::Pending; + } + } + log::trace!("write task is stopped after delay"); + this.state.close(None); + return Poll::Ready(()); + } + } + } + } +} + +/// Flush write buffer to underlying I/O stream. +pub(super) fn flush_io( + io: &mut T, + state: &WriteState, + cx: &mut Context<'_>, +) -> Poll { + let mut buf = if let Some(buf) = state.get_write_buf() { + buf + } else { + return Poll::Ready(true); + }; + let len = buf.len(); + let pool = state.memory_pool(); + + if len != 0 { + // log::trace!("flushing framed transport: {:?}", buf); + + let mut written = 0; + while written < len { + match Pin::new(&mut *io).poll_write(cx, &buf[written..]) { + Poll::Pending => break, + Poll::Ready(Ok(n)) => { + if n == 0 { + log::trace!("Disconnected during flush, written {}", written); + pool.release_write_buf(buf); + state.close(Some(io::Error::new( + io::ErrorKind::WriteZero, + "failed to write frame to transport", + ))); + return Poll::Ready(false); + } else { + written += n + } + } + Poll::Ready(Err(e)) => { + log::trace!("Error during flush: {}", e); + pool.release_write_buf(buf); + state.close(Some(e)); + return Poll::Ready(false); + } + } + } + // log::trace!("flushed {} bytes", written); + + // remove written data + let result = if written == len { + buf.clear(); + state.release_write_buf(buf); + Poll::Ready(true) + } else { + buf.advance(written); + state.release_write_buf(buf); + Poll::Pending + }; + + // flush + match Pin::new(&mut *io).poll_flush(cx) { + Poll::Ready(Ok(_)) => result, + Poll::Pending => Poll::Pending, + Poll::Ready(Err(e)) => { + log::trace!("error during flush: {}", e); + state.close(Some(e)); + Poll::Ready(false) + } + } + } else { + Poll::Ready(true) + } +} diff --git a/ntex-io/src/utils.rs b/ntex-io/src/utils.rs new file mode 100644 index 00000000..c1c321d5 --- /dev/null +++ b/ntex-io/src/utils.rs @@ -0,0 +1,49 @@ +use ntex_service::{fn_factory_with_config, into_service, Service, ServiceFactory}; + +use super::{Filter, Io, IoBoxed, IoStream}; + +/// Service that converts any Io stream to IoBoxed stream +pub fn into_boxed( + srv: S, +) -> impl ServiceFactory< + Config = S::Config, + Request = Io, + Response = S::Response, + Error = S::Error, + InitError = S::InitError, +> +where + F: Filter + 'static, + S: ServiceFactory, +{ + fn_factory_with_config(move |cfg: S::Config| { + let fut = srv.new_service(cfg); + async move { + let srv = fut.await?; + Ok(into_service(move |io: Io| srv.call(io.into_boxed()))) + } + }) +} + +/// Service that converts IoStream stream to IoBoxed stream +pub fn from_iostream( + srv: S, +) -> impl ServiceFactory< + Config = S::Config, + Request = I, + Response = S::Response, + Error = S::Error, + InitError = S::InitError, +> +where + I: IoStream, + S: ServiceFactory, +{ + fn_factory_with_config(move |cfg: S::Config| { + let fut = srv.new_service(cfg); + async move { + let srv = fut.await?; + Ok(into_service(move |io| srv.call(Io::new(io).into_boxed()))) + } + }) +} diff --git a/ntex-macros/Cargo.toml b/ntex-macros/Cargo.toml index 5a93fe86..6f4576d3 100644 --- a/ntex-macros/Cargo.toml +++ b/ntex-macros/Cargo.toml @@ -16,5 +16,5 @@ syn = { version = "^1", features = ["full", "parsing"] } proc-macro2 = "^1" [dev-dependencies] -ntex = "0.3.1" +ntex = "0.4.10" futures = "0.3.13" diff --git a/ntex-rt/src/builder.rs b/ntex-rt/src/builder.rs index a7ba02a3..f76c05a8 100644 --- a/ntex-rt/src/builder.rs +++ b/ntex-rt/src/builder.rs @@ -75,7 +75,7 @@ impl Builder { let (stop_tx, stop) = channel(); let (sys_sender, sys_receiver) = unbounded_channel(); - let system = System::construct( + let _system = System::construct( sys_sender, Arbiter::new_system(local), self.stop_on_panic, @@ -87,10 +87,7 @@ impl Builder { // start the system arbiter let _ = local.spawn_local(arb); - AsyncSystemRunner { - stop, - _system: system, - } + AsyncSystemRunner { stop, _system } } fn create_runtime(self, f: F) -> SystemRunner @@ -108,7 +105,7 @@ impl Builder { }); // system arbiter - let system = System::construct( + let _system = System::construct( sys_sender, Arbiter::new_system(rt.local()), self.stop_on_panic, @@ -119,11 +116,7 @@ impl Builder { // init system arbiter and run configuration method rt.block_on(lazy(move |_| f())); - SystemRunner { - rt, - stop, - _system: system, - } + SystemRunner { rt, stop, _system } } } diff --git a/ntex-rt/src/time.rs b/ntex-rt/src/time.rs index 89da7c03..5f755566 100644 --- a/ntex-rt/src/time.rs +++ b/ntex-rt/src/time.rs @@ -32,9 +32,10 @@ pub fn sleep(duration: Duration) -> Sleep { } } +#[doc(hidden)] /// Creates new [`Interval`] that yields with interval of `period` with the -/// first tick completing at `start`. The default [`MissedTickBehavior`] is -/// [`Burst`](MissedTickBehavior::Burst), but this can be configured +/// first tick completing at `start`. The default `MissedTickBehavior` is +/// `Burst`, but this can be configured /// by calling [`set_missed_tick_behavior`](Interval::set_missed_tick_behavior). #[inline] pub fn interval_at(start: Instant, period: Duration) -> Interval { diff --git a/ntex-util/Cargo.toml b/ntex-util/Cargo.toml index fe88e1a4..c7328dfc 100644 --- a/ntex-util/Cargo.toml +++ b/ntex-util/Cargo.toml @@ -24,8 +24,6 @@ futures-core = { version = "0.3.18", default-features = false, features = ["allo futures-sink = { version = "0.3.18", default-features = false, features = ["alloc"] } pin-project-lite = "0.2.6" -backtrace = "*" - [dev-dependencies] ntex = "0.4.10" ntex-rt = "0.3.2" diff --git a/ntex-util/src/channel/mod.rs b/ntex-util/src/channel/mod.rs index 79b21749..b1102b91 100644 --- a/ntex-util/src/channel/mod.rs +++ b/ntex-util/src/channel/mod.rs @@ -5,8 +5,8 @@ pub mod condition; pub mod oneshot; pub mod pool; -/// Error returned from a [`Receiver`](Receiver) when the corresponding -/// [`Sender`](Sender) is dropped. +/// Error returned from a `Receiver` when the corresponding +/// `Sender` is dropped. #[derive(Clone, Copy, PartialEq, Eq, Debug)] pub struct Canceled; diff --git a/ntex-util/src/time/wheel.rs b/ntex-util/src/time/wheel.rs index 6fd54f1a..3c86f199 100644 --- a/ntex-util/src/time/wheel.rs +++ b/ntex-util/src/time/wheel.rs @@ -658,17 +658,19 @@ mod tests { fut2.await; let elapsed = Instant::now() - time; + #[cfg(not(target_os = "macos"))] assert!( - elapsed > Duration::from_millis(200) && elapsed < Duration::from_millis(500), + elapsed > Duration::from_millis(200) && elapsed < Duration::from_millis(300), "elapsed: {:?}", elapsed ); fut1.await; let elapsed = Instant::now() - time; + #[cfg(not(target_os = "macos"))] assert!( elapsed > Duration::from_millis(1000) - && elapsed < Duration::from_millis(3000), // osx + && elapsed < Duration::from_millis(1200), // osx "elapsed: {:?}", elapsed ); @@ -676,8 +678,11 @@ mod tests { let time = Instant::now(); sleep(Millis(25)).await; let elapsed = Instant::now() - time; + #[cfg(not(target_os = "macos"))] assert!( - elapsed > Duration::from_millis(20) && elapsed < Duration::from_millis(50) + elapsed > Duration::from_millis(20) && elapsed < Duration::from_millis(50), + "elapsed: {:?}", + elapsed ); } } diff --git a/ntex/Cargo.toml b/ntex/Cargo.toml index 9a2984a7..44915d73 100644 --- a/ntex/Cargo.toml +++ b/ntex/Cargo.toml @@ -50,6 +50,7 @@ ntex-service = "0.2.1" ntex-macros = "0.1.3" ntex-util = "0.1.2" ntex-bytes = "0.1.7" +ntex-io = { version = "0.1", features = ["tokio"] } base64 = "0.13" bitflags = "1.3" diff --git a/ntex/src/connect/io.rs b/ntex/src/connect/io.rs new file mode 100644 index 00000000..c949add3 --- /dev/null +++ b/ntex/src/connect/io.rs @@ -0,0 +1,148 @@ +use std::task::{Context, Poll}; +use std::{future::Future, pin::Pin}; + +use crate::io::Io; +use crate::service::{Service, ServiceFactory}; +use crate::util::{PoolId, PoolRef, Ready}; + +use super::service::ConnectServiceResponse; +use super::{Address, Connect, ConnectError, Connector}; + +pub struct IoConnector { + inner: Connector, + pool: PoolRef, +} + +impl IoConnector { + /// Construct new connect service with custom dns resolver + pub fn new() -> Self { + IoConnector { + inner: Connector::new(), + pool: PoolId::P0.pool_ref(), + } + } + + /// Set memory pool. + /// + /// Use specified memory pool for memory allocations. By default P0 + /// memory pool is used. + pub fn memory_pool(mut self, id: PoolId) -> Self { + self.pool = id.pool_ref(); + self + } +} + +impl IoConnector { + /// Resolve and connect to remote host + pub fn connect(&self, message: U) -> IoConnectServiceResponse + where + Connect: From, + { + IoConnectServiceResponse { + inner: self.inner.call(message.into()), + pool: self.pool, + } + } +} + +impl Default for IoConnector { + fn default() -> Self { + IoConnector::new() + } +} + +impl Clone for IoConnector { + fn clone(&self) -> Self { + IoConnector { + inner: self.inner.clone(), + pool: self.pool, + } + } +} + +impl ServiceFactory for IoConnector { + type Request = Connect; + type Response = Io; + type Error = ConnectError; + type Config = (); + type Service = IoConnector; + type InitError = (); + type Future = Ready; + + #[inline] + fn new_service(&self, _: ()) -> Self::Future { + Ready::Ok(self.clone()) + } +} + +impl Service for IoConnector { + type Request = Connect; + type Response = Io; + type Error = ConnectError; + type Future = IoConnectServiceResponse; + + #[inline] + fn poll_ready(&self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + #[inline] + fn call(&self, req: Connect) -> Self::Future { + self.connect(req) + } +} + +#[doc(hidden)] +pub struct IoConnectServiceResponse { + inner: ConnectServiceResponse, + pool: PoolRef, +} + +impl Future for IoConnectServiceResponse { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match Pin::new(&mut self.inner).poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(stream)) => { + Poll::Ready(Ok(Io::with_memory_pool(stream, self.pool))) + } + Poll::Ready(Err(err)) => Poll::Ready(Err(err)), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[crate::rt_test] + async fn test_connect() { + let server = crate::server::test_server(|| { + crate::service::fn_service(|_| async { Ok::<_, ()>(()) }) + }); + + let srv = IoConnector::default(); + let result = srv.connect("").await; + assert!(result.is_err()); + let result = srv.connect("localhost:99999").await; + assert!(result.is_err()); + + let srv = IoConnector::default(); + let result = srv.connect(format!("{}", server.addr())).await; + assert!(result.is_ok()); + + let msg = Connect::new(format!("{}", server.addr())).set_addrs(vec![ + format!("127.0.0.1:{}", server.addr().port() - 1) + .parse() + .unwrap(), + server.addr(), + ]); + let result = crate::connect::connect(msg).await; + assert!(result.is_ok()); + + let msg = Connect::new(server.addr()); + let result = crate::connect::connect(msg).await; + assert!(result.is_ok()); + } +} diff --git a/ntex/src/connect/mod.rs b/ntex/src/connect/mod.rs index 4dc233d7..37344ca2 100644 --- a/ntex/src/connect/mod.rs +++ b/ntex/src/connect/mod.rs @@ -2,6 +2,7 @@ use std::future::Future; mod error; +mod io; mod message; mod resolve; mod service; @@ -18,6 +19,7 @@ pub mod rustls; use crate::rt::net::TcpStream; pub use self::error::ConnectError; +pub use self::io::IoConnector; pub use self::message::{Address, Connect}; pub use self::resolve::Resolver; pub use self::service::Connector; diff --git a/ntex/src/connect/service.rs b/ntex/src/connect/service.rs index 9d986196..c495f386 100644 --- a/ntex/src/connect/service.rs +++ b/ntex/src/connect/service.rs @@ -35,9 +35,7 @@ impl Connector { impl Default for Connector { fn default() -> Self { - Connector { - resolver: Resolver::default(), - } + Connector::new() } } diff --git a/ntex/src/lib.rs b/ntex/src/lib.rs index 5731ce16..f13c2721 100644 --- a/ntex/src/lib.rs +++ b/ntex/src/lib.rs @@ -39,7 +39,6 @@ pub mod framed; #[cfg(feature = "http-framework")] pub mod http; pub mod server; -pub mod testing; pub mod util; #[cfg(feature = "http-framework")] pub mod web; @@ -78,3 +77,13 @@ pub mod time { //! Utilities for tracking time. pub use ntex_util::time::*; } + +pub mod io { + //! IO streaming utilities. + pub use ntex_io::*; +} + +pub mod testing { + //! IO testing utilities. + pub use ntex_io::testing::IoTest as Io; +} diff --git a/ntex/src/testing.rs b/ntex/src/testing.rs deleted file mode 100644 index 839cbad4..00000000 --- a/ntex/src/testing.rs +++ /dev/null @@ -1,373 +0,0 @@ -use std::cell::{Cell, RefCell}; -use std::sync::{Arc, Mutex}; -use std::task::{Context, Poll, Waker}; -use std::{cmp, fmt, io, mem, pin::Pin}; - -use crate::codec::{AsyncRead, AsyncWrite, ReadBuf}; -use crate::time::{sleep, Millis}; -use crate::util::{poll_fn, BytesMut}; - -#[derive(Default)] -struct AtomicWaker(Arc>>>); - -impl AtomicWaker { - fn wake(&self) { - if let Some(waker) = self.0.lock().unwrap().borrow_mut().take() { - waker.wake() - } - } -} - -impl fmt::Debug for AtomicWaker { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "AtomicWaker") - } -} - -/// Async io stream -#[derive(Debug)] -pub struct Io { - tp: Type, - state: Arc>, - local: Arc>>, - remote: Arc>>, -} - -bitflags::bitflags! { - struct Flags: u8 { - const FLUSHED = 0b0000_0001; - const CLOSED = 0b0000_0010; - } -} - -#[derive(Copy, Clone, PartialEq, Eq, Debug)] -enum Type { - Client, - Server, - ClientClone, - ServerClone, -} - -#[derive(Copy, Clone, Default, Debug)] -struct State { - client_dropped: bool, - server_dropped: bool, -} - -#[derive(Default, Debug)] -struct Channel { - buf: BytesMut, - buf_cap: usize, - flags: Flags, - waker: AtomicWaker, - read: IoState, - write: IoState, -} - -impl Channel { - fn is_closed(&self) -> bool { - self.flags.contains(Flags::CLOSED) - } -} - -impl Default for Flags { - fn default() -> Self { - Flags::empty() - } -} - -#[derive(Debug)] -enum IoState { - Ok, - Pending, - Close, - Err(io::Error), -} - -impl Default for IoState { - fn default() -> Self { - IoState::Ok - } -} - -impl Io { - /// Create a two interconnected streams - pub fn create() -> (Io, Io) { - let local = Arc::new(Mutex::new(RefCell::new(Channel::default()))); - let remote = Arc::new(Mutex::new(RefCell::new(Channel::default()))); - let state = Arc::new(Cell::new(State::default())); - - ( - Io { - tp: Type::Client, - local: local.clone(), - remote: remote.clone(), - state: state.clone(), - }, - Io { - state, - tp: Type::Server, - local: remote, - remote: local, - }, - ) - } - - pub fn is_client_dropped(&self) -> bool { - self.state.get().client_dropped - } - - pub fn is_server_dropped(&self) -> bool { - self.state.get().server_dropped - } - - /// Check if channel is closed from remoote side - pub fn is_closed(&self) -> bool { - self.remote.lock().unwrap().borrow().is_closed() - } - - /// Set read to Pending state - pub fn read_pending(&self) { - self.remote.lock().unwrap().borrow_mut().read = IoState::Pending; - } - - /// Set read to error - pub fn read_error(&self, err: io::Error) { - self.remote.lock().unwrap().borrow_mut().read = IoState::Err(err); - } - - /// Set write error on remote side - pub fn write_error(&self, err: io::Error) { - self.local.lock().unwrap().borrow_mut().write = IoState::Err(err); - } - - /// Access read buffer. - pub fn local_buffer(&self, f: F) -> R - where - F: FnOnce(&mut BytesMut) -> R, - { - let guard = self.local.lock().unwrap(); - let mut ch = guard.borrow_mut(); - f(&mut ch.buf) - } - - /// Access remote buffer. - pub fn remote_buffer(&self, f: F) -> R - where - F: FnOnce(&mut BytesMut) -> R, - { - let guard = self.remote.lock().unwrap(); - let mut ch = guard.borrow_mut(); - f(&mut ch.buf) - } - - /// Closed remote side. - pub async fn close(&self) { - { - let guard = self.remote.lock().unwrap(); - let mut remote = guard.borrow_mut(); - remote.read = IoState::Close; - remote.waker.wake(); - } - sleep(Millis(35)).await; - } - - /// Add extra data to the remote buffer and notify reader - pub fn write>(&self, data: T) { - let guard = self.remote.lock().unwrap(); - let mut write = guard.borrow_mut(); - write.buf.extend_from_slice(data.as_ref()); - write.waker.wake(); - } - - /// Read any available data - pub fn remote_buffer_cap(&self, cap: usize) { - // change cap - self.local.lock().unwrap().borrow_mut().buf_cap = cap; - // wake remote - self.remote.lock().unwrap().borrow().waker.wake(); - } - - /// Read any available data - pub fn read_any(&self) -> BytesMut { - self.local.lock().unwrap().borrow_mut().buf.split() - } - - /// Read data, if data is not available wait for it - pub async fn read(&self) -> Result { - if self.local.lock().unwrap().borrow().buf.is_empty() { - poll_fn(|cx| { - let guard = self.local.lock().unwrap(); - let read = guard.borrow_mut(); - if read.buf.is_empty() { - let closed = match self.tp { - Type::Client | Type::ClientClone => { - self.is_server_dropped() || read.is_closed() - } - Type::Server | Type::ServerClone => self.is_client_dropped(), - }; - if closed { - Poll::Ready(()) - } else { - *read.waker.0.lock().unwrap().borrow_mut() = - Some(cx.waker().clone()); - drop(read); - drop(guard); - Poll::Pending - } - } else { - Poll::Ready(()) - } - }) - .await; - } - Ok(self.local.lock().unwrap().borrow_mut().buf.split()) - } -} - -impl Clone for Io { - fn clone(&self) -> Self { - let tp = match self.tp { - Type::Server => Type::ServerClone, - Type::Client => Type::ClientClone, - val => val, - }; - - Io { - tp, - local: self.local.clone(), - remote: self.remote.clone(), - state: self.state.clone(), - } - } -} - -impl Drop for Io { - fn drop(&mut self) { - let mut state = self.state.get(); - match self.tp { - Type::Server => state.server_dropped = true, - Type::Client => state.client_dropped = true, - _ => (), - } - self.state.set(state); - } -} - -impl AsyncRead for Io { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - let guard = self.local.lock().unwrap(); - let mut ch = guard.borrow_mut(); - *ch.waker.0.lock().unwrap().borrow_mut() = Some(cx.waker().clone()); - - if !ch.buf.is_empty() { - let size = std::cmp::min(ch.buf.len(), buf.remaining()); - let b = ch.buf.split_to(size); - buf.put_slice(&b); - return Poll::Ready(Ok(())); - } - - match mem::take(&mut ch.read) { - IoState::Ok => Poll::Pending, - IoState::Close => { - ch.read = IoState::Close; - Poll::Ready(Ok(())) - } - IoState::Pending => Poll::Pending, - IoState::Err(e) => Poll::Ready(Err(e)), - } - } -} - -impl AsyncWrite for Io { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - let guard = self.remote.lock().unwrap(); - let mut ch = guard.borrow_mut(); - - match mem::take(&mut ch.write) { - IoState::Ok => { - let cap = cmp::min(buf.len(), ch.buf_cap); - if cap > 0 { - ch.buf.extend(&buf[..cap]); - ch.buf_cap -= cap; - ch.flags.remove(Flags::FLUSHED); - ch.waker.wake(); - Poll::Ready(Ok(cap)) - } else { - *self - .local - .lock() - .unwrap() - .borrow_mut() - .waker - .0 - .lock() - .unwrap() - .borrow_mut() = Some(cx.waker().clone()); - Poll::Pending - } - } - IoState::Close => Poll::Ready(Ok(0)), - IoState::Pending => { - *self - .local - .lock() - .unwrap() - .borrow_mut() - .waker - .0 - .lock() - .unwrap() - .borrow_mut() = Some(cx.waker().clone()); - Poll::Pending - } - IoState::Err(e) => Poll::Ready(Err(e)), - } - } - - fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - self.local - .lock() - .unwrap() - .borrow_mut() - .flags - .insert(Flags::CLOSED); - Poll::Ready(Ok(())) - } -} - -#[cfg(test)] -#[allow(clippy::redundant_clone)] -mod tests { - use super::*; - - #[crate::rt_test] - async fn basic() { - let (client, server) = Io::create(); - assert_eq!(client.tp, Type::Client); - assert_eq!(client.clone().tp, Type::ClientClone); - assert_eq!(server.tp, Type::Server); - assert_eq!(server.clone().tp, Type::ServerClone); - - assert!(!server.is_client_dropped()); - drop(client); - assert!(server.is_client_dropped()); - - let server2 = server.clone(); - assert!(!server2.is_server_dropped()); - drop(server); - assert!(server2.is_server_dropped()); - } -} diff --git a/ntex/src/web/rmap.rs b/ntex/src/web/rmap.rs index 98d8a43c..f026d888 100644 --- a/ntex/src/web/rmap.rs +++ b/ntex/src/web/rmap.rs @@ -10,6 +10,7 @@ use crate::web::httprequest::HttpRequest; #[derive(Clone, Debug)] pub struct ResourceMap { + #[allow(dead_code)] root: ResourceDef, parent: RefCell>>, named: HashMap, diff --git a/ntex/tests/server.rs b/ntex/tests/server.rs index 9fb8abdc..52980397 100644 --- a/ntex/tests/server.rs +++ b/ntex/tests/server.rs @@ -129,6 +129,7 @@ fn test_start() { } #[test] +#[allow(deprecated)] fn test_configure() { let addr1 = TestServer::unused_addr(); let addr2 = TestServer::unused_addr(); @@ -179,6 +180,7 @@ fn test_configure() { } #[test] +#[allow(deprecated)] fn test_configure_async() { let addr1 = TestServer::unused_addr(); let addr2 = TestServer::unused_addr(); @@ -255,7 +257,7 @@ fn test_on_worker_start() { .bind("addr2", addr2) .unwrap() .listen("addr3", lst) - .apply_async(move |rt| { + .on_worker_start(move |rt| { let num = num.clone(); async move { rt.service("addr1", fn_service(|_| ok::<_, ()>(())));