From 7e3a4c2d0099d88daea12a685e1ce3258cc7dea3 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Sun, 26 Dec 2021 15:46:27 +0600 Subject: [PATCH] Better error information for .poll_recv() method --- ntex-io/CHANGES.md | 6 +- ntex-io/src/dispatcher.rs | 107 ++++++++++------------- ntex-io/src/filter.rs | 3 + ntex-io/src/io.rs | 121 ++++++++++++++------------ ntex-io/src/ioref.rs | 39 ++------- ntex-io/src/lib.rs | 17 +++- ntex/CHANGES.md | 2 +- ntex/src/http/client/h1proto.rs | 58 ++++++++---- ntex/src/http/h1/dispatcher.rs | 150 +++++++++++++++++++------------- 9 files changed, 275 insertions(+), 228 deletions(-) diff --git a/ntex-io/CHANGES.md b/ntex-io/CHANGES.md index 80aeb8c4..4abeb161 100644 --- a/ntex-io/CHANGES.md +++ b/ntex-io/CHANGES.md @@ -1,6 +1,10 @@ # Changes -## [0.1.0-b.6] - 2021-12-xx +## [0.1.0-b.6] - 2021-12-26 + +* Better error information for .poll_recv() method. + +* Remove redundant Io::poll_write_backpressure() method. * Fix read filters ordering diff --git a/ntex-io/src/dispatcher.rs b/ntex-io/src/dispatcher.rs index a04bab3b..b3f61fc6 100644 --- a/ntex-io/src/dispatcher.rs +++ b/ntex-io/src/dispatcher.rs @@ -7,7 +7,7 @@ use ntex_service::{IntoService, Service}; use ntex_util::time::{now, Seconds}; use ntex_util::{future::Either, ready}; -use super::{rt::spawn, DispatchItem, IoBoxed, IoRef, Timer}; +use crate::{rt::spawn, DispatchItem, IoBoxed, IoRef, RecvError, Timer}; type Response = ::Item; @@ -36,7 +36,7 @@ where io: IoBoxed, st: Cell, timer: Timer, - ka_timeout: Seconds, + ka_timeout: Cell, ka_updated: Cell, error: Cell>, ready_err: Cell, @@ -100,10 +100,10 @@ where { let io = IoBoxed::from(io); let updated = now(); - let ka_timeout = Seconds(30); + let ka_timeout = Cell::new(Seconds(30)); // register keepalive timer - let expire = updated + time::Duration::from(ka_timeout); + let expire = updated + time::Duration::from(ka_timeout.get()); timer.register(expire, expire, &io); Dispatcher { @@ -132,7 +132,7 @@ where /// To disable timeout set value to 0. /// /// By default keep-alive timeout is set to 30 seconds. - pub fn keepalive_timeout(mut self, timeout: Seconds) -> Self { + pub fn keepalive_timeout(self, timeout: Seconds) -> Self { // register keepalive timer let prev = self.inner.ka_updated.get() + time::Duration::from(self.inner.ka()); if timeout.is_zero() { @@ -141,7 +141,7 @@ where let expire = self.inner.ka_updated.get() + time::Duration::from(timeout); self.inner.timer.register(expire, prev, &self.inner.io); } - self.inner.ka_timeout = timeout; + self.inner.ka_timeout.set(timeout); self } @@ -168,11 +168,11 @@ where fn handle_result(&self, item: Result, io: &IoRef) { self.inflight.set(self.inflight.get() - 1); match item { - Ok(Some(val)) => match io.encode(val, &self.codec) { - Ok(true) => (), - Ok(false) => io.enable_write_backpressure(), - Err(err) => self.error.set(Some(DispatcherError::Encoder(err))), - }, + Ok(Some(val)) => { + if let Err(err) = io.encode(val, &self.codec) { + self.error.set(Some(DispatcherError::Encoder(err))) + } + } Err(err) => self.error.set(Some(DispatcherError::Service(err))), Ok(None) => return, } @@ -216,31 +216,33 @@ where DispatcherState::Processing => { let item = match ready!(slf.poll_service(this.service, cx, io)) { PollService::Ready => { - match io.poll_write_backpressure(cx) { - Poll::Pending => { + // decode incoming bytes if buffer is ready + match ready!(io.poll_recv(&slf.shared.codec, cx)) { + Ok(el) => { + slf.update_keepalive(); + DispatchItem::Item(el) + } + Err(RecvError::KeepAlive) => { + slf.st.set(DispatcherState::Stop); + DispatchItem::KeepAliveTimeout + } + Err(RecvError::StopDispatcher) => { + log::trace!("dispatcher is instructed to stop"); + slf.st.set(DispatcherState::Stop); + continue; + } + Err(RecvError::WriteBackpressure) => { // instruct write task to notify dispatcher when data is flushed slf.st.set(DispatcherState::Backpressure); DispatchItem::WBackPressureEnabled } - Poll::Ready(()) => { - // decode incoming bytes if buffer is ready - match ready!(io.poll_recv(&slf.shared.codec, cx)) { - Ok(Some(el)) => { - slf.update_keepalive(); - DispatchItem::Item(el) - } - Err(Either::Left(err)) => { - slf.st.set(DispatcherState::Stop); - slf.unregister_keepalive(); - DispatchItem::DecoderError(err) - } - Err(Either::Right(err)) => { - slf.st.set(DispatcherState::Stop); - slf.unregister_keepalive(); - DispatchItem::Disconnect(Some(err)) - } - Ok(None) => DispatchItem::Disconnect(None), - } + Err(RecvError::Decoder(err)) => { + slf.st.set(DispatcherState::Stop); + DispatchItem::DecoderError(err) + } + Err(RecvError::PeerGone(err)) => { + slf.st.set(DispatcherState::Stop); + DispatchItem::Disconnect(err) } } } @@ -270,7 +272,7 @@ where let result = ready!(slf.poll_service(this.service, cx, io)); let item = match result { PollService::Ready => { - if slf.io.poll_write_backpressure(cx).is_ready() { + if slf.io.poll_flush(cx, false).is_ready() { slf.st.set(DispatcherState::Processing); DispatchItem::WBackPressureDisabled } else { @@ -300,6 +302,8 @@ where } // drain service responses and shutdown io DispatcherState::Stop => { + slf.unregister_keepalive(); + // service may relay on poll_ready for response results if !this.inner.ready_err.get() { let _ = this.service.poll_ready(cx); @@ -360,11 +364,11 @@ where io: &IoRef, ) { match item { - Ok(Some(item)) => match io.encode(item, &self.shared.codec) { - Ok(true) => (), - Ok(false) => io.enable_write_backpressure(), - Err(err) => self.shared.error.set(Some(DispatcherError::Encoder(err))), - }, + Ok(Some(item)) => { + if let Err(err) = io.encode(item, &self.shared.codec) { + self.shared.error.set(Some(DispatcherError::Encoder(err))) + } + } Err(err) => self.shared.error.set(Some(DispatcherError::Service(err))), Ok(None) => (), } @@ -384,7 +388,6 @@ where // check for errors Poll::Ready(if let Some(err) = self.shared.error.take() { log::trace!("error occured, stopping dispatcher"); - self.unregister_keepalive(); self.st.set(DispatcherState::Stop); match err { @@ -399,24 +402,6 @@ where PollService::ServiceError } } - } else if self.io.is_dispatcher_stopped() { - log::trace!("dispatcher is instructed to stop"); - - self.unregister_keepalive(); - - // process unhandled data - if let Ok(Some(el)) = io.decode(&self.shared.codec) { - PollService::Item(DispatchItem::Item(el)) - } else { - self.st.set(DispatcherState::Stop); - - // get io error - if let Some(err) = self.io.take_error() { - PollService::Item(DispatchItem::Disconnect(Some(err))) - } else { - PollService::ServiceError - } - } } else { PollService::Ready }) @@ -432,7 +417,6 @@ where log::trace!("service readiness check failed, stopping"); self.st.set(DispatcherState::Stop); self.error.set(Some(err)); - self.unregister_keepalive(); self.ready_err.set(true); Poll::Ready(PollService::ServiceError) } @@ -440,11 +424,11 @@ where } fn ka(&self) -> Seconds { - self.ka_timeout + self.ka_timeout.get() } fn ka_enabled(&self) -> bool { - self.ka_timeout.non_zero() + self.ka_timeout.get().non_zero() } /// check keepalive timeout @@ -475,6 +459,7 @@ where /// unregister keep-alive timer fn unregister_keepalive(&self) { if self.ka_enabled() { + self.ka_timeout.set(Seconds::ZERO); self.timer.unregister( self.ka_updated.get() + time::Duration::from(self.ka()), &self.io, @@ -533,7 +518,7 @@ mod tests { ) -> (Self, State) { let state = Io::new(io); let timer = Timer::default(); - let ka_timeout = Seconds(1); + let ka_timeout = Cell::new(Seconds(1)); let ka_updated = now(); let shared = Rc::new(DispatcherShared { codec: codec, diff --git a/ntex-io/src/filter.rs b/ntex-io/src/filter.rs index f8bd324e..2787c03a 100644 --- a/ntex-io/src/filter.rs +++ b/ntex-io/src/filter.rs @@ -127,6 +127,9 @@ impl Filter for Base { if buf.is_empty() { pool.release_write_buf(buf); } else { + if buf.len() >= pool.write_params_high() { + self.0 .0.insert_flags(Flags::WR_BACKPRESSURE); + } self.0 .0.write_buf.set(Some(buf)); self.0 .0.write_task.wake(); } diff --git a/ntex-io/src/io.rs b/ntex-io/src/io.rs index 357faa9c..f3e35e34 100644 --- a/ntex-io/src/io.rs +++ b/ntex-io/src/io.rs @@ -9,7 +9,7 @@ use ntex_util::{future::poll_fn, future::Either, task::LocalWaker, time::Millis} use super::filter::{Base, NullFilter}; use super::seal::{IoBoxed, Sealed}; use super::tasks::{ReadContext, WriteContext}; -use super::{Filter, FilterFactory, Handle, IoStream}; +use super::{Filter, FilterFactory, Handle, IoStream, RecvError}; bitflags::bitflags! { pub struct Flags: u16 { @@ -120,7 +120,7 @@ impl IoState { self.read_task.wake(); self.write_task.wake(); self.dispatch_task.wake(); - self.insert_flags(Flags::IO_ERR | Flags::DSP_STOP); + self.insert_flags(Flags::IO_ERR); self.notify_disconnect(); } @@ -419,7 +419,28 @@ impl Io { where U: Decoder, { - poll_fn(|cx| self.poll_recv(codec, cx)).await + loop { + return match poll_fn(|cx| self.poll_recv(codec, cx)).await { + Ok(item) => Ok(Some(item)), + Err(RecvError::KeepAlive) => Err(Either::Right(io::Error::new( + io::ErrorKind::Other, + "Keep-alive", + ))), + Err(RecvError::StopDispatcher) => Err(Either::Right(io::Error::new( + io::ErrorKind::Other, + "Dispatcher stopped", + ))), + Err(RecvError::WriteBackpressure) => { + poll_fn(|cx| self.poll_flush(cx, false)) + .await + .map_err(Either::Right)?; + continue; + } + Err(RecvError::Decoder(err)) => Err(Either::Left(err)), + Err(RecvError::PeerGone(Some(err))) => Err(Either::Right(err)), + Err(RecvError::PeerGone(None)) => Ok(None), + }; + } } #[inline] @@ -514,7 +535,6 @@ impl Io { } else if ready { log::trace!("waking up io read task"); flags.remove(Flags::RD_READY); - self.0 .0.read_task.wake(); self.0 .0.flags.set(flags); Poll::Ready(Ok(Some(()))) } else { @@ -528,25 +548,41 @@ impl Io { /// Decode codec item from incoming bytes stream. /// /// Wake read task and request to read more data if data is not enough for decoding. + /// If error get returned this method does not register waker for later wake up action. pub fn poll_recv( &self, codec: &U, cx: &mut Context<'_>, - ) -> Poll, Either>> + ) -> Poll>> where U: Decoder, { match self.decode(codec) { - Ok(Some(el)) => Poll::Ready(Ok(Some(el))), - Ok(None) => match self.poll_read_ready(cx) { - Poll::Pending | Poll::Ready(Ok(Some(()))) => { - log::trace!("not enough data to decode next frame"); - Poll::Pending + Ok(Some(el)) => Poll::Ready(Ok(el)), + Ok(None) => { + let flags = self.flags(); + if flags.contains(Flags::DSP_STOP) { + Poll::Ready(Err(RecvError::StopDispatcher)) + } else if flags.contains(Flags::DSP_KEEPALIVE) { + Poll::Ready(Err(RecvError::KeepAlive)) + } else if flags.contains(Flags::WR_BACKPRESSURE) { + Poll::Ready(Err(RecvError::WriteBackpressure)) + } else { + match self.poll_read_ready(cx) { + Poll::Pending | Poll::Ready(Ok(Some(()))) => { + log::trace!("not enough data to decode next frame"); + Poll::Pending + } + Poll::Ready(Err(e)) => { + Poll::Ready(Err(RecvError::PeerGone(Some(e)))) + } + Poll::Ready(Ok(None)) => { + Poll::Ready(Err(RecvError::PeerGone(None))) + } + } } - Poll::Ready(Err(e)) => Poll::Ready(Err(Either::Right(e))), - Poll::Ready(Ok(None)) => Poll::Ready(Ok(None)), - }, - Err(err) => Poll::Ready(Err(Either::Left(err))), + } + Err(err) => Poll::Ready(Err(RecvError::Decoder(err))), } } @@ -567,53 +603,26 @@ impl Io { .unwrap_or_else(|| io::Error::new(io::ErrorKind::Other, "disconnected")))); } - if let Some(buf) = self.0 .0.write_buf.take() { - let len = buf.len(); - if len != 0 { - self.0 .0.write_buf.set(Some(buf)); + let len = self + .0 + .0 + .with_write_buf(|buf| buf.as_ref().map(|b| b.len()).unwrap_or(0)); - if full { - self.0 .0.insert_flags(Flags::WR_WAIT); - self.0 .0.dispatch_task.register(cx.waker()); - return Poll::Pending; - } else if len >= self.0.memory_pool().write_params_high() << 1 { - self.0 .0.insert_flags(Flags::WR_BACKPRESSURE); - self.0 .0.dispatch_task.register(cx.waker()); - return Poll::Pending; - } else { - self.0 .0.remove_flags(Flags::WR_BACKPRESSURE); - } - } - } - - Poll::Ready(Ok(())) - } - - #[inline] - /// Wait until write task flushes data to io stream - /// - /// Write task must be waken up separately. - pub fn poll_write_backpressure(&self, cx: &mut Context<'_>) -> Poll<()> { - if !self.is_io_open() { - Poll::Ready(()) - } else if self.flags().contains(Flags::WR_BACKPRESSURE) { - self.0 .0.dispatch_task.register(cx.waker()); - Poll::Pending - } else { - let len = self - .0 - .0 - .with_write_buf(|buf| buf.as_ref().map(|b| b.len()).unwrap_or(0)); - let hw = self.memory_pool().write_params_high(); - if len >= hw { - log::trace!("enable write back-pressure"); + if len > 0 { + if full { + self.0 .0.insert_flags(Flags::WR_WAIT); + self.0 .0.dispatch_task.register(cx.waker()); + return Poll::Pending; + } else if len >= self.0.memory_pool().write_params_high() << 1 { self.0 .0.insert_flags(Flags::WR_BACKPRESSURE); self.0 .0.dispatch_task.register(cx.waker()); - Poll::Pending - } else { - Poll::Ready(()) + return Poll::Pending; } } + self.0 + .0 + .remove_flags(Flags::WR_WAIT | Flags::WR_BACKPRESSURE); + Poll::Ready(Ok(())) } #[inline] diff --git a/ntex-io/src/ioref.rs b/ntex-io/src/ioref.rs index dc06cd7d..7fbeaccb 100644 --- a/ntex-io/src/ioref.rs +++ b/ntex-io/src/ioref.rs @@ -1,6 +1,6 @@ use std::{any, fmt, io}; -use ntex_bytes::{BytesMut, PoolRef}; +use ntex_bytes::{BufMut, BytesMut, PoolRef}; use ntex_codec::{Decoder, Encoder}; use super::io::{Flags, IoRef, OnDisconnect}; @@ -68,12 +68,6 @@ impl IoRef { self.0.dispatch_task.wake(); } - #[inline] - /// Mark dispatcher as stopped - pub fn stop_dispatcher(&self) { - self.0.insert_flags(Flags::DSP_STOP); - } - #[inline] /// Gracefully close connection /// @@ -141,15 +135,6 @@ impl IoRef { len >= self.memory_pool().read_params_high() } - #[inline] - /// Wait until write task flushes data to io stream - /// - /// Write task must be waken up separately. - pub fn enable_write_backpressure(&self) { - log::trace!("enable write back-pressure"); - self.0.insert_flags(Flags::WR_BACKPRESSURE); - } - #[inline] /// Get mut access to write buffer pub fn with_write_buf(&self, f: F) -> Result @@ -185,7 +170,7 @@ impl IoRef { /// Encode and write item to a buffer and wake up write task /// /// Returns write buffer state, false is returned if write buffer if full. - pub fn encode(&self, item: U::Item, codec: &U) -> Result::Error> + pub fn encode(&self, item: U::Item, codec: &U) -> Result<(), ::Error> where U: Encoder, { @@ -200,25 +185,21 @@ impl IoRef { let (hw, lw) = self.memory_pool().write_params().unpack(); // make sure we've got room - let remaining = buf.capacity() - buf.len(); + let remaining = buf.remaining_mut(); if remaining < lw { buf.reserve(hw - remaining); } // encode item and wake write task - let result = codec.encode(item, &mut buf).map(|_| { - if is_write_sleep { - self.0.write_task.wake(); - } - buf.len() < hw - }); + codec.encode(item, &mut buf)?; + if is_write_sleep { + self.0.write_task.wake(); + } if let Err(err) = filter.release_write_buf(buf) { self.0.set_error(Some(err)); } - result - } else { - Ok(true) } + Ok(()) } #[inline] @@ -313,14 +294,13 @@ mod tests { sleep(Millis(50)).await; let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await; if let Poll::Ready(msg) = res { - assert_eq!(msg.unwrap().unwrap(), Bytes::from_static(BIN)); + assert_eq!(msg.unwrap(), Bytes::from_static(BIN)); } client.read_error(io::Error::new(io::ErrorKind::Other, "err")); let msg = state.recv(&BytesCodec).await; assert!(msg.is_err()); assert!(state.flags().contains(Flags::IO_ERR)); - assert!(state.flags().contains(Flags::DSP_STOP)); let (client, server) = IoTest::create(); client.remote_buffer_cap(1024); @@ -348,7 +328,6 @@ mod tests { let res = state.send(&BytesCodec, Bytes::from_static(b"test")).await; assert!(res.is_err()); assert!(state.flags().contains(Flags::IO_ERR)); - assert!(state.flags().contains(Flags::DSP_STOP)); let (client, server) = IoTest::create(); client.remote_buffer_cap(1024); diff --git a/ntex-io/src/lib.rs b/ntex-io/src/lib.rs index 9d31c4ab..9e744bd0 100644 --- a/ntex-io/src/lib.rs +++ b/ntex-io/src/lib.rs @@ -94,7 +94,22 @@ pub trait Handle { fn query(&self, id: TypeId) -> Option>; } -/// Framed transport item +/// Recv error +#[derive(Debug)] +pub enum RecvError { + /// Keep-alive timeout occured + KeepAlive, + /// Write backpressure is enabled + WriteBackpressure, + /// Dispatcher marked stopped + StopDispatcher, + /// Unrecoverable frame decoding errors + Decoder(U::Error), + /// Peer is disconnected + PeerGone(Option), +} + +/// Dispatcher item pub enum DispatchItem { Item(::Item), /// Write back-pressure enabled diff --git a/ntex/CHANGES.md b/ntex/CHANGES.md index 1bb4ed6e..c9a72074 100644 --- a/ntex/CHANGES.md +++ b/ntex/CHANGES.md @@ -1,6 +1,6 @@ # Changes -## [0.5.0-b.4] - 2021-12-xx +## [0.5.0-b.4] - 2021-12-26 * Allow to get access to ws transport codec diff --git a/ntex/src/http/client/h1proto.rs b/ntex/src/http/client/h1proto.rs index ed3c3198..61570fca 100644 --- a/ntex/src/http/client/h1proto.rs +++ b/ntex/src/http/client/h1proto.rs @@ -1,4 +1,4 @@ -use std::{io::Write, pin::Pin, task::Context, task::Poll, time::Instant}; +use std::{io, io::Write, pin::Pin, task::Context, task::Poll, time::Instant}; use crate::http::body::{BodySize, MessageBody}; use crate::http::error::PayloadError; @@ -6,8 +6,8 @@ use crate::http::h1; use crate::http::header::{HeaderMap, HeaderValue, HOST}; use crate::http::message::{RequestHeadType, ResponseHead}; use crate::http::payload::{Payload, PayloadStream}; -use crate::io::IoBoxed; -use crate::util::{poll_fn, BufMut, Bytes, BytesMut}; +use crate::io::{IoBoxed, RecvError}; +use crate::util::{poll_fn, ready, BufMut, Bytes, BytesMut}; use crate::Stream; use super::connection::{Connection, ConnectionType}; @@ -110,9 +110,8 @@ where loop { match poll_fn(|cx| body.poll_next_chunk(cx)).await { Some(result) => { - if !io.encode(h1::Message::Chunk(Some(result?)), codec)? { - io.flush(false).await?; - } + io.encode(h1::Message::Chunk(Some(result?)), codec)?; + io.flush(false).await?; } None => { io.encode(h1::Message::Chunk(None), codec)?; @@ -156,19 +155,40 @@ impl Stream for PlStream { cx: &mut Context<'_>, ) -> Poll> { let mut this = self.as_mut(); - match this.io.as_ref().unwrap().poll_recv(&this.codec, cx)? { - Poll::Pending => Poll::Pending, - Poll::Ready(Some(chunk)) => { - if let Some(chunk) = chunk { - Poll::Ready(Some(Ok(chunk))) - } else { - let io = this.io.take().unwrap(); - let force_close = !this.codec.keepalive(); - release_connection(io, force_close, this.created, this.pool.take()); - Poll::Ready(None) - } - } - Poll::Ready(None) => Poll::Ready(None), + loop { + return Poll::Ready(Some( + match ready!(this.io.as_ref().unwrap().poll_recv(&this.codec, cx)) { + Ok(chunk) => { + if let Some(chunk) = chunk { + Ok(chunk) + } else { + let io = this.io.take().unwrap(); + let force_close = !this.codec.keepalive(); + release_connection( + io, + force_close, + this.created, + this.pool.take(), + ); + return Poll::Ready(None); + } + } + Err(RecvError::KeepAlive) => { + Err(io::Error::new(io::ErrorKind::Other, "Keep-alive").into()) + } + Err(RecvError::StopDispatcher) => { + Err(io::Error::new(io::ErrorKind::Other, "Dispatcher stopped") + .into()) + } + Err(RecvError::WriteBackpressure) => { + ready!(this.io.as_ref().unwrap().poll_flush(cx, false))?; + continue; + } + Err(RecvError::Decoder(err)) => Err(err), + Err(RecvError::PeerGone(Some(err))) => Err(err.into()), + Err(RecvError::PeerGone(None)) => return Poll::Ready(None), + }, + )); } } } diff --git a/ntex/src/http/h1/dispatcher.rs b/ntex/src/http/h1/dispatcher.rs index 6299b35f..11b60d75 100644 --- a/ntex/src/http/h1/dispatcher.rs +++ b/ntex/src/http/h1/dispatcher.rs @@ -1,10 +1,10 @@ //! Framed transport dispatcher use std::task::{Context, Poll}; -use std::{error::Error, fmt, future::Future, marker, pin::Pin, rc::Rc, time}; +use std::{error::Error, fmt, future::Future, io, marker, pin::Pin, rc::Rc, time}; -use crate::io::{Filter, Io, IoRef}; +use crate::io::{Filter, Io, IoRef, RecvError}; use crate::service::Service; -use crate::{time::now, util::ready, util::Bytes, util::Either}; +use crate::{time::now, util::ready, util::Bytes}; use crate::http; use crate::http::body::{BodySize, MessageBody, ResponseBody}; @@ -122,7 +122,6 @@ where macro_rules! set_error ({ $slf:tt, $err:ident } => { *$slf.st = State::Stop; $slf.inner.error = Some($err); - $slf.inner.unregister_keepalive(); }); impl Future for Dispatcher @@ -239,35 +238,11 @@ where State::ReadRequest => { log::trace!("trying to read http message"); - // stop dispatcher - if this.inner.io().is_dispatcher_stopped() { - log::trace!("dispatcher is instructed to stop"); - *this.st = State::Stop; - this.inner.unregister_keepalive(); - continue; - } - - // keep-alive timeout - if this.inner.state.is_keepalive() { - if !this.inner.flags.contains(Flags::STARTED) { - log::trace!("slow request timeout"); - let (req, body) = - Response::RequestTimeout().finish().into_parts(); - let _ = this.inner.send_response(req, body.into_body()); - this.inner.error = Some(DispatchError::SlowRequestTimeout); - } else { - log::trace!("keep-alive timeout, close connection"); - } - *this.st = State::Stop; - this.inner.unregister_keepalive(); - continue; - } - let io = this.inner.io(); // decode incoming bytes stream match ready!(io.poll_recv(&this.inner.codec, cx)) { - Ok(Some((mut req, pl))) => { + Ok((mut req, pl)) => { log::trace!( "http message is received: {:?} and payload {:?}", req, @@ -332,24 +307,43 @@ where ); } } - Ok(None) => { - // peer is gone - log::trace!("peer is gone"); - let e = DispatchError::Disconnect(None); - set_error!(this, e); + Err(RecvError::WriteBackpressure) => { + if let Err(err) = ready!(this.inner.io().poll_flush(cx, false)) + { + log::trace!("peer is gone with {:?}", err); + *this.st = State::Stop; + this.inner.error = + Some(DispatchError::Disconnect(Some(err))); + } } - Err(Either::Left(err)) => { + Err(RecvError::Decoder(err)) => { // Malformed requests, respond with 400 log::trace!("malformed request: {:?}", err); let (res, body) = Response::BadRequest().finish().into_parts(); this.inner.error = Some(DispatchError::Parse(err)); *this.st = this.inner.send_response(res, body.into_body()); } - Err(Either::Right(err)) => { + Err(RecvError::PeerGone(err)) => { log::trace!("peer is gone with {:?}", err); - // peer is gone - let e = DispatchError::Disconnect(Some(err)); - set_error!(this, e); + *this.st = State::Stop; + this.inner.error = Some(DispatchError::Disconnect(err)); + } + Err(RecvError::StopDispatcher) => { + log::trace!("dispatcher is instructed to stop"); + *this.st = State::Stop; + } + Err(RecvError::KeepAlive) => { + // keep-alive timeout + if !this.inner.flags.contains(Flags::STARTED) { + log::trace!("slow request timeout"); + let (req, body) = + Response::RequestTimeout().finish().into_parts(); + let _ = this.inner.send_response(req, body.into_body()); + this.inner.error = Some(DispatchError::SlowRequestTimeout); + } else { + log::trace!("keep-alive timeout, close connection"); + } + *this.st = State::Stop; } } } @@ -371,7 +365,7 @@ where set_error!(this, e); } else { loop { - ready!(this.inner.io().poll_write_backpressure(cx)); + let _ = ready!(this.inner.io().poll_flush(cx, false)); let item = ready!(body.poll_next_chunk(cx)); if let Some(st) = this.inner.send_payload(item) { *this.st = st; @@ -397,6 +391,8 @@ where } // prepare to shutdown State::Stop => { + this.inner.unregister_keepalive(); + if this .inner .io @@ -441,7 +437,7 @@ where // connection is not keep-alive, disconnect if !self.flags.contains(Flags::KEEPALIVE) || !self.codec.keepalive_enabled() { self.unregister_keepalive(); - self.state.stop_dispatcher(); + self.state.close(); State::Stop } else { self.reset_keepalive(); @@ -452,6 +448,7 @@ where fn unregister_keepalive(&mut self) { if self.flags.contains(Flags::KEEPALIVE) { self.config.timer_h1.unregister(self.expire, &self.state); + self.flags.remove(Flags::KEEPALIVE); } } @@ -583,28 +580,64 @@ where loop { let res = io.poll_recv(&payload.0, cx); match res { - Poll::Ready(Ok(Some(PayloadItem::Chunk(chunk)))) => { + Poll::Ready(Ok(PayloadItem::Chunk(chunk))) => { updated = true; payload.1.feed_data(chunk); } - Poll::Ready(Ok(Some(PayloadItem::Eof))) => { + Poll::Ready(Ok(PayloadItem::Eof)) => { updated = true; payload.1.feed_eof(); self.payload = None; break; } - Poll::Ready(Ok(None)) => { - payload.1.set_error(PayloadError::EncodingCorrupted); - self.payload = None; - return Poll::Ready(Err(ParseError::Incomplete.into())); - } - Poll::Ready(Err(e)) => { - payload.1.set_error(PayloadError::EncodingCorrupted); - self.payload = None; - return Poll::Ready(Err(match e { - Either::Left(e) => DispatchError::Parse(e), - Either::Right(e) => DispatchError::Disconnect(Some(e)), - })); + Poll::Ready(Err(err)) => { + let err = match err { + RecvError::WriteBackpressure => { + if io.poll_flush(cx, false)?.is_pending() { + break; + } else { + continue; + } + } + RecvError::KeepAlive => { + payload + .1 + .set_error(PayloadError::EncodingCorrupted); + self.payload = None; + io::Error::new(io::ErrorKind::Other, "Keep-alive") + .into() + } + RecvError::StopDispatcher => { + payload + .1 + .set_error(PayloadError::EncodingCorrupted); + self.payload = None; + io::Error::new( + io::ErrorKind::Other, + "Dispatcher stopped", + ) + .into() + } + RecvError::PeerGone(err) => { + payload + .1 + .set_error(PayloadError::EncodingCorrupted); + self.payload = None; + if let Some(err) = err { + DispatchError::Disconnect(Some(err)) + } else { + ParseError::Incomplete.into() + } + } + RecvError::Decoder(e) => { + payload + .1 + .set_error(PayloadError::EncodingCorrupted); + self.payload = None; + DispatchError::Parse(e) + } + }; + return Poll::Ready(Err(err)); } Poll::Pending => break, } @@ -870,9 +903,8 @@ mod tests { } #[crate::rt_test] - /// if socket is disconnected, h1 dispatcher does not process any data - // /// h1 dispatcher still processes all incoming requests - // /// but it does not write any data to socket + /// /// h1 dispatcher still processes all incoming requests + /// /// but it does not write any data to socket async fn test_write_disconnected() { let num = Arc::new(AtomicUsize::new(0)); let num2 = num.clone(); @@ -892,7 +924,7 @@ mod tests { assert!(client.read_any().is_empty()); // only first request get handled - assert_eq!(num.load(Ordering::Relaxed), 0); + assert_eq!(num.load(Ordering::Relaxed), 1); } #[crate::rt_test]