From 5f6600c814ac986164cf87f0a83275ac533d3700 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 4 Nov 2024 12:49:18 +0500 Subject: [PATCH] Update Service trait and use unified Counter (#455) --- ntex-server/Cargo.toml | 2 +- ntex-server/src/net/counter.rs | 109 ------------------------------ ntex-server/src/net/mod.rs | 6 +- ntex-server/src/net/service.rs | 5 +- ntex-tls/CHANGES.md | 4 ++ ntex-tls/Cargo.toml | 6 +- ntex-tls/src/counter.rs | 84 ----------------------- ntex-tls/src/lib.rs | 5 +- ntex-tls/src/openssl/accept.rs | 19 ++++-- ntex-tls/src/rustls/accept.rs | 16 +++-- ntex-util/CHANGES.md | 4 +- ntex-util/src/services/counter.rs | 92 +++++++++++++++++-------- ntex-util/src/services/mod.rs | 5 +- 13 files changed, 112 insertions(+), 245 deletions(-) delete mode 100644 ntex-server/src/net/counter.rs delete mode 100644 ntex-tls/src/counter.rs diff --git a/ntex-server/Cargo.toml b/ntex-server/Cargo.toml index a21943b2..4ada3557 100644 --- a/ntex-server/Cargo.toml +++ b/ntex-server/Cargo.toml @@ -20,7 +20,7 @@ ntex-bytes = "0.1" ntex-net = "2" ntex-service = "3.3" ntex-rt = "0.4" -ntex-util = "2" +ntex-util = "2.5" async-channel = "2" async-broadcast = "0.7" diff --git a/ntex-server/src/net/counter.rs b/ntex-server/src/net/counter.rs deleted file mode 100644 index 87dbe2af..00000000 --- a/ntex-server/src/net/counter.rs +++ /dev/null @@ -1,109 +0,0 @@ -use std::{cell::Cell, future::poll_fn, rc::Rc, task}; - -use ntex_util::task::LocalWaker; - -#[derive(Debug)] -/// Simple counter with ability to notify task on reaching specific number -/// -/// Counter could be cloned, total count is shared across all clones. -pub(super) struct Counter(Rc); - -#[derive(Debug)] -struct CounterInner { - count: Cell, - capacity: usize, - task: LocalWaker, -} - -impl Counter { - /// Create `Counter` instance and set max value. - pub(super) fn new(capacity: usize) -> Self { - Counter(Rc::new(CounterInner { - capacity, - count: Cell::new(0), - task: LocalWaker::new(), - })) - } - - /// Get counter guard. - pub(super) fn get(&self) -> CounterGuard { - CounterGuard::new(self.0.clone()) - } - - pub(crate) fn is_available(&self) -> bool { - self.0.count.get() < self.0.capacity - } - - /// Check if counter is not at capacity. If counter at capacity - /// it registers notification for current task. - pub(crate) async fn available(&self) { - poll_fn(|cx| { - if self.0.available(cx) { - task::Poll::Ready(()) - } else { - task::Poll::Pending - } - }) - .await - } - - pub(crate) async fn unavailable(&self) { - poll_fn(|cx| { - if self.0.available(cx) { - task::Poll::Pending - } else { - task::Poll::Ready(()) - } - }) - .await - } - - /// Get total number of acquired counts - pub(super) fn total(&self) -> usize { - self.0.count.get() - } - - pub(super) fn priv_clone(&self) -> Self { - Counter(self.0.clone()) - } -} - -pub(super) struct CounterGuard(Rc); - -impl CounterGuard { - fn new(inner: Rc) -> Self { - inner.inc(); - CounterGuard(inner) - } -} - -impl Unpin for CounterGuard {} - -impl Drop for CounterGuard { - fn drop(&mut self) { - self.0.dec(); - } -} - -impl CounterInner { - fn inc(&self) { - let num = self.count.get() + 1; - self.count.set(num); - if num == self.capacity { - self.task.wake(); - } - } - - fn dec(&self) { - let num = self.count.get(); - self.count.set(num - 1); - if num == self.capacity { - self.task.wake(); - } - } - - fn available(&self, cx: &mut task::Context<'_>) -> bool { - self.task.register(cx.waker()); - self.count.get() < self.capacity - } -} diff --git a/ntex-server/src/net/mod.rs b/ntex-server/src/net/mod.rs index 4a4d8b82..626a97fc 100644 --- a/ntex-server/src/net/mod.rs +++ b/ntex-server/src/net/mod.rs @@ -1,10 +1,10 @@ //! General purpose tcp server +use ntex_util::services::Counter; use std::sync::atomic::{AtomicUsize, Ordering}; mod accept; mod builder; mod config; -mod counter; mod factory; mod service; mod socket; @@ -56,8 +56,7 @@ pub enum SslError { static MAX_CONNS: AtomicUsize = AtomicUsize::new(25600); thread_local! { - static MAX_CONNS_COUNTER: self::counter::Counter = - self::counter::Counter::new(MAX_CONNS.load(Ordering::Relaxed)); + static MAX_CONNS_COUNTER: Counter = Counter::new(MAX_CONNS.load(Ordering::Relaxed)); } /// Sets the maximum per-worker number of concurrent connections. @@ -68,6 +67,7 @@ thread_local! { /// By default max connections is set to a 25k per worker. pub(super) fn max_concurrent_connections(num: usize) { MAX_CONNS.store(num, Ordering::Relaxed); + MAX_CONNS_COUNTER.with(|conns| conns.set_capacity(num)); } pub(super) fn num_connections() -> usize { diff --git a/ntex-server/src/net/service.rs b/ntex-server/src/net/service.rs index b67df07d..4be6c828 100644 --- a/ntex-server/src/net/service.rs +++ b/ntex-server/src/net/service.rs @@ -3,12 +3,11 @@ use std::{fmt, future::poll_fn, future::Future, pin::Pin, task::Poll}; use ntex_bytes::{Pool, PoolRef}; use ntex_net::Io; use ntex_service::{boxed, Service, ServiceCtx, ServiceFactory}; -use ntex_util::{future::join_all, HashMap}; +use ntex_util::{future::join_all, services::Counter, HashMap}; use crate::ServerConfiguration; use super::accept::{AcceptNotify, AcceptorCommand}; -use super::counter::Counter; use super::factory::{FactoryServiceType, NetService, OnWorkerStart}; use super::{socket::Connection, Token, MAX_CONNS_COUNTER}; @@ -135,7 +134,7 @@ impl ServiceFactory for StreamService { Ok(StreamServiceImpl { tokens, services, - conns: MAX_CONNS_COUNTER.with(|conns| conns.priv_clone()), + conns: MAX_CONNS_COUNTER.with(|conns| conns.clone()), }) } } diff --git a/ntex-tls/CHANGES.md b/ntex-tls/CHANGES.md index 042aa771..8e417bc4 100644 --- a/ntex-tls/CHANGES.md +++ b/ntex-tls/CHANGES.md @@ -1,5 +1,9 @@ # Changes +## [2.3.0] - 2024-11-04 + +* Use updated Service trait + ## [2.2.0] - 2024-09-25 * Disable default features for rustls diff --git a/ntex-tls/Cargo.toml b/ntex-tls/Cargo.toml index b033e329..4a731964 100644 --- a/ntex-tls/Cargo.toml +++ b/ntex-tls/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-tls" -version = "2.2.0" +version = "2.3.0" authors = ["ntex contributors "] description = "An implementation of SSL streams for ntex backed by OpenSSL" keywords = ["network", "framework", "async", "futures"] @@ -28,8 +28,8 @@ rustls-ring = ["tls_rust", "tls_rust/ring", "tls_rust/std"] [dependencies] ntex-bytes = "0.1" ntex-io = "2.3" -ntex-util = "2" -ntex-service = "3" +ntex-util = "2.5" +ntex-service = "3.3" ntex-net = "2" log = "0.4" diff --git a/ntex-tls/src/counter.rs b/ntex-tls/src/counter.rs deleted file mode 100644 index d91bfc90..00000000 --- a/ntex-tls/src/counter.rs +++ /dev/null @@ -1,84 +0,0 @@ -#![allow(dead_code)] -use std::{cell::Cell, future::poll_fn, rc::Rc, task, task::Poll}; - -use ntex_util::task::LocalWaker; - -#[derive(Debug, Clone)] -/// Simple counter with ability to notify task on reaching specific number -/// -/// Counter could be cloned, total count is shared across all clones. -pub(super) struct Counter(Rc); - -#[derive(Debug)] -struct CounterInner { - count: Cell, - capacity: usize, - task: LocalWaker, -} - -impl Counter { - /// Create `Counter` instance and set max value. - pub(super) fn new(capacity: usize) -> Self { - Counter(Rc::new(CounterInner { - capacity, - count: Cell::new(0), - task: LocalWaker::new(), - })) - } - - /// Get counter guard. - pub(super) fn get(&self) -> CounterGuard { - CounterGuard::new(self.0.clone()) - } - - /// Check if counter is not at capacity. If counter at capacity - /// it registers notification for current task. - pub(super) async fn available(&self) { - poll_fn(|cx| { - if self.0.available(cx) { - Poll::Ready(()) - } else { - Poll::Pending - } - }) - .await - } -} - -pub(super) struct CounterGuard(Rc); - -impl CounterGuard { - fn new(inner: Rc) -> Self { - inner.inc(); - CounterGuard(inner) - } -} - -impl Drop for CounterGuard { - fn drop(&mut self) { - self.0.dec(); - } -} - -impl CounterInner { - fn inc(&self) { - self.count.set(self.count.get() + 1); - } - - fn dec(&self) { - let num = self.count.get(); - self.count.set(num - 1); - if num == self.capacity { - self.task.wake(); - } - } - - fn available(&self, cx: &mut task::Context<'_>) -> bool { - if self.count.get() < self.capacity { - true - } else { - self.task.register(cx.waker()); - false - } - } -} diff --git a/ntex-tls/src/lib.rs b/ntex-tls/src/lib.rs index 128cea8d..897c8c27 100644 --- a/ntex-tls/src/lib.rs +++ b/ntex-tls/src/lib.rs @@ -9,7 +9,7 @@ pub mod openssl; #[cfg(feature = "rustls")] pub mod rustls; -mod counter; +use ntex_util::services::Counter; /// Sets the maximum per-worker concurrent ssl connection establish process. /// @@ -19,12 +19,13 @@ mod counter; /// By default max connections is set to a 256. pub fn max_concurrent_ssl_accept(num: usize) { MAX_SSL_ACCEPT.store(num, Ordering::Relaxed); + MAX_SSL_ACCEPT_COUNTER.with(|counts| counts.set_capacity(num)); } static MAX_SSL_ACCEPT: AtomicUsize = AtomicUsize::new(256); thread_local! { - static MAX_SSL_ACCEPT_COUNTER: counter::Counter = counter::Counter::new(MAX_SSL_ACCEPT.load(Ordering::Relaxed)); + static MAX_SSL_ACCEPT_COUNTER: Counter = Counter::new(MAX_SSL_ACCEPT.load(Ordering::Relaxed)); } /// A TLS PSK identity. diff --git a/ntex-tls/src/openssl/accept.rs b/ntex-tls/src/openssl/accept.rs index a083f69e..95314d8d 100644 --- a/ntex-tls/src/openssl/accept.rs +++ b/ntex-tls/src/openssl/accept.rs @@ -2,13 +2,10 @@ use std::{cell::RefCell, error::Error, fmt, io}; use ntex_io::{Filter, Io, Layer}; use ntex_service::{Service, ServiceCtx, ServiceFactory}; -use ntex_util::time::{self, Millis}; +use ntex_util::{services::Counter, time, time::Millis}; use tls_openssl::ssl; -use crate::counter::Counter; -use crate::MAX_SSL_ACCEPT_COUNTER; - -use super::SslFilter; +use crate::{openssl::SslFilter, MAX_SSL_ACCEPT_COUNTER}; /// Support `TLS` server connections via openssl package /// @@ -98,15 +95,25 @@ impl Service> for SslAcceptorService { type Error = Box; async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> { - self.conns.available().await; + if !self.conns.is_available() { + self.conns.available().await + } Ok(()) } + #[inline] + async fn not_ready(&self) { + if self.conns.is_available() { + self.conns.unavailable().await + } + } + async fn call( &self, io: Io, _: ServiceCtx<'_, Self>, ) -> Result { + let _guard = self.conns.get(); let timeout = self.timeout; let ctx_result = ssl::Ssl::new(self.acceptor.context()); diff --git a/ntex-tls/src/rustls/accept.rs b/ntex-tls/src/rustls/accept.rs index d5102357..7565a7a2 100644 --- a/ntex-tls/src/rustls/accept.rs +++ b/ntex-tls/src/rustls/accept.rs @@ -4,10 +4,9 @@ use tls_rust::ServerConfig; use ntex_io::{Filter, Io, Layer}; use ntex_service::{Service, ServiceCtx, ServiceFactory}; -use ntex_util::time::Millis; +use ntex_util::{services::Counter, time::Millis}; -use super::TlsServerFilter; -use crate::{counter::Counter, MAX_SSL_ACCEPT_COUNTER}; +use crate::{rustls::TlsServerFilter, MAX_SSL_ACCEPT_COUNTER}; #[derive(Debug)] /// Support `SSL` connections via rustls package @@ -81,10 +80,19 @@ impl Service> for TlsAcceptorService { type Error = io::Error; async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> { - self.conns.available().await; + if !self.conns.is_available() { + self.conns.available().await + } Ok(()) } + #[inline] + async fn not_ready(&self) { + if self.conns.is_available() { + self.conns.unavailable().await + } + } + async fn call( &self, io: Io, diff --git a/ntex-util/CHANGES.md b/ntex-util/CHANGES.md index 6cbd96a6..80460c89 100644 --- a/ntex-util/CHANGES.md +++ b/ntex-util/CHANGES.md @@ -1,9 +1,11 @@ # Changes -## [2.5.0] - 2024-11-02 +## [2.5.0] - 2024-11-04 * Use updated Service trait +* Export Counter type + ## [2.4.0] - 2024-09-26 * Remove "must_use" from `condition::Waiter` diff --git a/ntex-util/src/services/counter.rs b/ntex-util/src/services/counter.rs index 5ceb0c37..b1139691 100644 --- a/ntex-util/src/services/counter.rs +++ b/ntex-util/src/services/counter.rs @@ -1,4 +1,4 @@ -use std::{cell::Cell, future::poll_fn, rc::Rc, task}; +use std::{cell::Cell, cell::RefCell, future::poll_fn, rc::Rc, task::Context, task::Poll}; use crate::task::LocalWaker; @@ -6,53 +6,67 @@ use crate::task::LocalWaker; /// /// Counter could be cloned, total count is shared across all clones. #[derive(Debug)] -pub struct Counter(Rc); +pub struct Counter(usize, Rc); #[derive(Debug)] struct CounterInner { count: Cell, - capacity: usize, - task: LocalWaker, + capacity: Cell, + tasks: RefCell>, } impl Counter { /// Create `Counter` instance and set max value. pub fn new(capacity: usize) -> Self { - Counter(Rc::new(CounterInner { - capacity, - count: Cell::new(0), - task: LocalWaker::new(), - })) + let mut tasks = slab::Slab::new(); + let idx = tasks.insert(LocalWaker::new()); + + Counter( + idx, + Rc::new(CounterInner { + count: Cell::new(0), + capacity: Cell::new(capacity), + tasks: RefCell::new(tasks), + }), + ) } /// Get counter guard. - pub(crate) fn get(&self) -> CounterGuard { - CounterGuard::new(self.0.clone()) + pub fn get(&self) -> CounterGuard { + CounterGuard::new(self.1.clone()) } - pub(crate) fn is_available(&self) -> bool { - self.0.count.get() < self.0.capacity + /// Set counter capacity + pub fn set_capacity(&self, cap: usize) { + self.1.capacity.set(cap); + self.1.notify(); + } + + /// Check is counter has free capacity. + pub fn is_available(&self) -> bool { + self.1.count.get() < self.1.capacity.get() } /// Check if counter is not at capacity. If counter at capacity /// it registers notification for current task. - pub(crate) async fn available(&self) { + pub async fn available(&self) { poll_fn(|cx| { if self.poll_available(cx) { - task::Poll::Ready(()) + Poll::Ready(()) } else { - task::Poll::Pending + Poll::Pending } }) .await } - pub(crate) async fn unavailable(&self) { + /// Wait untile counter becomes at capacity. + pub async fn unavailable(&self) { poll_fn(|cx| { if self.poll_available(cx) { - task::Poll::Pending + Poll::Pending } else { - task::Poll::Ready(()) + Poll::Ready(()) } }) .await @@ -60,8 +74,28 @@ impl Counter { /// Check if counter is not at capacity. If counter at capacity /// it registers notification for current task. - fn poll_available(&self, cx: &mut task::Context<'_>) -> bool { - self.0.available(cx) + fn poll_available(&self, cx: &mut Context<'_>) -> bool { + let tasks = self.1.tasks.borrow(); + tasks[self.0].register(cx.waker()); + self.1.count.get() < self.1.capacity.get() + } + + /// Get total number of acquired counts + pub fn total(&self) -> usize { + self.1.count.get() + } +} + +impl Clone for Counter { + fn clone(&self) -> Self { + let idx = self.1.tasks.borrow_mut().insert(LocalWaker::new()); + Self(idx, self.1.clone()) + } +} + +impl Drop for Counter { + fn drop(&mut self) { + self.1.tasks.borrow_mut().remove(self.0); } } @@ -87,21 +121,23 @@ impl CounterInner { fn inc(&self) { let num = self.count.get() + 1; self.count.set(num); - if num == self.capacity { - self.task.wake(); + if num == self.capacity.get() { + self.notify(); } } fn dec(&self) { let num = self.count.get(); self.count.set(num - 1); - if num == self.capacity { - self.task.wake(); + if num == self.capacity.get() { + self.notify(); } } - fn available(&self, cx: &mut task::Context<'_>) -> bool { - self.task.register(cx.waker()); - self.count.get() < self.capacity + fn notify(&self) { + let tasks = self.tasks.borrow(); + for (_, task) in &*tasks { + task.wake() + } } } diff --git a/ntex-util/src/services/mod.rs b/ntex-util/src/services/mod.rs index 28974192..6f60afb0 100644 --- a/ntex-util/src/services/mod.rs +++ b/ntex-util/src/services/mod.rs @@ -1,5 +1,4 @@ pub mod buffer; -pub mod counter; mod extensions; pub mod inflight; pub mod keepalive; @@ -7,4 +6,8 @@ pub mod onerequest; pub mod timeout; pub mod variant; +#[doc(hidden)] +pub mod counter; + +pub use self::counter::{Counter, CounterGuard}; pub use self::extensions::Extensions;