From 6c68a59e998d5f99c18e7f1d9e967988c09e6dfe Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Tue, 21 Dec 2021 14:17:29 +0600 Subject: [PATCH] refactor io api and backpressure api --- ntex-io/CHANGES.md | 6 + ntex-io/src/dispatcher.rs | 83 +++++------ ntex-io/src/io.rs | 238 +++++++++++++++++++++----------- ntex-io/src/ioref.rs | 68 ++++----- ntex-tls/Cargo.toml | 4 +- ntex/CHANGES.md | 4 + ntex/Cargo.toml | 6 +- ntex/src/http/client/h1proto.rs | 6 +- ntex/src/http/h1/dispatcher.rs | 89 +++++------- ntex/tests/http_awc_ws.rs | 8 +- ntex/tests/http_ws.rs | 20 +-- ntex/tests/web_ws.rs | 8 +- 12 files changed, 290 insertions(+), 250 deletions(-) diff --git a/ntex-io/CHANGES.md b/ntex-io/CHANGES.md index e884f0af..14a40b72 100644 --- a/ntex-io/CHANGES.md +++ b/ntex-io/CHANGES.md @@ -2,8 +2,14 @@ ## [0.1.0-b.3] - 2021-12-xx +* Add .poll_write_backpressure() + +* Rename .poll_read_next() to .poll_recv() + * Rename .poll_write_ready() to .poll_flush() +* Rename .next() to .recv() + * Rename .write_ready() to .flush() ## [0.1.0-b.2] - 2021-12-20 diff --git a/ntex-io/src/dispatcher.rs b/ntex-io/src/dispatcher.rs index 73ae7e49..60b4c85c 100644 --- a/ntex-io/src/dispatcher.rs +++ b/ntex-io/src/dispatcher.rs @@ -4,8 +4,8 @@ use std::{cell::Cell, future, pin::Pin, rc::Rc, task::Context, task::Poll, time} use ntex_bytes::Pool; use ntex_codec::{Decoder, Encoder}; use ntex_service::{IntoService, Service}; -use ntex_util::future::Either; use ntex_util::time::{now, Seconds}; +use ntex_util::{future::Either, ready}; use super::{rt::spawn, DispatchItem, IoBoxed, IoRef, Timer}; @@ -203,50 +203,46 @@ where // handle memory pool pressure if slf.pool.poll_ready(cx).is_pending() { - io.pause(cx); + io.pause(); return Poll::Pending; } loop { match slf.st.get() { DispatcherState::Processing => { - let result = if let Poll::Ready(result) = - slf.poll_service(this.service, cx, io) - { - result - } else { - return Poll::Pending; - }; - - let item = match result { + let item = match ready!(slf.poll_service(this.service, cx, io)) { PollService::Ready => { - if !io.is_write_ready() { - // instruct write task to notify dispatcher when data is flushed - io.enable_write_backpressure(cx); - slf.st.set(DispatcherState::Backpressure); - DispatchItem::WBackPressureEnabled - } else { - // decode incoming bytes if buffer is ready - match io.poll_read_next(&slf.shared.codec, cx) { - Poll::Ready(Some(Ok(el))) => { - slf.update_keepalive(); - DispatchItem::Item(el) - } - Poll::Ready(Some(Err(Either::Left(err)))) => { - slf.st.set(DispatcherState::Stop); - slf.unregister_keepalive(); - DispatchItem::DecoderError(err) - } - Poll::Ready(Some(Err(Either::Right(err)))) => { - slf.st.set(DispatcherState::Stop); - slf.unregister_keepalive(); - DispatchItem::Disconnect(Some(err)) - } - Poll::Ready(None) => DispatchItem::Disconnect(None), - Poll::Pending => { - log::trace!("not enough data to decode next frame, register dispatch task"); - io.resume(); - return Poll::Pending; + match io.poll_write_backpressure(cx) { + Poll::Pending => { + // instruct write task to notify dispatcher when data is flushed + slf.st.set(DispatcherState::Backpressure); + DispatchItem::WBackPressureEnabled + } + Poll::Ready(()) => { + // decode incoming bytes if buffer is ready + match io.poll_recv(&slf.shared.codec, cx) { + Poll::Ready(Some(Ok(el))) => { + slf.update_keepalive(); + DispatchItem::Item(el) + } + Poll::Ready(Some(Err(Either::Left(err)))) => { + slf.st.set(DispatcherState::Stop); + slf.unregister_keepalive(); + DispatchItem::DecoderError(err) + } + Poll::Ready(Some(Err(Either::Right(err)))) => { + slf.st.set(DispatcherState::Stop); + slf.unregister_keepalive(); + DispatchItem::Disconnect(Some(err)) + } + Poll::Ready(None) => { + DispatchItem::Disconnect(None) + } + Poll::Pending => { + log::trace!("not enough data to decode next frame, register dispatch task"); + io.resume(); + return Poll::Pending; + } } } } @@ -274,13 +270,10 @@ where } // handle write back-pressure DispatcherState::Backpressure => { - let result = match slf.poll_service(this.service, cx, io) { - Poll::Ready(result) => result, - Poll::Pending => return Poll::Pending, - }; + let result = ready!(slf.poll_service(this.service, cx, io)); let item = match result { PollService::Ready => { - if io.is_write_ready() { + if slf.io.poll_write_backpressure(cx).is_ready() { slf.st.set(DispatcherState::Processing); DispatchItem::WBackPressureDisabled } else { @@ -308,7 +301,7 @@ where slf.spawn_service_call(this.service.call(item)); } } - // drain service responses + // drain service responses and shutdown io DispatcherState::Stop => { // service may relay on poll_ready for response results if !this.inner.ready_err.get() { @@ -434,7 +427,7 @@ where // pause io read task Poll::Pending => { log::trace!("service is not ready, register dispatch task"); - io.pause(cx); + io.pause(); Poll::Pending } // handle service readiness error diff --git a/ntex-io/src/io.rs b/ntex-io/src/io.rs index ef393245..7e45676b 100644 --- a/ntex-io/src/io.rs +++ b/ntex-io/src/io.rs @@ -167,6 +167,37 @@ impl IoState { Ok(()) } } + + #[inline] + pub(super) fn with_read_buf(&self, release: bool, f: Fn) -> Ret + where + Fn: FnOnce(&mut Option) -> Ret, + { + let buf = self.read_buf.as_ptr(); + let ref_buf = unsafe { buf.as_mut().unwrap() }; + let result = f(ref_buf); + + // release buffer + if release { + if let Some(ref buf) = ref_buf { + if buf.is_empty() { + let buf = mem::take(ref_buf).unwrap(); + self.pool.get().release_read_buf(buf); + } + } + } + result + } + + #[inline] + pub(super) fn with_write_buf(&self, f: Fn) -> Ret + where + Fn: FnOnce(&mut Option) -> Ret, + { + let buf = self.write_buf.as_ptr(); + let ref_buf = unsafe { buf.as_mut().unwrap() }; + f(ref_buf) + } } impl Eq for IoState {} @@ -376,14 +407,29 @@ impl Io { impl Io { #[inline] /// Read incoming io stream and decode codec item. - pub async fn next( + pub async fn recv( &self, codec: &U, ) -> Option>> where U: Decoder, { - poll_fn(|cx| self.poll_read_next(codec, cx)).await + poll_fn(|cx| self.poll_recv(codec, cx)).await + } + + #[inline] + /// Pause read task + pub fn pause(&self) { + self.0 .0.insert_flags(Flags::RD_PAUSED); + } + + #[inline] + /// Wake read io ask if it is paused + pub fn resume(&self) { + if self.flags().contains(Flags::RD_PAUSED) { + self.0 .0.remove_flags(Flags::RD_PAUSED); + self.0 .0.read_task.wake(); + } } #[inline] @@ -400,13 +446,8 @@ impl Io { let mut buf = filter .get_write_buf() .unwrap_or_else(|| self.memory_pool().get_write_buf()); - - let is_write_sleep = buf.is_empty(); codec.encode(item, &mut buf).map_err(Either::Left)?; filter.release_write_buf(buf).map_err(Either::Right)?; - if is_write_sleep { - self.0 .0.write_task.wake(); - } poll_fn(|cx| self.poll_flush(cx, true)) .await @@ -422,67 +463,11 @@ impl Io { poll_fn(|cx| self.poll_flush(cx, full)).await } - #[doc(hidden)] - #[deprecated] - #[inline] - pub async fn write_ready(&self, full: bool) -> Result<(), io::Error> { - poll_fn(|cx| self.poll_flush(cx, full)).await - } - #[inline] /// Shut down connection pub async fn shutdown(&self) -> Result<(), io::Error> { poll_fn(|cx| self.poll_shutdown(cx)).await } -} - -impl Io { - #[inline] - /// Wake write task and instruct to flush data. - /// - /// If `full` is true then wake up dispatcher when all data is flushed - /// otherwise wake up when size of write buffer is lower than - /// buffer max size. - pub fn poll_flush(&self, cx: &mut Context<'_>, full: bool) -> Poll> { - // check io error - if !self.0 .0.is_io_open() { - return Poll::Ready(Err(self.0 .0.error.take().unwrap_or_else(|| { - io::Error::new(io::ErrorKind::Other, "disconnected") - }))); - } - - if let Some(buf) = self.0 .0.write_buf.take() { - let len = buf.len(); - if len != 0 { - self.0 .0.write_buf.set(Some(buf)); - - if full { - self.0 .0.insert_flags(Flags::WR_WAIT); - self.0 .0.dispatch_task.register(cx.waker()); - return Poll::Pending; - } else if len >= self.0.memory_pool().write_params_high() << 1 { - self.0 .0.insert_flags(Flags::WR_BACKPRESSURE); - self.0 .0.dispatch_task.register(cx.waker()); - return Poll::Pending; - } else { - self.0 .0.remove_flags(Flags::WR_BACKPRESSURE); - } - } - } - - Poll::Ready(Ok(())) - } - - #[doc(hidden)] - #[deprecated] - #[inline] - pub fn poll_write_ready( - &self, - cx: &mut Context<'_>, - full: bool, - ) -> Poll> { - self.poll_flush(cx, full) - } #[inline] /// Wake read task and instruct to read more data @@ -525,7 +510,10 @@ impl Io { #[inline] #[allow(clippy::type_complexity)] - pub fn poll_read_next( + /// Decode codec item from incoming bytes stream. + /// + /// Wake read task and request to read more data if data is not enough for decoding. + pub fn poll_recv( &self, codec: &U, cx: &mut Context<'_>, @@ -544,6 +532,69 @@ impl Io { } } + #[inline] + /// Wake write task and instruct to flush data. + /// + /// If `full` is true then wake up dispatcher when all data is flushed + /// otherwise wake up when size of write buffer is lower than + /// buffer max size. + pub fn poll_flush(&self, cx: &mut Context<'_>, full: bool) -> Poll> { + // check io error + if !self.0 .0.is_io_open() { + return Poll::Ready(Err(self.0 .0.error.take().unwrap_or_else(|| { + io::Error::new(io::ErrorKind::Other, "disconnected") + }))); + } + + if let Some(buf) = self.0 .0.write_buf.take() { + let len = buf.len(); + if len != 0 { + self.0 .0.write_buf.set(Some(buf)); + + if full { + self.0 .0.insert_flags(Flags::WR_WAIT); + self.0 .0.dispatch_task.register(cx.waker()); + return Poll::Pending; + } else if len >= self.0.memory_pool().write_params_high() << 1 { + self.0 .0.insert_flags(Flags::WR_BACKPRESSURE); + self.0 .0.dispatch_task.register(cx.waker()); + return Poll::Pending; + } else { + self.0 .0.remove_flags(Flags::WR_BACKPRESSURE); + } + } + } + + Poll::Ready(Ok(())) + } + + #[inline] + /// Wait until write task flushes data to io stream + /// + /// Write task must be waken up separately. + pub fn poll_write_backpressure(&self, cx: &mut Context<'_>) -> Poll<()> { + if !self.is_io_open() { + Poll::Ready(()) + } else if self.flags().contains(Flags::WR_BACKPRESSURE) { + self.0 .0.dispatch_task.register(cx.waker()); + Poll::Pending + } else { + let len = self + .0 + .0 + .with_write_buf(|buf| buf.as_ref().map(|b| b.len()).unwrap_or(0)); + let hw = self.memory_pool().write_params_high(); + if len >= hw { + log::trace!("enable write back-pressure"); + self.0 .0.insert_flags(Flags::WR_BACKPRESSURE); + self.0 .0.dispatch_task.register(cx.waker()); + Poll::Pending + } else { + Poll::Ready(()) + } + } + } + #[inline] /// Shut down connection pub fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll> { @@ -565,30 +616,55 @@ impl Io { } } + #[doc(hidden)] + #[deprecated] #[inline] - /// Pause read task - pub fn pause(&self, cx: &mut Context<'_>) { - self.0 .0.insert_flags(Flags::RD_PAUSED); - self.0 .0.dispatch_task.register(cx.waker()); + pub async fn next( + &self, + codec: &U, + ) -> Option>> + where + U: Decoder, + { + self.recv(codec).await } + #[doc(hidden)] + #[deprecated] #[inline] - /// Wake read io task if it is paused - pub fn resume(&self) -> bool { - let flags = self.0 .0.flags.get(); - if flags.contains(Flags::RD_PAUSED) { - self.0 .0.remove_flags(Flags::RD_PAUSED); - self.0 .0.read_task.wake(); - true - } else { - false - } + pub async fn write_ready(&self, full: bool) -> Result<(), io::Error> { + poll_fn(|cx| self.poll_flush(cx, full)).await } + #[doc(hidden)] + #[deprecated] + #[inline] + pub fn poll_write_ready( + &self, + cx: &mut Context<'_>, + full: bool, + ) -> Poll> { + self.poll_flush(cx, full) + } + + #[doc(hidden)] + #[deprecated] + #[inline] + #[allow(clippy::type_complexity)] + pub fn poll_read_next( + &self, + codec: &U, + cx: &mut Context<'_>, + ) -> Poll>>> + where + U: Decoder, + { + self.poll_recv(codec, cx) + } + + #[doc(hidden)] + #[deprecated] #[inline] - /// Wait until write task flushes data to io stream - /// - /// Write task must be waken up separately. pub fn enable_write_backpressure(&self, cx: &mut Context<'_>) { log::trace!("enable write back-pressure for dispatcher"); self.0 .0.insert_flags(Flags::WR_BACKPRESSURE); diff --git a/ntex-io/src/ioref.rs b/ntex-io/src/ioref.rs index 91853346..57cfe463 100644 --- a/ntex-io/src/ioref.rs +++ b/ntex-io/src/ioref.rs @@ -126,26 +126,19 @@ impl IoRef { #[inline] /// Check if write buffer is full pub fn is_write_buf_full(&self) -> bool { - if let Some(buf) = self.0.read_buf.take() { - let hw = self.memory_pool().write_params_high(); - let result = buf.len() >= hw; - self.0.write_buf.set(Some(buf)); - result - } else { - false - } + let len = self + .0 + .with_write_buf(|buf| buf.as_ref().map(|b| b.len()).unwrap_or(0)); + len >= self.memory_pool().write_params_high() } #[inline] /// Check if read buffer is full pub fn is_read_buf_full(&self) -> bool { - if let Some(buf) = self.0.read_buf.take() { - let result = buf.len() >= self.memory_pool().read_params_high(); - self.0.read_buf.set(Some(buf)); - result - } else { - false - } + let len = self + .0 + .with_read_buf(false, |buf| buf.as_ref().map(|b| b.len()).unwrap_or(0)); + len >= self.memory_pool().read_params_high() } #[inline] @@ -167,9 +160,6 @@ impl IoRef { let mut buf = filter .get_write_buf() .unwrap_or_else(|| self.memory_pool().get_write_buf()); - if buf.is_empty() { - self.0.write_task.wake(); - } let result = f(&mut buf); filter.release_write_buf(buf)?; @@ -182,18 +172,13 @@ impl IoRef { where F: FnOnce(&mut BytesMut) -> R, { - let mut buf = self - .0 - .read_buf - .take() - .unwrap_or_else(|| self.memory_pool().get_read_buf()); - let res = f(&mut buf); - if buf.is_empty() { - self.memory_pool().release_read_buf(buf); - } else { - self.0.read_buf.set(Some(buf)); - } - res + self.0.with_read_buf(true, |buf| { + // set buf + if buf.is_none() { + *buf = Some(self.memory_pool().get_read_buf()); + } + f(buf.as_mut().unwrap()) + }) } #[inline] @@ -252,12 +237,9 @@ impl IoRef { where U: Decoder, { - if let Some(mut buf) = self.0.read_buf.take() { - let result = codec.decode(&mut buf); - self.0.read_buf.set(Some(buf)); - return result; - } - Ok(None) + self.0.with_read_buf(false, |buf| { + buf.as_mut().map(|b| codec.decode(b)).unwrap_or(Ok(None)) + }) } #[inline] @@ -325,20 +307,20 @@ mod tests { assert!(!state.is_read_buf_full()); assert!(!state.is_write_buf_full()); - let msg = state.next(&BytesCodec).await.unwrap().unwrap(); + let msg = state.recv(&BytesCodec).await.unwrap().unwrap(); assert_eq!(msg, Bytes::from_static(BIN)); - let res = poll_fn(|cx| Poll::Ready(state.poll_read_next(&BytesCodec, cx))).await; + let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await; assert!(res.is_pending()); client.write(TEXT); sleep(Millis(50)).await; - let res = poll_fn(|cx| Poll::Ready(state.poll_read_next(&BytesCodec, cx))).await; + let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await; if let Poll::Ready(msg) = res { assert_eq!(msg.unwrap().unwrap(), Bytes::from_static(BIN)); } client.read_error(io::Error::new(io::ErrorKind::Other, "err")); - let msg = state.next(&BytesCodec).await; + let msg = state.recv(&BytesCodec).await; assert!(msg.unwrap().is_err()); assert!(state.flags().contains(Flags::IO_ERR)); assert!(state.flags().contains(Flags::DSP_STOP)); @@ -348,7 +330,7 @@ mod tests { let state = Io::new(server); client.read_error(io::Error::new(io::ErrorKind::Other, "err")); - let res = poll_fn(|cx| Poll::Ready(state.poll_read_next(&BytesCodec, cx))).await; + let res = poll_fn(|cx| Poll::Ready(state.poll_recv(&BytesCodec, cx))).await; if let Poll::Ready(msg) = res { assert!(msg.unwrap().is_err()); assert!(state.flags().contains(Flags::IO_ERR)); @@ -506,7 +488,7 @@ mod tests { client.remote_buffer_cap(1024); client.write(TEXT); - let msg = state.next(&BytesCodec).await.unwrap().unwrap(); + let msg = state.recv(&BytesCodec).await.unwrap().unwrap(); assert_eq!(msg, Bytes::from_static(BIN)); state @@ -537,7 +519,7 @@ mod tests { client.remote_buffer_cap(1024); client.write(TEXT); - let msg = state.next(&BytesCodec).await.unwrap().unwrap(); + let msg = state.recv(&BytesCodec).await.unwrap().unwrap(); assert_eq!(msg, Bytes::from_static(BIN)); state diff --git a/ntex-tls/Cargo.toml b/ntex-tls/Cargo.toml index 3471cf55..588838ae 100644 --- a/ntex-tls/Cargo.toml +++ b/ntex-tls/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-tls" -version = "0.1.0-b.1" +version = "0.1.0-b.2" authors = ["ntex contributors "] description = "An implementation of SSL streams for ntex backed by OpenSSL" keywords = ["network", "framework", "async", "futures"] @@ -26,7 +26,7 @@ rustls = ["tls_rust"] [dependencies] ntex-bytes = "0.1.8" -ntex-io = "0.1.0-b.1" +ntex-io = "0.1.0-b.3" ntex-util = "0.1.3" ntex-service = "0.2.1" pin-project-lite = "0.2" diff --git a/ntex/CHANGES.md b/ntex/CHANGES.md index c5edc188..8dbf63a3 100644 --- a/ntex/CHANGES.md +++ b/ntex/CHANGES.md @@ -1,5 +1,9 @@ # Changes +## [0.5.0-b.2] - 2021-12-xx + +* Refactor write back-pressure for http1 + ## [0.5.0-b.1] - 2021-12-20 * Refactor http/1 dispatcher diff --git a/ntex/Cargo.toml b/ntex/Cargo.toml index 19e77cf4..da74fca3 100644 --- a/ntex/Cargo.toml +++ b/ntex/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex" -version = "0.5.0-b.1" +version = "0.5.0-b.2" authors = ["ntex contributors "] description = "Framework for composable network services" readme = "README.md" @@ -45,8 +45,8 @@ ntex-service = "0.2.1" ntex-macros = "0.1.3" ntex-util = "0.1.3" ntex-bytes = "0.1.8" -ntex-tls = "=0.1.0-b.1" -ntex-io = "=0.1.0-b.3" +ntex-tls = "0.1.0-b.2" +ntex-io = "0.1.0-b.3" ntex-rt = { version = "0.4.0-b.0", default-features = false, features = ["tokio"] } base64 = "0.13" diff --git a/ntex/src/http/client/h1proto.rs b/ntex/src/http/client/h1proto.rs index 7e2ac336..2490a853 100644 --- a/ntex/src/http/client/h1proto.rs +++ b/ntex/src/http/client/h1proto.rs @@ -74,7 +74,7 @@ where log::trace!("reading http1 response"); // read response and init read body - let head = if let Some(result) = io.next(&codec).await { + let head = if let Some(result) = io.recv(&codec).await { let result = result?; log::trace!( "http1 response is received, type: {:?}, response: {:#?}", @@ -108,7 +108,7 @@ pub(super) async fn open_tunnel( io.send((head, BodySize::None).into(), &codec).await?; // read response - if let Some(head) = io.next(&codec).await { + if let Some(head) = io.recv(&codec).await { Ok((head?, io, codec)) } else { Err(SendRequestError::from(ConnectError::Disconnected)) @@ -173,7 +173,7 @@ impl Stream for PlStream { cx: &mut Context<'_>, ) -> Poll> { let mut this = self.as_mut(); - match this.io.as_ref().unwrap().poll_read_next(&this.codec, cx)? { + match this.io.as_ref().unwrap().poll_recv(&this.codec, cx)? { Poll::Pending => Poll::Pending, Poll::Ready(Some(chunk)) => { if let Some(chunk) = chunk { diff --git a/ntex/src/http/h1/dispatcher.rs b/ntex/src/http/h1/dispatcher.rs index a8693671..fe7f3016 100644 --- a/ntex/src/http/h1/dispatcher.rs +++ b/ntex/src/http/h1/dispatcher.rs @@ -76,12 +76,6 @@ struct DispatcherInner { _t: marker::PhantomData<(S, B)>, } -enum WritePayloadStatus { - Next(State), - Pause, - Continue, -} - impl Dispatcher where F: Filter + 'static, @@ -128,6 +122,12 @@ where } } +macro_rules! set_error ({ $slf:tt, $err:ident } => { + *$slf.st = State::Stop; + $slf.inner.error = Some($err); + $slf.inner.unregister_keepalive(); +}); + impl Future for Dispatcher where F: Filter, @@ -167,9 +167,7 @@ where if let Err(e) = ready!(this.inner.poll_read_payload(cx)) { - *this.st = State::Stop; - this.inner.unregister_keepalive(); - this.inner.error = Some(e); + set_error!(this, e); } } else { return Poll::Pending; @@ -286,7 +284,7 @@ where let io = this.inner.io(); // decode incoming bytes stream - match io.poll_read_next(&this.inner.codec, cx) { + match io.poll_recv(&this.inner.codec, cx) { Poll::Ready(Some(Ok((mut req, pl)))) => { log::trace!( "http message is received: {:?} and payload {:?}", @@ -363,19 +361,14 @@ where Poll::Ready(Some(Err(Either::Right(err)))) => { log::trace!("peer is gone with {:?}", err); // peer is gone - *this.st = State::Stop; - this.inner.unregister_keepalive(); - this.inner.state.stop_dispatcher(); - return Poll::Ready(Err(DispatchError::Disconnect(Some( - err, - )))); + let e = DispatchError::Disconnect(Some(err)); + set_error!(this, e); } Poll::Ready(None) => { log::trace!("peer is gone"); // peer is gone - this.inner.unregister_keepalive(); - this.inner.state.stop_dispatcher(); - return Poll::Ready(Err(DispatchError::Disconnect(None))); + let e = DispatchError::Disconnect(None); + set_error!(this, e); } Poll::Pending => { log::trace!("not enough data to decode http message"); @@ -389,35 +382,25 @@ where *this.st = this.inner.switch_to_read_request(); } Err(e) => { - *this.st = State::Stop; - this.inner.error = Some(e); - this.inner.unregister_keepalive(); + set_error!(this, e); } }, // send response body State::SendPayload { ref mut body } => { if !this.inner.state.is_io_open() { - *this.st = State::Stop; - this.inner.error = Some(this.inner.state.take_error().into()); - this.inner.unregister_keepalive(); + let e = this.inner.state.take_error().into(); + set_error!(this, e); } else if let Poll::Ready(Err(e)) = this.inner.poll_read_payload(cx) { - *this.st = State::Stop; - this.inner.error = Some(e); - this.inner.unregister_keepalive(); + set_error!(this, e); } else { - match body.poll_next_chunk(cx) { - Poll::Ready(item) => match this.inner.send_payload(item) { - WritePayloadStatus::Next(st) => { - *this.st = st; - } - WritePayloadStatus::Pause => { - this.inner.io().enable_write_backpressure(cx); - return Poll::Pending; - } - WritePayloadStatus::Continue => (), - }, - Poll::Pending => return Poll::Pending, + loop { + ready!(this.inner.io().poll_write_backpressure(cx)); + let item = ready!(body.poll_next_chunk(cx)); + if let Some(st) = this.inner.send_payload(item) { + *this.st = st; + break; + } } } } @@ -579,21 +562,15 @@ where fn send_payload( &mut self, item: Option>>, - ) -> WritePayloadStatus { + ) -> Option> { match item { Some(Ok(item)) => { trace!("got response chunk: {:?}", item.len()); match self.io().encode(Message::Chunk(Some(item)), &self.codec) { + Ok(_) => None, Err(err) => { self.error = Some(DispatchError::Encode(err)); - WritePayloadStatus::Next(State::Stop) - } - Ok(has_space) => { - if has_space { - WritePayloadStatus::Continue - } else { - WritePayloadStatus::Pause - } + Some(State::Stop) } } } @@ -601,20 +578,20 @@ where trace!("response payload eof"); if let Err(err) = self.io().encode(Message::Chunk(None), &self.codec) { self.error = Some(DispatchError::Encode(err)); - WritePayloadStatus::Next(State::Stop) + Some(State::Stop) } else if self.flags.contains(Flags::SENDPAYLOAD_AND_STOP) { - WritePayloadStatus::Next(State::Stop) + Some(State::Stop) } else if self.payload.is_some() { - WritePayloadStatus::Next(State::ReadPayload) + Some(State::ReadPayload) } else { self.reset_keepalive(); - WritePayloadStatus::Next(self.switch_to_read_request()) + Some(self.switch_to_read_request()) } } Some(Err(e)) => { trace!("error during response body poll: {:?}", e); self.error = Some(DispatchError::ResponsePayload(e)); - WritePayloadStatus::Next(State::Stop) + Some(State::Stop) } } } @@ -633,7 +610,8 @@ where // read request payload let mut updated = false; loop { - match io.poll_read_next(&payload.0, cx) { + let res = io.poll_recv(&payload.0, cx); + match res { Poll::Ready(Some(Ok(PayloadItem::Chunk(chunk)))) => { updated = true; payload.1.feed_data(chunk); @@ -1029,6 +1007,7 @@ mod tests { #[crate::rt_test] async fn test_write_backpressure() { + env_logger::init(); let num = Arc::new(AtomicUsize::new(0)); let num2 = num.clone(); diff --git a/ntex/tests/http_awc_ws.rs b/ntex/tests/http_awc_ws.rs index e057762c..5baba936 100644 --- a/ntex/tests/http_awc_ws.rs +++ b/ntex/tests/http_awc_ws.rs @@ -57,19 +57,19 @@ async fn test_simple() { io.send(ws::Message::Text(ByteString::from_static("text")), &codec) .await .unwrap(); - let item = io.next(&codec).await.unwrap().unwrap(); + let item = io.recv(&codec).await.unwrap().unwrap(); assert_eq!(item, ws::Frame::Text(Bytes::from_static(b"text"))); io.send(ws::Message::Binary("text".into()), &codec) .await .unwrap(); - let item = io.next(&codec).await.unwrap().unwrap(); + let item = io.recv(&codec).await.unwrap().unwrap(); assert_eq!(item, ws::Frame::Binary(Bytes::from_static(b"text"))); io.send(ws::Message::Ping("text".into()), &codec) .await .unwrap(); - let item = io.next(&codec).await.unwrap().unwrap(); + let item = io.recv(&codec).await.unwrap().unwrap(); assert_eq!(item, ws::Frame::Pong("text".to_string().into())); io.send( @@ -79,6 +79,6 @@ async fn test_simple() { .await .unwrap(); - let item = io.next(&codec).await.unwrap().unwrap(); + let item = io.recv(&codec).await.unwrap().unwrap(); assert_eq!(item, ws::Frame::Close(Some(ws::CloseCode::Normal.into()))); } diff --git a/ntex/tests/http_ws.rs b/ntex/tests/http_ws.rs index 33b78991..ce38fb82 100644 --- a/ntex/tests/http_ws.rs +++ b/ntex/tests/http_ws.rs @@ -100,7 +100,7 @@ async fn test_simple() { io.send(ws::Message::Text(ByteString::from_static("text")), &codec) .await .unwrap(); - let item = io.next(&codec).await; + let item = io.recv(&codec).await; assert_eq!( item.unwrap().unwrap(), ws::Frame::Text(Bytes::from_static(b"text")) @@ -109,7 +109,7 @@ async fn test_simple() { io.send(ws::Message::Binary("text".into()), &codec) .await .unwrap(); - let item = io.next(&codec).await; + let item = io.recv(&codec).await; assert_eq!( item.unwrap().unwrap(), ws::Frame::Binary(Bytes::from_static(&b"text"[..])) @@ -118,7 +118,7 @@ async fn test_simple() { io.send(ws::Message::Ping("text".into()), &codec) .await .unwrap(); - let item = io.next(&codec).await; + let item = io.recv(&codec).await; assert_eq!( item.unwrap().unwrap(), ws::Frame::Pong("text".to_string().into()) @@ -130,7 +130,7 @@ async fn test_simple() { ) .await .unwrap(); - let item = io.next(&codec).await; + let item = io.recv(&codec).await; assert_eq!( item.unwrap().unwrap(), ws::Frame::Continuation(ws::Item::FirstText(Bytes::from_static(b"text"))) @@ -157,7 +157,7 @@ async fn test_simple() { ) .await .unwrap(); - let item = io.next(&codec).await; + let item = io.recv(&codec).await; assert_eq!( item.unwrap().unwrap(), ws::Frame::Continuation(ws::Item::Continue(Bytes::from_static(b"text"))) @@ -169,7 +169,7 @@ async fn test_simple() { ) .await .unwrap(); - let item = io.next(&codec).await; + let item = io.recv(&codec).await; assert_eq!( item.unwrap().unwrap(), ws::Frame::Continuation(ws::Item::Last(Bytes::from_static(b"text"))) @@ -197,7 +197,7 @@ async fn test_simple() { ) .await .unwrap(); - let item = io.next(&codec).await; + let item = io.recv(&codec).await; assert_eq!( item.unwrap().unwrap(), ws::Frame::Continuation(ws::Item::FirstBinary(Bytes::from_static(b"bin"))) @@ -209,7 +209,7 @@ async fn test_simple() { ) .await .unwrap(); - let item = io.next(&codec).await; + let item = io.recv(&codec).await; assert_eq!( item.unwrap().unwrap(), ws::Frame::Continuation(ws::Item::Continue(Bytes::from_static(b"text"))) @@ -221,7 +221,7 @@ async fn test_simple() { ) .await .unwrap(); - let item = io.next(&codec).await; + let item = io.recv(&codec).await; assert_eq!( item.unwrap().unwrap(), ws::Frame::Continuation(ws::Item::Last(Bytes::from_static(b"text"))) @@ -234,7 +234,7 @@ async fn test_simple() { .await .unwrap(); - let item = io.next(&codec).await; + let item = io.recv(&codec).await; assert_eq!( item.unwrap().unwrap(), ws::Frame::Close(Some(ws::CloseCode::Normal.into())) diff --git a/ntex/tests/web_ws.rs b/ntex/tests/web_ws.rs index 98081d9f..ff40bfbb 100644 --- a/ntex/tests/web_ws.rs +++ b/ntex/tests/web_ws.rs @@ -41,19 +41,19 @@ async fn web_ws() { io.send(ws::Message::Text(ByteString::from_static("text")), &codec) .await .unwrap(); - let item = io.next(&codec).await.unwrap().unwrap(); + let item = io.recv(&codec).await.unwrap().unwrap(); assert_eq!(item, ws::Frame::Text(Bytes::from_static(b"text"))); io.send(ws::Message::Binary("text".into()), &codec) .await .unwrap(); - let item = io.next(&codec).await.unwrap().unwrap(); + let item = io.recv(&codec).await.unwrap().unwrap(); assert_eq!(item, ws::Frame::Binary(Bytes::from_static(b"text"))); io.send(ws::Message::Ping("text".into()), &codec) .await .unwrap(); - let item = io.next(&codec).await.unwrap().unwrap(); + let item = io.recv(&codec).await.unwrap().unwrap(); assert_eq!(item, ws::Frame::Pong("text".to_string().into())); io.send( @@ -63,7 +63,7 @@ async fn web_ws() { .await .unwrap(); - let item = io.next(&codec).await.unwrap().unwrap(); + let item = io.recv(&codec).await.unwrap().unwrap(); assert_eq!(item, ws::Frame::Close(Some(ws::CloseCode::Away.into()))); }