From 5f20ee2be5f8fca50b5a0faa69fc3e299ce8bc43 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 12 Aug 2024 20:07:26 +0500 Subject: [PATCH] Http gracefull shutdown support (#393) --- ntex-io/CHANGES.md | 4 ++ ntex-io/Cargo.toml | 2 +- ntex-io/src/ioref.rs | 15 +++++-- ntex/CHANGES.md | 4 ++ ntex/Cargo.toml | 8 ++-- ntex/src/http/config.rs | 33 ++++++++++++-- ntex/src/http/h1/dispatcher.rs | 6 +++ ntex/src/http/h1/service.rs | 73 +++++++++++++++++++++++++----- ntex/src/http/h2/service.rs | 59 +++++++++++++++++++++--- ntex/src/http/service.rs | 73 +++++++++++++++++++++++++----- ntex/src/http/test.rs | 14 +++--- ntex/tests/http_openssl.rs | 49 +++++++++++++++++++- ntex/tests/http_server.rs | 82 +++++++++++++++++++++++++++++++++- 13 files changed, 377 insertions(+), 45 deletions(-) diff --git a/ntex-io/CHANGES.md b/ntex-io/CHANGES.md index 99d74e09..3c9c32cf 100644 --- a/ntex-io/CHANGES.md +++ b/ntex-io/CHANGES.md @@ -1,5 +1,9 @@ # Changes +## [2.2.0] - 2024-08-12 + +* Allow to notify dispatcher from IoRef + ## [2.1.0] - 2024-07-30 * Optimize `Io` layout diff --git a/ntex-io/Cargo.toml b/ntex-io/Cargo.toml index bbcbe8f7..47f52cb0 100644 --- a/ntex-io/Cargo.toml +++ b/ntex-io/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-io" -version = "2.1.0" +version = "2.2.0" authors = ["ntex contributors "] description = "Utilities for encoding and decoding frames" keywords = ["network", "framework", "async", "futures"] diff --git a/ntex-io/src/ioref.rs b/ntex-io/src/ioref.rs index 6cc9bbd9..6ae3dead 100644 --- a/ntex-io/src/ioref.rs +++ b/ntex-io/src/ioref.rs @@ -217,17 +217,24 @@ impl IoRef { } #[inline] - /// current timer handle - pub fn timer_handle(&self) -> timer::TimerHandle { - self.0.timeout.get() + /// Wakeup dispatcher + pub fn notify_dispatcher(&self) { + self.0.dispatch_task.wake(); + log::trace!("{}: Timer, notify dispatcher", self.tag()); } #[inline] - /// wakeup dispatcher and send keep-alive error + /// Wakeup dispatcher and send keep-alive error pub fn notify_timeout(&self) { self.0.notify_timeout() } + #[inline] + /// current timer handle + pub fn timer_handle(&self) -> timer::TimerHandle { + self.0.timeout.get() + } + #[inline] /// Start timer pub fn start_timer(&self, timeout: Seconds) -> timer::TimerHandle { diff --git a/ntex/CHANGES.md b/ntex/CHANGES.md index 91376c91..2980b2a9 100644 --- a/ntex/CHANGES.md +++ b/ntex/CHANGES.md @@ -1,5 +1,9 @@ # Changes +## [2.2.0] - 2024-08-12 + +* Http server gracefull shutdown support + ## [2.1.0] - 2024-07-30 * Better handling for connection upgrade #385 diff --git a/ntex/Cargo.toml b/ntex/Cargo.toml index b3d4cb93..fb73093b 100644 --- a/ntex/Cargo.toml +++ b/ntex/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex" -version = "2.1.0" +version = "2.2.0" authors = ["ntex contributors "] description = "Framework for composable network services" readme = "README.md" @@ -65,10 +65,10 @@ ntex-service = "3.0" ntex-macros = "0.1.3" ntex-util = "2" ntex-bytes = "0.1.27" -ntex-server = "2.1" -ntex-h2 = "1.0" +ntex-server = "2.3" +ntex-h2 = "1.1" ntex-rt = "0.4.13" -ntex-io = "2.1" +ntex-io = "2.2" ntex-net = "2.0" ntex-tls = "2.0" diff --git a/ntex/src/http/config.rs b/ntex/src/http/config.rs index a2b5091d..f68da851 100644 --- a/ntex/src/http/config.rs +++ b/ntex/src/http/config.rs @@ -234,13 +234,23 @@ impl ServiceConfig { } } +bitflags::bitflags! { + #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)] + struct Flags: u8 { + /// Keep-alive enabled + const KA_ENABLED = 0b0000_0001; + /// Shutdown service + const SHUTDOWN = 0b0000_0010; + } +} + pub(super) struct DispatcherConfig { + flags: Cell, pub(super) service: Pipeline, pub(super) control: Pipeline, pub(super) keep_alive: Seconds, pub(super) client_disconnect: Seconds, pub(super) h2config: h2::Config, - pub(super) ka_enabled: bool, pub(super) headers_read_rate: Option, pub(super) payload_read_rate: Option, pub(super) timer: DateService, @@ -253,22 +263,39 @@ impl DispatcherConfig { control: control.into(), keep_alive: cfg.keep_alive, client_disconnect: cfg.client_disconnect, - ka_enabled: cfg.ka_enabled, headers_read_rate: cfg.headers_read_rate, payload_read_rate: cfg.payload_read_rate, h2config: cfg.h2config.clone(), timer: cfg.timer.clone(), + flags: Cell::new(if cfg.ka_enabled { + Flags::KA_ENABLED + } else { + Flags::empty() + }), } } /// Return state of connection keep-alive functionality pub(super) fn keep_alive_enabled(&self) -> bool { - self.ka_enabled + self.flags.get().contains(Flags::KA_ENABLED) } pub(super) fn headers_read_rate(&self) -> Option<&ReadRate> { self.headers_read_rate.as_ref() } + + /// Service is shuting down + pub(super) fn is_shutdown(&self) -> bool { + self.flags.get().contains(Flags::SHUTDOWN) + } + + pub(super) fn shutdown(&self) { + self.h2config.shutdown(); + + let mut flags = self.flags.get(); + flags.insert(Flags::SHUTDOWN); + self.flags.set(flags); + } } const DATE_VALUE_LENGTH_HDR: usize = 39; diff --git a/ntex/src/http/h1/dispatcher.rs b/ntex/src/http/h1/dispatcher.rs index e7e59cdf..244853b8 100644 --- a/ntex/src/http/h1/dispatcher.rs +++ b/ntex/src/http/h1/dispatcher.rs @@ -223,6 +223,12 @@ where B: MessageBody, { fn poll_read_request(&mut self, cx: &mut Context<'_>) -> Poll> { + // stop dispatcher + if self.config.is_shutdown() { + log::trace!("{}: Service is shutting down", self.io.tag()); + return Poll::Ready(self.stop()); + } + log::trace!("{}: Trying to read http message", self.io.tag()); let result = match self.io.poll_recv_decode(&self.codec, cx) { diff --git a/ntex/src/http/h1/service.rs b/ntex/src/http/h1/service.rs index dccf3486..4e69e158 100644 --- a/ntex/src/http/h1/service.rs +++ b/ntex/src/http/h1/service.rs @@ -1,12 +1,12 @@ -use std::{error::Error, fmt, marker, rc::Rc}; +use std::{cell::Cell, cell::RefCell, error::Error, fmt, marker, rc::Rc}; use crate::http::body::MessageBody; use crate::http::config::{DispatcherConfig, ServiceConfig}; use crate::http::error::{DispatchError, ResponseError}; use crate::http::{request::Request, response::Response}; -use crate::io::{types, Filter, Io}; +use crate::io::{types, Filter, Io, IoRef}; use crate::service::{IntoServiceFactory, Service, ServiceCtx, ServiceFactory}; -use crate::util::join; +use crate::{channel::oneshot, util::join, util::HashSet}; use super::control::{Control, ControlAck}; use super::default::DefaultControlService; @@ -181,10 +181,14 @@ where .await .map_err(|e| log::error!("Cannot construct control service: {:?}", e))?; + let (tx, rx) = oneshot::channel(); let config = Rc::new(DispatcherConfig::new(self.cfg.clone(), service, control)); Ok(H1ServiceHandler { config, + inflight: RefCell::new(Default::default()), + rx: Cell::new(Some(rx)), + tx: Cell::new(Some(tx)), _t: marker::PhantomData, }) } @@ -193,6 +197,9 @@ where /// `Service` implementation for HTTP1 transport pub struct H1ServiceHandler { config: Rc>, + inflight: RefCell>, + rx: Cell>>, + tx: Cell>>, _t: marker::PhantomData<(F, B)>, } @@ -224,18 +231,62 @@ where } async fn shutdown(&self) { - self.config.control.shutdown().await; - self.config.service.shutdown().await; + self.config.shutdown(); + + // check inflight connections + let inflight = { + let inflight = self.inflight.borrow(); + for io in inflight.iter() { + io.notify_dispatcher(); + } + inflight.len() + }; + if inflight != 0 { + log::trace!("Shutting down service, in-flight connections: {}", inflight); + + if let Some(rx) = self.rx.take() { + let _ = rx.await; + } + + log::trace!("Shutting down is complected",); + } + + join( + self.config.control.shutdown(), + self.config.service.shutdown(), + ) + .await; } async fn call(&self, io: Io, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> { - log::trace!( - "New http1 connection, peer address {:?}", - io.query::().get() - ); + let inflight = { + let mut inflight = self.inflight.borrow_mut(); + inflight.insert(io.get_ref()); + inflight.len() + }; - Dispatcher::new(io, self.config.clone()) + log::trace!( + "New http1 connection, peer address {:?}, inflight: {}", + io.query::().get(), + inflight + ); + let ioref = io.get_ref(); + + let result = Dispatcher::new(io, self.config.clone()) .await - .map_err(DispatchError::Control) + .map_err(DispatchError::Control); + + { + let mut inflight = self.inflight.borrow_mut(); + inflight.remove(&ioref); + + if inflight.len() == 0 { + if let Some(tx) = self.tx.take() { + let _ = tx.send(()); + } + } + } + + result } } diff --git a/ntex/src/http/h2/service.rs b/ntex/src/http/h2/service.rs index ca53f251..79ba99a8 100644 --- a/ntex/src/http/h2/service.rs +++ b/ntex/src/http/h2/service.rs @@ -1,7 +1,9 @@ -use std::{cell::RefCell, error::Error, fmt, future::poll_fn, io, marker, mem, rc::Rc}; +use std::cell::{Cell, RefCell}; +use std::{error::Error, fmt, future::poll_fn, io, marker, mem, rc::Rc}; use ntex_h2::{self as h2, frame::StreamId, server}; +use crate::channel::oneshot; use crate::http::body::{BodySize, MessageBody}; use crate::http::config::{DispatcherConfig, ServiceConfig}; use crate::http::error::{DispatchError, H2Error, ResponseError}; @@ -10,7 +12,7 @@ use crate::http::message::{CurrentIo, ResponseHead}; use crate::http::{DateService, Method, Request, Response, StatusCode, Uri, Version}; use crate::io::{types, Filter, Io, IoBoxed, IoRef}; use crate::service::{IntoServiceFactory, Service, ServiceCtx, ServiceFactory}; -use crate::util::{Bytes, BytesMut, HashMap}; +use crate::util::{Bytes, BytesMut, HashMap, HashSet}; use super::payload::{Payload, PayloadSender}; use super::DefaultControlService; @@ -177,11 +179,16 @@ where .create(()) .await .map_err(|e| log::error!("Cannot construct publish service: {:?}", e))?; + + let (tx, rx) = oneshot::channel(); let config = Rc::new(DispatcherConfig::new(self.cfg.clone(), service, ())); Ok(H2ServiceHandler { config, control: self.ctl.clone(), + inflight: RefCell::new(Default::default()), + rx: Cell::new(Some(rx)), + tx: Cell::new(Some(tx)), _t: marker::PhantomData, }) } @@ -191,6 +198,9 @@ where pub struct H2ServiceHandler, B, C> { config: Rc>, control: Rc, + inflight: RefCell>, + rx: Cell>>, + tx: Cell>>, _t: marker::PhantomData<(F, B)>, } @@ -218,6 +228,25 @@ where #[inline] async fn shutdown(&self) { + self.config.shutdown(); + + // check inflight connections + let inflight = { + let inflight = self.inflight.borrow(); + for io in inflight.iter() { + io.notify_dispatcher(); + } + inflight.len() + }; + if inflight != 0 { + log::trace!("Shutting down service, in-flight connections: {}", inflight); + if let Some(rx) = self.rx.take() { + let _ = rx.await; + } + + log::trace!("Shutting down is complected",); + } + self.config.service.shutdown().await } @@ -226,9 +255,16 @@ where io: Io, _: ServiceCtx<'_, Self>, ) -> Result { + let inflight = { + let mut inflight = self.inflight.borrow_mut(); + inflight.insert(io.get_ref()); + inflight.len() + }; + log::trace!( - "New http2 connection, peer address {:?}", - io.query::().get() + "New http2 connection, peer address {:?}, inflight: {}", + io.query::().get(), + inflight ); let control = self.control.create(()).await.map_err(|e| { DispatchError::Control( @@ -236,7 +272,20 @@ where ) })?; - handle(io.into(), control, self.config.clone()).await + let ioref = io.get_ref(); + let result = handle(io.into(), control, self.config.clone()).await; + { + let mut inflight = self.inflight.borrow_mut(); + inflight.remove(&ioref); + + if inflight.len() == 0 { + if let Some(tx) = self.tx.take() { + let _ = tx.send(()); + } + } + } + + result } } diff --git a/ntex/src/http/service.rs b/ntex/src/http/service.rs index 0c554f21..4cab50ff 100644 --- a/ntex/src/http/service.rs +++ b/ntex/src/http/service.rs @@ -1,8 +1,8 @@ -use std::{error, fmt, marker, rc::Rc}; +use std::{cell::Cell, cell::RefCell, error, fmt, marker, rc::Rc}; -use crate::io::{types, Filter, Io}; +use crate::io::{types, Filter, Io, IoRef}; use crate::service::{IntoServiceFactory, Service, ServiceCtx, ServiceFactory}; -use crate::util::join; +use crate::{channel::oneshot, util::join, util::HashSet}; use super::body::MessageBody; use super::builder::HttpServiceBuilder; @@ -257,11 +257,15 @@ where .await .map_err(|e| log::error!("Cannot construct control service: {:?}", e))?; + let (tx, rx) = oneshot::channel(); let config = DispatcherConfig::new(self.cfg.clone(), service, control); Ok(HttpServiceHandler { config: Rc::new(config), h2_control: self.h2_control.clone(), + inflight: RefCell::new(HashSet::default()), + rx: Cell::new(Some(rx)), + tx: Cell::new(Some(tx)), _t: marker::PhantomData, }) } @@ -271,6 +275,9 @@ where pub struct HttpServiceHandler { config: Rc>, h2_control: Rc, + inflight: RefCell>, + rx: Cell>>, + tx: Cell>>, _t: marker::PhantomData<(F, B)>, } @@ -306,8 +313,31 @@ where #[inline] async fn shutdown(&self) { - self.config.control.shutdown().await; - self.config.service.shutdown().await; + self.config.shutdown(); + + // check inflight connections + let inflight = { + let inflight = self.inflight.borrow(); + for io in inflight.iter() { + io.notify_dispatcher(); + } + inflight.len() + }; + if inflight != 0 { + log::trace!("Shutting down service, in-flight connections: {}", inflight); + + if let Some(rx) = self.rx.take() { + let _ = rx.await; + } + + log::trace!("Shutting down is complected",); + } + + join( + self.config.control.shutdown(), + self.config.service.shutdown(), + ) + .await; } async fn call( @@ -315,12 +345,22 @@ where io: Io, _: ServiceCtx<'_, Self>, ) -> Result { - log::trace!( - "New http connection, peer address {:?}", - io.query::().get() - ); + let inflight = { + let mut inflight = self.inflight.borrow_mut(); + inflight.insert(io.get_ref()); + inflight.len() + }; - if io.query::().get() == Some(types::HttpProtocol::Http2) { + log::trace!( + "New http connection, peer address {:?}, in-flight: {}", + io.query::().get(), + inflight + ); + let ioref = io.get_ref(); + + let result = if io.query::().get() + == Some(types::HttpProtocol::Http2) + { let control = self.h2_control.create(()).await.map_err(|e| { DispatchError::Control( format!("Cannot construct control service: {:?}", e).into(), @@ -331,6 +371,19 @@ where h1::Dispatcher::new(io, self.config.clone()) .await .map_err(DispatchError::Control) + }; + + { + let mut inflight = self.inflight.borrow_mut(); + inflight.remove(&ioref); + + if inflight.len() == 0 { + if let Some(tx) = self.tx.take() { + let _ = tx.send(()); + } + } } + + result } } diff --git a/ntex/src/http/test.rs b/ntex/src/http/test.rs index 4ee578b0..2f6d164e 100644 --- a/ntex/src/http/test.rs +++ b/ntex/src/http/test.rs @@ -7,6 +7,7 @@ use coo_kie::{Cookie, CookieJar}; #[cfg(feature = "ws")] use crate::io::Filter; use crate::io::Io; +use crate::server::Server; #[cfg(feature = "ws")] use crate::ws::{error::WsClientError, WsClient, WsConnection}; use crate::{rt::System, service::ServiceFactory}; @@ -237,18 +238,18 @@ where let system = sys.system(); sys.run(move || { - crate::server::build() + let srv = crate::server::build() .listen("test", tcp, move |_| factory())? .set_tag("test", "HTTP-TEST-SRV") .workers(1) .disable_signals() .run(); - tx.send((system, local_addr)).unwrap(); + tx.send((system, srv, local_addr)).unwrap(); Ok(()) }) }); - let (system, addr) = rx.recv().unwrap(); + let (system, server, addr) = rx.recv().unwrap(); let client = { let connector = { @@ -286,6 +287,7 @@ where addr, client, system, + server, } } @@ -295,6 +297,7 @@ pub struct TestServer { addr: net::SocketAddr, client: Client, system: System, + server: Server, } impl TestServer { @@ -402,13 +405,14 @@ impl TestServer { } /// Stop http server - fn stop(&mut self) { + pub async fn stop(&self) { + self.server.stop(true).await; self.system.stop(); } } impl Drop for TestServer { fn drop(&mut self) { - self.stop() + self.system.stop(); } } diff --git a/ntex/tests/http_openssl.rs b/ntex/tests/http_openssl.rs index dd45fd7a..c91de0b8 100644 --- a/ntex/tests/http_openssl.rs +++ b/ntex/tests/http_openssl.rs @@ -12,7 +12,7 @@ use ntex::http::{body, h1, HttpService, Method, Request, Response, StatusCode, V use ntex::service::{fn_service, ServiceFactory}; use ntex::time::{sleep, timeout, Millis, Seconds}; use ntex::util::{Bytes, BytesMut, Ready}; -use ntex::{web::error::InternalError, ws, ws::handshake_response}; +use ntex::{channel::oneshot, rt, web::error::InternalError, ws, ws::handshake_response}; async fn load_body(stream: S) -> Result where @@ -534,3 +534,50 @@ async fn test_ws_transport() { })) ); } + +#[ntex::test] +async fn test_h2_graceful_shutdown() -> io::Result<()> { + let count = Arc::new(AtomicUsize::new(0)); + let count2 = count.clone(); + + let srv = test_server(move || { + let count = count2.clone(); + HttpService::build() + .h2(move |_| { + let count = count.clone(); + count.fetch_add(1, Ordering::Relaxed); + async move { + sleep(Millis(1000)).await; + count.fetch_sub(1, Ordering::Relaxed); + Ok::<_, io::Error>(Response::Ok().finish()) + } + }) + .openssl(ssl_acceptor()) + .map_err(|_| ()) + }); + + let req = srv.srequest(Method::GET, "/"); + rt::spawn(async move { + let _ = req.send().await.unwrap(); + sleep(Millis(100000)).await; + }); + let req = srv.srequest(Method::GET, "/"); + rt::spawn(async move { + let _ = req.send().await.unwrap(); + sleep(Millis(100000)).await; + }); + sleep(Millis(150)).await; + assert_eq!(count.load(Ordering::Relaxed), 2); + + let (tx, rx) = oneshot::channel(); + rt::spawn(async move { + srv.stop().await; + let _ = tx.send(()); + }); + sleep(Millis(150)).await; + assert_eq!(count.load(Ordering::Relaxed), 2); + + let _ = rx.await; + assert_eq!(count.load(Ordering::Relaxed), 0); + Ok(()) +} diff --git a/ntex/tests/http_server.rs b/ntex/tests/http_server.rs index 16a9aabe..3f91cb9e 100644 --- a/ntex/tests/http_server.rs +++ b/ntex/tests/http_server.rs @@ -9,7 +9,9 @@ use ntex::http::header::{self, HeaderName, HeaderValue}; use ntex::http::{body, h1::Control, test::server as test_server}; use ntex::http::{HttpService, KeepAlive, Method, Request, Response, StatusCode, Version}; use ntex::time::{sleep, timeout, Millis, Seconds}; -use ntex::{service::fn_service, util::Bytes, util::Ready, web::error}; +use ntex::{ + channel::oneshot, rt, service::fn_service, util::Bytes, util::Ready, web::error, +}; #[ntex::test] async fn test_h1() { @@ -724,3 +726,81 @@ async fn test_h1_client_drop() -> io::Result<()> { assert_eq!(count.load(Ordering::Relaxed), 1); Ok(()) } + +#[ntex::test] +async fn test_h1_gracefull_shutdown() { + let count = Arc::new(AtomicUsize::new(0)); + let count2 = count.clone(); + + let srv = test_server(move || { + let count = count2.clone(); + HttpService::build().h1(move |_: Request| { + let count = count.clone(); + count.fetch_add(1, Ordering::Relaxed); + async move { + sleep(Millis(1000)).await; + count.fetch_sub(1, Ordering::Relaxed); + Ok::<_, io::Error>(Response::Ok().finish()) + } + }) + }); + + let mut stream1 = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream1.write_all(b"GET /index.html HTTP/1.1\r\n\r\n"); + + let mut stream2 = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream2.write_all(b"GET /index.html HTTP/1.1\r\n\r\n"); + + sleep(Millis(150)).await; + assert_eq!(count.load(Ordering::Relaxed), 2); + + let (tx, rx) = oneshot::channel(); + rt::spawn(async move { + srv.stop().await; + let _ = tx.send(()); + }); + sleep(Millis(150)).await; + assert_eq!(count.load(Ordering::Relaxed), 2); + + let _ = rx.await; + assert_eq!(count.load(Ordering::Relaxed), 0); +} + +#[ntex::test] +async fn test_h1_gracefull_shutdown_2() { + let count = Arc::new(AtomicUsize::new(0)); + let count2 = count.clone(); + + let srv = test_server(move || { + let count = count2.clone(); + HttpService::build().finish(move |_: Request| { + let count = count.clone(); + count.fetch_add(1, Ordering::Relaxed); + async move { + sleep(Millis(1000)).await; + count.fetch_sub(1, Ordering::Relaxed); + Ok::<_, io::Error>(Response::Ok().finish()) + } + }) + }); + + let mut stream1 = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream1.write_all(b"GET /index.html HTTP/1.1\r\n\r\n"); + + let mut stream2 = net::TcpStream::connect(srv.addr()).unwrap(); + let _ = stream2.write_all(b"GET /index.html HTTP/1.1\r\n\r\n"); + + sleep(Millis(150)).await; + assert_eq!(count.load(Ordering::Relaxed), 2); + + let (tx, rx) = oneshot::channel(); + rt::spawn(async move { + srv.stop().await; + let _ = tx.send(()); + }); + sleep(Millis(150)).await; + assert_eq!(count.load(Ordering::Relaxed), 2); + + let _ = rx.await; + assert_eq!(count.load(Ordering::Relaxed), 0); +}