From 9de3f3060f83ab5f829e44bda76faf67725b87d9 Mon Sep 17 00:00:00 2001 From: Will Brown Date: Thu, 6 Apr 2023 13:39:56 -0400 Subject: [PATCH] add enter/exit fn for spawn_cbs (#195) --- ntex-rt/src/lib.rs | 82 +++++++++++++++++++++++++++++++--------------- 1 file changed, 56 insertions(+), 26 deletions(-) diff --git a/ntex-rt/src/lib.rs b/ntex-rt/src/lib.rs index 5db5979f..d17d8ecf 100644 --- a/ntex-rt/src/lib.rs +++ b/ntex-rt/src/lib.rs @@ -10,29 +10,39 @@ pub use self::builder::{Builder, SystemRunner}; pub use self::system::System; thread_local! { - static CB: RefCell<(TBefore, TSpawn, TAfter)> = RefCell::new(( - Box::new(|| {None}), Box::new(|_| {ptr::null()}), Box::new(|_| {})) + static CB: RefCell<(TBefore, TEnter, TExit, TAfter)> = RefCell::new(( + Box::new(|| {None}), Box::new(|_| {ptr::null()}), Box::new(|_| {}), Box::new(|_| {})) ); } type TBefore = Box Option<*const ()>>; -type TSpawn = Box *const ()>; +type TEnter = Box *const ()>; +type TExit = Box; type TAfter = Box; -pub unsafe fn spawn_cbs( +/// # Safety +/// +/// The user must ensure that the pointer returned by `before` is `'static`. It will become +/// owned by the spawned task for the life of the task. Ownership of the pointer will be +/// returned to the user at the end of the task via `after`. The pointer is opaque to the +/// runtime. +pub unsafe fn spawn_cbs( before: FBefore, - before_spawn: FSpawn, - after_spawn: FAfter, + enter: FEnter, + exit: FExit, + after: FAfter, ) where FBefore: Fn() -> Option<*const ()> + 'static, - FSpawn: Fn(*const ()) -> *const () + 'static, + FEnter: Fn(*const ()) -> *const () + 'static, + FExit: Fn(*const ()) + 'static, FAfter: Fn(*const ()) + 'static, { CB.with(|cb| { *cb.borrow_mut() = ( Box::new(before), - Box::new(before_spawn), - Box::new(after_spawn), + Box::new(enter), + Box::new(exit), + Box::new(after), ); }); } @@ -40,7 +50,8 @@ pub unsafe fn spawn_cbs( #[allow(dead_code)] #[cfg(all(feature = "glommio", target_os = "linux"))] mod glommio { - use std::{future::Future, pin::Pin, task::Context, task::Poll}; + use std::future::{poll_fn, Future}; + use std::{pin::Pin, task::Context, task::Poll}; use futures_channel::oneshot::Canceled; use glomm_io::task; @@ -64,7 +75,7 @@ mod glommio { /// /// This function panics if ntex system is not running. #[inline] - pub fn spawn(f: F) -> JoinHandle + pub fn spawn(mut f: F) -> JoinHandle where F: Future + 'static, F::Output: 'static, @@ -74,11 +85,17 @@ mod glommio { fut: Either::Left( glomm_io::spawn_local(async move { if let Some(ptr) = ptr { - let new_ptr = crate::CB.with(|cb| (cb.borrow().1)(ptr)); glomm_io::executor().yield_now().await; - let res = f.await; - crate::CB.with(|cb| (cb.borrow().2)(new_ptr)); - res + let mut f = unsafe { Pin::new_unchecked(&mut f) }; + let result = poll_fn(|ctx| { + let new_ptr = crate::CB.with(|cb| (cb.borrow().1)(ptr)); + let result = f.as_mut().poll(ctx); + crate::CB.with(|cb| (cb.borrow().2)(new_ptr)); + result + }) + .await; + crate::CB.with(|cb| (cb.borrow().3)(ptr)); + result } else { glomm_io::executor().yield_now().await; f.await @@ -145,7 +162,7 @@ mod glommio { #[cfg(feature = "tokio")] mod tokio { - use std::future::Future; + use std::future::{poll_fn, Future}; pub use tok_io::task::{spawn_blocking, JoinError, JoinHandle}; /// Runs the provided future, blocking the current thread until the future @@ -174,10 +191,16 @@ mod tokio { let ptr = crate::CB.with(|cb| (cb.borrow().0)()); tok_io::task::spawn_local(async move { if let Some(ptr) = ptr { - let new_ptr = crate::CB.with(|cb| (cb.borrow().1)(ptr)); - let res = f.await; - crate::CB.with(|cb| (cb.borrow().2)(new_ptr)); - res + tok_io::pin!(f); + let result = poll_fn(|ctx| { + let new_ptr = crate::CB.with(|cb| (cb.borrow().1)(ptr)); + let result = f.as_mut().poll(ctx); + crate::CB.with(|cb| (cb.borrow().2)(new_ptr)); + result + }) + .await; + crate::CB.with(|cb| (cb.borrow().3)(ptr)); + result } else { f.await } @@ -205,7 +228,8 @@ mod tokio { #[cfg(feature = "async-std")] mod asyncstd { use futures_core::ready; - use std::{fmt, future::Future, pin::Pin, task::Context, task::Poll}; + use std::future::{poll_fn, Future}; + use std::{fmt, pin::Pin, task::Context, task::Poll}; /// Runs the provided future, blocking the current thread until the future /// completes. @@ -221,7 +245,7 @@ mod asyncstd { /// /// This function panics if ntex system is not running. #[inline] - pub fn spawn(f: F) -> JoinHandle + pub fn spawn(mut f: F) -> JoinHandle where F: Future + 'static, { @@ -229,10 +253,16 @@ mod asyncstd { JoinHandle { fut: async_std::task::spawn_local(async move { if let Some(ptr) = ptr { - let new_ptr = crate::CB.with(|cb| (cb.borrow().1)(ptr)); - let res = f.await; - crate::CB.with(|cb| (cb.borrow().2)(new_ptr)); - res + let mut f = unsafe { Pin::new_unchecked(&mut f) }; + let result = poll_fn(|ctx| { + let new_ptr = crate::CB.with(|cb| (cb.borrow().1)(ptr)); + let result = f.as_mut().poll(ctx); + crate::CB.with(|cb| (cb.borrow().2)(new_ptr)); + result + }) + .await; + crate::CB.with(|cb| (cb.borrow().3)(ptr)); + result } else { f.await }