diff --git a/ntex-service/CHANGES.md b/ntex-service/CHANGES.md index 42b1b33b..a74e8b2f 100644 --- a/ntex-service/CHANGES.md +++ b/ntex-service/CHANGES.md @@ -1,5 +1,9 @@ # Changes +## [1.2.2] - 2023-06-24 + +* Added `ServiceCall::advance_to_call` + ## [1.2.1] - 2023-06-23 * Make `PipelineCall` static diff --git a/ntex-service/Cargo.toml b/ntex-service/Cargo.toml index 67550494..a0f65478 100644 --- a/ntex-service/Cargo.toml +++ b/ntex-service/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-service" -version = "1.2.1" +version = "1.2.2" authors = ["ntex contributors "] description = "ntex service" keywords = ["network", "framework", "async", "futures"] diff --git a/ntex-service/src/ctx.rs b/ntex-service/src/ctx.rs index f0395dc9..f01fa997 100644 --- a/ntex-service/src/ctx.rs +++ b/ntex-service/src/ctx.rs @@ -166,6 +166,26 @@ pin_project_lite::pin_project! { } } +impl<'a, S, Req> ServiceCall<'a, S, Req> +where + S: Service, + S: 'a, + S: ?Sized, + Req: 'a, +{ + pub fn advance_to_call(self) -> ServiceCallToCall<'a, S, Req> { + match self.state { + ServiceCallState::Ready { .. } => {} + ServiceCallState::Call { .. } | ServiceCallState::Empty => { + panic!( + "`ServiceCall::advance_to_call` must be called before `ServiceCall::poll`" + ) + } + } + ServiceCallToCall { state: self.state } + } +} + pin_project_lite::pin_project! { #[project = ServiceCallStateProject] enum ServiceCallState<'a, S, Req> @@ -234,6 +254,68 @@ where } } +pin_project_lite::pin_project! { + #[must_use = "futures do nothing unless polled"] + pub struct ServiceCallToCall<'a, S, Req> + where + S: Service, + S: 'a, + S: ?Sized, + Req: 'a, + { + #[pin] + state: ServiceCallState<'a, S, Req>, + } +} + +impl<'a, S, Req> Future for ServiceCallToCall<'a, S, Req> +where + S: Service + ?Sized, +{ + type Output = Result, S::Error>; + + fn poll( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> task::Poll { + let mut this = self.as_mut().project(); + + match this.state.as_mut().project() { + ServiceCallStateProject::Ready { + req, + svc, + idx, + waiters, + } => match svc.poll_ready(cx)? { + task::Poll::Ready(()) => { + waiters.notify(); + + let fut = svc.call( + req.take().unwrap(), + ServiceCtx { + idx: *idx, + waiters, + _t: marker::PhantomData, + }, + ); + this.state.set(ServiceCallState::Empty); + task::Poll::Ready(Ok(fut)) + } + task::Poll::Pending => { + waiters.register(*idx, cx); + task::Poll::Pending + } + }, + ServiceCallStateProject::Call { .. } => { + unreachable!("`ServiceCallToCall` can only be constructed in `Ready` state") + } + ServiceCallStateProject::Empty => { + panic!("future must not be polled after it returned `Poll::Ready`") + } + } + } +} + #[cfg(test)] mod tests { use ntex_util::future::{lazy, poll_fn, Ready}; diff --git a/ntex-service/src/lib.rs b/ntex-service/src/lib.rs index ca9fdbf2..887d83dd 100644 --- a/ntex-service/src/lib.rs +++ b/ntex-service/src/lib.rs @@ -24,7 +24,7 @@ mod then; pub use self::apply::{apply_fn, apply_fn_factory}; pub use self::chain::{chain, chain_factory}; -pub use self::ctx::{ServiceCall, ServiceCtx}; +pub use self::ctx::{ServiceCall, ServiceCallToCall, ServiceCtx}; pub use self::fn_service::{fn_factory, fn_factory_with_config, fn_service}; pub use self::fn_shutdown::fn_shutdown; pub use self::map_config::{map_config, unit_config}; diff --git a/ntex-util/CHANGES.md b/ntex-util/CHANGES.md index 9b1864c3..63f14bc1 100644 --- a/ntex-util/CHANGES.md +++ b/ntex-util/CHANGES.md @@ -1,5 +1,11 @@ # Changes +## [0.3.1] - 2023-06-24 + +* Changed `BufferService` to maintain order + +* Buffer error type changed to indicate cancellation + ## [0.3.0] - 2023-06-22 * Release v0.3.0 diff --git a/ntex-util/Cargo.toml b/ntex-util/Cargo.toml index bef46d01..522f1144 100644 --- a/ntex-util/Cargo.toml +++ b/ntex-util/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-util" -version = "0.3.0" +version = "0.3.1" authors = ["ntex contributors "] description = "Utilities for ntex framework" keywords = ["network", "framework", "async", "futures"] @@ -17,7 +17,7 @@ path = "src/lib.rs" [dependencies] ntex-rt = "0.4.7" -ntex-service = "1.2.0" +ntex-service = "1.2.2" bitflags = "1.3" fxhash = "0.2.1" log = "0.4" diff --git a/ntex-util/src/services/buffer.rs b/ntex-util/src/services/buffer.rs index 7f7aae2a..ce1dcb6b 100644 --- a/ntex-util/src/services/buffer.rs +++ b/ntex-util/src/services/buffer.rs @@ -3,15 +3,16 @@ use std::cell::{Cell, RefCell}; use std::task::{ready, Context, Poll}; use std::{collections::VecDeque, future::Future, marker::PhantomData, pin::Pin}; -use ntex_service::{IntoService, Middleware, Service, ServiceCall, ServiceCtx}; +use ntex_service::{IntoService, Middleware, Service, ServiceCallToCall, ServiceCtx}; -use crate::{channel::oneshot, future::Either, task::LocalWaker}; +use crate::channel::{oneshot, Canceled}; /// Buffer - service factory for service that can buffer incoming request. /// /// Default number of buffered requests is 16 pub struct Buffer { buf_size: usize, + cancel_on_shutdown: bool, _t: PhantomData, } @@ -19,6 +20,7 @@ impl Default for Buffer { fn default() -> Self { Self { buf_size: 16, + cancel_on_shutdown: false, _t: PhantomData, } } @@ -29,12 +31,21 @@ impl Buffer { self.buf_size = size; self } + + /// Cancel all buffered requests on shutdown + /// + /// By default buffered requests are flushed during poll_shutdown + pub fn cancel_on_shutdown(mut self) -> Self { + self.cancel_on_shutdown = true; + self + } } impl Clone for Buffer { fn clone(&self) -> Self { Self { buf_size: self.buf_size, + cancel_on_shutdown: self.cancel_on_shutdown, _t: PhantomData, } } @@ -50,9 +61,10 @@ where BufferService { service, size: self.buf_size, + cancel_on_shutdown: self.cancel_on_shutdown, ready: Cell::new(false), - waker: LocalWaker::default(), buf: RefCell::new(VecDeque::with_capacity(self.buf_size)), + next_call: RefCell::default(), _t: PhantomData, } } @@ -63,10 +75,11 @@ where /// Default number of buffered requests is 16 pub struct BufferService> { size: usize, + cancel_on_shutdown: bool, ready: Cell, service: S, - waker: LocalWaker, - buf: RefCell>>, + buf: RefCell>>>, + next_call: RefCell>>, _t: PhantomData, } @@ -80,13 +93,21 @@ where { Self { size, + cancel_on_shutdown: false, ready: Cell::new(false), service: service.into_service(), - waker: LocalWaker::default(), buf: RefCell::new(VecDeque::with_capacity(size)), + next_call: RefCell::default(), _t: PhantomData, } } + + pub fn cancel_on_shutdown(self) -> Self { + Self { + cancel_on_shutdown: true, + ..self + } + } } impl Clone for BufferService @@ -96,10 +117,11 @@ where fn clone(&self) -> Self { Self { size: self.size, + cancel_on_shutdown: self.cancel_on_shutdown, ready: Cell::new(false), service: self.service.clone(), - waker: LocalWaker::default(), buf: RefCell::new(VecDeque::with_capacity(self.size)), + next_call: RefCell::default(), _t: PhantomData, } } @@ -110,51 +132,106 @@ where S: Service, { type Response = S::Response; - type Error = S::Error; - type Future<'f> = Either, BufferServiceResponse<'f, R, S>> where Self: 'f, R: 'f; + type Error = BufferServiceError; + type Future<'f> = BufferServiceResponse<'f, R, S> where Self: 'f, R: 'f; #[inline] fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { - self.waker.register(cx.waker()); let mut buffer = self.buf.borrow_mut(); + let mut next_call = self.next_call.borrow_mut(); + if let Some(next_call) = &*next_call { + // hold advancement until the last released task either makes a call or is dropped + let _ = ready!(next_call.poll_recv(cx)); + } + next_call.take(); if self.service.poll_ready(cx)?.is_pending() { if buffer.len() < self.size { // buffer next request self.ready.set(false); - Poll::Ready(Ok(())) + return Poll::Ready(Ok(())); } else { log::trace!("Buffer limit exceeded"); - Poll::Pending + return Poll::Pending; } - } else if let Some(sender) = buffer.pop_front() { - let _ = sender.send(()); - self.ready.set(false); - Poll::Ready(Ok(())) - } else { - self.ready.set(true); - Poll::Ready(Ok(())) } + + while let Some(sender) = buffer.pop_front() { + let (next_call_tx, next_call_rx) = oneshot::channel(); + if sender.send(next_call_tx).is_err() || next_call_rx.poll_recv(cx).is_ready() { + // the task is gone + continue; + } + next_call.replace(next_call_rx); + self.ready.set(false); + return Poll::Ready(Ok(())); + } + + self.ready.set(true); + Poll::Ready(Ok(())) } #[inline] fn call<'a>(&'a self, req: R, ctx: ServiceCtx<'a, Self>) -> Self::Future<'a> { if self.ready.get() { self.ready.set(false); - Either::Left(ctx.call(&self.service, req)) + BufferServiceResponse { + slf: self, + state: ResponseState::Running { + fut: ctx.call_nowait(&self.service, req), + }, + } } else { let (tx, rx) = oneshot::channel(); self.buf.borrow_mut().push_back(tx); - Either::Right(BufferServiceResponse { + BufferServiceResponse { slf: self, - fut: ctx.call(&self.service, req), - rx: Some(rx), - }) + state: ResponseState::WaitingForRelease { + rx, + call: Some(ctx.call(&self.service, req).advance_to_call()), + }, + } } } - ntex_service::forward_poll_shutdown!(service); + fn poll_shutdown(&self, cx: &mut std::task::Context<'_>) -> Poll<()> { + let mut buffer = self.buf.borrow_mut(); + if self.cancel_on_shutdown { + buffer.clear(); + } else if !buffer.is_empty() { + let mut next_call = self.next_call.borrow_mut(); + if let Some(next_call) = &*next_call { + // hold advancement until the last released task either makes a call or is dropped + let _ = ready!(next_call.poll_recv(cx)); + } + next_call.take(); + + if ready!(self.service.poll_ready(cx)).is_err() { + log::error!( + "Buffered inner service failed while buffer flushing on shutdown" + ); + return Poll::Ready(()); + } + + while let Some(sender) = buffer.pop_front() { + let (next_call_tx, next_call_rx) = oneshot::channel(); + if sender.send(next_call_tx).is_err() + || next_call_rx.poll_recv(cx).is_ready() + { + // the task is gone + continue; + } + next_call.replace(next_call_rx); + if buffer.is_empty() { + break; + } + return Poll::Pending; + } + } + + self.service.poll_shutdown(cx) + } } pin_project_lite::pin_project! { @@ -163,9 +240,18 @@ pin_project_lite::pin_project! { pub struct BufferServiceResponse<'f, R, S: Service> { #[pin] - fut: ServiceCall<'f, S, R>, + state: ResponseState<'f, R, S>, slf: &'f BufferService, - rx: Option>, + } +} + +pin_project_lite::pin_project! { + #[project = ResponseStateProject] + enum ResponseState<'f, R, S: Service> + { + WaitingForRelease { rx: oneshot::Receiver>, call: Option> }, + WaitingForReady { tx: oneshot::Sender<()>, #[pin] call: ServiceCallToCall<'f, S, R> }, + Running { #[pin] fut: S::Future<'f> }, } } @@ -173,22 +259,63 @@ impl<'f, R, S> Future for BufferServiceResponse<'f, R, S> where S: Service, { - type Output = Result; + type Output = Result>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.as_mut().project(); + let mut this = self.as_mut().project(); + match this.state.as_mut().project() { + ResponseStateProject::WaitingForRelease { rx, call } => { + match ready!(rx.poll_recv(cx)) { + Ok(tx) => { + let call = call.take().expect("always set in this state"); + this.state.set(ResponseState::WaitingForReady { tx, call }); + self.poll(cx) + } + Err(Canceled) => { + log::trace!("Buffered service request canceled"); + Poll::Ready(Err(BufferServiceError::RequestCanceled)) + } + } + } + ResponseStateProject::WaitingForReady { call, .. } => { + let fut = match ready!(call.poll(cx)) { + Ok(fut) => fut, + Err(err) => return Poll::Ready(Err(err.into())), + }; - if let Some(ref rx) = this.rx { - let _ = ready!(rx.poll_recv(cx)); - this.rx.take(); + this.state.set(ResponseState::Running { fut }); + self.poll(cx) + } + ResponseStateProject::Running { fut } => fut.poll(cx).map_err(|e| e.into()), } - - let res = ready!(this.fut.poll(cx)); - this.slf.waker.wake(); - Poll::Ready(res) } } +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum BufferServiceError { + Service(E), + RequestCanceled, +} + +impl From for BufferServiceError { + fn from(err: E) -> Self { + BufferServiceError::Service(err) + } +} + +impl std::fmt::Display for BufferServiceError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + BufferServiceError::Service(e) => std::fmt::Display::fmt(e, f), + BufferServiceError::RequestCanceled => { + f.write_str("buffer service request canceled") + } + } + } +} + +impl std::error::Error for BufferServiceError {} + #[cfg(test)] mod tests { use ntex_service::{apply, fn_factory, Pipeline, Service, ServiceFactory}; @@ -196,6 +323,7 @@ mod tests { use super::*; use crate::future::{lazy, Ready}; + use crate::task::LocalWaker; #[derive(Clone)] struct TestService(Rc);