Refactor io timers (#258)

* Tune logging

* Refactor io timers

* Refactor http h1 dispatcher
This commit is contained in:
Nikolay Kim 2023-11-30 00:37:55 +06:00 committed by GitHub
parent ae766a5629
commit 5e7f3259e7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 366 additions and 225 deletions

View file

@ -1,5 +1,11 @@
# Changes
## [0.3.12] - 2023-11-29
* Refactor io timers
* Tune logging
## [0.3.11] - 2023-11-25
* Fix keep-alive timeout handling

View file

@ -1,6 +1,6 @@
[package]
name = "ntex-io"
version = "0.3.11"
version = "0.3.12"
authors = ["ntex contributors <team@ntex.rs>"]
description = "Utilities for encoding and decoding frames"
keywords = ["network", "framework", "async", "futures"]

View file

@ -4,8 +4,7 @@ use std::{cell::Cell, future, pin::Pin, rc::Rc, task::Context, task::Poll, time}
use ntex_bytes::Pool;
use ntex_codec::{Decoder, Encoder};
use ntex_service::{IntoService, Pipeline, Service};
use ntex_util::time::{now, Seconds};
use ntex_util::{future::Either, ready, spawn};
use ntex_util::{future::Either, ready, spawn, time::Seconds};
use crate::{Decoded, DispatchItem, IoBoxed, IoStatusUpdate, RecvError};
@ -17,31 +16,39 @@ pub struct DispatcherConfig(Rc<DispatcherConfigInner>);
#[derive(Debug)]
struct DispatcherConfigInner {
keepalive_timeout: Cell<time::Duration>,
keepalive_timeout: Cell<Seconds>,
disconnect_timeout: Cell<Seconds>,
frame_read_enabled: Cell<bool>,
frame_read_rate: Cell<u16>,
frame_read_timeout: Cell<time::Duration>,
frame_read_max_timeout: Cell<time::Duration>,
frame_read_timeout: Cell<Seconds>,
frame_read_max_timeout: Cell<Seconds>,
}
impl Default for DispatcherConfig {
fn default() -> Self {
DispatcherConfig(Rc::new(DispatcherConfigInner {
keepalive_timeout: Cell::new(Seconds(30).into()),
keepalive_timeout: Cell::new(Seconds(30)),
disconnect_timeout: Cell::new(Seconds(1)),
frame_read_rate: Cell::new(0),
frame_read_enabled: Cell::new(false),
frame_read_timeout: Cell::new(Seconds::ZERO.into()),
frame_read_max_timeout: Cell::new(Seconds::ZERO.into()),
frame_read_timeout: Cell::new(Seconds::ZERO),
frame_read_max_timeout: Cell::new(Seconds::ZERO),
}))
}
}
impl DispatcherConfig {
#[doc(hidden)]
#[deprecated(since = "0.3.12")]
#[inline]
/// Get keep-alive timeout
pub fn keepalive_timeout(&self) -> time::Duration {
self.0.keepalive_timeout.get().into()
}
#[inline]
/// Get keep-alive timeout
pub fn keepalive_timeout_secs(&self) -> Seconds {
self.0.keepalive_timeout.get()
}
@ -51,9 +58,25 @@ impl DispatcherConfig {
self.0.disconnect_timeout.get()
}
#[doc(hidden)]
#[deprecated(since = "0.3.12")]
#[inline]
/// Get frame read rate
pub fn frame_read_rate(&self) -> Option<(time::Duration, time::Duration, u16)> {
if self.0.frame_read_enabled.get() {
Some((
self.0.frame_read_timeout.get().into(),
self.0.frame_read_max_timeout.get().into(),
self.0.frame_read_rate.get(),
))
} else {
None
}
}
#[inline]
/// Get frame read rate
pub fn frame_read_rate_params(&self) -> Option<(Seconds, Seconds, u16)> {
if self.0.frame_read_enabled.get() {
Some((
self.0.frame_read_timeout.get(),
@ -71,7 +94,7 @@ impl DispatcherConfig {
///
/// By default keep-alive timeout is set to 30 seconds.
pub fn set_keepalive_timeout(&self, timeout: Seconds) -> &Self {
self.0.keepalive_timeout.set(timeout.into());
self.0.keepalive_timeout.set(timeout);
self
}
@ -102,8 +125,8 @@ impl DispatcherConfig {
rate: u16,
) -> &Self {
self.0.frame_read_enabled.set(!timeout.is_zero());
self.0.frame_read_timeout.set(timeout.into());
self.0.frame_read_max_timeout.set(max_timeout.into());
self.0.frame_read_timeout.set(timeout);
self.0.frame_read_max_timeout.set(max_timeout);
self.0.frame_read_rate.set(rate);
self
}
@ -145,7 +168,7 @@ where
cfg: DispatcherConfig,
read_remains: u32,
read_remains_prev: u32,
read_max_timeout: time::Instant,
read_max_timeout: Seconds,
}
pub(crate) struct DispatcherShared<S, U>
@ -226,7 +249,7 @@ where
flags: Flags::empty(),
read_remains: 0,
read_remains_prev: 0,
read_max_timeout: now(),
read_max_timeout: Seconds::ZERO,
st: DispatcherState::Processing,
},
}
@ -541,14 +564,19 @@ where
} else if self.read_remains == 0 && decoded.remains == 0 {
// no new data, start keep-alive timer
if !self.flags.contains(Flags::KA_TIMEOUT) {
log::debug!("Start keep-alive timer {:?}", self.cfg.keepalive_timeout());
log::debug!(
"Start keep-alive timer {:?}",
self.cfg.keepalive_timeout_secs()
);
self.flags.insert(Flags::KA_TIMEOUT);
self.shared.io.start_timer(self.cfg.keepalive_timeout());
self.shared
.io
.start_timer_secs(self.cfg.keepalive_timeout_secs());
}
} else if self.flags.contains(Flags::READ_TIMEOUT) {
// received new data but not enough for parsing complete frame
self.read_remains = decoded.remains as u32;
} else if let Some((timeout, max, _)) = self.cfg.frame_read_rate() {
} else if let Some((timeout, max, _)) = self.cfg.frame_read_rate_params() {
// we got new data but not enough to parse single frame
// start read timer
self.flags.remove(Flags::KA_TIMEOUT);
@ -556,10 +584,8 @@ where
self.read_remains = decoded.remains as u32;
self.read_remains_prev = 0;
if !max.is_zero() {
self.read_max_timeout = now() + max;
}
self.shared.io.start_timer(timeout);
self.read_max_timeout = max;
self.shared.io.start_timer_secs(timeout);
}
}
@ -568,7 +594,7 @@ where
// check read timer
if self.flags.contains(Flags::READ_TIMEOUT) {
if let Some((timeout, max, rate)) = self.cfg.frame_read_rate() {
if let Some((timeout, max, rate)) = self.cfg.frame_read_rate_params() {
let total = (self.read_remains - self.read_remains_prev)
.try_into()
.unwrap_or(u16::MAX);
@ -578,9 +604,14 @@ where
self.read_remains_prev = self.read_remains;
self.read_remains = 0;
if max.is_zero() || (!max.is_zero() && now() < self.read_max_timeout) {
if !max.is_zero() {
self.read_max_timeout =
Seconds(self.read_max_timeout.0.saturating_sub(timeout.0));
}
if max.is_zero() || !self.read_max_timeout.is_zero() {
log::trace!("Frame read rate {:?}, extend timer", total);
self.shared.io.start_timer(timeout);
self.shared.io.start_timer_secs(timeout);
return Ok(());
}
log::trace!("Max payload timeout has been reached");
@ -689,9 +720,9 @@ mod tests {
error: None,
flags: super::Flags::empty(),
st: DispatcherState::Processing,
read_max_timeout: time::Instant::now(),
read_remains: 0,
read_remains_prev: 0,
read_max_timeout: Seconds::ZERO,
pool,
shared,
cfg,
@ -1030,7 +1061,6 @@ mod tests {
#[ntex::test]
async fn test_keepalive2() {
let _ = env_logger::try_init();
let (client, server) = IoTest::create();
client.remote_buffer_cap(1024);

View file

@ -1,16 +1,16 @@
use std::cell::Cell;
use std::task::{Context, Poll};
use std::{fmt, future::Future, hash, io, marker, mem, ops, pin::Pin, ptr, rc::Rc, time};
use std::{fmt, future::Future, hash, io, marker, mem, ops, pin::Pin, ptr, rc::Rc};
use ntex_bytes::{PoolId, PoolRef};
use ntex_codec::{Decoder, Encoder};
use ntex_util::time::{now, Seconds};
use ntex_util::{future::poll_fn, future::Either, task::LocalWaker};
use ntex_util::{future::poll_fn, future::Either, task::LocalWaker, time::Seconds};
use crate::buf::Stack;
use crate::filter::{Base, Filter, Layer, NullFilter};
use crate::seal::Sealed;
use crate::tasks::{ReadContext, WriteContext};
use crate::timer::TimerHandle;
use crate::{Decoded, FilterLayer, Handle, IoStatusUpdate, IoStream, RecvError};
bitflags::bitflags! {
@ -70,7 +70,7 @@ pub(crate) struct IoState {
pub(super) handle: Cell<Option<Box<dyn Handle>>>,
#[allow(clippy::box_collection)]
pub(super) on_disconnect: Cell<Option<Box<Vec<LocalWaker>>>>,
pub(super) keepalive: Cell<time::Instant>,
pub(super) keepalive: Cell<TimerHandle>,
}
impl IoState {
@ -201,7 +201,7 @@ impl Io {
filter: Cell::new(NullFilter::get()),
handle: Cell::new(None),
on_disconnect: Cell::new(None),
keepalive: Cell::new(now()),
keepalive: Cell::new(TimerHandle::default()),
});
let filter = Box::new(Base::new(IoRef(inner.clone())));
@ -257,7 +257,7 @@ impl<F> Io<F> {
filter: Cell::new(NullFilter::get()),
handle: Cell::new(None),
on_disconnect: Cell::new(None),
keepalive: Cell::new(now()),
keepalive: Cell::new(TimerHandle::default()),
});
let state = mem::replace(&mut self.0, IoRef(inner));
@ -634,7 +634,7 @@ impl<F> Io<F> {
/// Wait for status updates
pub fn poll_status_update(&self, cx: &mut Context<'_>) -> Poll<IoStatusUpdate> {
let flags = self.flags();
if flags.contains(Flags::IO_STOPPED | Flags::IO_STOPPING) {
if flags.intersects(Flags::IO_STOPPED | Flags::IO_STOPPING) {
Poll::Ready(IoStatusUpdate::PeerGone(self.error()))
} else if flags.contains(Flags::DSP_STOP) {
self.0 .0.remove_flags(Flags::DSP_STOP);
@ -697,8 +697,9 @@ impl<F> ops::Deref for Io<F> {
impl<F> Drop for Io<F> {
fn drop(&mut self) {
self.stop_keepalive_timer();
if self.1.is_set() {
self.stop_timer();
if !self.0.flags().contains(Flags::IO_STOPPED) && self.1.is_set() {
log::trace!(
"io is dropped, force stopping io streams {:?}",
self.0.flags()

View file

@ -2,6 +2,7 @@ use std::{any, fmt, hash, io, time};
use ntex_bytes::{BytesVec, PoolRef};
use ntex_codec::{Decoder, Encoder};
use ntex_util::time::Seconds;
use super::{io::Flags, timer, types, Decoded, Filter, IoRef, OnDisconnect, WriteBuf};
@ -37,7 +38,7 @@ impl IoRef {
self.0
.flags
.get()
.contains(Flags::IO_STOPPING | Flags::IO_STOPPED)
.intersects(Flags::IO_STOPPING | Flags::IO_STOPPED)
}
#[inline]
@ -104,9 +105,7 @@ impl IoRef {
where
U: Encoder,
{
let flags = self.0.flags.get();
if !flags.contains(Flags::IO_STOPPING) {
if !self.is_closed() {
self.with_write_buf(|buf| {
// make sure we've got room
self.memory_pool().resize_write_buf(buf);
@ -114,14 +113,18 @@ impl IoRef {
// encode item and wake write task
codec.encode_vec(item, buf)
})
// .with_write_buf() could return io::Error<Result<(), U::Error>>,
// in that case mark io as failed
.map_or_else(
|err| {
log::trace!("Got io error while encoding, error: {:?}", err);
self.0.io_stopped(Some(err));
Ok(())
},
|item| item,
)
} else {
log::trace!("Io is closed/closing, skip frame encoding");
Ok(())
}
}
@ -208,24 +211,42 @@ impl IoRef {
self.0.buffer.with_read_destination(self, f)
}
#[inline]
/// current timer handle
pub fn timer_handle(&self) -> timer::TimerHandle {
self.0.keepalive.get()
}
#[doc(hidden)]
#[deprecated(since = "0.3.12")]
#[inline]
/// current timer deadline
pub fn timer_deadline(&self) -> time::Instant {
self.0.keepalive.get()
self.0.keepalive.get().instant()
}
#[inline]
/// Start timer
pub fn start_timer(&self, timeout: time::Duration) {
self.start_timer_secs(Seconds(timeout.as_secs() as u16));
}
#[inline]
/// Start timer
pub fn start_timer_secs(&self, timeout: Seconds) -> timer::TimerHandle {
if self.flags().contains(Flags::TIMEOUT) {
timer::unregister(self.0.keepalive.get(), self);
}
if !timeout.is_zero() {
log::debug!("start timer {:?}", timeout);
self.0.insert_flags(Flags::TIMEOUT);
self.0.keepalive.set(timer::register(timeout, self));
let hnd = timer::register(timeout, self);
self.0.keepalive.set(hnd);
hnd
} else {
self.0.remove_flags(Flags::TIMEOUT);
Default::default()
}
}
@ -234,16 +255,21 @@ impl IoRef {
pub fn stop_timer(&self) {
if self.flags().contains(Flags::TIMEOUT) {
log::debug!("unregister timer");
self.0.remove_flags(Flags::TIMEOUT);
timer::unregister(self.0.keepalive.get(), self)
}
}
#[doc(hidden)]
#[deprecated(since = "0.3.6")]
#[inline]
/// Start keep-alive timer
pub fn start_keepalive_timer(&self, timeout: time::Duration) {
self.start_timer(timeout);
}
#[doc(hidden)]
#[deprecated(since = "0.3.6")]
#[inline]
/// Stop keep-alive timer
pub fn stop_keepalive_timer(&self) {

View file

@ -30,6 +30,7 @@ pub use self::framed::Framed;
pub use self::io::{Io, IoRef, OnDisconnect};
pub use self::seal::{IoBoxed, Sealed};
pub use self::tasks::{ReadContext, WriteContext};
pub use self::timer::TimerHandle;
pub use self::utils::{filter, seal, Decoded};
/// Status for read task

View file

@ -1,8 +1,8 @@
#![allow(clippy::mutable_key_type)]
use std::collections::{BTreeMap, VecDeque};
use std::{cell::RefCell, rc::Rc, time::Duration, time::Instant};
use std::{cell::RefCell, ops, rc::Rc, time::Duration, time::Instant};
use ntex_util::time::{now, sleep, Millis};
use ntex_util::time::{now, sleep, Seconds};
use ntex_util::{spawn, HashSet};
use crate::{io::IoState, IoRef};
@ -14,23 +14,56 @@ thread_local! {
static TIMER: Rc<RefCell<Inner>> = Rc::new(RefCell::new(
Inner {
running: false,
base: Instant::now(),
current: 0,
cache: VecDeque::with_capacity(CAP),
notifications: BTreeMap::default(),
}));
}
#[derive(Copy, Clone, Default, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub struct TimerHandle(u32);
impl TimerHandle {
pub fn remains(&self) -> Seconds {
TIMER.with(|timer| {
let cur = timer.borrow().current;
if self.0 <= cur {
Seconds::ZERO
} else {
Seconds((self.0 - cur) as u16)
}
})
}
pub fn instant(&self) -> Instant {
TIMER.with(|timer| timer.borrow().base + Duration::from_secs(self.0 as u64))
}
}
impl ops::Add<Seconds> for TimerHandle {
type Output = TimerHandle;
#[inline]
fn add(self, other: Seconds) -> TimerHandle {
TimerHandle(self.0 + other.0 as u32)
}
}
struct Inner {
running: bool,
base: Instant,
current: u32,
cache: VecDeque<HashSet<Rc<IoState>>>,
notifications: BTreeMap<Instant, HashSet<Rc<IoState>>>,
notifications: BTreeMap<u32, HashSet<Rc<IoState>>>,
}
impl Inner {
fn unregister(&mut self, expire: Instant, io: &IoRef) {
if let Some(states) = self.notifications.get_mut(&expire) {
fn unregister(&mut self, hnd: TimerHandle, io: &IoRef) {
if let Some(states) = self.notifications.get_mut(&hnd.0) {
states.remove(&io.0);
if states.is_empty() {
if let Some(items) = self.notifications.remove(&expire) {
if let Some(items) = self.notifications.remove(&hnd.0) {
if self.cache.len() <= CAP {
self.cache.push_back(items);
}
@ -40,26 +73,29 @@ impl Inner {
}
}
pub(crate) fn register(timeout: Duration, io: &IoRef) -> Instant {
pub(crate) fn register(timeout: Seconds, io: &IoRef) -> TimerHandle {
TIMER.with(|timer| {
let mut inner = timer.borrow_mut();
let expire = now() + timeout;
// setup current delta
if !inner.running {
inner.current = (now() - inner.base).as_secs() as u32;
}
let hnd = inner.current + timeout.0 as u32;
// search existing key
let expire = if let Some((expire, _)) =
inner.notifications.range(expire..expire + SEC).next()
{
*expire
let hnd = if let Some((hnd, _)) = inner.notifications.range(hnd..hnd + 1).next() {
*hnd
} else {
let items = inner.cache.pop_front().unwrap_or_default();
inner.notifications.insert(expire, items);
expire
inner.notifications.insert(hnd, items);
hnd
};
inner
.notifications
.get_mut(&expire)
.get_mut(&hnd)
.unwrap()
.insert(io.0.clone());
@ -70,15 +106,15 @@ pub(crate) fn register(timeout: Duration, io: &IoRef) -> Instant {
spawn(async move {
let guard = TimerGuard(inner.clone());
loop {
sleep(Millis::ONE_SEC).await;
sleep(SEC).await;
{
let mut i = inner.borrow_mut();
let now_time = now();
i.current += 1;
// notify io dispatcher
while let Some(key) = i.notifications.keys().next() {
let key = *key;
if key <= now_time {
if key <= i.current {
let mut items = i.notifications.remove(&key).unwrap();
items.drain().for_each(|st| st.notify_timeout());
if i.cache.len() <= CAP {
@ -100,7 +136,7 @@ pub(crate) fn register(timeout: Duration, io: &IoRef) -> Instant {
});
}
expire
TimerHandle(hnd)
})
}
@ -114,8 +150,8 @@ impl Drop for TimerGuard {
}
}
pub(crate) fn unregister(expire: Instant, io: &IoRef) {
pub(crate) fn unregister(hnd: TimerHandle, io: &IoRef) {
TIMER.with(|timer| {
timer.borrow_mut().unregister(expire, io);
timer.borrow_mut().unregister(hnd, io);
})
}

View file

@ -1,5 +1,9 @@
# Changes
## [0.7.13] - 2023-11-29
* Refactor h1 timers
## [0.7.12] - 2023-11-22
* Replace async-oneshot with oneshot

View file

@ -1,6 +1,6 @@
[package]
name = "ntex"
version = "0.7.12"
version = "0.7.13"
authors = ["ntex contributors <team@ntex.rs>"]
description = "Framework for composable network services"
readme = "README.md"
@ -58,7 +58,7 @@ ntex-util = "0.3.4"
ntex-bytes = "0.1.21"
ntex-h2 = "0.4.4"
ntex-rt = "0.4.11"
ntex-io = "0.3.10"
ntex-io = "0.3.12"
ntex-tls = "0.3.2"
ntex-tokio = { version = "0.3.1", optional = true }
ntex-glommio = { version = "0.3.1", optional = true }

View file

@ -1,4 +1,4 @@
use std::{cell::Cell, ptr::copy_nonoverlapping, rc::Rc, time, time::Duration};
use std::{cell::Cell, ptr::copy_nonoverlapping, rc::Rc, time};
use ntex_h2::{self as h2};
@ -43,7 +43,7 @@ impl From<Option<usize>> for KeepAlive {
#[derive(Debug, Clone)]
/// Http service configuration
pub struct ServiceConfig {
pub(super) keep_alive: Millis,
pub(super) keep_alive: Seconds,
pub(super) client_disconnect: Seconds,
pub(super) ka_enabled: bool,
pub(super) ssl_handshake_timeout: Millis,
@ -56,16 +56,16 @@ pub struct ServiceConfig {
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub(super) struct ReadRate {
pub(super) rate: u16,
pub(super) timeout: time::Duration,
pub(super) max_timeout: time::Duration,
pub(super) timeout: Seconds,
pub(super) max_timeout: Seconds,
}
impl Default for ReadRate {
fn default() -> Self {
ReadRate {
rate: 256,
timeout: time::Duration::from_secs(1),
max_timeout: time::Duration::from_secs(4),
timeout: Seconds(1),
max_timeout: Seconds(4),
}
}
}
@ -74,7 +74,7 @@ impl Default for ServiceConfig {
fn default() -> Self {
Self::new(
KeepAlive::Timeout(Seconds(5)),
Millis(1_000),
Seconds::ONE,
Seconds::ONE,
Millis(5_000),
h2::Config::server(),
@ -86,17 +86,21 @@ impl ServiceConfig {
/// Create instance of `ServiceConfig`
pub fn new(
keep_alive: KeepAlive,
client_timeout: Millis,
client_timeout: Seconds,
client_disconnect: Seconds,
ssl_handshake_timeout: Millis,
h2config: h2::Config,
) -> ServiceConfig {
let (keep_alive, ka_enabled) = match keep_alive {
KeepAlive::Timeout(val) => (Millis::from(val), true),
KeepAlive::Os => (Millis::ZERO, true),
KeepAlive::Disabled => (Millis::ZERO, false),
KeepAlive::Timeout(val) => (val, true),
KeepAlive::Os => (Seconds::ZERO, true),
KeepAlive::Disabled => (Seconds::ZERO, false),
};
let keep_alive = if ka_enabled {
keep_alive
} else {
Seconds::ZERO
};
let keep_alive = if ka_enabled { keep_alive } else { Millis::ZERO };
ServiceConfig {
client_disconnect,
@ -107,8 +111,8 @@ impl ServiceConfig {
timer: DateService::new(),
headers_read_rate: Some(ReadRate {
rate: 256,
timeout: client_timeout.into(),
max_timeout: (client_timeout + Millis(3_000)).into(),
timeout: client_timeout,
max_timeout: client_timeout + Seconds(3),
}),
payload_read_rate: None,
}
@ -118,8 +122,8 @@ impl ServiceConfig {
if timeout.is_zero() {
self.headers_read_rate = None;
} else {
let mut rate = self.headers_read_rate.clone().unwrap_or_default();
rate.timeout = timeout.into();
let mut rate = self.headers_read_rate.unwrap_or_default();
rate.timeout = timeout;
self.headers_read_rate = Some(rate);
}
}
@ -129,11 +133,15 @@ impl ServiceConfig {
/// By default keep alive is set to a 5 seconds.
pub fn keepalive<W: Into<KeepAlive>>(&mut self, val: W) -> &mut Self {
let (keep_alive, ka_enabled) = match val.into() {
KeepAlive::Timeout(val) => (Millis::from(val), true),
KeepAlive::Os => (Millis::ZERO, true),
KeepAlive::Disabled => (Millis::ZERO, false),
KeepAlive::Timeout(val) => (val, true),
KeepAlive::Os => (Seconds::ZERO, true),
KeepAlive::Disabled => (Seconds::ZERO, false),
};
let keep_alive = if ka_enabled {
keep_alive
} else {
Seconds::ZERO
};
let keep_alive = if ka_enabled { keep_alive } else { Millis::ZERO };
self.keep_alive = keep_alive;
self.ka_enabled = ka_enabled;
@ -146,7 +154,7 @@ impl ServiceConfig {
///
/// By default keep-alive timeout is set to 30 seconds.
pub fn keepalive_timeout(&mut self, timeout: Seconds) -> &mut Self {
self.keep_alive = timeout.into();
self.keep_alive = timeout;
self.ka_enabled = !timeout.is_zero();
self
}
@ -179,8 +187,8 @@ impl ServiceConfig {
/// Set read rate parameters for request headers.
///
/// Set max timeout for reading request headers. If the client
/// sends `rate` amount of data, increase the timeout by 1 second for every.
/// Set read timeout, max timeout and rate for reading request headers. If the client
/// sends `rate` amount of data within `timeout` period of time, extend timeout by `timeout` seconds.
/// But no more than `max_timeout` timeout.
///
/// By default headers read rate is set to 1sec with max timeout 5sec.
@ -193,8 +201,8 @@ impl ServiceConfig {
if !timeout.is_zero() {
self.headers_read_rate = Some(ReadRate {
rate,
timeout: timeout.into(),
max_timeout: max_timeout.into(),
timeout,
max_timeout,
});
} else {
self.headers_read_rate = None;
@ -218,8 +226,8 @@ impl ServiceConfig {
if !timeout.is_zero() {
self.payload_read_rate = Some(ReadRate {
rate,
timeout: timeout.into(),
max_timeout: max_timeout.into(),
timeout,
max_timeout,
});
} else {
self.payload_read_rate = None;
@ -234,7 +242,7 @@ pub(super) struct DispatcherConfig<S, X, U> {
pub(super) service: Pipeline<S>,
pub(super) expect: Pipeline<X>,
pub(super) upgrade: Option<Pipeline<U>>,
pub(super) keep_alive: Duration,
pub(super) keep_alive: Seconds,
pub(super) client_disconnect: Seconds,
pub(super) h2config: h2::Config,
pub(super) ka_enabled: bool,
@ -257,8 +265,8 @@ impl<S, X, U> DispatcherConfig<S, X, U> {
expect: expect.into(),
upgrade: upgrade.map(|v| v.into()),
on_request: on_request.map(|v| v.into()),
keep_alive: Duration::from(cfg.keep_alive),
client_disconnect: cfg.client_disconnect.into(),
keep_alive: cfg.keep_alive,
client_disconnect: cfg.client_disconnect,
ka_enabled: cfg.ka_enabled,
headers_read_rate: cfg.headers_read_rate,
payload_read_rate: cfg.payload_read_rate,

View file

@ -1,12 +1,10 @@
//! Framed transport dispatcher
use std::task::{Context, Poll};
use std::{
cell::RefCell, error::Error, future::Future, io, marker, pin::Pin, rc::Rc, time,
};
use std::{cell::RefCell, error::Error, future::Future, io, marker, pin::Pin, rc::Rc};
use crate::io::{Decoded, Filter, Io, IoBoxed, IoRef, IoStatusUpdate, RecvError};
use crate::service::{Pipeline, PipelineCall, Service};
use crate::time::now;
use crate::time::Seconds;
use crate::util::{ready, Bytes};
use crate::http;
@ -21,8 +19,6 @@ use super::decoder::{PayloadDecoder, PayloadItem, PayloadType};
use super::payload::{Payload, PayloadSender, PayloadStatus};
use super::{codec::Codec, Message};
const ONE_SEC: time::Duration = time::Duration::from_secs(1);
bitflags::bitflags! {
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub struct Flags: u8 {
@ -33,10 +29,12 @@ bitflags::bitflags! {
/// Stop after sending payload
const SENDPAYLOAD_AND_STOP = 0b0000_0100;
/// Keep-alive is enabled
const READ_KA_TIMEOUT = 0b0001_0000;
/// Read headers timer is enabled
const READ_HDRS_TIMEOUT = 0b0001_0000;
const READ_HDRS_TIMEOUT = 0b0010_0000;
/// Read headers payload is enabled
const READ_PL_TIMEOUT = 0b0010_0000;
const READ_PL_TIMEOUT = 0b0100_0000;
}
}
@ -97,7 +95,7 @@ struct DispatcherInner<F, S, B, X, U> {
payload: Option<(PayloadDecoder, PayloadSender)>,
read_remains: u32,
read_consumed: u32,
read_max_timeout: time::Instant,
read_max_timeout: Seconds,
_t: marker::PhantomData<(S, B)>,
}
@ -118,11 +116,11 @@ where
io.set_disconnect_timeout(config.client_disconnect);
// slow-request timer
let flags = if let Some(cfg) = config.headers_read_rate() {
io.start_timer(cfg.timeout);
Flags::READ_HDRS_TIMEOUT
let (flags, max_timeout) = if let Some(cfg) = config.headers_read_rate() {
io.start_timer_secs(cfg.timeout);
(Flags::READ_HDRS_TIMEOUT, cfg.max_timeout)
} else {
Flags::empty()
(Flags::empty(), Seconds::ZERO)
};
Dispatcher {
@ -137,7 +135,7 @@ where
payload: None,
read_remains: 0,
read_consumed: 0,
read_max_timeout: now(),
read_max_timeout: max_timeout,
_t: marker::PhantomData,
},
}
@ -496,17 +494,15 @@ where
cx: &mut Context<'_>,
call_state: &mut std::pin::Pin<&mut CallState<S, X>>,
) -> Poll<State<B>> {
log::trace!("trying to read http message");
log::trace!("Trying to read http message");
loop {
// let result = ready!(self.io.poll_recv(&self.codec, cx));
let result = match self.io.poll_recv_decode(&self.codec, cx) {
Ok(decoded) => {
if let Some(st) =
self.update_request_timer(decoded.item.is_some(), decoded.remains)
{
if let Some(st) = self.update_hdrs_timer(&decoded) {
return Poll::Ready(st);
}
if let Some(item) = decoded.item {
Ok(item)
} else {
@ -519,7 +515,7 @@ where
// decode incoming bytes stream
return match result {
Ok((mut req, pl)) => {
log::trace!("http message is received: {:?} and payload {:?}", req, pl);
log::trace!("Http message is received: {:?} and payload {:?}", req, pl);
// configure request payload
let upgrade = match pl {
@ -545,7 +541,7 @@ where
if upgrade {
// Handle UPGRADE request
log::trace!("prep io for upgrade handler");
log::trace!("Prepare io for upgrade handler");
Poll::Ready(State::Upgrade(Some(req)))
} else {
if req.upgrade() {
@ -567,7 +563,7 @@ where
}
Err(RecvError::WriteBackpressure) => {
if let Err(err) = ready!(self.io.poll_flush(cx, false)) {
log::trace!("peer is gone with {:?}", err);
log::trace!("Peer is gone with {:?}", err);
self.error = Some(DispatchError::PeerGone(Some(err)));
Poll::Ready(State::Stop)
} else {
@ -576,30 +572,33 @@ where
}
Err(RecvError::Decoder(err)) => {
// Malformed requests, respond with 400
log::trace!("malformed request: {:?}", err);
log::trace!("Malformed request: {:?}", err);
let (res, body) = Response::BadRequest().finish().into_parts();
self.error = Some(DispatchError::Parse(err));
Poll::Ready(self.send_response(res, body.into_body()))
}
Err(RecvError::PeerGone(err)) => {
log::trace!("peer is gone with {:?}", err);
log::trace!("Peer is gone with {:?}", err);
self.error = Some(DispatchError::PeerGone(err));
Poll::Ready(State::Stop)
}
Err(RecvError::Stop) => {
log::trace!("dispatcher is instructed to stop");
log::trace!("Dispatcher is instructed to stop");
Poll::Ready(State::Stop)
}
Err(RecvError::KeepAlive) => {
if self.flags.contains(Flags::READ_HDRS_TIMEOUT) {
log::trace!("slow request timeout");
let (req, body) = Response::RequestTimeout().finish().into_parts();
let _ = self.send_response(req, body.into_body());
self.error = Some(DispatchError::SlowRequestTimeout);
} else if self.flags.contains(Flags::READ_PL_TIMEOUT) {
log::trace!("slow payload timeout");
if let Err(err) = self.handle_timeout() {
log::trace!("Slow request timeout");
let (req, body) =
Response::RequestTimeout().finish().into_parts();
let _ = self.send_response(req, body.into_body());
self.error = Some(err);
} else {
continue;
}
} else {
log::trace!("keep-alive timeout, close connection");
log::trace!("Keep-alive timeout, close connection");
}
Poll::Ready(State::Stop)
}
@ -608,7 +607,7 @@ where
}
fn send_response(&mut self, msg: Response<()>, body: ResponseBody<B>) -> State<B> {
trace!("sending response: {:?} body: {:?}", msg, body.size());
trace!("Sending response: {:?} body: {:?}", msg, body.size());
// we dont need to process responses if socket is disconnected
// but we still want to handle requests with app service
// so we skip response processing for droppped connection
@ -650,7 +649,7 @@ where
) -> Option<State<B>> {
match item {
Some(Ok(item)) => {
trace!("got response chunk: {:?}", item.len());
trace!("Got response chunk: {:?}", item.len());
match self.io.encode(Message::Chunk(Some(item)), &self.codec) {
Ok(_) => None,
Err(err) => {
@ -660,7 +659,7 @@ where
}
}
None => {
trace!("response payload eof {:?}", self.flags);
trace!("Response payload eof {:?}", self.flags);
if let Err(err) = self.io.encode(Message::Chunk(None), &self.codec) {
self.error = Some(DispatchError::Encode(err));
Some(State::Stop)
@ -673,7 +672,7 @@ where
}
}
Some(Err(e)) => {
trace!("error during response body poll: {:?}", e);
trace!("Error during response body poll: {:?}", e);
self.error = Some(DispatchError::ResponsePayload(e));
Some(State::Stop)
}
@ -756,8 +755,8 @@ where
}
}
RecvError::KeepAlive => {
if let Err(err) = self.handle_payload_timeout() {
DispatchError::from(err)
if let Err(err) = self.handle_timeout() {
err
} else {
continue;
}
@ -824,35 +823,107 @@ where
}
}
fn handle_payload_timeout(&mut self) -> Result<(), io::Error> {
// check payload read rate
if self.flags.contains(Flags::READ_PL_TIMEOUT) {
if let Some(ref cfg) = self.config.payload_read_rate {
let total = (self.read_remains + self.read_consumed)
.try_into()
.unwrap_or(u16::MAX);
if total > cfg.rate {
fn handle_timeout(&mut self) -> Result<(), DispatchError> {
// check read rate
if self
.flags
.intersects(Flags::READ_PL_TIMEOUT | Flags::READ_HDRS_TIMEOUT)
{
let cfg = if self.flags.contains(Flags::READ_HDRS_TIMEOUT) {
&self.config.headers_read_rate
} else {
&self.config.payload_read_rate
};
if let Some(ref cfg) = cfg {
let total = if self.flags.contains(Flags::READ_HDRS_TIMEOUT) {
let total = (self.read_remains - self.read_consumed)
.try_into()
.unwrap_or(u16::MAX);
self.read_remains = 0;
total
} else {
let total = (self.read_remains + self.read_consumed)
.try_into()
.unwrap_or(u16::MAX);
self.read_consumed = 0;
total
};
if total > cfg.rate {
// update max timeout
if !cfg.max_timeout.is_zero() {
self.read_max_timeout =
Seconds(self.read_max_timeout.0.saturating_sub(cfg.timeout.0));
}
// start timer for next period
if cfg.max_timeout.is_zero()
|| (!cfg.max_timeout.is_zero() && now() < self.read_max_timeout)
{
log::trace!("Payload read rate {:?}, extend timer", total);
self.io.start_timer(cfg.timeout);
if cfg.max_timeout.is_zero() || !self.read_max_timeout.is_zero() {
log::trace!("Bytes read rate {:?}, extend timer", total);
self.io.start_timer_secs(cfg.timeout);
return Ok(());
}
log::trace!("Max payload timeout has been reached");
}
}
}
log::trace!("Timeout during payload reading");
self.set_payload_error(PayloadError::Io(io::Error::new(
io::ErrorKind::TimedOut,
"Keep-alive",
)));
Err(io::Error::new(io::ErrorKind::TimedOut, "Keep-alive"))
log::trace!("Timeout during reading");
if self.flags.contains(Flags::READ_PL_TIMEOUT) {
self.set_payload_error(PayloadError::Io(io::Error::new(
io::ErrorKind::TimedOut,
"Keep-alive",
)));
Err(DispatchError::from(io::Error::new(
io::ErrorKind::TimedOut,
"Keep-alive",
)))
} else {
Err(DispatchError::SlowRequestTimeout)
}
}
fn update_hdrs_timer(
&mut self,
decoded: &Decoded<(Request, PayloadType)>,
) -> Option<State<B>> {
// got parsed frame
if decoded.item.is_some() {
self.read_remains = 0;
self.io.stop_timer();
self.flags.remove(
Flags::READ_KA_TIMEOUT | Flags::READ_HDRS_TIMEOUT | Flags::READ_PL_TIMEOUT,
);
} else if self.flags.contains(Flags::READ_HDRS_TIMEOUT) {
// received new data but not enough for parsing complete frame
self.read_remains = decoded.remains as u32;
} else if self.read_remains == 0 && decoded.remains == 0 {
// no new data, start keep-alive timer
if self.codec.keepalive() {
if !self.flags.contains(Flags::READ_KA_TIMEOUT) {
log::debug!("Start keep-alive timer {:?}", self.config.keep_alive);
self.flags.insert(Flags::READ_KA_TIMEOUT);
if self.config.keep_alive_enabled() {
self.io.start_timer_secs(self.config.keep_alive);
}
}
} else {
self.io.close();
return Some(State::Stop);
}
} else if let Some(ref cfg) = self.config.headers_read_rate {
log::debug!("Start headers read timer {:?}", cfg.timeout);
// we got new data but not enough to parse single frame
// start read timer
self.flags.remove(Flags::READ_KA_TIMEOUT);
self.flags.insert(Flags::READ_HDRS_TIMEOUT);
self.read_consumed = 0;
self.read_remains = decoded.remains as u32;
self.read_max_timeout = cfg.max_timeout;
self.io.start_timer_secs(cfg.timeout);
}
None
}
fn update_payload_timer(&mut self, decoded: &Decoded<PayloadItem>) {
@ -865,74 +936,10 @@ where
self.read_remains = decoded.remains as u32;
self.read_consumed = decoded.consumed as u32;
self.io.start_timer(cfg.timeout);
if !cfg.max_timeout.is_zero() {
self.read_max_timeout = now() + cfg.max_timeout;
}
self.read_max_timeout = cfg.max_timeout;
self.io.start_timer_secs(cfg.timeout);
}
}
fn update_request_timer(&mut self, received: bool, remains: usize) -> Option<State<B>> {
// we got parsed frame
if received {
// remove all timers
self.flags
.remove(Flags::READ_HDRS_TIMEOUT | Flags::READ_PL_TIMEOUT);
self.io.stop_timer();
} else if self.flags.contains(Flags::READ_HDRS_TIMEOUT) {
// update read timer
if let Some(ref cfg) = self.config.headers_read_rate {
let bytes = remains as u32;
let delta = if bytes > self.read_remains {
(bytes - self.read_remains).try_into().unwrap_or(u16::MAX)
} else {
bytes.try_into().unwrap_or(u16::MAX)
};
// read rate higher than min rate
if delta >= cfg.rate {
let n = now();
let next = self.io.timer_deadline() + ONE_SEC;
let new_timeout = if n >= next { ONE_SEC } else { next - n };
// max timeout
if cfg.max_timeout.is_zero()
|| (n + new_timeout) <= self.read_max_timeout
{
self.io.stop_timer();
self.io.start_timer(new_timeout);
// store current buf size for future rate calculation
self.read_remains = bytes;
}
}
}
} else {
// no new data then start keep-alive timer
if remains == 0 {
if self.codec.keepalive() {
if self.config.keep_alive_enabled() {
self.io.start_timer(self.config.keep_alive);
}
} else {
self.io.close();
return Some(State::Stop);
}
} else if let Some(ref cfg) = self.config.headers_read_rate {
// we got new data but not enough to parse single frame
// start read timer
self.flags.insert(Flags::READ_HDRS_TIMEOUT);
self.read_remains = 0;
self.io.start_timer(cfg.timeout);
if !cfg.max_timeout.is_zero() {
self.read_max_timeout = now() + cfg.max_timeout;
}
}
}
None
}
}
#[cfg(test)]
@ -968,7 +975,7 @@ mod tests {
{
let config = ServiceConfig::new(
Seconds(5).into(),
Millis(1_000),
Seconds(1),
Seconds::ZERO,
Millis(5_000),
Config::server(),
@ -1021,7 +1028,7 @@ mod tests {
let data2 = data.clone();
let config = ServiceConfig::new(
Seconds(5).into(),
Millis(1_000),
Seconds(1),
Seconds::ZERO,
Millis(5_000),
Config::server(),
@ -1406,7 +1413,6 @@ mod tests {
#[crate::rt_test]
async fn test_payload_timeout() {
env_logger::init();
let mark = Arc::new(AtomicUsize::new(0));
let mark2 = mark.clone();
@ -1432,7 +1438,7 @@ mod tests {
let mut config = ServiceConfig::new(
Seconds(5).into(),
Millis(1_000),
Seconds(1),
Seconds::ZERO,
Millis(5_000),
Config::server(),

View file

@ -32,6 +32,7 @@ struct ReadRate {
}
impl Config {
#[allow(clippy::wrong_self_convention)]
fn into_cfg(&self) -> http::ServiceConfig {
let mut svc_cfg = http::ServiceConfig::default();
svc_cfg.keepalive(self.keep_alive);
@ -186,7 +187,7 @@ where
cfg.headers_read_rate = None;
} else {
let mut rate = cfg.headers_read_rate.unwrap_or_default();
rate.timeout = timeout.into();
rate.timeout = timeout;
cfg.headers_read_rate = Some(rate);
}
}
@ -258,8 +259,8 @@ where
if !timeout.is_zero() {
self.config.lock().unwrap().payload_read_rate = Some(ReadRate {
rate,
timeout: timeout.into(),
max_timeout: max_timeout.into(),
timeout,
max_timeout,
});
} else {
self.config.lock().unwrap().payload_read_rate = None;

View file

@ -171,6 +171,28 @@ async fn test_slow_request() {
assert!(data.starts_with("HTTP/1.1 408 Request Timeout"));
}
#[ntex::test]
async fn test_slow_request2() {
const DATA: &[u8] = b"GET /test/tests/test HTTP/1.1\r\n";
let srv = test_server(|| {
HttpService::build()
.headers_read_rate(Seconds(1), Seconds(2), 4)
.finish(|_| Ready::Ok::<_, io::Error>(Response::Ok().finish()))
});
let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
let _ = stream.write_all(b"GET /test/tests/test HTTP/1.1\r\n\r\n");
let mut data = vec![0; 1024];
let _ = stream.read(&mut data);
assert_eq!(&data[..17], b"HTTP/1.1 200 OK\r\n");
let _ = stream.write_all(DATA);
let mut data = String::new();
let _ = stream.read_to_string(&mut data);
assert!(data.starts_with("HTTP/1.1 408 Request Timeout"));
}
#[ntex::test]
async fn test_http1_malformed_request() {
let srv = test_server(|| {