Better error information for .poll_recv() method

This commit is contained in:
Nikolay Kim 2021-12-26 15:46:27 +06:00
parent 6128b7851d
commit 7e3a4c2d00
9 changed files with 275 additions and 228 deletions

View file

@ -1,6 +1,10 @@
# Changes
## [0.1.0-b.6] - 2021-12-xx
## [0.1.0-b.6] - 2021-12-26
* Better error information for .poll_recv() method.
* Remove redundant Io::poll_write_backpressure() method.
* Fix read filters ordering

View file

@ -7,7 +7,7 @@ use ntex_service::{IntoService, Service};
use ntex_util::time::{now, Seconds};
use ntex_util::{future::Either, ready};
use super::{rt::spawn, DispatchItem, IoBoxed, IoRef, Timer};
use crate::{rt::spawn, DispatchItem, IoBoxed, IoRef, RecvError, Timer};
type Response<U> = <U as Encoder>::Item;
@ -36,7 +36,7 @@ where
io: IoBoxed,
st: Cell<DispatcherState>,
timer: Timer,
ka_timeout: Seconds,
ka_timeout: Cell<Seconds>,
ka_updated: Cell<time::Instant>,
error: Cell<Option<S::Error>>,
ready_err: Cell<bool>,
@ -100,10 +100,10 @@ where
{
let io = IoBoxed::from(io);
let updated = now();
let ka_timeout = Seconds(30);
let ka_timeout = Cell::new(Seconds(30));
// register keepalive timer
let expire = updated + time::Duration::from(ka_timeout);
let expire = updated + time::Duration::from(ka_timeout.get());
timer.register(expire, expire, &io);
Dispatcher {
@ -132,7 +132,7 @@ where
/// To disable timeout set value to 0.
///
/// By default keep-alive timeout is set to 30 seconds.
pub fn keepalive_timeout(mut self, timeout: Seconds) -> Self {
pub fn keepalive_timeout(self, timeout: Seconds) -> Self {
// register keepalive timer
let prev = self.inner.ka_updated.get() + time::Duration::from(self.inner.ka());
if timeout.is_zero() {
@ -141,7 +141,7 @@ where
let expire = self.inner.ka_updated.get() + time::Duration::from(timeout);
self.inner.timer.register(expire, prev, &self.inner.io);
}
self.inner.ka_timeout = timeout;
self.inner.ka_timeout.set(timeout);
self
}
@ -168,11 +168,11 @@ where
fn handle_result(&self, item: Result<S::Response, S::Error>, io: &IoRef) {
self.inflight.set(self.inflight.get() - 1);
match item {
Ok(Some(val)) => match io.encode(val, &self.codec) {
Ok(true) => (),
Ok(false) => io.enable_write_backpressure(),
Err(err) => self.error.set(Some(DispatcherError::Encoder(err))),
},
Ok(Some(val)) => {
if let Err(err) = io.encode(val, &self.codec) {
self.error.set(Some(DispatcherError::Encoder(err)))
}
}
Err(err) => self.error.set(Some(DispatcherError::Service(err))),
Ok(None) => return,
}
@ -216,31 +216,33 @@ where
DispatcherState::Processing => {
let item = match ready!(slf.poll_service(this.service, cx, io)) {
PollService::Ready => {
match io.poll_write_backpressure(cx) {
Poll::Pending => {
// decode incoming bytes if buffer is ready
match ready!(io.poll_recv(&slf.shared.codec, cx)) {
Ok(el) => {
slf.update_keepalive();
DispatchItem::Item(el)
}
Err(RecvError::KeepAlive) => {
slf.st.set(DispatcherState::Stop);
DispatchItem::KeepAliveTimeout
}
Err(RecvError::StopDispatcher) => {
log::trace!("dispatcher is instructed to stop");
slf.st.set(DispatcherState::Stop);
continue;
}
Err(RecvError::WriteBackpressure) => {
// instruct write task to notify dispatcher when data is flushed
slf.st.set(DispatcherState::Backpressure);
DispatchItem::WBackPressureEnabled
}
Poll::Ready(()) => {
// decode incoming bytes if buffer is ready
match ready!(io.poll_recv(&slf.shared.codec, cx)) {
Ok(Some(el)) => {
slf.update_keepalive();
DispatchItem::Item(el)
}
Err(Either::Left(err)) => {
slf.st.set(DispatcherState::Stop);
slf.unregister_keepalive();
DispatchItem::DecoderError(err)
}
Err(Either::Right(err)) => {
slf.st.set(DispatcherState::Stop);
slf.unregister_keepalive();
DispatchItem::Disconnect(Some(err))
}
Ok(None) => DispatchItem::Disconnect(None),
}
Err(RecvError::Decoder(err)) => {
slf.st.set(DispatcherState::Stop);
DispatchItem::DecoderError(err)
}
Err(RecvError::PeerGone(err)) => {
slf.st.set(DispatcherState::Stop);
DispatchItem::Disconnect(err)
}
}
}
@ -270,7 +272,7 @@ where
let result = ready!(slf.poll_service(this.service, cx, io));
let item = match result {
PollService::Ready => {
if slf.io.poll_write_backpressure(cx).is_ready() {
if slf.io.poll_flush(cx, false).is_ready() {
slf.st.set(DispatcherState::Processing);
DispatchItem::WBackPressureDisabled
} else {
@ -300,6 +302,8 @@ where
}
// drain service responses and shutdown io
DispatcherState::Stop => {
slf.unregister_keepalive();
// service may relay on poll_ready for response results
if !this.inner.ready_err.get() {
let _ = this.service.poll_ready(cx);
@ -360,11 +364,11 @@ where
io: &IoRef,
) {
match item {
Ok(Some(item)) => match io.encode(item, &self.shared.codec) {
Ok(true) => (),
Ok(false) => io.enable_write_backpressure(),
Err(err) => self.shared.error.set(Some(DispatcherError::Encoder(err))),
},
Ok(Some(item)) => {
if let Err(err) = io.encode(item, &self.shared.codec) {
self.shared.error.set(Some(DispatcherError::Encoder(err)))
}
}
Err(err) => self.shared.error.set(Some(DispatcherError::Service(err))),
Ok(None) => (),
}
@ -384,7 +388,6 @@ where
// check for errors
Poll::Ready(if let Some(err) = self.shared.error.take() {
log::trace!("error occured, stopping dispatcher");
self.unregister_keepalive();
self.st.set(DispatcherState::Stop);
match err {
@ -399,24 +402,6 @@ where
PollService::ServiceError
}
}
} else if self.io.is_dispatcher_stopped() {
log::trace!("dispatcher is instructed to stop");
self.unregister_keepalive();
// process unhandled data
if let Ok(Some(el)) = io.decode(&self.shared.codec) {
PollService::Item(DispatchItem::Item(el))
} else {
self.st.set(DispatcherState::Stop);
// get io error
if let Some(err) = self.io.take_error() {
PollService::Item(DispatchItem::Disconnect(Some(err)))
} else {
PollService::ServiceError
}
}
} else {
PollService::Ready
})
@ -432,7 +417,6 @@ where
log::trace!("service readiness check failed, stopping");
self.st.set(DispatcherState::Stop);
self.error.set(Some(err));
self.unregister_keepalive();
self.ready_err.set(true);
Poll::Ready(PollService::ServiceError)
}
@ -440,11 +424,11 @@ where
}
fn ka(&self) -> Seconds {
self.ka_timeout
self.ka_timeout.get()
}
fn ka_enabled(&self) -> bool {
self.ka_timeout.non_zero()
self.ka_timeout.get().non_zero()
}
/// check keepalive timeout
@ -475,6 +459,7 @@ where
/// unregister keep-alive timer
fn unregister_keepalive(&self) {
if self.ka_enabled() {
self.ka_timeout.set(Seconds::ZERO);
self.timer.unregister(
self.ka_updated.get() + time::Duration::from(self.ka()),
&self.io,
@ -533,7 +518,7 @@ mod tests {
) -> (Self, State) {
let state = Io::new(io);
let timer = Timer::default();
let ka_timeout = Seconds(1);
let ka_timeout = Cell::new(Seconds(1));
let ka_updated = now();
let shared = Rc::new(DispatcherShared {
codec: codec,

View file

@ -127,6 +127,9 @@ impl Filter for Base {
if buf.is_empty() {
pool.release_write_buf(buf);
} else {
if buf.len() >= pool.write_params_high() {
self.0 .0.insert_flags(Flags::WR_BACKPRESSURE);
}
self.0 .0.write_buf.set(Some(buf));
self.0 .0.write_task.wake();
}

View file

@ -9,7 +9,7 @@ use ntex_util::{future::poll_fn, future::Either, task::LocalWaker, time::Millis}
use super::filter::{Base, NullFilter};
use super::seal::{IoBoxed, Sealed};
use super::tasks::{ReadContext, WriteContext};
use super::{Filter, FilterFactory, Handle, IoStream};
use super::{Filter, FilterFactory, Handle, IoStream, RecvError};
bitflags::bitflags! {
pub struct Flags: u16 {
@ -120,7 +120,7 @@ impl IoState {
self.read_task.wake();
self.write_task.wake();
self.dispatch_task.wake();
self.insert_flags(Flags::IO_ERR | Flags::DSP_STOP);
self.insert_flags(Flags::IO_ERR);
self.notify_disconnect();
}
@ -419,7 +419,28 @@ impl<F> Io<F> {
where
U: Decoder,
{
poll_fn(|cx| self.poll_recv(codec, cx)).await
loop {
return match poll_fn(|cx| self.poll_recv(codec, cx)).await {
Ok(item) => Ok(Some(item)),
Err(RecvError::KeepAlive) => Err(Either::Right(io::Error::new(
io::ErrorKind::Other,
"Keep-alive",
))),
Err(RecvError::StopDispatcher) => Err(Either::Right(io::Error::new(
io::ErrorKind::Other,
"Dispatcher stopped",
))),
Err(RecvError::WriteBackpressure) => {
poll_fn(|cx| self.poll_flush(cx, false))
.await
.map_err(Either::Right)?;
continue;
}
Err(RecvError::Decoder(err)) => Err(Either::Left(err)),
Err(RecvError::PeerGone(Some(err))) => Err(Either::Right(err)),
Err(RecvError::PeerGone(None)) => Ok(None),
};
}
}
#[inline]
@ -514,7 +535,6 @@ impl<F> Io<F> {
} else if ready {
log::trace!("waking up io read task");
flags.remove(Flags::RD_READY);
self.0 .0.read_task.wake();
self.0 .0.flags.set(flags);
Poll::Ready(Ok(Some(())))
} else {
@ -528,25 +548,41 @@ impl<F> Io<F> {
/// Decode codec item from incoming bytes stream.
///
/// Wake read task and request to read more data if data is not enough for decoding.
/// If error get returned this method does not register waker for later wake up action.
pub fn poll_recv<U>(
&self,
codec: &U,
cx: &mut Context<'_>,
) -> Poll<Result<Option<U::Item>, Either<U::Error, io::Error>>>
) -> Poll<Result<U::Item, RecvError<U>>>
where
U: Decoder,
{
match self.decode(codec) {
Ok(Some(el)) => Poll::Ready(Ok(Some(el))),
Ok(None) => match self.poll_read_ready(cx) {
Poll::Pending | Poll::Ready(Ok(Some(()))) => {
log::trace!("not enough data to decode next frame");
Poll::Pending
Ok(Some(el)) => Poll::Ready(Ok(el)),
Ok(None) => {
let flags = self.flags();
if flags.contains(Flags::DSP_STOP) {
Poll::Ready(Err(RecvError::StopDispatcher))
} else if flags.contains(Flags::DSP_KEEPALIVE) {
Poll::Ready(Err(RecvError::KeepAlive))
} else if flags.contains(Flags::WR_BACKPRESSURE) {
Poll::Ready(Err(RecvError::WriteBackpressure))
} else {
match self.poll_read_ready(cx) {
Poll::Pending | Poll::Ready(Ok(Some(()))) => {
log::trace!("not enough data to decode next frame");
Poll::Pending
}
Poll::Ready(Err(e)) => {
Poll::Ready(Err(RecvError::PeerGone(Some(e))))
}
Poll::Ready(Ok(None)) => {
Poll::Ready(Err(RecvError::PeerGone(None)))
}
}
}
Poll::Ready(Err(e)) => Poll::Ready(Err(Either::Right(e))),
Poll::Ready(Ok(None)) => Poll::Ready(Ok(None)),
},
Err(err) => Poll::Ready(Err(Either::Left(err))),
}
Err(err) => Poll::Ready(Err(RecvError::Decoder(err))),
}
}
@ -567,53 +603,26 @@ impl<F> Io<F> {
.unwrap_or_else(|| io::Error::new(io::ErrorKind::Other, "disconnected"))));
}
if let Some(buf) = self.0 .0.write_buf.take() {
let len = buf.len();
if len != 0 {
self.0 .0.write_buf.set(Some(buf));
let len = self
.0
.0
.with_write_buf(|buf| buf.as_ref().map(|b| b.len()).unwrap_or(0));
if full {
self.0 .0.insert_flags(Flags::WR_WAIT);
self.0 .0.dispatch_task.register(cx.waker());
return Poll::Pending;
} else if len >= self.0.memory_pool().write_params_high() << 1 {
self.0 .0.insert_flags(Flags::WR_BACKPRESSURE);
self.0 .0.dispatch_task.register(cx.waker());
return Poll::Pending;
} else {
self.0 .0.remove_flags(Flags::WR_BACKPRESSURE);
}
}
}
Poll::Ready(Ok(()))
}
#[inline]
/// Wait until write task flushes data to io stream
///
/// Write task must be waken up separately.
pub fn poll_write_backpressure(&self, cx: &mut Context<'_>) -> Poll<()> {
if !self.is_io_open() {
Poll::Ready(())
} else if self.flags().contains(Flags::WR_BACKPRESSURE) {
self.0 .0.dispatch_task.register(cx.waker());
Poll::Pending
} else {
let len = self
.0
.0
.with_write_buf(|buf| buf.as_ref().map(|b| b.len()).unwrap_or(0));
let hw = self.memory_pool().write_params_high();
if len >= hw {
log::trace!("enable write back-pressure");
if len > 0 {
if full {
self.0 .0.insert_flags(Flags::WR_WAIT);
self.0 .0.dispatch_task.register(cx.waker());
return Poll::Pending;
} else if len >= self.0.memory_pool().write_params_high() << 1 {
self.0 .0.insert_flags(Flags::WR_BACKPRESSURE);
self.0 .0.dispatch_task.register(cx.waker());
Poll::Pending
} else {
Poll::Ready(())
return Poll::Pending;
}
}
self.0
.0
.remove_flags(Flags::WR_WAIT | Flags::WR_BACKPRESSURE);
Poll::Ready(Ok(()))
}
#[inline]

View file

@ -1,6 +1,6 @@
use std::{any, fmt, io};
use ntex_bytes::{BytesMut, PoolRef};
use ntex_bytes::{BufMut, BytesMut, PoolRef};
use ntex_codec::{Decoder, Encoder};
use super::io::{Flags, IoRef, OnDisconnect};
@ -68,12 +68,6 @@ impl IoRef {
self.0.dispatch_task.wake();
}
#[inline]
/// Mark dispatcher as stopped
pub fn stop_dispatcher(&self) {
self.0.insert_flags(Flags::DSP_STOP);
}
#[inline]
/// Gracefully close connection
///
@ -141,15 +135,6 @@ impl IoRef {
len >= self.memory_pool().read_params_high()
}
#[inline]
/// Wait until write task flushes data to io stream
///
/// Write task must be waken up separately.
pub fn enable_write_backpressure(&self) {
log::trace!("enable write back-pressure");
self.0.insert_flags(Flags::WR_BACKPRESSURE);
}
#[inline]
/// Get mut access to write buffer
pub fn with_write_buf<F, R>(&self, f: F) -> Result<R, io::Error>
@ -185,7 +170,7 @@ impl IoRef {
/// Encode and write item to a buffer and wake up write task
///
/// Returns write buffer state, false is returned if write buffer if full.
pub fn encode<U>(&self, item: U::Item, codec: &U) -> Result<bool, <U as Encoder>::Error>
pub fn encode<U>(&self, item: U::Item, codec: &U) -> Result<(), <U as Encoder>::Error>
where
U: Encoder,
{
@ -200,25 +185,21 @@ impl IoRef {
let (hw, lw) = self.memory_pool().write_params().unpack();
// make sure we've got room
let remaining = buf.capacity() - buf.len();
let remaining = buf.remaining_mut();
if remaining < lw {
buf.reserve(hw - remaining);
}
// encode item and wake write task
let result = codec.encode(item, &mut buf).map(|_| {
if is_write_sleep {
self.0.write_task.wake();
}
buf.len() < hw
});
codec.encode(item, &mut buf)?;
if is_write_sleep {
self.0.write_task.wake();
}
if let Err(err) = filter.release_write_buf(buf) {
self.0.set_error(Some(err));
}
result
} else {
Ok(true)
}
Ok(())
}
#[inline]
@ -313,14 +294,13 @@ mod tests {
sleep(Millis(50)).await;
let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await;
if let Poll::Ready(msg) = res {
assert_eq!(msg.unwrap().unwrap(), Bytes::from_static(BIN));
assert_eq!(msg.unwrap(), Bytes::from_static(BIN));
}
client.read_error(io::Error::new(io::ErrorKind::Other, "err"));
let msg = state.recv(&BytesCodec).await;
assert!(msg.is_err());
assert!(state.flags().contains(Flags::IO_ERR));
assert!(state.flags().contains(Flags::DSP_STOP));
let (client, server) = IoTest::create();
client.remote_buffer_cap(1024);
@ -348,7 +328,6 @@ mod tests {
let res = state.send(&BytesCodec, Bytes::from_static(b"test")).await;
assert!(res.is_err());
assert!(state.flags().contains(Flags::IO_ERR));
assert!(state.flags().contains(Flags::DSP_STOP));
let (client, server) = IoTest::create();
client.remote_buffer_cap(1024);

View file

@ -94,7 +94,22 @@ pub trait Handle {
fn query(&self, id: TypeId) -> Option<Box<dyn Any>>;
}
/// Framed transport item
/// Recv error
#[derive(Debug)]
pub enum RecvError<U: Decoder> {
/// Keep-alive timeout occured
KeepAlive,
/// Write backpressure is enabled
WriteBackpressure,
/// Dispatcher marked stopped
StopDispatcher,
/// Unrecoverable frame decoding errors
Decoder(U::Error),
/// Peer is disconnected
PeerGone(Option<sio::Error>),
}
/// Dispatcher item
pub enum DispatchItem<U: Encoder + Decoder> {
Item(<U as Decoder>::Item),
/// Write back-pressure enabled

View file

@ -1,6 +1,6 @@
# Changes
## [0.5.0-b.4] - 2021-12-xx
## [0.5.0-b.4] - 2021-12-26
* Allow to get access to ws transport codec

View file

@ -1,4 +1,4 @@
use std::{io::Write, pin::Pin, task::Context, task::Poll, time::Instant};
use std::{io, io::Write, pin::Pin, task::Context, task::Poll, time::Instant};
use crate::http::body::{BodySize, MessageBody};
use crate::http::error::PayloadError;
@ -6,8 +6,8 @@ use crate::http::h1;
use crate::http::header::{HeaderMap, HeaderValue, HOST};
use crate::http::message::{RequestHeadType, ResponseHead};
use crate::http::payload::{Payload, PayloadStream};
use crate::io::IoBoxed;
use crate::util::{poll_fn, BufMut, Bytes, BytesMut};
use crate::io::{IoBoxed, RecvError};
use crate::util::{poll_fn, ready, BufMut, Bytes, BytesMut};
use crate::Stream;
use super::connection::{Connection, ConnectionType};
@ -110,9 +110,8 @@ where
loop {
match poll_fn(|cx| body.poll_next_chunk(cx)).await {
Some(result) => {
if !io.encode(h1::Message::Chunk(Some(result?)), codec)? {
io.flush(false).await?;
}
io.encode(h1::Message::Chunk(Some(result?)), codec)?;
io.flush(false).await?;
}
None => {
io.encode(h1::Message::Chunk(None), codec)?;
@ -156,19 +155,40 @@ impl Stream for PlStream {
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let mut this = self.as_mut();
match this.io.as_ref().unwrap().poll_recv(&this.codec, cx)? {
Poll::Pending => Poll::Pending,
Poll::Ready(Some(chunk)) => {
if let Some(chunk) = chunk {
Poll::Ready(Some(Ok(chunk)))
} else {
let io = this.io.take().unwrap();
let force_close = !this.codec.keepalive();
release_connection(io, force_close, this.created, this.pool.take());
Poll::Ready(None)
}
}
Poll::Ready(None) => Poll::Ready(None),
loop {
return Poll::Ready(Some(
match ready!(this.io.as_ref().unwrap().poll_recv(&this.codec, cx)) {
Ok(chunk) => {
if let Some(chunk) = chunk {
Ok(chunk)
} else {
let io = this.io.take().unwrap();
let force_close = !this.codec.keepalive();
release_connection(
io,
force_close,
this.created,
this.pool.take(),
);
return Poll::Ready(None);
}
}
Err(RecvError::KeepAlive) => {
Err(io::Error::new(io::ErrorKind::Other, "Keep-alive").into())
}
Err(RecvError::StopDispatcher) => {
Err(io::Error::new(io::ErrorKind::Other, "Dispatcher stopped")
.into())
}
Err(RecvError::WriteBackpressure) => {
ready!(this.io.as_ref().unwrap().poll_flush(cx, false))?;
continue;
}
Err(RecvError::Decoder(err)) => Err(err),
Err(RecvError::PeerGone(Some(err))) => Err(err.into()),
Err(RecvError::PeerGone(None)) => return Poll::Ready(None),
},
));
}
}
}

View file

@ -1,10 +1,10 @@
//! Framed transport dispatcher
use std::task::{Context, Poll};
use std::{error::Error, fmt, future::Future, marker, pin::Pin, rc::Rc, time};
use std::{error::Error, fmt, future::Future, io, marker, pin::Pin, rc::Rc, time};
use crate::io::{Filter, Io, IoRef};
use crate::io::{Filter, Io, IoRef, RecvError};
use crate::service::Service;
use crate::{time::now, util::ready, util::Bytes, util::Either};
use crate::{time::now, util::ready, util::Bytes};
use crate::http;
use crate::http::body::{BodySize, MessageBody, ResponseBody};
@ -122,7 +122,6 @@ where
macro_rules! set_error ({ $slf:tt, $err:ident } => {
*$slf.st = State::Stop;
$slf.inner.error = Some($err);
$slf.inner.unregister_keepalive();
});
impl<F, S, B, X, U> Future for Dispatcher<F, S, B, X, U>
@ -239,35 +238,11 @@ where
State::ReadRequest => {
log::trace!("trying to read http message");
// stop dispatcher
if this.inner.io().is_dispatcher_stopped() {
log::trace!("dispatcher is instructed to stop");
*this.st = State::Stop;
this.inner.unregister_keepalive();
continue;
}
// keep-alive timeout
if this.inner.state.is_keepalive() {
if !this.inner.flags.contains(Flags::STARTED) {
log::trace!("slow request timeout");
let (req, body) =
Response::RequestTimeout().finish().into_parts();
let _ = this.inner.send_response(req, body.into_body());
this.inner.error = Some(DispatchError::SlowRequestTimeout);
} else {
log::trace!("keep-alive timeout, close connection");
}
*this.st = State::Stop;
this.inner.unregister_keepalive();
continue;
}
let io = this.inner.io();
// decode incoming bytes stream
match ready!(io.poll_recv(&this.inner.codec, cx)) {
Ok(Some((mut req, pl))) => {
Ok((mut req, pl)) => {
log::trace!(
"http message is received: {:?} and payload {:?}",
req,
@ -332,24 +307,43 @@ where
);
}
}
Ok(None) => {
// peer is gone
log::trace!("peer is gone");
let e = DispatchError::Disconnect(None);
set_error!(this, e);
Err(RecvError::WriteBackpressure) => {
if let Err(err) = ready!(this.inner.io().poll_flush(cx, false))
{
log::trace!("peer is gone with {:?}", err);
*this.st = State::Stop;
this.inner.error =
Some(DispatchError::Disconnect(Some(err)));
}
}
Err(Either::Left(err)) => {
Err(RecvError::Decoder(err)) => {
// Malformed requests, respond with 400
log::trace!("malformed request: {:?}", err);
let (res, body) = Response::BadRequest().finish().into_parts();
this.inner.error = Some(DispatchError::Parse(err));
*this.st = this.inner.send_response(res, body.into_body());
}
Err(Either::Right(err)) => {
Err(RecvError::PeerGone(err)) => {
log::trace!("peer is gone with {:?}", err);
// peer is gone
let e = DispatchError::Disconnect(Some(err));
set_error!(this, e);
*this.st = State::Stop;
this.inner.error = Some(DispatchError::Disconnect(err));
}
Err(RecvError::StopDispatcher) => {
log::trace!("dispatcher is instructed to stop");
*this.st = State::Stop;
}
Err(RecvError::KeepAlive) => {
// keep-alive timeout
if !this.inner.flags.contains(Flags::STARTED) {
log::trace!("slow request timeout");
let (req, body) =
Response::RequestTimeout().finish().into_parts();
let _ = this.inner.send_response(req, body.into_body());
this.inner.error = Some(DispatchError::SlowRequestTimeout);
} else {
log::trace!("keep-alive timeout, close connection");
}
*this.st = State::Stop;
}
}
}
@ -371,7 +365,7 @@ where
set_error!(this, e);
} else {
loop {
ready!(this.inner.io().poll_write_backpressure(cx));
let _ = ready!(this.inner.io().poll_flush(cx, false));
let item = ready!(body.poll_next_chunk(cx));
if let Some(st) = this.inner.send_payload(item) {
*this.st = st;
@ -397,6 +391,8 @@ where
}
// prepare to shutdown
State::Stop => {
this.inner.unregister_keepalive();
if this
.inner
.io
@ -441,7 +437,7 @@ where
// connection is not keep-alive, disconnect
if !self.flags.contains(Flags::KEEPALIVE) || !self.codec.keepalive_enabled() {
self.unregister_keepalive();
self.state.stop_dispatcher();
self.state.close();
State::Stop
} else {
self.reset_keepalive();
@ -452,6 +448,7 @@ where
fn unregister_keepalive(&mut self) {
if self.flags.contains(Flags::KEEPALIVE) {
self.config.timer_h1.unregister(self.expire, &self.state);
self.flags.remove(Flags::KEEPALIVE);
}
}
@ -583,28 +580,64 @@ where
loop {
let res = io.poll_recv(&payload.0, cx);
match res {
Poll::Ready(Ok(Some(PayloadItem::Chunk(chunk)))) => {
Poll::Ready(Ok(PayloadItem::Chunk(chunk))) => {
updated = true;
payload.1.feed_data(chunk);
}
Poll::Ready(Ok(Some(PayloadItem::Eof))) => {
Poll::Ready(Ok(PayloadItem::Eof)) => {
updated = true;
payload.1.feed_eof();
self.payload = None;
break;
}
Poll::Ready(Ok(None)) => {
payload.1.set_error(PayloadError::EncodingCorrupted);
self.payload = None;
return Poll::Ready(Err(ParseError::Incomplete.into()));
}
Poll::Ready(Err(e)) => {
payload.1.set_error(PayloadError::EncodingCorrupted);
self.payload = None;
return Poll::Ready(Err(match e {
Either::Left(e) => DispatchError::Parse(e),
Either::Right(e) => DispatchError::Disconnect(Some(e)),
}));
Poll::Ready(Err(err)) => {
let err = match err {
RecvError::WriteBackpressure => {
if io.poll_flush(cx, false)?.is_pending() {
break;
} else {
continue;
}
}
RecvError::KeepAlive => {
payload
.1
.set_error(PayloadError::EncodingCorrupted);
self.payload = None;
io::Error::new(io::ErrorKind::Other, "Keep-alive")
.into()
}
RecvError::StopDispatcher => {
payload
.1
.set_error(PayloadError::EncodingCorrupted);
self.payload = None;
io::Error::new(
io::ErrorKind::Other,
"Dispatcher stopped",
)
.into()
}
RecvError::PeerGone(err) => {
payload
.1
.set_error(PayloadError::EncodingCorrupted);
self.payload = None;
if let Some(err) = err {
DispatchError::Disconnect(Some(err))
} else {
ParseError::Incomplete.into()
}
}
RecvError::Decoder(e) => {
payload
.1
.set_error(PayloadError::EncodingCorrupted);
self.payload = None;
DispatchError::Parse(e)
}
};
return Poll::Ready(Err(err));
}
Poll::Pending => break,
}
@ -870,9 +903,8 @@ mod tests {
}
#[crate::rt_test]
/// if socket is disconnected, h1 dispatcher does not process any data
// /// h1 dispatcher still processes all incoming requests
// /// but it does not write any data to socket
/// /// h1 dispatcher still processes all incoming requests
/// /// but it does not write any data to socket
async fn test_write_disconnected() {
let num = Arc::new(AtomicUsize::new(0));
let num2 = num.clone();
@ -892,7 +924,7 @@ mod tests {
assert!(client.read_any().is_empty());
// only first request get handled
assert_eq!(num.load(Ordering::Relaxed), 0);
assert_eq!(num.load(Ordering::Relaxed), 1);
}
#[crate::rt_test]