diff --git a/ntex-io/src/io.rs b/ntex-io/src/io.rs index d6d3cf17..6a65f42c 100644 --- a/ntex-io/src/io.rs +++ b/ntex-io/src/io.rs @@ -85,6 +85,7 @@ impl IoState { #[inline] pub(super) fn notify_keepalive(&self) { + log::trace!("keep-alive timeout, notify dispatcher"); let mut flags = self.flags.get(); if !flags.contains(Flags::DSP_KEEPALIVE) { flags.insert(Flags::DSP_KEEPALIVE); diff --git a/ntex/CHANGES.md b/ntex/CHANGES.md index c3f8044e..3fecef51 100644 --- a/ntex/CHANGES.md +++ b/ntex/CHANGES.md @@ -1,5 +1,9 @@ # Changes +## [0.5.4] - 2022-01-02 + +* http1: Unregister keep-alive timer after request is received + ## [0.5.3] - 2021-12-31 * Fix WsTransport shutdown, send close frame diff --git a/ntex/Cargo.toml b/ntex/Cargo.toml index 669b8a98..9c9b38dc 100644 --- a/ntex/Cargo.toml +++ b/ntex/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex" -version = "0.5.3" +version = "0.5.4" authors = ["ntex contributors "] description = "Framework for composable network services" readme = "README.md" diff --git a/ntex/src/http/h1/dispatcher.rs b/ntex/src/http/h1/dispatcher.rs index e6fc46de..4ec99842 100644 --- a/ntex/src/http/h1/dispatcher.rs +++ b/ntex/src/http/h1/dispatcher.rs @@ -19,13 +19,15 @@ use super::{codec::Codec, Message}; bitflags::bitflags! { pub struct Flags: u16 { /// We parsed one complete request message - const STARTED = 0b0000_0001; + const STARTED = 0b0000_0001; /// Keep-alive is enabled on current connection - const KEEPALIVE = 0b0000_0010; + const KEEPALIVE = 0b0000_0010; + /// Keep-alive is registered + const KEEPALIVE_REG = 0b0000_0100; /// Upgrade request - const UPGRADE = 0b0000_0100; + const UPGRADE = 0b0000_1000; /// Stop after sending payload - const SENDPAYLOAD_AND_STOP = 0b0000_0100; + const SENDPAYLOAD_AND_STOP = 0b0001_0000; } } @@ -103,7 +105,7 @@ where state, config, io: Some(io), - flags: Flags::empty(), + flags: Flags::KEEPALIVE_REG, error: None, payload: None, _t: marker::PhantomData, @@ -175,7 +177,6 @@ where }); if result.is_err() { *this.st = State::Stop; - this.inner.unregister_keepalive(); this = self.as_mut().project(); continue; } else if this.inner.flags.contains(Flags::UPGRADE) { @@ -231,17 +232,15 @@ where State::ReadRequest => { log::trace!("trying to read http message"); - let io = this.inner.io(); - // decode incoming bytes stream - match ready!(io.poll_recv(&this.inner.codec, cx)) { - Ok((mut req, pl)) => { + match this.inner.io().poll_recv(&this.inner.codec, cx) { + Poll::Ready(Ok((mut req, pl))) => { log::trace!( "http message is received: {:?} and payload {:?}", req, pl ); - req.head_mut().io = Some(io.get_ref()); + req.head_mut().io = Some(this.inner.state.clone()); // configure request payload let upgrade = match pl { @@ -265,11 +264,10 @@ where } }; - // unregister slow-request timer - if !this.inner.flags.contains(Flags::STARTED) { - this.inner.flags.insert(Flags::STARTED); - this.inner.io().remove_keepalive_timer(); - } + // slow-request first request + this.inner.flags.insert(Flags::STARTED); + this.inner.flags.remove(Flags::KEEPALIVE_REG); + this.inner.io().remove_keepalive_timer(); if upgrade { // Handle UPGRADE request @@ -297,7 +295,7 @@ where ); } } - Err(RecvError::WriteBackpressure) => { + Poll::Ready(Err(RecvError::WriteBackpressure)) => { if let Err(err) = ready!(this.inner.io().poll_flush(cx, false)) { log::trace!("peer is gone with {:?}", err); @@ -305,23 +303,23 @@ where this.inner.error = Some(DispatchError::PeerGone(Some(err))); } } - Err(RecvError::Decoder(err)) => { + Poll::Ready(Err(RecvError::Decoder(err))) => { // Malformed requests, respond with 400 log::trace!("malformed request: {:?}", err); let (res, body) = Response::BadRequest().finish().into_parts(); this.inner.error = Some(DispatchError::Parse(err)); *this.st = this.inner.send_response(res, body.into_body()); } - Err(RecvError::PeerGone(err)) => { + Poll::Ready(Err(RecvError::PeerGone(err))) => { log::trace!("peer is gone with {:?}", err); *this.st = State::Stop; this.inner.error = Some(DispatchError::PeerGone(err)); } - Err(RecvError::Stop) => { + Poll::Ready(Err(RecvError::Stop)) => { log::trace!("dispatcher is instructed to stop"); *this.st = State::Stop; } - Err(RecvError::KeepAlive) => { + Poll::Ready(Err(RecvError::KeepAlive)) => { // keep-alive timeout if !this.inner.flags.contains(Flags::STARTED) { log::trace!("slow request timeout"); @@ -334,6 +332,18 @@ where } *this.st = State::Stop; } + Poll::Pending => { + // register keep-alive timer + if this.inner.flags.contains(Flags::KEEPALIVE) + && !this.inner.flags.contains(Flags::KEEPALIVE_REG) + { + this.inner.flags.insert(Flags::KEEPALIVE_REG); + this.inner + .io() + .start_keepalive_timer(this.inner.config.keep_alive); + } + return Poll::Pending; + } } } // consume request's payload @@ -415,11 +425,9 @@ where fn switch_to_read_request(&mut self) -> State { // connection is not keep-alive, disconnect if !self.flags.contains(Flags::KEEPALIVE) || !self.codec.keepalive_enabled() { - self.unregister_keepalive(); self.state.close(); State::Stop } else { - self.reset_keepalive(); State::ReadRequest } } @@ -431,13 +439,6 @@ where } } - fn reset_keepalive(&mut self) { - // re-register keep-alive - if self.flags.contains(Flags::KEEPALIVE) { - self.io().start_keepalive_timer(self.config.keep_alive); - } - } - fn handle_error(&mut self, err: E, critical: bool) -> State where E: ResponseError + 'static, @@ -524,7 +525,6 @@ where } else if self.payload.is_some() { Some(State::ReadPayload) } else { - self.reset_keepalive(); Some(self.switch_to_read_request()) } } diff --git a/ntex/tests/http_server.rs b/ntex/tests/http_server.rs index 46f505a3..cedc49ef 100644 --- a/ntex/tests/http_server.rs +++ b/ntex/tests/http_server.rs @@ -210,6 +210,33 @@ async fn test_http1_keepalive_timeout() { assert_eq!(res, 0); } +/// Keep-alive must occure only while waiting complete request +#[ntex::test] +async fn test_http1_no_keepalive_during_response() { + let srv = test_server(|| { + HttpService::build().keep_alive(1).h1(|_| async { + sleep(Millis(1100)).await; + Ok::<_, io::Error>(Response::Ok().finish()) + }) + }); + + let mut stream = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream.write_all(b"GET /test/tests/test HTTP/1.1\r\n\r\n"); + let mut data = vec![0; 1024]; + let _ = stream.read(&mut data); + assert_eq!(&data[..17], b"HTTP/1.1 200 OK\r\n"); + + let _ = stream.write_all(b"GET /test/tests/test HTTP/1.1\r\n\r\n"); + let mut data = vec![0; 1024]; + let _ = stream.read(&mut data); + assert_eq!(&data[..17], b"HTTP/1.1 200 OK\r\n"); + + let _ = stream.write_all(b"GET /test/tests/test HTTP/1.1\r\n\r\n"); + let mut data = vec![0; 1024]; + let _ = stream.read(&mut data); + assert_eq!(&data[..17], b"HTTP/1.1 200 OK\r\n"); +} + #[ntex::test] async fn test_http1_keepalive_close() { let srv = test_server(|| {