Allow to run publish future to completion in case error (#529)

This commit is contained in:
Nikolay Kim 2025-03-16 12:11:01 +01:00 committed by GitHub
parent 1f71b200ad
commit 7417ee3a4b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 89 additions and 20 deletions

View file

@ -1,5 +1,9 @@
# Changes
## [2.12.2] - 2025-03-15
* http: Allow to run publish future to completion in case error
## [2.12.1] - 2025-03-14
* Allow to disable test logging (no-test-logging features)

View file

@ -1,6 +1,6 @@
[package]
name = "ntex"
version = "2.12.1"
version = "2.12.2"
authors = ["ntex contributors <team@ntex.rs>"]
description = "Framework for composable network services"
readme = "README.md"

View file

@ -1,5 +1,5 @@
//! HTTP/1 protocol dispatcher
use std::{error, future, io, marker, pin::Pin, rc::Rc, task::Context, task::Poll};
use std::{error, future, io, marker, mem, pin::Pin, rc::Rc, task::Context, task::Poll};
use crate::io::{Decoded, Filter, Io, IoStatusUpdate, RecvError};
use crate::service::{PipelineCall, Service};
@ -144,7 +144,20 @@ where
inner.send_response(res, body)
}
Poll::Ready(Err(err)) => inner.control(Control::err(err)),
Poll::Pending => ready!(inner.poll_request(cx)),
Poll::Pending => {
// state changed because of error.
// spawn current publish future to runtime
// so it could complete error handling
let st = ready!(inner.poll_request(cx));
if inner.payload.is_some() {
if let State::CallPublish { fut } =
mem::replace(&mut *this.st, State::ReadRequest)
{
crate::rt::spawn(fut);
}
}
st
}
},
// handle control service responses
State::CallControl { fut } => match Pin::new(fut).poll(cx) {
@ -339,7 +352,7 @@ where
.io
.encode(Message::Item((msg, body.size())), &self.codec)
.map_err(|err| {
if let Some(mut payload) = self.payload.take() {
if let Some(ref mut payload) = self.payload {
payload.1.set_error(PayloadError::Incomplete(None));
}
err
@ -438,7 +451,7 @@ where
}
fn set_payload_error(&mut self, err: PayloadError) {
if let Some(mut payload) = self.payload.take() {
if let Some(ref mut payload) = self.payload {
payload.1.set_error(err);
}
}

View file

@ -3,8 +3,7 @@ use std::rc::{Rc, Weak};
use std::task::{Context, Poll};
use std::{cell::RefCell, collections::VecDeque, pin::Pin};
use crate::http::error::PayloadError;
use crate::{task::LocalWaker, util::Bytes, util::Stream};
use crate::{http::error::PayloadError, task::LocalWaker, util::Bytes, util::Stream};
/// max buffer size 32k
const MAX_BUFFER_SIZE: usize = 32_768;
@ -119,7 +118,7 @@ impl PayloadSender {
// we check only if Payload (other side) is alive,
// otherwise always return true (consume payload)
if let Some(shared) = self.inner.upgrade() {
if shared.borrow().need_read {
if shared.borrow().flags.contains(Flags::NEED_READ) {
PayloadStatus::Read
} else {
shared.borrow_mut().io_task.register(cx.waker());
@ -131,12 +130,20 @@ impl PayloadSender {
}
}
bitflags::bitflags! {
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
struct Flags: u8 {
const EOF = 0b0000_0001;
const ERROR = 0b0000_0010;
const NEED_READ = 0b0000_0100;
}
}
#[derive(Debug)]
struct Inner {
len: usize,
eof: bool,
flags: Flags,
err: Option<PayloadError>,
need_read: bool,
items: VecDeque<Bytes>,
task: LocalWaker,
io_task: LocalWaker,
@ -144,12 +151,16 @@ struct Inner {
impl Inner {
fn new(eof: bool) -> Self {
let flags = if eof {
Flags::EOF | Flags::NEED_READ
} else {
Flags::NEED_READ
};
Inner {
eof,
flags,
len: 0,
err: None,
items: VecDeque::new(),
need_read: true,
task: LocalWaker::new(),
io_task: LocalWaker::new(),
}
@ -157,18 +168,23 @@ impl Inner {
fn set_error(&mut self, err: PayloadError) {
self.err = Some(err);
self.flags.insert(Flags::ERROR);
self.task.wake()
}
fn feed_eof(&mut self) {
self.eof = true;
self.flags.insert(Flags::EOF);
self.task.wake()
}
fn feed_data(&mut self, data: Bytes) {
self.len += data.len();
self.items.push_back(data);
self.need_read = self.len < MAX_BUFFER_SIZE;
if self.len < MAX_BUFFER_SIZE {
self.flags.insert(Flags::NEED_READ);
} else {
self.flags.remove(Flags::NEED_READ);
}
self.task.wake();
}
@ -178,19 +194,25 @@ impl Inner {
) -> Poll<Option<Result<Bytes, PayloadError>>> {
if let Some(data) = self.items.pop_front() {
self.len -= data.len();
self.need_read = self.len < MAX_BUFFER_SIZE;
if self.len < MAX_BUFFER_SIZE {
self.flags.insert(Flags::NEED_READ);
} else {
self.flags.remove(Flags::NEED_READ);
}
if self.need_read && !self.eof {
if self.flags.contains(Flags::NEED_READ)
&& !self.flags.intersects(Flags::EOF | Flags::ERROR)
{
self.task.register(cx.waker());
}
self.io_task.wake();
Poll::Ready(Some(Ok(data)))
} else if let Some(err) = self.err.take() {
Poll::Ready(Some(Err(err)))
} else if self.eof {
} else if self.flags.intersects(Flags::EOF | Flags::ERROR) {
Poll::Ready(None)
} else {
self.need_read = true;
self.flags.insert(Flags::NEED_READ);
self.task.register(cx.waker());
self.io_task.wake();
Poll::Pending

View file

@ -405,6 +405,36 @@ async fn test_http1_handle_not_consumed_payload() {
assert_eq!(&data[..17], b"HTTP/1.1 200 OK\r\n");
}
/// Handle payload errors (keep-alive, disconnects)
#[ntex::test]
async fn test_http1_handle_payload_errors() {
let count = Arc::new(AtomicUsize::new(0));
let count2 = count.clone();
let srv = test_server(move || {
let count = count2.clone();
HttpService::build().h1(move |mut req: Request| {
let count = count.clone();
async move {
let mut pl = req.take_payload();
let result = pl.recv().await;
if result.unwrap().is_err() {
count.fetch_add(1, Ordering::Relaxed);
}
Ok::<_, io::Error>(Response::Ok().finish())
}
})
});
let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
let _ =
stream.write_all(b"GET /test/tests/test HTTP/1.1\r\ncontent-length: 99999\r\n\r\n");
sleep(Millis(250)).await;
drop(stream);
sleep(Millis(250)).await;
assert_eq!(count.load(Ordering::Acquire), 1);
}
#[ntex::test]
async fn test_content_length() {
let srv = test_server(|| {
@ -714,7 +744,7 @@ async fn test_h1_client_drop() -> io::Result<()> {
let _st = SetOnDrop(count);
assert!(req.peer_addr().is_some());
assert_eq!(req.version(), Version::HTTP_11);
sleep(Seconds(100)).await;
sleep(Millis(500)).await;
Ok::<_, io::Error>(Response::Ok().finish())
}
})
@ -722,7 +752,7 @@ async fn test_h1_client_drop() -> io::Result<()> {
let result = timeout(Millis(100), srv.request(Method::GET, "/").send()).await;
assert!(result.is_err());
sleep(Millis(250)).await;
sleep(Millis(1000)).await;
assert_eq!(count.load(Ordering::Relaxed), 1);
Ok(())
}