Update Service trait and use unified Counter (#455)

This commit is contained in:
Nikolay Kim 2024-11-04 12:49:18 +05:00 committed by GitHub
parent 30115bf2d5
commit 5f6600c814
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 112 additions and 245 deletions

View file

@ -20,7 +20,7 @@ ntex-bytes = "0.1"
ntex-net = "2" ntex-net = "2"
ntex-service = "3.3" ntex-service = "3.3"
ntex-rt = "0.4" ntex-rt = "0.4"
ntex-util = "2" ntex-util = "2.5"
async-channel = "2" async-channel = "2"
async-broadcast = "0.7" async-broadcast = "0.7"

View file

@ -1,109 +0,0 @@
use std::{cell::Cell, future::poll_fn, rc::Rc, task};
use ntex_util::task::LocalWaker;
#[derive(Debug)]
/// Simple counter with ability to notify task on reaching specific number
///
/// Counter could be cloned, total count is shared across all clones.
pub(super) struct Counter(Rc<CounterInner>);
#[derive(Debug)]
struct CounterInner {
count: Cell<usize>,
capacity: usize,
task: LocalWaker,
}
impl Counter {
/// Create `Counter` instance and set max value.
pub(super) fn new(capacity: usize) -> Self {
Counter(Rc::new(CounterInner {
capacity,
count: Cell::new(0),
task: LocalWaker::new(),
}))
}
/// Get counter guard.
pub(super) fn get(&self) -> CounterGuard {
CounterGuard::new(self.0.clone())
}
pub(crate) fn is_available(&self) -> bool {
self.0.count.get() < self.0.capacity
}
/// Check if counter is not at capacity. If counter at capacity
/// it registers notification for current task.
pub(crate) async fn available(&self) {
poll_fn(|cx| {
if self.0.available(cx) {
task::Poll::Ready(())
} else {
task::Poll::Pending
}
})
.await
}
pub(crate) async fn unavailable(&self) {
poll_fn(|cx| {
if self.0.available(cx) {
task::Poll::Pending
} else {
task::Poll::Ready(())
}
})
.await
}
/// Get total number of acquired counts
pub(super) fn total(&self) -> usize {
self.0.count.get()
}
pub(super) fn priv_clone(&self) -> Self {
Counter(self.0.clone())
}
}
pub(super) struct CounterGuard(Rc<CounterInner>);
impl CounterGuard {
fn new(inner: Rc<CounterInner>) -> Self {
inner.inc();
CounterGuard(inner)
}
}
impl Unpin for CounterGuard {}
impl Drop for CounterGuard {
fn drop(&mut self) {
self.0.dec();
}
}
impl CounterInner {
fn inc(&self) {
let num = self.count.get() + 1;
self.count.set(num);
if num == self.capacity {
self.task.wake();
}
}
fn dec(&self) {
let num = self.count.get();
self.count.set(num - 1);
if num == self.capacity {
self.task.wake();
}
}
fn available(&self, cx: &mut task::Context<'_>) -> bool {
self.task.register(cx.waker());
self.count.get() < self.capacity
}
}

View file

@ -1,10 +1,10 @@
//! General purpose tcp server //! General purpose tcp server
use ntex_util::services::Counter;
use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::atomic::{AtomicUsize, Ordering};
mod accept; mod accept;
mod builder; mod builder;
mod config; mod config;
mod counter;
mod factory; mod factory;
mod service; mod service;
mod socket; mod socket;
@ -56,8 +56,7 @@ pub enum SslError<E> {
static MAX_CONNS: AtomicUsize = AtomicUsize::new(25600); static MAX_CONNS: AtomicUsize = AtomicUsize::new(25600);
thread_local! { thread_local! {
static MAX_CONNS_COUNTER: self::counter::Counter = static MAX_CONNS_COUNTER: Counter = Counter::new(MAX_CONNS.load(Ordering::Relaxed));
self::counter::Counter::new(MAX_CONNS.load(Ordering::Relaxed));
} }
/// Sets the maximum per-worker number of concurrent connections. /// Sets the maximum per-worker number of concurrent connections.
@ -68,6 +67,7 @@ thread_local! {
/// By default max connections is set to a 25k per worker. /// By default max connections is set to a 25k per worker.
pub(super) fn max_concurrent_connections(num: usize) { pub(super) fn max_concurrent_connections(num: usize) {
MAX_CONNS.store(num, Ordering::Relaxed); MAX_CONNS.store(num, Ordering::Relaxed);
MAX_CONNS_COUNTER.with(|conns| conns.set_capacity(num));
} }
pub(super) fn num_connections() -> usize { pub(super) fn num_connections() -> usize {

View file

@ -3,12 +3,11 @@ use std::{fmt, future::poll_fn, future::Future, pin::Pin, task::Poll};
use ntex_bytes::{Pool, PoolRef}; use ntex_bytes::{Pool, PoolRef};
use ntex_net::Io; use ntex_net::Io;
use ntex_service::{boxed, Service, ServiceCtx, ServiceFactory}; use ntex_service::{boxed, Service, ServiceCtx, ServiceFactory};
use ntex_util::{future::join_all, HashMap}; use ntex_util::{future::join_all, services::Counter, HashMap};
use crate::ServerConfiguration; use crate::ServerConfiguration;
use super::accept::{AcceptNotify, AcceptorCommand}; use super::accept::{AcceptNotify, AcceptorCommand};
use super::counter::Counter;
use super::factory::{FactoryServiceType, NetService, OnWorkerStart}; use super::factory::{FactoryServiceType, NetService, OnWorkerStart};
use super::{socket::Connection, Token, MAX_CONNS_COUNTER}; use super::{socket::Connection, Token, MAX_CONNS_COUNTER};
@ -135,7 +134,7 @@ impl ServiceFactory<Connection> for StreamService {
Ok(StreamServiceImpl { Ok(StreamServiceImpl {
tokens, tokens,
services, services,
conns: MAX_CONNS_COUNTER.with(|conns| conns.priv_clone()), conns: MAX_CONNS_COUNTER.with(|conns| conns.clone()),
}) })
} }
} }

View file

@ -1,5 +1,9 @@
# Changes # Changes
## [2.3.0] - 2024-11-04
* Use updated Service trait
## [2.2.0] - 2024-09-25 ## [2.2.0] - 2024-09-25
* Disable default features for rustls * Disable default features for rustls

View file

@ -1,6 +1,6 @@
[package] [package]
name = "ntex-tls" name = "ntex-tls"
version = "2.2.0" version = "2.3.0"
authors = ["ntex contributors <team@ntex.rs>"] authors = ["ntex contributors <team@ntex.rs>"]
description = "An implementation of SSL streams for ntex backed by OpenSSL" description = "An implementation of SSL streams for ntex backed by OpenSSL"
keywords = ["network", "framework", "async", "futures"] keywords = ["network", "framework", "async", "futures"]
@ -28,8 +28,8 @@ rustls-ring = ["tls_rust", "tls_rust/ring", "tls_rust/std"]
[dependencies] [dependencies]
ntex-bytes = "0.1" ntex-bytes = "0.1"
ntex-io = "2.3" ntex-io = "2.3"
ntex-util = "2" ntex-util = "2.5"
ntex-service = "3" ntex-service = "3.3"
ntex-net = "2" ntex-net = "2"
log = "0.4" log = "0.4"

View file

@ -1,84 +0,0 @@
#![allow(dead_code)]
use std::{cell::Cell, future::poll_fn, rc::Rc, task, task::Poll};
use ntex_util::task::LocalWaker;
#[derive(Debug, Clone)]
/// Simple counter with ability to notify task on reaching specific number
///
/// Counter could be cloned, total count is shared across all clones.
pub(super) struct Counter(Rc<CounterInner>);
#[derive(Debug)]
struct CounterInner {
count: Cell<usize>,
capacity: usize,
task: LocalWaker,
}
impl Counter {
/// Create `Counter` instance and set max value.
pub(super) fn new(capacity: usize) -> Self {
Counter(Rc::new(CounterInner {
capacity,
count: Cell::new(0),
task: LocalWaker::new(),
}))
}
/// Get counter guard.
pub(super) fn get(&self) -> CounterGuard {
CounterGuard::new(self.0.clone())
}
/// Check if counter is not at capacity. If counter at capacity
/// it registers notification for current task.
pub(super) async fn available(&self) {
poll_fn(|cx| {
if self.0.available(cx) {
Poll::Ready(())
} else {
Poll::Pending
}
})
.await
}
}
pub(super) struct CounterGuard(Rc<CounterInner>);
impl CounterGuard {
fn new(inner: Rc<CounterInner>) -> Self {
inner.inc();
CounterGuard(inner)
}
}
impl Drop for CounterGuard {
fn drop(&mut self) {
self.0.dec();
}
}
impl CounterInner {
fn inc(&self) {
self.count.set(self.count.get() + 1);
}
fn dec(&self) {
let num = self.count.get();
self.count.set(num - 1);
if num == self.capacity {
self.task.wake();
}
}
fn available(&self, cx: &mut task::Context<'_>) -> bool {
if self.count.get() < self.capacity {
true
} else {
self.task.register(cx.waker());
false
}
}
}

View file

@ -9,7 +9,7 @@ pub mod openssl;
#[cfg(feature = "rustls")] #[cfg(feature = "rustls")]
pub mod rustls; pub mod rustls;
mod counter; use ntex_util::services::Counter;
/// Sets the maximum per-worker concurrent ssl connection establish process. /// Sets the maximum per-worker concurrent ssl connection establish process.
/// ///
@ -19,12 +19,13 @@ mod counter;
/// By default max connections is set to a 256. /// By default max connections is set to a 256.
pub fn max_concurrent_ssl_accept(num: usize) { pub fn max_concurrent_ssl_accept(num: usize) {
MAX_SSL_ACCEPT.store(num, Ordering::Relaxed); MAX_SSL_ACCEPT.store(num, Ordering::Relaxed);
MAX_SSL_ACCEPT_COUNTER.with(|counts| counts.set_capacity(num));
} }
static MAX_SSL_ACCEPT: AtomicUsize = AtomicUsize::new(256); static MAX_SSL_ACCEPT: AtomicUsize = AtomicUsize::new(256);
thread_local! { thread_local! {
static MAX_SSL_ACCEPT_COUNTER: counter::Counter = counter::Counter::new(MAX_SSL_ACCEPT.load(Ordering::Relaxed)); static MAX_SSL_ACCEPT_COUNTER: Counter = Counter::new(MAX_SSL_ACCEPT.load(Ordering::Relaxed));
} }
/// A TLS PSK identity. /// A TLS PSK identity.

View file

@ -2,13 +2,10 @@ use std::{cell::RefCell, error::Error, fmt, io};
use ntex_io::{Filter, Io, Layer}; use ntex_io::{Filter, Io, Layer};
use ntex_service::{Service, ServiceCtx, ServiceFactory}; use ntex_service::{Service, ServiceCtx, ServiceFactory};
use ntex_util::time::{self, Millis}; use ntex_util::{services::Counter, time, time::Millis};
use tls_openssl::ssl; use tls_openssl::ssl;
use crate::counter::Counter; use crate::{openssl::SslFilter, MAX_SSL_ACCEPT_COUNTER};
use crate::MAX_SSL_ACCEPT_COUNTER;
use super::SslFilter;
/// Support `TLS` server connections via openssl package /// Support `TLS` server connections via openssl package
/// ///
@ -98,15 +95,25 @@ impl<F: Filter> Service<Io<F>> for SslAcceptorService {
type Error = Box<dyn Error>; type Error = Box<dyn Error>;
async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> { async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
self.conns.available().await; if !self.conns.is_available() {
self.conns.available().await
}
Ok(()) Ok(())
} }
#[inline]
async fn not_ready(&self) {
if self.conns.is_available() {
self.conns.unavailable().await
}
}
async fn call( async fn call(
&self, &self,
io: Io<F>, io: Io<F>,
_: ServiceCtx<'_, Self>, _: ServiceCtx<'_, Self>,
) -> Result<Self::Response, Self::Error> { ) -> Result<Self::Response, Self::Error> {
let _guard = self.conns.get();
let timeout = self.timeout; let timeout = self.timeout;
let ctx_result = ssl::Ssl::new(self.acceptor.context()); let ctx_result = ssl::Ssl::new(self.acceptor.context());

View file

@ -4,10 +4,9 @@ use tls_rust::ServerConfig;
use ntex_io::{Filter, Io, Layer}; use ntex_io::{Filter, Io, Layer};
use ntex_service::{Service, ServiceCtx, ServiceFactory}; use ntex_service::{Service, ServiceCtx, ServiceFactory};
use ntex_util::time::Millis; use ntex_util::{services::Counter, time::Millis};
use super::TlsServerFilter; use crate::{rustls::TlsServerFilter, MAX_SSL_ACCEPT_COUNTER};
use crate::{counter::Counter, MAX_SSL_ACCEPT_COUNTER};
#[derive(Debug)] #[derive(Debug)]
/// Support `SSL` connections via rustls package /// Support `SSL` connections via rustls package
@ -81,10 +80,19 @@ impl<F: Filter> Service<Io<F>> for TlsAcceptorService {
type Error = io::Error; type Error = io::Error;
async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> { async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
self.conns.available().await; if !self.conns.is_available() {
self.conns.available().await
}
Ok(()) Ok(())
} }
#[inline]
async fn not_ready(&self) {
if self.conns.is_available() {
self.conns.unavailable().await
}
}
async fn call( async fn call(
&self, &self,
io: Io<F>, io: Io<F>,

View file

@ -1,9 +1,11 @@
# Changes # Changes
## [2.5.0] - 2024-11-02 ## [2.5.0] - 2024-11-04
* Use updated Service trait * Use updated Service trait
* Export Counter type
## [2.4.0] - 2024-09-26 ## [2.4.0] - 2024-09-26
* Remove "must_use" from `condition::Waiter` * Remove "must_use" from `condition::Waiter`

View file

@ -1,4 +1,4 @@
use std::{cell::Cell, future::poll_fn, rc::Rc, task}; use std::{cell::Cell, cell::RefCell, future::poll_fn, rc::Rc, task::Context, task::Poll};
use crate::task::LocalWaker; use crate::task::LocalWaker;
@ -6,53 +6,67 @@ use crate::task::LocalWaker;
/// ///
/// Counter could be cloned, total count is shared across all clones. /// Counter could be cloned, total count is shared across all clones.
#[derive(Debug)] #[derive(Debug)]
pub struct Counter(Rc<CounterInner>); pub struct Counter(usize, Rc<CounterInner>);
#[derive(Debug)] #[derive(Debug)]
struct CounterInner { struct CounterInner {
count: Cell<usize>, count: Cell<usize>,
capacity: usize, capacity: Cell<usize>,
task: LocalWaker, tasks: RefCell<slab::Slab<LocalWaker>>,
} }
impl Counter { impl Counter {
/// Create `Counter` instance and set max value. /// Create `Counter` instance and set max value.
pub fn new(capacity: usize) -> Self { pub fn new(capacity: usize) -> Self {
Counter(Rc::new(CounterInner { let mut tasks = slab::Slab::new();
capacity, let idx = tasks.insert(LocalWaker::new());
count: Cell::new(0),
task: LocalWaker::new(), Counter(
})) idx,
Rc::new(CounterInner {
count: Cell::new(0),
capacity: Cell::new(capacity),
tasks: RefCell::new(tasks),
}),
)
} }
/// Get counter guard. /// Get counter guard.
pub(crate) fn get(&self) -> CounterGuard { pub fn get(&self) -> CounterGuard {
CounterGuard::new(self.0.clone()) CounterGuard::new(self.1.clone())
} }
pub(crate) fn is_available(&self) -> bool { /// Set counter capacity
self.0.count.get() < self.0.capacity pub fn set_capacity(&self, cap: usize) {
self.1.capacity.set(cap);
self.1.notify();
}
/// Check is counter has free capacity.
pub fn is_available(&self) -> bool {
self.1.count.get() < self.1.capacity.get()
} }
/// Check if counter is not at capacity. If counter at capacity /// Check if counter is not at capacity. If counter at capacity
/// it registers notification for current task. /// it registers notification for current task.
pub(crate) async fn available(&self) { pub async fn available(&self) {
poll_fn(|cx| { poll_fn(|cx| {
if self.poll_available(cx) { if self.poll_available(cx) {
task::Poll::Ready(()) Poll::Ready(())
} else { } else {
task::Poll::Pending Poll::Pending
} }
}) })
.await .await
} }
pub(crate) async fn unavailable(&self) { /// Wait untile counter becomes at capacity.
pub async fn unavailable(&self) {
poll_fn(|cx| { poll_fn(|cx| {
if self.poll_available(cx) { if self.poll_available(cx) {
task::Poll::Pending Poll::Pending
} else { } else {
task::Poll::Ready(()) Poll::Ready(())
} }
}) })
.await .await
@ -60,8 +74,28 @@ impl Counter {
/// Check if counter is not at capacity. If counter at capacity /// Check if counter is not at capacity. If counter at capacity
/// it registers notification for current task. /// it registers notification for current task.
fn poll_available(&self, cx: &mut task::Context<'_>) -> bool { fn poll_available(&self, cx: &mut Context<'_>) -> bool {
self.0.available(cx) let tasks = self.1.tasks.borrow();
tasks[self.0].register(cx.waker());
self.1.count.get() < self.1.capacity.get()
}
/// Get total number of acquired counts
pub fn total(&self) -> usize {
self.1.count.get()
}
}
impl Clone for Counter {
fn clone(&self) -> Self {
let idx = self.1.tasks.borrow_mut().insert(LocalWaker::new());
Self(idx, self.1.clone())
}
}
impl Drop for Counter {
fn drop(&mut self) {
self.1.tasks.borrow_mut().remove(self.0);
} }
} }
@ -87,21 +121,23 @@ impl CounterInner {
fn inc(&self) { fn inc(&self) {
let num = self.count.get() + 1; let num = self.count.get() + 1;
self.count.set(num); self.count.set(num);
if num == self.capacity { if num == self.capacity.get() {
self.task.wake(); self.notify();
} }
} }
fn dec(&self) { fn dec(&self) {
let num = self.count.get(); let num = self.count.get();
self.count.set(num - 1); self.count.set(num - 1);
if num == self.capacity { if num == self.capacity.get() {
self.task.wake(); self.notify();
} }
} }
fn available(&self, cx: &mut task::Context<'_>) -> bool { fn notify(&self) {
self.task.register(cx.waker()); let tasks = self.tasks.borrow();
self.count.get() < self.capacity for (_, task) in &*tasks {
task.wake()
}
} }
} }

View file

@ -1,5 +1,4 @@
pub mod buffer; pub mod buffer;
pub mod counter;
mod extensions; mod extensions;
pub mod inflight; pub mod inflight;
pub mod keepalive; pub mod keepalive;
@ -7,4 +6,8 @@ pub mod onerequest;
pub mod timeout; pub mod timeout;
pub mod variant; pub mod variant;
#[doc(hidden)]
pub mod counter;
pub use self::counter::{Counter, CounterGuard};
pub use self::extensions::Extensions; pub use self::extensions::Extensions;