Unregister keep-alive timer after request is received

This commit is contained in:
Nikolay Kim 2022-01-02 02:49:55 +06:00
parent ebc5250f3d
commit 44b00682e0
5 changed files with 64 additions and 32 deletions

View file

@ -85,6 +85,7 @@ impl IoState {
#[inline] #[inline]
pub(super) fn notify_keepalive(&self) { pub(super) fn notify_keepalive(&self) {
log::trace!("keep-alive timeout, notify dispatcher");
let mut flags = self.flags.get(); let mut flags = self.flags.get();
if !flags.contains(Flags::DSP_KEEPALIVE) { if !flags.contains(Flags::DSP_KEEPALIVE) {
flags.insert(Flags::DSP_KEEPALIVE); flags.insert(Flags::DSP_KEEPALIVE);

View file

@ -1,5 +1,9 @@
# Changes # Changes
## [0.5.4] - 2022-01-02
* http1: Unregister keep-alive timer after request is received
## [0.5.3] - 2021-12-31 ## [0.5.3] - 2021-12-31
* Fix WsTransport shutdown, send close frame * Fix WsTransport shutdown, send close frame

View file

@ -1,6 +1,6 @@
[package] [package]
name = "ntex" name = "ntex"
version = "0.5.3" version = "0.5.4"
authors = ["ntex contributors <team@ntex.rs>"] authors = ["ntex contributors <team@ntex.rs>"]
description = "Framework for composable network services" description = "Framework for composable network services"
readme = "README.md" readme = "README.md"

View file

@ -19,13 +19,15 @@ use super::{codec::Codec, Message};
bitflags::bitflags! { bitflags::bitflags! {
pub struct Flags: u16 { pub struct Flags: u16 {
/// We parsed one complete request message /// We parsed one complete request message
const STARTED = 0b0000_0001; const STARTED = 0b0000_0001;
/// Keep-alive is enabled on current connection /// Keep-alive is enabled on current connection
const KEEPALIVE = 0b0000_0010; const KEEPALIVE = 0b0000_0010;
/// Keep-alive is registered
const KEEPALIVE_REG = 0b0000_0100;
/// Upgrade request /// Upgrade request
const UPGRADE = 0b0000_0100; const UPGRADE = 0b0000_1000;
/// Stop after sending payload /// Stop after sending payload
const SENDPAYLOAD_AND_STOP = 0b0000_0100; const SENDPAYLOAD_AND_STOP = 0b0001_0000;
} }
} }
@ -103,7 +105,7 @@ where
state, state,
config, config,
io: Some(io), io: Some(io),
flags: Flags::empty(), flags: Flags::KEEPALIVE_REG,
error: None, error: None,
payload: None, payload: None,
_t: marker::PhantomData, _t: marker::PhantomData,
@ -175,7 +177,6 @@ where
}); });
if result.is_err() { if result.is_err() {
*this.st = State::Stop; *this.st = State::Stop;
this.inner.unregister_keepalive();
this = self.as_mut().project(); this = self.as_mut().project();
continue; continue;
} else if this.inner.flags.contains(Flags::UPGRADE) { } else if this.inner.flags.contains(Flags::UPGRADE) {
@ -231,17 +232,15 @@ where
State::ReadRequest => { State::ReadRequest => {
log::trace!("trying to read http message"); log::trace!("trying to read http message");
let io = this.inner.io();
// decode incoming bytes stream // decode incoming bytes stream
match ready!(io.poll_recv(&this.inner.codec, cx)) { match this.inner.io().poll_recv(&this.inner.codec, cx) {
Ok((mut req, pl)) => { Poll::Ready(Ok((mut req, pl))) => {
log::trace!( log::trace!(
"http message is received: {:?} and payload {:?}", "http message is received: {:?} and payload {:?}",
req, req,
pl pl
); );
req.head_mut().io = Some(io.get_ref()); req.head_mut().io = Some(this.inner.state.clone());
// configure request payload // configure request payload
let upgrade = match pl { let upgrade = match pl {
@ -265,11 +264,10 @@ where
} }
}; };
// unregister slow-request timer // slow-request first request
if !this.inner.flags.contains(Flags::STARTED) { this.inner.flags.insert(Flags::STARTED);
this.inner.flags.insert(Flags::STARTED); this.inner.flags.remove(Flags::KEEPALIVE_REG);
this.inner.io().remove_keepalive_timer(); this.inner.io().remove_keepalive_timer();
}
if upgrade { if upgrade {
// Handle UPGRADE request // Handle UPGRADE request
@ -297,7 +295,7 @@ where
); );
} }
} }
Err(RecvError::WriteBackpressure) => { Poll::Ready(Err(RecvError::WriteBackpressure)) => {
if let Err(err) = ready!(this.inner.io().poll_flush(cx, false)) if let Err(err) = ready!(this.inner.io().poll_flush(cx, false))
{ {
log::trace!("peer is gone with {:?}", err); log::trace!("peer is gone with {:?}", err);
@ -305,23 +303,23 @@ where
this.inner.error = Some(DispatchError::PeerGone(Some(err))); this.inner.error = Some(DispatchError::PeerGone(Some(err)));
} }
} }
Err(RecvError::Decoder(err)) => { Poll::Ready(Err(RecvError::Decoder(err))) => {
// Malformed requests, respond with 400 // Malformed requests, respond with 400
log::trace!("malformed request: {:?}", err); log::trace!("malformed request: {:?}", err);
let (res, body) = Response::BadRequest().finish().into_parts(); let (res, body) = Response::BadRequest().finish().into_parts();
this.inner.error = Some(DispatchError::Parse(err)); this.inner.error = Some(DispatchError::Parse(err));
*this.st = this.inner.send_response(res, body.into_body()); *this.st = this.inner.send_response(res, body.into_body());
} }
Err(RecvError::PeerGone(err)) => { Poll::Ready(Err(RecvError::PeerGone(err))) => {
log::trace!("peer is gone with {:?}", err); log::trace!("peer is gone with {:?}", err);
*this.st = State::Stop; *this.st = State::Stop;
this.inner.error = Some(DispatchError::PeerGone(err)); this.inner.error = Some(DispatchError::PeerGone(err));
} }
Err(RecvError::Stop) => { Poll::Ready(Err(RecvError::Stop)) => {
log::trace!("dispatcher is instructed to stop"); log::trace!("dispatcher is instructed to stop");
*this.st = State::Stop; *this.st = State::Stop;
} }
Err(RecvError::KeepAlive) => { Poll::Ready(Err(RecvError::KeepAlive)) => {
// keep-alive timeout // keep-alive timeout
if !this.inner.flags.contains(Flags::STARTED) { if !this.inner.flags.contains(Flags::STARTED) {
log::trace!("slow request timeout"); log::trace!("slow request timeout");
@ -334,6 +332,18 @@ where
} }
*this.st = State::Stop; *this.st = State::Stop;
} }
Poll::Pending => {
// register keep-alive timer
if this.inner.flags.contains(Flags::KEEPALIVE)
&& !this.inner.flags.contains(Flags::KEEPALIVE_REG)
{
this.inner.flags.insert(Flags::KEEPALIVE_REG);
this.inner
.io()
.start_keepalive_timer(this.inner.config.keep_alive);
}
return Poll::Pending;
}
} }
} }
// consume request's payload // consume request's payload
@ -415,11 +425,9 @@ where
fn switch_to_read_request(&mut self) -> State<B> { fn switch_to_read_request(&mut self) -> State<B> {
// connection is not keep-alive, disconnect // connection is not keep-alive, disconnect
if !self.flags.contains(Flags::KEEPALIVE) || !self.codec.keepalive_enabled() { if !self.flags.contains(Flags::KEEPALIVE) || !self.codec.keepalive_enabled() {
self.unregister_keepalive();
self.state.close(); self.state.close();
State::Stop State::Stop
} else { } else {
self.reset_keepalive();
State::ReadRequest State::ReadRequest
} }
} }
@ -431,13 +439,6 @@ where
} }
} }
fn reset_keepalive(&mut self) {
// re-register keep-alive
if self.flags.contains(Flags::KEEPALIVE) {
self.io().start_keepalive_timer(self.config.keep_alive);
}
}
fn handle_error<E>(&mut self, err: E, critical: bool) -> State<B> fn handle_error<E>(&mut self, err: E, critical: bool) -> State<B>
where where
E: ResponseError + 'static, E: ResponseError + 'static,
@ -524,7 +525,6 @@ where
} else if self.payload.is_some() { } else if self.payload.is_some() {
Some(State::ReadPayload) Some(State::ReadPayload)
} else { } else {
self.reset_keepalive();
Some(self.switch_to_read_request()) Some(self.switch_to_read_request())
} }
} }

View file

@ -210,6 +210,33 @@ async fn test_http1_keepalive_timeout() {
assert_eq!(res, 0); assert_eq!(res, 0);
} }
/// Keep-alive must occure only while waiting complete request
#[ntex::test]
async fn test_http1_no_keepalive_during_response() {
let srv = test_server(|| {
HttpService::build().keep_alive(1).h1(|_| async {
sleep(Millis(1100)).await;
Ok::<_, io::Error>(Response::Ok().finish())
})
});
let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
let _ = stream.write_all(b"GET /test/tests/test HTTP/1.1\r\n\r\n");
let mut data = vec![0; 1024];
let _ = stream.read(&mut data);
assert_eq!(&data[..17], b"HTTP/1.1 200 OK\r\n");
let _ = stream.write_all(b"GET /test/tests/test HTTP/1.1\r\n\r\n");
let mut data = vec![0; 1024];
let _ = stream.read(&mut data);
assert_eq!(&data[..17], b"HTTP/1.1 200 OK\r\n");
let _ = stream.write_all(b"GET /test/tests/test HTTP/1.1\r\n\r\n");
let mut data = vec![0; 1024];
let _ = stream.read(&mut data);
assert_eq!(&data[..17], b"HTTP/1.1 200 OK\r\n");
}
#[ntex::test] #[ntex::test]
async fn test_http1_keepalive_close() { async fn test_http1_keepalive_close() {
let srv = test_server(|| { let srv = test_server(|| {