mirror of
https://github.com/ntex-rs/ntex.git
synced 2025-04-03 21:07:39 +03:00
Fix handling payload timer after payload got consumed (#366)
This commit is contained in:
parent
3b49828e5f
commit
9c29de14cf
7 changed files with 138 additions and 96 deletions
|
@ -6,8 +6,7 @@
|
|||
unreachable_pub,
|
||||
missing_debug_implementations
|
||||
)]
|
||||
|
||||
use std::{future::Future, rc::Rc};
|
||||
use std::rc::Rc;
|
||||
|
||||
mod and_then;
|
||||
mod apply;
|
||||
|
@ -183,11 +182,9 @@ pub trait ServiceFactory<Req, Cfg = ()> {
|
|||
type InitError;
|
||||
|
||||
/// Create and return a new service value asynchronously.
|
||||
fn create(
|
||||
&self,
|
||||
cfg: Cfg,
|
||||
) -> impl Future<Output = Result<Self::Service, Self::InitError>>;
|
||||
async fn create(&self, cfg: Cfg) -> Result<Self::Service, Self::InitError>;
|
||||
|
||||
#[inline]
|
||||
/// Create and return a new service value asynchronously and wrap into a container
|
||||
async fn pipeline(&self, cfg: Cfg) -> Result<Pipeline<Self::Service>, Self::InitError>
|
||||
where
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
# Changes
|
||||
|
||||
## [2.0.1] - 2024-05-29
|
||||
|
||||
* http: Fix handling payload timer after payload got consumed
|
||||
|
||||
## [2.0.0] - 2024-05-28
|
||||
|
||||
* Use "async fn" for Service::ready() and Service::shutdown()
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "ntex"
|
||||
version = "2.0.0"
|
||||
version = "2.0.1"
|
||||
authors = ["ntex contributors <team@ntex.rs>"]
|
||||
description = "Framework for composable network services"
|
||||
readme = "README.md"
|
||||
|
@ -63,10 +63,10 @@ ntex-router = "0.5.3"
|
|||
ntex-service = "3.0"
|
||||
ntex-macros = "0.1.3"
|
||||
ntex-util = "2.0"
|
||||
ntex-bytes = "0.1.25"
|
||||
ntex-bytes = "0.1.27"
|
||||
ntex-server = "2.0"
|
||||
ntex-h2 = "1.0"
|
||||
ntex-rt = "0.4.12"
|
||||
ntex-rt = "0.4.13"
|
||||
ntex-io = "2.0"
|
||||
ntex-net = "2.0"
|
||||
ntex-tls = "2.0"
|
||||
|
|
|
@ -4,6 +4,7 @@ use crate::http::message::CurrentIo;
|
|||
use crate::http::{body::Body, h1::Codec, Request, Response, ResponseError};
|
||||
use crate::io::{Filter, Io, IoBoxed};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Control<F, Err> {
|
||||
/// New request is loaded
|
||||
NewRequest(NewRequest),
|
||||
|
@ -40,19 +41,19 @@ bitflags::bitflags! {
|
|||
|
||||
#[derive(Debug)]
|
||||
pub(super) enum ControlResult {
|
||||
// handle request expect
|
||||
/// handle request expect
|
||||
Expect(Request),
|
||||
// handle request upgrade
|
||||
/// handle request upgrade
|
||||
Upgrade(Request),
|
||||
// forward request to publish service
|
||||
/// forward request to publish service
|
||||
Publish(Request),
|
||||
// forward request to publish service
|
||||
/// forward request to publish service
|
||||
PublishUpgrade(Request),
|
||||
// send response
|
||||
/// send response
|
||||
Response(Response<()>, Body),
|
||||
// send response
|
||||
/// send response
|
||||
ResponseWithIo(Response<()>, Body, IoBoxed),
|
||||
// drop connection
|
||||
/// drop connection
|
||||
Stop,
|
||||
}
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ where
|
|||
type Service = DefaultControlService;
|
||||
type InitError = io::Error;
|
||||
|
||||
#[inline]
|
||||
async fn create(&self, _: ()) -> Result<Self::Service, Self::InitError> {
|
||||
Ok(DefaultControlService)
|
||||
}
|
||||
|
@ -33,6 +34,7 @@ where
|
|||
type Response = ControlAck;
|
||||
type Error = io::Error;
|
||||
|
||||
#[inline]
|
||||
async fn call(
|
||||
&self,
|
||||
req: Control<F, Err>,
|
||||
|
|
|
@ -507,13 +507,20 @@ where
|
|||
fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll<State<F, C, S, B>> {
|
||||
if self.payload.is_some() {
|
||||
if let Some(st) = ready!(self.poll_request_payload(cx)) {
|
||||
return Poll::Ready(st);
|
||||
Poll::Ready(st)
|
||||
} else {
|
||||
Poll::Pending
|
||||
}
|
||||
} else {
|
||||
// check for io changes, could close while waiting for service call
|
||||
match ready!(self.io.poll_status_update(cx)) {
|
||||
IoStatusUpdate::KeepAlive => Poll::Pending,
|
||||
IoStatusUpdate::Stop | IoStatusUpdate::PeerGone(_) => {
|
||||
Poll::Ready(self.stop())
|
||||
}
|
||||
IoStatusUpdate::WriteBackpressure => Poll::Pending,
|
||||
}
|
||||
} else if self.poll_io_closed(cx) {
|
||||
// check if io is closed
|
||||
return Poll::Ready(self.stop());
|
||||
}
|
||||
Poll::Pending
|
||||
}
|
||||
|
||||
fn set_payload_error(&mut self, err: PayloadError) {
|
||||
|
@ -580,6 +587,7 @@ where
|
|||
self.payload.as_mut().unwrap().1.feed_data(chunk);
|
||||
}
|
||||
Ok(PayloadItem::Eof) => {
|
||||
self.flags.remove(Flags::READ_PL_TIMEOUT);
|
||||
self.payload.as_mut().unwrap().1.feed_eof();
|
||||
self.payload = None;
|
||||
break;
|
||||
|
@ -651,76 +659,66 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
/// check for io changes, could close while waiting for service call
|
||||
fn poll_io_closed(&self, cx: &mut Context<'_>) -> bool {
|
||||
match self.io.poll_status_update(cx) {
|
||||
Poll::Pending => false,
|
||||
Poll::Ready(
|
||||
IoStatusUpdate::KeepAlive
|
||||
| IoStatusUpdate::Stop
|
||||
| IoStatusUpdate::PeerGone(_),
|
||||
) => true,
|
||||
Poll::Ready(IoStatusUpdate::WriteBackpressure) => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_timeout(&mut self) -> Result<(), ProtocolError> {
|
||||
// check read rate
|
||||
if self
|
||||
.flags
|
||||
.intersects(Flags::READ_PL_TIMEOUT | Flags::READ_HDRS_TIMEOUT)
|
||||
{
|
||||
let cfg = if self.flags.contains(Flags::READ_HDRS_TIMEOUT) {
|
||||
&self.config.headers_read_rate
|
||||
let cfg = if self.flags.contains(Flags::READ_HDRS_TIMEOUT) {
|
||||
&self.config.headers_read_rate
|
||||
} else if self.flags.contains(Flags::READ_PL_TIMEOUT) {
|
||||
&self.config.payload_read_rate
|
||||
} else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
if let Some(ref cfg) = cfg {
|
||||
let total = if self.flags.contains(Flags::READ_HDRS_TIMEOUT) {
|
||||
let total = (self.read_remains - self.read_consumed)
|
||||
.try_into()
|
||||
.unwrap_or(u16::MAX);
|
||||
self.read_remains = 0;
|
||||
total
|
||||
} else {
|
||||
&self.config.payload_read_rate
|
||||
let total = (self.read_remains + self.read_consumed)
|
||||
.try_into()
|
||||
.unwrap_or(u16::MAX);
|
||||
self.read_consumed = 0;
|
||||
total
|
||||
};
|
||||
|
||||
if let Some(ref cfg) = cfg {
|
||||
let total = if self.flags.contains(Flags::READ_HDRS_TIMEOUT) {
|
||||
let total = (self.read_remains - self.read_consumed)
|
||||
.try_into()
|
||||
.unwrap_or(u16::MAX);
|
||||
self.read_remains = 0;
|
||||
total
|
||||
} else {
|
||||
let total = (self.read_remains + self.read_consumed)
|
||||
.try_into()
|
||||
.unwrap_or(u16::MAX);
|
||||
self.read_consumed = 0;
|
||||
total
|
||||
};
|
||||
if total > cfg.rate {
|
||||
// update max timeout
|
||||
if !cfg.max_timeout.is_zero() {
|
||||
self.read_max_timeout =
|
||||
Seconds(self.read_max_timeout.0.saturating_sub(cfg.timeout.0));
|
||||
}
|
||||
|
||||
if total > cfg.rate {
|
||||
// update max timeout
|
||||
if !cfg.max_timeout.is_zero() {
|
||||
self.read_max_timeout =
|
||||
Seconds(self.read_max_timeout.0.saturating_sub(cfg.timeout.0));
|
||||
}
|
||||
|
||||
// start timer for next period
|
||||
if cfg.max_timeout.is_zero() || !self.read_max_timeout.is_zero() {
|
||||
log::trace!(
|
||||
"{}: Bytes read rate {:?}, extend timer",
|
||||
self.io.tag(),
|
||||
total
|
||||
);
|
||||
self.io.start_timer(cfg.timeout);
|
||||
return Ok(());
|
||||
}
|
||||
// start timer for next period
|
||||
if cfg.max_timeout.is_zero() || !self.read_max_timeout.is_zero() {
|
||||
log::trace!(
|
||||
"{}: Bytes read rate {:?}, extend timer",
|
||||
self.io.tag(),
|
||||
total
|
||||
);
|
||||
self.io.start_timer(cfg.timeout);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log::trace!("{}: Timeout during reading", self.io.tag());
|
||||
if self.flags.contains(Flags::READ_PL_TIMEOUT) {
|
||||
self.set_payload_error(PayloadError::Io(io::Error::new(
|
||||
io::ErrorKind::TimedOut,
|
||||
"Keep-alive",
|
||||
)));
|
||||
Err(ProtocolError::SlowPayloadTimeout)
|
||||
log::trace!(
|
||||
"{}: Timeout during reading, {:?}",
|
||||
self.io.tag(),
|
||||
self.flags
|
||||
);
|
||||
if self.flags.contains(Flags::READ_PL_TIMEOUT) {
|
||||
self.set_payload_error(PayloadError::Io(io::Error::new(
|
||||
io::ErrorKind::TimedOut,
|
||||
"Keep-alive",
|
||||
)));
|
||||
Err(ProtocolError::SlowPayloadTimeout)
|
||||
} else {
|
||||
Err(ProtocolError::SlowRequestTimeout)
|
||||
}
|
||||
} else {
|
||||
Err(ProtocolError::SlowRequestTimeout)
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -731,7 +729,6 @@ where
|
|||
// got parsed frame
|
||||
if decoded.item.is_some() {
|
||||
self.read_remains = 0;
|
||||
self.io.stop_timer();
|
||||
self.flags.remove(
|
||||
Flags::READ_KA_TIMEOUT | Flags::READ_HDRS_TIMEOUT | Flags::READ_PL_TIMEOUT,
|
||||
);
|
||||
|
@ -741,16 +738,16 @@ where
|
|||
} else if self.read_remains == 0 && decoded.remains == 0 {
|
||||
// no new data, start keep-alive timer
|
||||
if self.codec.keepalive() {
|
||||
if !self.flags.contains(Flags::READ_KA_TIMEOUT) {
|
||||
if !self.flags.contains(Flags::READ_KA_TIMEOUT)
|
||||
&& self.config.keep_alive_enabled()
|
||||
{
|
||||
log::debug!(
|
||||
"{}: Start keep-alive timer {:?}",
|
||||
self.io.tag(),
|
||||
self.config.keep_alive
|
||||
);
|
||||
self.flags.insert(Flags::READ_KA_TIMEOUT);
|
||||
if self.config.keep_alive_enabled() {
|
||||
self.io.start_timer(self.config.keep_alive);
|
||||
}
|
||||
self.io.start_timer(self.config.keep_alive);
|
||||
}
|
||||
} else {
|
||||
self.io.close();
|
||||
|
@ -765,7 +762,8 @@ where
|
|||
|
||||
// we got new data but not enough to parse single frame
|
||||
// start read timer
|
||||
self.flags.remove(Flags::READ_KA_TIMEOUT);
|
||||
self.flags
|
||||
.remove(Flags::READ_KA_TIMEOUT | Flags::READ_PL_TIMEOUT);
|
||||
self.flags.insert(Flags::READ_HDRS_TIMEOUT);
|
||||
|
||||
self.read_consumed = 0;
|
||||
|
@ -781,6 +779,8 @@ where
|
|||
self.read_remains = decoded.remains as u32;
|
||||
self.read_consumed += decoded.consumed as u32;
|
||||
} else if let Some(ref cfg) = self.config.payload_read_rate {
|
||||
log::debug!("{}: Start payload timer {:?}", self.io.tag(), cfg.timeout);
|
||||
|
||||
// start payload timer
|
||||
self.flags.insert(Flags::READ_PL_TIMEOUT);
|
||||
|
||||
|
@ -1298,6 +1298,8 @@ mod tests {
|
|||
async fn test_payload_timeout() {
|
||||
let mark = Arc::new(AtomicUsize::new(0));
|
||||
let mark2 = mark.clone();
|
||||
let err_mark = Arc::new(AtomicUsize::new(0));
|
||||
let err_mark2 = err_mark.clone();
|
||||
|
||||
let (client, server) = Io::create();
|
||||
client.remote_buffer_cap(4096);
|
||||
|
@ -1332,7 +1334,17 @@ mod tests {
|
|||
Rc::new(DispatcherConfig::new(
|
||||
config,
|
||||
svc.into_service(),
|
||||
DefaultControlService,
|
||||
fn_service(move |msg: Control<_, _>| {
|
||||
if let Control::ProtocolError(ref err) = msg {
|
||||
if matches!(err.err(), ProtocolError::SlowPayloadTimeout) {
|
||||
err_mark2.store(
|
||||
err_mark2.load(Ordering::Relaxed) + 1,
|
||||
Ordering::Relaxed,
|
||||
);
|
||||
}
|
||||
}
|
||||
async move { Ok::<_, io::Error>(msg.ack()) }
|
||||
}),
|
||||
)),
|
||||
);
|
||||
crate::rt::spawn(disp);
|
||||
|
@ -1347,5 +1359,6 @@ mod tests {
|
|||
sleep(Millis(750)).await;
|
||||
}
|
||||
assert!(mark.load(Ordering::Relaxed) == 1536);
|
||||
assert!(err_mark.load(Ordering::Relaxed) == 1);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,15 +5,11 @@ use futures_util::future::{self, FutureExt};
|
|||
use futures_util::stream::{once, StreamExt};
|
||||
use regex::Regex;
|
||||
|
||||
use ntex::http::h1::Control;
|
||||
use ntex::http::header::{self, HeaderName, HeaderValue};
|
||||
use ntex::http::test::server as test_server;
|
||||
use ntex::http::{
|
||||
body, HttpService, KeepAlive, Method, Request, Response, StatusCode, Version,
|
||||
};
|
||||
use ntex::service::fn_service;
|
||||
use ntex::http::{body, h1::Control, test::server as test_server};
|
||||
use ntex::http::{HttpService, KeepAlive, Method, Request, Response, StatusCode, Version};
|
||||
use ntex::time::{sleep, timeout, Millis, Seconds};
|
||||
use ntex::{util::Bytes, util::Ready, web::error};
|
||||
use ntex::{service::fn_service, util::Bytes, util::Ready, web::error};
|
||||
|
||||
#[ntex::test]
|
||||
async fn test_h1() {
|
||||
|
@ -256,7 +252,7 @@ async fn test_http1_keepalive_timeout() {
|
|||
async fn test_http1_no_keepalive_during_response() {
|
||||
let srv = test_server(|| {
|
||||
HttpService::build().keep_alive(1).h1(|_| async {
|
||||
sleep(Millis(1100)).await;
|
||||
sleep(Millis(1200)).await;
|
||||
Ok::<_, io::Error>(Response::Ok().finish())
|
||||
})
|
||||
});
|
||||
|
@ -355,6 +351,35 @@ async fn test_http1_keepalive_disabled() {
|
|||
assert_eq!(res, 0);
|
||||
}
|
||||
|
||||
/// Payload timer should not fire aftre dispatcher has read whole payload
|
||||
#[ntex::test]
|
||||
async fn test_http1_disable_payload_timer_after_whole_pl_has_been_read() {
|
||||
let srv = test_server(|| {
|
||||
HttpService::build()
|
||||
.headers_read_rate(Seconds(1), Seconds(1), 128)
|
||||
.payload_read_rate(Seconds(1), Seconds(1), 512)
|
||||
.keep_alive(1)
|
||||
.h1_control(fn_service(move |msg: Control<_, _>| async move {
|
||||
Ok::<_, io::Error>(msg.ack())
|
||||
}))
|
||||
.h1(|mut req: Request| async move {
|
||||
req.payload().recv().await;
|
||||
sleep(Millis(1500)).await;
|
||||
Ok::<_, io::Error>(Response::Ok().finish())
|
||||
})
|
||||
});
|
||||
|
||||
let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
|
||||
let _ = stream.write_all(b"GET /test/tests/test HTTP/1.1\r\ncontent-length: 4\r\n");
|
||||
sleep(Millis(250)).await;
|
||||
let _ = stream.write_all(b"\r\n");
|
||||
sleep(Millis(250)).await;
|
||||
let _ = stream.write_all(b"1234");
|
||||
let mut data = vec![0; 1024];
|
||||
let _ = stream.read(&mut data);
|
||||
assert_eq!(&data[..17], b"HTTP/1.1 200 OK\r\n");
|
||||
}
|
||||
|
||||
#[ntex::test]
|
||||
async fn test_content_length() {
|
||||
let srv = test_server(|| {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue