diff --git a/ntex-service/CHANGES.md b/ntex-service/CHANGES.md index b399c3f8..0f51d87f 100644 --- a/ntex-service/CHANGES.md +++ b/ntex-service/CHANGES.md @@ -1,5 +1,9 @@ # Changes +## [3.2.0] - 2024-10-19 + +* Introduce `PipelineTag`, which allows to notify pipeline binding + ## [3.1.0] - 2024-09-29 * Notify readiness waiters if ready call get dropped diff --git a/ntex-service/Cargo.toml b/ntex-service/Cargo.toml index 519ef949..eedc7549 100644 --- a/ntex-service/Cargo.toml +++ b/ntex-service/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-service" -version = "3.1.0" +version = "3.2.0" authors = ["ntex contributors "] description = "ntex service" keywords = ["network", "framework", "async", "futures"] diff --git a/ntex-service/src/boxed.rs b/ntex-service/src/boxed.rs index c785729c..ad8f2927 100644 --- a/ntex-service/src/boxed.rs +++ b/ntex-service/src/boxed.rs @@ -1,4 +1,4 @@ -use std::{fmt, future::Future, pin::Pin}; +use std::{fmt, future::Future, pin::Pin, rc::Rc}; use crate::ctx::{ServiceCtx, WaitersRef}; @@ -51,14 +51,14 @@ trait ServiceObj { fn ready<'a>( &'a self, idx: usize, - waiters: &'a WaitersRef, + waiters: &'a Rc, ) -> BoxFuture<'a, (), Self::Error>; fn call<'a>( &'a self, req: Req, idx: usize, - waiters: &'a WaitersRef, + waiters: &'a Rc, ) -> BoxFuture<'a, Self::Response, Self::Error>; fn shutdown<'a>(&'a self) -> Pin + 'a>>; @@ -76,7 +76,7 @@ where fn ready<'a>( &'a self, idx: usize, - waiters: &'a WaitersRef, + waiters: &'a Rc, ) -> BoxFuture<'a, (), Self::Error> { Box::pin(async move { ServiceCtx::<'a, S>::from_ref(idx, waiters) @@ -95,7 +95,7 @@ where &'a self, req: Req, idx: usize, - waiters: &'a WaitersRef, + waiters: &'a Rc, ) -> BoxFuture<'a, Self::Response, Self::Error> { Box::pin(async move { ServiceCtx::<'a, S>::from_ref(idx, waiters) diff --git a/ntex-service/src/ctx.rs b/ntex-service/src/ctx.rs index a5b563cd..93d3c979 100644 --- a/ntex-service/src/ctx.rs +++ b/ntex-service/src/ctx.rs @@ -1,23 +1,37 @@ -use std::{cell, fmt, future::Future, marker, pin::Pin, rc::Rc, task}; +use std::{cell, fmt, future::Future, marker, pin::Pin, rc::Rc, task, task::Context}; use crate::Service; pub struct ServiceCtx<'a, S: ?Sized> { idx: usize, - waiters: &'a WaitersRef, + waiters: &'a Rc, _t: marker::PhantomData>, } +#[derive(Clone, Debug)] +/// Pipeline tag allows to notify pipeline binding +pub struct PipelineTag(Rc); + pub(crate) struct Waiters { index: usize, waiters: Rc, } +#[derive(Debug)] pub(crate) struct WaitersRef { cur: cell::Cell, indexes: cell::UnsafeCell>>, } +impl PipelineTag { + /// Notify pipeline dispatcher + pub fn notify(&self) { + if let Some(waker) = self.0.get()[0].take() { + waker.wake(); + } + } +} + impl WaitersRef { #[allow(clippy::mut_from_ref)] fn get(&self) -> &mut slab::Slab> { @@ -29,16 +43,15 @@ impl WaitersRef { } fn remove(&self, idx: usize) { - self.notify(); self.get().remove(idx); } - fn register(&self, idx: usize, cx: &mut task::Context<'_>) { + fn register(&self, idx: usize, cx: &mut Context<'_>) { self.get()[idx] = Some(cx.waker().clone()); } fn notify(&self) { - for (_, waker) in self.get().iter_mut() { + for (_, waker) in self.get().iter_mut().skip(1) { if let Some(waker) = waker.take() { waker.wake(); } @@ -47,7 +60,7 @@ impl WaitersRef { self.cur.set(usize::MAX); } - pub(crate) fn can_check(&self, idx: usize, cx: &mut task::Context<'_>) -> bool { + pub(crate) fn can_check(&self, idx: usize, cx: &mut Context<'_>) -> bool { let cur = self.cur.get(); if cur == idx { true @@ -64,9 +77,12 @@ impl WaitersRef { impl Waiters { pub(crate) fn new() -> Self { let mut waiters = slab::Slab::new(); - let index = waiters.insert(None); + + // first insert for wake ups from services + let _ = waiters.insert(None); + Waiters { - index, + index: waiters.insert(None), waiters: Rc::new(WaitersRef { cur: cell::Cell::new(usize::MAX), indexes: cell::UnsafeCell::new(waiters), @@ -74,18 +90,22 @@ impl Waiters { } } - pub(crate) fn get_ref(&self) -> &WaitersRef { - self.waiters.as_ref() + pub(crate) fn get_ref(&self) -> &Rc { + &self.waiters } - pub(crate) fn can_check(&self, cx: &mut task::Context<'_>) -> bool { + pub(crate) fn can_check(&self, cx: &mut Context<'_>) -> bool { self.waiters.can_check(self.index, cx) } - pub(crate) fn register(&self, cx: &mut task::Context<'_>) { + pub(crate) fn register(&self, cx: &mut Context<'_>) { self.waiters.register(self.index, cx); } + pub(crate) fn register_pipeline(&self, cx: &mut Context<'_>) { + self.waiters.register(0, cx); + } + pub(crate) fn notify(&self) { if self.waiters.cur.get() == self.index { self.waiters.notify(); @@ -97,7 +117,7 @@ impl Drop for Waiters { #[inline] fn drop(&mut self) { self.waiters.remove(self.index); - self.waiters.notify(); + self.notify(); } } @@ -128,7 +148,7 @@ impl<'a, S> ServiceCtx<'a, S> { } } - pub(crate) fn from_ref(idx: usize, waiters: &'a WaitersRef) -> Self { + pub(crate) fn from_ref(idx: usize, waiters: &'a Rc) -> Self { Self { idx, waiters, @@ -136,7 +156,7 @@ impl<'a, S> ServiceCtx<'a, S> { } } - pub(crate) fn inner(self) -> (usize, &'a WaitersRef) { + pub(crate) fn inner(self) -> (usize, &'a Rc) { (self.idx, self.waiters) } @@ -199,6 +219,11 @@ impl<'a, S> ServiceCtx<'a, S> { ) .await } + + /// Get pipeline tag for current pipeline + pub fn tag(&self) -> PipelineTag { + PipelineTag(self.waiters.clone()) + } } impl<'a, S> Copy for ServiceCtx<'a, S> {} @@ -227,7 +252,7 @@ struct ReadyCall<'a, S: ?Sized, F: Future> { impl<'a, S: ?Sized, F: Future> Drop for ReadyCall<'a, S, F> { fn drop(&mut self) { - if !self.completed { + if !self.completed && self.ctx.waiters.cur.get() == self.ctx.idx { self.ctx.waiters.notify(); } } @@ -238,10 +263,7 @@ impl<'a, S: ?Sized, F: Future> Unpin for ReadyCall<'a, S, F> {} impl<'a, S: ?Sized, F: Future> Future for ReadyCall<'a, S, F> { type Output = F::Output; - fn poll( - mut self: Pin<&mut Self>, - cx: &mut task::Context<'_>, - ) -> task::Poll { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> task::Poll { if self.ctx.waiters.can_check(self.ctx.idx, cx) { // SAFETY: `fut` never moves let result = unsafe { Pin::new_unchecked(&mut self.as_mut().fut).poll(cx) }; @@ -310,10 +332,9 @@ mod tests { let res = lazy(|cx| srv2.poll_ready(cx)).await; assert_eq!(res, Poll::Pending); - assert_eq!(cnt.get(), 1); - con.notify(); + con.notify(); let res = lazy(|cx| srv1.poll_ready(cx)).await; assert_eq!(res, Poll::Ready(Ok(()))); assert_eq!(cnt.get(), 1); @@ -431,4 +452,50 @@ mod tests { assert_eq!(cnt.get(), 2); assert_eq!(&*data.borrow(), &["srv1", "srv2"]); } + + #[ntex::test] + async fn test_pipeline_tag() { + struct Srv(Rc>, Cell>); + + impl Service<&'static str> for Srv { + type Response = &'static str; + type Error = (); + + async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> { + self.1.set(Some(ctx.tag())); + self.0.set(self.0.get() + 1); + Ok(()) + } + + async fn call( + &self, + req: &'static str, + _: ServiceCtx<'_, Self>, + ) -> Result<&'static str, ()> { + Ok(req) + } + } + + let cnt = Rc::new(Cell::new(0)); + let con = condition::Condition::new(); + + let srv = Pipeline::from(Srv(cnt.clone(), Cell::new(None))).bind(); + + let srv1 = srv.clone(); + let waiter = con.wait(); + ntex::rt::spawn(async move { + let _ = poll_fn(|cx| { + let _ = srv1.poll_ready(cx); + waiter.poll_ready(cx) + }) + .await; + }); + time::sleep(time::Millis(50)).await; + assert_eq!(cnt.get(), 1); + + let tag = srv.get_ref().1.take().unwrap(); + tag.notify(); + time::sleep(time::Millis(50)).await; + assert_eq!(cnt.get(), 2); + } } diff --git a/ntex-service/src/lib.rs b/ntex-service/src/lib.rs index e01b6902..bf4afbed 100644 --- a/ntex-service/src/lib.rs +++ b/ntex-service/src/lib.rs @@ -27,7 +27,7 @@ mod util; pub use self::apply::{apply_fn, apply_fn_factory}; pub use self::chain::{chain, chain_factory}; -pub use self::ctx::ServiceCtx; +pub use self::ctx::{PipelineTag, ServiceCtx}; pub use self::fn_service::{fn_factory, fn_factory_with_config, fn_service}; pub use self::fn_shutdown::fn_shutdown; pub use self::map_config::{map_config, unit_config}; diff --git a/ntex-service/src/pipeline.rs b/ntex-service/src/pipeline.rs index ea5b1883..7d1d76e9 100644 --- a/ntex-service/src/pipeline.rs +++ b/ntex-service/src/pipeline.rs @@ -346,6 +346,9 @@ where fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut slf = self.as_mut(); + // register pipeline tag + slf.pl.waiters.register_pipeline(cx); + if slf.pl.waiters.can_check(cx) { if let Some(ref mut fut) = slf.fut { match unsafe { Pin::new_unchecked(fut) }.poll(cx) {