Use Cell instead of RefCell for timer (#467)

This commit is contained in:
Nikolay Kim 2024-11-19 13:32:07 -08:00 committed by GitHub
parent daeded8f3b
commit 98646dee57
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 196 additions and 178 deletions

View file

@ -1,5 +1,9 @@
# Changes # Changes
## [2.6.0] - 2024-11-19
* Use Cell instead of RefCell for timer
## [2.5.0] - 2024-11-04 ## [2.5.0] - 2024-11-04
* Use updated Service trait * Use updated Service trait

View file

@ -1,6 +1,6 @@
[package] [package]
name = "ntex-util" name = "ntex-util"
version = "2.5.0" version = "2.6.0"
authors = ["ntex contributors <team@ntex.rs>"] authors = ["ntex contributors <team@ntex.rs>"]
description = "Utilities for ntex framework" description = "Utilities for ntex framework"
keywords = ["network", "framework", "async", "futures"] keywords = ["network", "framework", "async", "futures"]

View file

@ -2,9 +2,8 @@
//! //!
//! Inspired by linux kernel timers system //! Inspired by linux kernel timers system
#![allow(arithmetic_overflow, clippy::let_underscore_future)] #![allow(arithmetic_overflow, clippy::let_underscore_future)]
use std::cell::{Cell, RefCell};
use std::time::{Duration, Instant, SystemTime}; use std::time::{Duration, Instant, SystemTime};
use std::{cmp::max, future::Future, mem, pin::Pin, rc::Rc, task, task::Poll}; use std::{cell::Cell, cmp::max, future::Future, mem, pin::Pin, rc::Rc, task, task::Poll};
use futures_timer::Delay; use futures_timer::Delay;
use slab::Slab; use slab::Slab;
@ -72,7 +71,7 @@ const fn as_millis(dur: Duration) -> u64 {
/// Resolution is 5ms /// Resolution is 5ms
#[inline] #[inline]
pub fn now() -> Instant { pub fn now() -> Instant {
TIMER.with(Timer::now) TIMER.with(|t| t.with_mod(|inner| t.now(inner)))
} }
/// Returns the system time corresponding to “now”. /// Returns the system time corresponding to “now”.
@ -80,7 +79,7 @@ pub fn now() -> Instant {
/// Resolution is 5ms /// Resolution is 5ms
#[inline] #[inline]
pub fn system_time() -> SystemTime { pub fn system_time() -> SystemTime {
TIMER.with(Timer::system_time) TIMER.with(|t| t.with_mod(|inner| t.system_time(inner)))
} }
/// Returns the system time corresponding to “now”. /// Returns the system time corresponding to “now”.
@ -90,7 +89,7 @@ pub fn system_time() -> SystemTime {
#[inline] #[inline]
#[doc(hidden)] #[doc(hidden)]
pub fn query_system_time() -> SystemTime { pub fn query_system_time() -> SystemTime {
TIMER.with(Timer::system_time) TIMER.with(|t| t.with_mod(|inner| t.system_time(inner)))
} }
#[derive(Debug)] #[derive(Debug)]
@ -108,25 +107,27 @@ impl TimerHandle {
} }
pub fn is_elapsed(&self) -> bool { pub fn is_elapsed(&self) -> bool {
TIMER.with(|t| t.0.inner.borrow().timers[self.0].bucket.is_none()) TIMER.with(|t| t.with_mod(|m| m.timers[self.0].bucket.is_none()))
} }
pub fn poll_elapsed(&self, cx: &mut task::Context<'_>) -> Poll<()> { pub fn poll_elapsed(&self, cx: &mut task::Context<'_>) -> Poll<()> {
TIMER.with(|t| { TIMER.with(|t| {
let entry = &t.0.inner.borrow().timers[self.0]; t.with_mod(|inner| {
if entry.bucket.is_none() { let entry = &inner.timers[self.0];
Poll::Ready(()) if entry.bucket.is_none() {
} else { Poll::Ready(())
entry.task.register(cx.waker()); } else {
Poll::Pending entry.task.register(cx.waker());
} Poll::Pending
}
})
}) })
} }
} }
impl Drop for TimerHandle { impl Drop for TimerHandle {
fn drop(&mut self) { fn drop(&mut self) {
TIMER.with(|t| t.remove_timer(self.0)); TIMER.with(|t| t.with_mod(|inner| inner.remove_timer_bucket(self.0, true)))
} }
} }
@ -156,7 +157,7 @@ struct TimerInner {
lowres_time: Cell<Option<Instant>>, lowres_time: Cell<Option<Instant>>,
lowres_stime: Cell<Option<SystemTime>>, lowres_stime: Cell<Option<SystemTime>>,
lowres_driver: LocalWaker, lowres_driver: LocalWaker,
inner: RefCell<TimerMod>, inner: Cell<Option<Box<TimerMod>>>,
} }
struct TimerMod { struct TimerMod {
@ -170,6 +171,11 @@ struct TimerMod {
impl Timer { impl Timer {
fn new() -> Self { fn new() -> Self {
println!(
"=========== {:?}",
std::mem::size_of::<Option<Box<TimerMod>>>()
);
Timer(Rc::new(TimerInner { Timer(Rc::new(TimerInner {
elapsed: Cell::new(0), elapsed: Cell::new(0),
elapsed_time: Cell::new(None), elapsed_time: Cell::new(None),
@ -179,16 +185,26 @@ impl Timer {
lowres_time: Cell::new(None), lowres_time: Cell::new(None),
lowres_stime: Cell::new(None), lowres_stime: Cell::new(None),
lowres_driver: LocalWaker::new(), lowres_driver: LocalWaker::new(),
inner: RefCell::new(TimerMod { inner: Cell::new(Some(Box::new(TimerMod {
buckets: Self::create_buckets(), buckets: Self::create_buckets(),
timers: Slab::default(), timers: Slab::default(),
driver_sleep: Delay::new(Duration::ZERO), driver_sleep: Delay::new(Duration::ZERO),
occupied: [0; WHEEL_SIZE], occupied: [0; WHEEL_SIZE],
lowres_driver_sleep: Delay::new(Duration::ZERO), lowres_driver_sleep: Delay::new(Duration::ZERO),
}), }))),
})) }))
} }
fn with_mod<F, R>(&self, f: F) -> R
where
F: FnOnce(&mut TimerMod) -> R,
{
let mut m = self.0.inner.take().unwrap();
let result = f(&mut m);
self.0.inner.set(Some(m));
result
}
fn create_buckets() -> Vec<Bucket> { fn create_buckets() -> Vec<Bucket> {
let mut buckets = Vec::with_capacity(WHEEL_SIZE); let mut buckets = Vec::with_capacity(WHEEL_SIZE);
for idx in 0..WHEEL_SIZE { for idx in 0..WHEEL_SIZE {
@ -199,7 +215,7 @@ impl Timer {
buckets buckets
} }
fn now(&self) -> Instant { fn now(&self, inner: &mut TimerMod) -> Instant {
if let Some(cur) = self.0.lowres_time.get() { if let Some(cur) = self.0.lowres_time.get() {
cur cur
} else { } else {
@ -212,14 +228,14 @@ impl Timer {
if flags.contains(Flags::LOWRES_DRIVER) { if flags.contains(Flags::LOWRES_DRIVER) {
self.0.lowres_driver.wake(); self.0.lowres_driver.wake();
} else { } else {
LowresTimerDriver::start(self.0.clone()); LowresTimerDriver::start(self.0.clone(), inner);
} }
} }
now now
} }
} }
fn system_time(&self) -> SystemTime { fn system_time(&self, inner: &mut TimerMod) -> SystemTime {
if let Some(cur) = self.0.lowres_stime.get() { if let Some(cur) = self.0.lowres_stime.get() {
cur cur
} else { } else {
@ -232,7 +248,7 @@ impl Timer {
if flags.contains(Flags::LOWRES_DRIVER) { if flags.contains(Flags::LOWRES_DRIVER) {
self.0.lowres_driver.wake(); self.0.lowres_driver.wake();
} else { } else {
LowresTimerDriver::start(self.0.clone()); LowresTimerDriver::start(self.0.clone(), inner);
} }
} }
now now
@ -241,105 +257,103 @@ impl Timer {
/// Add the timer into the hash bucket /// Add the timer into the hash bucket
fn add_timer(&self, millis: u64) -> TimerHandle { fn add_timer(&self, millis: u64) -> TimerHandle {
if millis == 0 { self.with_mod(|inner| {
let mut inner = self.0.inner.borrow_mut(); if millis == 0 {
let entry = inner.timers.vacant_entry();
let no = entry.key();
let entry = inner.timers.vacant_entry(); entry.insert(TimerEntry {
let no = entry.key(); bucket_entry: 0,
bucket: None,
entry.insert(TimerEntry { task: LocalWaker::new(),
bucket_entry: 0, });
bucket: None, return TimerHandle(no);
task: LocalWaker::new(),
});
return TimerHandle(no);
}
let mut flags = self.0.flags.get();
flags.insert(Flags::RUNNING);
self.0.flags.set(flags);
let now = self.now();
let elapsed_time = self.0.elapsed_time();
let delta = if now >= elapsed_time {
to_units(as_millis(now - elapsed_time) + millis)
} else {
to_units(millis)
};
let (no, bucket_expiry) = {
// crate timer entry
let (idx, bucket_expiry) = self
.0
.calc_wheel_index(self.0.elapsed.get().wrapping_add(delta), delta);
let no = self.0.inner.borrow_mut().add_entry(idx);
(no, bucket_expiry)
};
// Check whether new bucket expire earlier
if bucket_expiry < self.0.next_expiry.get() {
self.0.next_expiry.set(bucket_expiry);
if flags.contains(Flags::DRIVER_STARTED) {
flags.insert(Flags::DRIVER_RECALC);
self.0.flags.set(flags);
self.0.driver.wake();
} else {
TimerDriver::start(self.0.clone());
} }
}
TimerHandle(no) let mut flags = self.0.flags.get();
flags.insert(Flags::RUNNING);
self.0.flags.set(flags);
let now = self.now(inner);
let elapsed_time = self.0.elapsed_time();
let delta = if now >= elapsed_time {
to_units(as_millis(now - elapsed_time) + millis)
} else {
to_units(millis)
};
let (no, bucket_expiry) = {
// crate timer entry
let (idx, bucket_expiry) = self
.0
.calc_wheel_index(self.0.elapsed.get().wrapping_add(delta), delta);
let no = inner.add_entry(idx);
(no, bucket_expiry)
};
// Check whether new bucket expire earlier
if bucket_expiry < self.0.next_expiry.get() {
self.0.next_expiry.set(bucket_expiry);
if flags.contains(Flags::DRIVER_STARTED) {
flags.insert(Flags::DRIVER_RECALC);
self.0.flags.set(flags);
self.0.driver.wake();
} else {
TimerDriver::start(self.0.clone(), inner);
}
}
TimerHandle(no)
})
} }
/// Update existing timer /// Update existing timer
fn update_timer(&self, hnd: usize, millis: u64) { fn update_timer(&self, hnd: usize, millis: u64) {
if millis == 0 { self.with_mod(|inner| {
self.remove_timer_bucket(hnd); if millis == 0 {
self.0.inner.borrow_mut().timers[hnd].bucket = None; inner.remove_timer_bucket(hnd, false);
return; inner.timers[hnd].bucket = None;
} return;
let now = self.now();
let elapsed_time = self.0.elapsed_time();
let delta = if now >= elapsed_time {
max(to_units(as_millis(now - elapsed_time) + millis), 1)
} else {
max(to_units(millis), 1)
};
let bucket_expiry = {
// calc bucket
let (idx, bucket_expiry) = self
.0
.calc_wheel_index(self.0.elapsed.get().wrapping_add(delta), delta);
self.0.inner.borrow_mut().update_entry(hnd, idx);
bucket_expiry
};
// Check whether new bucket expire earlier
if bucket_expiry < self.0.next_expiry.get() {
self.0.next_expiry.set(bucket_expiry);
let mut flags = self.0.flags.get();
if flags.contains(Flags::DRIVER_STARTED) {
flags.insert(Flags::DRIVER_RECALC);
self.0.flags.set(flags);
self.0.driver.wake();
} else {
TimerDriver::start(self.0.clone());
} }
}
let now = self.now(inner);
let elapsed_time = self.0.elapsed_time();
let delta = if now >= elapsed_time {
max(to_units(as_millis(now - elapsed_time) + millis), 1)
} else {
max(to_units(millis), 1)
};
let bucket_expiry = {
// calc bucket
let (idx, bucket_expiry) = self
.0
.calc_wheel_index(self.0.elapsed.get().wrapping_add(delta), delta);
inner.update_entry(hnd, idx);
bucket_expiry
};
// Check whether new bucket expire earlier
if bucket_expiry < self.0.next_expiry.get() {
self.0.next_expiry.set(bucket_expiry);
let mut flags = self.0.flags.get();
if flags.contains(Flags::DRIVER_STARTED) {
flags.insert(Flags::DRIVER_RECALC);
self.0.flags.set(flags);
self.0.driver.wake();
} else {
TimerDriver::start(self.0.clone(), inner);
}
}
})
} }
fn remove_timer(&self, handle: usize) { // fn remove_timer(&self, handle: usize) {
self.0.inner.borrow_mut().remove_timer_bucket(handle, true) // self.0.inner.borrow_mut().remove_timer_bucket(handle, true)
} // }
fn remove_timer_bucket(&self, handle: usize) {
self.0.inner.borrow_mut().remove_timer_bucket(handle, false)
}
} }
impl TimerMod { impl TimerMod {
@ -424,6 +438,16 @@ impl TimerMod {
} }
impl TimerInner { impl TimerInner {
fn with_mod<F, R>(&self, f: F) -> R
where
F: FnOnce(&mut TimerMod) -> R,
{
let mut m = self.inner.take().unwrap();
let result = f(&mut m);
self.inner.set(Some(m));
result
}
fn calc_wheel_index(&self, expires: u64, delta: u64) -> (usize, u64) { fn calc_wheel_index(&self, expires: u64, delta: u64) -> (usize, u64) {
if delta < lvl_start(1) { if delta < lvl_start(1) {
Self::calc_index(expires, 0) Self::calc_index(expires, 0)
@ -481,16 +505,12 @@ impl TimerInner {
} }
} }
fn execute_expired_timers(&self) { fn execute_expired_timers(&self, inner: &mut TimerMod) {
self.inner inner.execute_expired_timers(self.next_expiry.get());
.borrow_mut()
.execute_expired_timers(self.next_expiry.get());
} }
/// Find next expiration bucket /// Find next expiration bucket
fn next_pending_bucket(&self) -> Option<u64> { fn next_pending_bucket(&self, inner: &mut TimerMod) -> Option<u64> {
let inner = self.inner.borrow_mut();
let mut clk = self.elapsed.get(); let mut clk = self.elapsed.get();
let mut next = u64::MAX; let mut next = u64::MAX;
@ -537,7 +557,7 @@ impl TimerInner {
fn stop_wheel(&self) { fn stop_wheel(&self) {
// mark all timers as elapsed // mark all timers as elapsed
if let Ok(mut inner) = self.inner.try_borrow_mut() { if let Some(mut inner) = self.inner.take() {
let mut buckets = mem::take(&mut inner.buckets); let mut buckets = mem::take(&mut inner.buckets);
for b in &mut buckets { for b in &mut buckets {
for no in b.entries.drain() { for no in b.entries.drain() {
@ -555,6 +575,7 @@ impl TimerInner {
inner.buckets = buckets; inner.buckets = buckets;
inner.occupied = [0; WHEEL_SIZE]; inner.occupied = [0; WHEEL_SIZE];
self.inner.set(Some(inner));
} }
} }
} }
@ -604,12 +625,11 @@ impl TimerEntry {
struct TimerDriver(Rc<TimerInner>); struct TimerDriver(Rc<TimerInner>);
impl TimerDriver { impl TimerDriver {
fn start(timer: Rc<TimerInner>) { fn start(timer: Rc<TimerInner>, inner: &mut TimerMod) {
let mut flags = timer.flags.get(); let mut flags = timer.flags.get();
flags.insert(Flags::DRIVER_STARTED); flags.insert(Flags::DRIVER_STARTED);
timer.flags.set(flags); timer.flags.set(flags);
timer.inner.borrow_mut().driver_sleep = inner.driver_sleep = Delay::new(Duration::from_millis(timer.next_expiry_ms()));
Delay::new(Duration::from_millis(timer.next_expiry_ms()));
let _ = crate::spawn(TimerDriver(timer)); let _ = crate::spawn(TimerDriver(timer));
} }
@ -627,54 +647,53 @@ impl Future for TimerDriver {
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
self.0.driver.register(cx.waker()); self.0.driver.register(cx.waker());
let mut flags = self.0.flags.get(); self.0.with_mod(|inner| {
if flags.contains(Flags::DRIVER_RECALC) { let mut flags = self.0.flags.get();
flags.remove(Flags::DRIVER_RECALC); if flags.contains(Flags::DRIVER_RECALC) {
self.0.flags.set(flags); flags.remove(Flags::DRIVER_RECALC);
self.0.flags.set(flags);
let now = Instant::now();
let deadline =
if let Some(diff) = now.checked_duration_since(self.0.elapsed_time()) {
Duration::from_millis(self.0.next_expiry_ms()).saturating_sub(diff)
} else {
Duration::from_millis(self.0.next_expiry_ms())
};
self.0.inner.borrow_mut().driver_sleep.reset(deadline);
}
loop {
if Pin::new(&mut self.0.inner.borrow_mut().driver_sleep)
.poll(cx)
.is_ready()
{
let now = Instant::now(); let now = Instant::now();
self.0.elapsed.set(self.0.next_expiry.get()); let deadline =
self.0.elapsed_time.set(Some(now)); if let Some(diff) = now.checked_duration_since(self.0.elapsed_time()) {
self.0.execute_expired_timers(); Duration::from_millis(self.0.next_expiry_ms()).saturating_sub(diff)
} else {
if let Some(next_expiry) = self.0.next_pending_bucket() { Duration::from_millis(self.0.next_expiry_ms())
self.0.next_expiry.set(next_expiry); };
let dur = Duration::from_millis(self.0.next_expiry_ms()); inner.driver_sleep.reset(deadline);
self.0.inner.borrow_mut().driver_sleep.reset(dur);
continue;
} else {
self.0.next_expiry.set(u64::MAX);
self.0.elapsed_time.set(None);
}
} }
return Poll::Pending;
} loop {
if Pin::new(&mut inner.driver_sleep).poll(cx).is_ready() {
let now = Instant::now();
self.0.elapsed.set(self.0.next_expiry.get());
self.0.elapsed_time.set(Some(now));
self.0.execute_expired_timers(inner);
if let Some(next_expiry) = self.0.next_pending_bucket(inner) {
self.0.next_expiry.set(next_expiry);
let dur = Duration::from_millis(self.0.next_expiry_ms());
inner.driver_sleep.reset(dur);
continue;
} else {
self.0.next_expiry.set(u64::MAX);
self.0.elapsed_time.set(None);
}
}
return Poll::Pending;
}
})
} }
} }
struct LowresTimerDriver(Rc<TimerInner>); struct LowresTimerDriver(Rc<TimerInner>);
impl LowresTimerDriver { impl LowresTimerDriver {
fn start(timer: Rc<TimerInner>) { fn start(timer: Rc<TimerInner>, inner: &mut TimerMod) {
let mut flags = timer.flags.get(); let mut flags = timer.flags.get();
flags.insert(Flags::LOWRES_DRIVER); flags.insert(Flags::LOWRES_DRIVER);
timer.flags.set(flags); timer.flags.set(flags);
timer.inner.borrow_mut().lowres_driver_sleep = Delay::new(LOWRES_RESOLUTION); inner.lowres_driver_sleep = Delay::new(LOWRES_RESOLUTION);
let _ = crate::spawn(LowresTimerDriver(timer)); let _ = crate::spawn(LowresTimerDriver(timer));
} }
@ -692,27 +711,22 @@ impl Future for LowresTimerDriver {
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
self.0.lowres_driver.register(cx.waker()); self.0.lowres_driver.register(cx.waker());
let mut flags = self.0.flags.get(); self.0.with_mod(|inner| {
if !flags.contains(Flags::LOWRES_TIMER) { let mut flags = self.0.flags.get();
flags.insert(Flags::LOWRES_TIMER); if !flags.contains(Flags::LOWRES_TIMER) {
self.0.flags.set(flags); flags.insert(Flags::LOWRES_TIMER);
self.0 self.0.flags.set(flags);
.inner inner.lowres_driver_sleep.reset(LOWRES_RESOLUTION);
.borrow_mut() }
.lowres_driver_sleep
.reset(LOWRES_RESOLUTION);
}
if Pin::new(&mut self.0.inner.borrow_mut().lowres_driver_sleep) if Pin::new(&mut inner.lowres_driver_sleep).poll(cx).is_ready() {
.poll(cx) self.0.lowres_time.set(None);
.is_ready() self.0.lowres_stime.set(None);
{ flags.remove(Flags::LOWRES_TIMER);
self.0.lowres_time.set(None); self.0.flags.set(flags);
self.0.lowres_stime.set(None); }
flags.remove(Flags::LOWRES_TIMER); Poll::Pending
self.0.flags.set(flags); })
}
Poll::Pending
} }
} }