From 9c29de14cfd2cb23bf50e924fce11295a5600210 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 29 May 2024 17:42:18 +0500 Subject: [PATCH] Fix handling payload timer after payload got consumed (#366) --- ntex-service/src/lib.rs | 9 +- ntex/CHANGES.md | 4 + ntex/Cargo.toml | 6 +- ntex/src/http/h1/control.rs | 15 ++-- ntex/src/http/h1/default.rs | 2 + ntex/src/http/h1/dispatcher.rs | 157 ++++++++++++++++++--------------- ntex/tests/http_server.rs | 41 +++++++-- 7 files changed, 138 insertions(+), 96 deletions(-) diff --git a/ntex-service/src/lib.rs b/ntex-service/src/lib.rs index 213a0666..e01b6902 100644 --- a/ntex-service/src/lib.rs +++ b/ntex-service/src/lib.rs @@ -6,8 +6,7 @@ unreachable_pub, missing_debug_implementations )] - -use std::{future::Future, rc::Rc}; +use std::rc::Rc; mod and_then; mod apply; @@ -183,11 +182,9 @@ pub trait ServiceFactory { type InitError; /// Create and return a new service value asynchronously. - fn create( - &self, - cfg: Cfg, - ) -> impl Future>; + async fn create(&self, cfg: Cfg) -> Result; + #[inline] /// Create and return a new service value asynchronously and wrap into a container async fn pipeline(&self, cfg: Cfg) -> Result, Self::InitError> where diff --git a/ntex/CHANGES.md b/ntex/CHANGES.md index a6f19e1f..87253296 100644 --- a/ntex/CHANGES.md +++ b/ntex/CHANGES.md @@ -1,5 +1,9 @@ # Changes +## [2.0.1] - 2024-05-29 + +* http: Fix handling payload timer after payload got consumed + ## [2.0.0] - 2024-05-28 * Use "async fn" for Service::ready() and Service::shutdown() diff --git a/ntex/Cargo.toml b/ntex/Cargo.toml index e0c6f775..4b87d4e5 100644 --- a/ntex/Cargo.toml +++ b/ntex/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex" -version = "2.0.0" +version = "2.0.1" authors = ["ntex contributors "] description = "Framework for composable network services" readme = "README.md" @@ -63,10 +63,10 @@ ntex-router = "0.5.3" ntex-service = "3.0" ntex-macros = "0.1.3" ntex-util = "2.0" -ntex-bytes = "0.1.25" +ntex-bytes = "0.1.27" ntex-server = "2.0" ntex-h2 = "1.0" -ntex-rt = "0.4.12" +ntex-rt = "0.4.13" ntex-io = "2.0" ntex-net = "2.0" ntex-tls = "2.0" diff --git a/ntex/src/http/h1/control.rs b/ntex/src/http/h1/control.rs index 5ecaec95..6d4a985d 100644 --- a/ntex/src/http/h1/control.rs +++ b/ntex/src/http/h1/control.rs @@ -4,6 +4,7 @@ use crate::http::message::CurrentIo; use crate::http::{body::Body, h1::Codec, Request, Response, ResponseError}; use crate::io::{Filter, Io, IoBoxed}; +#[derive(Debug)] pub enum Control { /// New request is loaded NewRequest(NewRequest), @@ -40,19 +41,19 @@ bitflags::bitflags! { #[derive(Debug)] pub(super) enum ControlResult { - // handle request expect + /// handle request expect Expect(Request), - // handle request upgrade + /// handle request upgrade Upgrade(Request), - // forward request to publish service + /// forward request to publish service Publish(Request), - // forward request to publish service + /// forward request to publish service PublishUpgrade(Request), - // send response + /// send response Response(Response<()>, Body), - // send response + /// send response ResponseWithIo(Response<()>, Body, IoBoxed), - // drop connection + /// drop connection Stop, } diff --git a/ntex/src/http/h1/default.rs b/ntex/src/http/h1/default.rs index 4ec6f3a4..3c0d3268 100644 --- a/ntex/src/http/h1/default.rs +++ b/ntex/src/http/h1/default.rs @@ -20,6 +20,7 @@ where type Service = DefaultControlService; type InitError = io::Error; + #[inline] async fn create(&self, _: ()) -> Result { Ok(DefaultControlService) } @@ -33,6 +34,7 @@ where type Response = ControlAck; type Error = io::Error; + #[inline] async fn call( &self, req: Control, diff --git a/ntex/src/http/h1/dispatcher.rs b/ntex/src/http/h1/dispatcher.rs index 98a1bf6b..3a338071 100644 --- a/ntex/src/http/h1/dispatcher.rs +++ b/ntex/src/http/h1/dispatcher.rs @@ -507,13 +507,20 @@ where fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll> { if self.payload.is_some() { if let Some(st) = ready!(self.poll_request_payload(cx)) { - return Poll::Ready(st); + Poll::Ready(st) + } else { + Poll::Pending + } + } else { + // check for io changes, could close while waiting for service call + match ready!(self.io.poll_status_update(cx)) { + IoStatusUpdate::KeepAlive => Poll::Pending, + IoStatusUpdate::Stop | IoStatusUpdate::PeerGone(_) => { + Poll::Ready(self.stop()) + } + IoStatusUpdate::WriteBackpressure => Poll::Pending, } - } else if self.poll_io_closed(cx) { - // check if io is closed - return Poll::Ready(self.stop()); } - Poll::Pending } fn set_payload_error(&mut self, err: PayloadError) { @@ -580,6 +587,7 @@ where self.payload.as_mut().unwrap().1.feed_data(chunk); } Ok(PayloadItem::Eof) => { + self.flags.remove(Flags::READ_PL_TIMEOUT); self.payload.as_mut().unwrap().1.feed_eof(); self.payload = None; break; @@ -651,76 +659,66 @@ where } } - /// check for io changes, could close while waiting for service call - fn poll_io_closed(&self, cx: &mut Context<'_>) -> bool { - match self.io.poll_status_update(cx) { - Poll::Pending => false, - Poll::Ready( - IoStatusUpdate::KeepAlive - | IoStatusUpdate::Stop - | IoStatusUpdate::PeerGone(_), - ) => true, - Poll::Ready(IoStatusUpdate::WriteBackpressure) => false, - } - } - fn handle_timeout(&mut self) -> Result<(), ProtocolError> { // check read rate - if self - .flags - .intersects(Flags::READ_PL_TIMEOUT | Flags::READ_HDRS_TIMEOUT) - { - let cfg = if self.flags.contains(Flags::READ_HDRS_TIMEOUT) { - &self.config.headers_read_rate + let cfg = if self.flags.contains(Flags::READ_HDRS_TIMEOUT) { + &self.config.headers_read_rate + } else if self.flags.contains(Flags::READ_PL_TIMEOUT) { + &self.config.payload_read_rate + } else { + return Ok(()); + }; + + if let Some(ref cfg) = cfg { + let total = if self.flags.contains(Flags::READ_HDRS_TIMEOUT) { + let total = (self.read_remains - self.read_consumed) + .try_into() + .unwrap_or(u16::MAX); + self.read_remains = 0; + total } else { - &self.config.payload_read_rate + let total = (self.read_remains + self.read_consumed) + .try_into() + .unwrap_or(u16::MAX); + self.read_consumed = 0; + total }; - if let Some(ref cfg) = cfg { - let total = if self.flags.contains(Flags::READ_HDRS_TIMEOUT) { - let total = (self.read_remains - self.read_consumed) - .try_into() - .unwrap_or(u16::MAX); - self.read_remains = 0; - total - } else { - let total = (self.read_remains + self.read_consumed) - .try_into() - .unwrap_or(u16::MAX); - self.read_consumed = 0; - total - }; + if total > cfg.rate { + // update max timeout + if !cfg.max_timeout.is_zero() { + self.read_max_timeout = + Seconds(self.read_max_timeout.0.saturating_sub(cfg.timeout.0)); + } - if total > cfg.rate { - // update max timeout - if !cfg.max_timeout.is_zero() { - self.read_max_timeout = - Seconds(self.read_max_timeout.0.saturating_sub(cfg.timeout.0)); - } - - // start timer for next period - if cfg.max_timeout.is_zero() || !self.read_max_timeout.is_zero() { - log::trace!( - "{}: Bytes read rate {:?}, extend timer", - self.io.tag(), - total - ); - self.io.start_timer(cfg.timeout); - return Ok(()); - } + // start timer for next period + if cfg.max_timeout.is_zero() || !self.read_max_timeout.is_zero() { + log::trace!( + "{}: Bytes read rate {:?}, extend timer", + self.io.tag(), + total + ); + self.io.start_timer(cfg.timeout); + return Ok(()); } } - } - log::trace!("{}: Timeout during reading", self.io.tag()); - if self.flags.contains(Flags::READ_PL_TIMEOUT) { - self.set_payload_error(PayloadError::Io(io::Error::new( - io::ErrorKind::TimedOut, - "Keep-alive", - ))); - Err(ProtocolError::SlowPayloadTimeout) + log::trace!( + "{}: Timeout during reading, {:?}", + self.io.tag(), + self.flags + ); + if self.flags.contains(Flags::READ_PL_TIMEOUT) { + self.set_payload_error(PayloadError::Io(io::Error::new( + io::ErrorKind::TimedOut, + "Keep-alive", + ))); + Err(ProtocolError::SlowPayloadTimeout) + } else { + Err(ProtocolError::SlowRequestTimeout) + } } else { - Err(ProtocolError::SlowRequestTimeout) + Ok(()) } } @@ -731,7 +729,6 @@ where // got parsed frame if decoded.item.is_some() { self.read_remains = 0; - self.io.stop_timer(); self.flags.remove( Flags::READ_KA_TIMEOUT | Flags::READ_HDRS_TIMEOUT | Flags::READ_PL_TIMEOUT, ); @@ -741,16 +738,16 @@ where } else if self.read_remains == 0 && decoded.remains == 0 { // no new data, start keep-alive timer if self.codec.keepalive() { - if !self.flags.contains(Flags::READ_KA_TIMEOUT) { + if !self.flags.contains(Flags::READ_KA_TIMEOUT) + && self.config.keep_alive_enabled() + { log::debug!( "{}: Start keep-alive timer {:?}", self.io.tag(), self.config.keep_alive ); self.flags.insert(Flags::READ_KA_TIMEOUT); - if self.config.keep_alive_enabled() { - self.io.start_timer(self.config.keep_alive); - } + self.io.start_timer(self.config.keep_alive); } } else { self.io.close(); @@ -765,7 +762,8 @@ where // we got new data but not enough to parse single frame // start read timer - self.flags.remove(Flags::READ_KA_TIMEOUT); + self.flags + .remove(Flags::READ_KA_TIMEOUT | Flags::READ_PL_TIMEOUT); self.flags.insert(Flags::READ_HDRS_TIMEOUT); self.read_consumed = 0; @@ -781,6 +779,8 @@ where self.read_remains = decoded.remains as u32; self.read_consumed += decoded.consumed as u32; } else if let Some(ref cfg) = self.config.payload_read_rate { + log::debug!("{}: Start payload timer {:?}", self.io.tag(), cfg.timeout); + // start payload timer self.flags.insert(Flags::READ_PL_TIMEOUT); @@ -1298,6 +1298,8 @@ mod tests { async fn test_payload_timeout() { let mark = Arc::new(AtomicUsize::new(0)); let mark2 = mark.clone(); + let err_mark = Arc::new(AtomicUsize::new(0)); + let err_mark2 = err_mark.clone(); let (client, server) = Io::create(); client.remote_buffer_cap(4096); @@ -1332,7 +1334,17 @@ mod tests { Rc::new(DispatcherConfig::new( config, svc.into_service(), - DefaultControlService, + fn_service(move |msg: Control<_, _>| { + if let Control::ProtocolError(ref err) = msg { + if matches!(err.err(), ProtocolError::SlowPayloadTimeout) { + err_mark2.store( + err_mark2.load(Ordering::Relaxed) + 1, + Ordering::Relaxed, + ); + } + } + async move { Ok::<_, io::Error>(msg.ack()) } + }), )), ); crate::rt::spawn(disp); @@ -1347,5 +1359,6 @@ mod tests { sleep(Millis(750)).await; } assert!(mark.load(Ordering::Relaxed) == 1536); + assert!(err_mark.load(Ordering::Relaxed) == 1); } } diff --git a/ntex/tests/http_server.rs b/ntex/tests/http_server.rs index 433a968d..50f724b8 100644 --- a/ntex/tests/http_server.rs +++ b/ntex/tests/http_server.rs @@ -5,15 +5,11 @@ use futures_util::future::{self, FutureExt}; use futures_util::stream::{once, StreamExt}; use regex::Regex; -use ntex::http::h1::Control; use ntex::http::header::{self, HeaderName, HeaderValue}; -use ntex::http::test::server as test_server; -use ntex::http::{ - body, HttpService, KeepAlive, Method, Request, Response, StatusCode, Version, -}; -use ntex::service::fn_service; +use ntex::http::{body, h1::Control, test::server as test_server}; +use ntex::http::{HttpService, KeepAlive, Method, Request, Response, StatusCode, Version}; use ntex::time::{sleep, timeout, Millis, Seconds}; -use ntex::{util::Bytes, util::Ready, web::error}; +use ntex::{service::fn_service, util::Bytes, util::Ready, web::error}; #[ntex::test] async fn test_h1() { @@ -256,7 +252,7 @@ async fn test_http1_keepalive_timeout() { async fn test_http1_no_keepalive_during_response() { let srv = test_server(|| { HttpService::build().keep_alive(1).h1(|_| async { - sleep(Millis(1100)).await; + sleep(Millis(1200)).await; Ok::<_, io::Error>(Response::Ok().finish()) }) }); @@ -355,6 +351,35 @@ async fn test_http1_keepalive_disabled() { assert_eq!(res, 0); } +/// Payload timer should not fire aftre dispatcher has read whole payload +#[ntex::test] +async fn test_http1_disable_payload_timer_after_whole_pl_has_been_read() { + let srv = test_server(|| { + HttpService::build() + .headers_read_rate(Seconds(1), Seconds(1), 128) + .payload_read_rate(Seconds(1), Seconds(1), 512) + .keep_alive(1) + .h1_control(fn_service(move |msg: Control<_, _>| async move { + Ok::<_, io::Error>(msg.ack()) + })) + .h1(|mut req: Request| async move { + req.payload().recv().await; + sleep(Millis(1500)).await; + Ok::<_, io::Error>(Response::Ok().finish()) + }) + }); + + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream.write_all(b"GET /test/tests/test HTTP/1.1\r\ncontent-length: 4\r\n"); + sleep(Millis(250)).await; + let _ = stream.write_all(b"\r\n"); + sleep(Millis(250)).await; + let _ = stream.write_all(b"1234"); + let mut data = vec![0; 1024]; + let _ = stream.read(&mut data); + assert_eq!(&data[..17], b"HTTP/1.1 200 OK\r\n"); +} + #[ntex::test] async fn test_content_length() { let srv = test_server(|| {