Fix handling payload timer after payload got consumed (#366)

This commit is contained in:
Nikolay Kim 2024-05-29 17:42:18 +05:00 committed by GitHub
parent 3b49828e5f
commit 9c29de14cf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 138 additions and 96 deletions

View file

@ -6,8 +6,7 @@
unreachable_pub, unreachable_pub,
missing_debug_implementations missing_debug_implementations
)] )]
use std::rc::Rc;
use std::{future::Future, rc::Rc};
mod and_then; mod and_then;
mod apply; mod apply;
@ -183,11 +182,9 @@ pub trait ServiceFactory<Req, Cfg = ()> {
type InitError; type InitError;
/// Create and return a new service value asynchronously. /// Create and return a new service value asynchronously.
fn create( async fn create(&self, cfg: Cfg) -> Result<Self::Service, Self::InitError>;
&self,
cfg: Cfg,
) -> impl Future<Output = Result<Self::Service, Self::InitError>>;
#[inline]
/// Create and return a new service value asynchronously and wrap into a container /// Create and return a new service value asynchronously and wrap into a container
async fn pipeline(&self, cfg: Cfg) -> Result<Pipeline<Self::Service>, Self::InitError> async fn pipeline(&self, cfg: Cfg) -> Result<Pipeline<Self::Service>, Self::InitError>
where where

View file

@ -1,5 +1,9 @@
# Changes # Changes
## [2.0.1] - 2024-05-29
* http: Fix handling payload timer after payload got consumed
## [2.0.0] - 2024-05-28 ## [2.0.0] - 2024-05-28
* Use "async fn" for Service::ready() and Service::shutdown() * Use "async fn" for Service::ready() and Service::shutdown()

View file

@ -1,6 +1,6 @@
[package] [package]
name = "ntex" name = "ntex"
version = "2.0.0" version = "2.0.1"
authors = ["ntex contributors <team@ntex.rs>"] authors = ["ntex contributors <team@ntex.rs>"]
description = "Framework for composable network services" description = "Framework for composable network services"
readme = "README.md" readme = "README.md"
@ -63,10 +63,10 @@ ntex-router = "0.5.3"
ntex-service = "3.0" ntex-service = "3.0"
ntex-macros = "0.1.3" ntex-macros = "0.1.3"
ntex-util = "2.0" ntex-util = "2.0"
ntex-bytes = "0.1.25" ntex-bytes = "0.1.27"
ntex-server = "2.0" ntex-server = "2.0"
ntex-h2 = "1.0" ntex-h2 = "1.0"
ntex-rt = "0.4.12" ntex-rt = "0.4.13"
ntex-io = "2.0" ntex-io = "2.0"
ntex-net = "2.0" ntex-net = "2.0"
ntex-tls = "2.0" ntex-tls = "2.0"

View file

@ -4,6 +4,7 @@ use crate::http::message::CurrentIo;
use crate::http::{body::Body, h1::Codec, Request, Response, ResponseError}; use crate::http::{body::Body, h1::Codec, Request, Response, ResponseError};
use crate::io::{Filter, Io, IoBoxed}; use crate::io::{Filter, Io, IoBoxed};
#[derive(Debug)]
pub enum Control<F, Err> { pub enum Control<F, Err> {
/// New request is loaded /// New request is loaded
NewRequest(NewRequest), NewRequest(NewRequest),
@ -40,19 +41,19 @@ bitflags::bitflags! {
#[derive(Debug)] #[derive(Debug)]
pub(super) enum ControlResult { pub(super) enum ControlResult {
// handle request expect /// handle request expect
Expect(Request), Expect(Request),
// handle request upgrade /// handle request upgrade
Upgrade(Request), Upgrade(Request),
// forward request to publish service /// forward request to publish service
Publish(Request), Publish(Request),
// forward request to publish service /// forward request to publish service
PublishUpgrade(Request), PublishUpgrade(Request),
// send response /// send response
Response(Response<()>, Body), Response(Response<()>, Body),
// send response /// send response
ResponseWithIo(Response<()>, Body, IoBoxed), ResponseWithIo(Response<()>, Body, IoBoxed),
// drop connection /// drop connection
Stop, Stop,
} }

View file

@ -20,6 +20,7 @@ where
type Service = DefaultControlService; type Service = DefaultControlService;
type InitError = io::Error; type InitError = io::Error;
#[inline]
async fn create(&self, _: ()) -> Result<Self::Service, Self::InitError> { async fn create(&self, _: ()) -> Result<Self::Service, Self::InitError> {
Ok(DefaultControlService) Ok(DefaultControlService)
} }
@ -33,6 +34,7 @@ where
type Response = ControlAck; type Response = ControlAck;
type Error = io::Error; type Error = io::Error;
#[inline]
async fn call( async fn call(
&self, &self,
req: Control<F, Err>, req: Control<F, Err>,

View file

@ -507,13 +507,20 @@ where
fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll<State<F, C, S, B>> { fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll<State<F, C, S, B>> {
if self.payload.is_some() { if self.payload.is_some() {
if let Some(st) = ready!(self.poll_request_payload(cx)) { if let Some(st) = ready!(self.poll_request_payload(cx)) {
return Poll::Ready(st); Poll::Ready(st)
} else {
Poll::Pending
}
} else {
// check for io changes, could close while waiting for service call
match ready!(self.io.poll_status_update(cx)) {
IoStatusUpdate::KeepAlive => Poll::Pending,
IoStatusUpdate::Stop | IoStatusUpdate::PeerGone(_) => {
Poll::Ready(self.stop())
}
IoStatusUpdate::WriteBackpressure => Poll::Pending,
} }
} else if self.poll_io_closed(cx) {
// check if io is closed
return Poll::Ready(self.stop());
} }
Poll::Pending
} }
fn set_payload_error(&mut self, err: PayloadError) { fn set_payload_error(&mut self, err: PayloadError) {
@ -580,6 +587,7 @@ where
self.payload.as_mut().unwrap().1.feed_data(chunk); self.payload.as_mut().unwrap().1.feed_data(chunk);
} }
Ok(PayloadItem::Eof) => { Ok(PayloadItem::Eof) => {
self.flags.remove(Flags::READ_PL_TIMEOUT);
self.payload.as_mut().unwrap().1.feed_eof(); self.payload.as_mut().unwrap().1.feed_eof();
self.payload = None; self.payload = None;
break; break;
@ -651,76 +659,66 @@ where
} }
} }
/// check for io changes, could close while waiting for service call
fn poll_io_closed(&self, cx: &mut Context<'_>) -> bool {
match self.io.poll_status_update(cx) {
Poll::Pending => false,
Poll::Ready(
IoStatusUpdate::KeepAlive
| IoStatusUpdate::Stop
| IoStatusUpdate::PeerGone(_),
) => true,
Poll::Ready(IoStatusUpdate::WriteBackpressure) => false,
}
}
fn handle_timeout(&mut self) -> Result<(), ProtocolError> { fn handle_timeout(&mut self) -> Result<(), ProtocolError> {
// check read rate // check read rate
if self let cfg = if self.flags.contains(Flags::READ_HDRS_TIMEOUT) {
.flags &self.config.headers_read_rate
.intersects(Flags::READ_PL_TIMEOUT | Flags::READ_HDRS_TIMEOUT) } else if self.flags.contains(Flags::READ_PL_TIMEOUT) {
{ &self.config.payload_read_rate
let cfg = if self.flags.contains(Flags::READ_HDRS_TIMEOUT) { } else {
&self.config.headers_read_rate return Ok(());
};
if let Some(ref cfg) = cfg {
let total = if self.flags.contains(Flags::READ_HDRS_TIMEOUT) {
let total = (self.read_remains - self.read_consumed)
.try_into()
.unwrap_or(u16::MAX);
self.read_remains = 0;
total
} else { } else {
&self.config.payload_read_rate let total = (self.read_remains + self.read_consumed)
.try_into()
.unwrap_or(u16::MAX);
self.read_consumed = 0;
total
}; };
if let Some(ref cfg) = cfg { if total > cfg.rate {
let total = if self.flags.contains(Flags::READ_HDRS_TIMEOUT) { // update max timeout
let total = (self.read_remains - self.read_consumed) if !cfg.max_timeout.is_zero() {
.try_into() self.read_max_timeout =
.unwrap_or(u16::MAX); Seconds(self.read_max_timeout.0.saturating_sub(cfg.timeout.0));
self.read_remains = 0; }
total
} else {
let total = (self.read_remains + self.read_consumed)
.try_into()
.unwrap_or(u16::MAX);
self.read_consumed = 0;
total
};
if total > cfg.rate { // start timer for next period
// update max timeout if cfg.max_timeout.is_zero() || !self.read_max_timeout.is_zero() {
if !cfg.max_timeout.is_zero() { log::trace!(
self.read_max_timeout = "{}: Bytes read rate {:?}, extend timer",
Seconds(self.read_max_timeout.0.saturating_sub(cfg.timeout.0)); self.io.tag(),
} total
);
// start timer for next period self.io.start_timer(cfg.timeout);
if cfg.max_timeout.is_zero() || !self.read_max_timeout.is_zero() { return Ok(());
log::trace!(
"{}: Bytes read rate {:?}, extend timer",
self.io.tag(),
total
);
self.io.start_timer(cfg.timeout);
return Ok(());
}
} }
} }
}
log::trace!("{}: Timeout during reading", self.io.tag()); log::trace!(
if self.flags.contains(Flags::READ_PL_TIMEOUT) { "{}: Timeout during reading, {:?}",
self.set_payload_error(PayloadError::Io(io::Error::new( self.io.tag(),
io::ErrorKind::TimedOut, self.flags
"Keep-alive", );
))); if self.flags.contains(Flags::READ_PL_TIMEOUT) {
Err(ProtocolError::SlowPayloadTimeout) self.set_payload_error(PayloadError::Io(io::Error::new(
io::ErrorKind::TimedOut,
"Keep-alive",
)));
Err(ProtocolError::SlowPayloadTimeout)
} else {
Err(ProtocolError::SlowRequestTimeout)
}
} else { } else {
Err(ProtocolError::SlowRequestTimeout) Ok(())
} }
} }
@ -731,7 +729,6 @@ where
// got parsed frame // got parsed frame
if decoded.item.is_some() { if decoded.item.is_some() {
self.read_remains = 0; self.read_remains = 0;
self.io.stop_timer();
self.flags.remove( self.flags.remove(
Flags::READ_KA_TIMEOUT | Flags::READ_HDRS_TIMEOUT | Flags::READ_PL_TIMEOUT, Flags::READ_KA_TIMEOUT | Flags::READ_HDRS_TIMEOUT | Flags::READ_PL_TIMEOUT,
); );
@ -741,16 +738,16 @@ where
} else if self.read_remains == 0 && decoded.remains == 0 { } else if self.read_remains == 0 && decoded.remains == 0 {
// no new data, start keep-alive timer // no new data, start keep-alive timer
if self.codec.keepalive() { if self.codec.keepalive() {
if !self.flags.contains(Flags::READ_KA_TIMEOUT) { if !self.flags.contains(Flags::READ_KA_TIMEOUT)
&& self.config.keep_alive_enabled()
{
log::debug!( log::debug!(
"{}: Start keep-alive timer {:?}", "{}: Start keep-alive timer {:?}",
self.io.tag(), self.io.tag(),
self.config.keep_alive self.config.keep_alive
); );
self.flags.insert(Flags::READ_KA_TIMEOUT); self.flags.insert(Flags::READ_KA_TIMEOUT);
if self.config.keep_alive_enabled() { self.io.start_timer(self.config.keep_alive);
self.io.start_timer(self.config.keep_alive);
}
} }
} else { } else {
self.io.close(); self.io.close();
@ -765,7 +762,8 @@ where
// we got new data but not enough to parse single frame // we got new data but not enough to parse single frame
// start read timer // start read timer
self.flags.remove(Flags::READ_KA_TIMEOUT); self.flags
.remove(Flags::READ_KA_TIMEOUT | Flags::READ_PL_TIMEOUT);
self.flags.insert(Flags::READ_HDRS_TIMEOUT); self.flags.insert(Flags::READ_HDRS_TIMEOUT);
self.read_consumed = 0; self.read_consumed = 0;
@ -781,6 +779,8 @@ where
self.read_remains = decoded.remains as u32; self.read_remains = decoded.remains as u32;
self.read_consumed += decoded.consumed as u32; self.read_consumed += decoded.consumed as u32;
} else if let Some(ref cfg) = self.config.payload_read_rate { } else if let Some(ref cfg) = self.config.payload_read_rate {
log::debug!("{}: Start payload timer {:?}", self.io.tag(), cfg.timeout);
// start payload timer // start payload timer
self.flags.insert(Flags::READ_PL_TIMEOUT); self.flags.insert(Flags::READ_PL_TIMEOUT);
@ -1298,6 +1298,8 @@ mod tests {
async fn test_payload_timeout() { async fn test_payload_timeout() {
let mark = Arc::new(AtomicUsize::new(0)); let mark = Arc::new(AtomicUsize::new(0));
let mark2 = mark.clone(); let mark2 = mark.clone();
let err_mark = Arc::new(AtomicUsize::new(0));
let err_mark2 = err_mark.clone();
let (client, server) = Io::create(); let (client, server) = Io::create();
client.remote_buffer_cap(4096); client.remote_buffer_cap(4096);
@ -1332,7 +1334,17 @@ mod tests {
Rc::new(DispatcherConfig::new( Rc::new(DispatcherConfig::new(
config, config,
svc.into_service(), svc.into_service(),
DefaultControlService, fn_service(move |msg: Control<_, _>| {
if let Control::ProtocolError(ref err) = msg {
if matches!(err.err(), ProtocolError::SlowPayloadTimeout) {
err_mark2.store(
err_mark2.load(Ordering::Relaxed) + 1,
Ordering::Relaxed,
);
}
}
async move { Ok::<_, io::Error>(msg.ack()) }
}),
)), )),
); );
crate::rt::spawn(disp); crate::rt::spawn(disp);
@ -1347,5 +1359,6 @@ mod tests {
sleep(Millis(750)).await; sleep(Millis(750)).await;
} }
assert!(mark.load(Ordering::Relaxed) == 1536); assert!(mark.load(Ordering::Relaxed) == 1536);
assert!(err_mark.load(Ordering::Relaxed) == 1);
} }
} }

View file

@ -5,15 +5,11 @@ use futures_util::future::{self, FutureExt};
use futures_util::stream::{once, StreamExt}; use futures_util::stream::{once, StreamExt};
use regex::Regex; use regex::Regex;
use ntex::http::h1::Control;
use ntex::http::header::{self, HeaderName, HeaderValue}; use ntex::http::header::{self, HeaderName, HeaderValue};
use ntex::http::test::server as test_server; use ntex::http::{body, h1::Control, test::server as test_server};
use ntex::http::{ use ntex::http::{HttpService, KeepAlive, Method, Request, Response, StatusCode, Version};
body, HttpService, KeepAlive, Method, Request, Response, StatusCode, Version,
};
use ntex::service::fn_service;
use ntex::time::{sleep, timeout, Millis, Seconds}; use ntex::time::{sleep, timeout, Millis, Seconds};
use ntex::{util::Bytes, util::Ready, web::error}; use ntex::{service::fn_service, util::Bytes, util::Ready, web::error};
#[ntex::test] #[ntex::test]
async fn test_h1() { async fn test_h1() {
@ -256,7 +252,7 @@ async fn test_http1_keepalive_timeout() {
async fn test_http1_no_keepalive_during_response() { async fn test_http1_no_keepalive_during_response() {
let srv = test_server(|| { let srv = test_server(|| {
HttpService::build().keep_alive(1).h1(|_| async { HttpService::build().keep_alive(1).h1(|_| async {
sleep(Millis(1100)).await; sleep(Millis(1200)).await;
Ok::<_, io::Error>(Response::Ok().finish()) Ok::<_, io::Error>(Response::Ok().finish())
}) })
}); });
@ -355,6 +351,35 @@ async fn test_http1_keepalive_disabled() {
assert_eq!(res, 0); assert_eq!(res, 0);
} }
/// Payload timer should not fire aftre dispatcher has read whole payload
#[ntex::test]
async fn test_http1_disable_payload_timer_after_whole_pl_has_been_read() {
let srv = test_server(|| {
HttpService::build()
.headers_read_rate(Seconds(1), Seconds(1), 128)
.payload_read_rate(Seconds(1), Seconds(1), 512)
.keep_alive(1)
.h1_control(fn_service(move |msg: Control<_, _>| async move {
Ok::<_, io::Error>(msg.ack())
}))
.h1(|mut req: Request| async move {
req.payload().recv().await;
sleep(Millis(1500)).await;
Ok::<_, io::Error>(Response::Ok().finish())
})
});
let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
let _ = stream.write_all(b"GET /test/tests/test HTTP/1.1\r\ncontent-length: 4\r\n");
sleep(Millis(250)).await;
let _ = stream.write_all(b"\r\n");
sleep(Millis(250)).await;
let _ = stream.write_all(b"1234");
let mut data = vec![0; 1024];
let _ = stream.read(&mut data);
assert_eq!(&data[..17], b"HTTP/1.1 200 OK\r\n");
}
#[ntex::test] #[ntex::test]
async fn test_content_length() { async fn test_content_length() {
let srv = test_server(|| { let srv = test_server(|| {