Fix io close (#12)

* Fix io close for Framed
* Fix connection shutdown for h1 dispatcher
* Enable client disconnect for http server by default
* Add connection disconnect timeout to framed service
This commit is contained in:
Nikolay Kim 2020-04-07 21:36:48 +06:00 committed by GitHub
parent 8a753a762f
commit 3b12a77e92
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 529 additions and 169 deletions

View file

@ -1,5 +1,11 @@
# Changes
## [0.1.1] - 2020-04-07
* Optimize io operations
* Fix framed close method
## [0.1.0] - 2020-03-31
* Fork crate to ntex namespace

View file

@ -1,6 +1,6 @@
[package]
name = "ntex-codec"
version = "0.1.0"
version = "0.1.1"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "Utilities for encoding and decoding frames"
keywords = ["network", "framework", "async", "futures"]
@ -8,7 +8,7 @@ homepage = "https://ntex.rs"
repository = "https://github.com/ntex-rs/ntex.git"
documentation = "https://docs.rs/ntex-codec/"
categories = ["network-programming", "asynchronous"]
license = "MIT/Apache-2.0"
license = "MIT"
edition = "2018"
[lib]
@ -20,6 +20,10 @@ bitflags = "1.2.1"
bytes = "0.5.4"
futures-core = "0.3.4"
futures-sink = "0.3.4"
tokio = { version = "0.2.4", default-features=false }
tokio = { version = "0.2.6", default-features=false }
tokio-util = { version = "0.2.0", default-features=false, features=["codec"] }
log = "0.4"
[dev-dependencies]
ntex = "0.1.4"
futures = "0.3.4"

View file

@ -15,11 +15,14 @@ bitflags::bitflags! {
struct Flags: u8 {
const EOF = 0b0001;
const READABLE = 0b0010;
const DISCONNECTED = 0b0100;
const SHUTDOWN = 0b1000;
}
}
/// A unified `Stream` and `Sink` interface to an underlying I/O object, using
/// the `Encoder` and `Decoder` traits to encode and decode frames.
/// `Framed` is heavily optimized for streaming io.
pub struct Framed<T, U> {
io: T,
codec: U,
@ -28,8 +31,6 @@ pub struct Framed<T, U> {
write_buf: BytesMut,
}
impl<T, U> Unpin for Framed<T, U> {}
impl<T, U> Framed<T, U>
where
T: AsyncRead + AsyncWrite,
@ -123,6 +124,18 @@ impl<T, U> Framed<T, U> {
&mut self.io
}
#[inline]
/// Get read buffer.
pub fn read_buf_mut(&mut self) -> &mut BytesMut {
&mut self.read_buf
}
#[inline]
/// Get write buffer.
pub fn write_buf_mut(&mut self) -> &mut BytesMut {
&mut self.write_buf
}
#[inline]
/// Check if write buffer is empty.
pub fn is_write_buf_empty(&self) -> bool {
@ -135,6 +148,12 @@ impl<T, U> Framed<T, U> {
self.write_buf.len() >= HW
}
#[inline]
/// Check if framed object is closed
pub fn is_closed(&self) -> bool {
self.flags.contains(Flags::DISCONNECTED)
}
#[inline]
/// Consume the `Frame`, returning `Frame` with different codec.
pub fn into_framed<U2>(self, codec: U2) -> Framed<T, U2> {
@ -227,34 +246,87 @@ where
pub fn flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), U::Error>> {
log::trace!("flushing framed transport");
while !self.write_buf.is_empty() {
log::trace!("writing; remaining={}", self.write_buf.len());
let len = self.write_buf.len();
if len == 0 {
return Poll::Ready(Ok(()));
}
let n = ready!(Pin::new(&mut self.io).poll_write(cx, &self.write_buf))?;
let mut written = 0;
while written < len {
match Pin::new(&mut self.io).poll_write(cx, &self.write_buf[written..]) {
Poll::Pending => break,
Poll::Ready(Ok(n)) => {
if n == 0 {
log::trace!("Disconnected during flush, written {}", written);
self.flags.insert(Flags::DISCONNECTED);
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::WriteZero,
"failed to write frame to transport",
)
.into()));
} else {
written += n
}
}
Poll::Ready(Err(e)) => {
log::trace!("Error during flush: {}", e);
self.flags.insert(Flags::DISCONNECTED);
return Poll::Ready(Err(e.into()));
}
}
}
// remove written data
self.write_buf.advance(n);
if written == len {
// flushed same amount as in buffer, we dont need to reallocate
unsafe { self.write_buf.set_len(0) }
} else {
self.write_buf.advance(written);
}
// Try flushing the underlying IO
ready!(Pin::new(&mut self.io).poll_flush(cx))?;
log::trace!("framed transport flushed");
if self.write_buf.is_empty() {
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
}
impl<T, U> Framed<T, U>
where
T: AsyncRead + AsyncWrite + Unpin,
{
#[inline]
/// Flush write buffer and shutdown underlying I/O stream.
pub fn close(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), U::Error>> {
///
/// Close method shutdown write side of a io object and
/// then reads until disconnect or error, high level code must use
/// timeout for close operation.
pub fn close(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
if !self.flags.contains(Flags::DISCONNECTED) {
// flush write buffer
ready!(Pin::new(&mut self.io).poll_flush(cx))?;
ready!(Pin::new(&mut self.io).poll_shutdown(cx))?;
if !self.flags.contains(Flags::SHUTDOWN) {
// shutdown WRITE side
ready!(Pin::new(&mut self.io).poll_shutdown(cx)).map_err(|e| {
self.flags.insert(Flags::DISCONNECTED);
e
})?;
self.flags.insert(Flags::SHUTDOWN);
}
// read until 0 or err
let mut buf = [0u8; 512];
loop {
match ready!(Pin::new(&mut self.io).poll_read(cx, &mut buf)) {
Err(_) | Ok(0) => {
break;
}
_ => (),
}
}
self.flags.insert(Flags::DISCONNECTED);
}
log::trace!("framed transport flushed and closed");
Poll::Ready(Ok(()))
}
@ -269,11 +341,9 @@ where
pub fn next_item(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<U::Item, U::Error>>>
where
T: AsyncRead,
U: Decoder,
{
) -> Poll<Option<Result<U::Item, U::Error>>> {
let mut done_read = false;
loop {
// Repeatedly call `decode` or `decode_eof` as long as it is
// "readable". Readable is defined as not having returned `None`. If
@ -302,26 +372,45 @@ where
}
self.flags.remove(Flags::READABLE);
if done_read {
return Poll::Pending;
}
}
debug_assert!(!self.flags.contains(Flags::EOF));
// read all data from socket
let mut updated = false;
loop {
// Otherwise, try to read more data and try again. Make sure we've got room
let remaining = self.read_buf.capacity() - self.read_buf.len();
if remaining < LW {
self.read_buf.reserve(HW - remaining)
}
let cnt = match Pin::new(&mut self.io).poll_read_buf(cx, &mut self.read_buf)
{
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e.into()))),
Poll::Ready(Ok(cnt)) => cnt,
};
if cnt == 0 {
self.flags.insert(Flags::EOF);
}
match Pin::new(&mut self.io).poll_read_buf(cx, &mut self.read_buf) {
Poll::Pending => {
if updated {
done_read = true;
self.flags.insert(Flags::READABLE);
break;
} else {
return Poll::Pending;
}
}
Poll::Ready(Ok(n)) => {
if n == 0 {
self.flags.insert(Flags::EOF | Flags::READABLE);
if updated {
done_read = true;
}
break;
} else {
updated = true;
}
}
Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e.into()))),
}
}
}
}
}
@ -329,7 +418,7 @@ where
impl<T, U> Stream for Framed<T, U>
where
T: AsyncRead + Unpin,
U: Decoder,
U: Decoder + Unpin,
{
type Item = Result<U::Item, U::Error>;
@ -344,8 +433,8 @@ where
impl<T, U> Sink<U::Item> for Framed<T, U>
where
T: AsyncWrite + Unpin,
U: Encoder,
T: AsyncRead + AsyncWrite + Unpin,
U: Encoder + Unpin,
U::Error: From<io::Error>,
{
type Error = U::Error;
@ -383,7 +472,7 @@ where
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.close(cx)
self.close(cx).map_err(|e| e.into())
}
}
@ -443,3 +532,77 @@ impl<T, U> FramedParts<T, U> {
}
}
}
#[cfg(test)]
mod tests {
use bytes::Bytes;
use futures::future::lazy;
use futures::Sink;
use ntex::testing::Io;
use super::*;
use crate::BytesCodec;
#[ntex::test]
async fn test_sink() {
let (client, server) = Io::create();
client.remote_buffer_cap(1024);
let mut server = Framed::new(server, BytesCodec);
assert!(lazy(|cx| Pin::new(&mut server).poll_ready(cx))
.await
.is_ready());
let data = Bytes::from_static(b"GET /test HTTP/1.1\r\n\r\n");
Pin::new(&mut server).start_send(data).unwrap();
assert_eq!(client.read_any(), b"".as_ref());
assert!(lazy(|cx| Pin::new(&mut server).poll_flush(cx))
.await
.is_ready());
assert_eq!(client.read_any(), b"GET /test HTTP/1.1\r\n\r\n".as_ref());
assert!(lazy(|cx| Pin::new(&mut server).poll_close(cx))
.await
.is_pending());
client.close().await;
assert!(lazy(|cx| Pin::new(&mut server).poll_close(cx))
.await
.is_ready());
assert!(client.is_closed());
}
#[ntex::test]
async fn test_write_pending() {
let (client, server) = Io::create();
let mut server = Framed::new(server, BytesCodec);
assert!(lazy(|cx| Pin::new(&mut server).poll_ready(cx))
.await
.is_ready());
let data = Bytes::from_static(b"GET /test HTTP/1.1\r\n\r\n");
Pin::new(&mut server).start_send(data).unwrap();
client.remote_buffer_cap(3);
assert!(lazy(|cx| Pin::new(&mut server).poll_flush(cx))
.await
.is_pending());
assert_eq!(client.read_any(), b"GET".as_ref());
client.remote_buffer_cap(1024);
assert!(lazy(|cx| Pin::new(&mut server).poll_flush(cx))
.await
.is_ready());
assert_eq!(client.read_any(), b" /test HTTP/1.1\r\n\r\n".as_ref());
assert!(lazy(|cx| Pin::new(&mut server).poll_close(cx))
.await
.is_pending());
client.close().await;
assert!(lazy(|cx| Pin::new(&mut server).poll_close(cx))
.await
.is_ready());
assert!(client.is_closed());
assert!(server.is_closed());
}
}

