From 3b12a77e921fd35781f9062c06000adb996801a8 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Tue, 7 Apr 2020 21:36:48 +0600 Subject: [PATCH] Fix io close (#12) * Fix io close for Framed * Fix connection shutdown for h1 dispatcher * Enable client disconnect for http server by default * Add connection disconnect timeout to framed service --- ntex-codec/CHANGES.md | 6 + ntex-codec/Cargo.toml | 12 +- ntex-codec/src/framed.rs | 261 ++++++++++++++++++++++++++------- ntex-codec/src/lib.rs | 2 +- ntex/CHANGES.md | 8 + ntex/Cargo.toml | 6 +- ntex/examples/echo.rs | 2 +- ntex/examples/hello-world.rs | 2 +- ntex/src/framed/dispatcher.rs | 37 ++++- ntex/src/framed/handshake.rs | 2 +- ntex/src/framed/service.rs | 59 +++++++- ntex/src/http/builder.rs | 12 +- ntex/src/http/h1/dispatcher.rs | 57 ++++--- ntex/src/http/h2/dispatcher.rs | 7 - ntex/src/http/test.rs | 2 - ntex/src/testing.rs | 188 +++++++++++++++++------- ntex/src/web/server.rs | 19 +-- ntex/tests/http_awc_client.rs | 7 +- ntex/tests/http_rustls.rs | 2 +- ntex/tests/http_server.rs | 4 +- ntex/tests/web_httpserver.rs | 3 +- 21 files changed, 529 insertions(+), 169 deletions(-) diff --git a/ntex-codec/CHANGES.md b/ntex-codec/CHANGES.md index 82ac7abc..56efe8be 100644 --- a/ntex-codec/CHANGES.md +++ b/ntex-codec/CHANGES.md @@ -1,5 +1,11 @@ # Changes +## [0.1.1] - 2020-04-07 + +* Optimize io operations + +* Fix framed close method + ## [0.1.0] - 2020-03-31 * Fork crate to ntex namespace diff --git a/ntex-codec/Cargo.toml b/ntex-codec/Cargo.toml index 17408325..e37ae5d1 100644 --- a/ntex-codec/Cargo.toml +++ b/ntex-codec/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-codec" -version = "0.1.0" +version = "0.1.1" authors = ["Nikolay Kim "] description = "Utilities for encoding and decoding frames" keywords = ["network", "framework", "async", "futures"] @@ -8,7 +8,7 @@ homepage = "https://ntex.rs" repository = "https://github.com/ntex-rs/ntex.git" documentation = "https://docs.rs/ntex-codec/" categories = ["network-programming", "asynchronous"] -license = "MIT/Apache-2.0" +license = "MIT" edition = "2018" [lib] @@ -20,6 +20,10 @@ bitflags = "1.2.1" bytes = "0.5.4" futures-core = "0.3.4" futures-sink = "0.3.4" -tokio = { version = "0.2.4", default-features=false } +tokio = { version = "0.2.6", default-features=false } tokio-util = { version = "0.2.0", default-features=false, features=["codec"] } -log = "0.4" \ No newline at end of file +log = "0.4" + +[dev-dependencies] +ntex = "0.1.4" +futures = "0.3.4" diff --git a/ntex-codec/src/framed.rs b/ntex-codec/src/framed.rs index 66a02b8c..dc772ffa 100644 --- a/ntex-codec/src/framed.rs +++ b/ntex-codec/src/framed.rs @@ -13,13 +13,16 @@ const HW: usize = 8 * 1024; bitflags::bitflags! { struct Flags: u8 { - const EOF = 0b0001; - const READABLE = 0b0010; + const EOF = 0b0001; + const READABLE = 0b0010; + const DISCONNECTED = 0b0100; + const SHUTDOWN = 0b1000; } } /// A unified `Stream` and `Sink` interface to an underlying I/O object, using /// the `Encoder` and `Decoder` traits to encode and decode frames. +/// `Framed` is heavily optimized for streaming io. pub struct Framed { io: T, codec: U, @@ -28,8 +31,6 @@ pub struct Framed { write_buf: BytesMut, } -impl Unpin for Framed {} - impl Framed where T: AsyncRead + AsyncWrite, @@ -123,6 +124,18 @@ impl Framed { &mut self.io } + #[inline] + /// Get read buffer. + pub fn read_buf_mut(&mut self) -> &mut BytesMut { + &mut self.read_buf + } + + #[inline] + /// Get write buffer. + pub fn write_buf_mut(&mut self) -> &mut BytesMut { + &mut self.write_buf + } + #[inline] /// Check if write buffer is empty. pub fn is_write_buf_empty(&self) -> bool { @@ -135,6 +148,12 @@ impl Framed { self.write_buf.len() >= HW } + #[inline] + /// Check if framed object is closed + pub fn is_closed(&self) -> bool { + self.flags.contains(Flags::DISCONNECTED) + } + #[inline] /// Consume the `Frame`, returning `Frame` with different codec. pub fn into_framed(self, codec: U2) -> Framed { @@ -227,34 +246,87 @@ where pub fn flush(&mut self, cx: &mut Context<'_>) -> Poll> { log::trace!("flushing framed transport"); - while !self.write_buf.is_empty() { - log::trace!("writing; remaining={}", self.write_buf.len()); - - let n = ready!(Pin::new(&mut self.io).poll_write(cx, &self.write_buf))?; - if n == 0 { - return Poll::Ready(Err(io::Error::new( - io::ErrorKind::WriteZero, - "failed to write frame to transport", - ) - .into())); - } - - // remove written data - self.write_buf.advance(n); + let len = self.write_buf.len(); + if len == 0 { + return Poll::Ready(Ok(())); } - // Try flushing the underlying IO - ready!(Pin::new(&mut self.io).poll_flush(cx))?; + let mut written = 0; + while written < len { + match Pin::new(&mut self.io).poll_write(cx, &self.write_buf[written..]) { + Poll::Pending => break, + Poll::Ready(Ok(n)) => { + if n == 0 { + log::trace!("Disconnected during flush, written {}", written); + self.flags.insert(Flags::DISCONNECTED); + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::WriteZero, + "failed to write frame to transport", + ) + .into())); + } else { + written += n + } + } + Poll::Ready(Err(e)) => { + log::trace!("Error during flush: {}", e); + self.flags.insert(Flags::DISCONNECTED); + return Poll::Ready(Err(e.into())); + } + } + } - log::trace!("framed transport flushed"); - Poll::Ready(Ok(())) + // remove written data + if written == len { + // flushed same amount as in buffer, we dont need to reallocate + unsafe { self.write_buf.set_len(0) } + } else { + self.write_buf.advance(written); + } + if self.write_buf.is_empty() { + Poll::Ready(Ok(())) + } else { + Poll::Pending + } } +} +impl Framed +where + T: AsyncRead + AsyncWrite + Unpin, +{ #[inline] /// Flush write buffer and shutdown underlying I/O stream. - pub fn close(&mut self, cx: &mut Context<'_>) -> Poll> { - ready!(Pin::new(&mut self.io).poll_flush(cx))?; - ready!(Pin::new(&mut self.io).poll_shutdown(cx))?; + /// + /// Close method shutdown write side of a io object and + /// then reads until disconnect or error, high level code must use + /// timeout for close operation. + pub fn close(&mut self, cx: &mut Context<'_>) -> Poll> { + if !self.flags.contains(Flags::DISCONNECTED) { + // flush write buffer + ready!(Pin::new(&mut self.io).poll_flush(cx))?; + + if !self.flags.contains(Flags::SHUTDOWN) { + // shutdown WRITE side + ready!(Pin::new(&mut self.io).poll_shutdown(cx)).map_err(|e| { + self.flags.insert(Flags::DISCONNECTED); + e + })?; + self.flags.insert(Flags::SHUTDOWN); + } + + // read until 0 or err + let mut buf = [0u8; 512]; + loop { + match ready!(Pin::new(&mut self.io).poll_read(cx, &mut buf)) { + Err(_) | Ok(0) => { + break; + } + _ => (), + } + } + self.flags.insert(Flags::DISCONNECTED); + } log::trace!("framed transport flushed and closed"); Poll::Ready(Ok(())) } @@ -269,11 +341,9 @@ where pub fn next_item( &mut self, cx: &mut Context<'_>, - ) -> Poll>> - where - T: AsyncRead, - U: Decoder, - { + ) -> Poll>> { + let mut done_read = false; + loop { // Repeatedly call `decode` or `decode_eof` as long as it is // "readable". Readable is defined as not having returned `None`. If @@ -302,26 +372,45 @@ where } self.flags.remove(Flags::READABLE); + if done_read { + return Poll::Pending; + } } debug_assert!(!self.flags.contains(Flags::EOF)); - // Otherwise, try to read more data and try again. Make sure we've got room - let remaining = self.read_buf.capacity() - self.read_buf.len(); - if remaining < LW { - self.read_buf.reserve(HW - remaining) + // read all data from socket + let mut updated = false; + loop { + // Otherwise, try to read more data and try again. Make sure we've got room + let remaining = self.read_buf.capacity() - self.read_buf.len(); + if remaining < LW { + self.read_buf.reserve(HW - remaining) + } + match Pin::new(&mut self.io).poll_read_buf(cx, &mut self.read_buf) { + Poll::Pending => { + if updated { + done_read = true; + self.flags.insert(Flags::READABLE); + break; + } else { + return Poll::Pending; + } + } + Poll::Ready(Ok(n)) => { + if n == 0 { + self.flags.insert(Flags::EOF | Flags::READABLE); + if updated { + done_read = true; + } + break; + } else { + updated = true; + } + } + Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e.into()))), + } } - let cnt = match Pin::new(&mut self.io).poll_read_buf(cx, &mut self.read_buf) - { - Poll::Pending => return Poll::Pending, - Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e.into()))), - Poll::Ready(Ok(cnt)) => cnt, - }; - - if cnt == 0 { - self.flags.insert(Flags::EOF); - } - self.flags.insert(Flags::READABLE); } } } @@ -329,7 +418,7 @@ where impl Stream for Framed where T: AsyncRead + Unpin, - U: Decoder, + U: Decoder + Unpin, { type Item = Result; @@ -344,8 +433,8 @@ where impl Sink for Framed where - T: AsyncWrite + Unpin, - U: Encoder, + T: AsyncRead + AsyncWrite + Unpin, + U: Encoder + Unpin, U::Error: From, { type Error = U::Error; @@ -383,7 +472,7 @@ where mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { - self.close(cx) + self.close(cx).map_err(|e| e.into()) } } @@ -443,3 +532,77 @@ impl FramedParts { } } } + +#[cfg(test)] +mod tests { + use bytes::Bytes; + use futures::future::lazy; + use futures::Sink; + use ntex::testing::Io; + + use super::*; + use crate::BytesCodec; + + #[ntex::test] + async fn test_sink() { + let (client, server) = Io::create(); + client.remote_buffer_cap(1024); + let mut server = Framed::new(server, BytesCodec); + + assert!(lazy(|cx| Pin::new(&mut server).poll_ready(cx)) + .await + .is_ready()); + + let data = Bytes::from_static(b"GET /test HTTP/1.1\r\n\r\n"); + Pin::new(&mut server).start_send(data).unwrap(); + assert_eq!(client.read_any(), b"".as_ref()); + + assert!(lazy(|cx| Pin::new(&mut server).poll_flush(cx)) + .await + .is_ready()); + assert_eq!(client.read_any(), b"GET /test HTTP/1.1\r\n\r\n".as_ref()); + + assert!(lazy(|cx| Pin::new(&mut server).poll_close(cx)) + .await + .is_pending()); + client.close().await; + assert!(lazy(|cx| Pin::new(&mut server).poll_close(cx)) + .await + .is_ready()); + assert!(client.is_closed()); + } + + #[ntex::test] + async fn test_write_pending() { + let (client, server) = Io::create(); + let mut server = Framed::new(server, BytesCodec); + + assert!(lazy(|cx| Pin::new(&mut server).poll_ready(cx)) + .await + .is_ready()); + let data = Bytes::from_static(b"GET /test HTTP/1.1\r\n\r\n"); + Pin::new(&mut server).start_send(data).unwrap(); + + client.remote_buffer_cap(3); + assert!(lazy(|cx| Pin::new(&mut server).poll_flush(cx)) + .await + .is_pending()); + assert_eq!(client.read_any(), b"GET".as_ref()); + + client.remote_buffer_cap(1024); + assert!(lazy(|cx| Pin::new(&mut server).poll_flush(cx)) + .await + .is_ready()); + assert_eq!(client.read_any(), b" /test HTTP/1.1\r\n\r\n".as_ref()); + + assert!(lazy(|cx| Pin::new(&mut server).poll_close(cx)) + .await + .is_pending()); + client.close().await; + assert!(lazy(|cx| Pin::new(&mut server).poll_close(cx)) + .await + .is_ready()); + assert!(client.is_closed()); + assert!(server.is_closed()); + } +} diff --git a/ntex-codec/src/lib.rs b/ntex-codec/src/lib.rs index fc53b688..0e3109ed 100644 --- a/ntex-codec/src/lib.rs +++ b/ntex-codec/src/lib.rs @@ -6,7 +6,7 @@ //! //! [`AsyncRead`]: # //! [`AsyncWrite`]: # -#![deny(rust_2018_idioms, warnings)] +// #![deny(rust_2018_idioms, warnings)] mod bcodec; mod framed; diff --git a/ntex/CHANGES.md b/ntex/CHANGES.md index 336ac53a..70241970 100644 --- a/ntex/CHANGES.md +++ b/ntex/CHANGES.md @@ -1,5 +1,13 @@ # Changes +## [0.1.5] - 2020-04-07 + +* ntex::http: enable client disconnect timeout by default + +* ntex::http: properly close h1 connection + +* ntex::framed: add connection disconnect timeout to framed service + ## [0.1.4] - 2020-04-06 * Remove unneeded RefCell from client connector diff --git a/ntex/Cargo.toml b/ntex/Cargo.toml index 4d170a60..2913bcbf 100644 --- a/ntex/Cargo.toml +++ b/ntex/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex" -version = "0.1.4" +version = "0.1.5" authors = ["Nikolay Kim "] description = "Framework for composable network services" readme = "README.md" @@ -36,10 +36,10 @@ compress = ["flate2", "brotli2"] cookie = ["coo-kie", "coo-kie/percent-encode"] [dependencies] -ntex-codec = "0.1" +ntex-codec = "0.1.1" ntex-rt = "0.1" ntex-rt-macros = "0.1" -ntex-router = "0.3.1" +ntex-router = "0.3.2" ntex-service = "0.1" actix-threadpool = "0.3.1" diff --git a/ntex/examples/echo.rs b/ntex/examples/echo.rs index c244b088..7de444a1 100644 --- a/ntex/examples/echo.rs +++ b/ntex/examples/echo.rs @@ -16,7 +16,7 @@ async fn main() -> io::Result<()> { .bind("echo", "127.0.0.1:8080", || { HttpService::build() .client_timeout(1000) - .client_disconnect(1000) + .disconnect_timeout(1000) .finish(|mut req: Request| async move { let mut body = BytesMut::new(); while let Some(item) = req.payload().next().await { diff --git a/ntex/examples/hello-world.rs b/ntex/examples/hello-world.rs index 3bc0c340..1971bd76 100644 --- a/ntex/examples/hello-world.rs +++ b/ntex/examples/hello-world.rs @@ -15,7 +15,7 @@ async fn main() -> io::Result<()> { .bind("hello-world", "127.0.0.1:8080", || { HttpService::build() .client_timeout(1000) - .client_disconnect(1000) + .disconnect_timeout(1000) .finish(|_req| { info!("{:?}", _req); let mut res = Response::Ok(); diff --git a/ntex/src/framed/dispatcher.rs b/ntex/src/framed/dispatcher.rs index edb8cdd5..5f411d09 100644 --- a/ntex/src/framed/dispatcher.rs +++ b/ntex/src/framed/dispatcher.rs @@ -1,12 +1,15 @@ //! Framed dispatcher service and related utilities +use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; +use std::time::Duration; -use futures::Stream; +use futures::{ready, Stream}; use log::debug; use crate::channel::mpsc; use crate::codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed}; +use crate::rt::time::{delay_for, Delay}; use crate::service::Service; use super::error::ServiceError; @@ -32,6 +35,7 @@ where state: FramedState, framed: Framed, rx: mpsc::Receiver::Item, S::Error>>, + disconnect_timeout: usize, } impl Dispatcher @@ -45,13 +49,19 @@ where ::Error: std::fmt::Debug, Out: Stream::Item> + Unpin, { - pub(super) fn new(framed: Framed, service: S, sink: Option) -> Self { + pub(super) fn new( + framed: Framed, + service: S, + sink: Option, + timeout: usize, + ) -> Self { Dispatcher { sink, service, framed, rx: mpsc::channel().1, state: FramedState::Processing, + disconnect_timeout: timeout, } } } @@ -61,6 +71,7 @@ enum FramedState { Error(ServiceError), FlushAndStop, Shutdown(Option>), + ShutdownIo(Delay), } #[derive(Copy, Clone, PartialEq, Eq, Debug)] @@ -250,12 +261,32 @@ where if let Some(err) = err.take() { Poll::Ready(Err(err)) } else { - Poll::Ready(Ok(())) + let pending = self.framed.close(cx).is_pending(); + if self.disconnect_timeout == 0 && pending { + self.state = FramedState::ShutdownIo(delay_for( + Duration::from_millis( + self.disconnect_timeout as u64, + ), + )); + continue; + } else { + Poll::Ready(Ok(())) + } } } else { Poll::Pending } } + FramedState::ShutdownIo(ref mut delay) => { + if let Poll::Ready(res) = self.framed.close(cx) { + return Poll::Ready( + res.map_err(|e| ServiceError::Encoder(e.into())), + ); + } else { + ready!(Pin::new(delay).poll(cx)); + return Poll::Ready(Ok(())); + } + } } } } diff --git a/ntex/src/framed/handshake.rs b/ntex/src/framed/handshake.rs index fbbb6545..08a109c3 100644 --- a/ntex/src/framed/handshake.rs +++ b/ntex/src/framed/handshake.rs @@ -130,6 +130,6 @@ where self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { - self.get_mut().framed.close(cx) + self.get_mut().framed.close(cx).map_err(|e| e.into()) } } diff --git a/ntex/src/framed/service.rs b/ntex/src/framed/service.rs index f955f0f1..9c029a0a 100644 --- a/ntex/src/framed/service.rs +++ b/ntex/src/framed/service.rs @@ -23,6 +23,7 @@ type ResponseItem = Option<::Item>; /// for building instances for framed services. pub struct Builder { connect: C, + disconnect_timeout: usize, _t: PhantomData<(St, Io, Codec, Out)>, } @@ -46,10 +47,24 @@ where { Builder { connect: connect.into_service(), + disconnect_timeout: 3000, _t: PhantomData, } } + /// Set connection disconnect timeout in milliseconds. + /// + /// Defines a timeout for disconnect connection. If a disconnect procedure does not complete + /// within this time, the connection get dropped. + /// + /// To disable timeout set value to 0. + /// + /// By default disconnect timeout is set to 3 seconds. + pub fn disconnect_timeout(mut self, val: usize) -> Self { + self.disconnect_timeout = val; + self + } + /// Provide stream items handler service and construct service factory. pub fn build(self, service: F) -> FramedServiceImpl where @@ -65,6 +80,7 @@ where FramedServiceImpl { connect: self.connect, handler: Rc::new(service.into_factory()), + disconnect_timeout: self.disconnect_timeout, _t: PhantomData, } } @@ -74,6 +90,7 @@ where /// for building instances for framed services. pub struct FactoryBuilder { connect: C, + disconnect_timeout: usize, _t: PhantomData<(St, Io, Codec, Out)>, } @@ -97,10 +114,24 @@ where { FactoryBuilder { connect: connect.into_factory(), + disconnect_timeout: 3000, _t: PhantomData, } } + /// Set connection disconnect timeout in milliseconds. + /// + /// Defines a timeout for disconnect connection. If a disconnect procedure does not complete + /// within this time, the connection get dropped. + /// + /// To disable timeout set value to 0. + /// + /// By default disconnect timeout is set to 3 seconds. + pub fn disconnect_timeout(mut self, val: usize) -> Self { + self.disconnect_timeout = val; + self + } + pub fn build( self, service: F, @@ -118,6 +149,7 @@ where FramedService { connect: self.connect, handler: Rc::new(service.into_factory()), + disconnect_timeout: self.disconnect_timeout, _t: PhantomData, } } @@ -126,6 +158,7 @@ where pub struct FramedService { connect: C, handler: Rc, + disconnect_timeout: usize, _t: PhantomData<(St, Io, Codec, Out, Cfg)>, } @@ -166,6 +199,7 @@ where FramedServiceResponse { fut: self.connect.new_service(()), handler: self.handler.clone(), + disconnect_timeout: self.disconnect_timeout, } } } @@ -197,6 +231,7 @@ where #[pin] fut: C::Future, handler: Rc, + disconnect_timeout: usize, } impl Future for FramedServiceResponse @@ -232,6 +267,7 @@ where Poll::Ready(Ok(FramedServiceImpl { connect, handler: this.handler.clone(), + disconnect_timeout: *this.disconnect_timeout, _t: PhantomData, })) } @@ -240,6 +276,7 @@ where pub struct FramedServiceImpl { connect: C, handler: Rc, + disconnect_timeout: usize, _t: PhantomData<(St, Io, Codec, Out)>, } @@ -287,6 +324,7 @@ where inner: FramedServiceImplResponseInner::Handshake( self.connect.call(Handshake::new(req)), self.handler.clone(), + self.disconnect_timeout, ), } } @@ -382,8 +420,13 @@ where ::Error: std::fmt::Debug, Out: Stream::Item> + Unpin, { - Handshake(#[pin] C::Future, Rc), - Handler(#[pin] T::Future, Option>, Option), + Handshake(#[pin] C::Future, Rc, usize), + Handler( + #[pin] T::Future, + Option>, + Option, + usize, + ), Dispatcher(Dispatcher), } @@ -419,7 +462,7 @@ where > { #[project] match self.project() { - FramedServiceImplResponseInner::Handshake(fut, handler) => { + FramedServiceImplResponseInner::Handshake(fut, handler, timeout) => { match fut.poll(cx) { Poll::Ready(Ok(res)) => { log::trace!("Connection handshake succeeded"); @@ -427,6 +470,7 @@ where handler.new_service(res.state), Some(res.framed), res.out, + *timeout, )) } Poll::Pending => Either::Right(Poll::Pending), @@ -436,14 +480,19 @@ where } } } - FramedServiceImplResponseInner::Handler(fut, framed, out) => { + FramedServiceImplResponseInner::Handler(fut, framed, out, timeout) => { match fut.poll(cx) { Poll::Ready(Ok(handler)) => { log::trace!( "Connection handler is created, starting dispatcher" ); Either::Left(FramedServiceImplResponseInner::Dispatcher( - Dispatcher::new(framed.take().unwrap(), handler, out.take()), + Dispatcher::new( + framed.take().unwrap(), + handler, + out.take(), + *timeout, + ), )) } Poll::Pending => Either::Right(Poll::Pending), diff --git a/ntex/src/http/builder.rs b/ntex/src/http/builder.rs index f842c242..33cc76bb 100644 --- a/ntex/src/http/builder.rs +++ b/ntex/src/http/builder.rs @@ -34,8 +34,8 @@ impl HttpServiceBuilder> { pub fn new() -> Self { HttpServiceBuilder { keep_alive: KeepAlive::Timeout(5), - client_timeout: 5000, - client_disconnect: 0, + client_timeout: 3000, + client_disconnect: 3000, handshake_timeout: 5000, expect: ExpectHandler, upgrade: None, @@ -76,7 +76,7 @@ where /// /// To disable timeout set value to 0. /// - /// By default client timeout is set to 5000 milliseconds. + /// By default client timeout is set to 3 seconds. pub fn client_timeout(mut self, val: u64) -> Self { self.client_timeout = val; self @@ -85,12 +85,12 @@ where /// Set server connection disconnect timeout in milliseconds. /// /// Defines a timeout for disconnect connection. If a disconnect procedure does not complete - /// within this time, the request get dropped. This timeout affects secure connections. + /// within this time, the connection get dropped. /// /// To disable timeout set value to 0. /// - /// By default disconnect timeout is set to 0. - pub fn client_disconnect(mut self, val: u64) -> Self { + /// By default disconnect timeout is set to 3 seconds. + pub fn disconnect_timeout(mut self, val: u64) -> Self { self.client_disconnect = val; self } diff --git a/ntex/src/http/h1/dispatcher.rs b/ntex/src/http/h1/dispatcher.rs index b1c6c068..1542a265 100644 --- a/ntex/src/http/h1/dispatcher.rs +++ b/ntex/src/http/h1/dispatcher.rs @@ -7,6 +7,7 @@ use std::{fmt, io, mem, net}; use bitflags::bitflags; use bytes::{Buf, BytesMut}; +use futures::ready; use pin_project::{pin_project, project}; use crate::codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed, FramedParts}; @@ -42,10 +43,12 @@ bitflags! { const STOP_READING = 0b0000_1000; /// Shutdown is in process (flushing and io shutdown timer) const SHUTDOWN = 0b0001_0000; + /// Io shutdown process started + const SHUTDOWN_IO = 0b0010_0000; /// Shutdown timer is started - const SHUTDOWN_TM = 0b0010_0000; + const SHUTDOWN_TM = 0b0100_0000; /// Connection is upgraded - const UPGRADE = 0b0100_0000; + const UPGRADE = 0b1000_0000; } } @@ -429,13 +432,23 @@ where return Poll::Ready(Ok(())); } - self.poll_flush(cx)?; + if !self.flags.contains(Flags::SHUTDOWN_IO) { + self.poll_flush(cx)?; - if self.write_buf.is_empty() { - if let Poll::Ready(res) = - Pin::new(self.io.as_mut().unwrap()).poll_shutdown(cx) - { - return Poll::Ready(res.map_err(DispatchError::from)); + if self.write_buf.is_empty() { + ready!(Pin::new(self.io.as_mut().unwrap()).poll_shutdown(cx)?); + self.flags.insert(Flags::SHUTDOWN_IO); + } + } + + // read until 0 or err + let mut buf = [0u8; 512]; + while let Poll::Ready(res) = + Pin::new(self.io.as_mut().unwrap()).poll_read(cx, &mut buf) + { + match res { + Err(_) | Ok(0) => return Poll::Ready(Ok(())), + _ => (), } } @@ -494,7 +507,7 @@ where trace!("Disconnected during flush, written {}", written); return Err(DispatchError::Io(io::Error::new( io::ErrorKind::WriteZero, - "", + "failed to write frame to transport", ))); } else { written += n @@ -972,18 +985,24 @@ mod tests { #[ntex_rt::test] async fn test_req_parse_err() { let (client, server) = Io::create(); + client.remote_buffer_cap(1024); client.write("GET /test HTTP/1\r\n\r\n"); let mut h1 = h1(server, |_| ok::<_, io::Error>(Response::Ok().finish())); - assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_ready()); + assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_pending()); assert!(h1.inner.flags.contains(Flags::SHUTDOWN)); client .read_buffer(|buf| assert_eq!(&buf[..26], b"HTTP/1.1 400 Bad Request\r\n")); + + client.close().await; + assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_ready()); + assert!(h1.inner.flags.contains(Flags::SHUTDOWN_IO)); } #[ntex_rt::test] async fn test_pipeline() { let (client, server) = Io::create(); + client.remote_buffer_cap(4096); let mut decoder = ClientCodec::default(); spawn_h1(server, |_| ok::<_, io::Error>(Response::Ok().finish())); @@ -991,7 +1010,7 @@ mod tests { let mut buf = client.read().await.unwrap(); assert!(load(&mut decoder, &mut buf).status.is_success()); - assert!(!client.is_server_closed()); + assert!(!client.is_server_dropped()); client.write("GET /test HTTP/1.1\r\n\r\n"); client.write("GET /test HTTP/1.1\r\n\r\n"); @@ -1000,15 +1019,16 @@ mod tests { assert!(load(&mut decoder, &mut buf).status.is_success()); assert!(load(&mut decoder, &mut buf).status.is_success()); assert!(decoder.decode(&mut buf).unwrap().is_none()); - assert!(!client.is_server_closed()); + assert!(!client.is_server_dropped()); client.close().await; - assert!(client.is_server_closed()); + assert!(client.is_server_dropped()); } #[ntex_rt::test] async fn test_pipeline_with_delay() { let (client, server) = Io::create(); + client.remote_buffer_cap(4096); let mut decoder = ClientCodec::default(); spawn_h1(server, |_| async { delay_for(Duration::from_millis(100)).await; @@ -1019,7 +1039,7 @@ mod tests { let mut buf = client.read().await.unwrap(); assert!(load(&mut decoder, &mut buf).status.is_success()); - assert!(!client.is_server_closed()); + assert!(!client.is_server_dropped()); client.write("GET /test HTTP/1.1\r\n\r\n"); client.write("GET /test HTTP/1.1\r\n\r\n"); @@ -1032,15 +1052,15 @@ mod tests { let mut buf = client.read().await.unwrap(); assert!(load(&mut decoder, &mut buf).status.is_success()); assert!(decoder.decode(&mut buf).unwrap().is_none()); - assert!(!client.is_server_closed()); + assert!(!client.is_server_dropped()); buf.extend(client.read().await.unwrap()); assert!(load(&mut decoder, &mut buf).status.is_success()); assert!(decoder.decode(&mut buf).unwrap().is_none()); - assert!(!client.is_server_closed()); + assert!(!client.is_server_dropped()); client.close().await; - assert!(client.is_server_closed()); + assert!(client.is_server_dropped()); } #[ntex_rt::test] @@ -1057,11 +1077,12 @@ mod tests { ok::<_, io::Error>(Response::Ok().finish()) }); + client.remote_buffer_cap(1024); client.write("GET /test HTTP/1.1\r\n\r\n"); client.write("GET /test HTTP/1.1\r\n\r\n"); client.write("GET /test HTTP/1.1\r\n\r\n"); client.close().await; - assert!(client.is_server_closed()); + assert!(client.is_server_dropped()); assert!(client.read_any().is_empty()); // all request must be handled diff --git a/ntex/src/http/h2/dispatcher.rs b/ntex/src/http/h2/dispatcher.rs index 4a55b981..52162d5f 100644 --- a/ntex/src/http/h2/dispatcher.rs +++ b/ntex/src/http/h2/dispatcher.rs @@ -56,13 +56,6 @@ where timeout: Option, peer_addr: Option, ) -> Self { - // let keepalive = config.keep_alive_enabled(); - // let flags = if keepalive { - // Flags::KEEPALIVE | Flags::KEEPALIVE_ENABLED - // } else { - // Flags::empty() - // }; - // keep-alive timer let (ka_expire, ka_timer) = if let Some(delay) = timeout { (delay.deadline(), Some(delay)) diff --git a/ntex/src/http/test.rs b/ntex/src/http/test.rs index 2a59af97..7cc5fde1 100644 --- a/ntex/src/http/test.rs +++ b/ntex/src/http/test.rs @@ -249,7 +249,6 @@ pub fn server>(factory: F) -> TestServer { .set_alpn_protos(b"\x02h2\x08http/1.1") .map_err(|e| log::error!("Can not set alpn protocol: {:?}", e)); Connector::default() - .conn_lifetime(time::Duration::from_secs(0)) .timeout(time::Duration::from_millis(30000)) .openssl(builder.build()) .finish() @@ -257,7 +256,6 @@ pub fn server>(factory: F) -> TestServer { #[cfg(not(feature = "openssl"))] { Connector::default() - .conn_lifetime(time::Duration::from_secs(0)) .timeout(time::Duration::from_millis(30000)) .finish() } diff --git a/ntex/src/testing.rs b/ntex/src/testing.rs index 6cdc2005..169bd37d 100644 --- a/ntex/src/testing.rs +++ b/ntex/src/testing.rs @@ -2,7 +2,7 @@ use std::cell::{Cell, RefCell}; use std::pin::Pin; use std::sync::{Arc, Mutex}; use std::task::{Context, Poll}; -use std::{io, time}; +use std::{cmp, io, mem, time}; use bytes::BytesMut; use futures::future::poll_fn; @@ -15,8 +15,15 @@ use crate::rt::time::delay_for; pub struct Io { tp: Type, state: Arc>, - read: Arc>>, - write: Arc>>, + local: Arc>>, + remote: Arc>>, +} + +bitflags::bitflags! { + struct Flags: u8 { + const FLUSHED = 0b0000_0001; + const CLOSED = 0b0000_0010; + } } #[derive(Copy, Clone)] @@ -29,18 +36,49 @@ enum Type { #[derive(Copy, Clone, Default)] struct State { - client_closed: bool, - server_closed: bool, + client_dropped: bool, + server_dropped: bool, } #[derive(Default)] struct Channel { buf: BytesMut, + buf_cap: usize, + flags: Flags, + waker: AtomicWaker, read_err: Option, - read_waker: AtomicWaker, read_close: CloseState, - write_err: Option, - write_waker: AtomicWaker, + write: IoState, + flush: IoState, +} + +impl Channel { + fn is_closed(&self) -> bool { + self.flags.contains(Flags::CLOSED) + } + + fn is_flushed(&self) -> bool { + self.flags.contains(Flags::FLUSHED) + } +} + +impl Default for Flags { + fn default() -> Self { + Flags::empty() + } +} + +#[derive(Debug)] +enum IoState { + Ok, + Pending, + Err(io::Error), +} + +impl Default for IoState { + fn default() -> Self { + IoState::Ok + } } enum CloseState { @@ -57,32 +95,42 @@ impl Default for CloseState { impl Io { /// Create a two interconnected streams pub fn create() -> (Io, Io) { - let left = Arc::new(Mutex::new(RefCell::new(Channel::default()))); - let right = Arc::new(Mutex::new(RefCell::new(Channel::default()))); + let local = Arc::new(Mutex::new(RefCell::new(Channel::default()))); + let remote = Arc::new(Mutex::new(RefCell::new(Channel::default()))); let state = Arc::new(Cell::new(State::default())); ( Io { tp: Type::Client, - read: left.clone(), - write: right.clone(), + local: local.clone(), + remote: remote.clone(), state: state.clone(), }, Io { state, tp: Type::Server, - read: right, - write: left, + local: remote, + remote: local, }, ) } - pub fn is_client_closed(&self) -> bool { - self.state.get().client_closed + pub fn is_client_dropped(&self) -> bool { + self.state.get().client_dropped } - pub fn is_server_closed(&self) -> bool { - self.state.get().server_closed + pub fn is_server_dropped(&self) -> bool { + self.state.get().server_dropped + } + + /// Check if channel is closed from remoote side + pub fn is_closed(&self) -> bool { + self.remote.lock().unwrap().borrow().is_closed() + } + + /// Check flushed state + pub fn is_flushed(&self) -> bool { + self.remote.lock().unwrap().borrow().is_flushed() } /// Access read buffer. @@ -90,7 +138,7 @@ impl Io { where F: FnOnce(&mut BytesMut) -> R, { - let guard = self.read.lock().unwrap(); + let guard = self.local.lock().unwrap(); let mut ch = guard.borrow_mut(); f(&mut ch.buf) } @@ -98,52 +146,59 @@ impl Io { /// Access write buffer. pub async fn close(&self) { { - let guard = self.write.lock().unwrap(); + let guard = self.remote.lock().unwrap(); let mut write = guard.borrow_mut(); write.read_close = CloseState::Closed; - write.read_waker.wake(); + write.waker.wake(); } delay_for(time::Duration::from_millis(35)).await; } - /// Access write buffer. - pub fn write_buffer(&self, f: F) -> R - where - F: FnOnce(&mut BytesMut) -> R, - { - let guard = self.write.lock().unwrap(); - let mut ch = guard.borrow_mut(); - f(&mut ch.buf) - } - /// Add extra data to the buffer and notify reader pub fn write>(&self, data: T) { - let guard = self.write.lock().unwrap(); + let guard = self.remote.lock().unwrap(); let mut write = guard.borrow_mut(); write.buf.extend_from_slice(data.as_ref()); - write.read_waker.wake(); + write.waker.wake(); + } + + /// Set flush to Pending state + pub fn flush_pending(&self) { + self.remote.lock().unwrap().borrow_mut().flush = IoState::Pending; + } + + /// Set flush to errore + pub fn flush_error(&self, err: io::Error) { + self.remote.lock().unwrap().borrow_mut().flush = IoState::Err(err); + } + + /// Read any available data + pub fn remote_buffer_cap(&self, cap: usize) { + self.local.lock().unwrap().borrow_mut().buf_cap = cap; } /// Read any available data pub fn read_any(&self) -> BytesMut { - self.read.lock().unwrap().borrow_mut().buf.split() + self.local.lock().unwrap().borrow_mut().buf.split() } /// Read data, if data is not available wait for it pub async fn read(&self) -> Result { - if self.read.lock().unwrap().borrow().buf.is_empty() { + if self.local.lock().unwrap().borrow().buf.is_empty() { poll_fn(|cx| { - let guard = self.read.lock().unwrap(); + let guard = self.local.lock().unwrap(); let read = guard.borrow_mut(); if read.buf.is_empty() { let closed = match self.tp { - Type::Client | Type::ClientClone => self.is_server_closed(), - Type::Server | Type::ServerClone => self.is_client_closed(), + Type::Client | Type::ClientClone => { + self.is_server_dropped() || read.is_closed() + } + Type::Server | Type::ServerClone => self.is_client_dropped(), }; if closed { Poll::Ready(()) } else { - read.read_waker.register(cx.waker()); + read.waker.register(cx.waker()); drop(read); drop(guard); Poll::Pending @@ -154,7 +209,7 @@ impl Io { }) .await; } - Ok(self.read.lock().unwrap().borrow_mut().buf.split()) + Ok(self.local.lock().unwrap().borrow_mut().buf.split()) } } @@ -168,8 +223,8 @@ impl Clone for Io { Io { tp, - read: self.read.clone(), - write: self.write.clone(), + local: self.local.clone(), + remote: self.remote.clone(), state: self.state.clone(), } } @@ -179,8 +234,8 @@ impl Drop for Io { fn drop(&mut self) { let mut state = self.state.get(); match self.tp { - Type::Server => state.server_closed = true, - Type::Client => state.client_closed = true, + Type::Server => state.server_dropped = true, + Type::Client => state.client_dropped = true, _ => (), } self.state.set(state); @@ -193,9 +248,9 @@ impl AsyncRead for Io { cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll> { - let guard = self.read.lock().unwrap(); + let guard = self.local.lock().unwrap(); let mut ch = guard.borrow_mut(); - ch.read_waker.register(cx.waker()); + ch.waker.register(cx.waker()); let result = if ch.buf.is_empty() { if let Some(err) = ch.read_err.take() { @@ -223,23 +278,48 @@ impl AsyncWrite for Io { _: &mut Context<'_>, buf: &[u8], ) -> Poll> { - let guard = self.write.lock().unwrap(); + let guard = self.remote.lock().unwrap(); let mut ch = guard.borrow_mut(); - if let Some(err) = ch.write_err.take() { - Poll::Ready(Err(err)) - } else { - ch.write_waker.wake(); - ch.buf.extend(buf); - Poll::Ready(Ok(buf.len())) + match mem::take(&mut ch.write) { + IoState::Ok => { + let cap = cmp::min(buf.len(), ch.buf_cap); + if cap > 0 { + ch.buf.extend(&buf[..cap]); + ch.buf_cap -= cap; + ch.flags.remove(Flags::FLUSHED); + ch.waker.wake(); + Poll::Ready(Ok(cap)) + } else { + Poll::Pending + } + } + IoState::Pending => Poll::Pending, + IoState::Err(e) => Poll::Ready(Err(e)), } } fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) + let guard = self.local.lock().unwrap(); + let mut ch = guard.borrow_mut(); + + match mem::take(&mut ch.flush) { + IoState::Ok => { + ch.flags.insert(Flags::FLUSHED); + Poll::Ready(Ok(())) + } + IoState::Pending => Poll::Pending, + IoState::Err(e) => Poll::Ready(Err(e)), + } } fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + self.local + .lock() + .unwrap() + .borrow_mut() + .flags + .insert(Flags::CLOSED); Poll::Ready(Ok(())) } } diff --git a/ntex/src/web/server.rs b/ntex/src/web/server.rs index 0fae274f..68c07a18 100644 --- a/ntex/src/web/server.rs +++ b/ntex/src/web/server.rs @@ -30,7 +30,7 @@ struct Config { host: Option, keep_alive: KeepAlive, client_timeout: u64, - client_shutdown: u64, + client_disconnect: u64, handshake_timeout: u64, } @@ -89,7 +89,7 @@ where host: None, keep_alive: KeepAlive::Timeout(5), client_timeout: 5000, - client_shutdown: 5000, + client_disconnect: 5000, handshake_timeout: 5000, })), backlog: 1024, @@ -162,22 +162,22 @@ where /// /// To disable timeout set value to 0. /// - /// By default client timeout is set to 5000 milliseconds. + /// By default client timeout is set to 5 seconds. pub fn client_timeout(self, val: u64) -> Self { self.config.lock().unwrap().client_timeout = val; self } - /// Set server connection shutdown timeout in milliseconds. + /// Set server connection disconnect timeout in milliseconds. /// /// Defines a timeout for shutdown connection. If a shutdown procedure does not complete /// within this time, the request is dropped. /// /// To disable timeout set value to 0. /// - /// By default client timeout is set to 5000 milliseconds. - pub fn client_shutdown(self, val: u64) -> Self { - self.config.lock().unwrap().client_shutdown = val; + /// By default client timeout is set to 5 seconds. + pub fn disconnect_timeout(self, val: u64) -> Self { + self.config.lock().unwrap().client_disconnect = val; self } @@ -270,6 +270,7 @@ where HttpService::build() .keep_alive(c.keep_alive) .client_timeout(c.client_timeout) + .disconnect_timeout(c.client_disconnect) .finish(map_config(factory(), move |_| cfg.clone())) .tcp() }, @@ -316,7 +317,7 @@ where HttpService::build() .keep_alive(c.keep_alive) .client_timeout(c.client_timeout) - .client_disconnect(c.client_shutdown) + .disconnect_timeout(c.client_disconnect) .ssl_handshake_timeout(c.handshake_timeout) .finish(map_config(factory(), move |_| cfg.clone())) .openssl(acceptor.clone()) @@ -364,7 +365,7 @@ where HttpService::build() .keep_alive(c.keep_alive) .client_timeout(c.client_timeout) - .client_disconnect(c.client_shutdown) + .disconnect_timeout(c.client_disconnect) .ssl_handshake_timeout(c.handshake_timeout) .finish(map_config(factory(), move |_| cfg.clone())) .rustls(config.clone()) diff --git a/ntex/tests/http_awc_client.rs b/ntex/tests/http_awc_client.rs index dce5c2ae..ca53e8e1 100644 --- a/ntex/tests/http_awc_client.rs +++ b/ntex/tests/http_awc_client.rs @@ -57,7 +57,12 @@ async fn test_simple() { let bytes = response.body().await.unwrap(); assert_eq!(bytes, Bytes::from_static(STR.as_ref())); - let mut response = srv.post("/").send().await.unwrap(); + let mut response = srv + .post("/") + .timeout(Duration::from_secs(30)) + .send() + .await + .unwrap(); assert!(response.status().is_success()); // read response diff --git a/ntex/tests/http_rustls.rs b/ntex/tests/http_rustls.rs index a437bbd8..78ce079e 100644 --- a/ntex/tests/http_rustls.rs +++ b/ntex/tests/http_rustls.rs @@ -158,7 +158,7 @@ async fn test_h2_content_length() { let req = srv .srequest(Method::HEAD, &format!("/{}", i)) - .timeout(Duration::from_secs(30)) + .timeout(Duration::from_secs(100)) .send(); let response = req.await.unwrap(); assert_eq!(response.headers().get(&header), None); diff --git a/ntex/tests/http_server.rs b/ntex/tests/http_server.rs index 1bc7af0d..00a68d1e 100644 --- a/ntex/tests/http_server.rs +++ b/ntex/tests/http_server.rs @@ -21,7 +21,7 @@ async fn test_h1() { HttpService::build() .keep_alive(KeepAlive::Disabled) .client_timeout(1000) - .client_disconnect(1000) + .disconnect_timeout(1000) .h1(|req: Request| { assert!(req.peer_addr().is_some()); future::ok::<_, io::Error>(Response::Ok().finish()) @@ -39,7 +39,7 @@ async fn test_h1_2() { HttpService::build() .keep_alive(KeepAlive::Disabled) .client_timeout(1000) - .client_disconnect(1000) + .disconnect_timeout(1000) .finish(|req: Request| { assert!(req.peer_addr().is_some()); assert_eq!(req.version(), http::Version::HTTP_11); diff --git a/ntex/tests/web_httpserver.rs b/ntex/tests/web_httpserver.rs index de175139..128523d9 100644 --- a/ntex/tests/web_httpserver.rs +++ b/ntex/tests/web_httpserver.rs @@ -29,7 +29,8 @@ async fn test_start() { .maxconnrate(10) .keep_alive(10) .client_timeout(5000) - .client_shutdown(0) + .disconnect_timeout(1000) + .ssl_handshake_timeout(1000) .server_hostname("localhost") .system_exit() .disable_signals()