From 7417ee3a4bfc7daf47ad5c242c4cdb4774786572 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Sun, 16 Mar 2025 12:11:01 +0100 Subject: [PATCH] Allow to run publish future to completion in case error (#529) --- ntex/CHANGES.md | 4 +++ ntex/Cargo.toml | 2 +- ntex/src/http/h1/dispatcher.rs | 21 ++++++++++++--- ntex/src/http/h1/payload.rs | 48 +++++++++++++++++++++++++--------- ntex/tests/http_server.rs | 34 ++++++++++++++++++++++-- 5 files changed, 89 insertions(+), 20 deletions(-) diff --git a/ntex/CHANGES.md b/ntex/CHANGES.md index c2de75b2..02eb904b 100644 --- a/ntex/CHANGES.md +++ b/ntex/CHANGES.md @@ -1,5 +1,9 @@ # Changes +## [2.12.2] - 2025-03-15 + +* http: Allow to run publish future to completion in case error + ## [2.12.1] - 2025-03-14 * Allow to disable test logging (no-test-logging features) diff --git a/ntex/Cargo.toml b/ntex/Cargo.toml index 0dfac797..4f06c9b1 100644 --- a/ntex/Cargo.toml +++ b/ntex/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex" -version = "2.12.1" +version = "2.12.2" authors = ["ntex contributors "] description = "Framework for composable network services" readme = "README.md" diff --git a/ntex/src/http/h1/dispatcher.rs b/ntex/src/http/h1/dispatcher.rs index 18263583..7a2142ea 100644 --- a/ntex/src/http/h1/dispatcher.rs +++ b/ntex/src/http/h1/dispatcher.rs @@ -1,5 +1,5 @@ //! HTTP/1 protocol dispatcher -use std::{error, future, io, marker, pin::Pin, rc::Rc, task::Context, task::Poll}; +use std::{error, future, io, marker, mem, pin::Pin, rc::Rc, task::Context, task::Poll}; use crate::io::{Decoded, Filter, Io, IoStatusUpdate, RecvError}; use crate::service::{PipelineCall, Service}; @@ -144,7 +144,20 @@ where inner.send_response(res, body) } Poll::Ready(Err(err)) => inner.control(Control::err(err)), - Poll::Pending => ready!(inner.poll_request(cx)), + Poll::Pending => { + // state changed because of error. + // spawn current publish future to runtime + // so it could complete error handling + let st = ready!(inner.poll_request(cx)); + if inner.payload.is_some() { + if let State::CallPublish { fut } = + mem::replace(&mut *this.st, State::ReadRequest) + { + crate::rt::spawn(fut); + } + } + st + } }, // handle control service responses State::CallControl { fut } => match Pin::new(fut).poll(cx) { @@ -339,7 +352,7 @@ where .io .encode(Message::Item((msg, body.size())), &self.codec) .map_err(|err| { - if let Some(mut payload) = self.payload.take() { + if let Some(ref mut payload) = self.payload { payload.1.set_error(PayloadError::Incomplete(None)); } err @@ -438,7 +451,7 @@ where } fn set_payload_error(&mut self, err: PayloadError) { - if let Some(mut payload) = self.payload.take() { + if let Some(ref mut payload) = self.payload { payload.1.set_error(err); } } diff --git a/ntex/src/http/h1/payload.rs b/ntex/src/http/h1/payload.rs index 1fe5e5a5..ac3c8609 100644 --- a/ntex/src/http/h1/payload.rs +++ b/ntex/src/http/h1/payload.rs @@ -3,8 +3,7 @@ use std::rc::{Rc, Weak}; use std::task::{Context, Poll}; use std::{cell::RefCell, collections::VecDeque, pin::Pin}; -use crate::http::error::PayloadError; -use crate::{task::LocalWaker, util::Bytes, util::Stream}; +use crate::{http::error::PayloadError, task::LocalWaker, util::Bytes, util::Stream}; /// max buffer size 32k const MAX_BUFFER_SIZE: usize = 32_768; @@ -119,7 +118,7 @@ impl PayloadSender { // we check only if Payload (other side) is alive, // otherwise always return true (consume payload) if let Some(shared) = self.inner.upgrade() { - if shared.borrow().need_read { + if shared.borrow().flags.contains(Flags::NEED_READ) { PayloadStatus::Read } else { shared.borrow_mut().io_task.register(cx.waker()); @@ -131,12 +130,20 @@ impl PayloadSender { } } +bitflags::bitflags! { + #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)] + struct Flags: u8 { + const EOF = 0b0000_0001; + const ERROR = 0b0000_0010; + const NEED_READ = 0b0000_0100; + } +} + #[derive(Debug)] struct Inner { len: usize, - eof: bool, + flags: Flags, err: Option, - need_read: bool, items: VecDeque, task: LocalWaker, io_task: LocalWaker, @@ -144,12 +151,16 @@ struct Inner { impl Inner { fn new(eof: bool) -> Self { + let flags = if eof { + Flags::EOF | Flags::NEED_READ + } else { + Flags::NEED_READ + }; Inner { - eof, + flags, len: 0, err: None, items: VecDeque::new(), - need_read: true, task: LocalWaker::new(), io_task: LocalWaker::new(), } @@ -157,18 +168,23 @@ impl Inner { fn set_error(&mut self, err: PayloadError) { self.err = Some(err); + self.flags.insert(Flags::ERROR); self.task.wake() } fn feed_eof(&mut self) { - self.eof = true; + self.flags.insert(Flags::EOF); self.task.wake() } fn feed_data(&mut self, data: Bytes) { self.len += data.len(); self.items.push_back(data); - self.need_read = self.len < MAX_BUFFER_SIZE; + if self.len < MAX_BUFFER_SIZE { + self.flags.insert(Flags::NEED_READ); + } else { + self.flags.remove(Flags::NEED_READ); + } self.task.wake(); } @@ -178,19 +194,25 @@ impl Inner { ) -> Poll>> { if let Some(data) = self.items.pop_front() { self.len -= data.len(); - self.need_read = self.len < MAX_BUFFER_SIZE; + if self.len < MAX_BUFFER_SIZE { + self.flags.insert(Flags::NEED_READ); + } else { + self.flags.remove(Flags::NEED_READ); + } - if self.need_read && !self.eof { + if self.flags.contains(Flags::NEED_READ) + && !self.flags.intersects(Flags::EOF | Flags::ERROR) + { self.task.register(cx.waker()); } self.io_task.wake(); Poll::Ready(Some(Ok(data))) } else if let Some(err) = self.err.take() { Poll::Ready(Some(Err(err))) - } else if self.eof { + } else if self.flags.intersects(Flags::EOF | Flags::ERROR) { Poll::Ready(None) } else { - self.need_read = true; + self.flags.insert(Flags::NEED_READ); self.task.register(cx.waker()); self.io_task.wake(); Poll::Pending diff --git a/ntex/tests/http_server.rs b/ntex/tests/http_server.rs index 44512500..cea9e667 100644 --- a/ntex/tests/http_server.rs +++ b/ntex/tests/http_server.rs @@ -405,6 +405,36 @@ async fn test_http1_handle_not_consumed_payload() { assert_eq!(&data[..17], b"HTTP/1.1 200 OK\r\n"); } +/// Handle payload errors (keep-alive, disconnects) +#[ntex::test] +async fn test_http1_handle_payload_errors() { + let count = Arc::new(AtomicUsize::new(0)); + let count2 = count.clone(); + + let srv = test_server(move || { + let count = count2.clone(); + HttpService::build().h1(move |mut req: Request| { + let count = count.clone(); + async move { + let mut pl = req.take_payload(); + let result = pl.recv().await; + if result.unwrap().is_err() { + count.fetch_add(1, Ordering::Relaxed); + } + Ok::<_, io::Error>(Response::Ok().finish()) + } + }) + }); + + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = + stream.write_all(b"GET /test/tests/test HTTP/1.1\r\ncontent-length: 99999\r\n\r\n"); + sleep(Millis(250)).await; + drop(stream); + sleep(Millis(250)).await; + assert_eq!(count.load(Ordering::Acquire), 1); +} + #[ntex::test] async fn test_content_length() { let srv = test_server(|| { @@ -714,7 +744,7 @@ async fn test_h1_client_drop() -> io::Result<()> { let _st = SetOnDrop(count); assert!(req.peer_addr().is_some()); assert_eq!(req.version(), Version::HTTP_11); - sleep(Seconds(100)).await; + sleep(Millis(500)).await; Ok::<_, io::Error>(Response::Ok().finish()) } }) @@ -722,7 +752,7 @@ async fn test_h1_client_drop() -> io::Result<()> { let result = timeout(Millis(100), srv.request(Method::GET, "/").send()).await; assert!(result.is_err()); - sleep(Millis(250)).await; + sleep(Millis(1000)).await; assert_eq!(count.load(Ordering::Relaxed), 1); Ok(()) }