View file

@ -6,7 +6,7 @@
//!
//! [`AsyncRead`]: #
//! [`AsyncWrite`]: #
#![deny(rust_2018_idioms, warnings)]
// #![deny(rust_2018_idioms, warnings)]
mod bcodec;
mod framed;

View file

@ -1,5 +1,13 @@
# Changes
## [0.1.5] - 2020-04-07
* ntex::http: enable client disconnect timeout by default
* ntex::http: properly close h1 connection
* ntex::framed: add connection disconnect timeout to framed service
## [0.1.4] - 2020-04-06
* Remove unneeded RefCell from client connector

View file

@ -1,6 +1,6 @@
[package]
name = "ntex"
version = "0.1.4"
version = "0.1.5"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
description = "Framework for composable network services"
readme = "README.md"
@ -36,10 +36,10 @@ compress = ["flate2", "brotli2"]
cookie = ["coo-kie", "coo-kie/percent-encode"]
[dependencies]
ntex-codec = "0.1"
ntex-codec = "0.1.1"
ntex-rt = "0.1"
ntex-rt-macros = "0.1"
ntex-router = "0.3.1"
ntex-router = "0.3.2"
ntex-service = "0.1"
actix-threadpool = "0.3.1"

View file

@ -16,7 +16,7 @@ async fn main() -> io::Result<()> {
.bind("echo", "127.0.0.1:8080", || {
HttpService::build()
.client_timeout(1000)
.client_disconnect(1000)
.disconnect_timeout(1000)
.finish(|mut req: Request| async move {
let mut body = BytesMut::new();
while let Some(item) = req.payload().next().await {

View file

@ -15,7 +15,7 @@ async fn main() -> io::Result<()> {
.bind("hello-world", "127.0.0.1:8080", || {
HttpService::build()
.client_timeout(1000)
.client_disconnect(1000)
.disconnect_timeout(1000)
.finish(|_req| {
info!("{:?}", _req);
let mut res = Response::Ok();

View file

@ -1,12 +1,15 @@
//! Framed dispatcher service and related utilities
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use futures::Stream;
use futures::{ready, Stream};
use log::debug;
use crate::channel::mpsc;
use crate::codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed};
use crate::rt::time::{delay_for, Delay};
use crate::service::Service;
use super::error::ServiceError;
@ -32,6 +35,7 @@ where
state: FramedState<S, U>,
framed: Framed<T, U>,
rx: mpsc::Receiver<Result<<U as Encoder>::Item, S::Error>>,
disconnect_timeout: usize,
}
impl<S, T, U, Out> Dispatcher<S, T, U, Out>
@ -45,13 +49,19 @@ where
<U as Encoder>::Error: std::fmt::Debug,
Out: Stream<Item = <U as Encoder>::Item> + Unpin,
{
pub(super) fn new(framed: Framed<T, U>, service: S, sink: Option<Out>) -> Self {
pub(super) fn new(
framed: Framed<T, U>,
service: S,
sink: Option<Out>,
timeout: usize,
) -> Self {
Dispatcher {
sink,
service,
framed,
rx: mpsc::channel().1,
state: FramedState::Processing,
disconnect_timeout: timeout,
}
}
}
@ -61,6 +71,7 @@ enum FramedState<S: Service, U: Encoder + Decoder> {
Error(ServiceError<S::Error, U>),
FlushAndStop,
Shutdown(Option<ServiceError<S::Error, U>>),
ShutdownIo(Delay),
}
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
@ -249,13 +260,33 @@ where
return if self.service.poll_shutdown(cx, err.is_some()).is_ready() {
if let Some(err) = err.take() {
Poll::Ready(Err(err))
} else {
let pending = self.framed.close(cx).is_pending();
if self.disconnect_timeout == 0 && pending {
self.state = FramedState::ShutdownIo(delay_for(
Duration::from_millis(
self.disconnect_timeout as u64,
),
));
continue;
} else {
Poll::Ready(Ok(()))
}
}
} else {
Poll::Pending
}
}
FramedState::ShutdownIo(ref mut delay) => {
if let Poll::Ready(res) = self.framed.close(cx) {
return Poll::Ready(
res.map_err(|e| ServiceError::Encoder(e.into())),
);
} else {
ready!(Pin::new(delay).poll(cx));
return Poll::Ready(Ok(()));
}
}
}
}
}

View file

@ -130,6 +130,6 @@ where
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.get_mut().framed.close(cx)
self.get_mut().framed.close(cx).map_err(|e| e.into())
}
}

View file

@ -23,6 +23,7 @@ type ResponseItem<U> = Option<<U as Encoder>::Item>;
/// for building instances for framed services.
pub struct Builder<St, C, Io, Codec, Out> {
connect: C,
disconnect_timeout: usize,
_t: PhantomData<(St, Io, Codec, Out)>,
}
@ -46,10 +47,24 @@ where
{
Builder {
connect: connect.into_service(),
disconnect_timeout: 3000,
_t: PhantomData,
}
}
/// Set connection disconnect timeout in milliseconds.
///
/// Defines a timeout for disconnect connection. If a disconnect procedure does not complete
/// within this time, the connection get dropped.
///
/// To disable timeout set value to 0.
///
/// By default disconnect timeout is set to 3 seconds.
pub fn disconnect_timeout(mut self, val: usize) -> Self {
self.disconnect_timeout = val;
self
}
/// Provide stream items handler service and construct service factory.
pub fn build<F, T>(self, service: F) -> FramedServiceImpl<St, C, T, Io, Codec, Out>
where
@ -65,6 +80,7 @@ where
FramedServiceImpl {
connect: self.connect,
handler: Rc::new(service.into_factory()),
disconnect_timeout: self.disconnect_timeout,
_t: PhantomData,
}
}
@ -74,6 +90,7 @@ where
/// for building instances for framed services.
pub struct FactoryBuilder<St, C, Io, Codec, Out> {
connect: C,
disconnect_timeout: usize,
_t: PhantomData<(St, Io, Codec, Out)>,
}
@ -97,10 +114,24 @@ where
{
FactoryBuilder {
connect: connect.into_factory(),
disconnect_timeout: 3000,
_t: PhantomData,
}
}
/// Set connection disconnect timeout in milliseconds.
///
/// Defines a timeout for disconnect connection. If a disconnect procedure does not complete
/// within this time, the connection get dropped.
///
/// To disable timeout set value to 0.
///
/// By default disconnect timeout is set to 3 seconds.
pub fn disconnect_timeout(mut self, val: usize) -> Self {
self.disconnect_timeout = val;
self
}
pub fn build<F, T, Cfg>(
self,
service: F,
@ -118,6 +149,7 @@ where
FramedService {
connect: self.connect,
handler: Rc::new(service.into_factory()),
disconnect_timeout: self.disconnect_timeout,
_t: PhantomData,
}
}
@ -126,6 +158,7 @@ where
pub struct FramedService<St, C, T, Io, Codec, Out, Cfg> {
connect: C,
handler: Rc<T>,
disconnect_timeout: usize,
_t: PhantomData<(St, Io, Codec, Out, Cfg)>,
}
@ -166,6 +199,7 @@ where
FramedServiceResponse {
fut: self.connect.new_service(()),
handler: self.handler.clone(),
disconnect_timeout: self.disconnect_timeout,
}
}
}
@ -197,6 +231,7 @@ where
#[pin]
fut: C::Future,
handler: Rc<T>,
disconnect_timeout: usize,
}
impl<St, C, T, Io, Codec, Out> Future for FramedServiceResponse<St, C, T, Io, Codec, Out>
@ -232,6 +267,7 @@ where
Poll::Ready(Ok(FramedServiceImpl {
connect,
handler: this.handler.clone(),
disconnect_timeout: *this.disconnect_timeout,
_t: PhantomData,
}))
}
@ -240,6 +276,7 @@ where
pub struct FramedServiceImpl<St, C, T, Io, Codec, Out> {
connect: C,
handler: Rc<T>,
disconnect_timeout: usize,
_t: PhantomData<(St, Io, Codec, Out)>,
}
@ -287,6 +324,7 @@ where
inner: FramedServiceImplResponseInner::Handshake(
self.connect.call(Handshake::new(req)),
self.handler.clone(),
self.disconnect_timeout,
),
}
}
@ -382,8 +420,13 @@ where
<Codec as Encoder>::Error: std::fmt::Debug,
Out: Stream<Item = <Codec as Encoder>::Item> + Unpin,
{
Handshake(#[pin] C::Future, Rc<T>),
Handler(#[pin] T::Future, Option<Framed<Io, Codec>>, Option<Out>),
Handshake(#[pin] C::Future, Rc<T>, usize),
Handler(
#[pin] T::Future,
Option<Framed<Io, Codec>>,
Option<Out>,
usize,
),
Dispatcher(Dispatcher<T::Service, Io, Codec, Out>),
}
@ -419,7 +462,7 @@ where
> {
#[project]
match self.project() {
FramedServiceImplResponseInner::Handshake(fut, handler) => {
FramedServiceImplResponseInner::Handshake(fut, handler, timeout) => {
match fut.poll(cx) {
Poll::Ready(Ok(res)) => {
log::trace!("Connection handshake succeeded");
@ -427,6 +470,7 @@ where
handler.new_service(res.state),
Some(res.framed),
res.out,
*timeout,
))
}
Poll::Pending => Either::Right(Poll::Pending),
@ -436,14 +480,19 @@ where
}
}
}
FramedServiceImplResponseInner::Handler(fut, framed, out) => {
FramedServiceImplResponseInner::Handler(fut, framed, out, timeout) => {
match fut.poll(cx) {
Poll::Ready(Ok(handler)) => {
log::trace!(
"Connection handler is created, starting dispatcher"
);
Either::Left(FramedServiceImplResponseInner::Dispatcher(
Dispatcher::new(framed.take().unwrap(), handler, out.take()),
Dispatcher::new(
framed.take().unwrap(),
handler,
out.take(),
*timeout,
),
))
}
Poll::Pending => Either::Right(Poll::Pending),

View file

@ -34,8 +34,8 @@ impl<T, S> HttpServiceBuilder<T, S, ExpectHandler, UpgradeHandler<T>> {
pub fn new() -> Self {
HttpServiceBuilder {
keep_alive: KeepAlive::Timeout(5),
client_timeout: 5000,
client_disconnect: 0,
client_timeout: 3000,
client_disconnect: 3000,
handshake_timeout: 5000,
expect: ExpectHandler,
upgrade: None,
@ -76,7 +76,7 @@ where
///
/// To disable timeout set value to 0.
///
/// By default client timeout is set to 5000 milliseconds.
/// By default client timeout is set to 3 seconds.
pub fn client_timeout(mut self, val: u64) -> Self {
self.client_timeout = val;
self
@ -85,12 +85,12 @@ where
/// Set server connection disconnect timeout in milliseconds.
///
/// Defines a timeout for disconnect connection. If a disconnect procedure does not complete
/// within this time, the request get dropped. This timeout affects secure connections.
/// within this time, the connection get dropped.
///
/// To disable timeout set value to 0.
///
/// By default disconnect timeout is set to 0.
pub fn client_disconnect(mut self, val: u64) -> Self {
/// By default disconnect timeout is set to 3 seconds.
pub fn disconnect_timeout(mut self, val: u64) -> Self {
self.client_disconnect = val;
self
}

View file

@ -7,6 +7,7 @@ use std::{fmt, io, mem, net};
use bitflags::bitflags;
use bytes::{Buf, BytesMut};
use futures::ready;
use pin_project::{pin_project, project};
use crate::codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed, FramedParts};
@ -42,10 +43,12 @@ bitflags! {
const STOP_READING = 0b0000_1000;
/// Shutdown is in process (flushing and io shutdown timer)
const SHUTDOWN = 0b0001_0000;
/// Io shutdown process started
const SHUTDOWN_IO = 0b0010_0000;
/// Shutdown timer is started
const SHUTDOWN_TM = 0b0010_0000;
const SHUTDOWN_TM = 0b0100_0000;
/// Connection is upgraded
const UPGRADE = 0b0100_0000;
const UPGRADE = 0b1000_0000;
}
}
@ -429,13 +432,23 @@ where
return Poll::Ready(Ok(()));
}
if !self.flags.contains(Flags::SHUTDOWN_IO) {
self.poll_flush(cx)?;
if self.write_buf.is_empty() {
if let Poll::Ready(res) =
Pin::new(self.io.as_mut().unwrap()).poll_shutdown(cx)
ready!(Pin::new(self.io.as_mut().unwrap()).poll_shutdown(cx)?);
self.flags.insert(Flags::SHUTDOWN_IO);
}
}
// read until 0 or err
let mut buf = [0u8; 512];
while let Poll::Ready(res) =
Pin::new(self.io.as_mut().unwrap()).poll_read(cx, &mut buf)
{
return Poll::Ready(res.map_err(DispatchError::from));
match res {
Err(_) | Ok(0) => return Poll::Ready(Ok(())),
_ => (),
}
}
@ -494,7 +507,7 @@ where
trace!("Disconnected during flush, written {}", written);
return Err(DispatchError::Io(io::Error::new(
io::ErrorKind::WriteZero,
"",
"failed to write frame to transport",
)));
} else {
written += n
@ -972,18 +985,24 @@ mod tests {
#[ntex_rt::test]
async fn test_req_parse_err() {
let (client, server) = Io::create();
client.remote_buffer_cap(1024);
client.write("GET /test HTTP/1\r\n\r\n");
let mut h1 = h1(server, |_| ok::<_, io::Error>(Response::Ok().finish()));
assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_ready());
assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_pending());
assert!(h1.inner.flags.contains(Flags::SHUTDOWN));
client
.read_buffer(|buf| assert_eq!(&buf[..26], b"HTTP/1.1 400 Bad Request\r\n"));
client.close().await;
assert!(lazy(|cx| Pin::new(&mut h1).poll(cx)).await.is_ready());
assert!(h1.inner.flags.contains(Flags::SHUTDOWN_IO));
}
#[ntex_rt::test]
async fn test_pipeline() {
let (client, server) = Io::create();
client.remote_buffer_cap(4096);
let mut decoder = ClientCodec::default();
spawn_h1(server, |_| ok::<_, io::Error>(Response::Ok().finish()));
@ -991,7 +1010,7 @@ mod tests {
let mut buf = client.read().await.unwrap();
assert!(load(&mut decoder, &mut buf).status.is_success());
assert!(!client.is_server_closed());
assert!(!client.is_server_dropped());
client.write("GET /test HTTP/1.1\r\n\r\n");
client.write("GET /test HTTP/1.1\r\n\r\n");
@ -1000,15 +1019,16 @@ mod tests {
assert!(load(&mut decoder, &mut buf).status.is_success());
assert!(load(&mut decoder, &mut buf).status.is_success());
assert!(decoder.decode(&mut buf).unwrap().is_none());
assert!(!client.is_server_closed());
assert!(!client.is_server_dropped());
client.close().await;
assert!(client.is_server_closed());
assert!(client.is_server_dropped());
}
#[ntex_rt::test]
async fn test_pipeline_with_delay() {
let (client, server) = Io::create();
client.remote_buffer_cap(4096);
let mut decoder = ClientCodec::default();
spawn_h1(server, |_| async {
delay_for(Duration::from_millis(100)).await;
@ -1019,7 +1039,7 @@ mod tests {
let mut buf = client.read().await.unwrap();
assert!(load(&mut decoder, &mut buf).status.is_success());
assert!(!client.is_server_closed());
assert!(!client.is_server_dropped());
client.write("GET /test HTTP/1.1\r\n\r\n");
client.write("GET /test HTTP/1.1\r\n\r\n");
@ -1032,15 +1052,15 @@ mod tests {
let mut buf = client.read().await.unwrap();
assert!(load(&mut decoder, &mut buf).status.is_success());
assert!(decoder.decode(&mut buf).unwrap().is_none());
assert!(!client.is_server_closed());
assert!(!client.is_server_dropped());
buf.extend(client.read().await.unwrap());
assert!(load(&mut decoder, &mut buf).status.is_success());
assert!(decoder.decode(&mut buf).unwrap().is_none());
assert!(!client.is_server_closed());
assert!(!client.is_server_dropped());
client.close().await;
assert!(client.is_server_closed());
assert!(client.is_server_dropped());
}
#[ntex_rt::test]
@ -1057,11 +1077,12 @@ mod tests {
ok::<_, io::Error>(Response::Ok().finish())
});
client.remote_buffer_cap(1024);
client.write("GET /test HTTP/1.1\r\n\r\n");
client.write("GET /test HTTP/1.1\r\n\r\n");
client.write("GET /test HTTP/1.1\r\n\r\n");
client.close().await;
assert!(client.is_server_closed());
assert!(client.is_server_dropped());
assert!(client.read_any().is_empty());
// all request must be handled

View file

@ -56,13 +56,6 @@ where
timeout: Option<Delay>,
peer_addr: Option<net::SocketAddr>,
) -> Self {
// let keepalive = config.keep_alive_enabled();
// let flags = if keepalive {
// Flags::KEEPALIVE | Flags::KEEPALIVE_ENABLED
// } else {
// Flags::empty()
// };
// keep-alive timer
let (ka_expire, ka_timer) = if let Some(delay) = timeout {
(delay.deadline(), Some(delay))

View file

@ -249,7 +249,6 @@ pub fn server<F: StreamServiceFactory<TcpStream>>(factory: F) -> TestServer {
.set_alpn_protos(b"\x02h2\x08http/1.1")
.map_err(|e| log::error!("Can not set alpn protocol: {:?}", e));
Connector::default()
.conn_lifetime(time::Duration::from_secs(0))
.timeout(time::Duration::from_millis(30000))
.openssl(builder.build())
.finish()
@ -257,7 +256,6 @@ pub fn server<F: StreamServiceFactory<TcpStream>>(factory: F) -> TestServer {
#[cfg(not(feature = "openssl"))]
{
Connector::default()
.conn_lifetime(time::Duration::from_secs(0))
.timeout(time::Duration::from_millis(30000))
.finish()
}

View file

@ -2,7 +2,7 @@ use std::cell::{Cell, RefCell};
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use std::{io, time};
use std::{cmp, io, mem, time};
use bytes::BytesMut;
use futures::future::poll_fn;
@ -15,8 +15,15 @@ use crate::rt::time::delay_for;
pub struct Io {
tp: Type,
state: Arc<Cell<State>>,
read: Arc<Mutex<RefCell<Channel>>>,
write: Arc<Mutex<RefCell<Channel>>>,
local: Arc<Mutex<RefCell<Channel>>>,
remote: Arc<Mutex<RefCell<Channel>>>,
}
bitflags::bitflags! {
struct Flags: u8 {
const FLUSHED = 0b0000_0001;
const CLOSED = 0b0000_0010;
}
}
#[derive(Copy, Clone)]
@ -29,18 +36,49 @@ enum Type {
#[derive(Copy, Clone, Default)]
struct State {
client_closed: bool,
server_closed: bool,
client_dropped: bool,
server_dropped: bool,
}
#[derive(Default)]
struct Channel {
buf: BytesMut,
buf_cap: usize,
flags: Flags,
waker: AtomicWaker,
read_err: Option<io::Error>,
read_waker: AtomicWaker,
read_close: CloseState,
write_err: Option<io::Error>,
write_waker: AtomicWaker,
write: IoState,
flush: IoState,
}
impl Channel {
fn is_closed(&self) -> bool {
self.flags.contains(Flags::CLOSED)
}
fn is_flushed(&self) -> bool {
self.flags.contains(Flags::FLUSHED)
}
}
impl Default for Flags {
fn default() -> Self {
Flags::empty()
}
}
#[derive(Debug)]
enum IoState {
Ok,
Pending,
Err(io::Error),
}
impl Default for IoState {
fn default() -> Self {
IoState::Ok
}
}
enum CloseState {
@ -57,32 +95,42 @@ impl Default for CloseState {
impl Io {
/// Create a two interconnected streams
pub fn create() -> (Io, Io) {
let left = Arc::new(Mutex::new(RefCell::new(Channel::default())));
let right = Arc::new(Mutex::new(RefCell::new(Channel::default())));
let local = Arc::new(Mutex::new(RefCell::new(Channel::default())));
let remote = Arc::new(Mutex::new(RefCell::new(Channel::default())));
let state = Arc::new(Cell::new(State::default()));
(
Io {
tp: Type::Client,
read: left.clone(),
write: right.clone(),
local: local.clone(),
remote: remote.clone(),
state: state.clone(),
},
Io {
state,
tp: Type::Server,
read: right,
write: left,
local: remote,
remote: local,
},
)
}
pub fn is_client_closed(&self) -> bool {
self.state.get().client_closed
pub fn is_client_dropped(&self) -> bool {
self.state.get().client_dropped
}
pub fn is_server_closed(&self) -> bool {
self.state.get().server_closed
pub fn is_server_dropped(&self) -> bool {
self.state.get().server_dropped
}
/// Check if channel is closed from remoote side
pub fn is_closed(&self) -> bool {
self.remote.lock().unwrap().borrow().is_closed()
}
/// Check flushed state
pub fn is_flushed(&self) -> bool {
self.remote.lock().unwrap().borrow().is_flushed()
}
/// Access read buffer.
@ -90,7 +138,7 @@ impl Io {
where
F: FnOnce(&mut BytesMut) -> R,
{
let guard = self.read.lock().unwrap();
let guard = self.local.lock().unwrap();
let mut ch = guard.borrow_mut();
f(&mut ch.buf)
}
@ -98,52 +146,59 @@ impl Io {
/// Access write buffer.
pub async fn close(&self) {
{
let guard = self.write.lock().unwrap();
let guard = self.remote.lock().unwrap();
let mut write = guard.borrow_mut();
write.read_close = CloseState::Closed;
write.read_waker.wake();
write.waker.wake();
}
delay_for(time::Duration::from_millis(35)).await;
}
/// Access write buffer.
pub fn write_buffer<F, R>(&self, f: F) -> R
where
F: FnOnce(&mut BytesMut) -> R,
{
let guard = self.write.lock().unwrap();
let mut ch = guard.borrow_mut();
f(&mut ch.buf)
}
/// Add extra data to the buffer and notify reader
pub fn write<T: AsRef<[u8]>>(&self, data: T) {
let guard = self.write.lock().unwrap();
let guard = self.remote.lock().unwrap();
let mut write = guard.borrow_mut();
write.buf.extend_from_slice(data.as_ref());
write.read_waker.wake();
write.waker.wake();
}
/// Set flush to Pending state
pub fn flush_pending(&self) {
self.remote.lock().unwrap().borrow_mut().flush = IoState::Pending;
}
/// Set flush to errore
pub fn flush_error(&self, err: io::Error) {
self.remote.lock().unwrap().borrow_mut().flush = IoState::Err(err);
}
/// Read any available data
pub fn remote_buffer_cap(&self, cap: usize) {
self.local.lock().unwrap().borrow_mut().buf_cap = cap;
}
/// Read any available data
pub fn read_any(&self) -> BytesMut {
self.read.lock().unwrap().borrow_mut().buf.split()
self.local.lock().unwrap().borrow_mut().buf.split()
}
/// Read data, if data is not available wait for it
pub async fn read(&self) -> Result<BytesMut, io::Error> {
if self.read.lock().unwrap().borrow().buf.is_empty() {
if self.local.lock().unwrap().borrow().buf.is_empty() {
poll_fn(|cx| {
let guard = self.read.lock().unwrap();
let guard = self.local.lock().unwrap();
let read = guard.borrow_mut();
if read.buf.is_empty() {
let closed = match self.tp {
Type::Client | Type::ClientClone => self.is_server_closed(),
Type::Server | Type::ServerClone => self.is_client_closed(),
Type::Client | Type::ClientClone => {
self.is_server_dropped() || read.is_closed()
}
Type::Server | Type::ServerClone => self.is_client_dropped(),
};
if closed {
Poll::Ready(())
} else {
read.read_waker.register(cx.waker());
read.waker.register(cx.waker());
drop(read);
drop(guard);
Poll::Pending
@ -154,7 +209,7 @@ impl Io {
})
.await;
}
Ok(self.read.lock().unwrap().borrow_mut().buf.split())
Ok(self.local.lock().unwrap().borrow_mut().buf.split())
}
}
@ -168,8 +223,8 @@ impl Clone for Io {
Io {
tp,
read: self.read.clone(),
write: self.write.clone(),
local: self.local.clone(),
remote: self.remote.clone(),
state: self.state.clone(),
}
}
@ -179,8 +234,8 @@ impl Drop for Io {
fn drop(&mut self) {
let mut state = self.state.get();
match self.tp {
Type::Server => state.server_closed = true,
Type::Client => state.client_closed = true,
Type::Server => state.server_dropped = true,
Type::Client => state.client_dropped = true,
_ => (),
}
self.state.set(state);
@ -193,9 +248,9 @@ impl AsyncRead for Io {
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let guard = self.read.lock().unwrap();
let guard = self.local.lock().unwrap();
let mut ch = guard.borrow_mut();
ch.read_waker.register(cx.waker());
ch.waker.register(cx.waker());
let result = if ch.buf.is_empty() {
if let Some(err) = ch.read_err.take() {
@ -223,23 +278,48 @@ impl AsyncWrite for Io {
_: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let guard = self.write.lock().unwrap();
let guard = self.remote.lock().unwrap();
let mut ch = guard.borrow_mut();
if let Some(err) = ch.write_err.take() {
Poll::Ready(Err(err))
match mem::take(&mut ch.write) {
IoState::Ok => {
let cap = cmp::min(buf.len(), ch.buf_cap);
if cap > 0 {
ch.buf.extend(&buf[..cap]);
ch.buf_cap -= cap;
ch.flags.remove(Flags::FLUSHED);
ch.waker.wake();
Poll::Ready(Ok(cap))
} else {
ch.write_waker.wake();
ch.buf.extend(buf);
Poll::Ready(Ok(buf.len()))
Poll::Pending
}
}
IoState::Pending => Poll::Pending,
IoState::Err(e) => Poll::Ready(Err(e)),
}
}
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
let guard = self.local.lock().unwrap();
let mut ch = guard.borrow_mut();
match mem::take(&mut ch.flush) {
IoState::Ok => {
ch.flags.insert(Flags::FLUSHED);
Poll::Ready(Ok(()))
}
IoState::Pending => Poll::Pending,
IoState::Err(e) => Poll::Ready(Err(e)),
}
}
fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
self.local
.lock()
.unwrap()
.borrow_mut()
.flags
.insert(Flags::CLOSED);
Poll::Ready(Ok(()))
}
}

View file

@ -30,7 +30,7 @@ struct Config {
host: Option<String>,
keep_alive: KeepAlive,
client_timeout: u64,
client_shutdown: u64,
client_disconnect: u64,
handshake_timeout: u64,
}
@ -89,7 +89,7 @@ where
host: None,
keep_alive: KeepAlive::Timeout(5),
client_timeout: 5000,
client_shutdown: 5000,
client_disconnect: 5000,
handshake_timeout: 5000,
})),
backlog: 1024,
@ -162,22 +162,22 @@ where
///
/// To disable timeout set value to 0.
///
/// By default client timeout is set to 5000 milliseconds.
/// By default client timeout is set to 5 seconds.
pub fn client_timeout(self, val: u64) -> Self {
self.config.lock().unwrap().client_timeout = val;
self
}
/// Set server connection shutdown timeout in milliseconds.
/// Set server connection disconnect timeout in milliseconds.
///
/// Defines a timeout for shutdown connection. If a shutdown procedure does not complete
/// within this time, the request is dropped.
///
/// To disable timeout set value to 0.
///
/// By default client timeout is set to 5000 milliseconds.
pub fn client_shutdown(self, val: u64) -> Self {
self.config.lock().unwrap().client_shutdown = val;
/// By default client timeout is set to 5 seconds.
pub fn disconnect_timeout(self, val: u64) -> Self {
self.config.lock().unwrap().client_disconnect = val;
self
}
@ -270,6 +270,7 @@ where
HttpService::build()
.keep_alive(c.keep_alive)
.client_timeout(c.client_timeout)
.disconnect_timeout(c.client_disconnect)
.finish(map_config(factory(), move |_| cfg.clone()))
.tcp()
},
@ -316,7 +317,7 @@ where
HttpService::build()
.keep_alive(c.keep_alive)
.client_timeout(c.client_timeout)
.client_disconnect(c.client_shutdown)
.disconnect_timeout(c.client_disconnect)
.ssl_handshake_timeout(c.handshake_timeout)
.finish(map_config(factory(), move |_| cfg.clone()))
.openssl(acceptor.clone())
@ -364,7 +365,7 @@ where
HttpService::build()
.keep_alive(c.keep_alive)
.client_timeout(c.client_timeout)
.client_disconnect(c.client_shutdown)
.disconnect_timeout(c.client_disconnect)
.ssl_handshake_timeout(c.handshake_timeout)
.finish(map_config(factory(), move |_| cfg.clone()))
.rustls(config.clone())

View file

@ -57,7 +57,12 @@ async fn test_simple() {
let bytes = response.body().await.unwrap();
assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
let mut response = srv.post("/").send().await.unwrap();
let mut response = srv
.post("/")
.timeout(Duration::from_secs(30))
.send()
.await
.unwrap();
assert!(response.status().is_success());
// read response

View file

@ -158,7 +158,7 @@ async fn test_h2_content_length() {
let req = srv
.srequest(Method::HEAD, &format!("/{}", i))
.timeout(Duration::from_secs(30))
.timeout(Duration::from_secs(100))
.send();
let response = req.await.unwrap();
assert_eq!(response.headers().get(&header), None);

View file

@ -21,7 +21,7 @@ async fn test_h1() {
HttpService::build()
.keep_alive(KeepAlive::Disabled)
.client_timeout(1000)
.client_disconnect(1000)
.disconnect_timeout(1000)
.h1(|req: Request| {
assert!(req.peer_addr().is_some());
future::ok::<_, io::Error>(Response::Ok().finish())
@ -39,7 +39,7 @@ async fn test_h1_2() {
HttpService::build()
.keep_alive(KeepAlive::Disabled)
.client_timeout(1000)
.client_disconnect(1000)
.disconnect_timeout(1000)
.finish(|req: Request| {
assert!(req.peer_addr().is_some());
assert_eq!(req.version(), http::Version::HTTP_11);

View file

@ -29,7 +29,8 @@ async fn test_start() {
.maxconnrate(10)
.keep_alive(10)
.client_timeout(5000)
.client_shutdown(0)
.disconnect_timeout(1000)
.ssl_handshake_timeout(1000)
.server_hostname("localhost")
.system_exit()
.disable_signals()